summaryrefslogtreecommitdiff
path: root/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketHandshakeTests.java
diff options
context:
space:
mode:
Diffstat (limited to 'spring-websocket/src/test/java/org/springframework/web/socket/WebSocketHandshakeTests.java')
-rw-r--r--spring-websocket/src/test/java/org/springframework/web/socket/WebSocketHandshakeTests.java158
1 files changed, 158 insertions, 0 deletions
diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketHandshakeTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketHandshakeTests.java
new file mode 100644
index 00000000..813b7b5f
--- /dev/null
+++ b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketHandshakeTests.java
@@ -0,0 +1,158 @@
+/*
+ * Copyright 2002-2016 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.socket;
+
+import java.net.URI;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.web.socket.client.jetty.JettyWebSocketClient;
+import org.springframework.web.socket.client.standard.StandardWebSocketClient;
+import org.springframework.web.socket.config.annotation.EnableWebSocket;
+import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
+import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
+import org.springframework.web.socket.handler.AbstractWebSocketHandler;
+import org.springframework.web.socket.handler.TextWebSocketHandler;
+import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
+
+import static org.junit.Assert.*;
+
+/**
+ * Client and server-side WebSocket integration tests.
+ *
+ * @author Rossen Stoyanchev
+ * @author Juergen Hoeller
+ */
+@RunWith(Parameterized.class)
+public class WebSocketHandshakeTests extends AbstractWebSocketIntegrationTests {
+
+ @Parameters(name = "server [{0}], client [{1}]")
+ public static Iterable<Object[]> arguments() {
+ return Arrays.asList(new Object[][] {
+ {new JettyWebSocketTestServer(), new JettyWebSocketClient()},
+ {new TomcatWebSocketTestServer(), new StandardWebSocketClient()},
+ {new UndertowTestServer(), new JettyWebSocketClient()}
+ });
+ }
+
+
+ @Override
+ protected Class<?>[] getAnnotatedConfigClasses() {
+ return new Class<?>[] {TestConfig.class};
+ }
+
+ @Test
+ public void subProtocolNegotiation() throws Exception {
+ WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
+ headers.setSecWebSocketProtocol("foo");
+ URI url = new URI(getWsBaseUrl() + "/ws");
+ WebSocketSession session = this.webSocketClient.doHandshake(new TextWebSocketHandler(), headers, url).get();
+ assertEquals("foo", session.getAcceptedProtocol());
+ session.close();
+ }
+
+ @Test // SPR-12727
+ public void unsolicitedPongWithEmptyPayload() throws Exception {
+ String url = getWsBaseUrl() + "/ws";
+ WebSocketSession session = this.webSocketClient.doHandshake(new AbstractWebSocketHandler() {}, url).get();
+
+ TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class);
+ serverHandler.setWaitMessageCount(1);
+
+ session.sendMessage(new PongMessage());
+
+ serverHandler.await();
+ assertNull(serverHandler.getTransportError());
+ assertEquals(1, serverHandler.getReceivedMessages().size());
+ assertEquals(PongMessage.class, serverHandler.getReceivedMessages().get(0).getClass());
+ }
+
+
+ @Configuration
+ @EnableWebSocket
+ static class TestConfig implements WebSocketConfigurer {
+
+ @Autowired
+ private DefaultHandshakeHandler handshakeHandler; // can't rely on classpath for server detection
+
+ @Override
+ public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
+ this.handshakeHandler.setSupportedProtocols("foo", "bar", "baz");
+ registry.addHandler(handler(), "/ws").setHandshakeHandler(this.handshakeHandler);
+ }
+
+ @Bean
+ public TestWebSocketHandler handler() {
+ return new TestWebSocketHandler();
+ }
+
+ }
+
+ @SuppressWarnings("rawtypes")
+ private static class TestWebSocketHandler extends AbstractWebSocketHandler {
+
+ private List<WebSocketMessage> receivedMessages = new ArrayList<>();
+
+ private int waitMessageCount;
+
+ private final CountDownLatch latch = new CountDownLatch(1);
+
+ private Throwable transportError;
+
+ public void setWaitMessageCount(int waitMessageCount) {
+ this.waitMessageCount = waitMessageCount;
+ }
+
+ public List<WebSocketMessage> getReceivedMessages() {
+ return this.receivedMessages;
+ }
+
+ public Throwable getTransportError() {
+ return this.transportError;
+ }
+
+ @Override
+ public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
+ this.receivedMessages.add(message);
+ if (this.receivedMessages.size() >= this.waitMessageCount) {
+ this.latch.countDown();
+ }
+ }
+
+ @Override
+ public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
+ this.transportError = exception;
+ this.latch.countDown();
+ }
+
+ public void await() throws InterruptedException {
+ this.latch.await(5, TimeUnit.SECONDS);
+ }
+ }
+
+}