如题,第一次用websocket,做了个这玩意,只做了上下文的聊天,没做流式。
中间还有个低级报错但卡了好久,具体可以看【错误记录】websocket连接失败,但后端毫无反应,还有【错误记录】ruoyi-vue@Autowired注入自定义mapper时为null解决
,感兴趣可前往观看。
实际上我后端用的是ruoyi-vue,前端用的ruoyi-app,但不重要。因为功能就是基于websocket和文心一言千帆大模型的接口,完全可以独立出来。
每个新建的账号会送一张20元的代金券,期限一个月内。而聊天服务接口单价约1分/千token,总之用来练手肯定够用了。

参考

文档中心-ERNIE-Bot-turbo
百度文心一言接入教程
若依插件-集成websocket实现简单通信

先看看效果

大致这样。
在这里插入图片描述

2023.10.13更新:昨天和朋友聊了一下,发现他的想法和我的不同——根本不用实体类去保存解析复杂的json,直接保存消息内容。有一说一,在这个小demo这里,确实可以更快更简单的实现,因为这个demo最耗时的就是看又臭又长的参数,然后写请求体和返回值的实体类,至少请求体实体类是可以不写的。

下面进入正题。

文心千帆创建应用

  1. 文心一言,大概是这里,先创建个账号,进控制台创建一个应用(有一个apikey和secretkey,有用),开通一个聊天服务(我开通的是ErnieBot-turbo),就可以了。具体有点忘了,大家可以参考其他博客。
  2. 其次官方有给一些参考,API调用指南在线测试平台,第二个链接可以对自己开通的聊天服务进行测试。其中也有一个分类是“技术文档”和“示例代码”,技术文档里边有普通/流式的请求/响应的参数和示例(如果比较小不容易看,文档中心-ERNIE-Bot-turbo也有),示例代码就是请求的各个语言的示例代码。

思路

有三个角色,大模型 ←→ 后端 ←→ 前端。

大模型:接受后端发过来的消息,返回响应消息
后端:接受前端发过来的消息,封装发给大模型;接收大模型返回的消息,回给后端;发送的消息和返回的消息都要保存到数据库
前端:发送消息,接受后端返回的响应消息,实时回显在聊天页面。

显然,websocket用在前后端之间进行交互,后端类似一个中间人,前端是一个用户,大模型是ai服务。

