websocket 发表于 2024-1-6 16:43:13

用java编写一个websocket端口转发工具

编写背景
工作原因,需要使用到内网,项目组的几个同事都要用内网,然后,甲方问上级申请了内网服务器。但是这个内网服务器需要通过堡垒机访问,而堡垒机只能用指定的两个内网ip访问。本来,两个ip是够用的,但是甲方有个领导是技术出身,就让我们在他的电脑上给他配置堡垒机的访问环境。然后,两个ip就被甲方领导占用了一个。然后,就导致内网ip不够用,我和同事经常等对方用完。
而资源池对内网开放了应用服务器的端口来供内网访问应用,堡垒机上只能编辑nginx的配置,虽然nginx支持tcp协议代理,但是需要安装单独的模块,内网显然是无法实现的。
所以,我就在网上搜索,有没有用websocket转发端口的程序,查了一下之后,发现一个go写的,又因为,资源池的服务器是没办法装go环境的,所以,我选择自己用java实现。

实现原理
我先用客户端监听本地的某一个端口,在有连接请求之后,客户端会生成一个随机的密钥,用公钥加密之后发送到服务器端,然后,服务器端和客户端之间的所有信息都会使用密钥加密通讯。socketServer收到的所有信息都加密后通过webSocket发送到服务端,而服务端又将数据解密后发送到目标ip和端口,再将从目标ip端口接收到的数据都加密了之后通过webSocket发送给客户端,客户端解密数据之后将数据反馈到本地的端口,然后就实现了资源池的端口加密映射到本地端口。数据库一类的端口映射出来后,就可以通过工具连接了,网页也可以映射,测试也可以正常访问。常见的应用的网络通讯,基本都是基于tcp协议的,所以,这样映射,基本可以满足日常开放。也就不用和同事争ip了。

核心代码

服务端

@Component
@Slf4j
@ServerEndpoint("/ws")
public class WSTServer {
       
       
        private Session session;
       
        private int step = 0;
       
        private String k;
        private String host;
        private int port;
       
        private Socket socket;
       
        private OutputStream out;
       
        private DES des;
       
        @OnOpen
        public void onOpen(Session session) {
                this.session = session;
                this.step = 0;
                log.info("连接建立");
        }
       
        @OnMessage
        public void strMsg(String msg) {
               
                switch(step) {
                       
                case 0:
                        //密钥交换
                        swKey(msg);
                        break;
                case 1:
                        //建立隧道
                        buildLink(msg);
                        break;
                }
        }
       
        @OnMessage
        public void byteMsg(byte[] b) {
                if( this.step == 2 ) {
                        byte[] sendByte = this.des.decrypt(b);
                        try {
                                this.out.write(sendByte);
                        } catch (IOException e) {
                                e.printStackTrace();
                        }
                }
        }
       
        @OnClose
        public void onClose() {
                if(this.socket != null) {
                        try {
                                this.session.close();
                        } catch (IOException e) {
                                e.printStackTrace();
                        }
                }
        }
       
        private void buildLink(String msg) {
               
                try {
                       
                        String url = des.decryptStr(msg);
                        List<String> pms = StrUtil.split(url, ":");
                       
                        this.host = pms.get(0);
                        this.port = Integer.valueOf(pms.get(1));
                       
                        this.socket = new Socket(host, this.port);
                       
                        this.out = socket.getOutputStream();
                       
                        new Thread(()->{
                               
                                try {
                                       
                                       
                                        try(
                                                InputStream input = socket.getInputStream();
                                        ){
                                               
                                                while(socket.isConnected()) {
                                                       
                                                        byte[] b = new byte;
                                                        int len = 0;
                                                        while( ( len = input.read(b) ) > 0) {
                                                                byte[] send = new byte;
                                                                System.arraycopy(b, 0, send, 0, len);
                                                                this.session.getBasicRemote().sendBinary(ByteBuffer.wrap(des.encrypt(send)));
                                                        }
                                                }
                                               
                                        }
                                       
                                }catch (IOException e) {
                                        e.printStackTrace();
                                }
                               
                        }).start();
                       
                        sendMsg("success");
                        this.step++;
                       
                        log.info("隧道建立成功");
                }catch (Exception e) {
                        e.printStackTrace();
                        log.info("隧道构建失败");
                        close();
                }
               
        }
       
