netty/patches/13757.patch

1313 lines
63 KiB
Diff

From 0b6ce4c1d2106815a80f48d7903160009c09c03a Mon Sep 17 00:00:00 2001
From: Norman Maurer <norman_maurer@apple.com>
Date: Fri, 29 Dec 2023 17:03:56 +0100
Subject: [PATCH 1/4] Retry the query via TCP if a query failed because of a
timeout when using UDP
Motivation:
We should retry the query via TCP if the query failed because of a timeout when using UDP.
Modifications:
- Move all the retry code for TCP into DnsQueryContext so we can reuse the same code for handling truncation and retry.
- Retry on timeout if possible
- Add unit tests
Result:
More robust resolver
---
.../resolver/dns/DatagramDnsQueryContext.java | 7 +-
.../netty/resolver/dns/DnsNameResolver.java | 223 ++------------
.../netty/resolver/dns/DnsQueryContext.java | 278 ++++++++++++++++--
.../resolver/dns/TcpDnsQueryContext.java | 4 +-
.../generated/handlers/reflect-config.json | 15 +-
.../resolver/dns/DnsNameResolverTest.java | 155 ++++++++--
6 files changed, 428 insertions(+), 254 deletions(-)
diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DatagramDnsQueryContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DatagramDnsQueryContext.java
index 4cea712833e2..4c5487069079 100644
--- a/resolver-dns/src/main/java/io/netty/resolver/dns/DatagramDnsQueryContext.java
+++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DatagramDnsQueryContext.java
@@ -15,6 +15,7 @@
*/
package io.netty.resolver.dns;
+import io.netty.bootstrap.Bootstrap;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.handler.codec.dns.DatagramDnsQuery;
@@ -33,10 +34,12 @@ final class DatagramDnsQueryContext extends DnsQueryContext {
InetSocketAddress nameServerAddr,
DnsQueryContextManager queryContextManager,
int maxPayLoadSize, boolean recursionDesired,
+ long queryTimeoutMillis,
DnsQuestion question, DnsRecord[] additionals,
- Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise) {
+ Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise,
+ Bootstrap socketBootstrap) {
super(channel, channelReadyFuture, nameServerAddr, queryContextManager, maxPayLoadSize, recursionDesired,
- question, additionals, promise);
+ queryTimeoutMillis, question, additionals, promise, socketBootstrap);
}
@Override
diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java
index 535b87cee39f..6939177e540c 100644
--- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java
+++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java
@@ -23,6 +23,8 @@
import io.netty.channel.ChannelFactory;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
@@ -43,8 +45,6 @@
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.handler.codec.dns.DnsResponse;
-import io.netty.handler.codec.dns.TcpDnsQueryEncoder;
-import io.netty.handler.codec.dns.TcpDnsResponseDecoder;
import io.netty.resolver.DefaultHostsFileEntriesResolver;
import io.netty.resolver.HostsFileEntries;
import io.netty.resolver.HostsFileEntriesResolver;
@@ -121,6 +121,13 @@ public class DnsNameResolver extends InetNameResolver {
private static final InternetProtocolFamily[] IPV6_PREFERRED_RESOLVED_PROTOCOL_FAMILIES =
{InternetProtocolFamily.IPv6, InternetProtocolFamily.IPv4};
+ private static final ChannelHandler NOOP_HANDLER = new ChannelHandlerAdapter() {
+ @Override
+ public boolean isSharable() {
+ return true;
+ }
+ };
+
static final ResolvedAddressTypes DEFAULT_RESOLVE_ADDRESS_TYPES;
static final String[] DEFAULT_SEARCH_DOMAINS;
private static final UnixResolverOptions DEFAULT_OPTIONS;
@@ -227,7 +234,6 @@ protected DnsResponse decodeResponse(ChannelHandlerContext ctx, DatagramPacket p
}
};
private static final DatagramDnsQueryEncoder DATAGRAM_ENCODER = new DatagramDnsQueryEncoder();
- private static final TcpDnsQueryEncoder TCP_ENCODER = new TcpDnsQueryEncoder();
private final Promise<Channel> channelReadyPromise;
private final Channel ch;
@@ -465,7 +471,12 @@ public DnsNameResolver(
.group(executor())
.channelFactory(socketChannelFactory)
.attr(DNS_PIPELINE_ATTRIBUTE, Boolean.TRUE)
- .handler(TCP_ENCODER);
+ .handler(NOOP_HANDLER);
+ if (queryTimeoutMillis > 0) {
+ // Set the connect timeout to the same as queryTimeout as otherwise it might take a long
+ // time to fail the original query if the connect times out.
+ socketBootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) queryTimeoutMillis);
+ }
}
switch (this.resolvedAddressTypes) {
case IPV4_ONLY:
@@ -1349,8 +1360,9 @@ final Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query0(
final int payloadSize = isOptResourceEnabled() ? maxPayloadSize() : 0;
try {
DnsQueryContext queryContext = new DatagramDnsQueryContext(ch, channelReadyPromise, nameServerAddr,
- queryContextManager, payloadSize, isRecursionDesired(), question, additionals, castPromise);
- ChannelFuture future = queryContext.writeQuery(queryTimeoutMillis(), flush);
+ queryContextManager, payloadSize, isRecursionDesired(), queryTimeoutMillis(), question, additionals,
+ castPromise, socketBootstrap);
+ ChannelFuture future = queryContext.writeQuery(flush);
queryLifecycleObserver.queryWritten(nameServerAddr, future);
return castPromise;
} catch (Exception e) {
@@ -1395,94 +1407,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
return;
}
- // Check if the response was truncated and if we can fallback to TCP to retry.
- if (!res.isTruncated() || socketBootstrap == null) {
- qCtx.finishSuccess(res);
- return;
- }
-
- socketBootstrap.connect(res.sender()).addListener(new ChannelFutureListener() {
- @Override
- public void operationComplete(ChannelFuture future) {
- if (!future.isSuccess()) {
- logger.debug("{} Unable to fallback to TCP [{}: {}]",
- ch, queryId, res.sender(), future.cause());
-
- // TCP fallback failed, just use the truncated response.
- qCtx.finishSuccess(res);
- return;
- }
- final Channel tcpCh = future.channel();
-
- Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise =
- tcpCh.eventLoop().newPromise();
- final int payloadSize = isOptResourceEnabled() ? maxPayloadSize() : 0;
- final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(tcpCh, channelReadyPromise,
- (InetSocketAddress) tcpCh.remoteAddress(), queryContextManager, payloadSize,
- isRecursionDesired(), qCtx.question(), EMPTY_ADDITIONALS, promise);
-
- tcpCh.pipeline().addLast(new TcpDnsResponseDecoder());
- tcpCh.pipeline().addLast(new ChannelInboundHandlerAdapter() {
- @Override
- public void channelRead(ChannelHandlerContext ctx, Object msg) {
- Channel tcpCh = ctx.channel();
- DnsResponse response = (DnsResponse) msg;
- int queryId = response.id();
-
- if (logger.isDebugEnabled()) {
- logger.debug("{} RECEIVED: TCP [{}: {}], {}", tcpCh, queryId,
- tcpCh.remoteAddress(), response);
- }
-
- DnsQueryContext foundCtx = queryContextManager.get(res.sender(), queryId);
- if (foundCtx != null && foundCtx.isDone()) {
- logger.debug("{} Received a DNS response for a query that was timed out or cancelled " +
- ": TCP [{}: {}]", tcpCh, queryId, res.sender());
- response.release();
- } else if (foundCtx == tcpCtx) {
- tcpCtx.finishSuccess(new AddressedEnvelopeAdapter(
- (InetSocketAddress) ctx.channel().remoteAddress(),
- (InetSocketAddress) ctx.channel().localAddress(),
- response));
- } else {
- response.release();
- tcpCtx.finishFailure("Received TCP DNS response with unexpected ID", null, false);
- if (logger.isDebugEnabled()) {
- logger.debug("{} Received a DNS response with an unexpected ID: TCP [{}: {}]",
- tcpCh, queryId, tcpCh.remoteAddress());
- }
- }
- }
-
- @Override
- public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
- if (tcpCtx.finishFailure(
- "TCP fallback error", cause, false) && logger.isDebugEnabled()) {
- logger.debug("{} Error during processing response: TCP [{}: {}]",
- ctx.channel(), queryId,
- ctx.channel().remoteAddress(), cause);
- }
- }
- });
-
- promise.addListener(
- new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
- @Override
- public void operationComplete(
- Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
- if (future.isSuccess()) {
- qCtx.finishSuccess(future.getNow());
- res.release();
- } else {
- // TCP fallback failed, just use the truncated response.
- qCtx.finishSuccess(res);
- }
- tcpCh.close();
- }
- });
- tcpCtx.writeQuery(queryTimeoutMillis(), true);
- }
- });
+ // The context will handle truncation itself.
+ qCtx.finishSuccess(res, res.isTruncated());
}
@Override
@@ -1500,113 +1426,4 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
}
}
}
-
- private static final class AddressedEnvelopeAdapter implements AddressedEnvelope<DnsResponse, InetSocketAddress> {
- private final InetSocketAddress sender;
- private final InetSocketAddress recipient;
- private final DnsResponse response;
-
- AddressedEnvelopeAdapter(InetSocketAddress sender, InetSocketAddress recipient, DnsResponse response) {
- this.sender = sender;
- this.recipient = recipient;
- this.response = response;
- }
-
- @Override
- public DnsResponse content() {
- return response;
- }
-
- @Override
- public InetSocketAddress sender() {
- return sender;
- }
-
- @Override
- public InetSocketAddress recipient() {
- return recipient;
- }
-
- @Override
- public AddressedEnvelope<DnsResponse, InetSocketAddress> retain() {
- response.retain();
- return this;
- }
-
- @Override
- public AddressedEnvelope<DnsResponse, InetSocketAddress> retain(int increment) {
- response.retain(increment);
- return this;
- }
-
- @Override
- public AddressedEnvelope<DnsResponse, InetSocketAddress> touch() {
- response.touch();
- return this;
- }
-
- @Override
- public AddressedEnvelope<DnsResponse, InetSocketAddress> touch(Object hint) {
- response.touch(hint);
- return this;
- }
-
- @Override
- public int refCnt() {
- return response.refCnt();
- }
-
- @Override
- public boolean release() {
- return response.release();
- }
-
- @Override
- public boolean release(int decrement) {
- return response.release(decrement);
- }
-
- @Override
- public boolean equals(Object obj) {
- if (this == obj) {
- return true;
- }
-
- if (!(obj instanceof AddressedEnvelope)) {
- return false;
- }
-
- @SuppressWarnings("unchecked")
- final AddressedEnvelope<?, SocketAddress> that = (AddressedEnvelope<?, SocketAddress>) obj;
- if (sender() == null) {
- if (that.sender() != null) {
- return false;
- }
- } else if (!sender().equals(that.sender())) {
- return false;
- }
-
- if (recipient() == null) {
- if (that.recipient() != null) {
- return false;
- }
- } else if (!recipient().equals(that.recipient())) {
- return false;
- }
-
- return response.equals(obj);
- }
-
- @Override
- public int hashCode() {
- int hashCode = response.hashCode();
- if (sender() != null) {
- hashCode = hashCode * 31 + sender().hashCode();
- }
- if (recipient() != null) {
- hashCode = hashCode * 31 + recipient().hashCode();
- }
- return hashCode;
- }
- }
}
diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java
index e3db5f807f6d..1741e25dd543 100644
--- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java
+++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java
@@ -15,10 +15,13 @@
*/
package io.netty.resolver.dns;
+import io.netty.bootstrap.Bootstrap;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.dns.AbstractDnsOptPseudoRrRecord;
import io.netty.handler.codec.dns.DnsQuery;
@@ -27,16 +30,20 @@
import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.handler.codec.dns.DnsSection;
+import io.netty.handler.codec.dns.TcpDnsQueryEncoder;
+import io.netty.handler.codec.dns.TcpDnsResponseDecoder;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.SystemPropertyUtil;
+import io.netty.util.internal.ThrowableUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.net.InetSocketAddress;
+import java.net.SocketAddress;
import java.util.concurrent.CancellationException;
import java.util.concurrent.TimeUnit;
@@ -53,6 +60,8 @@ abstract class DnsQueryContext {
logger.debug("-Dio.netty.resolver.dns.idReuseOnTimeoutDelayMillis: {}", ID_REUSE_ON_TIMEOUT_DELAY_MILLIS);
}
+ private static final TcpDnsQueryEncoder TCP_ENCODER = new TcpDnsQueryEncoder();
+
private final Future<? extends Channel> channelReadyFuture;
private final Channel channel;
private final InetSocketAddress nameServerAddr;
@@ -64,6 +73,10 @@ abstract class DnsQueryContext {
private final DnsRecord optResource;
private final boolean recursionDesired;
+
+ private final Bootstrap socketBootstrap;
+ private final long queryTimeoutMillis;
+
private volatile Future<?> timeoutFuture;
private int id = -1;
@@ -74,9 +87,10 @@ abstract class DnsQueryContext {
DnsQueryContextManager queryContextManager,
int maxPayLoadSize,
boolean recursionDesired,
+ long queryTimeoutMillis,
DnsQuestion question,
DnsRecord[] additionals,
- Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise) {
+ Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise, Bootstrap socketBootstrap) {
this.channel = checkNotNull(channel, "channel");
this.queryContextManager = checkNotNull(queryContextManager, "queryContextManager");
this.channelReadyFuture = checkNotNull(channelReadyFuture, "channelReadyFuture");
@@ -85,6 +99,8 @@ abstract class DnsQueryContext {
this.additionals = checkNotNull(additionals, "additionals");
this.promise = checkNotNull(promise, "promise");
this.recursionDesired = recursionDesired;
+ this.queryTimeoutMillis = queryTimeoutMillis;
+ this.socketBootstrap = socketBootstrap;
if (maxPayLoadSize > 0 &&
// Only add the extra OPT record if there is not already one. This is required as only one is allowed
@@ -147,12 +163,10 @@ final DnsQuestion question() {
/**
* Write the query and return the {@link ChannelFuture} that is completed once the write completes.
*
- * @param queryTimeoutMillis the timeout after which the query is considered timeout and the original
- * {@link Promise} will be failed.
* @param flush {@code true} if {@link Channel#flush()} should be called as well.
* @return the {@link ChannelFuture} that is notified once once the write completes.
*/
- final ChannelFuture writeQuery(long queryTimeoutMillis, boolean flush) {
+ final ChannelFuture writeQuery(boolean flush) {
assert id == -1 : this.getClass().getSimpleName() + ".writeQuery(...) can only be executed once.";
id = queryContextManager.add(nameServerAddr, this);
@@ -205,7 +219,7 @@ public void run() {
channel, protocol(), id, nameServerAddr, question);
}
- return sendQuery(nameServerAddr, query, queryTimeoutMillis, flush);
+ return sendQuery(query, flush);
}
private void removeFromContextManager(InetSocketAddress nameServerAddr) {
@@ -214,11 +228,10 @@ private void removeFromContextManager(InetSocketAddress nameServerAddr) {
assert self == this : "Removed DnsQueryContext is not the correct instance";
}
- private ChannelFuture sendQuery(final InetSocketAddress nameServerAddr, final DnsQuery query,
- final long queryTimeoutMillis, final boolean flush) {
+ private ChannelFuture sendQuery(final DnsQuery query, final boolean flush) {
final ChannelPromise writePromise = channel.newPromise();
if (channelReadyFuture.isSuccess()) {
- writeQuery(nameServerAddr, query, queryTimeoutMillis, flush, writePromise);
+ writeQuery(query, flush, writePromise);
} else {
Throwable cause = channelReadyFuture.cause();
if (cause != null) {
@@ -233,7 +246,7 @@ public void operationComplete(Future<? super Channel> future) {
// If the query is done in a late fashion (as the channel was not ready yet) we always flush
// to ensure we did not race with a previous flush() that was done when the Channel was not
// ready yet.
- writeQuery(nameServerAddr, query, queryTimeoutMillis, true, writePromise);
+ writeQuery(query, true, writePromise);
} else {
Throwable cause = future.cause();
failQuery(query, cause, writePromise);
@@ -254,7 +267,7 @@ private void failQuery(DnsQuery query, Throwable cause, ChannelPromise writeProm
}
}
- private void writeQuery(final InetSocketAddress nameServerAddr, final DnsQuery query, final long queryTimeoutMillis,
+ private void writeQuery(final DnsQuery query,
final boolean flush, ChannelPromise promise) {
final ChannelFuture writeFuture = flush ? channel.writeAndFlush(query, promise) :
channel.write(query, promise);
@@ -298,18 +311,21 @@ public void run() {
* Notifies the original {@link Promise} that the response for the query was received.
* This method takes ownership of passed {@link AddressedEnvelope}.
*/
- void finishSuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
- final DnsResponse res = envelope.content();
- if (res.count(DnsSection.QUESTION) != 1) {
- logger.warn("{} Received a DNS response with invalid number of questions. Expected: 1, found: {}",
- channel, envelope);
- } else if (!question().equals(res.recordAt(DnsSection.QUESTION))) {
- logger.warn("{} Received a mismatching DNS response. Expected: [{}], found: {}",
- channel, question(), envelope);
- } else if (trySuccess(envelope)) {
- return; // Ownership transferred, don't release
+ void finishSuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope, boolean truncated) {
+ // Check if the response was not truncated or if a fallback to TCP is possible.
+ if (!truncated || !retryWithTCP(envelope)) {
+ final DnsResponse res = envelope.content();
+ if (res.count(DnsSection.QUESTION) != 1) {
+ logger.warn("{} Received a DNS response with invalid number of questions. Expected: 1, found: {}",
+ channel, envelope);
+ } else if (!question().equals(res.recordAt(DnsSection.QUESTION))) {
+ logger.warn("{} Received a mismatching DNS response. Expected: [{}], found: {}",
+ channel, question(), envelope);
+ } else if (trySuccess(envelope)) {
+ return; // Ownership transferred, don't release
+ }
+ envelope.release();
}
- envelope.release();
}
@SuppressWarnings("unchecked")
@@ -342,9 +358,229 @@ final boolean finishFailure(String message, Throwable cause, boolean timeout) {
// This was caused by a timeout so use DnsNameResolverTimeoutException to allow the user to
// handle it special (like retry the query).
e = new DnsNameResolverTimeoutException(nameServerAddr, question, buf.toString());
+ if (retryWithTCP(e)) {
+ // We did successfully retry with TCP.
+ return false;
+ }
} else {
e = new DnsNameResolverException(nameServerAddr, question, buf.toString(), cause);
}
return promise.tryFailure(e);
}
+
+ /**
+ * Retry the original query with TCP if possible.
+ *
+ * @param originalResult the result of the original {@link DnsQueryContext}.
+ * @return {@code true} if retry via TCP is supported and so the ownership of
+ * {@code originalResult} was transferred, {@code false} otherwise.
+ */
+ private boolean retryWithTCP(final Object originalResult) {
+ if (socketBootstrap == null) {
+ return false;
+ }
+
+ socketBootstrap.connect(nameServerAddr).addListener(new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) {
+ if (!future.isSuccess()) {
+ logger.debug("{} Unable to fallback to TCP [{}: {}]",
+ future.channel(), id, nameServerAddr, future.cause());
+
+ // TCP fallback failed, just use the truncated response or error.
+ finishOriginal(originalResult, future);
+ return;
+ }
+ final Channel tcpCh = future.channel();
+ Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise =
+ tcpCh.eventLoop().newPromise();
+ final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(tcpCh, channelReadyFuture,
+ (InetSocketAddress) tcpCh.remoteAddress(), queryContextManager, 0,
+ recursionDesired, queryTimeoutMillis, question(), additionals, promise);
+ tcpCh.pipeline().addLast(TCP_ENCODER);
+ tcpCh.pipeline().addLast(new TcpDnsResponseDecoder());
+ tcpCh.pipeline().addLast(new ChannelInboundHandlerAdapter() {
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg) {
+ Channel tcpCh = ctx.channel();
+ DnsResponse response = (DnsResponse) msg;
+ int queryId = response.id();
+
+ if (logger.isDebugEnabled()) {
+ logger.debug("{} RECEIVED: TCP [{}: {}], {}", tcpCh, queryId,
+ tcpCh.remoteAddress(), response);
+ }
+
+ DnsQueryContext foundCtx = queryContextManager.get(nameServerAddr, queryId);
+ if (foundCtx != null && foundCtx.isDone()) {
+ logger.debug("{} Received a DNS response for a query that was timed out or cancelled " +
+ ": TCP [{}: {}]", tcpCh, queryId, nameServerAddr);
+ response.release();
+ } else if (foundCtx == tcpCtx) {
+ tcpCtx.finishSuccess(new AddressedEnvelopeAdapter(
+ (InetSocketAddress) ctx.channel().remoteAddress(),
+ (InetSocketAddress) ctx.channel().localAddress(),
+ response), false);
+ } else {
+ response.release();
+ tcpCtx.finishFailure("Received TCP DNS response with unexpected ID", null, false);
+ if (logger.isDebugEnabled()) {
+ logger.debug("{} Received a DNS response with an unexpected ID: TCP [{}: {}]",
+ tcpCh, queryId, tcpCh.remoteAddress());
+ }
+ }
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+ if (tcpCtx.finishFailure(
+ "TCP fallback error", cause, false) && logger.isDebugEnabled()) {
+ logger.debug("{} Error during processing response: TCP [{}: {}]",
+ ctx.channel(), id,
+ ctx.channel().remoteAddress(), cause);
+ }
+ }
+ });
+
+ promise.addListener(
+ new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
+ @Override
+ public void operationComplete(
+ Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
+ if (future.isSuccess()) {
+ finishSuccess(future.getNow(), false);
+ // Release the original result.
+ ReferenceCountUtil.release(originalResult);
+ } else {
+ // TCP fallback failed, just use the truncated response or error.
+ finishOriginal(originalResult, future);
+ }
+ tcpCh.close();
+ }
+ });
+ tcpCtx.writeQuery(true);
+ }
+ });
+ return true;
+ }
+
+ @SuppressWarnings("unchecked")
+ private void finishOriginal(Object originalResult, Future<?> future) {
+ if (originalResult instanceof Throwable) {
+ Throwable error = (Throwable) originalResult;
+ ThrowableUtil.addSuppressed(error, future.cause());
+ promise.tryFailure(error);
+ } else {
+ finishSuccess((AddressedEnvelope<? extends DnsResponse, InetSocketAddress>) originalResult, false);
+ }
+ }
+
+ private static final class AddressedEnvelopeAdapter implements AddressedEnvelope<DnsResponse, InetSocketAddress> {
+ private final InetSocketAddress sender;
+ private final InetSocketAddress recipient;
+ private final DnsResponse response;
+
+ AddressedEnvelopeAdapter(InetSocketAddress sender, InetSocketAddress recipient, DnsResponse response) {
+ this.sender = sender;
+ this.recipient = recipient;
+ this.response = response;
+ }
+
+ @Override
+ public DnsResponse content() {
+ return response;
+ }
+
+ @Override
+ public InetSocketAddress sender() {
+ return sender;
+ }
+
+ @Override
+ public InetSocketAddress recipient() {
+ return recipient;
+ }
+
+ @Override
+ public AddressedEnvelope<DnsResponse, InetSocketAddress> retain() {
+ response.retain();
+ return this;
+ }
+
+ @Override
+ public AddressedEnvelope<DnsResponse, InetSocketAddress> retain(int increment) {
+ response.retain(increment);
+ return this;
+ }
+
+ @Override
+ public AddressedEnvelope<DnsResponse, InetSocketAddress> touch() {
+ response.touch();
+ return this;
+ }
+
+ @Override
+ public AddressedEnvelope<DnsResponse, InetSocketAddress> touch(Object hint) {
+ response.touch(hint);
+ return this;
+ }
+
+ @Override
+ public int refCnt() {
+ return response.refCnt();
+ }
+
+ @Override
+ public boolean release() {
+ return response.release();
+ }
+
+ @Override
+ public boolean release(int decrement) {
+ return response.release(decrement);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+
+ if (!(obj instanceof AddressedEnvelope)) {
+ return false;
+ }
+
+ @SuppressWarnings("unchecked")
+ final AddressedEnvelope<?, SocketAddress> that = (AddressedEnvelope<?, SocketAddress>) obj;
+ if (sender() == null) {
+ if (that.sender() != null) {
+ return false;
+ }
+ } else if (!sender().equals(that.sender())) {
+ return false;
+ }
+
+ if (recipient() == null) {
+ if (that.recipient() != null) {
+ return false;
+ }
+ } else if (!recipient().equals(that.recipient())) {
+ return false;
+ }
+
+ return response.equals(obj);
+ }
+
+ @Override
+ public int hashCode() {
+ int hashCode = response.hashCode();
+ if (sender() != null) {
+ hashCode = hashCode * 31 + sender().hashCode();
+ }
+ if (recipient() != null) {
+ hashCode = hashCode * 31 + recipient().hashCode();
+ }
+ return hashCode;
+ }
+ }
}
diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/TcpDnsQueryContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/TcpDnsQueryContext.java
index 8f25ab8664e7..2111b67aad9c 100644
--- a/resolver-dns/src/main/java/io/netty/resolver/dns/TcpDnsQueryContext.java
+++ b/resolver-dns/src/main/java/io/netty/resolver/dns/TcpDnsQueryContext.java
@@ -33,10 +33,12 @@ final class TcpDnsQueryContext extends DnsQueryContext {
InetSocketAddress nameServerAddr,
DnsQueryContextManager queryContextManager,
int maxPayLoadSize, boolean recursionDesired,
+ long queryTimeoutMillis,
DnsQuestion question, DnsRecord[] additionals,
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise) {
super(channel, channelReadyFuture, nameServerAddr, queryContextManager, maxPayLoadSize, recursionDesired,
- question, additionals, promise);
+ // No retry via TCP.
+ queryTimeoutMillis, question, additionals, promise, null);
}
@Override
diff --git a/resolver-dns/src/main/resources/META-INF/native-image/io.netty/netty-resolver-dns/generated/handlers/reflect-config.json b/resolver-dns/src/main/resources/META-INF/native-image/io.netty/netty-resolver-dns/generated/handlers/reflect-config.json
index 5960b0047047..68508d9de7b0 100644
--- a/resolver-dns/src/main/resources/META-INF/native-image/io.netty/netty-resolver-dns/generated/handlers/reflect-config.json
+++ b/resolver-dns/src/main/resources/META-INF/native-image/io.netty/netty-resolver-dns/generated/handlers/reflect-config.json
@@ -7,9 +7,16 @@
"queryAllPublicMethods": true
},
{
- "name": "io.netty.resolver.dns.DnsNameResolver$3",
+ "name": "io.netty.resolver.dns.DnsNameResolver$2",
"condition": {
- "typeReachable": "io.netty.resolver.dns.DnsNameResolver$3"
+ "typeReachable": "io.netty.resolver.dns.DnsNameResolver$2"
+ },
+ "queryAllPublicMethods": true
+ },
+ {
+ "name": "io.netty.resolver.dns.DnsNameResolver$4",
+ "condition": {
+ "typeReachable": "io.netty.resolver.dns.DnsNameResolver$4"
},
"queryAllPublicMethods": true
},
@@ -21,9 +28,9 @@
"queryAllPublicMethods": true
},
{
- "name": "io.netty.resolver.dns.DnsNameResolver$DnsResponseHandler$1$1",
+ "name": "io.netty.resolver.dns.DnsQueryContext$6$1",
"condition": {
- "typeReachable": "io.netty.resolver.dns.DnsNameResolver$DnsResponseHandler$1$1"
+ "typeReachable": "io.netty.resolver.dns.DnsQueryContext$6$1"
},
"queryAllPublicMethods": true
}
diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java
index 4977981a53b3..30277cb9a6c7 100644
--- a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java
+++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java
@@ -3293,30 +3293,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
if (tcpFallback) {
// If we are configured to use TCP as a fallback lets replay the dns message over TCP
Socket socket = serverSocket.accept();
+ responseViaSocket(socket, messageRef.get());
- InputStream in = socket.getInputStream();
- assertTrue((in.read() << 8 | (in.read() & 0xff)) > 2); // skip length field
- int txnId = in.read() << 8 | (in.read() & 0xff);
-
- IoBuffer ioBuffer = IoBuffer.allocate(1024);
- // Must replace the transactionId with the one from the TCP request
- DnsMessageModifier modifier = modifierFrom(messageRef.get());
- modifier.setTransactionId(txnId);
- new DnsMessageEncoder().encode(ioBuffer, modifier.getDnsMessage());
- ioBuffer.flip();
-
- ByteBuffer lenBuffer = ByteBuffer.allocate(2);
- lenBuffer.putShort((short) ioBuffer.remaining());
- lenBuffer.flip();
-
- while (lenBuffer.hasRemaining()) {
- socket.getOutputStream().write(lenBuffer.get());
- }
-
- while (ioBuffer.hasRemaining()) {
- socket.getOutputStream().write(ioBuffer.get());
- }
- socket.getOutputStream().flush();
// Let's wait until we received the envelope before closing the socket.
envelopeFuture.syncUninterruptibly();
@@ -3352,6 +3330,137 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
}
}
+ private static void responseViaSocket(Socket socket, DnsMessage message) throws IOException {
+ InputStream in = socket.getInputStream();
+ assertTrue((in.read() << 8 | (in.read() & 0xff)) > 2); // skip length field
+ int txnId = in.read() << 8 | (in.read() & 0xff);
+
+ IoBuffer ioBuffer = IoBuffer.allocate(1024);
+ // Must replace the transactionId with the one from the TCP request
+ DnsMessageModifier modifier = modifierFrom(message);
+ modifier.setTransactionId(txnId);
+ new DnsMessageEncoder().encode(ioBuffer, modifier.getDnsMessage());
+ ioBuffer.flip();
+
+ ByteBuffer lenBuffer = ByteBuffer.allocate(2);
+ lenBuffer.putShort((short) ioBuffer.remaining());
+ lenBuffer.flip();
+
+ while (lenBuffer.hasRemaining()) {
+ socket.getOutputStream().write(lenBuffer.get());
+ }
+
+ while (ioBuffer.hasRemaining()) {
+ socket.getOutputStream().write(ioBuffer.get());
+ }
+ socket.getOutputStream().flush();
+ }
+
+ @Test
+ public void testTcpFallbackWhenTimeout() throws IOException {
+ testTcpFallbackWhenTimeout(true);
+ }
+
+ @Test
+ public void testTcpFallbackFailedWhenTimeout() throws IOException {
+ testTcpFallbackWhenTimeout(false);
+ }
+
+ private void testTcpFallbackWhenTimeout(boolean tcpSuccess) throws IOException {
+ ServerSocket serverSocket = new ServerSocket();
+ serverSocket.setReuseAddress(true);
+ serverSocket.bind(new InetSocketAddress(NetUtil.LOCALHOST4, 0));
+
+ final String host = "somehost.netty.io";
+ final String txt = "this is a txt record";
+ final AtomicReference<DnsMessage> messageRef = new AtomicReference<DnsMessage>();
+
+ TestDnsServer dnsServer2 = new TestDnsServer(new RecordStore() {
+ @Override
+ public Set<ResourceRecord> getRecords(QuestionRecord question) {
+ String name = question.getDomainName();
+ if (name.equals(host)) {
+ return Collections.<ResourceRecord>singleton(
+ new TestDnsServer.TestResourceRecord(name, RecordType.TXT,
+ Collections.<String, Object>singletonMap(
+ DnsAttribute.CHARACTER_STRING.toLowerCase(), txt)));
+ }
+ return null;
+ }
+ }) {
+ @Override
+ protected DnsMessage filterMessage(DnsMessage message) {
+ // Store a original message so we can replay it later on.
+ messageRef.set(message);
+ return null;
+ }
+ };
+ DnsNameResolver resolver = null;
+ try {
+ DnsNameResolverBuilder builder = newResolver();
+ final DatagramChannel datagramChannel = new NioDatagramChannel();
+ ChannelFactory<DatagramChannel> channelFactory = new ChannelFactory<DatagramChannel>() {
+ @Override
+ public DatagramChannel newChannel() {
+ return datagramChannel;
+ }
+ };
+ builder.channelFactory(channelFactory);
+ dnsServer2.start(null, (InetSocketAddress) serverSocket.getLocalSocketAddress());
+ // If we are configured to use TCP as a fallback also bind a TCP socket
+ builder.socketChannelType(NioSocketChannel.class);
+
+ builder.queryTimeoutMillis(1000)
+ .resolvedAddressTypes(ResolvedAddressTypes.IPV4_PREFERRED)
+ .maxQueriesPerResolve(16)
+ .nameServerProvider(new SingletonDnsServerAddressStreamProvider(dnsServer2.localAddress()));
+ resolver = builder.build();
+ Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> envelopeFuture = resolver.query(
+ new DefaultDnsQuestion(host, DnsRecordType.TXT));
+
+ // If we are configured to use TCP as a fallback lets replay the dns message over TCP
+ Socket socket = serverSocket.accept();
+
+ if (tcpSuccess) {
+ responseViaSocket(socket, messageRef.get());
+ socket.close();
+
+ // Let's wait until we received the envelope before closing the socket.
+ envelopeFuture.syncUninterruptibly();
+
+ AddressedEnvelope<DnsResponse, InetSocketAddress> envelope =
+ envelopeFuture.syncUninterruptibly().getNow();
+ assertNotNull(envelope.sender());
+
+ DnsResponse response = envelope.content();
+ assertNotNull(response);
+
+ assertEquals(DnsResponseCode.NOERROR, response.code());
+ int count = response.count(DnsSection.ANSWER);
+
+ assertEquals(1, count);
+ List<String> texts = decodeTxt(response.recordAt(DnsSection.ANSWER, 0));
+ assertEquals(1, texts.size());
+ assertEquals(txt, texts.get(0));
+
+ assertFalse(envelope.content().isTruncated());
+ assertTrue(envelope.release());
+ } else {
+ // Just close the socket. This should cause the original exception to be used.
+ socket.close();
+ Throwable error = envelopeFuture.awaitUninterruptibly().cause();
+ assertThat(error, instanceOf(DnsNameResolverTimeoutException.class));
+ assertThat(error.getSuppressed().length, greaterThanOrEqualTo(1));
+ }
+ } finally {
+ dnsServer2.stop();
+ if (resolver != null) {
+ resolver.close();
+ }
+ serverSocket.close();
+ }
+ }
+
@Test
public void testCancelPromise() throws Exception {
final EventLoop eventLoop = group.next();
From fea7225dff68c16be94451e2c1c44ec442547e5a Mon Sep 17 00:00:00 2001
From: Norman Maurer <norman_maurer@apple.com>
Date: Fri, 29 Dec 2023 22:05:08 +0100
Subject: [PATCH 2/4] Fix race in test
---
.../test/java/io/netty/resolver/dns/DnsNameResolverTest.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java
index 30277cb9a6c7..2cd9e6a4d9f9 100644
--- a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java
+++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java
@@ -3423,10 +3423,10 @@ public DatagramChannel newChannel() {
if (tcpSuccess) {
responseViaSocket(socket, messageRef.get());
- socket.close();
// Let's wait until we received the envelope before closing the socket.
envelopeFuture.syncUninterruptibly();
+ socket.close();
AddressedEnvelope<DnsResponse, InetSocketAddress> envelope =
envelopeFuture.syncUninterruptibly().getNow();
From 6785c2d865fad37cf2a12963be99c8789aa82b15 Mon Sep 17 00:00:00 2001
From: Norman Maurer <norman_maurer@apple.com>
Date: Fri, 12 Jan 2024 12:23:06 +0100
Subject: [PATCH 3/4] Fix compile error
---
.../src/main/java/io/netty/resolver/dns/DnsNameResolver.java | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java
index 9a63aa185dc9..4391b689e1a7 100644
--- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java
+++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java
@@ -470,8 +470,7 @@ public DnsNameResolver(
socketBootstrap.option(ChannelOption.SO_REUSEADDR, true)
.group(executor())
.channelFactory(socketChannelFactory)
- .attr(DNS_PIPELINE_ATTRIBUTE, Boolean.TRUE)
- .handler(TCP_ENCODER);
+ .attr(DNS_PIPELINE_ATTRIBUTE, Boolean.TRUE);
if (queryTimeoutMillis > 0 && queryTimeoutMillis <= Integer.MAX_VALUE) {
// Set the connect timeout to the same as queryTimeout as otherwise it might take a long
// time for the query to fail in case of a connection timeout.
From 3c8df3a4d4b0ffe8116e254de29ef52a94f9adc5 Mon Sep 17 00:00:00 2001
From: Norman Maurer <norman_maurer@apple.com>
Date: Fri, 12 Jan 2024 12:54:35 +0100
Subject: [PATCH 4/4] Make tcp fallback on timeout configurable
---
.../resolver/dns/DatagramDnsQueryContext.java | 4 +-
.../netty/resolver/dns/DnsNameResolver.java | 41 ++++----------
.../resolver/dns/DnsNameResolverBuilder.java | 54 ++++++++++++++++---
.../netty/resolver/dns/DnsQueryContext.java | 13 +++--
.../resolver/dns/TcpDnsQueryContext.java | 2 +-
.../resolver/dns/DnsNameResolverTest.java | 2 +-
6 files changed, 70 insertions(+), 46 deletions(-)
diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DatagramDnsQueryContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DatagramDnsQueryContext.java
index 4c5487069079..ca382dcbb114 100644
--- a/resolver-dns/src/main/java/io/netty/resolver/dns/DatagramDnsQueryContext.java
+++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DatagramDnsQueryContext.java
@@ -37,9 +37,9 @@ final class DatagramDnsQueryContext extends DnsQueryContext {
long queryTimeoutMillis,
DnsQuestion question, DnsRecord[] additionals,
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise,
- Bootstrap socketBootstrap) {
+ Bootstrap socketBootstrap, boolean retryWithTcpOnTimeout) {
super(channel, channelReadyFuture, nameServerAddr, queryContextManager, maxPayLoadSize, recursionDesired,
- queryTimeoutMillis, question, additionals, promise, socketBootstrap);
+ queryTimeoutMillis, question, additionals, promise, socketBootstrap, retryWithTcpOnTimeout);
}
@Override
diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java
index 4391b689e1a7..d286bac27832 100644
--- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java
+++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java
@@ -279,6 +279,7 @@ protected DnsServerAddressStream initialValue() {
private final DnsQueryLifecycleObserverFactory dnsQueryLifecycleObserverFactory;
private final boolean completeOncePreferredResolved;
private final Bootstrap socketBootstrap;
+ private final boolean retryWithTcpOnTimeout;
private final int maxNumConsolidation;
private final Map<String, Future<List<InetAddress>>> inflightLookups;
@@ -382,44 +383,18 @@ public DnsNameResolver(
String[] searchDomains,
int ndots,
boolean decodeIdn) {
- this(eventLoop, channelFactory, null, resolveCache, NoopDnsCnameCache.INSTANCE, authoritativeDnsServerCache,
+ this(eventLoop, channelFactory, null, false, resolveCache,
+ NoopDnsCnameCache.INSTANCE, authoritativeDnsServerCache, null,
dnsQueryLifecycleObserverFactory, queryTimeoutMillis, resolvedAddressTypes, recursionDesired,
maxQueriesPerResolve, traceEnabled, maxPayloadSize, optResourceEnabled, hostsFileEntriesResolver,
- dnsServerAddressStreamProvider, searchDomains, ndots, decodeIdn, false);
- }
-
- DnsNameResolver(
- EventLoop eventLoop,
- ChannelFactory<? extends DatagramChannel> channelFactory,
- ChannelFactory<? extends SocketChannel> socketChannelFactory,
- final DnsCache resolveCache,
- final DnsCnameCache cnameCache,
- final AuthoritativeDnsServerCache authoritativeDnsServerCache,
- DnsQueryLifecycleObserverFactory dnsQueryLifecycleObserverFactory,
- long queryTimeoutMillis,
- ResolvedAddressTypes resolvedAddressTypes,
- boolean recursionDesired,
- int maxQueriesPerResolve,
- boolean traceEnabled,
- int maxPayloadSize,
- boolean optResourceEnabled,
- HostsFileEntriesResolver hostsFileEntriesResolver,
- DnsServerAddressStreamProvider dnsServerAddressStreamProvider,
- String[] searchDomains,
- int ndots,
- boolean decodeIdn,
- boolean completeOncePreferredResolved) {
- this(eventLoop, channelFactory, socketChannelFactory, resolveCache, cnameCache, authoritativeDnsServerCache,
- null, dnsQueryLifecycleObserverFactory, queryTimeoutMillis, resolvedAddressTypes,
- recursionDesired, maxQueriesPerResolve, traceEnabled, maxPayloadSize, optResourceEnabled,
- hostsFileEntriesResolver, dnsServerAddressStreamProvider, searchDomains, ndots, decodeIdn,
- completeOncePreferredResolved, 0);
+ dnsServerAddressStreamProvider, searchDomains, ndots, decodeIdn, false, 0);
}
DnsNameResolver(
EventLoop eventLoop,
ChannelFactory<? extends DatagramChannel> channelFactory,
ChannelFactory<? extends SocketChannel> socketChannelFactory,
+ boolean retryWithTcpOnTimeout,
final DnsCache resolveCache,
final DnsCnameCache cnameCache,
final AuthoritativeDnsServerCache authoritativeDnsServerCache,
@@ -463,6 +438,7 @@ public DnsNameResolver(
this.ndots = ndots >= 0 ? ndots : DEFAULT_OPTIONS.ndots();
this.decodeIdn = decodeIdn;
this.completeOncePreferredResolved = completeOncePreferredResolved;
+ this.retryWithTcpOnTimeout = retryWithTcpOnTimeout;
if (socketChannelFactory == null) {
socketBootstrap = null;
} else {
@@ -470,7 +446,8 @@ public DnsNameResolver(
socketBootstrap.option(ChannelOption.SO_REUSEADDR, true)
.group(executor())
.channelFactory(socketChannelFactory)
- .attr(DNS_PIPELINE_ATTRIBUTE, Boolean.TRUE);
+ .attr(DNS_PIPELINE_ATTRIBUTE, Boolean.TRUE)
+ .handler(NOOP_HANDLER);
if (queryTimeoutMillis > 0 && queryTimeoutMillis <= Integer.MAX_VALUE) {
// Set the connect timeout to the same as queryTimeout as otherwise it might take a long
// time for the query to fail in case of a connection timeout.
@@ -1360,7 +1337,7 @@ final Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query0(
try {
DnsQueryContext queryContext = new DatagramDnsQueryContext(ch, channelReadyPromise, nameServerAddr,
queryContextManager, payloadSize, isRecursionDesired(), queryTimeoutMillis(), question, additionals,
- castPromise, socketBootstrap);
+ castPromise, socketBootstrap, retryWithTcpOnTimeout);
ChannelFuture future = queryContext.writeQuery(flush);
queryLifecycleObserver.queryWritten(nameServerAddr, future);
return castPromise;
diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java
index ddd26d501b9b..0745ac45c285 100644
--- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java
+++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java
@@ -47,6 +47,8 @@ public final class DnsNameResolverBuilder {
volatile EventLoop eventLoop;
private ChannelFactory<? extends DatagramChannel> channelFactory;
private ChannelFactory<? extends SocketChannel> socketChannelFactory;
+ private boolean retryOnTimeout;
+
private DnsCache resolveCache;
private DnsCnameCache cnameCache;
private AuthoritativeDnsServerCache authoritativeDnsServerCache;
@@ -143,7 +145,44 @@ public DnsNameResolverBuilder channelType(Class<? extends DatagramChannel> chann
* @return {@code this}
*/
public DnsNameResolverBuilder socketChannelFactory(ChannelFactory<? extends SocketChannel> channelFactory) {
+ return socketChannelFactory(channelFactory, false);
+ }
+
+ /**
+ * Sets the {@link ChannelFactory} as a {@link ReflectiveChannelFactory} of this type for
+ * <a href="https://tools.ietf.org/html/rfc7766">TCP fallback</a> if needed.
+ * Use as an alternative to {@link #socketChannelFactory(ChannelFactory)}.
+ *
+ * TCP fallback is <strong>not</strong> enabled by default and must be enabled by providing a non-null
+ * {@code channelType} for this method.
+ *
+ * @param channelType the type or {@code null} if <a href="https://tools.ietf.org/html/rfc7766">TCP fallback</a>
+ * should not be supported. By default, TCP fallback is not enabled.
+ * @return {@code this}
+ */
+ public DnsNameResolverBuilder socketChannelType(Class<? extends SocketChannel> channelType) {
+ return socketChannelType(channelType, false);
+ }
+
+ /**
+ * Sets the {@link ChannelFactory} that will create a {@link SocketChannel} for
+ * <a href="https://tools.ietf.org/html/rfc7766">TCP fallback</a> if needed.
+ *
+ * TCP fallback is <strong>not</strong> enabled by default and must be enabled by providing a non-null
+ * {@link ChannelFactory} for this method.
+ *
+ * @param channelFactory the {@link ChannelFactory} or {@code null}
+ * if <a href="https://tools.ietf.org/html/rfc7766">TCP fallback</a> should not be supported.
+ * By default, TCP fallback is not enabled.
+ * @param retryOnTimeout if {@code true} the {@link DnsNameResolver} will also fallback to TCP if a timeout
+ * was detected, if {@code false} it will only try to use TCP if the response was marked
+ * as truncated.
+ * @return {@code this}
+ */
+ public DnsNameResolverBuilder socketChannelFactory(
+ ChannelFactory<? extends SocketChannel> channelFactory, boolean retryOnTimeout) {
this.socketChannelFactory = channelFactory;
+ this.retryOnTimeout = retryOnTimeout;
return this;
}
@@ -157,13 +196,17 @@ public DnsNameResolverBuilder socketChannelFactory(ChannelFactory<? extends Sock
*
* @param channelType the type or {@code null} if <a href="https://tools.ietf.org/html/rfc7766">TCP fallback</a>
* should not be supported. By default, TCP fallback is not enabled.
+ * @param retryOnTimeout if {@code true} the {@link DnsNameResolver} will also fallback to TCP if a timeout
+ * was detected, if {@code false} it will only try to use TCP if the response was marked
+ * as truncated.
* @return {@code this}
*/
- public DnsNameResolverBuilder socketChannelType(Class<? extends SocketChannel> channelType) {
+ public DnsNameResolverBuilder socketChannelType(
+ Class<? extends SocketChannel> channelType, boolean retryOnTimeout) {
if (channelType == null) {
- return socketChannelFactory(null);
+ return socketChannelFactory(null, retryOnTimeout);
}
- return socketChannelFactory(new ReflectiveChannelFactory<SocketChannel>(channelType));
+ return socketChannelFactory(new ReflectiveChannelFactory<SocketChannel>(channelType), retryOnTimeout);
}
/**
@@ -528,6 +571,7 @@ public DnsNameResolver build() {
eventLoop,
channelFactory,
socketChannelFactory,
+ retryOnTimeout,
resolveCache,
cnameCache,
authoritativeDnsServerCache,
@@ -565,9 +609,7 @@ public DnsNameResolverBuilder copy() {
copiedBuilder.channelFactory(channelFactory);
}
- if (socketChannelFactory != null) {
- copiedBuilder.socketChannelFactory(socketChannelFactory);
- }
+ copiedBuilder.socketChannelFactory(socketChannelFactory, retryOnTimeout);
if (resolveCache != null) {
copiedBuilder.resolveCache(resolveCache);
diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java
index 1741e25dd543..da1091b1d160 100644
--- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java
+++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java
@@ -75,6 +75,8 @@ abstract class DnsQueryContext {
private final boolean recursionDesired;
private final Bootstrap socketBootstrap;
+
+ private final boolean retryWithTcpOnTimeout;
private final long queryTimeoutMillis;
private volatile Future<?> timeoutFuture;
@@ -90,7 +92,9 @@ abstract class DnsQueryContext {
long queryTimeoutMillis,
DnsQuestion question,
DnsRecord[] additionals,
- Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise, Bootstrap socketBootstrap) {
+ Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise,
+ Bootstrap socketBootstrap,
+ boolean retryWithTcpOnTimeout) {
this.channel = checkNotNull(channel, "channel");
this.queryContextManager = checkNotNull(queryContextManager, "queryContextManager");
this.channelReadyFuture = checkNotNull(channelReadyFuture, "channelReadyFuture");
@@ -101,6 +105,7 @@ abstract class DnsQueryContext {
this.recursionDesired = recursionDesired;
this.queryTimeoutMillis = queryTimeoutMillis;
this.socketBootstrap = socketBootstrap;
+ this.retryWithTcpOnTimeout = retryWithTcpOnTimeout;
if (maxPayLoadSize > 0 &&
// Only add the extra OPT record if there is not already one. This is required as only one is allowed
@@ -313,7 +318,7 @@ public void run() {
*/
void finishSuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope, boolean truncated) {
// Check if the response was not truncated or if a fallback to TCP is possible.
- if (!truncated || !retryWithTCP(envelope)) {
+ if (!truncated || !retryWithTcp(envelope)) {
final DnsResponse res = envelope.content();
if (res.count(DnsSection.QUESTION) != 1) {
logger.warn("{} Received a DNS response with invalid number of questions. Expected: 1, found: {}",
@@ -358,7 +363,7 @@ final boolean finishFailure(String message, Throwable cause, boolean timeout) {
// This was caused by a timeout so use DnsNameResolverTimeoutException to allow the user to
// handle it special (like retry the query).
e = new DnsNameResolverTimeoutException(nameServerAddr, question, buf.toString());
- if (retryWithTCP(e)) {
+ if (retryWithTcpOnTimeout && retryWithTcp(e)) {
// We did successfully retry with TCP.
return false;
}
@@ -375,7 +380,7 @@ final boolean finishFailure(String message, Throwable cause, boolean timeout) {
* @return {@code true} if retry via TCP is supported and so the ownership of
* {@code originalResult} was transferred, {@code false} otherwise.
*/
- private boolean retryWithTCP(final Object originalResult) {
+ private boolean retryWithTcp(final Object originalResult) {
if (socketBootstrap == null) {
return false;
}
diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/TcpDnsQueryContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/TcpDnsQueryContext.java
index 2111b67aad9c..f8022a8c6a77 100644
--- a/resolver-dns/src/main/java/io/netty/resolver/dns/TcpDnsQueryContext.java
+++ b/resolver-dns/src/main/java/io/netty/resolver/dns/TcpDnsQueryContext.java
@@ -38,7 +38,7 @@ final class TcpDnsQueryContext extends DnsQueryContext {
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise) {
super(channel, channelReadyFuture, nameServerAddr, queryContextManager, maxPayLoadSize, recursionDesired,
// No retry via TCP.
- queryTimeoutMillis, question, additionals, promise, null);
+ queryTimeoutMillis, question, additionals, promise, null, false);
}
@Override
diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java
index 2cd9e6a4d9f9..b873a2f55510 100644
--- a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java
+++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java
@@ -3408,7 +3408,7 @@ public DatagramChannel newChannel() {
builder.channelFactory(channelFactory);
dnsServer2.start(null, (InetSocketAddress) serverSocket.getLocalSocketAddress());
// If we are configured to use TCP as a fallback also bind a TCP socket
- builder.socketChannelType(NioSocketChannel.class);
+ builder.socketChannelType(NioSocketChannel.class, true);
builder.queryTimeoutMillis(1000)
.resolvedAddressTypes(ResolvedAddressTypes.IPV4_PREFERRED)