SpringBoot集成websocket

本次以代码实战的方式介绍SpringBoot集成websocket

HandshakeInterceptor

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
/**
* websocket握手拦截器
* @author zhouheng
* @date 2021-04-20 18:17
*/
@Slf4j
@Component
public class XXXHandshakeInterceptor implements HandshakeInterceptor {
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse serverHttpResponse,
WebSocketHandler webSocketHandler, Map<String, Object> attributes) throws Exception {
log.info("Handle before webSocket connected. ");
if (request instanceof ServletServerHttpRequest) {
ServletServerHttpRequest servletRequest = (ServletServerHttpRequest) request;
// ws://path?taskId=xxx
String taskId = servletRequest.getServletRequest().getParameter(Constants.TASK_ID);
// get ip
String ipAddress = getIpAddress(servletRequest.getServletRequest());
if (ipAddress == null) {
ipAddress = servletRequest.getServletRequest().getRemoteAddr();
}
attributes.put(Constants.IP, ipAddress);
attributes.put(Constants.TASK_ID, taskId);
}
return true;
}

@Override
public void afterHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse,
WebSocketHandler webSocketHandler, Exception e) {
}

// 配合nginx 获取真实ip
public static String getIpAddress(HttpServletRequest request) {
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.length() == 0 || Constants.UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || Constants.UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || Constants.UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (ip == null || ip.length() == 0 || Constants.UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
if (ip == null || ip.length() == 0 || Constants.UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
return ip.equals("0:0:0:0:0:0:0:1") ? "127.0.0.1" : ip;
}

}

message handler

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

/**
*
* 消息处理handler
* @author zhouheng
* @date 2021-04-20 19:58
*/
@Service
@Slf4j
public class XXXWebSocketHandler implements WebSocketHandler {

private static ITaskService taskService;
private Long taskId;
private WebSocketSession session;
private static final CopyOnWriteArraySet<XXXWebSocketHandler> sockets = new CopyOnWriteArraySet<>();
private String ip;

@Autowired
public void setTaskService(ITaskService taskService) {
XXXWebSocketHandler.taskService = taskService;
}

@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
if (null == session.getAttributes().get(Constants.IP)) {
log.info("websocket连接未获取到ip");
this.ip = "not get ip";
} else {
this.ip = session.getAttributes().get(Constants.IP).toString();
}

if (null == session.getAttributes().get(Constants.TASK_ID)) {
log.error("session [{}] , ip [{}] websocket连接未获取到taskId",session.getId(),ip);
return;
}
this.taskId = Long.valueOf(session.getAttributes().get(Constants.TASK_ID).toString());
this.session = session;
sockets.add(this);
log.info("session [{}],taskId [{}], ip [{}] connect websocket!", session.getId(), taskId, ip);
}

@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
log.info("websocket server receive message: {} from session [{}] taskId [{}] ip [{}] ", message.getPayload(),
session.getId(), taskId, ip);
}

@Override
public void handleTransportError(WebSocketSession session, Throwable e) throws Exception {
e.printStackTrace();
log.error("websocket session [{}] taskId [{}] ip [{}] has error!", session.getId(), taskId, ip);
}

@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
try {
XXXWebSocketHandler.sockets.remove(this);
taskService.closeStream(this.taskId);
log.info("websocket session[{}] taskId[{}] ip [{}] close ", session.getId(), taskId, ip);
} catch (Exception e) {
log.error(e.getMessage());
}
}

@Override
public boolean supportsPartialMessages() {
return false;
}

// 发送消息带业务逻辑
public static void sendMessage(AlarmVO vo) {
if (sockets.isEmpty()) {
return;
}
for (XXXWebSocketHandler socket : sockets) {
try {
if (vo.getTaskId() != null && vo.getTaskId().equals(socket.taskId)) {
synchronized (socket.getSession()) {
socket.getSession().sendMessage(new TextMessage(JsonUtil.toJson(vo)));
}
}

} catch (Exception e) {
e.printStackTrace();
log.error("send message has error, session [{}],taskId [{}] ,ip [{}]", socket.session, socket.taskId,
socket.ip);
}
}
}

public WebSocketSession getSession() {
return this.session;
}
}

websocket config

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/**
* @author zhouheng 2021-03-26
*/
@Configuration
@EnableWebSocket //开启websocket
public class WebSocketConfig implements WebSocketConfigurer {

@Resource
private XXXWebSocketHandler XXXWebSocketHandler;

@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
// 注册handler及连接url
registry.addHandler(XXXWebSocketHandler, "/api/v1/ws")
// 注册拦截器
.addInterceptors(new AlarmHandshakeInterceptor())
// 设置跨域
.setAllowedOrigins("*");
}
}

nginx配置

1
2
3
4
5
6
7
8
9
location /api/v1/ws {
proxy_pass http://127.0.0.1:1103;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
}

实战截图