        private void swKey(String msg) {
                try {
                        //获取客户端发来的密钥
                        this.k = RsaUtil.decodeHex(msg);
                        this.des = new DES(HexUtil.decodeHex(this.k));
                        this.step++;
                        sendMsg("success");
                        log.info("密钥获取完毕");
                }catch (Exception e) {
                        // 密钥获取失败,中断连接
                        e.printStackTrace();
                        log.info("密钥获取失败");
                        close();
                }
        }
       
        private void sendMsg(String msg) {
                if(this.des != null) {
                        try {
                                this.session.getBasicRemote().sendText(this.des.encryptHex(msg));
                        } catch (IOException e) {
                                e.printStackTrace();
                        }
                }
        }
       
        private void close() {
                try {
                        this.session.close();
                } catch (IOException e1) {}
        }

客户端
public class WSClient {

        public static void main(String[] args) {
               
                String pub = FileUtil.readString("pub", CharsetUtil.CHARSET_GBK);
                RSA rsa = new RSA(null,pub);
               
                try (
                                ServerSocket server = new ServerSocket(Integer.valueOf(args));
                        ){
                       
                        while(true) {
                               
                                Socket skt = server.accept();
                               
                                InputStream in = skt.getInputStream();
                                OutputStream out = skt.getOutputStream();
                               
                                new Thread( () -> {
                                        buildLink(args,args,in, out, new DES(), rsa);
                                } ).start();
                        }
                       
                } catch (IOException e) {
                        e.printStackTrace();
                }
               
        }

        public static void buildLink(String tar,String ws,InputStream input, OutputStream out,DES des,RSA rsa) {
               
                try {
                       
                        AtomicBoolean wconnect = new AtomicBoolean(true);
                       
                        WebSocketClient client = new WebSocketClient(new URI(ws)) {

                                @Override
                                public void onOpen(ServerHandshake handshakedata) {
                                       
                                }

                                @Override
                                public void onMessage(String message) {
                                        //接收到回复后,发送需要建立的连接
                                        wconnect.set(false);
                                }
                               
                                @Override
                                public void onMessage(ByteBuffer bytes) {
                                        byte[] b = des.decrypt(bytes.array());
                                        try {
                                                out.write(b);
                                        } catch (IOException e) {
                                                e.printStackTrace();
                                        }
                                }

                                @Override
                                public void onError(Exception ex) {

                                }

                                @Override
                                public void onClose(int code, String reason, boolean remote) {

                                }
                        };
                       
                        client.connectBlocking();
                       
                        //建立连接后,发送信息
                        client.send(rsa.encryptHex(HexUtil.encodeHexStr(des.getSecretKey().getEncoded()), KeyType.PublicKey));
                       
                        while(wconnect.get()) {
                                Thread.sleep(200);
                        }
                       
                        //收到回复后,发送连接目标
                        client.send(des.encryptHex(tar));
                        wconnect.set(true);
                       
                        while(wconnect.get()) {
                                Thread.sleep(200);
                        }
                       
                        //发送心跳包,防止会话过期
                        new Thread( () -> {
                                while(true) {
                                        client.send("123");
                                        ThreadUtil.safeSleep(1000*5);
                                }
                        }).start();
                       
                        //收到连接建立后,建立本地连接
                        while(true) {
                               
                                byte[] b = new byte;
                                int len = 0;
                               
                                while((len = input.read(b) ) > 0) {
                                       
                                        byte[] send = new byte;
                                        System.arraycopy(b, 0, send, 0, len);
                                       
                                        client.send(des.encrypt(send));
                                }
                        }
                       
                } catch (URISyntaxException e) {
                        e.printStackTrace();
                } catch (InterruptedException e) {
                        e.printStackTrace();
                } catch (IOException e) {
                        e.printStackTrace();
                }
        }
}




页: [1]
查看完整版本: 用java编写一个websocket端口转发工具