diff --git a/netty-http-bouncycastle/src/main/java/module-info.java b/netty-http-bouncycastle/src/main/java/module-info.java index d509a26..1461f09 100644 --- a/netty-http-bouncycastle/src/main/java/module-info.java +++ b/netty-http-bouncycastle/src/main/java/module-info.java @@ -1,8 +1,10 @@ +import org.xbib.netty.http.server.api.ServerCertificateProvider; + module org.xbib.netty.http.bouncycastle { exports org.xbib.netty.http.bouncycastle; requires org.xbib.netty.http.server.api; requires org.bouncycastle.pkix; requires org.bouncycastle.provider; - provides org.xbib.netty.http.server.api.security.ServerCertificateProvider with + provides ServerCertificateProvider with org.xbib.netty.http.bouncycastle.BouncyCastleSelfSignedCertificateProvider; } diff --git a/netty-http-bouncycastle/src/main/java/org/xbib/netty/http/bouncycastle/BouncyCastleSelfSignedCertificateProvider.java b/netty-http-bouncycastle/src/main/java/org/xbib/netty/http/bouncycastle/BouncyCastleSelfSignedCertificateProvider.java index f6599ab..fc36da0 100644 --- a/netty-http-bouncycastle/src/main/java/org/xbib/netty/http/bouncycastle/BouncyCastleSelfSignedCertificateProvider.java +++ b/netty-http-bouncycastle/src/main/java/org/xbib/netty/http/bouncycastle/BouncyCastleSelfSignedCertificateProvider.java @@ -1,7 +1,7 @@ package org.xbib.netty.http.bouncycastle; import org.bouncycastle.operator.OperatorCreationException; -import org.xbib.netty.http.server.api.security.ServerCertificateProvider; +import org.xbib.netty.http.server.api.ServerCertificateProvider; import java.io.IOException; import java.io.InputStream; import java.io.UncheckedIOException; diff --git a/netty-http-bouncycastle/src/main/resources/META-INF/services/org.xbib.netty.http.server.api.security.ServerCertificateProvider b/netty-http-bouncycastle/src/main/resources/META-INF/services/org.xbib.netty.http.server.api.ServerCertificateProvider similarity index 100% rename from netty-http-bouncycastle/src/main/resources/META-INF/services/org.xbib.netty.http.server.api.security.ServerCertificateProvider rename to netty-http-bouncycastle/src/main/resources/META-INF/services/org.xbib.netty.http.server.api.ServerCertificateProvider diff --git a/netty-http-client-api/src/main/java/org/xbib/netty/http/client/api/Request.java b/netty-http-client-api/src/main/java/org/xbib/netty/http/client/api/Request.java index 2e9868b..f6e2a80 100644 --- a/netty-http-client-api/src/main/java/org/xbib/netty/http/client/api/Request.java +++ b/netty-http-client-api/src/main/java/org/xbib/netty/http/client/api/Request.java @@ -12,6 +12,7 @@ import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.multipart.InterfaceHttpData; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.codec.http2.HttpConversionUtil; import io.netty.util.AsciiString; import org.xbib.net.URL; @@ -76,12 +77,28 @@ public final class Request implements AutoCloseable { private TimeoutListener timeoutListener; - private Request(URL url, HttpVersion httpVersion, HttpMethod httpMethod, - HttpHeaders headers, Collection cookies, ByteBuf content, List bodyData, - long timeoutInMillis, boolean followRedirect, int maxRedirect, int redirectCount, - boolean isBackOff, BackOff backOff, - ResponseListener responseListener, ExceptionListener exceptionListener, - TimeoutListener timeoutListener) { + private final WebSocketFrame webSocketFrame; + + private final WebSocketResponseListener webSocketResponseListener; + + private Request(URL url, + HttpVersion httpVersion, + HttpMethod httpMethod, + HttpHeaders headers, + Collection cookies, + ByteBuf content, + List bodyData, + long timeoutInMillis, + boolean followRedirect, + int maxRedirect, + int redirectCount, + boolean isBackOff, + BackOff backOff, + ResponseListener responseListener, + ExceptionListener exceptionListener, + TimeoutListener timeoutListener, + WebSocketFrame webSocketFrame, + WebSocketResponseListener webSocketResponseListener) { this.url = url; this.httpVersion = httpVersion; this.httpMethod = httpMethod; @@ -98,6 +115,8 @@ public final class Request implements AutoCloseable { this.responseListener = responseListener; this.exceptionListener = exceptionListener; this.timeoutListener = timeoutListener; + this.webSocketFrame = webSocketFrame; + this.webSocketResponseListener = webSocketResponseListener; } public URL url() { @@ -157,6 +176,14 @@ public final class Request implements AutoCloseable { return backOff; } + public WebSocketFrame getWebSocketFrame() { + return webSocketFrame; + } + + public WebSocketResponseListener getWebSocketResponseListener() { + return webSocketResponseListener; + } + public boolean canRedirect() { if (!followRedirect) { return false; @@ -356,6 +383,10 @@ public final class Request implements AutoCloseable { private TimeoutListener timeoutListener; + private WebSocketFrame webSocketFrame; + + private WebSocketResponseListener webSocketResponseListener; + Builder(ByteBufAllocator allocator) { this.allocator = allocator; this.httpMethod = DEFAULT_METHOD; @@ -622,6 +653,16 @@ public final class Request implements AutoCloseable { return this; } + public Builder setWebSocketFrame(WebSocketFrame webSocketFrame) { + this.webSocketFrame = webSocketFrame; + return this; + } + + public Builder setWebSocketResponseListener(WebSocketResponseListener webSocketResponseListener) { + this.webSocketResponseListener = webSocketResponseListener; + return this; + } + public Request build() { DefaultHttpHeaders validatedHeaders = new DefaultHttpHeaders(true); validatedHeaders.set(headers); @@ -670,7 +711,7 @@ public final class Request implements AutoCloseable { } return new Request(url, httpVersion, httpMethod, validatedHeaders, cookies, content, bodyData, timeoutInMillis, followRedirect, maxRedirects, 0, enableBackOff, backOff, - responseListener, exceptionListener, timeoutListener); + responseListener, exceptionListener, timeoutListener, webSocketFrame, webSocketResponseListener); } private void addHeader(AsciiString name, Object value) { diff --git a/netty-http-client-api/src/main/java/org/xbib/netty/http/client/api/WebSocketResponseListener.java b/netty-http-client-api/src/main/java/org/xbib/netty/http/client/api/WebSocketResponseListener.java new file mode 100644 index 0000000..3287416 --- /dev/null +++ b/netty-http-client-api/src/main/java/org/xbib/netty/http/client/api/WebSocketResponseListener.java @@ -0,0 +1,9 @@ +package org.xbib.netty.http.client.api; + +import io.netty.handler.codec.http.websocketx.WebSocketFrame; + +@FunctionalInterface +public interface WebSocketResponseListener { + + void onResponse(F frame); +} diff --git a/netty-http-client/NOTICE.txt b/netty-http-client/NOTICE.txt new file mode 100644 index 0000000..0410686 --- /dev/null +++ b/netty-http-client/NOTICE.txt @@ -0,0 +1,8 @@ + +http2 web socket implementation based on the work of Maksym Ostroverkhov + +https://github.com/jauntsdn/netty-websocket-http2/ + +Apache License 2.0 + +forked at 20 October 2021 diff --git a/netty-http-client/src/main/java/org/xbib/netty/http/client/ClientConfig.java b/netty-http-client/src/main/java/org/xbib/netty/http/client/ClientConfig.java index f1dbf9d..21e88a5 100644 --- a/netty-http-client/src/main/java/org/xbib/netty/http/client/ClientConfig.java +++ b/netty-http-client/src/main/java/org/xbib/netty/http/client/ClientConfig.java @@ -245,7 +245,7 @@ public class ClientConfig { private Boolean poolSecure = Defaults.POOL_SECURE; - private List serverNamesForIdentification = new ArrayList<>(); + private final List serverNamesForIdentification = new ArrayList<>(); private Http2Settings http2Settings = Defaults.HTTP2_SETTINGS; diff --git a/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/http/Http1ChannelInitializer.java b/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/http/Http1ChannelInitializer.java index 1ce4946..08cf26e 100644 --- a/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/http/Http1ChannelInitializer.java +++ b/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/http/Http1ChannelInitializer.java @@ -4,9 +4,12 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpContentDecompressor; import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory; +import io.netty.handler.codec.http.websocketx.WebSocketVersion; import io.netty.handler.logging.LogLevel; import io.netty.handler.ssl.ApplicationProtocolNames; import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler; @@ -14,9 +17,12 @@ import io.netty.handler.ssl.SslHandler; import io.netty.handler.stream.ChunkedWriteHandler; import org.xbib.netty.http.client.Client; import org.xbib.netty.http.client.ClientConfig; +import org.xbib.netty.http.client.handler.ws1.Http1WebSocketClientHandler; import org.xbib.netty.http.common.HttpChannelInitializer; import org.xbib.netty.http.client.handler.http2.Http2ChannelInitializer; import org.xbib.netty.http.common.HttpAddress; + +import java.net.URI; import java.util.logging.Level; import java.util.logging.Logger; @@ -96,17 +102,23 @@ public class Http1ChannelInitializer extends ChannelInitializer impleme private void configureCleartext(Channel channel) { ChannelPipeline pipeline = channel.pipeline(); - //pipeline.addLast("client-chunk-compressor", new HttpChunkContentCompressor(6)); - pipeline.addLast("http-client-chunk-writer", new ChunkedWriteHandler()); - pipeline.addLast("http-client-codec", new HttpClientCodec(clientConfig.getMaxInitialLineLength(), + pipeline.addLast("http-client-chunk-writer", + new ChunkedWriteHandler()); + pipeline.addLast("http-client-codec", + new HttpClientCodec(clientConfig.getMaxInitialLineLength(), clientConfig.getMaxHeadersSize(), clientConfig.getMaxChunkSize())); if (clientConfig.isEnableGzip()) { pipeline.addLast("http-client-decompressor", new HttpContentDecompressor()); } - HttpObjectAggregator httpObjectAggregator = new HttpObjectAggregator(clientConfig.getMaxContentLength(), - false); + HttpObjectAggregator httpObjectAggregator = + new HttpObjectAggregator(clientConfig.getMaxContentLength(), false); httpObjectAggregator.setMaxCumulationBufferComponents(clientConfig.getMaxCompositeBufferComponents()); - pipeline.addLast("http-client-aggregator", httpObjectAggregator); - pipeline.addLast("http-client-handler", httpResponseHandler); + pipeline.addLast("http-client-aggregator", + httpObjectAggregator); + //pipeline.addLast( "http-client-ws-protocol-handler", + // new Http1WebSocketClientHandler(WebSocketClientHandshakerFactory.newHandshaker(URI.create("/websocket"), + // WebSocketVersion.V13, null, false, new DefaultHttpHeaders()))); + pipeline.addLast("http-client-handler", + httpResponseHandler); } } diff --git a/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/http/HttpChunkContentCompressor.java b/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/http/HttpChunkContentCompressor.java index 5da36fe..e50c1ad 100644 --- a/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/http/HttpChunkContentCompressor.java +++ b/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/http/HttpChunkContentCompressor.java @@ -11,10 +11,6 @@ import io.netty.handler.codec.http.HttpContentCompressor; */ public class HttpChunkContentCompressor extends HttpContentCompressor { - HttpChunkContentCompressor(int compressionLevel) { - super(compressionLevel); - } - @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { if (msg instanceof ByteBuf) { diff --git a/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/ws1/Http1WebSocketClientHandler.java b/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/ws1/Http1WebSocketClientHandler.java new file mode 100644 index 0000000..a4c7250 --- /dev/null +++ b/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/ws1/Http1WebSocketClientHandler.java @@ -0,0 +1,53 @@ +package org.xbib.netty.http.client.handler.ws1; + +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker; +import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler; + +import java.io.IOException; + +public class Http1WebSocketClientHandler extends ChannelInboundHandlerAdapter { + + final WebSocketClientHandshaker handshaker; + + public Http1WebSocketClientHandler(WebSocketClientHandshaker handshaker) { + this.handshaker = handshaker; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.fireChannelActive(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + ctx.fireChannelInactive(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof CloseWebSocketFrame) { + handshaker.close(ctx.channel(), (CloseWebSocketFrame) msg) + .addListener(ChannelFutureListener.CLOSE); + } else { + ctx.fireChannelRead(msg); + } + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) { + String actualProtocol = handshaker.actualSubprotocol(); + if (actualProtocol.equals("")) { + } + else { + throw new IOException("Invalid Websocket Protocol"); + } + } else { + ctx.fireUserEventTriggered(evt); + } + } +} \ No newline at end of file diff --git a/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/ws2/Http2WebSocketClientBuilder.java b/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/ws2/Http2WebSocketClientBuilder.java new file mode 100644 index 0000000..e22edd3 --- /dev/null +++ b/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/ws2/Http2WebSocketClientBuilder.java @@ -0,0 +1,147 @@ +package org.xbib.netty.http.client.handler.ws2; + +import io.netty.handler.codec.http.websocketx.WebSocketDecoderConfig; +import io.netty.handler.codec.http.websocketx.extensions.compression.PerMessageDeflateClientExtensionHandshaker; +import org.xbib.netty.http.common.ws.Preconditions; + +/** + * Builder for {@link Http2WebSocketClientHandler} + */ +public final class Http2WebSocketClientBuilder { + + private static final short DEFAULT_STREAM_WEIGHT = 16; + + private static final boolean MASK_PAYLOAD = true; + + private WebSocketDecoderConfig webSocketDecoderConfig; + + private PerMessageDeflateClientExtensionHandshaker perMessageDeflateClientExtensionHandshaker; + + private long handshakeTimeoutMillis = 15_000; + + private short streamWeight; + + private long closedWebSocketRemoveTimeoutMillis = 30_000; + + private boolean isSingleWebSocketPerConnection; + + Http2WebSocketClientBuilder() {} + + /** @return new {@link Http2WebSocketClientBuilder} instance */ + public static Http2WebSocketClientBuilder create() { + return new Http2WebSocketClientBuilder(); + } + + /** + * @param webSocketDecoderConfig websocket decoder configuration. Must be non-null + * @return this {@link Http2WebSocketClientBuilder} instance + */ + public Http2WebSocketClientBuilder decoderConfig(WebSocketDecoderConfig webSocketDecoderConfig) { + this.webSocketDecoderConfig = Preconditions.requireNonNull(webSocketDecoderConfig, "webSocketDecoderConfig"); + return this; + } + + /** + * @param handshakeTimeoutMillis websocket handshake timeout. Must be positive + * @return this {@link Http2WebSocketClientBuilder} instance + */ + public Http2WebSocketClientBuilder handshakeTimeoutMillis(long handshakeTimeoutMillis) { + this.handshakeTimeoutMillis = + Preconditions.requirePositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); + return this; + } + + /** + * @param closedWebSocketRemoveTimeoutMillis delay until websockets handler forgets closed + * websocket. Necessary to gracefully handle incoming http2 frames racing with outgoing stream + * termination frame. + * @return this {@link Http2WebSocketClientBuilder} instance + */ + public Http2WebSocketClientBuilder closedWebSocketRemoveTimeoutMillis(long closedWebSocketRemoveTimeoutMillis) { + this.closedWebSocketRemoveTimeoutMillis = + Preconditions.requirePositive(closedWebSocketRemoveTimeoutMillis, "closedWebSocketRemoveTimeoutMillis"); + return this; + } + + /** + * @param isCompressionEnabled enables permessage-deflate compression with default configuration + * @return this {@link Http2WebSocketClientBuilder} instance + */ + public Http2WebSocketClientBuilder compression(boolean isCompressionEnabled) { + if (isCompressionEnabled) { + if (perMessageDeflateClientExtensionHandshaker == null) { + perMessageDeflateClientExtensionHandshaker = + new PerMessageDeflateClientExtensionHandshaker(); + } + } else { + perMessageDeflateClientExtensionHandshaker = null; + } + return this; + } + + /** + * Enables permessage-deflate compression with extended configuration. Parameters are described in + * netty's PerMessageDeflateClientExtensionHandshaker + * + * @param compressionLevel sets compression level. Range is [0; 9], default is 6 + * @param allowClientWindowSize allows server to customize the client's inflater window size, + * default is false + * @param requestedServerWindowSize requested server window size if server inflater is + * customizable + * @param allowClientNoContext allows server to activate client_no_context_takeover, default is + * false + * @param requestedServerNoContext whether client needs to activate server_no_context_takeover if + * server is compatible, default is false. + * @return this {@link Http2WebSocketClientBuilder} instance + */ + public Http2WebSocketClientBuilder compression(int compressionLevel, boolean allowClientWindowSize, + int requestedServerWindowSize, boolean allowClientNoContext, boolean requestedServerNoContext) { + perMessageDeflateClientExtensionHandshaker = + new PerMessageDeflateClientExtensionHandshaker(compressionLevel, allowClientWindowSize, + requestedServerWindowSize, allowClientNoContext, requestedServerNoContext); + return this; + } + + /** + * @param weight sets websocket http2 stream weight. Must belong to [1; 256] range + * @return this {@link Http2WebSocketClientBuilder} instance + */ + public Http2WebSocketClientBuilder streamWeight(int weight) { + this.streamWeight = Preconditions.requireRange(weight, 1, 256, "streamWeight"); + return this; + } + + /** + * @param isSingleWebSocketPerConnection optimize for at most 1 websocket per connection + * @return this {@link Http2WebSocketClientBuilder} instance + */ + public Http2WebSocketClientBuilder assumeSingleWebSocketPerConnection(boolean isSingleWebSocketPerConnection) { + this.isSingleWebSocketPerConnection = isSingleWebSocketPerConnection; + return this; + } + + /** @return new {@link Http2WebSocketClientHandler} instance */ + public Http2WebSocketClientHandler build() { + PerMessageDeflateClientExtensionHandshaker compressionHandshaker = perMessageDeflateClientExtensionHandshaker; + boolean hasCompression = compressionHandshaker != null; + WebSocketDecoderConfig config = webSocketDecoderConfig; + if (config == null) { + config = WebSocketDecoderConfig.newBuilder() + .expectMaskedFrames(false) + .allowMaskMismatch(false) + .allowExtensions(hasCompression) + .build(); + } else { + boolean isAllowExtensions = config.allowExtensions(); + if (!isAllowExtensions && hasCompression) { + config = config.toBuilder().allowExtensions(true).build(); + } + } + short weight = streamWeight; + if (weight == 0) { + weight = DEFAULT_STREAM_WEIGHT; + } + return new Http2WebSocketClientHandler(config, MASK_PAYLOAD, weight, handshakeTimeoutMillis, + closedWebSocketRemoveTimeoutMillis, compressionHandshaker, isSingleWebSocketPerConnection); + } +} diff --git a/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/ws2/Http2WebSocketClientHandler.java b/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/ws2/Http2WebSocketClientHandler.java new file mode 100644 index 0000000..5111934 --- /dev/null +++ b/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/ws2/Http2WebSocketClientHandler.java @@ -0,0 +1,175 @@ +package org.xbib.netty.http.client.handler.ws2; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.EventLoop; +import io.netty.handler.codec.http.websocketx.WebSocketDecoderConfig; +import io.netty.handler.codec.http.websocketx.extensions.compression.PerMessageDeflateClientExtensionHandshaker; +import io.netty.handler.codec.http2.*; +import io.netty.handler.ssl.SslHandler; +import org.xbib.netty.http.common.ws.Http2WebSocket; +import org.xbib.netty.http.common.ws.Http2WebSocketChannelHandler; +import org.xbib.netty.http.common.ws.Http2WebSocketProtocol; +import org.xbib.netty.http.common.ws.Http2WebSocketValidator; + +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +/** + * Provides client-side support for websocket-over-http2. Creates sub channel for http2 stream of + * successfully handshaked websocket. Subchannel is compatible with http1 websocket handlers. Should + * be used in tandem with {@link Http2WebSocketClientHandshaker} + */ +public final class Http2WebSocketClientHandler extends Http2WebSocketChannelHandler { + + private static final AtomicReferenceFieldUpdater HANDSHAKER = + AtomicReferenceFieldUpdater.newUpdater(Http2WebSocketClientHandler.class, Http2WebSocketClientHandshaker.class, "handshaker"); + + private final long handshakeTimeoutMillis; + + private final PerMessageDeflateClientExtensionHandshaker compressionHandshaker; + + private final short streamWeight; + + private CharSequence scheme; + + private Boolean supportsWebSocket; + + private boolean supportsWebSocketCalled; + + private volatile Http2Connection.Endpoint streamIdFactory; + + private volatile Http2WebSocketClientHandshaker handshaker; + + Http2WebSocketClientHandler( + WebSocketDecoderConfig webSocketDecoderConfig, + boolean isEncoderMaskPayload, + short streamWeight, + long handshakeTimeoutMillis, + long closedWebSocketRemoveTimeoutMillis, + PerMessageDeflateClientExtensionHandshaker compressionHandshaker, + boolean isSingleWebSocketPerConnection) { + super( + webSocketDecoderConfig, + isEncoderMaskPayload, + closedWebSocketRemoveTimeoutMillis, + isSingleWebSocketPerConnection); + this.streamWeight = streamWeight; + this.handshakeTimeoutMillis = handshakeTimeoutMillis; + this.compressionHandshaker = compressionHandshaker; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + super.handlerAdded(ctx); + this.scheme = + ctx.pipeline().get(SslHandler.class) != null + ? Http2WebSocketProtocol.SCHEME_HTTPS + : Http2WebSocketProtocol.SCHEME_HTTP; + this.streamIdFactory = http2Handler.connection().local(); + } + + @Override + public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) + throws Http2Exception { + if (supportsWebSocket != null) { + super.onSettingsRead(ctx, settings); + return; + } + Long extendedConnectEnabled = + settings.get(Http2WebSocketProtocol.SETTINGS_ENABLE_CONNECT_PROTOCOL); + boolean supports = + supportsWebSocket = extendedConnectEnabled != null && extendedConnectEnabled == 1; + Http2WebSocketClientHandshaker listener = HANDSHAKER.get(this); + if (listener != null) { + supportsWebSocketCalled = true; + listener.onSupportsWebSocket(supports); + } + super.onSettingsRead(ctx, settings); + } + + @Override + public void onHeadersRead( + ChannelHandlerContext ctx, + int streamId, + Http2Headers headers, + int padding, + boolean endOfStream) + throws Http2Exception { + boolean proceed = handshakeWebSocket(streamId, headers, endOfStream); + if (proceed) { + next().onHeadersRead(ctx, streamId, headers, padding, endOfStream); + } + } + + @Override + public void onHeadersRead( + ChannelHandlerContext ctx, + int streamId, + Http2Headers headers, + int streamDependency, + short weight, + boolean exclusive, + int padding, + boolean endOfStream) + throws Http2Exception { + boolean proceed = handshakeWebSocket(streamId, headers, endOfStream); + if (proceed) { + next().onHeadersRead(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream); + } + } + + Http2WebSocketClientHandshaker handShaker() { + Http2WebSocketClientHandshaker h = HANDSHAKER.get(this); + if (h != null) { + return h; + } + Http2Connection.Endpoint streamIdFactory = this.streamIdFactory; + if (streamIdFactory == null) { + throw new IllegalStateException( + "webSocket handshaker cant be created before channel is registered"); + } + Http2WebSocketClientHandshaker handShaker = + new Http2WebSocketClientHandshaker( + webSocketsParent, + streamIdFactory, + config, + isEncoderMaskPayload, + streamWeight, + scheme, + handshakeTimeoutMillis, + compressionHandshaker); + + if (HANDSHAKER.compareAndSet(this, null, handShaker)) { + EventLoop el = ctx.channel().eventLoop(); + if (el.inEventLoop()) { + onSupportsWebSocket(handShaker); + } else { + el.execute(() -> onSupportsWebSocket(handShaker)); + } + return handShaker; + } + return HANDSHAKER.get(this); + } + + private boolean handshakeWebSocket(int streamId, Http2Headers responseHeaders, boolean endOfStream) { + Http2WebSocket webSocket = webSocketRegistry.get(streamId); + if (webSocket != null) { + if (!Http2WebSocketValidator.isValid(responseHeaders)) { + handShaker().reject(streamId, webSocket, responseHeaders, endOfStream); + } else { + handShaker().handshake(webSocket, responseHeaders, endOfStream); + } + return false; + } + return true; + } + + private void onSupportsWebSocket(Http2WebSocketClientHandshaker handshaker) { + if (supportsWebSocketCalled) { + return; + } + Boolean supports = supportsWebSocket; + if (supports != null) { + handshaker.onSupportsWebSocket(supports); + } + } +} diff --git a/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/ws2/Http2WebSocketClientHandshaker.java b/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/ws2/Http2WebSocketClientHandshaker.java new file mode 100644 index 0000000..de28325 --- /dev/null +++ b/netty-http-client/src/main/java/org/xbib/netty/http/client/handler/ws2/Http2WebSocketClientHandshaker.java @@ -0,0 +1,502 @@ +package org.xbib.netty.http.client.handler.ws2; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.EventLoop; +import io.netty.handler.codec.http.websocketx.WebSocketDecoderConfig; +import io.netty.handler.codec.http.websocketx.WebSocketHandshakeException; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketClientExtension; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionData; +import io.netty.handler.codec.http.websocketx.extensions.compression.PerMessageDeflateClientExtensionHandshaker; +import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.Http2Connection; +import io.netty.handler.codec.http2.Http2Headers; +import io.netty.handler.codec.http2.Http2LocalFlowController; +import io.netty.util.AsciiString; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.ScheduledFuture; +import org.xbib.netty.http.common.ws.Http2WebSocket; +import org.xbib.netty.http.common.ws.Http2WebSocketChannel; +import org.xbib.netty.http.common.ws.Http2WebSocketChannelHandler; +import org.xbib.netty.http.common.ws.Http2WebSocketEvent; +import org.xbib.netty.http.common.ws.Http2WebSocketExtensions; +import org.xbib.netty.http.common.ws.Http2WebSocketMessages; +import org.xbib.netty.http.common.ws.Http2WebSocketProtocol; +import org.xbib.netty.http.common.ws.Preconditions; +import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayDeque; +import java.util.Objects; +import java.util.Queue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Establishes websocket-over-http2 on provided connection channel + */ +public final class Http2WebSocketClientHandshaker { + + private static final Logger logger = Logger.getLogger(Http2WebSocketClientHandshaker.class.getName()); + + private static final int ESTIMATED_DEFERRED_HANDSHAKES = 4; + + private static final AtomicIntegerFieldUpdater WEBSOCKET_CHANNEL_SERIAL = + AtomicIntegerFieldUpdater.newUpdater(Http2WebSocketClientHandshaker.class, "webSocketChannelSerial"); + + private static final Http2Headers EMPTY_HEADERS = new DefaultHttp2Headers(false); + + private final Http2Connection.Endpoint streamIdFactory; + + private final WebSocketDecoderConfig webSocketDecoderConfig; + + private final Http2WebSocketChannelHandler.WebSocketsParent webSocketsParent; + + private final short streamWeight; + + private final CharSequence scheme; + + private final PerMessageDeflateClientExtensionHandshaker compressionHandshaker; + + private final boolean isEncoderMaskPayload; + + private final long timeoutMillis; + + private Queue deferred; + + private Boolean supportsWebSocket; + + private volatile int webSocketChannelSerial; + + private CharSequence compressionExtensionHeader; + + Http2WebSocketClientHandshaker(Http2WebSocketChannelHandler.WebSocketsParent webSocketsParent, + Http2Connection.Endpoint streamIdFactory, + WebSocketDecoderConfig webSocketDecoderConfig, + boolean isEncoderMaskPayload, + short streamWeight, + CharSequence scheme, + long handshakeTimeoutMillis, PerMessageDeflateClientExtensionHandshaker compressionHandshaker) { + this.webSocketsParent = webSocketsParent; + this.streamIdFactory = streamIdFactory; + this.webSocketDecoderConfig = webSocketDecoderConfig; + this.isEncoderMaskPayload = isEncoderMaskPayload; + this.timeoutMillis = handshakeTimeoutMillis; + this.streamWeight = streamWeight; + this.scheme = scheme; + this.compressionHandshaker = compressionHandshaker; + } + + /** + * Creates new {@link Http2WebSocketClientHandshaker} for given connection channel + * + * @param channel connection channel. Pipeline must contain {@link Http2WebSocketClientHandler} + * and netty http2 codec (e.g. Http2ConnectionHandler or Http2FrameCodec) + * @return new {@link Http2WebSocketClientHandshaker} instance + */ + public static Http2WebSocketClientHandshaker create(Channel channel) { + Objects.requireNonNull(channel, "channel"); + return Preconditions.requireHandler(channel, Http2WebSocketClientHandler.class).handShaker(); + } + + /** + * Starts websocket-over-http2 handshake using given path + * + * @param path websocket path, must be non-empty + * @param webSocketHandler http1 websocket handler added to pipeline of subchannel created for + * successfully handshaked http2 websocket + * @return ChannelFuture with result of handshake. Its channel accepts http1 WebSocketFrames as + * soon as this method returns. + */ + public ChannelFuture handshake(String path, ChannelHandler webSocketHandler) { + return handshake(path, "", EMPTY_HEADERS, webSocketHandler); + } + + /** + * Starts websocket-over-http2 handshake using given path and request headers + * + * @param path websocket path, must be non-empty + * @param requestHeaders request headers, must be non-null + * @param webSocketHandler http1 websocket handler added to pipeline of subchannel created for + * successfully handshaked http2 websocket + * @return ChannelFuture with result of handshake. Its channel accepts http1 WebSocketFrames as + * soon as this method returns. + */ + public ChannelFuture handshake( + String path, Http2Headers requestHeaders, ChannelHandler webSocketHandler) { + return handshake(path, "", requestHeaders, webSocketHandler); + } + + /** + * Starts websocket-over-http2 handshake using given path and subprotocol + * + * @param path websocket path, must be non-empty + * @param subprotocol websocket subprotocol, must be non-null + * @param webSocketHandler http1 websocket handler added to pipeline of subchannel created for + * successfully handshaked http2 websocket + * @return ChannelFuture with result of handshake. Its channel accepts http1 WebSocketFrames as + * soon as this method returns. + */ + public ChannelFuture handshake(String path, String subprotocol, ChannelHandler webSocketHandler) { + return handshake(path, subprotocol, EMPTY_HEADERS, webSocketHandler); + } + + /** + * Starts websocket-over-http2 handshake using given path, subprotocol and request headers + * + * @param path websocket path, must be non-empty + * @param subprotocol websocket subprotocol, must be non-null + * @param requestHeaders request headers, must be non-null + * @param webSocketHandler http1 websocket handler added to pipeline of subchannel created for + * successfully handshaked http2 websocket + * @return ChannelFuture with result of handshake. Its channel accepts http1 WebSocketFrames as + * soon as this method returns. + */ + public ChannelFuture handshake(String path, String subprotocol, + Http2Headers requestHeaders, ChannelHandler webSocketHandler) { + Preconditions.requireNonEmpty(path, "path"); + Preconditions.requireNonNull(subprotocol, "subprotocol"); + Preconditions.requireNonNull(requestHeaders, "requestHeaders"); + Preconditions.requireNonNull(webSocketHandler, "webSocketHandler"); + long startNanos = System.nanoTime(); + ChannelHandlerContext ctx = webSocketsParent.context(); + if (!ctx.channel().isOpen()) { + return ctx.newFailedFuture(new ClosedChannelException()); + } + int serial = WEBSOCKET_CHANNEL_SERIAL.getAndIncrement(this); + Http2WebSocketChannel webSocketChannel = new Http2WebSocketChannel(webSocketsParent, serial, path, + subprotocol, webSocketDecoderConfig, isEncoderMaskPayload, webSocketHandler).initialize(); + Handshake handshake = new Handshake(webSocketChannel, requestHeaders, timeoutMillis, startNanos); + handshake.future().addListener(future -> { + Throwable cause = future.cause(); + if (cause != null && !(cause instanceof WebSocketHandshakeException)) { + Http2WebSocketEvent.fireHandshakeError(webSocketChannel, null, System.nanoTime(), cause); + } + }); + EventLoop el = ctx.channel().eventLoop(); + if (el.inEventLoop()) { + handshakeOrDefer(handshake, el); + } else { + el.execute(() -> handshakeOrDefer(handshake, el)); + } + return webSocketChannel.handshakePromise(); + } + + void handshake(Http2WebSocket webSocket, Http2Headers responseHeaders, boolean endOfStream) { + if (webSocket == Http2WebSocket.CLOSED) { + return; + } + Http2WebSocketChannel webSocketChannel = (Http2WebSocketChannel) webSocket; + ChannelPromise handshakePromise = webSocketChannel.handshakePromise(); + if (handshakePromise.isDone()) { + return; + } + String errorMessage = null; + WebSocketClientExtension compressionExtension = null; + String status = responseHeaders.status().toString(); + switch (status) { + case "200": + if (endOfStream) { + errorMessage = Http2WebSocketMessages.HANDSHAKE_UNEXPECTED_RESULT; + } else { + String clientSubprotocol = webSocketChannel.subprotocol(); + CharSequence serverSubprotocol = + responseHeaders.get(Http2WebSocketProtocol.HEADER_WEBSOCKET_SUBPROTOCOL_NAME); + if (!isEqual(clientSubprotocol, serverSubprotocol)) { + errorMessage = + Http2WebSocketMessages.HANDSHAKE_UNEXPECTED_SUBPROTOCOL + clientSubprotocol; + } + if (errorMessage == null) { + PerMessageDeflateClientExtensionHandshaker handshaker = compressionHandshaker; + if (handshaker != null) { + CharSequence extensionsHeader = responseHeaders.get(Http2WebSocketProtocol.HEADER_WEBSOCKET_EXTENSIONS_NAME); + WebSocketExtensionData compression = Http2WebSocketExtensions.decode(extensionsHeader); + if (compression != null) { + compressionExtension = handshaker.handshakeExtension(compression); + } + } + } + } + break; + case "400": + CharSequence webSocketVersion = + responseHeaders.get(Http2WebSocketProtocol.HEADER_WEBSOCKET_VERSION_NAME); + errorMessage = webSocketVersion != null + ? Http2WebSocketMessages.HANDSHAKE_UNSUPPORTED_VERSION + webSocketVersion + : Http2WebSocketMessages.HANDSHAKE_BAD_REQUEST; + break; + case "404": + errorMessage = Http2WebSocketMessages.HANDSHAKE_PATH_NOT_FOUND + + webSocketChannel.path() + + Http2WebSocketMessages.HANDSHAKE_PATH_NOT_FOUND_SUBPROTOCOLS + + webSocketChannel.subprotocol(); + break; + default: + errorMessage = Http2WebSocketMessages.HANDSHAKE_GENERIC_ERROR + status; + } + if (errorMessage != null) { + Exception cause = new WebSocketHandshakeException(errorMessage); + if (handshakePromise.tryFailure(cause)) { + Http2WebSocketEvent.fireHandshakeError(webSocketChannel, responseHeaders, System.nanoTime(), cause); + } + return; + } + if (compressionExtension != null) { + webSocketChannel.compression(compressionExtension.newExtensionEncoder(), compressionExtension.newExtensionDecoder()); + } + if (handshakePromise.trySuccess()) { + Http2WebSocketEvent.fireHandshakeSuccess(webSocketChannel, responseHeaders, System.nanoTime()); + } + } + + void reject(int streamId, Http2WebSocket webSocket, Http2Headers headers, boolean endOfStream) { + Http2WebSocketEvent.fireHandshakeValidationStartAndError(webSocketsParent.context().channel(), + streamId, headers.set(AsciiString.of("x-websocket-endofstream"), AsciiString.of(endOfStream ? "true" : "false"))); + if (webSocket == Http2WebSocket.CLOSED) { + return; + } + Http2WebSocketChannel webSocketChannel = (Http2WebSocketChannel) webSocket; + ChannelPromise handshakePromise = webSocketChannel.handshakePromise(); + if (handshakePromise.isDone()) { + return; + } + Exception cause = new WebSocketHandshakeException(Http2WebSocketMessages.HANDSHAKE_INVALID_RESPONSE_HEADERS); + if (handshakePromise.tryFailure(cause)) { + Http2WebSocketEvent.fireHandshakeError(webSocketChannel, headers, System.nanoTime(), cause); + } + } + + void onSupportsWebSocket(boolean supportsWebSocket) { + if (!supportsWebSocket) { + logger.log(Level.SEVERE, Http2WebSocketMessages.HANDSHAKE_UNSUPPORTED_BOOTSTRAP); + } + this.supportsWebSocket = supportsWebSocket; + handshakeDeferred(supportsWebSocket); + } + + private void handshakeOrDefer(Handshake handshake, EventLoop eventLoop) { + if (handshake.isDone()) { + return; + } + Http2WebSocketChannel webSocketChannel = handshake.webSocketChannel(); + Http2Headers requestHeaders = handshake.requestHeaders(); + long startNanos = handshake.startNanos(); + ChannelFuture registered = eventLoop.register(webSocketChannel); + if (!registered.isSuccess()) { + Throwable cause = registered.cause(); + Exception e = new WebSocketHandshakeException("websocket handshake channel registration error", cause); + Http2WebSocketEvent.fireHandshakeStartAndError(webSocketChannel.parent(), + webSocketChannel.serial(), webSocketChannel.path(), webSocketChannel.subprotocol(), + requestHeaders, startNanos, System.nanoTime(), e); + handshake.complete(e); + return; + } + Http2WebSocketEvent.fireHandshakeStart(webSocketChannel, requestHeaders, startNanos); + Boolean supports = supportsWebSocket; + if (supports == null) { + Queue d = deferred; + if (d == null) { + d = deferred = new ArrayDeque<>(ESTIMATED_DEFERRED_HANDSHAKES); + } + handshake.startTimeout(); + d.add(handshake); + return; + } + if (supports) { + handshake.startTimeout(); + } + handshakeImmediate(handshake, supports); + } + + private void handshakeDeferred(boolean supportsWebSocket) { + Queue d = deferred; + if (d == null) { + return; + } + deferred = null; + Handshake handshake = d.poll(); + while (handshake != null) { + handshakeImmediate(handshake, supportsWebSocket); + handshake = d.poll(); + } + } + + private void handshakeImmediate(Handshake handshake, boolean supportsWebSocket) { + Http2WebSocketChannel webSocketChannel = handshake.webSocketChannel(); + Http2Headers customHeaders = handshake.requestHeaders(); + if (handshake.isDone()) { + return; + } + if (!supportsWebSocket) { + WebSocketHandshakeException e = new WebSocketHandshakeException(Http2WebSocketMessages.HANDSHAKE_UNSUPPORTED_BOOTSTRAP); + Http2WebSocketEvent.fireHandshakeError(webSocketChannel, null, System.nanoTime(), e); + handshake.complete(e); + return; + } + int streamId = streamIdFactory.incrementAndGetNextStreamId(); + webSocketsParent.register(streamId, webSocketChannel.setStreamId(streamId)); + String authority = authority(); + String path = webSocketChannel.path(); + Http2Headers headers = Http2WebSocketProtocol.extendedConnect(new DefaultHttp2Headers() + .scheme(scheme) + .authority(authority) + .path(path) + .set(Http2WebSocketProtocol.HEADER_WEBSOCKET_VERSION_NAME, + Http2WebSocketProtocol.HEADER_WEBSOCKET_VERSION_VALUE)); + PerMessageDeflateClientExtensionHandshaker handshaker = compressionHandshaker; + if (handshaker != null) { + headers.set(Http2WebSocketProtocol.HEADER_WEBSOCKET_EXTENSIONS_NAME, + compressionExtensionHeader(handshaker)); + } + String subprotocol = webSocketChannel.subprotocol(); + if (!subprotocol.isEmpty()) { + headers.set(Http2WebSocketProtocol.HEADER_WEBSOCKET_SUBPROTOCOL_NAME, subprotocol); + } + if (!customHeaders.isEmpty()) { + headers.setAll(customHeaders); + } + short pendingStreamWeight = webSocketChannel.pendingStreamWeight(); + short weight = pendingStreamWeight > 0 ? pendingStreamWeight : streamWeight; + webSocketsParent.writeHeaders(webSocketChannel.streamId(), headers, false, weight) + .addListener(future -> { + if (!future.isSuccess()) { + handshake.complete(future.cause()); + return; + } + webSocketChannel.setStreamWeightAttribute(weight); + }); + } + + private String authority() { + return ((InetSocketAddress) webSocketsParent.context().channel().remoteAddress()).getHostString(); + } + + private CharSequence compressionExtensionHeader(PerMessageDeflateClientExtensionHandshaker handshaker) { + CharSequence header = compressionExtensionHeader; + if (header == null) { + header = compressionExtensionHeader = AsciiString.of(Http2WebSocketExtensions.encode(handshaker.newRequestData())); + } + return header; + } + + private static boolean isEqual(String str, CharSequence seq) { + if ((seq == null || seq.length() == 0) && str.isEmpty()) { + return true; + } + if (seq == null) { + return false; + } + return str.contentEquals(seq); + } + + static class Handshake { + private final Future channelClose; + private final ChannelPromise handshake; + private final long timeoutMillis; + private boolean done; + private ScheduledFuture timeoutFuture; + private Future handshakeCompleteFuture; + private GenericFutureListener channelCloseListener; + private final Http2WebSocketChannel webSocketChannel; + private final Http2Headers requestHeaders; + private final long handshakeStartNanos; + + public Handshake(Http2WebSocketChannel webSocketChannel, Http2Headers requestHeaders, + long timeoutMillis, long handshakeStartNanos) { + this.channelClose = webSocketChannel.closeFuture(); + this.handshake = webSocketChannel.handshakePromise(); + this.timeoutMillis = timeoutMillis; + this.webSocketChannel = webSocketChannel; + this.requestHeaders = requestHeaders; + this.handshakeStartNanos = handshakeStartNanos; + } + + public void startTimeout() { + ChannelPromise h = handshake; + Channel channel = h.channel(); + if (done) { + return; + } + GenericFutureListener l = channelCloseListener = future -> onConnectionClose(); + channelClose.addListener(l); + if (done) { + return; + } + handshakeCompleteFuture = h.addListener(future -> onHandshakeComplete(future.cause())); + if (done) { + return; + } + timeoutFuture = channel.eventLoop().schedule(this::onTimeout, timeoutMillis, TimeUnit.MILLISECONDS); + } + + public void complete(Throwable e) { + onHandshakeComplete(e); + } + + public boolean isDone() { + return done; + } + + public ChannelFuture future() { + return handshake; + } + + public Http2WebSocketChannel webSocketChannel() { + return webSocketChannel; + } + + public Http2Headers requestHeaders() { + return requestHeaders; + } + + public long startNanos() { + return handshakeStartNanos; + } + private void onConnectionClose() { + if (!done) { + handshake.tryFailure(new ClosedChannelException()); + done(); + } + } + + private void onHandshakeComplete(Throwable cause) { + if (!done) { + if (cause != null) { + handshake.tryFailure(cause); + } else { + handshake.trySuccess(); + } + done(); + } + } + + private void onTimeout() { + if (!done) { + handshake.tryFailure(new TimeoutException()); + done(); + } + } + + private void done() { + done = true; + GenericFutureListener closeListener = channelCloseListener; + if (closeListener != null) { + channelClose.removeListener(closeListener); + } + cancel(handshakeCompleteFuture); + cancel(timeoutFuture); + } + + private void cancel(Future future) { + if (future != null) { + future.cancel(true); + } + } + } +} diff --git a/netty-http-client/src/test/java/org/xbib/netty/http/client/test/http1/GoogleTest.java b/netty-http-client/src/test/java/org/xbib/netty/http/client/test/http1/GoogleTest.java index a77469d..03f26e7 100644 --- a/netty-http-client/src/test/java/org/xbib/netty/http/client/test/http1/GoogleTest.java +++ b/netty-http-client/src/test/java/org/xbib/netty/http/client/test/http1/GoogleTest.java @@ -23,7 +23,7 @@ class GoogleTest { void testHttp1WithTlsV13() throws Exception { AtomicBoolean success = new AtomicBoolean(); Client client = Client.builder() - .setTransportLayerSecurityProtocols(new String[] { "TLSv1.3" }) + .setTransportLayerSecurityProtocols("TLSv1.3") .build(); try { Request request = Request.get().url("https://www.google.com/") diff --git a/netty-http-client/src/test/java/org/xbib/netty/http/client/test/http1/Http1Test.java b/netty-http-client/src/test/java/org/xbib/netty/http/client/test/http1/Http1Test.java index 244f60b..38e6975 100644 --- a/netty-http-client/src/test/java/org/xbib/netty/http/client/test/http1/Http1Test.java +++ b/netty-http-client/src/test/java/org/xbib/netty/http/client/test/http1/Http1Test.java @@ -24,7 +24,8 @@ class Http1Test { Client client = Client.builder() .build(); try { - Request request = Request.get().url("https://xbib.org") + Request request = Request.get() + .url("https://xbib.org") .setResponseListener(resp -> logger.log(Level.FINE, "got response: " + resp.getHeaders() + resp.getBodyAsString(StandardCharsets.UTF_8) + diff --git a/netty-http-client/websocket.md b/netty-http-client/websocket.md new file mode 100644 index 0000000..a54ccf9 --- /dev/null +++ b/netty-http-client/websocket.md @@ -0,0 +1,258 @@ + +# netty-websocket-http2 + +Netty based implementation of [rfc8441](https://tools.ietf.org/html/rfc8441) - bootstrapping websockets with http/2 + +Library addresses two use cases: for application servers and clients, +It is transparent use of existing http1 websocket handlers on top of http2 streams; for gateways/proxies, +It is websockets-over-http2 support with no http1 dependencies and minimal overhead. + +[https://jauntsdn.com/post/netty-websocket-http2/](https://jauntsdn.com/post/netty-websocket-http2/) + +### websocket channel API +Intended for application servers and clients. +Allows transparent application of existing http1 websocket handlers on top of http2 stream. + +* Server +```groovy +EchoWebSocketHandler http1WebSocketHandler = new EchoWebSocketHandler(); + + Http2WebSocketServerHandler http2webSocketHandler = + Http2WebSocketServerBuilder.create() + .acceptor( + (ctx, path, subprotocols, request, response) -> { + switch (path) { + case "/echo": + if (subprotocols.contains("echo.jauntsdn.com") + && acceptUserAgent(request, response)) { + /*selecting subprotocol for accepted requests is mandatory*/ + Http2WebSocketAcceptor.Subprotocol + .accept("echo.jauntsdn.com", response); + return ctx.executor() + .newSucceededFuture(http1WebSocketHandler); + } + break; + case "/echo_all": + if (subprotocols.isEmpty() + && acceptUserAgent(request, response)) { + return ctx.executor() + .newSucceededFuture(http1WebSocketHandler); + } + break; + } + return ctx.executor() + .newFailedFuture( + new WebSocketHandshakeException( + "websocket rejected, path: " + path)); + }) + .build(); + + ch.pipeline() + .addLast(sslHandler, + http2frameCodec, + http2webSocketHandler); +``` + +* Client +```groovy + Channel channel = + new Bootstrap() + .handler( + new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + + Http2WebSocketClientHandler http2WebSocketClientHandler = + Http2WebSocketClientBuilder.create() + .handshakeTimeoutMillis(15_000) + .build(); + + ch.pipeline() + .addLast( + sslHandler, + http2FrameCodec, + http2WebSocketClientHandler); + } + }) + .connect(address) + .sync() + .channel(); + +Http2WebSocketClientHandshaker handShaker = Http2WebSocketClientHandshaker.create(channel); + +Http2Headers headers = + new DefaultHttp2Headers().set("user-agent", "jauntsdn-websocket-http2-client/1.1.2"); +ChannelFuture handshakeFuture = + /*http1 websocket handler*/ + handShaker.handshake("/echo", headers, new EchoWebSocketHandler()); + +handshakeFuture.channel().writeAndFlush(new TextWebSocketFrame("hello http2 websocket")); +``` +Successfully handshaked http2 stream spawns websocket subchannel, with provided +http1 websocket handlers on its pipeline. + +Runnable demo is available in `netty-websocket-http2-example` module - +[channelserver](https://github.com/jauntsdn/netty-websocket-http2/blob/develop/netty-websocket-http2-example/src/main/java/com/jauntsdn/netty/handler/codec/http2/websocketx/example/channelserver/Main.java), +[channelclient](https://github.com/jauntsdn/netty-websocket-http2/blob/develop/netty-websocket-http2-example/src/main/java/com/jauntsdn/netty/handler/codec/http2/websocketx/example/channelclient/Main.java). + +### websocket handshake only API +Intended for intermediaries/proxies. +Only verifies whether http2 stream is valid websocket, then passes it down the pipeline as `POST` request with `x-protocol=websocket` header. + +```groovy + Http2WebSocketServerHandler http2webSocketHandler = + Http2WebSocketServerBuilder.buildHandshakeOnly(); + + Http2StreamsHandler http2StreamsHandler = new Http2StreamsHandler(); + ch.pipeline() + .addLast(sslHandler, + frameCodec, + http2webSocketHandler, + http2StreamsHandler); +``` + +Works with both callbacks-style `Http2ConnectionHandler` and frames based `Http2FrameCodec`. + +``` +Http2WebSocketServerBuilder.buildHandshakeOnly(); +``` + +Runnable demo is available in `netty-websocket-http2-example` module - +[handshakeserver](https://github.com/jauntsdn/netty-websocket-http2/blob/develop/netty-websocket-http2-example/src/main/java/com/jauntsdn/netty/handler/codec/http2/websocketx/example/handshakeserver/Main.java), +[channelclient](https://github.com/jauntsdn/netty-websocket-http2/blob/develop/netty-websocket-http2-example/src/main/java/com/jauntsdn/netty/handler/codec/http2/websocketx/example/channelclient/Main.java). + +### configuration +Initial settings of server http2 codecs (`Http2ConnectionHandler` or `Http2FrameCodec`) should contain [SETTINGS_ENABLE_CONNECT_PROTOCOL=1](https://tools.ietf.org/html/rfc8441#section-9.1) +parameter to advertise websocket-over-http2 support. + +Also server http2 codecs must disable built-in headers validation because It is not compatible +with rfc8441 due to newly introduced `:protocol` pseudo-header. All websocket handlers provided by this library +do headers validation on their own - both for websocket and non-websocket requests. + +Above configuration may be done with utility methods of `Http2WebSocketServerBuilder` + +``` +public static Http2FrameCodecBuilder configureHttp2Server( + Http2FrameCodecBuilder http2Builder); + +public static Http2ConnectionHandlerBuilder configureHttp2Server( + Http2ConnectionHandlerBuilder http2Builder) +``` + +### compression & subprotocols +Client and server `permessage-deflate` compression configuration is shared by all streams +```groovy +Http2WebSocketServerBuilder.compression(enabled); +``` +or +```groovy +Http2WebSocketServerBuilder.compression( + compressionLevel, + allowServerWindowSize, + preferredClientWindowSize, + allowServerNoContext, + preferredClientNoContext); +``` +Client subprotocols are configured on per-path basis +```groovy +EchoWebSocketHandler http1WebsocketHandler = new EchoWebSocketHandler(); +ChannelFuture handshake = + handShaker.handshake("/echo", "subprotocol", headers, http1WebsocketHandler); +``` +On a server It is responsibility of `Http2WebSocketAcceptor` to select supported subprotocol with +```groovy +Http2WebSocketAcceptor.Subprotocol.accept(subprotocol, response); +``` +### lifecycle + +Handshake events and several shutdown options are available when +using `Websocket channel` style APIs. + +#### handshake events + +Events are fired on parent channel, also on websocket channel if one gets created +* `Http2WebSocketHandshakeStartEvent(websocketId, path, subprotocols, timestampNanos, requestHeaders)` +* `Http2WebSocketHandshakeErrorEvent(webSocketId, path, subprotocols, timestampNanos, responseHeaders, error)` +* `Http2WebSocketHandshakeSuccessEvent(webSocketId, path, subprotocols, timestampNanos, responseHeaders)` + +#### close events + +Outbound `Http2WebSocketLocalCloseEvent` on websocket channel pipeline closes +http2 stream by sending empty `DATA` frame with `END_STREAM` flag set. + +Graceful and `RST` stream shutdown by remote endpoint is represented with inbound `Http2WebSocketRemoteCloseEvent` +(with type `CLOSE_REMOTE_ENDSTREAM` and `CLOSE_REMOTE_RESET` respectively) on websocket channel pipeline. + +Graceful connection shutdown by remote with `GO_AWAY` frame is represented by inbound `Http2WebSocketRemoteGoAwayEvent` +on websocket channel pipeline. + +#### shutdown + +Closing websocket channel terminates its http2 stream by sending `RST` frame. + +#### validation & write error events + +Both API style handlers send `Http2WebSocketHandshakeErrorEvent` for invalid websocket-over-http2 and http requests. +For http2 frame write errors `Http2WebSocketWriteErrorEvent` is sent on parent channel if auto-close is not enabled; +otherwise exception is delivered with `ChannelPipeline.fireExceptionCaught` followed by immediate close. + +### flow control + +Inbound flow control is done automatically as soon as `DATA` frames are received. +Library relies on netty's `DefaultHttp2LocalFlowController` for refilling receive window. + +Outbound flow control is expressed as websocket channels writability change on send window +exhaust/refill, provided by `DefaultHttp2RemoteFlowController`. + +### websocket stream weight + +Initial stream weight is configured with + +```groovy +Http2WebSocketClientBuilder.streamWeight(weight); +``` +it can be updated by firing `Http2WebSocketStreamWeightUpdateEvent` on websocket channel pipeline. + +### performance + +Library relies on capabilities provided by netty's `Http2ConnectionHandler` so performance characteristics should be similar. +[netty-websocket-http2-perftest](https://github.com/jauntsdn/netty-websocket-http2/tree/develop/netty-websocket-http2-perftest/src/main/java/com/jauntsdn/netty/handler/codec/http2/websocketx/perftest) +module contains application that gives rough throughput/latency estimate. The application is started with `perf_server.sh`, `perf_client.sh`. + +On modern box one can expect following results for single websocket: + +To evaluate performance with multiple connections we compose an application comprised with simple echo server, and client +sending batches of messages periodically over single websocket per connection (approximately models chat application) + +With 25k active connections each sending batches of 5-10 messages of 0.2-0.5 KBytes over single websocket every 15-30seconds, +the results are as follows (measured over time spans of 5 seconds): + +### examples + +`netty-websocket-http2-example` module contains demos showcasing both API styles, with this library/browser as clients. + +* `channelserver, channelclient` packages for websocket subchannel API demos. +* `handshakeserver, channelclient` packages for handshake only API demo. +* `lwsclient` package for client demo that runs against [https://libwebsockets.org/testserver/](https://libwebsockets.org/testserver/) which hosts websocket-over-http2 + server implemented with [libwebsockets](https://github.com/warmcat/libwebsockets) - popular C-based networking library. + +### browser example +`Channelserver` example serves web page at `https://www.localhost:8099` that sends pings to `/echo` endpoint. + +Currently only `Mozilla Firefox` and latest `Google Chrome` support websockets-over-http2. + +### build & binaries +``` +./gradlew +``` + +Releases are published on MavenCentral +```groovy +repositories { + mavenCentral() +} + +dependencies { + implementation 'com.jauntsdn.netty:netty-websocket-http2:1.1.2' +} +``` diff --git a/netty-http-common/src/main/java/module-info.java b/netty-http-common/src/main/java/module-info.java index 20e2672..b1cafc3 100644 --- a/netty-http-common/src/main/java/module-info.java +++ b/netty-http-common/src/main/java/module-info.java @@ -4,6 +4,7 @@ module org.xbib.netty.http.common { exports org.xbib.netty.http.common.mime; exports org.xbib.netty.http.common.security; exports org.xbib.netty.http.common.util; + exports org.xbib.netty.http.common.ws; requires org.xbib.net.url; requires io.netty.buffer; requires io.netty.common; diff --git a/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocket.java b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocket.java new file mode 100644 index 0000000..3fe54ba --- /dev/null +++ b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocket.java @@ -0,0 +1,55 @@ +package org.xbib.netty.http.common.ws; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http2.Http2Exception; +import io.netty.handler.codec.http2.Http2Flags; +import io.netty.handler.codec.http2.Http2FrameAdapter; +import io.netty.handler.codec.http2.Http2FrameListener; + +public interface Http2WebSocket extends Http2FrameListener { + + @Override + void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData); + + void trySetWritable(); + + void fireExceptionCaught(Throwable t); + + void streamClosed(); + + void closeForcibly(); + + Http2WebSocket CLOSED = new Http2WebSocketClosedChannel(); + + class Http2WebSocketClosedChannel extends Http2FrameAdapter implements Http2WebSocket { + + @Override + public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) {} + + @Override + public void streamClosed() {} + + @Override + public void trySetWritable() {} + + @Override + public void fireExceptionCaught(Throwable t) {} + + @Override + public void closeForcibly() {} + + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) + throws Http2Exception { + int processed = super.onDataRead(ctx, streamId, data, padding, endOfStream); + data.release(); + return processed; + } + + @Override + public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, ByteBuf payload) { + payload.release(); + } + } +} diff --git a/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketChannel.java b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketChannel.java new file mode 100644 index 0000000..f40ce26 --- /dev/null +++ b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketChannel.java @@ -0,0 +1,1318 @@ +package org.xbib.netty.http.common.ws; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.*; +import io.netty.handler.codec.http.websocketx.*; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionDecoder; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionEncoder; +import io.netty.handler.codec.http2.*; +import io.netty.util.AttributeKey; +import io.netty.util.DefaultAttributeMap; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.internal.StringUtil; +import java.io.IOException; +import java.net.SocketAddress; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.logging.Level; +import java.util.logging.Logger; + +public final class Http2WebSocketChannel extends DefaultAttributeMap + implements Channel, Http2WebSocket, GenericFutureListener { + + private static final Logger logger = Logger.getLogger(Http2WebSocketChannel.class.getName()); + + private static final ChannelMetadata METADATA = new ChannelMetadata(false, 16); + + private static final AttributeKey STREAM_WEIGHT_KEY = + AttributeKey.newInstance("com.jauntsdn.netty.handler.codec.http2.websocketx.stream_weight"); + + private static final GenericFutureListener FRAME_WRITE_LISTENER = + new FrameWriteListener(); + + private static final MessageSizeEstimator.Handle MESSAGE_SIZE_ESTIMATOR_INSTANCE = + DefaultMessageSizeEstimator.DEFAULT.newHandle(); + + private static final AtomicLongFieldUpdater TOTAL_PENDING_SIZE_UPDATER = + AtomicLongFieldUpdater.newUpdater(Http2WebSocketChannel.class, "totalPendingSize"); + + private static final AtomicIntegerFieldUpdater UNWRITABLE_UPDATER = + AtomicIntegerFieldUpdater.newUpdater(Http2WebSocketChannel.class, "unwritable"); + + private final Http2StreamChannelConfig config = new Http2StreamChannelConfig(this); + + private final Http2ChannelUnsafe unsafe = new Http2ChannelUnsafe(); + + private final ChannelId channelId; + + private final ChannelPipeline pipeline; + + private final Http2WebSocketChannelHandler.WebSocketsParent webSocketChannelParent; + + private final int websocketChannelSerial; + + private final String path; + + private final String subprotocol; + + private final ChannelPromise closePromise; + + private final ChannelPromise handshakePromise; + + private GenericFutureListener handshakePromiseListener; + + private volatile int streamId; + + private volatile boolean registered; + + private volatile long totalPendingSize; + + private volatile int unwritable; + + private Runnable fireChannelWritabilityChangedTask; + + private boolean outboundClosed; + + Boolean closeInitiator; + + /** + * This variable represents if a read is in progress for the current channel or was requested. + * Note that depending upon the {@link RecvByteBufAllocator} behavior a read may extend beyond the + * {@link Http2ChannelUnsafe#beginRead()} method scope. The {@link Http2ChannelUnsafe#beginRead()} + * loop may drain all pending data, and then if the parent channel is reading this channel may + * still accept frames. + */ + private ReadStatus readStatus = ReadStatus.IDLE; + + private Queue inboundBuffer; + private boolean readCompletePending; + private short pendingStreamWeight; + private WebSocketExtensionEncoder compressionEncoder; + private WebSocketExtensionDecoder compressionDecoder; + boolean isHandshakeCompleted; + + public Http2WebSocketChannel(Http2WebSocketChannelHandler.WebSocketsParent webSocketChannelParent, + int websocketChannelSerial, + String path, + String subprotocol, + WebSocketDecoderConfig config, + boolean isEncoderMaskPayload, + WebSocketExtensionEncoder compressionEncoder, + WebSocketExtensionDecoder compressionDecoder, + ChannelHandler websocketHandler) { + this.isHandshakeCompleted = true; + this.webSocketChannelParent = webSocketChannelParent; + this.websocketChannelSerial = websocketChannelSerial; + this.path = path; + this.subprotocol = subprotocol; + channelId = new Http2WebSocketChannelId(parent().id(), websocketChannelSerial); + ChannelPipeline pl = pipeline = new WebSocketChannelPipeline(this); + if (compressionEncoder != null && compressionDecoder != null) { + pl.addLast(new WebSocket13FrameDecoder(config), + compressionDecoder, + new WebSocket13FrameEncoder(isEncoderMaskPayload), + compressionEncoder); + } else { + pl.addLast(new WebSocket13FrameDecoder(config), new WebSocket13FrameEncoder(isEncoderMaskPayload)); + } + if (config.withUTF8Validator()) { + pl.addLast(new Utf8FrameValidator()); + } + pl.addLast(websocketHandler); + closePromise = pl.newPromise(); + handshakePromise = null; + parent().closeFuture().addListener(this); + } + + public Http2WebSocketChannel(Http2WebSocketChannelHandler.WebSocketsParent webSocketChannelParent, + int websocketChannelSerial, + String path, + String subprotocol, + WebSocketDecoderConfig config, + boolean isEncoderMaskPayload, + ChannelHandler websocketHandler) { + this.webSocketChannelParent = webSocketChannelParent; + this.websocketChannelSerial = websocketChannelSerial; + this.path = path; + this.subprotocol = subprotocol; + channelId = new Http2WebSocketChannelId(parent().id(), websocketChannelSerial); + ChannelPipeline pl = pipeline = new WebSocketChannelPipeline(this); + PreHandshakeHandler preHandshakeHandler = new PreHandshakeHandler(); + pl.addLast(preHandshakeHandler, websocketHandler); + closePromise = pl.newPromise(); + handshakePromise = pl.newPromise(); + handshakePromiseListener = + new CompleteClientHandshake(config, isEncoderMaskPayload, preHandshakeHandler); + } + + public Http2WebSocketChannel initialize() { + GenericFutureListener handshakeListener = handshakePromiseListener; + handshakePromiseListener = null; + handshakePromise.addListener(handshakeListener); + parent().closeFuture().addListener(this); + return this; + } + + class CompleteClientHandshake implements GenericFutureListener { + private final WebSocketDecoderConfig config; + private final boolean isEncoderMaskPayload; + private final PreHandshakeHandler preHandshakeHandler; + + public CompleteClientHandshake(WebSocketDecoderConfig config, + boolean isEncoderMaskPayload, + PreHandshakeHandler preHandshakeHandler) { + this.config = config; + this.isEncoderMaskPayload = isEncoderMaskPayload; + this.preHandshakeHandler = preHandshakeHandler; + } + + @Override + public void operationComplete(ChannelFuture future) { + isHandshakeCompleted = true; + Throwable cause = future.cause(); + if (cause != null) { + preHandshakeHandler.cancel(cause); + return; + } + WebSocketDecoderConfig config = this.config; + ChannelPipeline pl = pipeline(); + if (config.withUTF8Validator()) { + pl.addFirst(new Utf8FrameValidator()); + } + WebSocketExtensionEncoder encoder = compressionEncoder; + WebSocketExtensionDecoder decoder = compressionDecoder; + if (encoder != null && decoder != null) { + pl.addFirst(new WebSocket13FrameDecoder(config), + decoder, + new WebSocket13FrameEncoder(isEncoderMaskPayload), + encoder); + } else { + pl.addFirst(new WebSocket13FrameDecoder(config), new WebSocket13FrameEncoder(isEncoderMaskPayload)); + } + preHandshakeHandler.complete(); + } + } + + public int serial() { + return websocketChannelSerial; + } + + public String path() { + return path; + } + + public String subprotocol() { + return subprotocol; + } + + public short pendingStreamWeight() { + short weight = pendingStreamWeight; + pendingStreamWeight = 0; + return weight; + } + + public void compression(WebSocketExtensionEncoder compressionEncoder, WebSocketExtensionDecoder compressionDecoder) { + this.compressionEncoder = compressionEncoder; + this.compressionDecoder = compressionDecoder; + } + + @Override + public void operationComplete(ChannelFuture future) { + streamClosed(); + } + + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) { + int readableBytes = data.readableBytes(); + if (padding > 0) { + data.release(); + pipeline().fireExceptionCaught(new IllegalArgumentException("Http2WebSocketChannel received padded DATA frame, padding length: " + padding)); + close(); + return readableBytes; + } + if (!isHandshakeCompleted) { + data.release(); + pipeline().fireExceptionCaught(new IllegalArgumentException("Http2WebSocketChannel received DATA frame before handshake completion")); + close(); + return readableBytes; + } + + fireChildRead(data, endOfStream); + return readableBytes; + } + + @Override + public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) { + pipeline().fireUserEventTriggered(Http2WebSocketEvent.Http2WebSocketRemoteCloseEvent.reset(serial(), path, subprotocol, System.nanoTime())); + streamClosed(); + } + + @Override + public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) { + pipeline().fireUserEventTriggered(new Http2WebSocketEvent.Http2WebSocketRemoteGoAwayEvent(serial(), path, subprotocol, System.nanoTime(), errorCode)); + streamClosed(); + } + + public Http2WebSocketChannel setStreamId(int streamId) { + this.streamId = streamId; + return this; + } + + public ChannelPromise handshakePromise() { + return handshakePromise; + } + + private void incrementPendingOutboundBytes(long size, boolean invokeLater) { + if (size == 0) { + return; + } + long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, size); + if (newWriteBufferSize > config().getWriteBufferHighWaterMark()) { + setUnwritable(invokeLater); + } + } + + private void decrementPendingOutboundBytes(long size, boolean invokeLater) { + if (size == 0) { + return; + } + long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, -size); + // Once the totalPendingSize dropped below the low water-mark we can mark the child channel + // as writable again. Before doing so we also need to ensure the parent channel is writable to + // prevent excessive buffering in the parent outbound buffer. If the parent is not writable + // we will mark the child channel as writable once the parent becomes writable by calling + // trySetWritable() later. + if (newWriteBufferSize < config().getWriteBufferLowWaterMark() && parent().isWritable()) { + setWritable(invokeLater); + } + } + + @Override + public void trySetWritable() { + // The parent is writable again but the child channel itself may still not be writable. + // Lets try to set the child channel writable to match the state of the parent channel + // if (and only if) the totalPendingSize is smaller then the low water-mark. + // If this is not the case we will try again later once we drop under it. + if (totalPendingSize < config().getWriteBufferLowWaterMark()) { + setWritable(false); + } + } + + @Override + public void fireExceptionCaught(Throwable t) { + pipeline.fireExceptionCaught(t); + } + + @Override + public void closeForcibly() { + unsafe.closeForcibly(); + } + + private void setWritable(boolean invokeLater) { + for (; ; ) { + final int oldValue = unwritable; + final int newValue = oldValue & ~1; + if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { + if (oldValue != 0 && newValue == 0) { + fireChannelWritabilityChanged(invokeLater); + } + break; + } + } + } + + private void setUnwritable(boolean invokeLater) { + for (; ; ) { + final int oldValue = unwritable; + final int newValue = oldValue | 1; + if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { + if (oldValue == 0 && newValue != 0) { + fireChannelWritabilityChanged(invokeLater); + } + break; + } + } + } + + private void fireChannelWritabilityChanged(boolean invokeLater) { + final ChannelPipeline pipeline = pipeline(); + if (invokeLater) { + Runnable task = fireChannelWritabilityChangedTask; + if (task == null) { + fireChannelWritabilityChangedTask = task = pipeline::fireChannelWritabilityChanged; + } + eventLoop().execute(task); + } else { + pipeline.fireChannelWritabilityChanged(); + } + } + + public int streamId() { + return streamId; + } + + @Override + public void streamClosed() { + Http2ChannelUnsafe u = unsafe; + u.streamClosed(); + } + + boolean isCloseInitiator() { + Boolean ci = closeInitiator; + return ci != null && ci; + } + + void trySetCloseInitiator(boolean isCloseInitiator) { + if (closeInitiator == null) { + closeInitiator = isCloseInitiator; + } + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + @Override + public ChannelConfig config() { + return config; + } + + @Override + public boolean isOpen() { + return !closePromise.isDone(); + } + + @Override + public boolean isActive() { + return isOpen(); + } + + @Override + public boolean isWritable() { + return unwritable == 0; + } + + @Override + public ChannelId id() { + return channelId; + } + + @Override + public EventLoop eventLoop() { + return parent().eventLoop(); + } + + @Override + public Channel parent() { + return webSocketChannelParent.context().channel(); + } + + @Override + public boolean isRegistered() { + return registered; + } + + @Override + public SocketAddress localAddress() { + return parent().localAddress(); + } + + @Override + public SocketAddress remoteAddress() { + return parent().remoteAddress(); + } + + @Override + public ChannelFuture closeFuture() { + return closePromise; + } + + @Override + public long bytesBeforeUnwritable() { + long bytes = config().getWriteBufferHighWaterMark() - totalPendingSize; + // If bytes is negative we know we are not writable, but if bytes is non-negative we have to + // check + // writability. Note that totalPendingSize and isWritable() use different volatile variables + // that are not + // synchronized together. totalPendingSize will be updated before isWritable(). + if (bytes > 0) { + return isWritable() ? bytes : 0; + } + return 0; + } + + @Override + public long bytesBeforeWritable() { + long bytes = totalPendingSize - config().getWriteBufferLowWaterMark(); + // If bytes is negative we know we are writable, but if bytes is non-negative we have to check + // writability. + // Note that totalPendingSize and isWritable() use different volatile variables that are not + // synchronized + // together. totalPendingSize will be updated before isWritable(). + if (bytes > 0) { + return isWritable() ? 0 : bytes; + } + return 0; + } + + @Override + public Unsafe unsafe() { + return unsafe; + } + + @Override + public ChannelPipeline pipeline() { + return pipeline; + } + + @Override + public ByteBufAllocator alloc() { + return config().getAllocator(); + } + + @Override + public Channel read() { + pipeline().read(); + return this; + } + + @Override + public Channel flush() { + pipeline().flush(); + return this; + } + + @Override + public ChannelFuture bind(SocketAddress localAddress) { + return pipeline().bind(localAddress); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress) { + return pipeline().connect(remoteAddress); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress) { + return pipeline().connect(remoteAddress, localAddress); + } + + @Override + public ChannelFuture disconnect() { + return pipeline().disconnect(); + } + + @Override + public ChannelFuture close() { + return pipeline().close(); + } + + @Override + public ChannelFuture deregister() { + return pipeline().deregister(); + } + + @Override + public ChannelFuture bind(SocketAddress localAddress, ChannelPromise promise) { + return pipeline().bind(localAddress, promise); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, ChannelPromise promise) { + return pipeline().connect(remoteAddress, promise); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + return pipeline().connect(remoteAddress, localAddress, promise); + } + + @Override + public ChannelFuture disconnect(ChannelPromise promise) { + return pipeline().disconnect(promise); + } + + @Override + public ChannelFuture close(ChannelPromise promise) { + return pipeline().close(promise); + } + + @Override + public ChannelFuture deregister(ChannelPromise promise) { + return pipeline().deregister(promise); + } + + @Override + public ChannelFuture write(Object msg) { + return pipeline().write(msg); + } + + @Override + public ChannelFuture write(Object msg, ChannelPromise promise) { + return pipeline().write(msg, promise); + } + + @Override + public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) { + return pipeline().writeAndFlush(msg, promise); + } + + @Override + public ChannelFuture writeAndFlush(Object msg) { + return pipeline().writeAndFlush(msg); + } + + @Override + public ChannelPromise newPromise() { + return pipeline().newPromise(); + } + + @Override + public ChannelProgressivePromise newProgressivePromise() { + return pipeline().newProgressivePromise(); + } + + @Override + public ChannelFuture newSucceededFuture() { + return pipeline().newSucceededFuture(); + } + + @Override + public ChannelFuture newFailedFuture(Throwable cause) { + return pipeline().newFailedFuture(cause); + } + + @Override + public ChannelPromise voidPromise() { + return pipeline().voidPromise(); + } + + @Override + public int hashCode() { + return id().hashCode(); + } + + @Override + public boolean equals(Object o) { + return this == o; + } + + @Override + public int compareTo(Channel o) { + if (this == o) { + return 0; + } + return id().compareTo(o.id()); + } + + @Override + public String toString() { + return parent().toString(); + } + + /** + * Receive a read message. This does not notify handlers unless a read is in progress on the + * channel. + */ + public void fireChildRead(ByteBuf data, boolean endOfStream) { + assert eventLoop().inEventLoop(); + if (endOfStream) { + trySetCloseInitiator(false); + } + if (!isActive()) { + ReferenceCountUtil.release(data); + } else if (readStatus != ReadStatus.IDLE) { + // If a read is in progress or has been requested, there cannot be anything in the queue, + // otherwise we would have drained it from the queue and processed it during the read cycle. + Queue inbound = inboundBuffer; + assert inbound == null || inbound.isEmpty(); + Http2ChannelUnsafe u = unsafe; + final RecvByteBufAllocator.Handle allocHandle = u.recvBufAllocHandle(); + u.doRead0(data, allocHandle); + // We currently don't need to check for readEOS because the parent channel and child channel + // are limited + // to the same EventLoop thread. There are a limited number of frame types that may come after + // EOS is + // read (unknown, reset) and the trade off is less conditionals for the hot path + // (headers/data) at the + // cost of additional readComplete notifications on the rare path. + if (allocHandle.continueReading()) { + maybeAddChannelToReadCompletePendingQueue(); + } else { + u.notifyReadComplete(allocHandle, true); + } + } else { + Queue inbound = inboundBuffer; + if (inbound == null) { + inbound = inboundBuffer = new ArrayDeque<>(4); + } + inbound.add(data); + } + if (endOfStream) { + pipeline().fireUserEventTriggered(Http2WebSocketEvent.Http2WebSocketRemoteCloseEvent.endStream(serial(), + path, subprotocol, System.nanoTime())); + } + } + + public void fireChildReadComplete() { + assert eventLoop().inEventLoop(); + assert readStatus != ReadStatus.IDLE || !readCompletePending; + unsafe.notifyReadComplete(unsafe.recvBufAllocHandle(), false); + } + + public void setStreamWeightAttribute(short streamWeight) { + attr(STREAM_WEIGHT_KEY).set(streamWeight); + } + + public Short streamWeightAttribute() { + if (!hasAttr(STREAM_WEIGHT_KEY)) { + return null; + } + return attr(STREAM_WEIGHT_KEY).get(); + } + + private final class Http2ChannelUnsafe implements Unsafe { + private final VoidChannelPromise unsafeVoidPromise = new VoidChannelPromise(Http2WebSocketChannel.this, false); + private RecvByteBufAllocator.Handle recvHandle; + private boolean writeDoneAndNoFlush; + private boolean closeInitiated; + private boolean streamClosed; + + @Override + public void connect(final SocketAddress remoteAddress, SocketAddress localAddress, final ChannelPromise promise) { + if (!promise.setUncancellable()) { + return; + } + promise.setFailure(new UnsupportedOperationException()); + } + + @Override + public RecvByteBufAllocator.Handle recvBufAllocHandle() { + RecvByteBufAllocator.Handle h = recvHandle; + if (h == null) { + h = recvHandle = config().getRecvByteBufAllocator().newHandle(); + h.reset(config()); + } + return h; + } + + @Override + public SocketAddress localAddress() { + return parent().unsafe().localAddress(); + } + + @Override + public SocketAddress remoteAddress() { + return parent().unsafe().remoteAddress(); + } + + @Override + public void register(EventLoop eventLoop, ChannelPromise promise) { + if (!promise.setUncancellable()) { + return; + } + if (registered) { + promise.setFailure(new UnsupportedOperationException("Re-register is not supported")); + return; + } + registered = true; + promise.setSuccess(); + ChannelPipeline pl = pipeline(); + pl.fireChannelRegistered(); + if (isActive()) { + pl.fireChannelActive(); + } + } + + @Override + public void bind(SocketAddress localAddress, ChannelPromise promise) { + if (!promise.setUncancellable()) { + return; + } + promise.setFailure(new UnsupportedOperationException()); + } + + @Override + public void disconnect(ChannelPromise promise) { + close(promise); + } + + @Override + public void close(final ChannelPromise promise) { + if (!promise.setUncancellable()) { + return; + } + if (closeInitiated) { + if (closePromise.isDone()) { + promise.setSuccess(); + } else if (!(promise instanceof VoidChannelPromise)) { + closePromise.addListener(future -> promise.setSuccess()); + } + return; + } + closeInitiated = true; + parent().closeFuture().removeListener(Http2WebSocketChannel.this); + readCompletePending = false; + final boolean wasActive = isActive(); + if (parent().isActive() && !streamClosed && streamId > 0) { + trySetCloseInitiator(true); + writeRstStream().addListener(FRAME_WRITE_LISTENER); + } + Queue inbound = inboundBuffer; + if (inbound != null) { + inboundBuffer = null; + for (; ; ) { + ByteBuf msg = inbound.poll(); + if (msg == null) { + break; + } + ReferenceCountUtil.release(msg); + } + } + outboundClosed = true; + closePromise.setSuccess(); + promise.setSuccess(); + fireChannelInactiveAndDeregister(voidPromise(), wasActive); + } + + @Override + public void closeForcibly() { + close(unsafe().voidPromise()); + } + + @Override + public void deregister(ChannelPromise promise) { + fireChannelInactiveAndDeregister(promise, false); + } + + private void fireChannelInactiveAndDeregister( + final ChannelPromise promise, final boolean fireChannelInactive) { + if (!promise.setUncancellable()) { + return; + } + if (!registered) { + promise.setSuccess(); + return; + } + invokeLater(() -> { + ChannelPipeline pl = pipeline; + if (fireChannelInactive) { + pl.fireChannelInactive(); + } + if (registered) { + registered = false; + pl.fireChannelUnregistered(); + } + safeSetSuccess(promise); + }); + } + + private void safeSetSuccess(ChannelPromise promise) { + if (!(promise instanceof VoidChannelPromise) && !promise.trySuccess()) { + logger.log(Level.WARNING, "failed to mark a promise as success because it is done already: " + promise); + } + } + + private void invokeLater(Runnable task) { + try { + eventLoop().execute(task); + } catch (RejectedExecutionException e) { + logger.log(Level.WARNING, "can't invoke task later as EventLoop rejected it", e); + } + } + + @Override + public void beginRead() { + if (!isActive()) { + return; + } + switch (readStatus) { + case IDLE: + readStatus = ReadStatus.IN_PROGRESS; + doBeginRead(); + break; + case IN_PROGRESS: + readStatus = ReadStatus.REQUESTED; + break; + default: + break; + } + } + + private ByteBuf pollQueuedMessage() { + Queue inbound = inboundBuffer; + return inbound == null ? null : inbound.poll(); + } + + void doBeginRead() { + while (readStatus != ReadStatus.IDLE) { + ByteBuf message = pollQueuedMessage(); + if (message == null) { + if (streamClosed) { + unsafe.closeForcibly(); + } + flush(); + break; + } + final RecvByteBufAllocator.Handle allocHandle = recvBufAllocHandle(); + allocHandle.reset(config()); + boolean continueReading = false; + do { + doRead0(message, allocHandle); + } while ((streamClosed || (continueReading = allocHandle.continueReading())) + && (message = pollQueuedMessage()) != null); + if (continueReading && isParentReadInProgress() && !streamClosed) { + maybeAddChannelToReadCompletePendingQueue(); + } else { + notifyReadComplete(allocHandle, true); + } + } + } + + void streamClosed() { + streamClosed = true; + doBeginRead(); + } + + void notifyReadComplete(RecvByteBufAllocator.Handle allocHandle, boolean forceReadComplete) { + if (!readCompletePending && !forceReadComplete) { + return; + } + readCompletePending = false; + if (readStatus == ReadStatus.REQUESTED) { + readStatus = ReadStatus.IN_PROGRESS; + } else { + readStatus = ReadStatus.IDLE; + } + allocHandle.readComplete(); + pipeline().fireChannelReadComplete(); + flush(); + if (streamClosed) { + unsafe.closeForcibly(); + } + } + + void doRead0(ByteBuf data, RecvByteBufAllocator.Handle allocHandle) { + final int bytes = data.readableBytes(); + allocHandle.attemptedBytesRead(bytes); + allocHandle.lastBytesRead(bytes); + allocHandle.incMessagesRead(1); + pipeline().fireChannelRead(data); + } + + @Override + public void write(Object msg, final ChannelPromise promise) { + if (!promise.setUncancellable()) { + ReferenceCountUtil.release(msg); + return; + } + if (!isActive() || outboundClosed && (msg instanceof ByteBuf)) { + ReferenceCountUtil.release(msg); + promise.setFailure(new ClosedChannelException()); + logger.log(Level.FINE, "Websocket channel frame dropped because outbound is closed"); + return; + } + try { + if (msg instanceof ByteBuf) { + writeData((ByteBuf) msg, false, promise); + } else { + String msgStr = msg.toString(); + ReferenceCountUtil.release(msg); + promise.setFailure(new IllegalArgumentException("Message must be an " + StringUtil.simpleClassName(ByteBuf.class) + ": " + msgStr)); + } + } catch (Throwable t) { + promise.tryFailure(t); + } + } + + ChannelFuture writeData(ByteBuf dataFrameContents, boolean endOfStream, final ChannelPromise promise) { + ChannelFuture f = webSocketChannelParent.writeData(streamId, dataFrameContents, endOfStream, promise); + if (f.isDone()) { + writeComplete(f); + } else { + final long bytes = MESSAGE_SIZE_ESTIMATOR_INSTANCE.size(dataFrameContents); + incrementPendingOutboundBytes(bytes, false); + f.addListener( + (ChannelFuture future) -> { + writeComplete(future); + decrementPendingOutboundBytes(bytes, false); + }); + writeDoneAndNoFlush = true; + } + return f; + } + + private void writeComplete(ChannelFuture future) { + Throwable cause = future.cause(); + if (cause != null) { + Throwable error = wrapStreamClosedError(cause); + if (error instanceof IOException) { + if (config.isAutoClose()) { + closeForcibly(); + } else { + outboundClosed = true; + } + } + } + } + + private Throwable wrapStreamClosedError(Throwable cause) { + if (cause instanceof Http2Exception + && ((Http2Exception) cause).error() == Http2Error.STREAM_CLOSED) { + return new ClosedChannelException().initCause(cause); + } + return cause; + } + + @Override + public void flush() { + if (!writeDoneAndNoFlush || isParentReadInProgress()) { + return; + } + writeDoneAndNoFlush = false; + webSocketChannelParent.context().flush(); + } + + @Override + public ChannelPromise voidPromise() { + return unsafeVoidPromise; + } + + @Override + public ChannelOutboundBuffer outboundBuffer() { + return null; + } + } + + /** + * {@link ChannelConfig} so that the high and low writebuffer watermarks can reflect the outbound + * flow control window, without having to create a new {@link WriteBufferWaterMark} object + * whenever the flow control window changes. + */ + private static final class Http2StreamChannelConfig extends DefaultChannelConfig { + Http2StreamChannelConfig(Channel channel) { + super(channel); + } + + @Override + public ChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator) { + if (!(allocator.newHandle() instanceof RecvByteBufAllocator.ExtendedHandle)) { + throw new IllegalArgumentException("allocator.newHandle() must return an object of type: " + + RecvByteBufAllocator.ExtendedHandle.class); + } + super.setRecvByteBufAllocator(allocator); + return this; + } + } + + private void maybeAddChannelToReadCompletePendingQueue() { + if (!readCompletePending) { + readCompletePending = true; + addChannelToReadCompletePendingQueue(); + } + } + + ChannelFuture writeRstStream() { + logger.log(Level.FINE, "Websocket channel writing RST frame for path: " + + path + ", streamId: " + streamId + ", errorCode: " + Http2Error.CANCEL.code()); + return webSocketChannelParent.writeRstStream(streamId, Http2Error.CANCEL.code()); + } + + ChannelFuture writePriority(short weight) { + logger.log(Level.FINE, "Websocket channel writing PRIORITY frame for path: " + + path + ", streamId: " + streamId + ", weight: " + weight); + return webSocketChannelParent.writePriority(streamId, weight); + } + + private boolean isParentReadInProgress() { + return webSocketChannelParent.isParentReadInProgress(); + }; + + private void addChannelToReadCompletePendingQueue() { + webSocketChannelParent.addChannelToReadCompletePendingQueue(this); + } + + @Override + public void onPriorityRead( + ChannelHandlerContext ctx, + int streamId, + int streamDependency, + short weight, + boolean exclusive) {} + + @Override + public void onSettingsAckRead(ChannelHandlerContext ctx) {} + + @Override + public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) {} + + @Override + public void onPingRead(ChannelHandlerContext ctx, long data) {} + + @Override + public void onPingAckRead(ChannelHandlerContext ctx, long data) {} + + @Override + public void onPushPromiseRead( + ChannelHandlerContext ctx, + int streamId, + int promisedStreamId, + Http2Headers headers, + int padding) {} + + @Override + public void onWindowUpdateRead( + ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) {} + + @Override + public void onUnknownFrame( + ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, ByteBuf payload) { + payload.release(); + } + + @Override + public void onHeadersRead( + ChannelHandlerContext ctx, + int streamId, + Http2Headers headers, + int padding, + boolean endOfStream) {} + + @Override + public void onHeadersRead( + ChannelHandlerContext ctx, + int streamId, + Http2Headers headers, + int streamDependency, + short weight, + boolean exclusive, + int padding, + boolean endOfStream) {} + + private class WebSocketChannelPipeline extends DefaultChannelPipeline { + + protected WebSocketChannelPipeline(Channel channel) { + super(channel); + } + + @Override + protected void incrementPendingOutboundBytes(long size) { + Http2WebSocketChannel.this.incrementPendingOutboundBytes(size, true); + } + + @Override + protected void decrementPendingOutboundBytes(long size) { + Http2WebSocketChannel.this.decrementPendingOutboundBytes(size, true); + } + + @Override + protected void onUnhandledInboundUserEventTriggered(Object evt) { + if (evt instanceof Http2WebSocketEvent) { + if (closePromise.isDone()) { + return; + } + Http2WebSocketEvent webSocketEvent = (Http2WebSocketEvent) evt; + switch (webSocketEvent.type()) { + case CLOSE_LOCAL_ENDSTREAM: + logger.log(Level.FINE, "Graceful local close of websocket, streamId: " + streamId + ", path: " + path); + trySetCloseInitiator(true); + ChannelHandlerContext ctx = webSocketChannelParent.context(); + Http2ChannelUnsafe u = unsafe; + u.writeData(Unpooled.EMPTY_BUFFER, true, ctx.newPromise()) + .addListener(FRAME_WRITE_LISTENER); + u.flush(); + u.streamClosed(); + break; + case WEIGHT_UPDATE: + if (handshakePromise == null) { + logger.log(Level.FINE, "Attempted to send PRIORITY frame for stream: " + streamId + " as server, ignoring"); + return; + } + short weight = webSocketEvent.cast().streamWeight(); + if (streamId == 0) { + pendingStreamWeight = weight; + return; + } + writePriority(weight).addListener((ChannelFuture future) -> { + Throwable cause = future.cause(); + if (cause != null) { + Http2WebSocketEvent.fireFrameWriteError(future.channel(), cause); + } else { + setStreamWeightAttribute(weight); + } + }); + break; + default: + break; + } + return; + } + super.onUnhandledInboundUserEventTriggered(evt); + } + } + + static class FrameWriteListener implements GenericFutureListener { + @Override + public void operationComplete(ChannelFuture future) { + Throwable cause = future.cause(); + if (cause != null) { + Http2WebSocketEvent.fireFrameWriteError(future.channel(), cause); + } + } + } + + static class Http2WebSocketChannelId implements ChannelId { + + private final int id; + private final ChannelId parentId; + + Http2WebSocketChannelId(ChannelId parentId, int id) { + this.parentId = parentId; + this.id = id; + } + + @Override + public String asShortText() { + return parentId.asShortText() + '/' + id; + } + + @Override + public String asLongText() { + return parentId.asLongText() + '/' + id; + } + + @Override + public int compareTo(ChannelId o) { + if (o instanceof Http2WebSocketChannelId) { + Http2WebSocketChannelId otherId = (Http2WebSocketChannelId) o; + int res = parentId.compareTo(otherId.parentId); + if (res == 0) { + return id - otherId.id; + } else { + return res; + } + } + return parentId.compareTo(o); + } + + @Override + public int hashCode() { + return id * 31 + parentId.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof Http2WebSocketChannelId)) { + return false; + } + Http2WebSocketChannelId otherId = (Http2WebSocketChannelId) obj; + return id == otherId.id && parentId.equals(otherId.parentId); + } + + @Override + public String toString() { + return asShortText(); + } + } + + static class PreHandshakeHandler extends ChannelOutboundHandlerAdapter { + Queue outboundBuffer; + boolean isDone; + ChannelHandlerContext ctx; + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + super.handlerAdded(ctx); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + if (isDone) { + ReferenceCountUtil.safeRelease(msg); + return; + } + if (!(msg instanceof WebSocketFrame)) { + super.write(ctx, msg, promise); + return; + } + + Queue outbound = outboundBuffer; + if (outbound == null) { + outbound = outboundBuffer = new ArrayDeque<>(); + } + outbound.offer(new PendingOutbound((WebSocketFrame) msg, promise)); + } + + void complete() { + Queue outbound = outboundBuffer; + ChannelHandlerContext c = ctx; + if (outbound == null) { + c.pipeline().remove(this); + return; + } + outboundBuffer = null; + PendingOutbound o = outbound.poll(); + do { + c.write(o.webSocketFrame, o.completePromise); + o = outbound.poll(); + } while (o != null); + c.flush(); + c.pipeline().remove(this); + } + + void cancel(Throwable cause) { + isDone = true; + Queue outbound = outboundBuffer; + if (outbound == null) { + ctx.close(); + return; + } + outboundBuffer = null; + + PendingOutbound o = outbound.poll(); + do { + o.completePromise.tryFailure(cause); + o.webSocketFrame.release(); + o = outbound.poll(); + } while (o != null); + ctx.close(); + } + + static class PendingOutbound { + final WebSocketFrame webSocketFrame; + final ChannelPromise completePromise; + + PendingOutbound(WebSocketFrame webSocketFrame, ChannelPromise completePromise) { + this.webSocketFrame = webSocketFrame; + this.completePromise = completePromise; + } + } + } + + /** The current status of the read-processing for a {@link Http2WebSocketChannel}. */ + private enum ReadStatus { + /** No read in progress and no read was requested (yet) */ + IDLE, + + /** Reading in progress */ + IN_PROGRESS, + + /** A read operation was requested. */ + REQUESTED + } +} diff --git a/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketChannelHandler.java b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketChannelHandler.java new file mode 100644 index 0000000..e6c55d9 --- /dev/null +++ b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketChannelHandler.java @@ -0,0 +1,392 @@ +package org.xbib.netty.http.common.ws; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.EventLoop; +import io.netty.handler.codec.http.websocketx.WebSocketDecoderConfig; +import io.netty.handler.codec.http2.*; +import io.netty.util.collection.IntCollections; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.ScheduledFuture; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +public abstract class Http2WebSocketChannelHandler extends Http2WebSocketHandler { + + protected final WebSocketDecoderConfig config; + + protected final boolean isEncoderMaskPayload; + + protected final long closedWebSocketRemoveTimeoutMillis; + + protected final Supplier> webSocketRegistryFactory; + + protected IntObjectMap webSocketRegistry = IntCollections.emptyMap(); + + protected ChannelHandlerContext ctx; + + protected WebSocketsParent webSocketsParent; + + protected boolean isAutoRead; + + public Http2WebSocketChannelHandler(WebSocketDecoderConfig webSocketDecoderConfig, + boolean isEncoderMaskPayload, + long closedWebSocketRemoveTimeoutMillis, + boolean isSingleWebSocketPerConnection) { + this.config = webSocketDecoderConfig; + this.isEncoderMaskPayload = isEncoderMaskPayload; + this.closedWebSocketRemoveTimeoutMillis = closedWebSocketRemoveTimeoutMillis; + this.webSocketRegistryFactory = webSocketRegistryFactory(isSingleWebSocketPerConnection); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + super.handlerAdded(ctx); + this.ctx = ctx; + this.isAutoRead = ctx.channel().config().isAutoRead(); + Http2ConnectionEncoder encoder = http2Handler.encoder(); + this.webSocketsParent = new WebSocketsParent(encoder); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + webSocketRegistry.clear(); + super.channelInactive(ctx); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + if (ctx.channel().isWritable()) { + IntObjectMap webSockets = this.webSocketRegistry; + if (!webSockets.isEmpty()) { + webSockets.forEach((key, webSocket) -> webSocket.trySetWritable()); + } + } + super.channelWritabilityChanged(ctx); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + webSocketsParent.setReadInProgress(); + super.channelRead(ctx, msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + webSocketsParent.processPendingReadCompleteQueue(); + super.channelReadComplete(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (!(cause instanceof Http2Exception.StreamException)) { + super.exceptionCaught(ctx, cause); + return; + } + IntObjectMap webSockets = this.webSocketRegistry; + if (!webSockets.isEmpty()) { + Http2Exception.StreamException streamException = (Http2Exception.StreamException) cause; + Http2WebSocket webSocket = webSockets.get(streamException.streamId()); + if (webSocket == null) { + super.exceptionCaught(ctx, cause); + return; + } + if (webSocket != Http2WebSocket.CLOSED) { + try { + ClosedChannelException e = new ClosedChannelException(); + e.initCause(streamException); + webSocket.fireExceptionCaught(e); + } finally { + webSocket.closeForcibly(); + } + } + return; + } + super.exceptionCaught(ctx, cause); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + IntObjectMap webSockets = this.webSocketRegistry; + if (!webSockets.isEmpty()) { + webSockets.forEach((key, webSocket) -> webSocket.streamClosed()); + } + super.close(ctx, promise); + } + + @Override + public void onGoAwayRead( + ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) + throws Http2Exception { + IntObjectMap webSockets = this.webSocketRegistry; + if (!webSockets.isEmpty()) { + webSockets.forEach( + (key, webSocket) -> webSocket.onGoAwayRead(ctx, lastStreamId, errorCode, debugData)); + } + next().onGoAwayRead(ctx, lastStreamId, errorCode, debugData); + } + + @Override + public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) + throws Http2Exception { + webSocketOrNext(streamId).onRstStreamRead(ctx, streamId, errorCode); + } + + @Override + public int onDataRead( + ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) + throws Http2Exception { + return webSocketOrNext(streamId).onDataRead(ctx, streamId, data.retain(), padding, endOfStream); + } + + @Override + public void onHeadersRead( + ChannelHandlerContext ctx, + int streamId, + Http2Headers headers, + int padding, + boolean endOfStream) + throws Http2Exception { + webSocketOrNext(streamId).onHeadersRead(ctx, streamId, headers, padding, endOfStream); + } + + @Override + public void onHeadersRead( + ChannelHandlerContext ctx, + int streamId, + Http2Headers headers, + int streamDependency, + short weight, + boolean exclusive, + int padding, + boolean endOfStream) + throws Http2Exception { + webSocketOrNext(streamId) + .onHeadersRead( + ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream); + } + + @Override + public void onPriorityRead( + ChannelHandlerContext ctx, + int streamId, + int streamDependency, + short weight, + boolean exclusive) + throws Http2Exception { + webSocketOrNext(streamId).onPriorityRead(ctx, streamId, streamDependency, weight, exclusive); + } + + @Override + public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) + throws Http2Exception { + webSocketOrNext(streamId).onWindowUpdateRead(ctx, streamId, windowSizeIncrement); + } + + @Override + public void onUnknownFrame( + ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, ByteBuf payload) + throws Http2Exception { + webSocketOrNext(streamId).onUnknownFrame(ctx, frameType, streamId, flags, payload); + } + + Http2FrameListener webSocketOrNext(int streamId) { + Http2WebSocket webSocket = webSocketRegistry.get(streamId); + if (webSocket != null) { + ChannelHandlerContext c = ctx; + if (!isAutoRead) { + c.read(); + } + return webSocket; + } + return next; + } + + void registerWebSocket(int streamId, Http2WebSocketChannel webSocket) { + IntObjectMap registry = webSocketRegistry; + if (registry == IntCollections.emptyMap()) { + webSocketRegistry = registry = webSocketRegistryFactory.get(); + } + registry.put(streamId, webSocket); + IntObjectMap finalRegistry = registry; + webSocket + .closeFuture() + .addListener( + future -> { + Channel channel = ctx.channel(); + ChannelFuture connectionCloseFuture = channel.closeFuture(); + if (connectionCloseFuture.isDone()) { + return; + } + /*stream is remotely closed already so there will be no frames stream received*/ + if (!webSocket.isCloseInitiator()) { + finalRegistry.remove(streamId); + return; + } + finalRegistry.put(streamId, Http2WebSocket.CLOSED); + removeAfterTimeout( + streamId, finalRegistry, connectionCloseFuture, channel.eventLoop()); + }); + } + + void removeAfterTimeout(int streamId, IntObjectMap webSockets, ChannelFuture connectionCloseFuture, + EventLoop eventLoop) { + RemoveWebSocket removeWebSocket = + new RemoveWebSocket(streamId, webSockets, connectionCloseFuture); + ScheduledFuture removeWebSocketFuture = + eventLoop.schedule( + removeWebSocket, closedWebSocketRemoveTimeoutMillis, TimeUnit.MILLISECONDS); + removeWebSocket.removeWebSocketFuture(removeWebSocketFuture); + } + + private static class RemoveWebSocket implements Runnable, GenericFutureListener { + + private final IntObjectMap webSockets; + + private final int streamId; + + private final ChannelFuture connectionCloseFuture; + + private ScheduledFuture removeWebSocketFuture; + + RemoveWebSocket(int streamId, IntObjectMap webSockets, ChannelFuture connectionCloseFuture) { + this.streamId = streamId; + this.webSockets = webSockets; + this.connectionCloseFuture = connectionCloseFuture; + } + + void removeWebSocketFuture(ScheduledFuture removeWebSocketFuture) { + this.removeWebSocketFuture = removeWebSocketFuture; + connectionCloseFuture.addListener(this); + } + + @Override + public void operationComplete(ChannelFuture future) { + removeWebSocketFuture.cancel(true); + } + + @Override + public void run() { + webSockets.remove(streamId); + connectionCloseFuture.removeListener(this); + } + } + + @SuppressWarnings("Convert2MethodRef") + static Supplier> webSocketRegistryFactory( + boolean isSingleWebSocketPerConnection) { + if (isSingleWebSocketPerConnection) { + return () -> new Http2WebSocketHandlerContainers.SingleElementOptimizedMap<>(); + } else { + return () -> new IntObjectHashMap<>(4); + } + } + + /** + * Provides DATA, RST, WINDOW_UPDATE frame write operations to websocket channel. Also hosts code + * derived from netty so It can be attributed properly + */ + public class WebSocketsParent { + static final int READ_COMPLETE_PENDING_QUEUE_MAX_SIZE = Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS; + + final Queue readCompletePendingQueue = new ArrayDeque<>(8); + + boolean parentReadInProgress; + + final Http2ConnectionEncoder connectionEncoder; + + public WebSocketsParent(Http2ConnectionEncoder connectionEncoder) { + this.connectionEncoder = connectionEncoder; + } + + public ChannelFuture writeHeaders(int streamId, Http2Headers headers, boolean endStream) { + ChannelHandlerContext c = ctx; + ChannelPromise p = c.newPromise(); + return connectionEncoder.writeHeaders(c, streamId, headers, 0, endStream, p); + } + + public ChannelFuture writeHeaders( + int streamId, Http2Headers headers, boolean endStream, short weight) { + ChannelHandlerContext c = ctx; + ChannelPromise p = c.newPromise(); + ChannelFuture channelFuture = + connectionEncoder.writeHeaders(c, streamId, headers, 0, weight, false, 0, endStream, p); + c.flush(); + return channelFuture; + } + + public ChannelFuture writeData(int streamId, ByteBuf data, boolean endStream, ChannelPromise promise) { + ChannelHandlerContext c = ctx; + return connectionEncoder.writeData(c, streamId, data, 0, endStream, promise); + } + + public ChannelFuture writeRstStream(int streamId, long errorCode) { + ChannelHandlerContext c = ctx; + ChannelPromise p = c.newPromise(); + ChannelFuture channelFuture = connectionEncoder.writeRstStream(c, streamId, errorCode, p); + c.flush(); + return channelFuture; + } + + public ChannelFuture writePriority(int streamId, short weight) { + ChannelHandlerContext c = ctx; + ChannelPromise p = c.newPromise(); + ChannelFuture channelFuture = + connectionEncoder.writePriority(c, streamId, 0, weight, false, p); + c.flush(); + return channelFuture; + } + + public boolean isParentReadInProgress() { + return parentReadInProgress; + } + + public void addChannelToReadCompletePendingQueue(Http2WebSocketChannel webSocketChannel) { + Queue q = readCompletePendingQueue; + while (q.size() >= READ_COMPLETE_PENDING_QUEUE_MAX_SIZE) { + processPendingReadCompleteQueue(); + } + q.offer(webSocketChannel); + } + + public ChannelHandlerContext context() { + return ctx; + } + + public void register(final int streamId, Http2WebSocketChannel webSocket) { + registerWebSocket(streamId, webSocket); + } + + void setReadInProgress() { + parentReadInProgress = true; + } + + void processPendingReadCompleteQueue() { + parentReadInProgress = true; + Queue q = readCompletePendingQueue; + Http2WebSocketChannel childChannel = q.poll(); + if (childChannel != null) { + try { + do { + childChannel.fireChildReadComplete(); + childChannel = q.poll(); + } while (childChannel != null); + } finally { + parentReadInProgress = false; + q.clear(); + ctx.flush(); + } + } else { + parentReadInProgress = false; + } + } + } +} diff --git a/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketEvent.java b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketEvent.java new file mode 100644 index 0000000..1bb6488 --- /dev/null +++ b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketEvent.java @@ -0,0 +1,379 @@ +package org.xbib.netty.http.common.ws; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.http.websocketx.WebSocketHandshakeException; +import io.netty.handler.codec.http2.Http2Headers; + +/** + * Base type for websocket-over-http2 events + */ +public abstract class Http2WebSocketEvent { + + private final Type type; + + public Http2WebSocketEvent(Type type) { + this.type = type; + } + + public static void fireFrameWriteError(Channel parentChannel, Throwable t) { + ChannelPipeline parentPipeline = parentChannel.pipeline(); + if (parentChannel.config().isAutoClose()) { + parentPipeline.fireExceptionCaught(t); + parentChannel.close(); + return; + } + if (t instanceof Exception) { + parentPipeline.fireUserEventTriggered(new Http2WebSocketWriteErrorEvent(Http2WebSocketMessages.WRITE_ERROR, t)); + return; + } + parentPipeline.fireExceptionCaught(t); + } + + public static void fireHandshakeValidationStartAndError(Channel parentChannel, int streamId, Http2Headers headers) { + long timestamp = System.nanoTime(); + Http2WebSocketEvent.fireHandshakeStartAndError(parentChannel, streamId, + nonNullString(headers.path()), + nonNullString(headers.get(Http2WebSocketProtocol.HEADER_WEBSOCKET_SUBPROTOCOL_NAME)), + headers, timestamp, timestamp, WebSocketHandshakeException.class.getName(), + Http2WebSocketMessages.HANDSHAKE_INVALID_REQUEST_HEADERS); + } + + public static void fireHandshakeStartAndError(Channel parentChannel, + int serial, String path, String subprotocols, + Http2Headers requestHeaders, long startNanos, long errorNanos, Throwable t) { + ChannelPipeline parentPipeline = parentChannel.pipeline(); + if (t instanceof Exception) { + parentPipeline.fireUserEventTriggered(new Http2WebSocketHandshakeStartEvent(serial, path, subprotocols, startNanos, requestHeaders)); + parentPipeline.fireUserEventTriggered( + new Http2WebSocketHandshakeErrorEvent(serial, path, subprotocols, errorNanos, null, t)); + return; + } + parentPipeline.fireExceptionCaught(t); + } + + public static void fireHandshakeStartAndError(Channel parentChannel, int serial, String path, String subprotocols, + Http2Headers requestHeaders, long startNanos, long errorNanos, String errorName, String errorMessage) { + ChannelPipeline parentPipeline = parentChannel.pipeline(); + parentPipeline.fireUserEventTriggered(new Http2WebSocketHandshakeStartEvent(serial, path, subprotocols, + startNanos, requestHeaders)); + parentPipeline.fireUserEventTriggered(new Http2WebSocketHandshakeErrorEvent(serial, path, subprotocols, + errorNanos, null, errorName, errorMessage)); + } + + public static void fireHandshakeStartAndSuccess(Http2WebSocketChannel webSocketChannel, int serial, String path, String subprotocols, + Http2Headers requestHeaders, Http2Headers responseHeaders, long startNanos, long successNanos) { + ChannelPipeline parentPipeline = webSocketChannel.parent().pipeline(); + ChannelPipeline webSocketPipeline = webSocketChannel.pipeline(); + Http2WebSocketHandshakeStartEvent startEvent = new Http2WebSocketHandshakeStartEvent(serial, path, subprotocols, + startNanos, requestHeaders); + Http2WebSocketHandshakeSuccessEvent successEvent = new Http2WebSocketHandshakeSuccessEvent(serial, path, subprotocols, + successNanos, responseHeaders); + parentPipeline.fireUserEventTriggered(startEvent); + parentPipeline.fireUserEventTriggered(successEvent); + webSocketPipeline.fireUserEventTriggered(startEvent); + webSocketPipeline.fireUserEventTriggered(successEvent); + } + + public static void fireHandshakeStart(Http2WebSocketChannel webSocketChannel, Http2Headers requestHeaders, long timestampNanos) { + ChannelPipeline parentPipeline = webSocketChannel.parent().pipeline(); + ChannelPipeline webSocketPipeline = webSocketChannel.pipeline(); + Http2WebSocketHandshakeStartEvent startEvent = new Http2WebSocketHandshakeStartEvent( + webSocketChannel.serial(), + webSocketChannel.path(), + webSocketChannel.subprotocol(), + timestampNanos, + requestHeaders); + parentPipeline.fireUserEventTriggered(startEvent); + webSocketPipeline.fireUserEventTriggered(startEvent); + } + + public static void fireHandshakeError(Http2WebSocketChannel webSocketChannel, Http2Headers responseHeaders, + long timestampNanos, Throwable t) { + ChannelPipeline parentPipeline = webSocketChannel.parent().pipeline(); + if (t instanceof Exception) { + String path = webSocketChannel.path(); + ChannelPipeline webSocketPipeline = webSocketChannel.pipeline(); + Http2WebSocketHandshakeErrorEvent errorEvent = new Http2WebSocketHandshakeErrorEvent(webSocketChannel.serial(), + path, webSocketChannel.subprotocol(), timestampNanos, responseHeaders, t); + parentPipeline.fireUserEventTriggered(errorEvent); + webSocketPipeline.fireUserEventTriggered(errorEvent); + return; + } + parentPipeline.fireExceptionCaught(t); + } + + public static void fireHandshakeSuccess(Http2WebSocketChannel webSocketChannel, Http2Headers responseHeaders, long timestampNanos) { + String path = webSocketChannel.path(); + ChannelPipeline parentPipeline = webSocketChannel.parent().pipeline(); + ChannelPipeline webSocketPipeline = webSocketChannel.pipeline(); + Http2WebSocketHandshakeSuccessEvent successEvent = new Http2WebSocketHandshakeSuccessEvent(webSocketChannel.serial(), + path, webSocketChannel.subprotocol(), timestampNanos, responseHeaders); + parentPipeline.fireUserEventTriggered(successEvent); + webSocketPipeline.fireUserEventTriggered(successEvent); + } + + public Type type() { + return type; + } + + @SuppressWarnings("unchecked") + public T cast() { + return (T) this; + } + + public enum Type { + HANDSHAKE_START, + HANDSHAKE_SUCCESS, + HANDSHAKE_ERROR, + CLOSE_LOCAL_ENDSTREAM, + CLOSE_REMOTE_ENDSTREAM, + CLOSE_REMOTE_RESET, + CLOSE_REMOTE_GOAWAY, + WEIGHT_UPDATE, + WRITE_ERROR + } + + /** + * Represents write error of frames that are not exposed to user code: HEADERS and RST_STREAM + * frames sent by server on handshake, DATA frames with END_STREAM flag for graceful shutdown, + * PRIORITY frames etc. + */ + public static class Http2WebSocketWriteErrorEvent extends Http2WebSocketEvent { + private final String message; + private final Throwable cause; + + Http2WebSocketWriteErrorEvent(String message, Throwable cause) { + super(Type.WRITE_ERROR); + this.message = message; + this.cause = cause; + } + + /** @return frame write error message */ + public String errorMessage() { + return message; + } + + /** @return frame write error */ + public Throwable error() { + return cause; + } + } + + /** Base type for websocket-over-http2 lifecycle events */ + public static class Http2WebSocketLifecycleEvent extends Http2WebSocketEvent { + private final int id; + private final String path; + private final String subprotocol; + private final long timestampNanos; + + Http2WebSocketLifecycleEvent(Type type, int id, String path, String subprotocol, long timestampNanos) { + super(type); + this.id = id; + this.path = path; + this.subprotocol = subprotocol; + this.timestampNanos = timestampNanos; + } + + /** @return id to correlate events of particular websocket */ + public int id() { + return id; + } + + /** @return websocket path */ + public String path() { + return path; + } + + /** @return websocket subprotocol */ + public String subprotocols() { + return subprotocol; + } + + /** @return event timestamp */ + public long timestampNanos() { + return timestampNanos; + } + } + + /** websocket-over-http2 handshake start event */ + public static class Http2WebSocketHandshakeStartEvent extends Http2WebSocketLifecycleEvent { + private final Http2Headers requestHeaders; + + Http2WebSocketHandshakeStartEvent(int id, String path, String subprotocol, long timestampNanos, Http2Headers requestHeaders) { + super(Type.HANDSHAKE_START, id, path, subprotocol, timestampNanos); + this.requestHeaders = requestHeaders; + } + + /** @return websocket request headers */ + public Http2Headers requestHeaders() { + return requestHeaders; + } + } + + /** + * websocket-over-http2 handshake error event + */ + public static class Http2WebSocketHandshakeErrorEvent extends Http2WebSocketLifecycleEvent { + private final Http2Headers responseHeaders; + private final String errorName; + private final String errorMessage; + private final Throwable error; + + Http2WebSocketHandshakeErrorEvent(int id, String path, String subprotocols, + long timestampNanos, Http2Headers responseHeaders, Throwable error) { + this(id, path, subprotocols, timestampNanos, responseHeaders, error, null, null); + } + + Http2WebSocketHandshakeErrorEvent(int id, String path, String subprotocols, + long timestampNanos, Http2Headers responseHeaders, String errorName, String errorMessage) { + this(id, path, subprotocols, timestampNanos, responseHeaders, null, errorName, errorMessage); + } + + private Http2WebSocketHandshakeErrorEvent(int id, String path, String subprotocols, long timestampNanos, + Http2Headers responseHeaders, Throwable error, String errorName, String errorMessage) { + super(Type.HANDSHAKE_ERROR, id, path, subprotocols, timestampNanos); + this.responseHeaders = responseHeaders; + this.errorName = errorName; + this.errorMessage = errorMessage; + this.error = error; + } + + /** @return response headers of failed websocket handshake */ + public Http2Headers responseHeaders() { + return responseHeaders; + } + + /** + * @return exception associated with failed websocket handshake. May be null, in this case + * {@link #errorName()} and {@link #errorMessage()} contain error details. + */ + public Throwable error() { + return error; + } + + /** + * @return name of error associated with failed websocket handshake. May be null, in this case + * {@link #error()} contains respective exception + */ + public String errorName() { + return errorName; + } + + /** + * @return message of error associated with failed websocket handshake. May be null, in this + * case {@link #error()} contains respective exception + */ + public String errorMessage() { + return errorMessage; + } + } + + /** + * websocket-over-http2 handshake success event + */ + public static class Http2WebSocketHandshakeSuccessEvent extends Http2WebSocketLifecycleEvent { + private final Http2Headers responseHeaders; + + Http2WebSocketHandshakeSuccessEvent(int id, String path, String subprotocols, + long timestampNanos, Http2Headers responseHeaders) { + super(Type.HANDSHAKE_SUCCESS, id, path, subprotocols, timestampNanos); + this.responseHeaders = responseHeaders; + } + + /** @return response headers of succeeded websocket handshake */ + public Http2Headers responseHeaders() { + return responseHeaders; + } + } + + /** + * websocket-over-http2 close by remote event. Graceful close is denoted by {@link + * Type#CLOSE_REMOTE_ENDSTREAM}, forced close is denoted by {@link Type#CLOSE_REMOTE_RESET} + */ + public static class Http2WebSocketRemoteCloseEvent extends Http2WebSocketLifecycleEvent { + private Http2WebSocketRemoteCloseEvent(Type type, int id, String path, String subprotocols, long timestampNanos) { + super(type, id, path, subprotocols, timestampNanos); + } + + static Http2WebSocketRemoteCloseEvent endStream(int id, String path, String subprotocols, long timestampNanos) { + return new Http2WebSocketRemoteCloseEvent(Type.CLOSE_REMOTE_ENDSTREAM, id, path, subprotocols, timestampNanos); + } + + static Http2WebSocketRemoteCloseEvent reset(int id, String path, String subprotocols, long timestampNanos) { + return new Http2WebSocketRemoteCloseEvent(Type.CLOSE_REMOTE_RESET, id, path, subprotocols, timestampNanos); + } + } + + /** + * graceful connection close by remote (GO_AWAY) event. + */ + public static class Http2WebSocketRemoteGoAwayEvent extends Http2WebSocketLifecycleEvent { + private final long errorCode; + + Http2WebSocketRemoteGoAwayEvent(int id, String path, String subprotocol, long timestampNanos, long errorCode) { + super(Type.CLOSE_REMOTE_GOAWAY, id, path, subprotocol, timestampNanos); + this.errorCode = errorCode; + } + + /** @return received GO_AWAY frame error code */ + public long errorCode() { + return errorCode; + } + } + + /** + * websocket-over-http2 local graceful close event. Firing {@link + * Http2WebSocketLocalCloseEvent#INSTANCE} on channel pipeline will close associated http2 stream + * locally by sending empty DATA frame with END_STREAN flag set + */ + public static final class Http2WebSocketLocalCloseEvent extends Http2WebSocketEvent { + + public static final Http2WebSocketLocalCloseEvent INSTANCE = new Http2WebSocketLocalCloseEvent(); + + Http2WebSocketLocalCloseEvent() { + super(Type.CLOSE_LOCAL_ENDSTREAM); + } + } + + /** + * websocket-over-http2 stream weight update event. Firing {@link + * Http2WebSocketLocalCloseEvent#INSTANCE} on channel pipeline will send PRIORITY frame for + * associated http2 stream + */ + public static final class Http2WebSocketStreamWeightUpdateEvent extends Http2WebSocketEvent { + + private final short streamWeight; + + Http2WebSocketStreamWeightUpdateEvent(short streamWeight) { + super(Type.WEIGHT_UPDATE); + this.streamWeight = Preconditions.requireRange(streamWeight, 1, 256, "streamWeight"); + } + + public short streamWeight() { + return streamWeight; + } + + public static Http2WebSocketStreamWeightUpdateEvent create(short streamWeight) { + return new Http2WebSocketStreamWeightUpdateEvent(streamWeight); + } + + /** + * @param webSocketChannel websocket-over-http2 channel + * @return weight of http2 stream associated with websocket channel + */ + public static Short streamWeight(Channel webSocketChannel) { + if (webSocketChannel instanceof Http2WebSocketChannel) { + return ((Http2WebSocketChannel) webSocketChannel).streamWeightAttribute(); + } + return null; + } + } + + private static String nonNullString(CharSequence seq) { + if (seq == null) { + return ""; + } + return seq.toString(); + } +} diff --git a/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketExtensions.java b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketExtensions.java new file mode 100644 index 0000000..11bd3e9 --- /dev/null +++ b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketExtensions.java @@ -0,0 +1,86 @@ +package org.xbib.netty.http.common.ws; + +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionData; +import io.netty.util.AsciiString; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public final class Http2WebSocketExtensions { + + static final String HEADER_WEBSOCKET_EXTENSIONS_VALUE_PERMESSAGE_DEFLATE = "permessage-deflate"; + + static final AsciiString HEADER_WEBSOCKET_EXTENSIONS_VALUE_PERMESSAGE_DEFLATE_ASCII = + AsciiString.of(HEADER_WEBSOCKET_EXTENSIONS_VALUE_PERMESSAGE_DEFLATE); + + static final Pattern HEADER_WEBSOCKET_EXTENSIONS_PARAMETER_PATTERN = + Pattern.compile("^([^=]+)(=[\\\"]?([^\\\"]+)[\\\"]?)?$"); + + public static WebSocketExtensionData decode(CharSequence extensionHeader) { + if (extensionHeader == null || extensionHeader.length() == 0) { + return null; + } + AsciiString asciiExtensionHeader = (AsciiString) extensionHeader; + for (AsciiString extension : asciiExtensionHeader.split(',')) { + AsciiString[] extensionParameters = extension.split(';'); + AsciiString name = extensionParameters[0].trim(); + if (HEADER_WEBSOCKET_EXTENSIONS_VALUE_PERMESSAGE_DEFLATE_ASCII.equals(name)) { + Map parameters; + if (extensionParameters.length > 1) { + parameters = new HashMap<>(extensionParameters.length - 1); + for (int i = 1; i < extensionParameters.length; i++) { + AsciiString parameter = extensionParameters[i].trim(); + Matcher parameterMatcher = + HEADER_WEBSOCKET_EXTENSIONS_PARAMETER_PATTERN.matcher(parameter); + if (parameterMatcher.matches()) { + String key = parameterMatcher.group(1); + if (key != null) { + String value = parameterMatcher.group(3); + parameters.put(key, value); + } + } + } + } else { + parameters = Collections.emptyMap(); + } + return new WebSocketExtensionData( + HEADER_WEBSOCKET_EXTENSIONS_VALUE_PERMESSAGE_DEFLATE, parameters); + } + } + return null; + } + + public static String encode(WebSocketExtensionData extensionData) { + String name = extensionData.name(); + Map params = extensionData.parameters(); + if (params.isEmpty()) { + return name; + } + StringBuilder sb = new StringBuilder(sizeOf(name, params)); + sb.append(name); + for (Map.Entry param : params.entrySet()) { + sb.append(";"); + sb.append(param.getKey()); + String value = param.getValue(); + if (value != null) { + sb.append("="); + sb.append(value); + } + } + return sb.toString(); + } + + static int sizeOf(String extensionName, Map extensionParameters) { + int size = extensionName.length(); + for (Map.Entry param : extensionParameters.entrySet()) { + size += param.getKey().length() + 1; + String value = param.getValue(); + if (value != null) { + size += value.length() + 1; + } + } + return size; + } +} diff --git a/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketHandler.java b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketHandler.java new file mode 100644 index 0000000..0559b53 --- /dev/null +++ b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketHandler.java @@ -0,0 +1,126 @@ +package org.xbib.netty.http.common.ws; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http2.Http2ConnectionDecoder; +import io.netty.handler.codec.http2.Http2ConnectionHandler; +import io.netty.handler.codec.http2.Http2Exception; +import io.netty.handler.codec.http2.Http2Flags; +import io.netty.handler.codec.http2.Http2FrameListener; +import io.netty.handler.codec.http2.Http2Headers; +import io.netty.handler.codec.http2.Http2Settings; +import io.netty.util.AsciiString; + +/** + * Base type for client and server websocket-over-http2 handlers + */ +public abstract class Http2WebSocketHandler extends ChannelDuplexHandler implements Http2FrameListener { + + static final AsciiString HEADER_WEBSOCKET_ENDOFSTREAM_NAME = AsciiString.of("x-websocket-endofstream"); + + static final AsciiString HEADER_WEBSOCKET_ENDOFSTREAM_VALUE_TRUE = AsciiString.of("true"); + + static final AsciiString HEADER_WEBSOCKET_ENDOFSTREAM_VALUE_FALSE = AsciiString.of("false"); + + protected Http2ConnectionHandler http2Handler; + + protected Http2FrameListener next; + + public Http2WebSocketHandler() {} + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + Http2ConnectionHandler http2Handler = this.http2Handler = Preconditions.requireHandler(ctx.channel(), Http2ConnectionHandler.class); + Http2ConnectionDecoder decoder = http2Handler.decoder(); + Http2FrameListener next = decoder.frameListener(); + decoder.frameListener(this); + this.next = next; + } + + @Override + public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) throws Http2Exception { + next().onGoAwayRead(ctx, lastStreamId, errorCode, debugData); + } + + @Override + public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception { + next().onRstStreamRead(ctx, streamId, errorCode); + } + + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) throws Http2Exception { + return next().onDataRead(ctx, streamId, data.retain(), padding, endOfStream); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, + boolean endOfStream) throws Http2Exception { + next().onHeadersRead(ctx, streamId, headers, padding, endOfStream); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int streamDependency, short weight, boolean exclusive, int padding, boolean endOfStream) throws Http2Exception { + next().onHeadersRead(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream); + } + + @Override + public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight, + boolean exclusive) throws Http2Exception { + next().onPriorityRead(ctx, streamId, streamDependency, weight, exclusive); + } + + @Override + public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception { + next().onSettingsAckRead(ctx); + } + + @Override + public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) + throws Http2Exception { + next().onSettingsRead(ctx, settings); + } + + @Override + public void onPingRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + next().onPingRead(ctx, data); + } + + @Override + public void onPingAckRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + next().onPingAckRead(ctx, data); + } + + @Override + public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId, + Http2Headers headers, int padding) throws Http2Exception { + next().onPushPromiseRead(ctx, streamId, promisedStreamId, headers, padding); + } + + @Override + public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) + throws Http2Exception { + next().onWindowUpdateRead(ctx, streamId, windowSizeIncrement); + } + + @Override + public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, ByteBuf payload) + throws Http2Exception { + next().onUnknownFrame(ctx, frameType, streamId, flags, payload); + } + + protected final Http2FrameListener next() { + return next; + } + + static AsciiString endOfStreamName() { + return HEADER_WEBSOCKET_ENDOFSTREAM_NAME; + } + + static AsciiString endOfStreamValue(boolean endOfStream) { + return endOfStream + ? HEADER_WEBSOCKET_ENDOFSTREAM_VALUE_TRUE + : HEADER_WEBSOCKET_ENDOFSTREAM_VALUE_FALSE; + } +} diff --git a/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketHandlerContainers.java b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketHandlerContainers.java new file mode 100644 index 0000000..4d9d89c --- /dev/null +++ b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketHandlerContainers.java @@ -0,0 +1,172 @@ +package org.xbib.netty.http.common.ws; + +import io.netty.util.collection.IntCollections; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; +import java.util.Collection; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; + +final class Http2WebSocketHandlerContainers { + + static final class SingleElementOptimizedMap implements IntObjectMap { + int singleKey; + T singleValue; + IntObjectMap delegate = IntCollections.emptyMap(); + + @Override + public T get(int key) { + int sk = singleKey; + if (key == sk) { + return singleValue; + } + if (sk == -1) { + return delegate.get(key); + } + return null; + } + + @Override + public T put(int key, T value) { + int sk = singleKey; + if (sk == 0 || key == sk) { + T sv = singleValue; + singleKey = key; + singleValue = value; + return sv; + } + IntObjectMap d = delegate; + if (d.isEmpty()) { + d = delegate = new IntObjectHashMap<>(4); + d.put(sk, singleValue); + singleKey = -1; + singleValue = null; + } + return d.put(key, value); + } + + @Override + public T remove(int key) { + int sk = singleKey; + if (key == sk) { + T sv = singleValue; + singleKey = 0; + singleValue = null; + return sv; + } + if (sk == -1) { + IntObjectMap d = delegate; + T removed = d.remove(key); + if (d.isEmpty()) { + singleKey = 0; + delegate = IntCollections.emptyMap(); + } + return removed; + } + return null; + } + + @Override + public boolean containsKey(int key) { + int sk = singleKey; + return sk == key || sk == -1 && delegate.containsKey(key); + } + + @Override + public int size() { + int sk = singleKey; + switch (sk) { + case 0: + return 0; + case -1: + return delegate.size(); + /*sk > 0*/ + default: + return 1; + } + } + + @Override + public boolean isEmpty() { + return singleKey == 0; + } + + @Override + public void clear() { + singleKey = 0; + singleValue = null; + delegate = IntCollections.emptyMap(); + } + + @Override + public void forEach(BiConsumer action) { + int sk = singleKey; + if (sk > 0) { + action.accept(sk, singleValue); + } else if (sk == -1) { + delegate.forEach(action); + } + } + + @Override + public Iterable> entries() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public boolean containsKey(Object key) { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public boolean containsValue(Object value) { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public T get(Object key) { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public T put(Integer key, T value) { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public T remove(Object key) { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public Set keySet() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public Collection values() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public Set> entrySet() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public boolean equals(Object o) { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public int hashCode() { + throw new UnsupportedOperationException("Not implemented"); + } + } +} diff --git a/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketMessages.java b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketMessages.java new file mode 100644 index 0000000..2f3cdb9 --- /dev/null +++ b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketMessages.java @@ -0,0 +1,27 @@ +package org.xbib.netty.http.common.ws; + +public interface Http2WebSocketMessages { + String HANDSHAKE_UNEXPECTED_RESULT = + "websocket handshake error: unexpected result - status=200, end_of_stream=true"; + String HANDSHAKE_UNSUPPORTED_VERSION = + "websocket handshake error: unsupported version; supported versions - "; + String HANDSHAKE_BAD_REQUEST = + "websocket handshake error: bad request"; + String HANDSHAKE_PATH_NOT_FOUND = + "websocket handshake error: path not found - "; + String HANDSHAKE_PATH_NOT_FOUND_SUBPROTOCOLS = + ", subprotocols - "; + String HANDSHAKE_UNEXPECTED_SUBPROTOCOL = + "websocket handshake error: unexpected subprotocol - "; + String HANDSHAKE_GENERIC_ERROR = + "websocket handshake error: "; + String HANDSHAKE_UNSUPPORTED_ACCEPTOR_TYPE = + "websocket handshake error: async acceptors are not supported"; + String HANDSHAKE_UNSUPPORTED_BOOTSTRAP = + "websocket handshake error: bootstrapping websockets with http2 is not supported by server"; + String HANDSHAKE_INVALID_REQUEST_HEADERS = + "websocket handshake error: invalid request headers"; + String HANDSHAKE_INVALID_RESPONSE_HEADERS = + "websocket handshake error: invalid response headers"; + String WRITE_ERROR = "websocket frame write error"; +} diff --git a/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketProtocol.java b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketProtocol.java new file mode 100644 index 0000000..372d153 --- /dev/null +++ b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketProtocol.java @@ -0,0 +1,41 @@ +package org.xbib.netty.http.common.ws; + +import io.netty.handler.codec.http2.Http2Headers; +import io.netty.util.AsciiString; + +public final class Http2WebSocketProtocol { + + public static final char SETTINGS_ENABLE_CONNECT_PROTOCOL = 8; + + public static final AsciiString HEADER_METHOD_CONNECT = AsciiString.of("CONNECT"); + + public static final AsciiString HEADER_PROTOCOL_NAME = AsciiString.of(":protocol"); + + public static final AsciiString HEADER_PROTOCOL_VALUE = AsciiString.of("websocket"); + + public static final AsciiString SCHEME_HTTP = AsciiString.of("http"); + + public static final AsciiString SCHEME_HTTPS = AsciiString.of("https"); + + public static final AsciiString HEADER_WEBSOCKET_VERSION_NAME = AsciiString.of("sec-websocket-version"); + + public static final AsciiString HEADER_WEBSOCKET_VERSION_VALUE = AsciiString.of("13"); + + public static final AsciiString HEADER_WEBSOCKET_SUBPROTOCOL_NAME = AsciiString.of("sec-websocket-protocol"); + + public static final AsciiString HEADER_WEBSOCKET_EXTENSIONS_NAME = AsciiString.of("sec-websocket-extensions"); + + public static final AsciiString HEADER_PROTOCOL_NAME_HANDSHAKED = AsciiString.of("x-protocol"); + + public static final AsciiString HEADER_METHOD_CONNECT_HANDSHAKED = AsciiString.of("POST"); + + public static Http2Headers extendedConnect(Http2Headers headers) { + return headers.method(Http2WebSocketProtocol.HEADER_METHOD_CONNECT) + .set(Http2WebSocketProtocol.HEADER_PROTOCOL_NAME, Http2WebSocketProtocol.HEADER_PROTOCOL_VALUE); + } + + public static boolean isExtendedConnect(Http2Headers headers) { + return HEADER_METHOD_CONNECT.equals(headers.method()) + && HEADER_PROTOCOL_VALUE.equals(headers.get(HEADER_PROTOCOL_NAME)); + } +} diff --git a/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketValidator.java b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketValidator.java new file mode 100644 index 0000000..755a7bb --- /dev/null +++ b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Http2WebSocketValidator.java @@ -0,0 +1,156 @@ +package org.xbib.netty.http.common.ws; + +import io.netty.handler.codec.http2.Http2Headers; +import io.netty.util.AsciiString; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +public final class Http2WebSocketValidator { + static final AsciiString PSEUDO_HEADER_METHOD = AsciiString.of(":method"); + static final AsciiString PSEUDO_HEADER_SCHEME = AsciiString.of(":scheme"); + static final AsciiString PSEUDO_HEADER_AUTHORITY = AsciiString.of(":authority"); + static final AsciiString PSEUDO_HEADER_PATH = AsciiString.of(":path"); + static final AsciiString PSEUDO_HEADER_PROTOCOL = AsciiString.of(":protocol"); + static final AsciiString PSEUDO_HEADER_STATUS = AsciiString.of(":status"); + static final AsciiString PSEUDO_HEADER_METHOD_CONNECT = AsciiString.of("connect"); + + static final AsciiString HEADER_CONNECTION = AsciiString.of("connection"); + static final AsciiString HEADER_KEEPALIVE = AsciiString.of("keep-alive"); + static final AsciiString HEADER_PROXY_CONNECTION = AsciiString.of("proxy-connection"); + static final AsciiString HEADER_TRANSFER_ENCODING = AsciiString.of("transfer-encoding"); + static final AsciiString HEADER_UPGRADE = AsciiString.of("upgrade"); + static final AsciiString HEADER_TE = AsciiString.of("te"); + static final AsciiString HEADER_TE_TRAILERS = AsciiString.of("trailers"); + + static final Set INVALID_HEADERS = invalidHeaders(); + + public static boolean isValid(final Http2Headers responseHeaders) { + boolean isFirst = true; + for (Map.Entry header : responseHeaders) { + CharSequence name = header.getKey(); + if (isFirst) { + if (!PSEUDO_HEADER_STATUS.equals(name) || isEmpty(header.getValue())) { + return false; + } + isFirst = false; + } else if (Http2Headers.PseudoHeaderName.hasPseudoHeaderFormat(name)) { + return false; + } + } + return containsValidHeaders(responseHeaders); + } + + static boolean containsValidPseudoHeaders( + Http2Headers requestHeaders, Set validPseudoHeaders) { + for (Map.Entry header : requestHeaders) { + CharSequence name = header.getKey(); + if (!Http2Headers.PseudoHeaderName.hasPseudoHeaderFormat(name)) { + break; + } + if (!validPseudoHeaders.contains(name)) { + return false; + } + } + return true; + } + + static boolean containsValidHeaders(Http2Headers headers) { + for (CharSequence invalidHeader : INVALID_HEADERS) { + if (headers.contains(invalidHeader)) { + return false; + } + } + CharSequence te = headers.get(HEADER_TE); + return te == null || HEADER_TE_TRAILERS.equals(te); + } + + static Set validPseudoHeaders() { + Set result = new HashSet<>(); + result.add(PSEUDO_HEADER_SCHEME); + result.add(PSEUDO_HEADER_AUTHORITY); + result.add(PSEUDO_HEADER_PATH); + result.add(PSEUDO_HEADER_METHOD); + return result; + } + + private static Set invalidHeaders() { + Set result = new HashSet<>(); + result.add(HEADER_CONNECTION); + result.add(HEADER_KEEPALIVE); + result.add(HEADER_PROXY_CONNECTION); + result.add(HEADER_TRANSFER_ENCODING); + result.add(HEADER_UPGRADE); + return result; + } + + static boolean isEmpty(CharSequence seq) { + return seq == null || seq.length() == 0; + } + + static boolean isHttp(CharSequence scheme) { + return Http2WebSocketProtocol.SCHEME_HTTPS.equals(scheme) + || Http2WebSocketProtocol.SCHEME_HTTP.equals(scheme); + } + + public static class Http { + + private static final Set VALID_PSEUDO_HEADERS = validPseudoHeaders(); + + public static boolean isValid(final Http2Headers requestHeaders, boolean endOfStream) { + AsciiString authority = AsciiString.of(requestHeaders.authority()); + /*must be non-empty, not include userinfo subcomponent*/ + if (isEmpty(authority) || authority.contains("@")) { + return false; + } + AsciiString method = AsciiString.of(requestHeaders.method()); + if (isEmpty(method)) { + return false; + } + AsciiString scheme = AsciiString.of(requestHeaders.scheme()); + AsciiString path = AsciiString.of(requestHeaders.path()); + if (method.equals(PSEUDO_HEADER_METHOD_CONNECT)) { + if (!isEmpty(scheme) || !isEmpty(path)) { + return false; + } + } else { + if (isEmpty(scheme)) { + return false; + } + if (isEmpty(path) && isHttp(scheme)) { + return false; + } + } + return containsValidPseudoHeaders(requestHeaders, VALID_PSEUDO_HEADERS) + && containsValidHeaders(requestHeaders); + } + } + + public static class WebSocket { + + private static final Set VALID_PSEUDO_HEADERS; + + static { + Set headers = VALID_PSEUDO_HEADERS = validPseudoHeaders(); + headers.add(PSEUDO_HEADER_PROTOCOL); + } + + public static boolean isValid(final Http2Headers requestHeaders, boolean endOfStream) { + if (endOfStream) { + return false; + } + if (!isHttp(requestHeaders.scheme())) { + return false; + } + AsciiString authority = AsciiString.of(requestHeaders.authority()); + if (isEmpty(authority) || authority.contains("@")) { + return false; + } + if (isEmpty(requestHeaders.path())) { + return false; + } + return containsValidPseudoHeaders(requestHeaders, VALID_PSEUDO_HEADERS) + && containsValidHeaders(requestHeaders); + } + } +} diff --git a/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Preconditions.java b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Preconditions.java new file mode 100644 index 0000000..e37f27d --- /dev/null +++ b/netty-http-common/src/main/java/org/xbib/netty/http/common/ws/Preconditions.java @@ -0,0 +1,52 @@ +package org.xbib.netty.http.common.ws; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; + +public final class Preconditions { + + public static T requireNonNull(T t, String message) { + if (t == null) { + throw new IllegalArgumentException(message + " must be non null"); + } + return t; + } + + public static String requireNonEmpty(String string, String message) { + if (string == null || string.isEmpty()) { + throw new IllegalArgumentException(message + " must be non empty"); + } + return string; + } + + public static T requireHandler(Channel channel, Class handler) { + T h = channel.pipeline().get(handler); + if (h == null) { + throw new IllegalArgumentException( + handler.getSimpleName() + " is absent in the channel pipeline"); + } + return h; + } + + public static long requirePositive(long value, String message) { + if (value <= 0) { + throw new IllegalArgumentException(message + " must be positive: " + value); + } + return value; + } + + public static int requireNonNegative(int value, String message) { + if (value < 0) { + throw new IllegalArgumentException(message + " must be non-negative: " + value); + } + return value; + } + + public static short requireRange(int value, int from, int to, String message) { + if (value >= from && value <= to) { + return (short) value; + } + throw new IllegalArgumentException( + String.format("%s must belong to range [%d, %d]: ", message, from, to)); + } +} diff --git a/netty-http-server-api/src/main/java/module-info.java b/netty-http-server-api/src/main/java/module-info.java index f74d26d..0b3919b 100644 --- a/netty-http-server-api/src/main/java/module-info.java +++ b/netty-http-server-api/src/main/java/module-info.java @@ -1,7 +1,5 @@ module org.xbib.netty.http.server.api { exports org.xbib.netty.http.server.api; - exports org.xbib.netty.http.server.api.annotation; - exports org.xbib.netty.http.server.api.security; requires org.xbib.netty.http.common; requires org.xbib.net.url; requires io.netty.buffer; diff --git a/netty-http-server-api/src/main/java/org/xbib/netty/http/server/api/security/ServerCertificateProvider.java b/netty-http-server-api/src/main/java/org/xbib/netty/http/server/api/ServerCertificateProvider.java similarity index 93% rename from netty-http-server-api/src/main/java/org/xbib/netty/http/server/api/security/ServerCertificateProvider.java rename to netty-http-server-api/src/main/java/org/xbib/netty/http/server/api/ServerCertificateProvider.java index 9515da3..c9671a5 100644 --- a/netty-http-server-api/src/main/java/org/xbib/netty/http/server/api/security/ServerCertificateProvider.java +++ b/netty-http-server-api/src/main/java/org/xbib/netty/http/server/api/ServerCertificateProvider.java @@ -1,4 +1,4 @@ -package org.xbib.netty.http.server.api.security; +package org.xbib.netty.http.server.api; import java.io.InputStream; diff --git a/netty-http-server-api/src/main/java/org/xbib/netty/http/server/api/ServerConfig.java b/netty-http-server-api/src/main/java/org/xbib/netty/http/server/api/ServerConfig.java index 65eb31d..3bcf6fa 100644 --- a/netty-http-server-api/src/main/java/org/xbib/netty/http/server/api/ServerConfig.java +++ b/netty-http-server-api/src/main/java/org/xbib/netty/http/server/api/ServerConfig.java @@ -1,11 +1,16 @@ package org.xbib.netty.http.server.api; +import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.WriteBufferWaterMark; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.codec.http2.Http2Settings; import io.netty.handler.logging.LogLevel; import io.netty.handler.ssl.CipherSuiteFilter; +import io.netty.handler.ssl.OpenSsl; import io.netty.handler.ssl.SslProvider; import org.xbib.netty.http.common.HttpAddress; +import org.xbib.netty.http.common.security.SecurityUtil; + import java.security.KeyStore; import java.security.Provider; import java.util.Collection; @@ -61,8 +66,6 @@ public interface ServerConfig { boolean isCompressionEnabled(); - int getCompressionThreshold(); - boolean isDecompressionEnabled(); boolean isInstallHttp2Upgrade(); @@ -92,4 +95,174 @@ public interface ServerConfig { Domain> getDomain(String name); Domain> getDefaultDomain(); + + SimpleChannelInboundHandler getWebSocketFrameHandler(); + + interface Defaults { + + /** + * Default bind address. We do not want to use port 80 or 8080. + */ + HttpAddress ADDRESS = HttpAddress.http1("localhost", 8008); + + /** + * If frame logging/traffic logging is enabled or not. + */ + boolean DEBUG = false; + + /** + * Default debug log level. + */ + LogLevel DEBUG_LOG_LEVEL = LogLevel.DEBUG; + + String TRANSPORT_PROVIDER_NAME = null; + + /** + * Let Netty decide about parent thread count. + */ + int PARENT_THREAD_COUNT = 0; + + /** + * Let Netty decide about child thread count. + */ + int CHILD_THREAD_COUNT = 0; + + /** + * Blocking thread pool count. Disabled by default, use Netty threads. + */ + int BLOCKING_THREAD_COUNT = 0; + + /** + * Blocking thread pool queue count. Disabled by default, use Netty threads. + */ + int BLOCKING_QUEUE_COUNT = 0; + + /** + * Default for SO_REUSEADDR. + */ + boolean SO_REUSEADDR = true; + + /** + * Default for TCP_NODELAY. + */ + boolean TCP_NODELAY = true; + + /** + * Set TCP send buffer to 64k per socket. + */ + int TCP_SEND_BUFFER_SIZE = 64 * 1024; + + /** + * Set TCP receive buffer to 64k per socket. + */ + int TCP_RECEIVE_BUFFER_SIZE = 64 * 1024; + + /** + * Default for socket back log. + */ + int SO_BACKLOG = 10 * 1024; + + /** + * Default connect timeout in milliseconds. + */ + int CONNECT_TIMEOUT_MILLIS = 5000; + + /** + * Default connect timeout in milliseconds. + */ + int READ_TIMEOUT_MILLIS = 15000; + + /** + * Default idle timeout in milliseconds. + */ + int IDLE_TIMEOUT_MILLIS = 60000; + + /** + * Set HTTP chunk maximum size to 8k. + * See {@link io.netty.handler.codec.http.HttpClientCodec}. + */ + int MAX_CHUNK_SIZE = 8 * 1024; + + /** + * Set HTTP initial line length to 4k. + * See {@link io.netty.handler.codec.http.HttpClientCodec}. + */ + int MAX_INITIAL_LINE_LENGTH = 4 * 1024; + + /** + * Set HTTP maximum headers size to 8k. + * See {@link io.netty.handler.codec.http.HttpClientCodec}. + */ + int MAX_HEADERS_SIZE = 8 * 1024; + + /** + * Set maximum content length to 256 MB. + */ + int MAX_CONTENT_LENGTH = 256 * 1024 * 1024; + + /** + * HTTP/1 pipelining capacity. 1024 is very high, it means + * 1024 requests can be present for a single client. + */ + int PIPELINING_CAPACITY = 1024; + + /** + * This is Netty's default. + */ + int MAX_COMPOSITE_BUFFER_COMPONENTS = 1024; + + /** + * Default write buffer water mark. + */ + WriteBufferWaterMark WRITE_BUFFER_WATER_MARK = WriteBufferWaterMark.DEFAULT; + + /** + * Default for compression. + */ + boolean ENABLE_COMPRESSION = true; + + /** + * Default for decompression. + */ + boolean ENABLE_DECOMPRESSION = true; + + /** + * Default HTTP/2 settings. + */ + Http2Settings HTTP_2_SETTINGS = Http2Settings.defaultSettings(); + + /** + * Default for HTTP/2 upgrade under HTTP 1. + */ + boolean INSTALL_HTTP_UPGRADE2 = false; + + /** + * Default SSL provider. + */ + SslProvider SSL_PROVIDER = SecurityUtil.Defaults.DEFAULT_SSL_PROVIDER; + + /** + * Default SSL context provider (for JDK SSL only). + */ + Provider SSL_CONTEXT_PROVIDER = null; + + /** + * Transport layer security protocol versions. + * Do not use SSLv2, SSLv3, TLS 1.0, TLS 1.1. + */ + String[] PROTOCOLS = OpenSsl.isAvailable() && OpenSsl.version() <= 0x10101009L ? + new String[] { "TLSv1.2" } : + new String[] { "TLSv1.3", "TLSv1.2" }; + + /** + * Default ciphers. We care about HTTP/2. + */ + Iterable CIPHERS = SecurityUtil.Defaults.DEFAULT_CIPHERS; + + /** + * Default cipher suite filter. + */ + CipherSuiteFilter CIPHER_SUITE_FILTER = SecurityUtil.Defaults.DEFAULT_CIPHER_SUITE_FILTER; + + } } diff --git a/netty-http-server-api/src/main/java/org/xbib/netty/http/server/api/annotation/Endpoint.java b/netty-http-server-api/src/main/java/org/xbib/netty/http/server/api/annotation/Endpoint.java deleted file mode 100644 index 3c740d4..0000000 --- a/netty-http-server-api/src/main/java/org/xbib/netty/http/server/api/annotation/Endpoint.java +++ /dev/null @@ -1,35 +0,0 @@ -package org.xbib.netty.http.server.api.annotation; - -import org.xbib.netty.http.server.api.Filter; - -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; - -/** - * The {@code Endpoint} annotation decorates methods which are mapped - * to a HTTP endpoint within the server, and provide its contents. - * The annotated methods must have the same signature and contract - * as {@link Filter#handle}, but can have arbitrary names. - */ -@Retention(RetentionPolicy.RUNTIME) -@Target(ElementType.METHOD) -public @interface Endpoint { - - /** - * The path that this field maps to (must begin with '/'). - * - * @return the path that this field maps to - */ - String path(); - - /** - * The HTTP methods supported by this endpoint (default is "GET" and "HEAD"). - * - * @return the HTTP methods supported by this endpoint - */ - String[] methods() default {"GET", "HEAD"}; - - String[] contentTypes(); -} diff --git a/netty-http-server/src/main/java/module-info.java b/netty-http-server/src/main/java/module-info.java index 09f2f9a..5c49bcb 100644 --- a/netty-http-server/src/main/java/module-info.java +++ b/netty-http-server/src/main/java/module-info.java @@ -1,8 +1,9 @@ +import org.xbib.netty.http.server.api.ServerCertificateProvider; import org.xbib.netty.http.server.protocol.http1.Http1; import org.xbib.netty.http.server.protocol.http2.Http2; module org.xbib.netty.http.server { - uses org.xbib.netty.http.server.api.security.ServerCertificateProvider; + uses ServerCertificateProvider; uses org.xbib.netty.http.server.api.ServerProtocolProvider; uses org.xbib.netty.http.common.TransportProvider; exports org.xbib.netty.http.server; diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/BaseTransport.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/BaseTransport.java index 860fb99..435ba7c 100644 --- a/netty-http-server/src/main/java/org/xbib/netty/http/server/BaseTransport.java +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/BaseTransport.java @@ -34,8 +34,7 @@ public abstract class BaseTransport implements ServerTransport { * @param reqHeaders the request headers * @return whether further processing should be performed */ - protected static AcceptState acceptRequest(HttpVersion httpVersion, - HttpHeaders reqHeaders) { + protected static AcceptState acceptRequest(HttpVersion httpVersion, HttpHeaders reqHeaders) { if (httpVersion.majorVersion() == 1 || httpVersion.majorVersion() == 2) { if (!reqHeaders.contains(HttpHeaderNames.HOST)) { // RFC2616#14.23: missing Host header gets 400 diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/DefaultServerConfig.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/DefaultServerConfig.java index 01d2507..c961888 100644 --- a/netty-http-server/src/main/java/org/xbib/netty/http/server/DefaultServerConfig.java +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/DefaultServerConfig.java @@ -1,10 +1,11 @@ package org.xbib.netty.http.server; +import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.WriteBufferWaterMark; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.codec.http2.Http2Settings; import io.netty.handler.logging.LogLevel; import io.netty.handler.ssl.CipherSuiteFilter; -import io.netty.handler.ssl.OpenSsl; import io.netty.handler.ssl.SslProvider; import org.xbib.netty.http.common.HttpAddress; import org.xbib.netty.http.common.security.SecurityUtil; @@ -21,180 +22,6 @@ import javax.net.ssl.TrustManagerFactory; public class DefaultServerConfig implements ServerConfig { - interface Defaults { - - /** - * Default bind address. We do not want to use port 80 or 8080. - */ - HttpAddress ADDRESS = HttpAddress.http1("localhost", 8008); - - /** - * If frame logging/traffic logging is enabled or not. - */ - boolean DEBUG = false; - - /** - * Default debug log level. - */ - LogLevel DEBUG_LOG_LEVEL = LogLevel.DEBUG; - - String TRANSPORT_PROVIDER_NAME = null; - - /** - * Let Netty decide about parent thread count. - */ - int PARENT_THREAD_COUNT = 0; - - /** - * Let Netty decide about child thread count. - */ - int CHILD_THREAD_COUNT = 0; - - /** - * Blocking thread pool count. Disabled by default, use Netty threads. - */ - int BLOCKING_THREAD_COUNT = 0; - - /** - * Blocking thread pool queue count. Disabled by default, use Netty threads. - */ - int BLOCKING_QUEUE_COUNT = 0; - - /** - * Default for SO_REUSEADDR. - */ - boolean SO_REUSEADDR = true; - - /** - * Default for TCP_NODELAY. - */ - boolean TCP_NODELAY = true; - - /** - * Set TCP send buffer to 64k per socket. - */ - int TCP_SEND_BUFFER_SIZE = 64 * 1024; - - /** - * Set TCP receive buffer to 64k per socket. - */ - int TCP_RECEIVE_BUFFER_SIZE = 64 * 1024; - - /** - * Default for socket back log. - */ - int SO_BACKLOG = 10 * 1024; - - /** - * Default connect timeout in milliseconds. - */ - int CONNECT_TIMEOUT_MILLIS = 5000; - - /** - * Default connect timeout in milliseconds. - */ - int READ_TIMEOUT_MILLIS = 15000; - - /** - * Default idle timeout in milliseconds. - */ - int IDLE_TIMEOUT_MILLIS = 60000; - - /** - * Set HTTP chunk maximum size to 8k. - * See {@link io.netty.handler.codec.http.HttpClientCodec}. - */ - int MAX_CHUNK_SIZE = 8 * 1024; - - /** - * Set HTTP initial line length to 4k. - * See {@link io.netty.handler.codec.http.HttpClientCodec}. - */ - int MAX_INITIAL_LINE_LENGTH = 4 * 1024; - - /** - * Set HTTP maximum headers size to 8k. - * See {@link io.netty.handler.codec.http.HttpClientCodec}. - */ - int MAX_HEADERS_SIZE = 8 * 1024; - - /** - * Set maximum content length to 256 MB. - */ - int MAX_CONTENT_LENGTH = 256 * 1024 * 1024; - - /** - * HTTP/1 pipelining capacity. 1024 is very high, it means - * 1024 requests can be present for a single client. - */ - int PIPELINING_CAPACITY = 1024; - - /** - * This is Netty's default. - */ - int MAX_COMPOSITE_BUFFER_COMPONENTS = 1024; - - /** - * Default write buffer water mark. - */ - WriteBufferWaterMark WRITE_BUFFER_WATER_MARK = WriteBufferWaterMark.DEFAULT; - - /** - * Default for compression. - */ - boolean ENABLE_COMPRESSION = true; - - /** - * Default compression threshold. If a response size is over this value, - * it will be compressed, otherwise not. - */ - int COMPRESSION_THRESHOLD = 8192; - - /** - * Default for decompression. - */ - boolean ENABLE_DECOMPRESSION = true; - - /** - * Default HTTP/2 settings. - */ - Http2Settings HTTP_2_SETTINGS = Http2Settings.defaultSettings(); - - /** - * Default for HTTP/2 upgrade under HTTP 1. - */ - boolean INSTALL_HTTP_UPGRADE2 = false; - - /** - * Default SSL provider. - */ - SslProvider SSL_PROVIDER = SecurityUtil.Defaults.DEFAULT_SSL_PROVIDER; - - /** - * Default SSL context provider (for JDK SSL only). - */ - Provider SSL_CONTEXT_PROVIDER = null; - - /** - * Transport layer security protocol versions. - * Do not use SSLv2, SSLv3, TLS 1.0, TLS 1.1. - */ - String[] PROTOCOLS = OpenSsl.isAvailable() && OpenSsl.version() <= 0x10101009L ? - new String[] { "TLSv1.2" } : - new String[] { "TLSv1.3", "TLSv1.2" }; - - /** - * Default ciphers. We care about HTTP/2. - */ - Iterable CIPHERS = SecurityUtil.Defaults.DEFAULT_CIPHERS; - - /** - * Default cipher suite filter. - */ - CipherSuiteFilter CIPHER_SUITE_FILTER = SecurityUtil.Defaults.DEFAULT_CIPHER_SUITE_FILTER; - - } - private HttpAddress httpAddress = Defaults.ADDRESS; private boolean debug = Defaults.DEBUG; @@ -243,8 +70,6 @@ public class DefaultServerConfig implements ServerConfig { private boolean enableCompression = Defaults.ENABLE_COMPRESSION; - private int compressionThreshold = Defaults.COMPRESSION_THRESHOLD; - private boolean enableDecompression = Defaults.ENABLE_DECOMPRESSION; private Http2Settings http2Settings = Defaults.HTTP_2_SETTINGS; @@ -271,6 +96,8 @@ public class DefaultServerConfig implements ServerConfig { private boolean acceptInvalidCertificates = false; + private SimpleChannelInboundHandler webSocketFrameHandler; + public DefaultServerConfig() { this.domains = new LinkedList<>(); } @@ -492,15 +319,6 @@ public class DefaultServerConfig implements ServerConfig { return enableCompression; } - public ServerConfig setCompressionThreshold(int compressionThreshold) { - this.compressionThreshold = compressionThreshold; - return this; - } - - public int getCompressionThreshold() { - return compressionThreshold; - } - public ServerConfig setDecompression(boolean enabled) { this.enableDecompression = enabled; return this; @@ -642,4 +460,14 @@ public class DefaultServerConfig implements ServerConfig { domains.stream().filter(d -> d.getName().equals("*")).findFirst(); return domainOptional.orElse(domains.getFirst()); } + + public ServerConfig setWebSocketFrameHandler(SimpleChannelInboundHandler webSocketFrameHandler) { + this.webSocketFrameHandler = webSocketFrameHandler; + return this; + } + + @Override + public SimpleChannelInboundHandler getWebSocketFrameHandler() { + return webSocketFrameHandler; + } } diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/HttpServerDomain.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/HttpServerDomain.java index e6d5dc8..c7a521c 100644 --- a/netty-http-server/src/main/java/org/xbib/netty/http/server/HttpServerDomain.java +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/HttpServerDomain.java @@ -12,7 +12,7 @@ import org.xbib.netty.http.common.HttpAddress; import org.xbib.netty.http.common.HttpMethod; import org.xbib.netty.http.server.api.Domain; import org.xbib.netty.http.server.api.EndpointResolver; -import org.xbib.netty.http.server.api.security.ServerCertificateProvider; +import org.xbib.netty.http.server.api.ServerCertificateProvider; import org.xbib.netty.http.common.security.SecurityUtil; import org.xbib.netty.http.server.api.ServerRequest; import org.xbib.netty.http.server.api.ServerResponse; diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/Server.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/Server.java index 97cb624..92caeaa 100644 --- a/netty-http-server/src/main/java/org/xbib/netty/http/server/Server.java +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/Server.java @@ -5,12 +5,14 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.WriteBufferWaterMark; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.ServerSocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.handler.logging.LoggingHandler; import io.netty.handler.ssl.SslContext; import io.netty.util.DomainWildcardMappingBuilder; @@ -21,6 +23,7 @@ import org.xbib.netty.http.common.HttpChannelInitializer; import org.xbib.netty.http.common.TransportProvider; import org.xbib.netty.http.server.api.Domain; import org.xbib.netty.http.server.api.EndpointResolver; +import org.xbib.netty.http.server.api.ServerConfig; import org.xbib.netty.http.server.api.ServerProtocolProvider; import org.xbib.netty.http.server.api.ServerRequest; import org.xbib.netty.http.server.api.ServerResponse; @@ -66,7 +69,7 @@ public final class Server implements AutoCloseable { } } - private final DefaultServerConfig serverConfig; + private final ServerConfig serverConfig; private final EventLoopGroup parentEventLoopGroup; @@ -99,7 +102,7 @@ public final class Server implements AutoCloseable { * @param executor an extra blocking thread pool executor or null */ @SuppressWarnings("unchecked") - private Server(DefaultServerConfig serverConfig, + private Server(ServerConfig serverConfig, ByteBufAllocator byteBufAllocator, EventLoopGroup parentEventLoopGroup, EventLoopGroup childEventLoopGroup, @@ -177,7 +180,7 @@ public final class Server implements AutoCloseable { return new Builder(httpServerDomain); } - public DefaultServerConfig getServerConfig() { + public ServerConfig getServerConfig() { return serverConfig; } @@ -373,7 +376,7 @@ public final class Server implements AutoCloseable { throw new IllegalStateException("no channel initializer found for major version " + majorVersion); } - private static EventLoopGroup createParentEventLoopGroup(DefaultServerConfig serverConfig, + private static EventLoopGroup createParentEventLoopGroup(ServerConfig serverConfig, EventLoopGroup parentEventLoopGroup ) { EventLoopGroup eventLoopGroup = parentEventLoopGroup; if (eventLoopGroup == null) { @@ -391,7 +394,7 @@ public final class Server implements AutoCloseable { return eventLoopGroup; } - private static EventLoopGroup createChildEventLoopGroup(DefaultServerConfig serverConfig, + private static EventLoopGroup createChildEventLoopGroup(ServerConfig serverConfig, EventLoopGroup childEventLoopGroup ) { EventLoopGroup eventLoopGroup = childEventLoopGroup; if (eventLoopGroup == null) { @@ -409,7 +412,7 @@ public final class Server implements AutoCloseable { return eventLoopGroup; } - private static Class createSocketChannelClass(DefaultServerConfig serverConfig, + private static Class createSocketChannelClass(ServerConfig serverConfig, Class socketChannelClass) { Class channelClass = socketChannelClass; if (channelClass == null) { @@ -684,6 +687,11 @@ public final class Server implements AutoCloseable { return this; } + public Builder setWebSocketFrameHandler(SimpleChannelInboundHandler webSocketFrameHandler) { + this.serverConfig.setWebSocketFrameHandler(webSocketFrameHandler); + return this; + } + public Server build() { int maxThreads = serverConfig.getBlockingThreadCount(); int maxQueue = serverConfig.getBlockingQueueCount(); @@ -732,9 +740,8 @@ public final class Server implements AutoCloseable { } } logger.log(Level.INFO, "configured domains: " + serverConfig.getDomains()); - return new Server(serverConfig, byteBufAllocator, - parentEventLoopGroup, childEventLoopGroup, socketChannelClass, - executor); + return new Server(serverConfig, byteBufAllocator, parentEventLoopGroup, childEventLoopGroup, + socketChannelClass, executor); } } } diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/endpoint/HttpEndpointResolver.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/endpoint/HttpEndpointResolver.java index 5bfbea5..34e7456 100644 --- a/netty-http-server/src/main/java/org/xbib/netty/http/server/endpoint/HttpEndpointResolver.java +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/endpoint/HttpEndpointResolver.java @@ -6,13 +6,8 @@ import org.xbib.netty.http.server.api.EndpointResolver; import org.xbib.netty.http.server.api.Filter; import org.xbib.netty.http.server.api.ServerRequest; import org.xbib.netty.http.server.api.ServerResponse; -import org.xbib.netty.http.server.api.annotation.Endpoint; -import org.xbib.netty.http.server.endpoint.service.MethodService; import java.io.IOException; -import java.lang.reflect.Method; import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; @@ -108,32 +103,6 @@ public class HttpEndpointResolver implements EndpointResolver { return this; } - /** - * Adds a service for the methods of the given object that - * are annotated with the {@link Endpoint} annotation. - * @param classWithAnnotatedMethods class with annotated methods - * @return this builder - */ - public Builder addEndpoint(Object classWithAnnotatedMethods) { - Objects.requireNonNull(classWithAnnotatedMethods); - for (Class clazz = classWithAnnotatedMethods.getClass(); clazz != null; clazz = clazz.getSuperclass()) { - for (Method method : clazz.getDeclaredMethods()) { - Endpoint endpoint = method.getAnnotation(Endpoint.class); - if (endpoint != null) { - MethodService methodService = new MethodService(method, classWithAnnotatedMethods); - addEndpoint(HttpEndpoint.builder() - .setPrefix(prefix) - .setPath(endpoint.path()) - .setMethods(Arrays.asList(endpoint.methods())) - .setContentTypes(Arrays.asList(endpoint.contentTypes())) - .setBefore(Collections.singletonList(methodService)) - .build()); - } - } - } - return this; - } - public Builder setDispatcher(Filter dispatcher) { Objects.requireNonNull(dispatcher); this.dispatcher = dispatcher; diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/handler/ExtendedSNIHandler.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/handler/ExtendedSNIHandler.java index 25c529e..3a8a363 100644 --- a/netty-http-server/src/main/java/org/xbib/netty/http/server/handler/ExtendedSNIHandler.java +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/handler/ExtendedSNIHandler.java @@ -6,7 +6,7 @@ import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslHandler; import io.netty.util.Mapping; import org.xbib.netty.http.common.HttpAddress; -import org.xbib.netty.http.server.DefaultServerConfig; +import org.xbib.netty.http.server.api.ServerConfig; import java.net.InetSocketAddress; import java.util.Arrays; import java.util.logging.Level; @@ -18,12 +18,12 @@ public class ExtendedSNIHandler extends SniHandler { private static final Logger logger = Logger.getLogger(ExtendedSNIHandler.class.getName()); - private final DefaultServerConfig serverConfig; + private final ServerConfig serverConfig; private final HttpAddress httpAddress; public ExtendedSNIHandler(Mapping mapping, - DefaultServerConfig serverConfig, HttpAddress httpAddress) { + ServerConfig serverConfig, HttpAddress httpAddress) { super(mapping); this.serverConfig = serverConfig; this.httpAddress = httpAddress; @@ -35,7 +35,7 @@ public class ExtendedSNIHandler extends SniHandler { } private static SslHandler newSslHandler(SslContext sslContext, - DefaultServerConfig serverConfig, + ServerConfig serverConfig, ByteBufAllocator allocator, HttpAddress httpAddress) { InetSocketAddress peer = httpAddress.getInetSocketAddress(); diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/http1/Http1ChannelInitializer.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/http1/Http1ChannelInitializer.java index a4bcb12..7d7aceb 100644 --- a/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/http1/Http1ChannelInitializer.java +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/http1/Http1ChannelInitializer.java @@ -15,6 +15,8 @@ import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpServerCodec; import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; import io.netty.handler.logging.LogLevel; import io.netty.handler.ssl.SslContext; import io.netty.handler.stream.ChunkedWriteHandler; @@ -22,8 +24,8 @@ import io.netty.handler.timeout.ReadTimeoutHandler; import io.netty.util.Mapping; import org.xbib.netty.http.common.HttpAddress; import org.xbib.netty.http.server.Server; -import org.xbib.netty.http.server.DefaultServerConfig; import org.xbib.netty.http.common.HttpChannelInitializer; +import org.xbib.netty.http.server.api.ServerConfig; import org.xbib.netty.http.server.handler.ExtendedSNIHandler; import org.xbib.netty.http.server.handler.IdleTimeoutHandler; import org.xbib.netty.http.server.handler.TrafficLoggingHandler; @@ -39,7 +41,7 @@ public class Http1ChannelInitializer extends ChannelInitializer private final Server server; - private final DefaultServerConfig serverConfig; + private final ServerConfig serverConfig; private final HttpAddress httpAddress; @@ -89,8 +91,7 @@ public class Http1ChannelInitializer extends ChannelInitializer serverConfig.getMaxHeadersSize(), serverConfig.getMaxChunkSize())); if (serverConfig.isCompressionEnabled()) { pipeline.addLast("http-server-compressor", - new HttpContentCompressor(6, 15, 8, - serverConfig.getCompressionThreshold())); + new HttpContentCompressor()); } if (serverConfig.isDecompressionEnabled()) { pipeline.addLast("http-server-decompressor", @@ -99,10 +100,20 @@ public class Http1ChannelInitializer extends ChannelInitializer HttpObjectAggregator httpObjectAggregator = new HttpObjectAggregator(serverConfig.getMaxContentLength()); httpObjectAggregator.setMaxCumulationBufferComponents(serverConfig.getMaxCompositeBufferComponents()); - pipeline.addLast("http-server-aggregator", httpObjectAggregator); - pipeline.addLast("http-server-pipelining", new HttpPipeliningHandler(serverConfig.getPipeliningCapacity())); - pipeline.addLast("http-server-handler", new ServerMessages(server)); - pipeline.addLast("http-idle-timeout-handler", new IdleTimeoutHandler(serverConfig.getIdleTimeoutMillis())); + pipeline.addLast("http-server-aggregator", + httpObjectAggregator); + if (serverConfig.getWebSocketFrameHandler() != null) { + pipeline.addLast("http-server-ws-protocol-handler", + new WebSocketServerProtocolHandler("/websocket")); + pipeline.addLast("http-server-ws-handler", + serverConfig.getWebSocketFrameHandler()); + } + pipeline.addLast("http-server-pipelining", + new HttpPipeliningHandler(serverConfig.getPipeliningCapacity())); + pipeline.addLast("http-server-handler", + new ServerMessages(server)); + pipeline.addLast("http-idle-timeout-handler", + new IdleTimeoutHandler(serverConfig.getIdleTimeoutMillis())); } @Sharable @@ -118,6 +129,13 @@ public class Http1ChannelInitializer extends ChannelInitializer @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof WebSocketFrame) { + WebSocketFrame webSocketFrame = (WebSocketFrame) msg; + if (serverConfig.getWebSocketFrameHandler() != null) { + serverConfig.getWebSocketFrameHandler().channelRead(ctx, webSocketFrame); + } + return; + } if (msg instanceof HttpPipelinedRequest) { HttpPipelinedRequest httpPipelinedRequest = (HttpPipelinedRequest) msg; if (httpPipelinedRequest.getRequest() instanceof FullHttpRequest) { diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/http2/Http2ChannelInitializer.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/http2/Http2ChannelInitializer.java index 17afcda..6523427 100644 --- a/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/http2/Http2ChannelInitializer.java +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/http2/Http2ChannelInitializer.java @@ -34,8 +34,8 @@ import io.netty.util.AsciiString; import io.netty.util.Mapping; import org.xbib.netty.http.common.HttpAddress; import org.xbib.netty.http.server.Server; -import org.xbib.netty.http.server.DefaultServerConfig; import org.xbib.netty.http.common.HttpChannelInitializer; +import org.xbib.netty.http.server.api.ServerConfig; import org.xbib.netty.http.server.handler.ExtendedSNIHandler; import org.xbib.netty.http.server.handler.IdleTimeoutHandler; import org.xbib.netty.http.server.handler.TrafficLoggingHandler; @@ -52,7 +52,7 @@ public class Http2ChannelInitializer extends ChannelInitializer private final Server server; - private final DefaultServerConfig serverConfig; + private final ServerConfig serverConfig; private final HttpAddress httpAddress; @@ -80,8 +80,10 @@ public class Http2ChannelInitializer extends ChannelInitializer configureCleartext(channel); } if (serverConfig.isDebug()) { - logger.log(Level.FINE, "HTTP/2 server channel initialized: " + - " address=" + httpAddress + " pipeline=" + channel.pipeline().names()); + if (logger.isLoggable(Level.FINEST)) { + logger.log(Level.FINEST, "HTTP/2 server channel initialized: " + + " address=" + httpAddress + " pipeline=" + channel.pipeline().names()); + } } } diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketAcceptor.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketAcceptor.java new file mode 100644 index 0000000..255af1e --- /dev/null +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketAcceptor.java @@ -0,0 +1,43 @@ +package org.xbib.netty.http.server.protocol.ws2; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http2.Http2Headers; +import io.netty.util.concurrent.Future; +import org.xbib.netty.http.common.ws.Http2WebSocketProtocol; +import java.util.List; +import java.util.Objects; + +/** + * Accepts valid websocket-over-http2 request, optionally modifies request and response headers. + */ +public interface Http2WebSocketAcceptor { + + /** + * @param ctx ChannelHandlerContext of connection channel. Intended for creating acceptor result + * with context.executor().newFailedFuture(Throwable), + * context.executor().newSucceededFuture(ChannelHandler) + * @param path websocket path + * @param subprotocols requested websocket subprotocols. Accepted subprotocol must be set on + * response headers, e.g. with {@link Subprotocol#accept(String, Http2Headers)} + * @param request request headers + * @param response response headers + * @return Succeeded future for accepted request, failed for rejected request. It is an error to + * return non completed future + */ + Future accept(ChannelHandlerContext ctx, String path, List subprotocols, + Http2Headers request, Http2Headers response); + + final class Subprotocol { + private Subprotocol() {} + + public static void accept(String subprotocol, Http2Headers response) { + Objects.requireNonNull(subprotocol, "subprotocol"); + Objects.requireNonNull(response, "response"); + if (subprotocol.isEmpty()) { + return; + } + response.set(Http2WebSocketProtocol.HEADER_WEBSOCKET_SUBPROTOCOL_NAME, subprotocol); + } + } +} diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketChannelFutureListener.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketChannelFutureListener.java new file mode 100644 index 0000000..19daf3c --- /dev/null +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketChannelFutureListener.java @@ -0,0 +1,27 @@ +package org.xbib.netty.http.server.protocol.ws2; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.util.concurrent.GenericFutureListener; +import org.xbib.netty.http.common.ws.Http2WebSocketEvent; + +/** + * ChannelFuture listener that gracefully closes websocket by sending empty DATA frame with + * END_STREAM flag set. + */ +public final class Http2WebSocketChannelFutureListener implements GenericFutureListener { + + public static final Http2WebSocketChannelFutureListener CLOSE = new Http2WebSocketChannelFutureListener(); + + private Http2WebSocketChannelFutureListener() {} + + @Override + public void operationComplete(ChannelFuture future) { + Channel channel = future.channel(); + Throwable cause = future.cause(); + if (cause != null) { + Http2WebSocketEvent.fireFrameWriteError(channel, cause); + } + channel.pipeline().fireUserEventTriggered(Http2WebSocketEvent.Http2WebSocketLocalCloseEvent.INSTANCE); + } +} diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketChannelInitializer.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketChannelInitializer.java new file mode 100644 index 0000000..58822a9 --- /dev/null +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketChannelInitializer.java @@ -0,0 +1,46 @@ +package org.xbib.netty.http.server.protocol.ws2; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketDecoderConfig; +import io.netty.handler.codec.http2.Http2FrameCodec; +import io.netty.handler.codec.http2.Http2FrameCodecBuilder; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslHandler; + +public class Http2WebSocketChannelInitializer extends ChannelInitializer { + private final SslContext sslContext; + + Http2WebSocketChannelInitializer(SslContext sslContext) { + this.sslContext = sslContext; + } + + @Override + protected void initChannel(SocketChannel ch) { + SslHandler sslHandler = sslContext.newHandler(ch.alloc()); + Http2FrameCodec http2frameCodec = Http2WebSocketServerBuilder + .configureHttp2Server(Http2FrameCodecBuilder.forServer()) + .build(); + ServerWebSocketHandler serverWebSocketHandler = new ServerWebSocketHandler(); + Http2WebSocketServerHandler http2webSocketHandler = + Http2WebSocketServerBuilder.create() + .decoderConfig(WebSocketDecoderConfig.newBuilder().allowExtensions(true).build()) + .compression(true) + .acceptor(new PathAcceptor("/test", serverWebSocketHandler)) + .build(); + ch.pipeline().addLast(sslHandler, http2frameCodec, http2webSocketHandler); + } + + @Sharable + private static class ServerWebSocketHandler extends SimpleChannelInboundHandler { + + @Override + protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame webSocketFrame) { + // echo + ctx.writeAndFlush(webSocketFrame.retain()); + } + } +} \ No newline at end of file diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketHandshakeOnlyServerHandler.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketHandshakeOnlyServerHandler.java new file mode 100644 index 0000000..1c84998 --- /dev/null +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketHandshakeOnlyServerHandler.java @@ -0,0 +1,79 @@ +package org.xbib.netty.http.server.protocol.ws2; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http2.Http2Error; +import io.netty.handler.codec.http2.Http2Exception; +import io.netty.handler.codec.http2.Http2Headers; +import io.netty.util.AsciiString; +import io.netty.util.concurrent.GenericFutureListener; +import org.xbib.netty.http.common.ws.Http2WebSocketEvent; +import org.xbib.netty.http.common.ws.Http2WebSocketHandler; +import org.xbib.netty.http.common.ws.Http2WebSocketProtocol; +import org.xbib.netty.http.common.ws.Http2WebSocketValidator; + +/** + * Provides server-side support for websocket-over-http2. Verifies websocket-over-http2 request + * validity. Invalid websocket requests are rejected by sending RST frame, valid websocket http2 + * stream frames are passed down the pipeline. Valid websocket stream request headers are modified + * as follows: :method=POST, x-protocol=websocket. Intended for proxies/intermidiaries that do not + * process websocket byte streams, but only route respective http2 streams - hence is not compatible + * with http1 websocket handlers. http1 websocket handlers support is provided by complementary + * {@link Http2WebSocketServerHandler} + */ +public final class Http2WebSocketHandshakeOnlyServerHandler extends Http2WebSocketHandler + implements GenericFutureListener { + + public Http2WebSocketHandshakeOnlyServerHandler() {} + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, + boolean endOfStream) throws Http2Exception { + if (handshake(headers, endOfStream)) { + super.onHeadersRead(ctx, streamId, headers, padding, endOfStream); + } else { + reject(ctx, streamId, headers, endOfStream); + } + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, + short weight, boolean exclusive, int padding, boolean endOfStream) + throws Http2Exception { + if (handshake(headers, endOfStream)) { + super.onHeadersRead( + ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream); + } else { + reject(ctx, streamId, headers, endOfStream); + } + } + + /*RST_STREAM frame write*/ + @Override + public void operationComplete(ChannelFuture future) { + Throwable cause = future.cause(); + if (cause != null) { + Http2WebSocketEvent.fireFrameWriteError(future.channel(), cause); + } + } + + private boolean handshake(Http2Headers headers, boolean endOfStream) { + if (Http2WebSocketProtocol.isExtendedConnect(headers)) { + boolean isValid = Http2WebSocketValidator.WebSocket.isValid(headers, endOfStream); + if (isValid) { + Http2WebSocketServerHandshaker.handshakeOnlyWebSocket(headers); + } + return isValid; + } + return Http2WebSocketValidator.Http.isValid(headers, endOfStream); + } + + private void reject(ChannelHandlerContext ctx, int streamId, Http2Headers headers, boolean endOfStream) { + Http2WebSocketEvent.fireHandshakeValidationStartAndError(ctx.channel(), streamId, + headers.set( AsciiString.of("x-websocket-endofstream"), AsciiString.of(endOfStream ? "true" : "false"))); + http2Handler.encoder() + .writeRstStream(ctx, streamId, Http2Error.PROTOCOL_ERROR.code(), ctx.newPromise()) + .addListener(this); + ctx.flush(); + } +} diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketPathNotFoundException.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketPathNotFoundException.java new file mode 100644 index 0000000..aba2ec3 --- /dev/null +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketPathNotFoundException.java @@ -0,0 +1,10 @@ +package org.xbib.netty.http.server.protocol.ws2; + +import io.netty.handler.codec.http.websocketx.WebSocketHandshakeException; + +public final class Http2WebSocketPathNotFoundException extends WebSocketHandshakeException { + + public Http2WebSocketPathNotFoundException(String message) { + super(message); + } +} diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketServerBuilder.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketServerBuilder.java new file mode 100644 index 0000000..e22cfd0 --- /dev/null +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketServerBuilder.java @@ -0,0 +1,190 @@ +package org.xbib.netty.http.server.protocol.ws2; + +import io.netty.handler.codec.http.websocketx.WebSocketDecoderConfig; +import io.netty.handler.codec.http.websocketx.extensions.compression.PerMessageDeflateServerExtensionHandshaker; +import io.netty.handler.codec.http2.Http2ConnectionHandlerBuilder; +import io.netty.handler.codec.http2.Http2FrameCodecBuilder; +import org.xbib.netty.http.common.ws.Http2WebSocketMessages; +import org.xbib.netty.http.common.ws.Http2WebSocketProtocol; +import org.xbib.netty.http.common.ws.Preconditions; + +import java.util.Objects; + +/** + * Builder for {@link Http2WebSocketServerHandler} + */ +public final class Http2WebSocketServerBuilder { + + private static final boolean MASK_PAYLOAD = false; + + private static final Http2WebSocketAcceptor REJECT_REQUESTS_ACCEPTOR = + (context, path, subprotocols, request, response) -> context.executor() + .newFailedFuture(new Http2WebSocketPathNotFoundException(Http2WebSocketMessages.HANDSHAKE_PATH_NOT_FOUND + path + + Http2WebSocketMessages.HANDSHAKE_PATH_NOT_FOUND_SUBPROTOCOLS + + subprotocols)); + + private WebSocketDecoderConfig webSocketDecoderConfig; + + private PerMessageDeflateServerExtensionHandshaker perMessageDeflateServerExtensionHandshaker; + + private long closedWebSocketRemoveTimeoutMillis = 30_000; + + private boolean isSingleWebSocketPerConnection; + + private Http2WebSocketAcceptor acceptor = REJECT_REQUESTS_ACCEPTOR; + + Http2WebSocketServerBuilder() {} + + /** + * Builds handshake-only {@link Http2WebSocketHandshakeOnlyServerHandler}. + * + * @return new {@link Http2WebSocketHandshakeOnlyServerHandler} instance + */ + public static Http2WebSocketHandshakeOnlyServerHandler buildHandshakeOnly() { + return new Http2WebSocketHandshakeOnlyServerHandler(); + } + + /** @return new {@link Http2WebSocketServerBuilder} instance */ + public static Http2WebSocketServerBuilder create() { + return new Http2WebSocketServerBuilder(); + } + + /** + * Utility method for configuring Http2FrameCodecBuilder with websocket-over-http2 support + * + * @param http2Builder {@link Http2FrameCodecBuilder} instance + * @return same {@link Http2FrameCodecBuilder} instance + */ + public static Http2FrameCodecBuilder configureHttp2Server(Http2FrameCodecBuilder http2Builder) { + Objects.requireNonNull(http2Builder, "http2Builder") + .initialSettings() + .put(Http2WebSocketProtocol.SETTINGS_ENABLE_CONNECT_PROTOCOL, (Long) 1L); + return http2Builder.validateHeaders(false); + } + + /** + * Utility method for configuring Http2ConnectionHandlerBuilder with websocket-over-http2 support + * + * @param http2Builder {@link Http2ConnectionHandlerBuilder} instance + * @return same {@link Http2ConnectionHandlerBuilder} instance + */ + public static Http2ConnectionHandlerBuilder configureHttp2Server(Http2ConnectionHandlerBuilder http2Builder) { + Objects.requireNonNull(http2Builder, "http2Builder") + .initialSettings() + .put(Http2WebSocketProtocol.SETTINGS_ENABLE_CONNECT_PROTOCOL, (Long) 1L); + return http2Builder.validateHeaders(false); + } + + /** + * @param webSocketDecoderConfig websocket decoder configuration. Must be non-null + * @return this {@link Http2WebSocketServerBuilder} instance + */ + public Http2WebSocketServerBuilder decoderConfig(WebSocketDecoderConfig webSocketDecoderConfig) { + this.webSocketDecoderConfig = + Preconditions.requireNonNull(webSocketDecoderConfig, "webSocketDecoderConfig"); + return this; + } + + /** + * @param closedWebSocketRemoveTimeoutMillis delay until websockets handler forgets closed + * websocket. Necessary to gracefully handle incoming http2 frames racing with outgoing stream + * termination frame. + * @return this {@link Http2WebSocketServerBuilder} instance + */ + public Http2WebSocketServerBuilder closedWebSocketRemoveTimeout(long closedWebSocketRemoveTimeoutMillis) { + this.closedWebSocketRemoveTimeoutMillis = + Preconditions.requirePositive(closedWebSocketRemoveTimeoutMillis, "closedWebSocketRemoveTimeoutMillis"); + return this; + } + + /** + * @param isCompressionEnabled enables permessage-deflate compression with default configuration + * @return this {@link Http2WebSocketServerBuilder} instance + */ + public Http2WebSocketServerBuilder compression(boolean isCompressionEnabled) { + if (isCompressionEnabled) { + if (perMessageDeflateServerExtensionHandshaker == null) { + perMessageDeflateServerExtensionHandshaker = + new PerMessageDeflateServerExtensionHandshaker(); + } + } else { + perMessageDeflateServerExtensionHandshaker = null; + } + return this; + } + + /** + * Enables permessage-deflate compression with extended configuration. Parameters are described in + * netty's PerMessageDeflateServerExtensionHandshaker + * + * @param compressionLevel sets compression level. Range is [0; 9], default is 6 + * @param allowServerWindowSize allows client to customize the server's inflater window size, + * default is false + * @param preferredClientWindowSize preferred client window size if client inflater is + * customizable + * @param allowServerNoContext allows client to activate server_no_context_takeover, default is + * false + * @param preferredClientNoContext whether server prefers to activate client_no_context_takeover + * if client is compatible, default is false + * @return this {@link Http2WebSocketServerBuilder} instance + */ + public Http2WebSocketServerBuilder compression(int compressionLevel, + boolean allowServerWindowSize, + int preferredClientWindowSize, + boolean allowServerNoContext, + boolean preferredClientNoContext) { + perMessageDeflateServerExtensionHandshaker = + new PerMessageDeflateServerExtensionHandshaker(compressionLevel, allowServerWindowSize, + preferredClientWindowSize, allowServerNoContext, preferredClientNoContext); + return this; + } + + /** + * Sets http1 websocket request acceptor + * + * @param acceptor websocket request acceptor. Must be non-null. + * @return this {@link Http2WebSocketServerBuilder} instance + */ + public Http2WebSocketServerBuilder acceptor(Http2WebSocketAcceptor acceptor) { + this.acceptor = Objects.requireNonNull(acceptor, "acceptor"); + return this; + } + + /** + * @param isSingleWebSocketPerConnection optimize for at most 1 websocket per connection + * @return this {@link Http2WebSocketServerBuilder} instance + */ + public Http2WebSocketServerBuilder assumeSingleWebSocketPerConnection(boolean isSingleWebSocketPerConnection) { + this.isSingleWebSocketPerConnection = isSingleWebSocketPerConnection; + return this; + } + + /** + * Builds subchannel based {@link Http2WebSocketServerHandler} compatible with http1 websocket + * handlers. + * + * @return new {@link Http2WebSocketServerHandler} instance + */ + public Http2WebSocketServerHandler build() { + boolean hasCompression = perMessageDeflateServerExtensionHandshaker != null; + WebSocketDecoderConfig config = webSocketDecoderConfig; + if (config == null) { + config = WebSocketDecoderConfig.newBuilder() + .expectMaskedFrames(true) + .allowMaskMismatch(false) + .allowExtensions(hasCompression) + .build(); + } else { + boolean isAllowExtensions = config.allowExtensions(); + if (!isAllowExtensions && hasCompression) { + config = config.toBuilder().allowExtensions(true).build(); + } + } + return new Http2WebSocketServerHandler(config, + MASK_PAYLOAD, + closedWebSocketRemoveTimeoutMillis, + perMessageDeflateServerExtensionHandshaker, + acceptor, + isSingleWebSocketPerConnection); + } +} diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketServerHandler.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketServerHandler.java new file mode 100644 index 0000000..b15f209 --- /dev/null +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketServerHandler.java @@ -0,0 +1,74 @@ +package org.xbib.netty.http.server.protocol.ws2; + +import io.netty.channel.*; +import io.netty.handler.codec.http.websocketx.WebSocketDecoderConfig; +import io.netty.handler.codec.http.websocketx.extensions.compression.PerMessageDeflateServerExtensionHandshaker; +import io.netty.handler.codec.http2.Http2Exception; +import io.netty.handler.codec.http2.Http2Headers; +import org.xbib.netty.http.common.ws.Http2WebSocketChannelHandler; +import org.xbib.netty.http.common.ws.Http2WebSocketProtocol; +import org.xbib.netty.http.common.ws.Http2WebSocketValidator; + +/** + * Provides server-side support for websocket-over-http2. Creates sub channel for http2 stream of + * successfully handshaked websocket. Subchannel is compatible with http1 websocket handlers. + */ +public final class Http2WebSocketServerHandler extends Http2WebSocketChannelHandler { + + private final PerMessageDeflateServerExtensionHandshaker compressionHandshaker; + + private final Http2WebSocketAcceptor http2WebSocketAcceptor; + + private Http2WebSocketServerHandshaker handshaker; + + Http2WebSocketServerHandler(WebSocketDecoderConfig webSocketDecoderConfig, boolean isEncoderMaskPayload, + long closedWebSocketRemoveTimeoutMillis, + PerMessageDeflateServerExtensionHandshaker compressionHandshaker, + Http2WebSocketAcceptor http2WebSocketAcceptor, + boolean isSingleWebSocketPerConnection) { + super(webSocketDecoderConfig, isEncoderMaskPayload, closedWebSocketRemoveTimeoutMillis, isSingleWebSocketPerConnection); + this.compressionHandshaker = compressionHandshaker; + this.http2WebSocketAcceptor = http2WebSocketAcceptor; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + super.handlerAdded(ctx); + this.handshaker = new Http2WebSocketServerHandshaker(webSocketsParent, + config, isEncoderMaskPayload, http2WebSocketAcceptor, compressionHandshaker); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, final int streamId, Http2Headers headers, + int padding, boolean endOfStream) throws Http2Exception { + boolean proceed = handshakeWebSocket(streamId, headers, endOfStream); + if (proceed) { + next().onHeadersRead(ctx, streamId, headers, padding, endOfStream); + } + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, + short weight, boolean exclusive, int padding, boolean endOfStream) throws Http2Exception { + boolean proceed = handshakeWebSocket(streamId, headers, endOfStream); + if (proceed) { + next().onHeadersRead(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream); + } + } + + private boolean handshakeWebSocket(int streamId, Http2Headers headers, boolean endOfStream) { + if (Http2WebSocketProtocol.isExtendedConnect(headers)) { + if (!Http2WebSocketValidator.WebSocket.isValid(headers, endOfStream)) { + handshaker.reject(streamId, headers, endOfStream); + } else { + handshaker.handshake(streamId, headers, endOfStream); + } + return false; + } + if (!Http2WebSocketValidator.Http.isValid(headers, endOfStream)) { + handshaker.reject(streamId, headers, endOfStream); + return false; + } + return true; + } +} diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketServerHandshaker.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketServerHandshaker.java new file mode 100644 index 0000000..b4d81a2 --- /dev/null +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/Http2WebSocketServerHandshaker.java @@ -0,0 +1,345 @@ +package org.xbib.netty.http.server.protocol.ws2; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.websocketx.WebSocketDecoderConfig; +import io.netty.handler.codec.http.websocketx.WebSocketHandshakeException; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionData; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionDecoder; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionEncoder; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketServerExtension; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketServerExtensionHandshaker; +import io.netty.handler.codec.http2.DefaultHttp2Headers; +import io.netty.handler.codec.http2.Http2Error; +import io.netty.handler.codec.http2.Http2Headers; +import io.netty.handler.codec.http2.ReadOnlyHttp2Headers; +import io.netty.util.AsciiString; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.ScheduledFuture; +import org.xbib.netty.http.common.ws.Http2WebSocketChannel; +import org.xbib.netty.http.common.ws.Http2WebSocketChannelHandler; +import org.xbib.netty.http.common.ws.Http2WebSocketEvent; +import org.xbib.netty.http.common.ws.Http2WebSocketExtensions; +import org.xbib.netty.http.common.ws.Http2WebSocketMessages; +import org.xbib.netty.http.common.ws.Http2WebSocketProtocol; + +import java.nio.channels.ClosedChannelException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +final class Http2WebSocketServerHandshaker implements GenericFutureListener { + + private static final AsciiString HEADERS_STATUS_200 = AsciiString.of("200"); + + private static final ReadOnlyHttp2Headers HEADERS_OK = + ReadOnlyHttp2Headers.serverHeaders(false, HEADERS_STATUS_200); + + private static final ReadOnlyHttp2Headers HEADERS_UNSUPPORTED_VERSION = + ReadOnlyHttp2Headers.serverHeaders(false, AsciiString.of("400"), + AsciiString.of(Http2WebSocketProtocol.HEADER_WEBSOCKET_VERSION_NAME), + AsciiString.of(Http2WebSocketProtocol.HEADER_WEBSOCKET_VERSION_VALUE)); + + private static final ReadOnlyHttp2Headers HEADERS_REJECTED = + ReadOnlyHttp2Headers.serverHeaders(false, AsciiString.of("400")); + + private static final ReadOnlyHttp2Headers HEADERS_NOT_FOUND = + ReadOnlyHttp2Headers.serverHeaders(false, AsciiString.of("404")); + + private static final ReadOnlyHttp2Headers HEADERS_INTERNAL_ERROR = + ReadOnlyHttp2Headers.serverHeaders(false, AsciiString.of("500")); + + private final Http2WebSocketChannelHandler.WebSocketsParent webSocketsParent; + + private final WebSocketDecoderConfig webSocketDecoderConfig; + + private final boolean isEncoderMaskPayload; + + private final Http2WebSocketAcceptor http2WebSocketAcceptor; + + private final WebSocketServerExtensionHandshaker compressionHandshaker; + + Http2WebSocketServerHandshaker(Http2WebSocketChannelHandler.WebSocketsParent webSocketsParent, + WebSocketDecoderConfig webSocketDecoderConfig, + boolean isEncoderMaskPayload, + Http2WebSocketAcceptor http2WebSocketAcceptor, + WebSocketServerExtensionHandshaker compressionHandshaker) { + this.webSocketsParent = webSocketsParent; + this.webSocketDecoderConfig = webSocketDecoderConfig; + this.isEncoderMaskPayload = isEncoderMaskPayload; + this.http2WebSocketAcceptor = http2WebSocketAcceptor; + this.compressionHandshaker = compressionHandshaker; + } + + void reject(final int streamId, final Http2Headers requestHeaders, boolean endOfStream) { + Http2WebSocketEvent.fireHandshakeValidationStartAndError( + webSocketsParent.context().channel(), + streamId, + requestHeaders.set(AsciiString.of("x-websocket-endofstream"), AsciiString.of(endOfStream ? "true" : "false"))); + writeRstStream(streamId).addListener(this); + } + + void handshake(final int streamId, final Http2Headers requestHeaders, boolean endOfStream) { + long startNanos = System.nanoTime(); + ChannelHandlerContext ctx = webSocketsParent.context(); + String path = requestHeaders.path().toString(); + CharSequence webSocketVersion = requestHeaders.get(Http2WebSocketProtocol.HEADER_WEBSOCKET_VERSION_NAME); + CharSequence subprotocolsSeq = requestHeaders.get(Http2WebSocketProtocol.HEADER_WEBSOCKET_SUBPROTOCOL_NAME); + String subprotocols = nonNullString(subprotocolsSeq); + if (isUnsupportedWebSocketVersion(webSocketVersion)) { + Http2WebSocketEvent.fireHandshakeStartAndError(ctx.channel(), streamId, path, subprotocols, + requestHeaders, startNanos, System.nanoTime(), WebSocketHandshakeException.class.getName(), + Http2WebSocketMessages.HANDSHAKE_UNSUPPORTED_VERSION + webSocketVersion); + writeHeaders(ctx, streamId, HEADERS_UNSUPPORTED_VERSION, true).addListener(this); + return; + } + List requestedSubprotocols = parseSubprotocols(subprotocols); + WebSocketServerExtension compressionExtension = null; + WebSocketServerExtensionHandshaker compressionHandshaker = this.compressionHandshaker; + if (compressionHandshaker != null) { + CharSequence extensionsHeader = requestHeaders.get(Http2WebSocketProtocol.HEADER_WEBSOCKET_EXTENSIONS_NAME); + WebSocketExtensionData compression = Http2WebSocketExtensions.decode(extensionsHeader); + if (compression != null) { + compressionExtension = compressionHandshaker.handshakeExtension(compression); + } + } + boolean hasCompression = compressionExtension != null; + WebSocketExtensionEncoder compressionEncoder = null; + WebSocketExtensionDecoder compressionDecoder = null; + Http2Headers responseHeaders = new DefaultHttp2Headers(); + if (hasCompression) { + responseHeaders.set(Http2WebSocketProtocol.HEADER_WEBSOCKET_EXTENSIONS_NAME, + Http2WebSocketExtensions.encode(compressionExtension.newReponseData())); + compressionEncoder = compressionExtension.newExtensionEncoder(); + compressionDecoder = compressionExtension.newExtensionDecoder(); + } + Future acceptorResult; + try { + acceptorResult = http2WebSocketAcceptor.accept(ctx, path, requestedSubprotocols, requestHeaders, responseHeaders); + } catch (Exception e) { + acceptorResult = ctx.executor().newFailedFuture(e); + } + if (!acceptorResult.isDone()) { + acceptorResult.cancel(true); + Http2WebSocketEvent.fireHandshakeStartAndError(ctx.channel(), streamId, path, subprotocols, + requestHeaders, startNanos, System.nanoTime(), WebSocketHandshakeException.class.getName(), + Http2WebSocketMessages.HANDSHAKE_UNSUPPORTED_ACCEPTOR_TYPE); + writeHeaders(ctx, streamId, HEADERS_INTERNAL_ERROR, true).addListener(this); + return; + } + Throwable rejected = acceptorResult.cause(); + if (rejected != null) { + Http2WebSocketEvent.fireHandshakeStartAndError(ctx.channel(), streamId, path, subprotocols, + requestHeaders, startNanos, System.nanoTime(), rejected); + Http2Headers response = rejected instanceof Http2WebSocketPathNotFoundException ? + HEADERS_NOT_FOUND : HEADERS_REJECTED; + writeHeaders(ctx, streamId, response, true).addListener(this); + return; + } + CharSequence acceptedSubprotocolSeq = responseHeaders.get(Http2WebSocketProtocol.HEADER_WEBSOCKET_SUBPROTOCOL_NAME); + String acceptedSubprotocol = nonNullString(acceptedSubprotocolSeq); + if (!isExpectedSubprotocol(acceptedSubprotocol, requestedSubprotocols)) { + String subprotocolOrBlank = acceptedSubprotocol.isEmpty() ? "''" : acceptedSubprotocol; + Http2WebSocketEvent.fireHandshakeStartAndError(ctx.channel(), streamId, path, subprotocols, + requestHeaders, startNanos, System.nanoTime(), WebSocketHandshakeException.class.getName(), + Http2WebSocketMessages.HANDSHAKE_UNEXPECTED_SUBPROTOCOL + subprotocolOrBlank); + writeHeaders(ctx, streamId, HEADERS_NOT_FOUND, true).addListener(this); + return; + } + ChannelHandler webSocketHandler = acceptorResult.getNow(); + WebSocketExtensionEncoder finalCompressionEncoder = compressionEncoder; + WebSocketExtensionDecoder finalCompressionDecoder = compressionDecoder; + Http2Headers successHeaders = successHeaders(responseHeaders); + writeHeaders(ctx, streamId, successHeaders, false).addListener(future -> { + Throwable cause = future.cause(); + if (cause != null) { + Channel ch = ctx.channel(); + Http2WebSocketEvent.fireFrameWriteError(ch, future.cause()); + Http2WebSocketEvent.fireHandshakeStartAndError(ch, streamId, path, subprotocols, + requestHeaders, startNanos, System.nanoTime(), cause); + return; + } + Http2WebSocketChannel webSocket = new Http2WebSocketChannel(webSocketsParent, streamId, path, acceptedSubprotocol, + webSocketDecoderConfig, isEncoderMaskPayload, finalCompressionEncoder, + finalCompressionDecoder, webSocketHandler).setStreamId(streamId); + ChannelFuture registered = ctx.channel().eventLoop().register(webSocket); + if (!registered.isSuccess()) { + Http2WebSocketEvent.fireHandshakeStartAndError(ctx.channel(), streamId, path, subprotocols, + requestHeaders, startNanos, System.nanoTime(), registered.cause()); + writeRstStream(streamId).addListener(this); + webSocket.streamClosed(); + return; + } + if (!webSocket.isOpen()) { + Http2WebSocketEvent.fireHandshakeStartAndError(ctx.channel(), streamId, path, subprotocols, + requestHeaders, startNanos, System.nanoTime(), ClosedChannelException.class.getName(), + "websocket channel closed immediately after eventloop registration"); + return; + } + webSocketsParent.register(streamId, webSocket); + Http2WebSocketEvent.fireHandshakeStartAndSuccess(webSocket, streamId, path, subprotocols, + requestHeaders, successHeaders, startNanos, System.nanoTime()); + }); + } + + private boolean isExpectedSubprotocol(String subprotocol, List requestedSubprotocols) { + int requestedLength = requestedSubprotocols.size(); + if (subprotocol.isEmpty()) { + return requestedLength == 0; + } + for (int i = 0; i < requestedLength; i++) { + if (requestedSubprotocols.get(i).equals(subprotocol)) { + return true; + } + } + return false; + } + + @Override + public void operationComplete(ChannelFuture future) { + Throwable cause = future.cause(); + if (cause != null) { + Http2WebSocketEvent.fireFrameWriteError(future.channel(), cause); + } + } + + private ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, boolean endStream) { + ChannelFuture channelFuture = webSocketsParent.writeHeaders(streamId, headers, endStream); + ctx.flush(); + return channelFuture; + } + + private ChannelFuture writeRstStream(int streamId) { + return webSocketsParent.writeRstStream(streamId, Http2Error.PROTOCOL_ERROR.code()); + } + + static Http2Headers handshakeOnlyWebSocket(Http2Headers headers) { + headers.remove(Http2WebSocketProtocol.HEADER_PROTOCOL_NAME); + headers.method(Http2WebSocketProtocol.HEADER_METHOD_CONNECT_HANDSHAKED); + return headers.set( + Http2WebSocketProtocol.HEADER_PROTOCOL_NAME_HANDSHAKED, + Http2WebSocketProtocol.HEADER_PROTOCOL_VALUE); + } + + static List parseSubprotocols(String subprotocols) { + if (subprotocols.isEmpty()) { + return Collections.emptyList(); + } + if (subprotocols.indexOf(',') == -1) { + return Collections.singletonList(subprotocols); + } + return Arrays.asList(subprotocols.split(",")); + } + + private static String nonNullString(CharSequence seq) { + if (seq == null) { + return ""; + } + return seq.toString(); + } + + private static Http2Headers successHeaders(Http2Headers responseHeaders) { + if (responseHeaders.isEmpty()) { + return HEADERS_OK; + } + return responseHeaders.status(HEADERS_STATUS_200); + } + + private static boolean isUnsupportedWebSocketVersion(CharSequence webSocketVersion) { + return webSocketVersion == null + || !Http2WebSocketProtocol.HEADER_WEBSOCKET_VERSION_VALUE.contentEquals(webSocketVersion); + } + + static class Handshake { + private final Future channelClose; + private final ChannelPromise handshake; + private final long timeoutMillis; + private boolean done; + private ScheduledFuture timeoutFuture; + private Future handshakeCompleteFuture; + private GenericFutureListener channelCloseListener; + + public Handshake(Future channelClose, ChannelPromise handshake, long timeoutMillis) { + this.channelClose = channelClose; + this.handshake = handshake; + this.timeoutMillis = timeoutMillis; + } + + public void startTimeout() { + ChannelPromise h = handshake; + Channel channel = h.channel(); + if (done) { + return; + } + GenericFutureListener l = channelCloseListener = future -> onConnectionClose(); + channelClose.addListener(l); + if (done) { + return; + } + handshakeCompleteFuture = h.addListener(future -> onHandshakeComplete(future.cause())); + if (done) { + return; + } + timeoutFuture = channel.eventLoop().schedule(this::onTimeout, timeoutMillis, TimeUnit.MILLISECONDS); + } + + public void complete(Throwable e) { + onHandshakeComplete(e); + } + + public boolean isDone() { + return done; + } + + public ChannelFuture future() { + return handshake; + } + + private void onConnectionClose() { + if (!done) { + handshake.tryFailure(new ClosedChannelException()); + done(); + } + } + + private void onHandshakeComplete(Throwable cause) { + if (!done) { + if (cause != null) { + handshake.tryFailure(cause); + } else { + handshake.trySuccess(); + } + done(); + } + } + + private void onTimeout() { + if (!done) { + handshake.tryFailure(new TimeoutException()); + done(); + } + } + + private void done() { + done = true; + GenericFutureListener closeListener = channelCloseListener; + if (closeListener != null) { + channelClose.removeListener(closeListener); + } + cancel(handshakeCompleteFuture); + cancel(timeoutFuture); + } + + private void cancel(Future future) { + if (future != null) { + future.cancel(true); + } + } + } +} diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/PathAcceptor.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/PathAcceptor.java new file mode 100644 index 0000000..fc579ca --- /dev/null +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/PathAcceptor.java @@ -0,0 +1,30 @@ +package org.xbib.netty.http.server.protocol.ws2; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.websocketx.WebSocketHandshakeException; +import io.netty.handler.codec.http2.Http2Headers; +import io.netty.util.concurrent.Future; + +import java.util.List; + +public class PathAcceptor implements Http2WebSocketAcceptor { + + private final String path; + + private final ChannelHandler webSocketHandler; + + public PathAcceptor(String path, ChannelHandler webSocketHandler) { + this.path = path; + this.webSocketHandler = webSocketHandler; + } + + @Override + public Future accept(ChannelHandlerContext ctx, String path, List subprotocols, + Http2Headers request, Http2Headers response) { + if (subprotocols.isEmpty() && path.equals(this.path)) { + return ctx.executor().newSucceededFuture(webSocketHandler); + } + return ctx.executor().newFailedFuture(new WebSocketHandshakeException(String.format("Path not found: %s , subprotocols: %s", path, subprotocols))); + } +} diff --git a/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/PathSubprotocolAcceptor.java b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/PathSubprotocolAcceptor.java new file mode 100644 index 0000000..8c779a6 --- /dev/null +++ b/netty-http-server/src/main/java/org/xbib/netty/http/server/protocol/ws2/PathSubprotocolAcceptor.java @@ -0,0 +1,44 @@ +package org.xbib.netty.http.server.protocol.ws2; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http2.Http2Headers; +import io.netty.util.concurrent.Future; + +import java.util.List; + +public class PathSubprotocolAcceptor implements Http2WebSocketAcceptor { + + private final ChannelHandler webSocketHandler; + + private final String path; + + private final String subprotocol; + + private final boolean acceptSubprotocol; + + public PathSubprotocolAcceptor(String path, String subprotocol, ChannelHandler webSocketHandler) { + this(path, subprotocol, webSocketHandler, true); + } + + public PathSubprotocolAcceptor(String path, String subprotocol, ChannelHandler webSocketHandler, boolean acceptSubprotocol) { + this.path = path; + this.subprotocol = subprotocol; + this.webSocketHandler = webSocketHandler; + this.acceptSubprotocol = acceptSubprotocol; + } + + @Override + public Future accept(ChannelHandlerContext ctx, + String path, List subprotocols, Http2Headers request, Http2Headers response) { + String subprotocol = this.subprotocol; + if (path.equals(this.path) && subprotocols.contains(subprotocol)) { + if (acceptSubprotocol) { + Subprotocol.accept(subprotocol, response); + } + return ctx.executor().newSucceededFuture(webSocketHandler); + } + return ctx.executor().newFailedFuture(new Http2WebSocketPathNotFoundException( + String.format("Path not found: %s , subprotocols: %s", path, subprotocols))); + } +} diff --git a/netty-http-server/src/test/java/org/xbib/netty/http/server/test/http1/BasicAuthTest.java b/netty-http-server/src/test/java/org/xbib/netty/http/server/test/http1/BasicAuthTest.java index a993449..dc167f1 100644 --- a/netty-http-server/src/test/java/org/xbib/netty/http/server/test/http1/BasicAuthTest.java +++ b/netty-http-server/src/test/java/org/xbib/netty/http/server/test/http1/BasicAuthTest.java @@ -2,39 +2,52 @@ package org.xbib.netty.http.server.test.http1; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpVersion; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.xbib.net.URL; import org.xbib.netty.http.client.Client; import org.xbib.netty.http.client.api.Request; import org.xbib.netty.http.client.api.ResponseListener; +import org.xbib.netty.http.common.HttpAddress; import org.xbib.netty.http.common.HttpResponse; +import org.xbib.netty.http.server.HttpServerDomain; +import org.xbib.netty.http.server.Server; + import java.nio.charset.StandardCharsets; import java.util.logging.Level; import java.util.logging.Logger; +import static org.junit.jupiter.api.Assertions.assertEquals; + public class BasicAuthTest { private static final Logger logger = Logger.getLogger(PostTest.class.getName()); - @Disabled + @Test void testBasicAuth() throws Exception { + HttpAddress httpAddress = HttpAddress.http1("localhost", 8008); + HttpServerDomain domain = HttpServerDomain.builder(httpAddress) + .singleEndpoint("/**", (request, response) -> { + String authorization = request.getHeader("Authorization"); + response.getBuilder().setStatus(HttpResponseStatus.OK.code()) + .setContentType("text/plain").build().write(authorization); + }) + .build(); + Server server = Server.builder(domain) + .build(); Client client = Client.builder() .build(); try { - ResponseListener responseListener = (resp) -> { - if (resp.getStatus().getCode() == HttpResponseStatus.OK.code()) { - logger.log(Level.INFO, "got response " + resp.getBodyAsString(StandardCharsets.UTF_8)); - } - }; - URL serverUrl = URL.from(""); - Request postRequest = Request.post().setVersion(HttpVersion.HTTP_1_1) - .url(serverUrl) - .addBasicAuthorization("", "") + server.accept(); + ResponseListener responseListener = (resp) -> + assertEquals("Basic aGVsbG86d29ybGQ=", resp.getBodyAsString(StandardCharsets.UTF_8)); + Request postRequest = Request.get() + .setVersion(HttpVersion.HTTP_1_1) + .url(server.getServerConfig().getAddress().base()) + .addBasicAuthorization("hello", "world") .setResponseListener(responseListener) .build(); client.execute(postRequest).get(); } finally { + server.shutdownGracefully(); client.shutdownGracefully(); logger.log(Level.INFO, "server and client shut down"); } diff --git a/netty-http-server/src/test/java/org/xbib/netty/http/server/test/ws1/EchoTest.java b/netty-http-server/src/test/java/org/xbib/netty/http/server/test/ws1/EchoTest.java new file mode 100644 index 0000000..f988410 --- /dev/null +++ b/netty-http-server/src/test/java/org/xbib/netty/http/server/test/ws1/EchoTest.java @@ -0,0 +1,55 @@ +package org.xbib.netty.http.server.test.ws1; + +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import org.junit.jupiter.api.Test; +import org.xbib.netty.http.client.Client; +import org.xbib.netty.http.client.api.Request; +import org.xbib.netty.http.client.api.ResponseListener; +import org.xbib.netty.http.common.HttpAddress; +import org.xbib.netty.http.common.HttpResponse; +import org.xbib.netty.http.server.HttpServerDomain; +import org.xbib.netty.http.server.Server; + +import java.nio.charset.StandardCharsets; +import java.util.logging.Level; +import java.util.logging.Logger; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class EchoTest { + + private static final Logger logger = Logger.getLogger(EchoTest.class.getName()); + + @Test + void testBasicAuth() throws Exception { + HttpAddress httpAddress = HttpAddress.http1("localhost", 8008); + HttpServerDomain domain = HttpServerDomain.builder(httpAddress) + .singleEndpoint("/**", (request, response) -> { + String authorization = request.getHeader("Authorization"); + response.getBuilder().setStatus(HttpResponseStatus.OK.code()) + .setContentType("text/plain").build().write(authorization); + }) + .build(); + Server server = Server.builder(domain) + .build(); + Client client = Client.builder() + .build(); + try { + server.accept(); + ResponseListener responseListener = (resp) -> + assertEquals("Basic aGVsbG86d29ybGQ=", resp.getBodyAsString(StandardCharsets.UTF_8)); + Request postRequest = Request.get() + .setVersion(HttpVersion.HTTP_1_1) + .url(server.getServerConfig().getAddress().base()) + .addBasicAuthorization("hello", "world") + .setResponseListener(responseListener) + .build(); + client.execute(postRequest).get(); + } finally { + server.shutdownGracefully(); + client.shutdownGracefully(); + logger.log(Level.INFO, "server and client shut down"); + } + } +}