步骤与代码

  1. 实现websocket相关
    1.1 注册到spring
    @Configuration
    public class WebSocketConfig {
        @Bean
        public ServerEndpointExporter serverEndpointExporter() {
            return new ServerEndpointExporter();
        }
    }
    
    1.2 实现一个WebSocket的服务(别看这么长,其实参考了若依插件-集成websocket实现简单通信,但没涉及信号量之类所以没什么用,除了onMessage外,其他如onOpen打印一条消息就行了,更多如WebSocketUsers可以去链接那下载)
    @CrossOrigin
    @Component
    @ServerEndpoint("/websocket/message")
    public class WebSocketServer {
       private ChatRecordMapper chatRecordMapper = SpringUtils.getBean(ChatRecordMapper.class);
       /**
        * WebSocketServer 日志控制器
        */
       private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketServer.class);
    
       /**
        * 默认最多允许同时在线人数100
        */
       public static int socketMaxOnlineCount = 100;
    
       private static Semaphore socketSemaphore = new Semaphore(socketMaxOnlineCount);
    
       /**
        * 连接建立成功调用的方法
        */
       @OnOpen
       public void onOpen(Session session) throws Exception {
           boolean semaphoreFlag = false;
           // 尝试获取信号量
           semaphoreFlag = SemaphoreUtils.tryAcquire(socketSemaphore);
           if (!semaphoreFlag) {
               // 未获取到信号量
               LOGGER.error("\n 当前在线人数超过限制数- {}", socketMaxOnlineCount);
               WebSocketUsers.sendMessageToUserByText(session, "当前在线人数超过限制数:" + socketMaxOnlineCount);
               session.close();
           } else {
               // 添加用户
               WebSocketUsers.put(session.getId(), session);
               LOGGER.info("\n 建立连接 - {}", session);
               LOGGER.info("\n 当前人数 - {}", WebSocketUsers.getUsers().size());
               WebSocketUsers.sendMessageToUserByText(session, "连接成功");
           }
       }
    
       /**
        * 连接关闭时处理
        */
       @OnClose
       public void onClose(Session session) {
           LOGGER.info("\n 关闭连接 - {}", session);
           // 移除用户
           WebSocketUsers.remove(session.getId());
           // 获取到信号量则需释放
           SemaphoreUtils.release(socketSemaphore);
       }
    
       /**
        * 抛出异常时处理
        */
       @OnError
       public void onError(Session session, Throwable exception) throws Exception {
           if (session.isOpen()) {
               // 关闭连接
               session.close();
           }
           String sessionId = session.getId();
           LOGGER.info("\n 连接异常 - {}", sessionId);
           LOGGER.info("\n 异常信息 - {}", exception);
           // 移出用户
           WebSocketUsers.remove(sessionId);
           // 获取到信号量则需释放
           SemaphoreUtils.release(socketSemaphore);
       }
    
       /**
        * 服务器接收到客户端消息时调用的方法
        */
       @OnMessage
       public void onMessage(String message, Session session) {
           // 首先,接收到一条消息
           LOGGER.info("\n 收到消息 - {}", message);
           // 1. 调用大模型API,把上下文和这次问题传入,得到回复
           BigModelService bigModelService = new BigModelService();
           TurboResponse response = bigModelService.callModelAPI(session.getId(),message);
           if (response == null) {
               WebSocketUsers.sendMessageToUserByText(session, "抱歉,似乎出了点问题,请联系管理员");
               return;
           }
           WebSocketUsers.sendMessageToUserByText(session, response.getResult());
       }
    }
    
  2. 实现请求接口相关
    2.1 先写实体类,包括BaiduChatMessage(最基本的聊天消息)、ErnieBotTurboParam(ErnieBot-Turbo的请求参数,包括了List<BaiduChatMessage>)TurboResponse(请求返回结果对应的实体类)
    @Data
    @SuperBuilder
    @NoArgsConstructor
    @AllArgsConstructor
    public class BaiduChatMessage implements Serializable {
        private String role;
        private String content;
    }
    
    @Data
    @SuperBuilder
    public class ErnieBotTurboParam implements Serializable {
        /**
         * 聊天上下文信息。说明:
         * (1)messages成员不能为空,1个成员表示单轮对话,多个成员表示多轮对话
         * (2)最后一个message为当前请求的信息,前面的message为历史对话信息
         * (3)必须为奇数个成员,成员中message的role必须依次为user、assistant
         * (4)最后一个message的content长度(即此轮对话的问题)不能超过2000个字符;如果messages中content总长度大于2000字符,系统会依次遗忘最早的历史会话,直到content的总长度不超过2000个字符
         */
        protected List<BaiduChatMessage> messages;
    
        /**
         * 是否以流式接口的形式返回数据,默认false
         */
        protected Boolean stream;
    
        /**
         * 表示最终用户的唯一标识符,可以监视和检测滥用行为,防止接口恶意调用
         */
        protected String user_id;
    
        public boolean isStream() {
            return Objects.equals(this.stream, true);
        }
    
        public ErnieBotTurboParam(){}
    }
    
    @Data
    public class TurboResponse implements Serializable {
        private String id;
        private String object;
        private Integer created;
    
        private String sentence_id;
        private Boolean is_end;
        private Boolean is_truncated;
        private String result;
        private Boolean need_clear_history;
    
        private Usage usage;
    
        @Data
        public static class Usage implements Serializable {
            private Integer prompt_tokens;
            private Integer completion_tokens;
            private Integer total_tokens;
        }
    }
    
    2.2 请求接口实现(注释很详细就不多说了)
    public class BigModelService {
        private ChatRecordMapper chatRecordMapper = SpringUtils.getBean(ChatRecordMapper.class);
        private static final Logger LOGGER = LoggerFactory.getLogger(BigModelService.class);
        private static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
        public static final String API_KEY = "你的apikey";
        public static final String SECRET_KEY = "你的secretkey";
    
        static String getAccessToken() throws IOException {
            MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded");
            RequestBody body = RequestBody.create(mediaType, "grant_type=client_credentials&client_id=" + API_KEY
                    + "&client_secret=" + SECRET_KEY);
            Request request = new Request.Builder()
                    .url("https://aip.baidubce.com/oauth/2.0/token")
                    .method("POST", body)
                    .addHeader("Content-Type", "application/x-www-form-urlencoded")
                    .build();
            Response response = HTTP_CLIENT.newCall(request).execute();
            // 解析返回的access_token
            JSONObject jsonObject = JSONObject.parseObject(response.body().string());
            return jsonObject.getString("access_token");
        }
    
        public TurboResponse callModelAPI(String sessionId, String message) {
            // 1. 构建请求体
            // 1.1 调用大模型API,要从数据库去查询上下文
            ChatRecord cr = chatRecordMapper.selectChatRecordBySessionId(sessionId);
            String records = cr == null ? "{}" : cr.getRecords();
            // 1.2 把message加进请求体
            // 1.2.1 解析上下文,获取聊天记录,把新的message封装加入到聊天记录中
            ErnieBotTurboParam param = JSONObject.parseObject(records, ErnieBotTurboParam.class);
            List<BaiduChatMessage> messages = param.getMessages() == null ? new ArrayList<>() : param.getMessages();
            messages.add(BaiduChatMessage.builder().role("user").content(message).build());
            // 1.2.2 把messages重新设置到param中
            param.setMessages(messages);
            try {
                // 2. 发出请求,调用大模型API
                RequestBody body = RequestBody.create(MediaType.parse("application/json"), JSONObject.toJSONString(param));
                Request request = new Request.Builder()
                        .url("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=" + getAccessToken())
                        .method("POST", body)
                        .addHeader("Content-Type", "application/json")
                        .build();
                Response response = HTTP_CLIENT.newCall(request).execute();
    
                if (response.isSuccessful()) {
                    // 3. 如果调用成功,
                    // 3.1 解析返回的聊天回复结果
                    TurboResponse turboResponse = JSONObject.parseObject(response.body().string(), TurboResponse.class);
                    LOGGER.info("调用大模型API成功: {}", turboResponse.toString());
                    // 3.2 将聊天回复结果存入数据库
                    // 3.2.1 先根据sessionId查询数据库
                    int count = chatRecordMapper.selectRecordCountBySessionId(sessionId);
                    // 将ai刚返回的回复追加到param中,再填入chatRecord
                    BaiduChatMessage aiMessage = BaiduChatMessage.builder()
                            .role("assistant").content(turboResponse.getResult()).build();
                    messages.add(aiMessage);
                    param.setMessages(messages);
                    if (count == 0) {
                        // 3.2.2 如果没有记录,则插入
                        ChatRecord chatRecord = new ChatRecord();
                        chatRecord.setRecordId(turboResponse.getId());
                        chatRecord.setSessionId(sessionId);
                        chatRecord.setRecords(JSONObject.toJSONString(param));
                        // 插入时应填入create_time字段
                        chatRecord.setCreateTime(LocalDateTime.now());
                        chatRecordMapper.insertChatRecord(chatRecord);
                    } else {
                        // 3.2.3 如果有记录,则更新
                        // 3.2.4 先查询出原来的记录的create_time,判断是否超过15min
                        ChatRecord chat = chatRecordMapper.selectChatRecordBySessionId(sessionId);
                        LocalDateTime createTime = chat.getCreateTime();
                        if (LocalDateTime.now().isAfter(createTime.plusMinutes(15))) {
                            // 2.2.2 如果超过15min,则清除records字段,将新的对话记录追加到records字段中
                            ChatRecord chatRecord = new ChatRecord();
                            chatRecord.setRecordId(turboResponse.getId());
                            chatRecord.setSessionId(sessionId);
                            chatRecord.setRecords(messages.toString());
                            // 更新时应填入create_time字段
                            chatRecord.setCreateTime(LocalDateTime.now());
                            chatRecordMapper.insertChatRecord(chatRecord);
                        } else {
                            ChatRecord chat0 = chatRecordMapper.selectChatRecordBySessionId(sessionId);
                            // 如果没有超过15min,则将新的对话记录追加到records字段中
                            ChatRecord chatRecord = new ChatRecord();
                            chatRecord.setRecordId(turboResponse.getId());
                            chatRecord.setSessionId(sessionId);
                            // 解析出原来的records
                            ChatRecord oldchat = chatRecordMapper.selectChatRecordBySessionId(sessionId);
                            ErnieBotTurboParam records1 = JSONObject.parseObject(oldchat.getRecords(), ErnieBotTurboParam.class);
                            // 将新的对话记录追加到records字段中
                            records1.setMessages(messages);
                            chatRecord.setRecords(JSONObject.toJSONString(records1));
                            // 没有15min就不更新create_time字段
                            // 更新chat_record
                            chatRecordMapper.updateChatRecord(chatRecord);
                        }
                    }
                    return turboResponse;
                } else {
                    LOGGER.error("调用大模型API失败: {}", response.message());
                }
            } catch (IOException e) {
                LOGGER.error("调用大模型API发生异常:", e);
            }
            return null;
        }
    }
    
  3. 持久化
    3.1 数据库里建个表
    CREATE TABLE `chat_record` (
      `record_id` varchar(20) NOT NULL COMMENT '记录id',
      `session_id` varchar(10) DEFAULT NULL COMMENT '所属用户',
      `records` json DEFAULT NULL COMMENT '聊天记录',
      `create_time` datetime DEFAULT NULL COMMENT '创建时间(判断过期)',
      PRIMARY KEY (`record_id`)
    ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='聊天记录表';
    
    3.2 对应实体类
    @Data
    @NoArgsConstructor
    @AllArgsConstructor
    public class ChatRecord {
        private String recordId;
        private String sessionId;
        private String records;
        private LocalDateTime createTime;
    }
    
    3.3 再写个mapper就行了
    @Mapper
    public interface ChatRecordMapper {
    
        @Insert("INSERT INTO chat_record (record_id, session_id, records, create_time) " +
                "VALUES (#{recordId}, #{sessionId}, #{records}, #{createTime})")
        void insertChatRecord(ChatRecord chatRecord);
    
        @Select("SELECT COUNT(*) FROM chat_record WHERE session_id = #{sessionId}")
        int selectRecordCountBySessionId(String sessionId);
    
        @Results({ // id是用来给@ResultMap注解引用的,到时候在xml中可以直接使用@ResultMap(value = "chatRecord")
                @Result(property = "recordId", column = "record_id"),
                @Result(property = "sessionId", column = "session_id"),
                @Result(property = "records", column = "records"),
                @Result(property = "createTime", column = "create_time")
        })
        @Select("SELECT * FROM chat_record WHERE session_id = #{sessionId}")
        ChatRecord selectChatRecordBySessionId(String sessionId);
    
    
        @Update("UPDATE chat_record SET records = #{records} WHERE session_id = #{sessionId}")
        void updateChatRecord(ChatRecord chatRecord);
    }
    
  4. 前端聊天页面与实时的回显
    4.1 聊天页面写一个(这里前端是uniapp,样式用到了些colorUI)
    <template>
    	<view>
    		<!-- 聊天消息界面 -->
    		<view class="cu-chat">
    			<view v-for="(message, index) in chatMessages" :key="index"
    				:class="message.type === 'user' ? 'cu-item self' : 'cu-item'">
    
    				<!-- 头像代码放在消息前面当消息类型为'ai' -->
    				<view class="cu-avatar round" v-if="message.type === 'ai'"
    					:style="{'background-image': 'url(https://img2.baidu.com/it/u=3652671026,3326768653&fm=253&fmt=auto&app=120&f=JPEG?w=171&h=171)'}">
    				</view>
    
    				<view class="main">
    					<view class="content shadow">
    						<text>{{ message.content }}</text>
    					</view>
    				</view>
    
    				<!-- 头像代码放在消息后面当消息类型为'user' -->
    				<view class="cu-avatar round" v-if="message.type === 'user'"
    					:style="{'background-image': 'url(https://img2.baidu.com/it/u=2435295423,1880375459&fm=253&fmt=auto&app=138&f=JPEG?w=519&h=500)'}">
    				</view>
    
    				<view class="date">{{ message.time }}</view>
    			</view>
    		</view>
    
    
    		<!-- 底部输入框 -->
    		<view class="cu-bar foot input" :style="[{bottom:InputBottom+'px'}]">
    			<view class="action">
    				<text class="cuIcon-sound text-grey"></text>
    			</view>
    			<input class="solid-bottom" :adjust-position="false" :focus="false" maxlength="300" cursor-spacing="10"
    				@focus="InputFocus" @blur="InputBlur" v-model="userMessage"></input>
    			<view class="action">
    				<text class="cuIcon-emojifill text-grey"></text>
    			</view>
    			<button class="cu-btn bg-green shadow" @click="sendMessage">发送</button>
    		</view>
    	</view>
    </template>
    
    <script>
    	export default {
    		data() {
    			return {
    				status: "",
    				ws: null,
    				InputBottom: 0,
    				userMessage: "", // 聊天框的内容,待发送的消息
    				chatMessages: [], // 用于存储聊天消息的数组
    			};
    		},
    		created() {
    			this.connect();
    		},
    		methods: {
    			InputFocus(e) {
    				this.InputBottom = e.detail.height;
    			},
    			InputBlur(e) {
    				this.InputBottom = 0;
    			},
    			connect() {
    				console.info("开始连接……")
    				this.ws = new WebSocket("ws://127.0.0.1:8080/websocket/message");
    				const self = this;
    				this.ws.onopen = function(event) {
    					console.info("连接成功")
    				};
    				this.ws.onmessage = function(event) {
    					console.info("收到服务端消息:", event.data);
    					// 收到的消息也保存到聊天数组中
    					if (event.data != "连接成功") {
    						self.chatMessages.push({
    							content: event.data,
    							type: "ai",
    							time: self.formatTime()
    						})
    					}
    				};
    				this.ws.onclose = function(event) {
    					console.info("关闭连接")
    				};
    			},
    			onunload() {
    				if (this.ws) {
    					this.ws.close();
    					this.ws = null;
    				}
    			},
    			formatTime() {
    				const now = new Date();
    				const year = now.getFullYear();
    				const month = String(now.getMonth() + 1).padStart(2, '0');
    				const day = String(now.getDate()).padStart(2, '0');
    				const hours = String(now.getHours()).padStart(2, '0');
    				const minutes = String(now.getMinutes()).padStart(2, '0');
    				const seconds = String(now.getSeconds()).padStart(2, '0');
    				return `${year}${month}${day}${hours}:${minutes}:${seconds}`;
    			},
    			sendMessage() {
    				if (this.ws) {
    					// 点击发送,把输入框内容添加到聊天数组
    					this.chatMessages.push({
    						content: this.userMessage,
    						type: "user", // 自己的消息
    						time:this.formatTime(),
    					});
    					// 发送消息
    					this.ws.send(this.userMessage);
    					// 清空输入框
    					this.userMessage = "";
    				} else {
    					alert("未连接到服务器");
    				}
    			},
    		},
    	};
    </script>
    
    <style>
    	page {
    		padding-bottom: 100upx;
    	}
    </style>
    
    4.2 js里写一个websocket(见上4.1的connect())

以上就大功告成了,这玩意还有很多缺漏和细节没做,像现在还是根据会话id去做,没有匹配用户id,15min清除聊天记录,但前端那没清……不过能跑能动就行,本来就是一个小任务,也懒得继续花时间调整。
记录一下,有问题可以交流

Logo

学AI,认准AI Studio!GPU算力,限时免费领,邀请好友解锁更多惊喜福利 >>>

更多推荐