1) 技术选型
webssh需要实时数据交互,选择长连接的WebSocket,为了开发的方便,框架选用SpringBoot,另外还自己了解了Java用户连接ssh的mina sshd和实现前端shell页面的xterm.js
2)添加maven依赖
org.apache.sshd sshd-core2.9.2 org.springframework.boot spring-boot-starter-websocket2.7.6
3) websocket配置
package cn.cloud.common.config;import cn.cloud.common.handler.WebSSHWebSocketHandler;import cn.cloud.common.interceptor.WebSocketInterceptor;import org.springframework.context.annotation.Configuration;import org.springframework.web.socket.config.annotation.EnableWebSocket;import org.springframework.web.socket.config.annotation.WebSocketConfigurer;import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;import javax.annotation.Resource;@Configuration@EnableWebSocketpublic class WebSSHWebSocketConfig implements WebSocketConfigurer { @Resource private WebSSHWebSocketHandler webSSHWebSocketHandler; @Override public void registerWebSocketHandlers(WebSocketHandlerRegistry webSocketHandlerRegistry) { //socket通道 //指定处理器和路径 webSocketHandlerRegistry.addHandler(webSSHWebSocketHandler, "/ws/webssh") .addInterceptors(new WebSocketInterceptor()) .setAllowedOrigins("*"); }}
4) websocket处理器配置
package cn.cloud.common.handler;import cn.cloud.common.pojo.OperateConstant;import cn.cloud.common.service.WebSSHService;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.data.redis.core.StringRedisTemplate;import org.springframework.stereotype.Component;import org.springframework.web.socket.*;import javax.annotation.Resource;import java.io.IOException;@Componentpublic class WebSSHWebSocketHandler implements WebSocketHandler { @Resource private WebSSHService webSSHService; @Resource private StringRedisTemplate stringRedisTemplate; private final Logger LOGGER = LoggerFactory.getLogger(WebSSHWebSocketHandler.class); @Override public void afterConnectionEstablished(WebSocketSession webSocketSession) { LOGGER.info("与{}建立websocket连接", webSocketSession.getAttributes().get(OperateConstant.USER_UUID_KEY)); // 调用初始化ssh连接 webSSHService.initConnection(webSocketSession); } @Override public void handleMessage(WebSocketSession webSocketSession, WebSocketMessage> webSocketMessage) throws Exception { if (webSocketMessage instanceof TextMessage) { // 处理前端消息 webSSHService.commandHandler(((TextMessage) webSocketMessage).getPayload(), webSocketSession); } else { LOGGER.error("Unexpected WebSocket message type: " + webSocketMessage); } } @Override public void handleTransportError(WebSocketSession webSocketSession, Throwable throwable) throws Exception { LOGGER.error("数据传输错误"); } @Override public void afterConnectionClosed(WebSocketSession webSocketSession, CloseStatus closeStatus) throws IOException { LOGGER.info("与{}断开websocket连接", webSocketSession.getAttributes().get(OperateConstant.USER_UUID_KEY)); // 关闭连接 webSSHService.closeConnection(webSocketSession); // websocket连接关闭后ip限制连接数随之变化 updateIpCount(webSocketSession); } private void updateIpCount(WebSocketSession webSocketSession) { String ip = String.valueOf(webSocketSession.getAttributes().get(OperateConstant.IP)); int count = Integer.parseInt(String.valueOf(stringRedisTemplate.opsForHash().get(OperateConstant.IP, ip))); stringRedisTemplate.opsForHash().put(OperateConstant.IP, ip, count - 1); } @Override public boolean supportsPartialMessages() { return false; }}
5) websocket拦截器配置
package cn.cloud.common.interceptor;import cn.cloud.common.pojo.OperateConstant;import cn.cloud.common.util.RedisUtil;import org.springframework.http.server.ServerHttpRequest;import org.springframework.http.server.ServerHttpResponse;import org.springframework.http.server.ServletServerHttpRequest;import org.springframework.web.socket.WebSocketHandler;import org.springframework.web.socket.server.HandshakeInterceptor;import javax.servlet.http.HttpServletRequest;import java.util.Map;import java.util.UUID;public class WebSocketInterceptor implements HandshakeInterceptor { private static final int MAX_REQUESTS_PER_SECOND = 10; private RedisUtil redisUtil = new RedisUtil(); @Override public boolean beforeHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Mapmap) throws Exception { if (serverHttpRequest instanceof ServletServerHttpRequest) { ServletServerHttpRequest request = (ServletServerHttpRequest) serverHttpRequest; HttpServletRequest servletRequest = request.getServletRequest(); String ip = getIpAddress(servletRequest); if (isLimitExceededRedis(ip)) { return false; } // 当某个 IP 的请求数超过指定的闽值时则拒绝建立websocket链接return false; updateIpRequestCountRedis(ip); // 生成一个UUID String user = UUID.randomUUID().toString().replace("-", ""); // 将uuid放到websocket session中 map.put(OperateConstant.USER_UUID_KEY, user); // 将ip放到websocket session中 map.put(OperateConstant.IP, ip); return true; } else { return false; } } @Override public void afterHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Exception e) {} private String getIpAddress(HttpServletRequest request) { String ip = request.getHeader("X-Forwarded-For"); if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) { ip = request.getHeader("X-Real-Ip"); } if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) { ip = request.getHeader("Proxy-Client-Ip"); } if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) { ip = request.getHeader("WL-Proxy-Client-Ip"); } if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) { ip = request.getHeader("HTTP_CLIENT_IP"); } if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) { ip = request.getHeader("HTTP_X_FORWARDED_FOR"); } if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) { ip = request.getRemoteAddr(); } return ip; } private boolean isLimitExceededRedis(String ip) { // 存在redis 中这样后续断开链接的时候可以直接读值 if (redisUtil.hget(OperateConstant.IP, ip) == null) { redisUtil.hput(OperateConstant.IP, ip, 0); redisUtil.expire(OperateConstant.IP, 24 * 60 * 60); } return Integer.parseInt(String.valueOf(redisUtil.hget(OperateConstant.IP, ip))) > MAX_REQUESTS_PER_SECOND; } private void updateIpRequestCountRedis(String ip) { redisUtil.hput(OperateConstant.IP, ip, Integer.parseInt(String.valueOf(redisUtil.hget(OperateConstant.IP, ip))) + 1); }}
5)mina sshd + websocket 核心业务逻辑实现
package cn.cloud.common.service.impl;import cn.cloud.common.pojo.OperateConstant;import cn.cloud.common.pojo.WebSSHConfig;import cn.cloud.common.pojo.WebSSHData;import cn.cloud.common.pojo.WebSSHInfo;import cn.cloud.common.service.WebSSHService;import com.fasterxml.jackson.databind.ObjectMapper;import org.apache.commons.lang.StringUtils;import org.apache.sshd.client.SshClient;import org.apache.sshd.client.channel.ChannelShell;import org.apache.sshd.client.channel.ClientChannelEvent;import org.apache.sshd.client.future.ConnectFuture;import org.apache.sshd.client.session.ClientSession;import org.apache.sshd.common.keyprovider.FileKeyPairProvider;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.stereotype.Service;import org.springframework.web.socket.TextMessage;import org.springframework.web.socket.WebSocketSession;import java.io.File;import java.io.IOException;import java.io.InputStream;import java.nio.file.Path;import java.util.Arrays;import java.util.Collections;import java.util.Map;import java.util.Objects;import java.util.concurrent.ConcurrentHashMap;import java.util.concurrent.ExecutorService;import java.util.concurrent.Executors;import java.util.stream.Collectors;import java.util.stream.Stream;@Servicepublic class SSHServiceImpl implements WebSSHService { // 存放ssh连接信息的map private static final MapsshMap = new ConcurrentHashMap<>(); private final Logger LOGGER = LoggerFactory.getLogger(SSHServiceImpl.class); // 线程池 private final ExecutorService executorService = Executors.newCachedThreadPool(); @Override public void initConnection(WebSocketSession webSocketSession) { try { SshClient sshClient = SshClient.setUpDefaultClient(); sshClient.open(); WebSSHInfo webSSHInfo = new WebSSHInfo(); webSSHInfo.setSshClient(sshClient); webSSHInfo.setWebSocketSession(webSocketSession); String uuid = String.valueOf(webSocketSession.getAttributes().get(OperateConstant.USER_UUID_KEY)); // 将这个ssh连接信息放入map中 sshMap.put(uuid, webSSHInfo); } catch (Exception e) { LOGGER.info(e.getMessage()); } } @Override public void commandHandler(String buffer, WebSocketSession webSocketSession) throws IOException { ObjectMapper objectMapper = new ObjectMapper(); WebSSHData webSSHData = null; try { webSSHData = objectMapper.readValue(buffer, WebSSHData.class); } catch (IOException e) { LOGGER.error("WebSSHData Json转换异常:{}", e.getMessage()); // 主动向前端推送msg sendMessage(webSocketSession, "Connection-Closed".getBytes()); return; } // uuid String userId = String.valueOf(webSocketSession.getAttributes().get(OperateConstant.USER_UUID_KEY)); // 找到刚才存储的ssh连接对象 WebSSHInfo webSSHInfo = (WebSSHInfo) sshMap.get(userId); if (webSSHInfo == null) { return; } // connect if (OperateConstant.WEBSSH_OPERATE_CONNECT.equals(webSSHData.getOperate())) { // 启动线程异步处理 WebSSHData finalWebSSHData = webSSHData; executorService.execute(new Runnable() { @Override public void run() { try { connectToSSH(webSSHInfo, finalWebSSHData, webSocketSession); } catch (Exception e) { LOGGER.error("connect to ssh error : {}", e.getMessage()); closeConnection(webSocketSession); try { // 主动向前端推送msg sendMessage(webSocketSession, "Connection-Refused".getBytes()); } catch (Exception ee) { ee.printStackTrace(); } } } }); } // command else if (OperateConstant.WEBSSH_OPERATE_COMMAND.equals(webSSHData.getOperate())) { String command = webSSHData.getCommand(); try { transToSSH(webSSHInfo.getChannelShell(), command); } catch (Exception e) { LOGGER.error("trans to ssh error : {}", e.getMessage()); closeConnection(webSocketSession); // 用户登录设备后如果长时间没有进行操作,可以配置此命令将长时间连接始终处于空闲状态,系统将自动断开该连接。 sendMessage(webSocketSession, "Connection-IdleTimeout".getBytes()); } } else { LOGGER.error("不支持的操作"); closeConnection(webSocketSession); } } @Override public void closeConnection(WebSocketSession session) { String userId = String.valueOf(session.getAttributes().get(OperateConstant.USER_UUID_KEY)); WebSSHInfo webSSHInfo = (WebSSHInfo) sshMap.get(userId); if (webSSHInfo != null) { // 断开shell连接 if (webSSHInfo.getChannelShell() != null && !webSSHInfo.getChannelShell().isClosed()) { webSSHInfo.getChannelShell().close(false); LOGGER.info("ChannelShell Closed..."); } // 断开exec连接 if (webSSHInfo.getChannelExec() != null) { webSSHInfo.getChannelExec().close(false); LOGGER.info("ChannelExec Closed..."); } if (webSSHInfo.getSshClient() != null && !webSSHInfo.getSshClient().isClosed()) { webSSHInfo.getSshClient().close(false); LOGGER.info("SshClient Closed..."); } //map中移除 sshMap.remove(userId); } } private void connectToSSH(WebSSHInfo webSSHInfo, WebSSHData webSSHData, WebSocketSession webSocketSession) throws Exception { // verify Session ConnectFuture verifySession = webSSHInfo.getSshClient() .connect(webSSHData.getUsername(), webSSHData.getHost(), webSSHData.getPort()) .verify(WebSSHConfig.connectTimeout); if (!verifySession.isConnected()) { LOGGER.error("Session connect failed after {} mill seconds", WebSSHConfig.connectTimeout); throw new Exception( "Session connect failed after " + WebSSHConfig.connectTimeout + " mill seconds."); } ClientSession clientSession = verifySession.getSession(); if (OperateConstant.KEYPAIR.equalsIgnoreCase(webSSHData.getAuthType())) { Path pathPrivate = null; Path pathPublic = null; if (StringUtils.isNotBlank(webSSHData.getPrivateKey())) { pathPrivate = new File(webSSHData.getPrivateKey()).toPath(); } if (StringUtils.isNotBlank(webSSHData.getPublicKey())) { pathPublic = new File(webSSHData.getPublicKey()).toPath(); } if (pathPrivate != null || pathPublic != null) { clientSession.addPublicKeyIdentity(new FileKeyPairProvider(Stream.of(pathPrivate, pathPublic).filter(Objects::nonNull).collect(Collectors.toList())).loadKey(clientSession, webSSHData.getKeypairType())); } } else if (OperateConstant.PASSWORD.equalsIgnoreCase(webSSHData.getAuthType())) { clientSession.addPasswordIdentity(webSSHData.getPassword()); } else { throw new Exception("Unknown ssh auth type: " + webSSHData.getAuthType()); } // authentication clientSession.auth().verify(WebSSHConfig.authTimeout); sendMessage(webSocketSession, "Authentication-Success".getBytes()); ChannelShell cs = clientSession.createShellChannel(); cs.setRedirectErrorStream(true); cs.open(); cs.waitFor(Collections.singletonList(ClientChannelEvent.CLOSED), WebSSHConfig.executeTimeout); webSSHInfo.setChannelShell(cs); //读取终端返回的信息流 InputStream out = cs.getInvertedOut(); try { //循环读取 byte[] buffer = new byte[1024]; int i = 0; //如果没有数据来,线程会一直阻塞在这个地方等待数据。 while ((i = out.read(buffer)) != -1) { sendMessage(webSocketSession, Arrays.copyOfRange(buffer, 0, i)); } } finally { // 断开连接后关闭会话-channel也随之关闭 clientSession.close(); if (clientSession.isClosed()) { LOGGER.info("clientSession closed..."); } if (!cs.isClosed()) { cs.close(); } if (out != null) { out.close(); } } } private void transToSSH(ChannelShell channel, String command) throws IOException { if (channel != null) { channel.getInvertedIn().write(command.getBytes()); channel.getInvertedIn().flush(); } } private void sendMessage(WebSocketSession session, byte[] buffer) throws IOException { session.sendMessage(new TextMessage(buffer)); } private void startClientSessionHeartCheck(ClientSession clientSession, WebSocketSession websocketSession, ChannelShell channelshell) { Thread thread = new Thread(() -> { if (clientSession != null) { while (clientSession.isOpen()) { LOGGER.info(websocketSession.getAttributes().get(OperateConstant.USER_UUID_KEY) + " clientSession is normal"); try { Thread.sleep(1000 * 60 * 2); } catch (Exception e) { LOGGER.error("心跳检测异常:", e); } // 停止线程 if (clientSession.isClosed() || clientSession.isClosed()) { // 告知前端session被关闭了 try { sendMessage(websocketSession, "Connection-closed".getBytes()); } catch (IOException ee) { ee.printStackTrace(); } Thread.currentThread().interrupt(); } } } else if (channelshell != null) { while (channelshell.isOpen()) { LOGGER.info(websocketSession.getAttributes().get(OperateConstant.USER_UUID_KEY) + " clientSession is normal"); try { Thread.sleep(1000 * 60 * 2); } catch (Exception e) { LOGGER.error("心跳检测异常:", e); } // 停止线程 if (channelshell.isClosed() || channelshell.isClosed()) { // 告知前端session被关闭了 try { sendMessage(websocketSession, "Connection-closed".getBytes()); } catch (IOException ee) { ee.printStackTrace(); } Thread.currentThread().interrupt(); } } } }); }}
其中涉及到的pojo及util
@1 OperateConstant
package cn.cloud.common.pojo;public interface OperateConstant { String USER_UUID_KEY = "user_uuid"; String WEBSSH_OPERATE_CONNECT = "connect"; String WEBSSH_OPERATE_COMMAND = "command"; String PASSWORD = "PASSWORD"; String KEYPAIR = "KEYPAIR"; String IP = "websocket_ip";}
@2 WebSSHConfig
package cn.cloud.common.pojo;public interface WebSSHConfig { Long connectTimeout = 5000L; Long authTimeout = 5000L; Long executeTimeout = 3000L; Integer maxIdle = 8; Integer maxTotal = 15; Integer minIdle = 2; Boolean testWhileIdle = true; Boolean testOnCreate = false; Boolean testOnBorrow = false; Boolean testOnReturn = false; Long minEvictableIdleTimeMillis = 300000L; Long timeBetweenEvictionRunsMillis = 30000L; Boolean blockWhenExhausted = true; Long maxWaitMillis = 30000L;}
@3 WebSSHData
package cn.cloud.common.pojo;import org.apache.commons.io.FilenameUtils;public class WebSSHData { //操作-connect、command private String operate; private String host; //端口号默认为22 private Integer port = 22; private String username; // 认证类型:PASSWORD/KEYPAIR private String authType = "PASSWORD"; private String password; private String command = ""; private String keypairType = "ssh-rsa"; private String publicKey = FilenameUtils.concat(System.getProperty("user.home"), ".ssh/id_rsa.pub"); private String privateKey = FilenameUtils.concat(System.getProperty("user.home"), ".ssh/id_rsa"); public String getOperate() { return operate; } public void setOperate(String operate) { this.operate = operate; } public String getHost() { return host; } public void setHost(String host) { this.host = host; } public Integer getPort() { return port; } public void setPort(Integer port) { this.port = port; } public String getUsername() { return username; } public void setUsername(String username) { this.username = username; } public String getPassword() { return password; } public void setPassword(String password) { this.password = password; } public String getCommand() { return command; } public void setCommand(String command) { this.command = command; } public String getAuthType() { return authType; } public void setAuthType(String authType) { this.authType = authType; } public String getKeypairType() { return keypairType; } public void setKeypairType(String keypairType) { this.keypairType = keypairType; } public String getPublicKey() { return publicKey; } public void setPublicKey(String publicKey) { this.publicKey = publicKey; } public String getPrivateKey() { return privateKey; } public void setPrivateKey(String privateKey) { this.privateKey = privateKey; }}
@4 WebSSHInfo
package cn.cloud.common.pojo;import lombok.AllArgsConstructor;import lombok.Data;import lombok.NoArgsConstructor;import org.apache.sshd.client.SshClient;import org.apache.sshd.client.channel.ChannelExec;import org.apache.sshd.client.channel.ChannelShell;import org.springframework.web.socket.WebSocketSession;@Data@AllArgsConstructor@NoArgsConstructorpublic class WebSSHInfo { // ssh客户端 private SshClient sshClient; // ws连接 private WebSocketSession webSocketSession; // linux-管道channel private ChannelShell channelShell; // win-管道channel private ChannelExec channelExec;}
@5 RedisUtil
package cn.cloud.common.util;import org.springframework.beans.factory.annotation.Autowired;import org.springframework.data.redis.core.RedisTemplate;import org.springframework.stereotype.Component;import java.util.concurrent.TimeUnit;@Componentpublic class RedisUtil { @Autowired public void setRedisTemplate(RedisTemplate redisTemplate) { this.redisTemplate = redisTemplate; } private static RedisTemplateredisTemplate; public Object hget(String key, String item){ return redisTemplate.opsForHash().get(key, item); } public void hput(String key, String item, Object value){ redisTemplate.opsForHash().put(key, item, value); } public void expire(String key, long time){ redisTemplate.expire(key, time, TimeUnit.SECONDS); }}
简单的xterm案例
xterm.js是一个基于WebSocket的容器,它可以帮助我们在前端实现命令行的样式。就像是我们平常再用SecureCRT或者XShell连接服务器时一样。
下面是官网上的入门案例: