diff --git a/graphql-spring-boot-test/src/main/java/com/graphql/spring/boot/test/GraphQLTestSubscription.java b/graphql-spring-boot-test/src/main/java/com/graphql/spring/boot/test/GraphQLTestSubscription.java index 05b1d02f..b3e11d74 100644 --- a/graphql-spring-boot-test/src/main/java/com/graphql/spring/boot/test/GraphQLTestSubscription.java +++ b/graphql-spring-boot-test/src/main/java/com/graphql/spring/boot/test/GraphQLTestSubscription.java @@ -16,6 +16,7 @@ import org.springframework.web.util.UriBuilderFactory; import javax.websocket.ClientEndpointConfig; +import javax.websocket.CloseReason; import javax.websocket.ContainerProvider; import javax.websocket.Endpoint; import javax.websocket.EndpointConfig; @@ -33,8 +34,8 @@ import java.util.Map; import java.util.Optional; import java.util.Queue; -import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Predicate; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.fail; @@ -46,29 +47,42 @@ @Slf4j public class GraphQLTestSubscription { + private static final WebSocketContainer WEB_SOCKET_CONTAINER = ContainerProvider.getWebSocketContainer(); private static final int SLEEP_INTERVAL_MS = 100; - private static final int ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT = 6000000; + private static final int ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT = 60000; private static final AtomicInteger ID_COUNTER = new AtomicInteger(1); private static final UriBuilderFactory URI_BUILDER_FACTORY = new DefaultUriBuilderFactory(); + private static final Object STATE_LOCK = new Object(); @Getter private Session session; - - @Getter - private boolean initialized = false; - @Getter - private boolean acknowledged = false; - @Getter - private boolean started = false; - @Getter - private boolean stopped = false; + private SubscriptionState state = SubscriptionState.builder() + .id(ID_COUNTER.incrementAndGet()) + .build(); private final Environment environment; private final ObjectMapper objectMapper; private final String subscriptionPath; - private final Queue responses = new ConcurrentLinkedQueue<>(); - private int id = ID_COUNTER.getAndIncrement(); + public boolean isInitialized() { + return state.isInitialized(); + } + + public boolean isAcknowledged() { + return state.isAcknowledged(); + } + + public boolean isStarted() { + return state.isStarted(); + } + + public boolean isStopped() { + return state.isStopped(); + } + + public boolean isCompleted() { + return state.isCompleted(); + } /** * Sends the "connection_init" message to the GraphQL server without a payload. @@ -85,7 +99,7 @@ public GraphQLTestSubscription init() { * @return self reference */ public GraphQLTestSubscription init(@Nullable final Object payload) { - if (initialized) { + if (isInitialized()) { fail("Subscription already initialized."); } try { @@ -97,8 +111,9 @@ public GraphQLTestSubscription init(@Nullable final Object payload) { message.put("type", "connection_init"); message.set("payload", getFinalPayload(payload)); sendMessage(message); - initialized = true; + state.setInitialized(true); awaitAcknowledgement(); + log.debug("Subscription successfully initialized."); return this; } @@ -120,20 +135,21 @@ public GraphQLTestSubscription start(@NonNull final String graphQLResource) { * @return self reference */ public GraphQLTestSubscription start(@NonNull final String graphGLResource, @Nullable final Object variables) { - if (!initialized) { + if (!isInitialized()) { init(); } - if (started) { + if (isStarted()) { fail("Start message already sent. To start a new subscription, please call reset first."); } - started = true; + state.setStarted(true); ObjectNode payload = objectMapper.createObjectNode(); payload.put("query", loadQuery(graphGLResource)); payload.set("variables", getFinalPayload(variables)); ObjectNode message = objectMapper.createObjectNode(); message.put("type", "start"); - message.put("id", id); + message.put("id", state.getId()); message.set("payload", payload); + log.debug("Sending start message."); sendMessage(message); return this; } @@ -143,24 +159,25 @@ public GraphQLTestSubscription start(@NonNull final String graphGLResource, @Nul * @return self reference */ public GraphQLTestSubscription stop() { - if (!initialized) { + if (!isInitialized()) { fail("Subscription not yet initialized."); } - if (stopped) { + if (isStopped()) { fail("Subscription already stopped."); } final ObjectNode message = objectMapper.createObjectNode(); message.put("type", "stop"); - message.put("id", id); + message.put("id", state.getId()); + log.debug("Sending stop message."); sendMessage(message); - stopped = true; try { + log.debug("Closing web socket session."); session.close(); - session = null; + awaitStop(); + log.debug("Web socket session closed."); } catch (IOException e) { fail("Could not close web socket session", e); } - log.debug("Subscription stopped."); return this; } @@ -169,20 +186,12 @@ public GraphQLTestSubscription stop() { * ensure that the bean is reusable between tests. */ public void reset() { - if (initialized && !stopped) { + if (isInitialized() && !isStopped()) { stop(); } - if (stopped) { - id = ID_COUNTER.getAndIncrement(); - } - initialized = false; - started = false; - stopped = false; - acknowledged = false; + state = SubscriptionState.builder().id(ID_COUNTER.incrementAndGet()).build(); session = null; - synchronized (responses) { - responses.clear(); - } + log.debug("Test subscription client reset."); } /** @@ -264,15 +273,15 @@ public List awaitAndGetNextResponses( final int numExpectedResponses, final boolean stopAfter ) { - if (!started) { + if (!isStarted()) { fail("Start message not sent. Please send start message first."); } - if (stopped) { + if (isStopped()) { fail("Subscription already stopped. Forgot to call reset after test case?"); } int elapsedTime = 0; while ( - ((responses.size() < numExpectedResponses) || numExpectedResponses <= 0) + ((state.getResponses().size() < numExpectedResponses) || numExpectedResponses <= 0) && elapsedTime < timeout ) { try { @@ -282,10 +291,11 @@ public List awaitAndGetNextResponses( fail("Test execution error - Thread.sleep failed.", e); } } - synchronized (responses) { - if (stopAfter) { - stop(); - } + if (stopAfter) { + stop(); + } + synchronized (STATE_LOCK) { + final Queue responses = state.getResponses(); int responsesToPoll = responses.size(); if (numExpectedResponses == 0) { assertThat(responses) @@ -336,16 +346,15 @@ public GraphQLTestSubscription waitAndExpectNoResponse(final int timeToWait) { * @return the remaining responses. */ public List getRemainingResponses() { - if (!stopped) { + if (!isStopped()) { fail("getRemainingResponses should only be called after the subscription was stopped."); } - final ArrayList graphQLResponses = new ArrayList<>(responses); - responses.clear(); + final ArrayList graphQLResponses = new ArrayList<>(state.getResponses()); + state.getResponses().clear(); return graphQLResponses; } private void initClient() throws Exception { - final WebSocketContainer webSocketContainer = ContainerProvider.getWebSocketContainer(); final String port = environment.getProperty("local.server.port"); final URI uri = URI_BUILDER_FACTORY.builder().scheme("ws").host("localhost").port(port).path(subscriptionPath) .build(); @@ -355,8 +364,8 @@ private void initClient() throws Exception { .build(); clientEndpointConfig.getUserProperties().put("org.apache.tomcat.websocket.IO_TIMEOUT_MS", String.valueOf(ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT)); - session = webSocketContainer.connectToServer(TestWebSocketClient.class, clientEndpointConfig, uri); - session.addMessageHandler(new TestMessageHandler()); + session = WEB_SOCKET_CONTAINER.connectToServer(new TestWebSocketClient(state), clientEndpointConfig, uri); + session.addMessageHandler(new TestMessageHandler(objectMapper, state)); } private JsonNode getFinalPayload(final Object variables) { @@ -384,8 +393,16 @@ private void sendMessage(final Object message) { } private void awaitAcknowledgement() { + await(GraphQLTestSubscription::isAcknowledged, "Connection was not acknowledged by the GraphQL server."); + } + + private void awaitStop() { + await(GraphQLTestSubscription::isStopped, "Connection was not stopped in time."); + } + + private void await(final Predicate condition, final String timeoutDescription) { int elapsedTime = 0; - while(!acknowledged && elapsedTime < ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT) { + while(!condition.test(this) && elapsedTime < ACKNOWLEDGEMENT_AND_CONNECTION_TIMEOUT) { try { Thread.sleep(SLEEP_INTERVAL_MS); elapsedTime += SLEEP_INTERVAL_MS; @@ -394,22 +411,31 @@ private void awaitAcknowledgement() { } } - if (!acknowledged) { - fail("Timeout: Connection was not acknowledged by the GraphQL server."); + if (!condition.test(this)) { + fail(String.format("Timeout: " + timeoutDescription)); } } - class TestMessageHandler implements MessageHandler.Whole { + @RequiredArgsConstructor + static class TestMessageHandler implements MessageHandler.Whole { + + private final ObjectMapper objectMapper; + private final SubscriptionState state; + @Override public void onMessage(final String message) { try { log.debug("Received message from web socket: {}", message); final JsonNode jsonNode = objectMapper.readTree(message); final JsonNode typeNode = jsonNode.get("type"); - assertThat(typeNode.isNull()).as("GraphQL messages should have a type field.").isFalse(); + assertThat(typeNode).as("GraphQL messages should have a type field.").isNotNull(); + assertThat(typeNode.isNull()).as("GraphQL messages type should not be null.").isFalse(); final String type = typeNode.asText(); - if (type.equals("connection_ack")) { - acknowledged = true; + if (type.equals("complete")) { + state.setCompleted(true); + log.debug("Subscription completed."); + } else if (type.equals("connection_ack")) { + state.setAcknowledged(true); log.debug("WebSocket connection acknowledged by the GraphQL Server."); } else if (type.equals("data") || type.equals("error")) { final JsonNode payload = jsonNode.get("payload"); @@ -417,8 +443,13 @@ public void onMessage(final String message) { final String payloadString = objectMapper.writeValueAsString(payload); final GraphQLResponse graphQLResponse = new GraphQLResponse(ResponseEntity.ok(payloadString), objectMapper); - synchronized (responses) { - responses.add(graphQLResponse); + if (state.isStopped() || state.isCompleted()) { + log.debug("Response discarded because subscription was stopped or completed in the meanwhile."); + } else { + synchronized (STATE_LOCK) { + state.getResponses().add(graphQLResponse); + } + log.debug("New response recorded."); } } } catch (JsonProcessingException e) { @@ -427,11 +458,21 @@ public void onMessage(final String message) { } } - public static class TestWebSocketClient extends Endpoint { + @RequiredArgsConstructor + private static class TestWebSocketClient extends Endpoint { + + private final SubscriptionState state; + @Override public void onOpen(final Session session, final EndpointConfig config) { log.debug("Connection established."); } + + @Override + public void onClose(Session session, CloseReason closeReason) { + super.onClose(session, closeReason); + state.setStopped(true); + } } static class TestWebSocketClientConfigurator extends ClientEndpointConfig.Configurator { diff --git a/graphql-spring-boot-test/src/main/java/com/graphql/spring/boot/test/SubscriptionState.java b/graphql-spring-boot-test/src/main/java/com/graphql/spring/boot/test/SubscriptionState.java new file mode 100644 index 00000000..8ed3a72c --- /dev/null +++ b/graphql-spring-boot-test/src/main/java/com/graphql/spring/boot/test/SubscriptionState.java @@ -0,0 +1,25 @@ +package com.graphql.spring.boot.test; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; + +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +class SubscriptionState { + + private boolean initialized; + private boolean acknowledged; + private boolean started; + private boolean stopped; + private boolean completed; + @Builder.Default + private Queue responses = new ConcurrentLinkedQueue<>(); + private int id; +} diff --git a/graphql-spring-boot-test/src/test/java/com/graphql/spring/boot/test/GraphQLTestSubscriptionResetTest.java b/graphql-spring-boot-test/src/test/java/com/graphql/spring/boot/test/GraphQLTestSubscriptionResetTest.java index a6809bc7..71622407 100644 --- a/graphql-spring-boot-test/src/test/java/com/graphql/spring/boot/test/GraphQLTestSubscriptionResetTest.java +++ b/graphql-spring-boot-test/src/test/java/com/graphql/spring/boot/test/GraphQLTestSubscriptionResetTest.java @@ -19,7 +19,7 @@ void shouldWorkIfSubscriptionWasNotStarted() { graphQLTestSubscription.reset(); // THEN assertThatSubscriptionWasReset(); - assertThatExistingIdWasRetained(firstId); + assertThatNewIdWasGenerated(firstId); } @Test @@ -52,11 +52,18 @@ void shouldWorkIfSubscriptionWasAlreadyStopped() { @DisplayName("Should allow starting a new subscription after reset.") void shouldAllowStartingNewSubscriptionAfterReset() { // GIVEN - graphQLTestSubscription.start(TIMER_SUBSCRIPTION_RESOURCE); + startAndAssertThatNewSubscriptionWorks(); // WHEN graphQLTestSubscription.reset(); // THEN - graphQLTestSubscription.start(TIMER_SUBSCRIPTION_RESOURCE); + startAndAssertThatNewSubscriptionWorks(); + } + + private void startAndAssertThatNewSubscriptionWorks() { + final Integer actual = graphQLTestSubscription.start(TIMER_SUBSCRIPTION_RESOURCE) + .awaitAndGetNextResponse(TIMEOUT) + .get("$.data.timer", Integer.class); + assertThat(actual).isZero(); } private void assertThatSubscriptionWasReset() { @@ -64,8 +71,9 @@ private void assertThatSubscriptionWasReset() { assertThat(graphQLTestSubscription.isAcknowledged()).isFalse(); assertThat(graphQLTestSubscription.isStarted()).isFalse(); assertThat(graphQLTestSubscription.isStopped()).isFalse(); - assertThat((Queue) ReflectionTestUtils.getField(graphQLTestSubscription, GraphQLTestSubscription.class, - "responses")).isEmpty(); + assertThat(graphQLTestSubscription.isCompleted()).isFalse(); + assertThat(((SubscriptionState) ReflectionTestUtils.getField(graphQLTestSubscription, GraphQLTestSubscription.class, + "state")).getResponses()).isEmpty(); assertThat(graphQLTestSubscription.getSession()).isNull(); } @@ -73,11 +81,7 @@ private void assertThatNewIdWasGenerated(int previousId) { assertThat(getSubscriptionId()).isEqualTo(previousId + 1); } - private void assertThatExistingIdWasRetained(int previousId) { - assertThat(getSubscriptionId()).isEqualTo(previousId); - } - private int getSubscriptionId() { - return (int) ReflectionTestUtils.getField(graphQLTestSubscription, GraphQLTestSubscription.class, "id"); + return ((SubscriptionState) ReflectionTestUtils.getField(graphQLTestSubscription, GraphQLTestSubscription.class, "state")).getId(); } }