Skip to content

Check qos, heartbeat, max channel are unsigned shorts #641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions src/main/java/com/rabbitmq/client/Channel.java
Original file line number Diff line number Diff line change
@@ -193,42 +193,49 @@ public interface Channel extends ShutdownNotifier, AutoCloseable {

/**
* Request specific "quality of service" settings.
*
* <p>
* These settings impose limits on the amount of data the server
* will deliver to consumers before requiring acknowledgements.
* Thus they provide a means of consumer-initiated flow control.
* @see com.rabbitmq.client.AMQP.Basic.Qos
* @param prefetchSize maximum amount of content (measured in
* octets) that the server will deliver, 0 if unlimited
* <p>
* Note the prefetch count must be between 0 and 65535 (unsigned short in AMQP 0-9-1).
*
* @param prefetchSize maximum amount of content (measured in
* octets) that the server will deliver, 0 if unlimited
* @param prefetchCount maximum number of messages that the server
* will deliver, 0 if unlimited
* @param global true if the settings should be applied to the
* entire channel rather than each consumer
* will deliver, 0 if unlimited
* @param global true if the settings should be applied to the
* entire channel rather than each consumer
* @throws java.io.IOException if an error is encountered
* @see com.rabbitmq.client.AMQP.Basic.Qos
*/
void basicQos(int prefetchSize, int prefetchCount, boolean global) throws IOException;

/**
* Request a specific prefetchCount "quality of service" settings
* for this channel.
* <p>
* Note the prefetch count must be between 0 and 65535 (unsigned short in AMQP 0-9-1).
*
* @see #basicQos(int, int, boolean)
* @param prefetchCount maximum number of messages that the server
* will deliver, 0 if unlimited
* @param global true if the settings should be applied to the
* entire channel rather than each consumer
* will deliver, 0 if unlimited
* @param global true if the settings should be applied to the
* entire channel rather than each consumer
* @throws java.io.IOException if an error is encountered
* @see #basicQos(int, int, boolean)
*/
void basicQos(int prefetchCount, boolean global) throws IOException;

/**
* Request a specific prefetchCount "quality of service" settings
* for this channel.
* <p>
* Note the prefetch count must be between 0 and 65535 (unsigned short in AMQP 0-9-1).
*
* @see #basicQos(int, int, boolean)
* @param prefetchCount maximum number of messages that the server
* will deliver, 0 if unlimited
* will deliver, 0 if unlimited
* @throws java.io.IOException if an error is encountered
* @see #basicQos(int, int, boolean)
*/
void basicQos(int prefetchCount) throws IOException;

16 changes: 15 additions & 1 deletion src/main/java/com/rabbitmq/client/ConnectionFactory.java
Original file line number Diff line number Diff line change
@@ -47,6 +47,8 @@
*/
public class ConnectionFactory implements Cloneable {

private static final int MAX_UNSIGNED_SHORT = 65535;

/** Default user name */
public static final String DEFAULT_USER = "guest";
/** Default password */
@@ -384,10 +386,16 @@ public int getRequestedChannelMax() {
}

/**
* Set the requested maximum channel number
* Set the requested maximum channel number.
* <p>
* Note the value must be between 0 and 65535 (unsigned short in AMQP 0-9-1).
*
* @param requestedChannelMax initially requested maximum channel number; zero for unlimited
*/
public void setRequestedChannelMax(int requestedChannelMax) {
if (requestedChannelMax < 0 || requestedChannelMax > MAX_UNSIGNED_SHORT) {
throw new IllegalArgumentException("Requested channel max must be between 0 and " + MAX_UNSIGNED_SHORT);
}
this.requestedChannelMax = requestedChannelMax;
}

@@ -477,10 +485,16 @@ public int getShutdownTimeout() {
* Set the requested heartbeat timeout. Heartbeat frames will be sent at about 1/2 the timeout interval.
* If server heartbeat timeout is configured to a non-zero value, this method can only be used
* to lower the value; otherwise any value provided by the client will be used.
* <p>
* Note the value must be between 0 and 65535 (unsigned short in AMQP 0-9-1).
*
* @param requestedHeartbeat the initially requested heartbeat timeout, in seconds; zero for none
* @see <a href="https://rabbitmq.com/heartbeats.html">RabbitMQ Heartbeats Guide</a>
*/
public void setRequestedHeartbeat(int requestedHeartbeat) {
if (requestedHeartbeat < 0 || requestedHeartbeat > MAX_UNSIGNED_SHORT) {
throw new IllegalArgumentException("Requested heartbeat must be between 0 and " + MAX_UNSIGNED_SHORT);
}
this.requestedHeartbeat = requestedHeartbeat;
}

18 changes: 16 additions & 2 deletions src/main/java/com/rabbitmq/client/impl/AMQConnection.java
Original file line number Diff line number Diff line change
@@ -15,13 +15,12 @@

package com.rabbitmq.client.impl;

import com.rabbitmq.client.*;
import com.rabbitmq.client.Method;
import com.rabbitmq.client.*;
import com.rabbitmq.client.impl.AMQChannel.BlockingRpcContinuation;
import com.rabbitmq.client.impl.recovery.RecoveryCanBeginListener;
import com.rabbitmq.utility.BlockingCell;
import com.rabbitmq.utility.Utility;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@@ -47,6 +46,8 @@ final class Copyright {
*/
public class AMQConnection extends ShutdownNotifierComponent implements Connection, NetworkConnection {

private static final int MAX_UNSIGNED_SHORT = 65535;

private static final Logger LOGGER = LoggerFactory.getLogger(AMQConnection.class);
// we want socket write and channel shutdown timeouts to kick in after
// the heartbeat one, so we use a value of 105% of the effective heartbeat timeout
@@ -399,6 +400,11 @@ public void start()
int channelMax =
negotiateChannelMax(this.requestedChannelMax,
connTune.getChannelMax());

if (!checkUnsignedShort(channelMax)) {
throw new IllegalArgumentException("Negotiated channel max must be between 0 and " + MAX_UNSIGNED_SHORT + ": " + channelMax);
}

_channelManager = instantiateChannelManager(channelMax, threadFactory);

int frameMax =
@@ -410,6 +416,10 @@ public void start()
negotiatedMaxValue(this.requestedHeartbeat,
connTune.getHeartbeat());

if (!checkUnsignedShort(heartbeat)) {
throw new IllegalArgumentException("Negotiated heartbeat must be between 0 and " + MAX_UNSIGNED_SHORT + ": " + heartbeat);
}

setHeartbeat(heartbeat);

_channel0.transmit(new AMQP.Connection.TuneOk.Builder()
@@ -626,6 +636,10 @@ private static int negotiatedMaxValue(int clientValue, int serverValue) {
Math.min(clientValue, serverValue);
}

private static boolean checkUnsignedShort(int value) {
return value >= 0 && value <= MAX_UNSIGNED_SHORT;
}

private class MainLoop implements Runnable {

/**
32 changes: 15 additions & 17 deletions src/main/java/com/rabbitmq/client/impl/ChannelN.java
Original file line number Diff line number Diff line change
@@ -15,30 +15,24 @@

package com.rabbitmq.client.impl;

import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.*;

import com.rabbitmq.client.ConfirmCallback;
import com.rabbitmq.client.*;
import com.rabbitmq.client.AMQP.BasicProperties;
import com.rabbitmq.client.Connection;
import com.rabbitmq.client.Method;
import com.rabbitmq.client.impl.AMQImpl.Basic;
import com.rabbitmq.client.AMQP.BasicProperties;
import com.rabbitmq.client.impl.AMQImpl.Channel;
import com.rabbitmq.client.impl.AMQImpl.Confirm;
import com.rabbitmq.client.impl.AMQImpl.Exchange;
import com.rabbitmq.client.impl.AMQImpl.Queue;
import com.rabbitmq.client.impl.AMQImpl.Tx;
import com.rabbitmq.client.impl.AMQImpl.*;
import com.rabbitmq.utility.Utility;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeoutException;

/**
* Main interface to AMQP protocol functionality. Public API -
* Implementation of all AMQChannels except channel zero.
@@ -50,6 +44,7 @@
* </pre>
*/
public class ChannelN extends AMQChannel implements com.rabbitmq.client.Channel {
private static final int MAX_UNSIGNED_SHORT = 65535;
private static final String UNSPECIFIED_OUT_OF_BAND = "";
private static final Logger LOGGER = LoggerFactory.getLogger(ChannelN.class);

@@ -647,7 +642,10 @@ public AMQCommand transformReply(AMQCommand command) {
public void basicQos(int prefetchSize, int prefetchCount, boolean global)
throws IOException
{
exnWrappingRpc(new Basic.Qos(prefetchSize, prefetchCount, global));
if (prefetchCount < 0 || prefetchCount > MAX_UNSIGNED_SHORT) {
throw new IllegalArgumentException("Prefetch count must be between 0 and " + MAX_UNSIGNED_SHORT);
}
exnWrappingRpc(new Basic.Qos(prefetchSize, prefetchCount, global));
}

/** Public API - {@inheritDoc} */
31 changes: 31 additions & 0 deletions src/test/java/com/rabbitmq/client/test/ChannelNTest.java
Original file line number Diff line number Diff line change
@@ -24,6 +24,9 @@

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Stream;

import static org.assertj.core.api.Assertions.assertThatThrownBy;

public class ChannelNTest {

@@ -57,4 +60,32 @@ public void callingBasicCancelForUnknownConsumerDoesNotThrowException() throws E
channel.basicCancel("does-not-exist");
}

@Test
public void qosShouldBeUnsignedShort() {
AMQConnection connection = Mockito.mock(AMQConnection.class);
ChannelN channel = new ChannelN(connection, 1, consumerWorkService);
class TestConfig {
int value;
Consumer call;

public TestConfig(int value, Consumer call) {
this.value = value;
this.call = call;
}
}
Consumer qos = value -> channel.basicQos(value);
Consumer qosGlobal = value -> channel.basicQos(value, true);
Consumer qosPrefetchSize = value -> channel.basicQos(10, value, true);
Stream.of(
new TestConfig(-1, qos), new TestConfig(65536, qos)
).flatMap(config -> Stream.of(config, new TestConfig(config.value, qosGlobal), new TestConfig(config.value, qosPrefetchSize)))
.forEach(config -> assertThatThrownBy(() -> config.call.apply(config.value)).isInstanceOf(IllegalArgumentException.class));
}

interface Consumer {

void apply(int value) throws Exception;

}

}
1 change: 0 additions & 1 deletion src/test/java/com/rabbitmq/client/test/ClientTests.java
Original file line number Diff line number Diff line change
@@ -52,7 +52,6 @@
ConnectionFactoryTest.class,
RecoveryAwareAMQConnectionFactoryTest.class,
RpcTest.class,
SslContextFactoryTest.class,
LambdaCallbackTest.class,
ChannelAsyncCompletableFutureTest.class,
RecoveryDelayHandlerTest.class,
99 changes: 63 additions & 36 deletions src/test/java/com/rabbitmq/client/test/ConnectionFactoryTest.java
Original file line number Diff line number Diff line change
@@ -15,19 +15,8 @@

package com.rabbitmq.client.test;

import com.rabbitmq.client.Address;
import com.rabbitmq.client.AddressResolver;
import com.rabbitmq.client.Connection;
import com.rabbitmq.client.ConnectionFactory;
import com.rabbitmq.client.DnsRecordIpAddressResolver;
import com.rabbitmq.client.ListAddressResolver;
import com.rabbitmq.client.MetricsCollector;
import com.rabbitmq.client.impl.AMQConnection;
import com.rabbitmq.client.impl.ConnectionParams;
import com.rabbitmq.client.impl.CredentialsProvider;
import com.rabbitmq.client.impl.FrameHandler;
import com.rabbitmq.client.impl.FrameHandlerFactory;
import org.junit.AfterClass;
import com.rabbitmq.client.*;
import com.rabbitmq.client.impl.*;
import org.junit.Test;

import java.io.IOException;
@@ -37,17 +26,18 @@
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Stream;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.Assert.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.*;

public class ConnectionFactoryTest {

// see https://github.com/rabbitmq/rabbitmq-java-client/issues/262
@Test public void tryNextAddressIfTimeoutExceptionNoAutoRecovery() throws IOException, TimeoutException {
@Test
public void tryNextAddressIfTimeoutExceptionNoAutoRecovery() throws IOException, TimeoutException {
final AMQConnection connectionThatThrowsTimeout = mock(AMQConnection.class);
final AMQConnection connectionThatSucceeds = mock(AMQConnection.class);
final Queue<AMQConnection> connections = new ArrayBlockingQueue<AMQConnection>(10);
@@ -69,22 +59,23 @@ protected synchronized FrameHandlerFactory createFrameHandlerFactory() {
doThrow(TimeoutException.class).when(connectionThatThrowsTimeout).start();
doNothing().when(connectionThatSucceeds).start();
Connection returnedConnection = connectionFactory.newConnection(
new Address[] { new Address("host1"), new Address("host2") }
new Address[]{new Address("host1"), new Address("host2")}
);
assertSame(connectionThatSucceeds, returnedConnection);
assertThat(returnedConnection).isSameAs(connectionThatSucceeds);
}

// see https://github.com/rabbitmq/rabbitmq-java-client/pull/350
@Test public void customizeCredentialsProvider() throws Exception {
@Test
public void customizeCredentialsProvider() throws Exception {
final CredentialsProvider provider = mock(CredentialsProvider.class);
final AMQConnection connection = mock(AMQConnection.class);
final AtomicBoolean createCalled = new AtomicBoolean(false);

ConnectionFactory connectionFactory = new ConnectionFactory() {
@Override
protected AMQConnection createConnection(ConnectionParams params, FrameHandler frameHandler,
MetricsCollector metricsCollector) {
assertSame(provider, params.getCredentialsProvider());
MetricsCollector metricsCollector) {
assertThat(provider).isSameAs(params.getCredentialsProvider());
createCalled.set(true);
return connection;
}
@@ -96,22 +87,23 @@ protected synchronized FrameHandlerFactory createFrameHandlerFactory() {
};
connectionFactory.setCredentialsProvider(provider);
connectionFactory.setAutomaticRecoveryEnabled(false);

doNothing().when(connection).start();

Connection returnedConnection = connectionFactory.newConnection();
assertSame(returnedConnection, connection);
assertTrue(createCalled.get());
assertThat(returnedConnection).isSameAs(connection);
assertThat(createCalled).isTrue();
}

@Test public void shouldNotUseDnsResolutionWhenOneAddressAndNoTls() throws Exception {
@Test
public void shouldNotUseDnsResolutionWhenOneAddressAndNoTls() throws Exception {
AMQConnection connection = mock(AMQConnection.class);
AtomicReference<AddressResolver> addressResolver = new AtomicReference<>();

ConnectionFactory connectionFactory = new ConnectionFactory() {
@Override
protected AMQConnection createConnection(ConnectionParams params, FrameHandler frameHandler,
MetricsCollector metricsCollector) {
MetricsCollector metricsCollector) {
return connection;
}

@@ -131,18 +123,18 @@ protected synchronized FrameHandlerFactory createFrameHandlerFactory() {

doNothing().when(connection).start();
connectionFactory.newConnection();

assertThat(addressResolver.get(), allOf(notNullValue(), instanceOf(ListAddressResolver.class)));
assertThat(addressResolver.get()).isNotNull().isInstanceOf(ListAddressResolver.class);
}

@Test public void shouldNotUseDnsResolutionWhenOneAddressAndTls() throws Exception {
@Test
public void shouldNotUseDnsResolutionWhenOneAddressAndTls() throws Exception {
AMQConnection connection = mock(AMQConnection.class);
AtomicReference<AddressResolver> addressResolver = new AtomicReference<>();

ConnectionFactory connectionFactory = new ConnectionFactory() {
@Override
protected AMQConnection createConnection(ConnectionParams params, FrameHandler frameHandler,
MetricsCollector metricsCollector) {
MetricsCollector metricsCollector) {
return connection;
}

@@ -164,7 +156,42 @@ protected synchronized FrameHandlerFactory createFrameHandlerFactory() {
connectionFactory.useSslProtocol();
connectionFactory.newConnection();

assertThat(addressResolver.get(), allOf(notNullValue(), instanceOf(ListAddressResolver.class)));
assertThat(addressResolver.get()).isNotNull().isInstanceOf(ListAddressResolver.class);
}

@Test
public void heartbeatAndChannelMaxMustBeUnsignedShorts() {
class TestConfig {
int value;
Consumer<Integer> call;
boolean expectException;

public TestConfig(int value, Consumer<Integer> call, boolean expectException) {
this.value = value;
this.call = call;
this.expectException = expectException;
}
}

ConnectionFactory cf = new ConnectionFactory();
Consumer<Integer> setHeartbeat = cf::setRequestedHeartbeat;
Consumer<Integer> setChannelMax = cf::setRequestedChannelMax;

Stream.of(
new TestConfig(0, setHeartbeat, false),
new TestConfig(10, setHeartbeat, false),
new TestConfig(65535, setHeartbeat, false),
new TestConfig(-1, setHeartbeat, true),
new TestConfig(65536, setHeartbeat, true))
.flatMap(config -> Stream.of(config, new TestConfig(config.value, setChannelMax, config.expectException)))
.forEach(config -> {
if (config.expectException) {
assertThatThrownBy(() -> config.call.accept(config.value)).isInstanceOf(IllegalArgumentException.class);
} else {
config.call.accept(config.value);
}
});

}

}
4 changes: 3 additions & 1 deletion src/test/java/com/rabbitmq/client/test/ssl/SSLTests.java
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
package com.rabbitmq.client.test.ssl;

import com.rabbitmq.client.test.AbstractRMQTestSuite;
import com.rabbitmq.client.test.SslContextFactoryTest;
import org.junit.runner.RunWith;
import org.junit.runner.Runner;
import org.junit.runners.Suite;
@@ -34,7 +35,8 @@
ConnectionFactoryDefaultTlsVersion.class,
NioTlsUnverifiedConnection.class,
HostnameVerification.class,
TlsConnectionLogging.class
TlsConnectionLogging.class,
SslContextFactoryTest.class
})
public class SSLTests {