From 86ef150fecc6104eff9ed70035d08d03a3bc9cc5 Mon Sep 17 00:00:00 2001 From: lan Date: Tue, 10 Mar 2026 12:58:23 +0800 Subject: [PATCH] Replace websocket flow with SSE support in backend. Update handlers, services, router, and data conversion logic to support server-sent events and related message pipeline changes. Made-with: Cursor --- configs/config.yaml | 36 +- internal/dto/converter.go | 2 + internal/dto/dto.go | 2 + internal/handler/message_handler.go | 99 ++- internal/handler/post_handler.go | 32 +- internal/handler/websocket_handler.go | 849 -------------------------- internal/middleware/cors.go | 25 +- internal/model/post.go | 2 +- internal/pkg/sse/hub.go | 152 +++++ internal/pkg/websocket/websocket.go | 435 ------------- internal/repository/comment_repo.go | 68 ++- internal/repository/post_repo.go | 73 ++- internal/repository/user_repo.go | 6 +- internal/router/router.go | 16 +- internal/service/chat_service.go | 241 +++----- internal/service/comment_service.go | 38 +- internal/service/group_service.go | 110 ++-- internal/service/post_service.go | 44 +- internal/service/push_service.go | 178 ++---- 19 files changed, 689 insertions(+), 1719 deletions(-) delete mode 100644 internal/handler/websocket_handler.go create mode 100644 internal/pkg/sse/hub.go delete mode 100644 internal/pkg/websocket/websocket.go diff --git a/configs/config.yaml b/configs/config.yaml index d686eec..6105f30 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -36,9 +36,9 @@ database: redis: type: miniredis # miniredis 或 redis redis: - host: localhost + host: 1Panel-redis-dfmM port: 6379 - password: "" + password: "redis_j8CMza" db: 0 miniredis: host: localhost @@ -67,13 +67,13 @@ cache: # S3对象存储配置 # 环境变量: APP_S3_ENDPOINT, APP_S3_ACCESS_KEY, APP_S3_SECRET_KEY, APP_S3_BUCKET, APP_S3_DOMAIN s3: - endpoint: "" - access_key: "" - secret_key: "" - bucket: "" + endpoint: "files.littlelan.cn" + access_key: "E6bMcYkQzCldRTrtmhvi" + secret_key: "4R9yjmwKNoHphiBkv05Oa8WGEIFbnlZeTLXfSgx3" + bucket: "test" use_ssl: true region: us-east-1 - domain: "" + domain: "files.littlelan.cn" # JWT配置 # 环境变量: APP_JWT_SECRET jwt: @@ -130,12 +130,12 @@ audit: # Gorse推荐系统配置 # 环境变量: APP_GORSE_ADDRESS, APP_GORSE_API_KEY, APP_GORSE_DASHBOARD, APP_GORSE_IMPORT_PASSWORD gorse: - enabled: false - address: "" # Gorse server地址 + enabled: true + address: "http://111.170.19.33:8088" # Gorse server地址 api_key: "" # API密钥 dashboard: "" # Gorse dashboard地址 - import_password: "" # 导入数据密码 - embedding_api_key: "" + import_password: "lanyimin123" # 导入数据密码 + embedding_api_key: "sk-ZPN5NMPSqEaOGCPfD2LqndZ5Wwmw3DC4CQgzgKhM35fI3RpD" embedding_url: "https://api.littlelan.cn/v1/embeddings" embedding_model: "BAAI/bge-m3" @@ -147,7 +147,7 @@ gorse: openai: enabled: true base_url: "https://api.littlelan.cn/" - api_key: "" + api_key: "sk-y7LOeKsNfzbZWTRSFsTs79jd8WYlezbIVgdVPgMvG4Xz2AlV" moderation_model: "qwen3.5-122b" moderation_max_images_per_request: 1 request_timeout: 30 @@ -160,12 +160,12 @@ openai: # APP_EMAIL_FROM_ADDRESS, APP_EMAIL_FROM_NAME # APP_EMAIL_USE_TLS, APP_EMAIL_INSECURE_SKIP_VERIFY, APP_EMAIL_TIMEOUT email: - enabled: false - host: "" - port: 587 - username: "" - password: "" - from_address: "" + enabled: true + host: "smtp.exmail.qq.com" + port: 465 + username: "no-reply@qczlit.cn" + password: "HbvwwVjRyiWg9gsK" + from_address: "no-reply@qczlit.cn" from_name: "Carrot BBS" use_tls: true insecure_skip_verify: false diff --git a/internal/dto/converter.go b/internal/dto/converter.go index 4b82c0d..dc37f89 100644 --- a/internal/dto/converter.go +++ b/internal/dto/converter.go @@ -284,6 +284,7 @@ func ConvertPostToResponse(post *model.Post, isLiked, isFavorited bool) *PostRes Title: post.Title, Content: post.Content, Images: images, + Status: string(post.Status), LikesCount: post.LikesCount, CommentsCount: post.CommentsCount, FavoritesCount: post.FavoritesCount, @@ -293,6 +294,7 @@ func ConvertPostToResponse(post *model.Post, isLiked, isFavorited bool) *PostRes IsLocked: post.IsLocked, IsVote: post.IsVote, CreatedAt: FormatTime(post.CreatedAt), + UpdatedAt: FormatTime(post.UpdatedAt), Author: author, IsLiked: isLiked, IsFavorited: isFavorited, diff --git a/internal/dto/dto.go b/internal/dto/dto.go index d0d13a3..36225b8 100644 --- a/internal/dto/dto.go +++ b/internal/dto/dto.go @@ -68,6 +68,7 @@ type PostResponse struct { Title string `json:"title"` Content string `json:"content"` Images []PostImageResponse `json:"images"` + Status string `json:"status,omitempty"` LikesCount int `json:"likes_count"` CommentsCount int `json:"comments_count"` FavoritesCount int `json:"favorites_count"` @@ -77,6 +78,7 @@ type PostResponse struct { IsLocked bool `json:"is_locked"` IsVote bool `json:"is_vote"` CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` Author *UserResponse `json:"author"` IsLiked bool `json:"is_liked"` IsFavorited bool `json:"is_favorited"` diff --git a/internal/handler/message_handler.go b/internal/handler/message_handler.go index bf8d964..27adb3d 100644 --- a/internal/handler/message_handler.go +++ b/internal/handler/message_handler.go @@ -2,12 +2,16 @@ package handler import ( "context" + "fmt" + "net/http" "strconv" + "time" "github.com/gin-gonic/gin" "carrot_bbs/internal/dto" "carrot_bbs/internal/model" + "carrot_bbs/internal/pkg/sse" "carrot_bbs/internal/pkg/response" "carrot_bbs/internal/service" ) @@ -18,18 +22,111 @@ type MessageHandler struct { messageService *service.MessageService userService *service.UserService groupService service.GroupService + sseHub *sse.Hub } // NewMessageHandler 创建消息处理器 -func NewMessageHandler(chatService service.ChatService, messageService *service.MessageService, userService *service.UserService, groupService service.GroupService) *MessageHandler { +func NewMessageHandler(chatService service.ChatService, messageService *service.MessageService, userService *service.UserService, groupService service.GroupService, sseHub *sse.Hub) *MessageHandler { return &MessageHandler{ chatService: chatService, messageService: messageService, userService: userService, groupService: groupService, + sseHub: sseHub, } } +// HandleSSE 实时消息订阅(SSE) +// GET /api/v1/realtime/sse +func (h *MessageHandler) HandleSSE(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + if h.sseHub == nil { + response.InternalServerError(c, "sse hub not available") + return + } + + lastID := sse.ParseEventID(c.GetHeader("Last-Event-ID")) + if lastID == 0 { + lastID = sse.ParseEventID(c.Query("last_event_id")) + } + ch, cancel, replay := h.sseHub.Subscribe(userID, lastID) + defer cancel() + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + response.InternalServerError(c, "streaming unsupported") + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + c.Status(http.StatusOK) + flusher.Flush() + + writeEvent := func(ev sse.Event) bool { + data, err := sse.EncodeData(ev) + if err != nil { + return false + } + if _, err := fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", ev.ID, ev.Event, data); err != nil { + return false + } + flusher.Flush() + return true + } + + for _, ev := range replay { + if !writeEvent(ev) { + return + } + } + + heartbeat := time.NewTicker(25 * time.Second) + defer heartbeat.Stop() + + for { + select { + case <-c.Request.Context().Done(): + return + case ev, ok := <-ch: + if !ok || !writeEvent(ev) { + return + } + case <-heartbeat.C: + if _, err := fmt.Fprint(w, "event: heartbeat\ndata: {}\n\n"); err != nil { + return + } + flusher.Flush() + } + } +} + +// HandleTyping 输入状态上报 +// POST /api/v1/conversations/typing +func (h *MessageHandler) HandleTyping(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + var params struct { + ConversationID string `json:"conversation_id" binding:"required"` + } + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + h.chatService.SendTyping(c.Request.Context(), userID, params.ConversationID) + response.SuccessWithMessage(c, "typing sent", nil) +} + // GetConversations 获取会话列表 // GET /api/conversations func (h *MessageHandler) GetConversations(c *gin.Context) { diff --git a/internal/handler/post_handler.go b/internal/handler/post_handler.go index fe6ada1..d8a67d3 100644 --- a/internal/handler/post_handler.go +++ b/internal/handler/post_handler.go @@ -105,6 +105,7 @@ func (h *PostHandler) GetByID(c *gin.Context) { Title: post.Title, Content: post.Content, Images: dto.ConvertPostImagesToResponse(post.Images), + Status: string(post.Status), LikesCount: post.LikesCount, CommentsCount: post.CommentsCount, FavoritesCount: post.FavoritesCount, @@ -114,6 +115,7 @@ func (h *PostHandler) GetByID(c *gin.Context) { IsLocked: post.IsLocked, IsVote: post.IsVote, CreatedAt: dto.FormatTime(post.CreatedAt), + UpdatedAt: dto.FormatTime(post.UpdatedAt), Author: authorWithFollowStatus, IsLiked: isLiked, IsFavorited: isFavorited, @@ -175,10 +177,18 @@ func (h *PostHandler) List(c *gin.Context) { posts, total, err = h.postService.GetRecommendedPosts(c.Request.Context(), currentUserID, page, pageSize) case "latest": // 最新帖子 - posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID) + if userID != "" && userID == currentUserID { + posts, total, err = h.postService.GetLatestPostsForOwner(c.Request.Context(), page, pageSize, userID) + } else { + posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID) + } default: // 默认获取最新帖子 - posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID) + if userID != "" && userID == currentUserID { + posts, total, err = h.postService.GetLatestPostsForOwner(c.Request.Context(), page, pageSize, userID) + } else { + posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID) + } } if err != nil { @@ -225,8 +235,9 @@ func (h *PostHandler) Update(c *gin.Context) { } type UpdateRequest struct { - Title string `json:"title"` - Content string `json:"content"` + Title string `json:"title"` + Content string `json:"content"` + Images *[]string `json:"images"` } var req UpdateRequest @@ -242,12 +253,18 @@ func (h *PostHandler) Update(c *gin.Context) { post.Content = req.Content } - err = h.postService.Update(c.Request.Context(), post) + err = h.postService.UpdateWithImages(c.Request.Context(), post, req.Images) if err != nil { response.InternalServerError(c, "failed to update post") return } + post, err = h.postService.GetByID(c.Request.Context(), post.ID) + if err != nil { + response.InternalServerError(c, "failed to get updated post") + return + } + currentUserID := c.GetString("user_id") var isLiked, isFavorited bool if currentUserID != "" { @@ -410,14 +427,15 @@ func (h *PostHandler) GetUserPosts(c *gin.Context) { page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) - posts, total, err := h.postService.GetUserPosts(c.Request.Context(), userID, page, pageSize) + currentUserID := c.GetString("user_id") + includePending := currentUserID != "" && currentUserID == userID + posts, total, err := h.postService.GetUserPosts(c.Request.Context(), userID, page, pageSize, includePending) if err != nil { response.InternalServerError(c, "failed to get user posts") return } // 获取当前用户ID用于判断点赞和收藏状态 - currentUserID := c.GetString("user_id") isLikedMap := make(map[string]bool) isFavoritedMap := make(map[string]bool) if currentUserID != "" { diff --git a/internal/handler/websocket_handler.go b/internal/handler/websocket_handler.go deleted file mode 100644 index 638f2cf..0000000 --- a/internal/handler/websocket_handler.go +++ /dev/null @@ -1,849 +0,0 @@ -package handler - -import ( - "context" - "encoding/json" - "log" - "net/http" - "strconv" - "strings" - "time" - - "carrot_bbs/internal/dto" - "carrot_bbs/internal/model" - ws "carrot_bbs/internal/pkg/websocket" - "carrot_bbs/internal/repository" - "carrot_bbs/internal/service" - - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" -) - -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return true // 允许所有来源,生产环境应该限制 - }, -} - -// WebSocketHandler WebSocket处理器 -type WebSocketHandler struct { - jwtService *service.JWTService - chatService service.ChatService - groupService service.GroupService - groupRepo repository.GroupRepository - userRepo *repository.UserRepository - wsManager *ws.WebSocketManager -} - -// NewWebSocketHandler 创建WebSocket处理器 -func NewWebSocketHandler( - jwtService *service.JWTService, - chatService service.ChatService, - groupService service.GroupService, - groupRepo repository.GroupRepository, - userRepo *repository.UserRepository, - wsManager *ws.WebSocketManager, -) *WebSocketHandler { - return &WebSocketHandler{ - jwtService: jwtService, - chatService: chatService, - groupService: groupService, - groupRepo: groupRepo, - userRepo: userRepo, - wsManager: wsManager, - } -} - -// HandleWebSocket 处理WebSocket连接 -func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) { - // 调试:打印请求头信息 - log.Printf("[WebSocket] 收到请求: Method=%s, Path=%s", c.Request.Method, c.Request.URL.Path) - log.Printf("[WebSocket] 请求头: Connection=%s, Upgrade=%s", - c.GetHeader("Connection"), - c.GetHeader("Upgrade")) - log.Printf("[WebSocket] Sec-WebSocket-Key=%s, Sec-WebSocket-Version=%s", - c.GetHeader("Sec-WebSocket-Key"), - c.GetHeader("Sec-WebSocket-Version")) - - // 从query参数获取token - token := c.Query("token") - if token == "" { - // 尝试从header获取 - authHeader := c.GetHeader("Authorization") - if strings.HasPrefix(authHeader, "Bearer ") { - token = strings.TrimPrefix(authHeader, "Bearer ") - } - } - - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "missing token"}) - return - } - - // 验证token - claims, err := h.jwtService.ParseToken(token) - if err != nil { - log.Printf("Invalid token: %v", err) - c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"}) - return - } - - userID := claims.UserID - if userID == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token claims"}) - return - } - - // 升级HTTP连接为WebSocket连接 - conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) - if err != nil { - log.Printf("Failed to upgrade connection: %v", err) - log.Printf("[WebSocket] 请求详情 - User-Agent: %s, Content-Type: %s", - c.GetHeader("User-Agent"), - c.GetHeader("Content-Type")) - return - } - - // 创建客户端 - client := &ws.Client{ - ID: userID, - UserID: userID, - Conn: conn, - Send: make(chan []byte, 256), - Manager: h.wsManager, - } - - // 注册客户端 - h.wsManager.Register(client) - - // 启动读写协程 - go client.WritePump() - go h.handleMessages(client) - -} - -// handleMessages 处理客户端消息 -// 针对移动端优化:增加超时时间到 120 秒,配合客户端 55 秒心跳 -func (h *WebSocketHandler) handleMessages(client *ws.Client) { - defer func() { - h.wsManager.Unregister(client) - client.Conn.Close() - }() - - client.Conn.SetReadLimit(512 * 1024) // 512KB - client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒 - client.Conn.SetPongHandler(func(string) error { - client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒 - return nil - }) - - // 心跳定时器 - 服务端主动 ping 间隔保持 30 秒 - pingTicker := time.NewTicker(30 * time.Second) - defer pingTicker.Stop() - - for { - select { - case <-pingTicker.C: - // 发送心跳 - if err := client.SendPing(); err != nil { - log.Printf("Failed to send ping: %v", err) - return - } - default: - _, message, err := client.Conn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - log.Printf("WebSocket error: %v", err) - } - return - } - - var wsMsg ws.WSMessage - if err := json.Unmarshal(message, &wsMsg); err != nil { - log.Printf("Failed to unmarshal message: %v", err) - continue - } - - h.processMessage(client, &wsMsg) - } - } -} - -// processMessage 处理消息 -func (h *WebSocketHandler) processMessage(client *ws.Client, msg *ws.WSMessage) { - switch msg.Type { - case ws.MessageTypePing: - // 响应心跳 - if err := client.SendPong(); err != nil { - log.Printf("Failed to send pong: %v", err) - } - - case ws.MessageTypePong: - // 客户端响应心跳 - - case ws.MessageTypeMessage: - // 处理聊天消息 - h.handleChatMessage(client, msg) - - case ws.MessageTypeTyping: - // 处理正在输入状态 - h.handleTyping(client, msg) - - case ws.MessageTypeRead: - // 处理已读回执 - h.handleReadReceipt(client, msg) - - // 群组消息处理 - case ws.MessageTypeGroupMessage: - // 处理群消息 - h.handleGroupMessage(client, msg) - - case ws.MessageTypeGroupTyping: - // 处理群输入状态 - h.handleGroupTyping(client, msg) - - case ws.MessageTypeGroupRead: - // 处理群消息已读 - h.handleGroupReadReceipt(client, msg) - - case ws.MessageTypeGroupRecall: - // 处理群消息撤回 - h.handleGroupRecall(client, msg) - - default: - log.Printf("Unknown message type: %s", msg.Type) - } -} - -// handleChatMessage 处理聊天消息 -func (h *WebSocketHandler) handleChatMessage(client *ws.Client, msg *ws.WSMessage) { - data, ok := msg.Data.(map[string]interface{}) - if !ok { - log.Printf("Invalid message data format") - return - } - - conversationIDStr, _ := data["conversationId"].(string) - - if conversationIDStr == "" { - log.Printf("Missing conversationId") - return - } - - // 解析会话ID - conversationID, err := service.ParseConversationID(conversationIDStr) - if err != nil { - log.Printf("Invalid conversation ID: %v", err) - return - } - - // 解析 segments - var segments model.MessageSegments - if data["segments"] != nil { - segmentsBytes, err := json.Marshal(data["segments"]) - if err == nil { - json.Unmarshal(segmentsBytes, &segments) - } - } - - // 从 segments 中提取回复消息ID - replyToID := dto.GetReplyMessageID(segments) - var replyToIDPtr *string - if replyToID != "" { - replyToIDPtr = &replyToID - } - - // 发送消息 - 使用 segments - message, err := h.chatService.SendMessage(context.Background(), client.UserID, conversationID, segments, replyToIDPtr) - if err != nil { - log.Printf("Failed to send message: %v", err) - // 发送错误消息 - errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{ - "error": "Failed to send message", - }) - if client.Send != nil { - msgBytes, _ := json.Marshal(errorMsg) - client.Send <- msgBytes - } - return - } - - // 发送确认消息(使用 meta 事件格式,包含完整的消息内容) - metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{ - "detail_type": ws.MetaDetailTypeAck, - "conversation_id": conversationID, - "id": message.ID, - "user_id": client.UserID, - "sender_id": client.UserID, - "seq": message.Seq, - "segments": message.Segments, - "created_at": message.CreatedAt.UnixMilli(), - }) - if client.Send != nil { - msgBytes, _ := json.Marshal(metaAckMsg) - client.Send <- msgBytes - } -} - -// handleTyping 处理正在输入状态 -func (h *WebSocketHandler) handleTyping(client *ws.Client, msg *ws.WSMessage) { - data, ok := msg.Data.(map[string]interface{}) - if !ok { - return - } - - conversationIDStr, _ := data["conversationId"].(string) - if conversationIDStr == "" { - return - } - - conversationID, err := service.ParseConversationID(conversationIDStr) - if err != nil { - return - } - - // 直接使用 string 类型的 userID - h.chatService.SendTyping(context.Background(), client.UserID, conversationID) -} - -// handleReadReceipt 处理已读回执 -func (h *WebSocketHandler) handleReadReceipt(client *ws.Client, msg *ws.WSMessage) { - data, ok := msg.Data.(map[string]interface{}) - if !ok { - return - } - - conversationIDStr, _ := data["conversationId"].(string) - if conversationIDStr == "" { - return - } - - conversationID, err := service.ParseConversationID(conversationIDStr) - if err != nil { - return - } - - // 获取lastReadSeq - lastReadSeq, _ := data["lastReadSeq"].(float64) - if lastReadSeq == 0 { - return - } - - // 直接使用 string 类型的 userID 和 conversationID - if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil { - log.Printf("Failed to mark as read: %v", err) - } -} - -// ==================== 群组消息处理 ==================== - -// handleGroupMessage 处理群消息 -func (h *WebSocketHandler) handleGroupMessage(client *ws.Client, msg *ws.WSMessage) { - // 打印接收到的消息类型和数据,用于调试 - log.Printf("[handleGroupMessage] Received message type: %s", msg.Type) - log.Printf("[handleGroupMessage] Message data: %+v", msg.Data) - - data, ok := msg.Data.(map[string]interface{}) - if !ok { - log.Printf("Invalid group message data format: data is not map[string]interface{}") - return - } - - // 解析群组ID(支持 camelCase 和 snake_case) - var groupIDFloat float64 - groupID := "" // 使用 groupID 作为最终变量名 - if val, ok := data["groupId"].(float64); ok { - groupIDFloat = val - groupID = strconv.FormatFloat(groupIDFloat, 'f', 0, 64) - } else if val, ok := data["group_id"].(string); ok { - groupID = val - } - - if groupID == "" { - log.Printf("Missing groupId in group message") - return - } - - // 解析会话ID(支持 camelCase 和 snake_case) - var conversationID string - if val, ok := data["conversationId"].(string); ok { - conversationID = val - } else if val, ok := data["conversation_id"].(string); ok { - conversationID = val - } - if conversationID == "" { - log.Printf("Missing conversationId in group message") - return - } - - // 解析 segments - var segments model.MessageSegments - if data["segments"] != nil { - segmentsBytes, err := json.Marshal(data["segments"]) - if err == nil { - json.Unmarshal(segmentsBytes, &segments) - } - } - - // 解析@用户列表(支持 camelCase 和 snake_case) - var mentionUsers []uint64 - var mentionUsersInterface []interface{} - if val, ok := data["mentionUsers"].([]interface{}); ok { - mentionUsersInterface = val - } else if val, ok := data["mention_users"].([]interface{}); ok { - mentionUsersInterface = val - } - if len(mentionUsersInterface) > 0 { - for _, uid := range mentionUsersInterface { - if uidFloat, ok := uid.(float64); ok { - mentionUsers = append(mentionUsers, uint64(uidFloat)) - } else if uidStr, ok := uid.(string); ok { - // 处理字符串格式的用户ID - if uidInt, err := strconv.ParseUint(uidStr, 10, 64); err == nil { - mentionUsers = append(mentionUsers, uidInt) - } - } - } - } - - // 解析@所有人(支持 camelCase 和 snake_case) - var mentionAll bool - if val, ok := data["mentionAll"].(bool); ok { - mentionAll = val - } else if val, ok := data["mention_all"].(bool); ok { - mentionAll = val - } - - // 检查用户是否可以发送群消息(验证成员身份和禁言状态) - // client.UserID 已经是 string 格式的 UUID - if err := h.groupService.CanSendGroupMessage(client.UserID, groupID); err != nil { - log.Printf("User cannot send group message: %v", err) - // 发送错误消息 - errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{ - "error": "Cannot send group message", - "reason": err.Error(), - "type": "group_message_error", - "groupId": groupID, - }) - if client.Send != nil { - msgBytes, _ := json.Marshal(errorMsg) - client.Send <- msgBytes - } - return - } - - // 检查@所有人权限(只有群主和管理员可以@所有人) - if mentionAll { - if !h.groupService.IsGroupAdmin(client.UserID, groupID) { - log.Printf("User %s has no permission to mention all in group %s", client.UserID, groupID) - mentionAll = false // 取消@所有人标记 - } - } - - // 创建消息 - message := &model.Message{ - ConversationID: conversationID, - SenderID: client.UserID, - Segments: segments, - Status: model.MessageStatusNormal, - MentionAll: mentionAll, - } - - // 序列化mentionUsers为JSON - if len(mentionUsers) > 0 { - mentionUsersJSON, _ := json.Marshal(mentionUsers) - message.MentionUsers = string(mentionUsersJSON) - } - - // 保存消息到数据库(只存库,不发私聊 WebSocket 帧,群消息通过 BroadcastGroupMessage 单独广播) - savedMessage, err := h.chatService.SaveMessage(context.Background(), client.UserID, conversationID, segments, nil) - if err != nil { - log.Printf("Failed to save group message: %v", err) - errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{ - "error": "Failed to save group message", - }) - if client.Send != nil { - msgBytes, _ := json.Marshal(errorMsg) - client.Send <- msgBytes - } - return - } - - // 更新消息的mention信息 - if len(mentionUsers) > 0 || mentionAll { - message.ID = savedMessage.ID - message.Seq = savedMessage.Seq - } - - // 构造群消息响应 - groupMsg := &ws.GroupMessage{ - ID: savedMessage.ID, - ConversationID: conversationID, - GroupID: groupID, - SenderID: client.UserID, - Seq: savedMessage.Seq, - Segments: segments, - MentionUsers: mentionUsers, - MentionAll: mentionAll, - CreatedAt: savedMessage.CreatedAt.UnixMilli(), - } - - // 广播消息给群组所有成员(排除发送者) - h.BroadcastGroupMessage(groupID, groupMsg, client.UserID) - - // 发送确认消息给发送者(作为meta事件) - // 使用 meta 事件格式发送 ack - metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{ - "detail_type": ws.MetaDetailTypeAck, - "conversation_id": conversationID, - "group_id": groupID, - "id": savedMessage.ID, - "user_id": client.UserID, - "sender_id": client.UserID, - "seq": savedMessage.Seq, - "segments": segments, - "created_at": savedMessage.CreatedAt.UnixMilli(), - }) - if client.Send != nil { - msgBytes, _ := json.Marshal(metaAckMsg) - client.Send <- msgBytes - } else { - log.Printf("[ERROR HandleGroupMessageSend] client.Send 为 nil, userID=%s", client.UserID) - } - - // 处理@提及通知 - if len(mentionUsers) > 0 || mentionAll { - // 提取文本正文(不含 @ 部分) - textContent := dto.ExtractTextContentFromModel(segments) - // 在通知内容前拼接被@的真实昵称,通过群成员列表查找 - mentionContent := h.buildMentionContent(groupID, mentionUsers, mentionAll, textContent) - h.handleGroupMention(groupID, savedMessage.ID, client.UserID, mentionContent, mentionUsers, mentionAll) - } -} - -// handleGroupTyping 处理群输入状态 -func (h *WebSocketHandler) handleGroupTyping(client *ws.Client, msg *ws.WSMessage) { - data, ok := msg.Data.(map[string]interface{}) - if !ok { - return - } - - groupIDFloat, _ := data["groupId"].(float64) - if groupIDFloat == 0 { - return - } - groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64) - - isTyping, _ := data["isTyping"].(bool) - - // 验证用户是否是群成员 - // client.UserID 已经是 string 格式的 UUID - isMember, err := h.groupRepo.IsMember(groupID, client.UserID) - if err != nil || !isMember { - return - } - - // 获取用户信息 - user, err := h.userRepo.GetByID(client.UserID) - if err != nil { - return - } - - // 构造输入状态消息 - typingMsg := &ws.GroupTypingMessage{ - GroupID: groupID, - UserID: client.UserID, - Username: user.Username, - IsTyping: isTyping, - } - - // 广播给群组其他成员 - wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg) - h.BroadcastGroupNoticeExclude(groupID, wsMsg, client.UserID) -} - -// handleGroupReadReceipt 处理群消息已读回执 -func (h *WebSocketHandler) handleGroupReadReceipt(client *ws.Client, msg *ws.WSMessage) { - data, ok := msg.Data.(map[string]interface{}) - if !ok { - return - } - - conversationID, _ := data["conversationId"].(string) - if conversationID == "" { - return - } - - lastReadSeq, _ := data["lastReadSeq"].(float64) - if lastReadSeq == 0 { - return - } - - // 标记已读 - if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil { - log.Printf("Failed to mark group message as read: %v", err) - } -} - -// handleGroupRecall 处理群消息撤回 -func (h *WebSocketHandler) handleGroupRecall(client *ws.Client, msg *ws.WSMessage) { - data, ok := msg.Data.(map[string]interface{}) - if !ok { - return - } - - messageID, _ := data["messageId"].(string) - if messageID == "" { - return - } - - groupIDFloat, _ := data["groupId"].(float64) - if groupIDFloat == 0 { - return - } - groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64) - - // 撤回消息 - if err := h.chatService.RecallMessage(context.Background(), messageID, client.UserID); err != nil { - log.Printf("Failed to recall group message: %v", err) - errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{ - "error": "Failed to recall message", - }) - if client.Send != nil { - msgBytes, _ := json.Marshal(errorMsg) - client.Send <- msgBytes - } - return - } - - // 广播撤回通知给群组所有成员 - recallNotice := ws.CreateWSMessage(ws.MessageTypeGroupRecall, map[string]interface{}{ - "messageId": messageID, - "groupId": groupID, - "userId": client.UserID, - "timestamp": time.Now().UnixMilli(), - }) - h.BroadcastGroupNotice(groupID, recallNotice) -} - -// handleGroupMention 处理群消息@提及通知 -func (h *WebSocketHandler) handleGroupMention(groupID string, messageID, senderID, content string, mentionUsers []uint64, mentionAll bool) { - // 如果@所有人,获取所有群成员 - if mentionAll { - members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) - if err != nil { - log.Printf("Failed to get group members for mention all: %v", err) - return - } - - for _, member := range members { - // 不通知发送者自己 - memberIDStr := member.UserID - if memberIDStr == senderID { - continue - } - - // 发送@提及通知 - mentionMsg := &ws.GroupMentionMessage{ - GroupID: groupID, - MessageID: messageID, - FromUserID: senderID, - Content: truncateContent(content, 50), - MentionAll: true, - CreatedAt: time.Now().UnixMilli(), - } - wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg) - h.wsManager.SendToUser(memberIDStr, wsMsg) - } - return - } - - // 处理特定用户的@提及 - for _, userID := range mentionUsers { - // userID 是 uint64,转换为 string - userIDStr := strconv.FormatUint(userID, 10) - if userIDStr == senderID { - continue // 不通知发送者自己 - } - - mentionMsg := &ws.GroupMentionMessage{ - GroupID: groupID, - MessageID: messageID, - FromUserID: senderID, - Content: truncateContent(content, 50), - MentionAll: false, - CreatedAt: time.Now().UnixMilli(), - } - wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg) - h.wsManager.SendToUser(userIDStr, wsMsg) - } -} - -// buildMentionContent 构建@提及通知的内容,通过群成员列表查找被@用户的真实昵称 -func (h *WebSocketHandler) buildMentionContent(groupID string, mentionUsers []uint64, mentionAll bool, textBody string) string { - var prefix string - if mentionAll { - prefix = "@所有人 " - } else if len(mentionUsers) > 0 { - // 查询群成员列表,找到被@用户的昵称 - members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) - if err == nil { - memberNickMap := make(map[string]string, len(members)) - for _, m := range members { - displayName := m.Nickname - if displayName == "" { - displayName = m.UserID - } - memberNickMap[m.UserID] = displayName - } - for _, uid := range mentionUsers { - uidStr := strconv.FormatUint(uid, 10) - if name, ok := memberNickMap[uidStr]; ok { - prefix += "@" + name + " " - } else { - prefix += "@某人 " - } - } - } else { - for range mentionUsers { - prefix += "@某人 " - } - } - } - return prefix + textBody -} - -// BroadcastGroupMessage 向群组所有成员广播消息 -func (h *WebSocketHandler) BroadcastGroupMessage(groupID string, message *ws.GroupMessage, excludeUserID string) { - // 获取群组所有成员 - members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) - if err != nil { - log.Printf("Failed to get group members for broadcast: %v", err) - return - } - - // 创建WebSocket消息 - wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMessage, message) - - // 遍历成员,如果在线则发送消息 - for _, member := range members { - memberIDStr := member.UserID - - // 排除发送者 - if memberIDStr == excludeUserID { - continue - } - - // 发送消息 - h.wsManager.SendToUser(memberIDStr, wsMsg) - } -} - -// BroadcastGroupNotice 广播群组通知给所有成员 -func (h *WebSocketHandler) BroadcastGroupNotice(groupID string, notice *ws.WSMessage) { - // 获取群组所有成员 - members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) - if err != nil { - log.Printf("Failed to get group members for notice broadcast: %v", err) - return - } - - // 遍历成员,如果在线则发送通知 - for _, member := range members { - memberIDStr := member.UserID - h.wsManager.SendToUser(memberIDStr, notice) - } -} - -// BroadcastGroupNoticeExclude 广播群组通知给所有成员(排除指定用户) -func (h *WebSocketHandler) BroadcastGroupNoticeExclude(groupID string, notice *ws.WSMessage, excludeUserID string) { - // 获取群组所有成员 - members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) - if err != nil { - log.Printf("Failed to get group members for notice broadcast: %v", err) - return - } - - // 遍历成员,如果在线则发送通知 - for _, member := range members { - memberIDStr := member.UserID - if memberIDStr == excludeUserID { - continue - } - h.wsManager.SendToUser(memberIDStr, notice) - } -} - -// SendGroupMemberNotice 发送群成员变动通知 -func (h *WebSocketHandler) SendGroupMemberNotice(noticeType string, groupID string, data *ws.GroupNoticeData) { - notice := &ws.GroupNoticeMessage{ - NoticeType: noticeType, - GroupID: groupID, - Data: data, - Timestamp: time.Now().UnixMilli(), - } - wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupNotice, notice) - h.BroadcastGroupNotice(groupID, wsMsg) -} - -// truncateContent 截断内容 -func truncateContent(content string, maxLen int) string { - if len(content) <= maxLen { - return content - } - return content[:maxLen] + "..." -} - -// BroadcastGroupTyping 向群组所有成员广播输入状态 -func (h *WebSocketHandler) BroadcastGroupTyping(groupID string, typingMsg *ws.GroupTypingMessage, excludeUserID string) { - // 获取群组所有成员 - members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) - if err != nil { - log.Printf("Failed to get group members for typing broadcast: %v", err) - return - } - - // 创建WebSocket消息 - wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg) - - // 遍历成员,如果在线则发送消息 - for _, member := range members { - memberIDStr := member.UserID - - // 排除指定用户 - if memberIDStr == excludeUserID { - continue - } - - // 发送消息 - h.wsManager.SendToUser(memberIDStr, wsMsg) - } -} - -// BroadcastGroupRead 向群组所有成员广播已读状态 -func (h *WebSocketHandler) BroadcastGroupRead(groupID string, readMsg map[string]interface{}, excludeUserID string) { - // 获取群组所有成员 - members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) - if err != nil { - log.Printf("Failed to get group members for read broadcast: %v", err) - return - } - - // 创建WebSocket消息 - wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupRead, readMsg) - - // 遍历成员,如果在线则发送消息 - for _, member := range members { - memberIDStr := member.UserID - - // 排除指定用户 - if memberIDStr == excludeUserID { - continue - } - - // 发送消息 - h.wsManager.SendToUser(memberIDStr, wsMsg) - } -} diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go index d6039bb..96fcf51 100644 --- a/internal/middleware/cors.go +++ b/internal/middleware/cors.go @@ -1,18 +1,10 @@ package middleware -import ( - "log" - "strings" - - "github.com/gin-gonic/gin" -) +import "github.com/gin-gonic/gin" // CORS CORS中间件 func CORS() gin.HandlerFunc { return func(c *gin.Context) { - // 获取请求路径 - path := c.Request.URL.Path - c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") // 添加 WebSocket 升级所需的头 @@ -22,25 +14,10 @@ func CORS() gin.HandlerFunc { // 处理 WebSocket 升级请求的预检 if c.Request.Method == "OPTIONS" { - log.Printf("[CORS] OPTIONS 预检请求: %s", path) c.AbortWithStatus(204) return } - // 针对 WebSocket 路径的特殊处理 - if path == "/ws" { - connection := c.GetHeader("Connection") - upgrade := c.GetHeader("Upgrade") - log.Printf("[CORS] WebSocket 请求: Connection=%s, Upgrade=%s", connection, upgrade) - - // 检查是否是有效的 WebSocket 升级请求 - if strings.Contains(strings.ToLower(connection), "upgrade") && strings.ToLower(upgrade) == "websocket" { - log.Printf("[CORS] 有效的 WebSocket 升级请求") - } else { - log.Printf("[CORS] 警告: 不是有效的 WebSocket 升级请求!") - } - } - c.Next() } } diff --git a/internal/model/post.go b/internal/model/post.go index 1c21e7d..655e1b0 100644 --- a/internal/model/post.go +++ b/internal/model/post.go @@ -58,7 +58,7 @@ type Post struct { // 时间戳 CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_posts_status_created,priority:2,sort:desc;index:idx_posts_user_status_created,priority:3,sort:desc;index:idx_posts_hot_score_created,priority:2,sort:desc"` - UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime:false"` } // BeforeCreate 创建前生成UUID diff --git a/internal/pkg/sse/hub.go b/internal/pkg/sse/hub.go new file mode 100644 index 0000000..57585f7 --- /dev/null +++ b/internal/pkg/sse/hub.go @@ -0,0 +1,152 @@ +package sse + +import ( + "encoding/json" + "strconv" + "sync" + "sync/atomic" + "time" +) + +const ( + defaultUserBufferSize = 128 + maxReplayEvents = 200 +) + +type Event struct { + ID uint64 `json:"event_id"` + Event string `json:"event"` + TS int64 `json:"ts"` + Payload interface{} `json:"payload"` +} + +type subscriber struct { + id uint64 + ch chan Event + quit chan struct{} +} + +type Hub struct { + seq uint64 + + mu sync.RWMutex + subscribers map[string]map[uint64]*subscriber + history map[string][]Event +} + +func NewHub() *Hub { + return &Hub{ + subscribers: make(map[string]map[uint64]*subscriber), + history: make(map[string][]Event), + } +} + +func (h *Hub) NextID() uint64 { + return atomic.AddUint64(&h.seq, 1) +} + +func ParseEventID(raw string) uint64 { + if raw == "" { + return 0 + } + id, err := strconv.ParseUint(raw, 10, 64) + if err != nil { + return 0 + } + return id +} + +func (h *Hub) Subscribe(userID string, afterID uint64) (chan Event, func(), []Event) { + subID := h.NextID() + sub := &subscriber{ + id: subID, + ch: make(chan Event, defaultUserBufferSize), + quit: make(chan struct{}), + } + + h.mu.Lock() + if _, ok := h.subscribers[userID]; !ok { + h.subscribers[userID] = make(map[uint64]*subscriber) + } + h.subscribers[userID][subID] = sub + + replay := make([]Event, 0) + for _, e := range h.history[userID] { + if e.ID > afterID { + replay = append(replay, e) + } + } + h.mu.Unlock() + + cancel := func() { + h.mu.Lock() + defer h.mu.Unlock() + if userSubs, ok := h.subscribers[userID]; ok { + if s, exists := userSubs[subID]; exists { + close(s.quit) + delete(userSubs, subID) + close(s.ch) + } + if len(userSubs) == 0 { + delete(h.subscribers, userID) + } + } + } + + return sub.ch, cancel, replay +} + +func (h *Hub) HasSubscribers(userID string) bool { + h.mu.RLock() + defer h.mu.RUnlock() + return len(h.subscribers[userID]) > 0 +} + +func (h *Hub) PublishToUser(userID string, eventName string, payload interface{}) Event { + ev := Event{ + ID: h.NextID(), + Event: eventName, + TS: time.Now().UnixMilli(), + Payload: payload, + } + h.publish(userID, ev) + return ev +} + +func (h *Hub) PublishToUsers(userIDs []string, eventName string, payload interface{}) { + for _, uid := range userIDs { + h.PublishToUser(uid, eventName, payload) + } +} + +func (h *Hub) publish(userID string, ev Event) { + h.mu.Lock() + history := append(h.history[userID], ev) + if len(history) > maxReplayEvents { + history = history[len(history)-maxReplayEvents:] + } + h.history[userID] = history + + targets := make([]*subscriber, 0, len(h.subscribers[userID])) + for _, s := range h.subscribers[userID] { + targets = append(targets, s) + } + h.mu.Unlock() + + for _, s := range targets { + select { + case <-s.quit: + case s.ch <- ev: + default: + // 慢消费者丢弃单条消息,客户端可通过 Last-Event-ID + HTTP 同步补偿 + } + } +} + +func EncodeData(ev Event) (string, error) { + body, err := json.Marshal(ev.Payload) + if err != nil { + return "", err + } + return string(body), nil +} diff --git a/internal/pkg/websocket/websocket.go b/internal/pkg/websocket/websocket.go deleted file mode 100644 index 6d67b7c..0000000 --- a/internal/pkg/websocket/websocket.go +++ /dev/null @@ -1,435 +0,0 @@ -package websocket - -import ( - "carrot_bbs/internal/model" - "encoding/json" - "log" - "sync" - "time" - - "github.com/gorilla/websocket" -) - -// WebSocket消息类型常量 -const ( - MessageTypePing = "ping" - MessageTypePong = "pong" - MessageTypeMessage = "message" - MessageTypeTyping = "typing" - MessageTypeRead = "read" - MessageTypeAck = "ack" - MessageTypeError = "error" - MessageTypeRecall = "recall" // 撤回消息 - MessageTypeSystem = "system" // 系统消息 - MessageTypeNotification = "notification" // 通知消息 - MessageTypeAnnouncement = "announcement" // 公告消息 - - // 群组相关消息类型 - MessageTypeGroupMessage = "group_message" // 群消息 - MessageTypeGroupTyping = "group_typing" // 群输入状态 - MessageTypeGroupNotice = "group_notice" // 群组通知(成员变动等) - MessageTypeGroupMention = "group_mention" // @提及通知 - MessageTypeGroupRead = "group_read" // 群消息已读 - MessageTypeGroupRecall = "group_recall" // 群消息撤回 - - // Meta事件详细类型 - MetaDetailTypeHeartbeat = "heartbeat" - MetaDetailTypeTyping = "typing" - MetaDetailTypeAck = "ack" // 消息发送确认 - MetaDetailTypeRead = "read" // 已读回执 -) - -// WSMessage WebSocket消息结构 -type WSMessage struct { - Type string `json:"type"` - Data interface{} `json:"data"` - Timestamp int64 `json:"timestamp"` -} - -// ChatMessage 聊天消息结构 -type ChatMessage struct { - ID string `json:"id"` - ConversationID string `json:"conversation_id"` - SenderID string `json:"sender_id"` - Seq int64 `json:"seq"` - Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组) - ReplyToID *string `json:"reply_to_id,omitempty"` - CreatedAt int64 `json:"created_at"` -} - -// SystemMessage 系统消息结构 -type SystemMessage struct { - ID string `json:"id"` // 消息ID - Type string `json:"type"` // 消息子类型(如:account_banned, post_approved等) - Title string `json:"title"` // 消息标题 - Content string `json:"content"` // 消息内容 - Data map[string]interface{} `json:"data"` // 额外数据 - CreatedAt int64 `json:"created_at"` // 创建时间戳 -} - -// NotificationMessage 通知消息结构 -type NotificationMessage struct { - ID string `json:"id"` // 通知ID - Type string `json:"type"` // 通知类型(like, comment, follow, mention等) - Title string `json:"title"` // 通知标题 - Content string `json:"content"` // 通知内容 - TriggerUser *NotificationUser `json:"trigger_user"` // 触发用户 - ResourceType string `json:"resource_type"` // 资源类型(post, comment等) - ResourceID string `json:"resource_id"` // 资源ID - Extra map[string]interface{} `json:"extra"` // 额外数据 - CreatedAt int64 `json:"created_at"` // 创建时间戳 -} - -// NotificationUser 通知中的用户信息 -type NotificationUser struct { - ID string `json:"id"` - Username string `json:"username"` - Avatar string `json:"avatar"` -} - -// AnnouncementMessage 公告消息结构 -type AnnouncementMessage struct { - ID string `json:"id"` // 公告ID - Title string `json:"title"` // 公告标题 - Content string `json:"content"` // 公告内容 - Priority int `json:"priority"` // 优先级(1-10) - CreatedAt int64 `json:"created_at"` // 创建时间戳 -} - -// GroupMessage 群消息结构 -type GroupMessage struct { - ID string `json:"id"` // 消息ID - ConversationID string `json:"conversation_id"` // 会话ID(群聊会话) - GroupID string `json:"group_id"` // 群组ID - SenderID string `json:"sender_id"` // 发送者ID - Seq int64 `json:"seq"` // 消息序号 - Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组) - ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID - MentionUsers []uint64 `json:"mention_users,omitempty"` // @的用户ID列表 - MentionAll bool `json:"mention_all"` // 是否@所有人 - CreatedAt int64 `json:"created_at"` // 创建时间戳 -} - -// GroupTypingMessage 群输入状态消息 -type GroupTypingMessage struct { - GroupID string `json:"group_id"` // 群组ID - UserID string `json:"user_id"` // 用户ID - Username string `json:"username"` // 用户名 - IsTyping bool `json:"is_typing"` // 是否正在输入 -} - -// GroupNoticeMessage 群组通知消息 -type GroupNoticeMessage struct { - NoticeType string `json:"notice_type"` // 通知类型:member_join, member_leave, member_removed, role_changed, muted, unmuted, group_dissolved - GroupID string `json:"group_id"` // 群组ID - Data interface{} `json:"data"` // 通知数据 - Timestamp int64 `json:"timestamp"` // 时间戳 - MessageID string `json:"message_id,omitempty"` // 消息ID(如果通知保存为消息) - Seq int64 `json:"seq,omitempty"` // 消息序号(如果通知保存为消息) -} - -// GroupNoticeData 通知数据结构 -type GroupNoticeData struct { - // 成员变动 - UserID string `json:"user_id,omitempty"` // 变动的用户ID - Username string `json:"username,omitempty"` // 用户名 - OperatorID string `json:"operator_id,omitempty"` // 操作者ID - OpName string `json:"op_name,omitempty"` // 操作者名称 - NewRole string `json:"new_role,omitempty"` // 新角色 - OldRole string `json:"old_role,omitempty"` // 旧角色 - MemberCount int `json:"member_count,omitempty"` // 当前成员数 - - // 群设置变更 - MuteAll bool `json:"mute_all,omitempty"` // 全员禁言状态 -} - -// GroupMentionMessage @提及通知消息 -type GroupMentionMessage struct { - GroupID string `json:"group_id"` // 群组ID - MessageID string `json:"message_id"` // 消息ID - FromUserID string `json:"from_user_id"` // 发送者ID - FromName string `json:"from_name"` // 发送者名称 - Content string `json:"content"` // 消息内容预览 - MentionAll bool `json:"mention_all"` // 是否@所有人 - CreatedAt int64 `json:"created_at"` // 创建时间戳 -} - -// AckMessage 消息发送确认结构 -type AckMessage struct { - ConversationID string `json:"conversation_id"` // 会话ID - GroupID string `json:"group_id,omitempty"` // 群组ID(群聊时) - ID string `json:"id"` // 消息ID - SenderID string `json:"sender_id"` // 发送者ID - Seq int64 `json:"seq"` // 消息序号 - Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组) - CreatedAt int64 `json:"created_at"` // 创建时间戳 -} - -// Client WebSocket客户端 -type Client struct { - ID string - UserID string - Conn *websocket.Conn - Send chan []byte - Manager *WebSocketManager - IsClosed bool - Mu sync.Mutex - closeOnce sync.Once // 确保 Send channel 只关闭一次 -} - -// WebSocketManager WebSocket连接管理器 -type WebSocketManager struct { - clients map[string]*Client // userID -> Client - register chan *Client - unregister chan *Client - broadcast chan *BroadcastMessage - mutex sync.RWMutex -} - -// BroadcastMessage 广播消息 -type BroadcastMessage struct { - Message *WSMessage - ExcludeUser string // 排除的用户ID,为空表示不排除 - TargetUser string // 目标用户ID,为空表示广播给所有用户 -} - -// NewWebSocketManager 创建WebSocket管理器 -func NewWebSocketManager() *WebSocketManager { - return &WebSocketManager{ - clients: make(map[string]*Client), - register: make(chan *Client, 100), - unregister: make(chan *Client, 100), - broadcast: make(chan *BroadcastMessage, 100), - } -} - -// Start 启动管理器 -func (m *WebSocketManager) Start() { - go func() { - for { - select { - case client := <-m.register: - m.mutex.Lock() - m.clients[client.UserID] = client - m.mutex.Unlock() - log.Printf("WebSocket client connected: userID=%s, 当前在线用户数=%d", client.UserID, len(m.clients)) - - case client := <-m.unregister: - m.mutex.Lock() - if _, ok := m.clients[client.UserID]; ok { - delete(m.clients, client.UserID) - // 使用 closeOnce 确保 channel 只关闭一次,避免 panic - client.closeOnce.Do(func() { - close(client.Send) - }) - log.Printf("WebSocket client disconnected: userID=%s", client.UserID) - } - m.mutex.Unlock() - - case broadcast := <-m.broadcast: - m.sendMessage(broadcast) - } - } - }() -} - -// Register 注册客户端 -func (m *WebSocketManager) Register(client *Client) { - m.register <- client -} - -// Unregister 注销客户端 -func (m *WebSocketManager) Unregister(client *Client) { - m.unregister <- client -} - -// Broadcast 广播消息给所有用户 -func (m *WebSocketManager) Broadcast(msg *WSMessage) { - m.broadcast <- &BroadcastMessage{ - Message: msg, - TargetUser: "", - } -} - -// SendToUser 发送消息给指定用户 -func (m *WebSocketManager) SendToUser(userID string, msg *WSMessage) { - m.broadcast <- &BroadcastMessage{ - Message: msg, - TargetUser: userID, - } -} - -// SendToUsers 发送消息给指定用户列表 -func (m *WebSocketManager) SendToUsers(userIDs []string, msg *WSMessage) { - for _, userID := range userIDs { - m.SendToUser(userID, msg) - } -} - -// GetClient 获取客户端 -func (m *WebSocketManager) GetClient(userID string) (*Client, bool) { - m.mutex.RLock() - defer m.mutex.RUnlock() - client, ok := m.clients[userID] - return client, ok -} - -// GetAllClients 获取所有客户端 -func (m *WebSocketManager) GetAllClients() map[string]*Client { - m.mutex.RLock() - defer m.mutex.RUnlock() - return m.clients -} - -// GetClientCount 获取在线用户数量 -func (m *WebSocketManager) GetClientCount() int { - m.mutex.RLock() - defer m.mutex.RUnlock() - return len(m.clients) -} - -// IsUserOnline 检查用户是否在线 -func (m *WebSocketManager) IsUserOnline(userID string) bool { - m.mutex.RLock() - defer m.mutex.RUnlock() - _, ok := m.clients[userID] - return ok -} - -// sendMessage 发送消息 -func (m *WebSocketManager) sendMessage(broadcast *BroadcastMessage) { - msgBytes, err := json.Marshal(broadcast.Message) - if err != nil { - log.Printf("Failed to marshal message: %v", err) - return - } - - m.mutex.RLock() - defer m.mutex.RUnlock() - - for userID, client := range m.clients { - // 如果指定了目标用户,只发送给目标用户 - if broadcast.TargetUser != "" && userID != broadcast.TargetUser { - continue - } - - // 如果指定了排除用户,跳过 - if broadcast.ExcludeUser != "" && userID == broadcast.ExcludeUser { - continue - } - - select { - case client.Send <- msgBytes: - default: - log.Printf("Failed to send message to user %s: channel full", userID) - } - } -} - -// SendPing 发送心跳 -func (c *Client) SendPing() error { - c.Mu.Lock() - defer c.Mu.Unlock() - if c.IsClosed { - return nil - } - msg := WSMessage{ - Type: MessageTypePing, - Data: nil, - Timestamp: time.Now().UnixMilli(), - } - msgBytes, _ := json.Marshal(msg) - return c.Conn.WriteMessage(websocket.TextMessage, msgBytes) -} - -// SendPong 发送Pong响应 -func (c *Client) SendPong() error { - c.Mu.Lock() - defer c.Mu.Unlock() - if c.IsClosed { - return nil - } - msg := WSMessage{ - Type: MessageTypePong, - Data: nil, - Timestamp: time.Now().UnixMilli(), - } - msgBytes, _ := json.Marshal(msg) - return c.Conn.WriteMessage(websocket.TextMessage, msgBytes) -} - -// WritePump 写入泵,将消息从Manager发送到客户端 -func (c *Client) WritePump() { - defer func() { - c.Conn.Close() - c.Mu.Lock() - c.IsClosed = true - c.Mu.Unlock() - }() - - for { - message, ok := <-c.Send - if !ok { - c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) - return - } - - c.Mu.Lock() - if c.IsClosed { - c.Mu.Unlock() - return - } - err := c.Conn.WriteMessage(websocket.TextMessage, message) - c.Mu.Unlock() - - if err != nil { - log.Printf("Write error: %v", err) - return - } - } -} - -// ReadPump 读取泵,从客户端读取消息 -func (c *Client) ReadPump(handler func(msg *WSMessage)) { - defer func() { - c.Manager.Unregister(c) - c.Conn.Close() - c.Mu.Lock() - c.IsClosed = true - c.Mu.Unlock() - }() - - c.Conn.SetReadLimit(512 * 1024) // 512KB - c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second)) - c.Conn.SetPongHandler(func(string) error { - c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second)) - return nil - }) - - for { - _, message, err := c.Conn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - log.Printf("WebSocket error: %v", err) - } - break - } - - var wsMsg WSMessage - if err := json.Unmarshal(message, &wsMsg); err != nil { - log.Printf("Failed to unmarshal message: %v", err) - continue - } - - handler(&wsMsg) - } -} - -// CreateWSMessage 创建WebSocket消息 -func CreateWSMessage(msgType string, data interface{}) *WSMessage { - return &WSMessage{ - Type: msgType, - Data: data, - Timestamp: time.Now().UnixMilli(), - } -} diff --git a/internal/repository/comment_repo.go b/internal/repository/comment_repo.go index 9836c0c..934a7ed 100644 --- a/internal/repository/comment_repo.go +++ b/internal/repository/comment_repo.go @@ -18,32 +18,7 @@ func NewCommentRepository(db *gorm.DB) *CommentRepository { // Create 创建评论 func (r *CommentRepository) Create(comment *model.Comment) error { - return r.db.Transaction(func(tx *gorm.DB) error { - // 创建评论 - err := tx.Create(comment).Error - if err != nil { - return err - } - - // 增加帖子的评论数并同步热度分 - if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID). - Updates(map[string]interface{}{ - "comments_count": gorm.Expr("comments_count + 1"), - "hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"), - }).Error; err != nil { - return err - } - - // 如果是回复,增加父评论的回复数 - if comment.ParentID != nil && *comment.ParentID != "" { - if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID). - UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil { - return err - } - } - - return nil - }) + return r.db.Create(comment).Error } // GetByID 根据ID获取评论 @@ -87,23 +62,52 @@ func (r *CommentRepository) Delete(id string) error { return err } - // 减少帖子的评论数并同步热度分 + // 仅已发布评论才参与统计,避免 pending/rejected 影响计数 + if comment.Status == model.CommentStatusPublished { + // 减少帖子的评论数并同步热度分 + if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID). + Updates(map[string]interface{}{ + "comments_count": gorm.Expr("comments_count - 1"), + "hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"), + }).Error; err != nil { + return err + } + + // 如果是回复,减少父评论的回复数 + if comment.ParentID != nil && *comment.ParentID != "" { + if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID). + UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil { + return err + } + } + } + + return nil + }) +} + +// ApplyPublishedStats 在评论审核通过后更新帖子评论数/回复数 +func (r *CommentRepository) ApplyPublishedStats(comment *model.Comment) error { + if comment == nil { + return nil + } + return r.db.Transaction(func(tx *gorm.DB) error { + // 增加帖子的评论数并同步热度分 if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID). Updates(map[string]interface{}{ - "comments_count": gorm.Expr("comments_count - 1"), - "hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"), + "comments_count": gorm.Expr("comments_count + 1"), + "hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"), }).Error; err != nil { return err } - // 如果是回复,减少父评论的回复数 + // 如果是回复,增加父评论的回复数 if comment.ParentID != nil && *comment.ParentID != "" { if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID). - UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil { + UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil { return err } } - return nil }) } diff --git a/internal/repository/post_repo.go b/internal/repository/post_repo.go index 56a8574..f141d02 100644 --- a/internal/repository/post_repo.go +++ b/internal/repository/post_repo.go @@ -2,6 +2,7 @@ package repository import ( "carrot_bbs/internal/model" + "time" "gorm.io/gorm" ) @@ -52,9 +53,41 @@ func (r *PostRepository) GetByID(id string) (*model.Post, error) { // Update 更新帖子 func (r *PostRepository) Update(post *model.Post) error { + post.UpdatedAt = time.Now() return r.db.Save(post).Error } +// UpdateWithImages 更新帖子及其图片(images=nil 表示不更新图片) +func (r *PostRepository) UpdateWithImages(post *model.Post, images *[]string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + post.UpdatedAt = time.Now() + if err := tx.Save(post).Error; err != nil { + return err + } + + if images == nil { + return nil + } + + if err := tx.Where("post_id = ?", post.ID).Delete(&model.PostImage{}).Error; err != nil { + return err + } + + for i, url := range *images { + image := &model.PostImage{ + PostID: post.ID, + URL: url, + SortOrder: i, + } + if err := tx.Create(image).Error; err != nil { + return err + } + } + + return nil + }) +} + // UpdateModerationStatus 更新帖子审核状态 func (r *PostRepository) UpdateModerationStatus(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error { updates := map[string]interface{}{ @@ -100,15 +133,24 @@ func (r *PostRepository) Delete(id string) error { } // List 分页获取帖子列表 -func (r *PostRepository) List(page, pageSize int, userID string) ([]*model.Post, int64, error) { +// includePending=true 时,仅在指定 userID 下额外返回 pending(用于作者查看自己待审核帖子) +func (r *PostRepository) List(page, pageSize int, userID string, includePending bool) ([]*model.Post, int64, error) { var posts []*model.Post var total int64 - query := r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished) + query := r.db.Model(&model.Post{}) if userID != "" { query = query.Where("user_id = ?", userID) } + if includePending && userID != "" { + query = query.Where("status IN ?", []model.PostStatus{ + model.PostStatusPublished, + model.PostStatusPending, + }) + } else { + query = query.Where("status = ?", model.PostStatusPublished) + } query.Count(&total) @@ -119,14 +161,32 @@ func (r *PostRepository) List(page, pageSize int, userID string) ([]*model.Post, } // GetUserPosts 获取用户帖子 -func (r *PostRepository) GetUserPosts(userID string, page, pageSize int) ([]*model.Post, int64, error) { +func (r *PostRepository) GetUserPosts(userID string, page, pageSize int, includePending bool) ([]*model.Post, int64, error) { var posts []*model.Post var total int64 - r.db.Model(&model.Post{}).Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Count(&total) + statusQuery := r.db.Model(&model.Post{}).Where("user_id = ?", userID) + if includePending { + statusQuery = statusQuery.Where("status IN ?", []model.PostStatus{ + model.PostStatusPublished, + model.PostStatusPending, + }) + } else { + statusQuery = statusQuery.Where("status = ?", model.PostStatusPublished) + } + statusQuery.Count(&total) offset := (page - 1) * pageSize - err := r.db.Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error + listQuery := r.db.Where("user_id = ?", userID) + if includePending { + listQuery = listQuery.Where("status IN ?", []model.PostStatus{ + model.PostStatusPublished, + model.PostStatusPending, + }) + } else { + listQuery = listQuery.Where("status = ?", model.PostStatusPublished) + } + err := listQuery.Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error return posts, total, err } @@ -256,7 +316,8 @@ func (r *PostRepository) IsFavorited(postID, userID string) bool { // IncrementViews 增加帖子观看量 func (r *PostRepository) IncrementViews(postID string) error { return r.db.Model(&model.Post{}).Where("id = ?", postID). - Updates(map[string]interface{}{ + // 浏览量属于统计字段,不应影响帖子内容更新时间(updated_at) + UpdateColumns(map[string]interface{}{ "views_count": gorm.Expr("views_count + 1"), "hot_score": gorm.Expr("likes_count * 2 + comments_count * 3 + (views_count + 1) * 0.1"), }).Error diff --git a/internal/repository/user_repo.go b/internal/repository/user_repo.go index 4db8437..788e9e9 100644 --- a/internal/repository/user_repo.go +++ b/internal/repository/user_repo.go @@ -177,7 +177,9 @@ func (r *UserRepository) RefreshFollowersCount(userID string) error { // GetPostsCount 获取用户帖子数(实时计算) func (r *UserRepository) GetPostsCount(userID string) (int64, error) { var count int64 - err := r.db.Model(&model.Post{}).Where("user_id = ?", userID).Count(&count).Error + err := r.db.Model(&model.Post{}). + Where("user_id = ? AND status = ?", userID, model.PostStatusPublished). + Count(&count).Error return count, err } @@ -202,7 +204,7 @@ func (r *UserRepository) GetPostsCountBatch(userIDs []string) (map[string]int64, var counts []CountResult err := r.db.Model(&model.Post{}). Select("user_id, count(*) as count"). - Where("user_id IN ?", userIDs). + Where("user_id IN ? AND status = ?", userIDs, model.PostStatusPublished). Group("user_id"). Scan(&counts).Error if err != nil { diff --git a/internal/router/router.go b/internal/router/router.go index d617acb..9484015 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -17,7 +17,6 @@ type Router struct { messageHandler *handler.MessageHandler notificationHandler *handler.NotificationHandler uploadHandler *handler.UploadHandler - wsHandler *handler.WebSocketHandler pushHandler *handler.PushHandler systemMessageHandler *handler.SystemMessageHandler groupHandler *handler.GroupHandler @@ -36,7 +35,6 @@ func New( notificationHandler *handler.NotificationHandler, uploadHandler *handler.UploadHandler, jwtService *service.JWTService, - wsHandler *handler.WebSocketHandler, pushHandler *handler.PushHandler, systemMessageHandler *handler.SystemMessageHandler, groupHandler *handler.GroupHandler, @@ -55,7 +53,6 @@ func New( messageHandler: messageHandler, notificationHandler: notificationHandler, uploadHandler: uploadHandler, - wsHandler: wsHandler, pushHandler: pushHandler, systemMessageHandler: systemMessageHandler, groupHandler: groupHandler, @@ -79,11 +76,6 @@ func (r *Router) setupRoutes() { c.JSON(200, gin.H{"status": "ok"}) }) - // WebSocket 路由 - if r.wsHandler != nil { - r.engine.GET("/ws", r.wsHandler.HandleWebSocket) - } - // API v1 v1 := r.engine.Group("/api/v1") { @@ -210,10 +202,18 @@ func (r *Router) setupRoutes() { conversations.POST("/set_pinned", r.messageHandler.HandleSetConversationPinned) // 获取未读消息总数 conversations.GET("/unread/count", r.messageHandler.GetUnreadCount) + // 上报输入状态 + conversations.POST("/typing", r.messageHandler.HandleTyping) // 仅自己删除会话 conversations.DELETE("/:id/self", r.messageHandler.HandleDeleteConversationForSelf) } + realtime := v1.Group("/realtime") + realtime.Use(authMiddleware) + { + realtime.GET("/sse", r.messageHandler.HandleSSE) + } + // 消息操作路由 messages := v1.Group("/messages") messages.Use(authMiddleware) diff --git a/internal/service/chat_service.go b/internal/service/chat_service.go index e81b8e9..776a74c 100644 --- a/internal/service/chat_service.go +++ b/internal/service/chat_service.go @@ -4,11 +4,11 @@ import ( "context" "errors" "fmt" - "log" "time" + "carrot_bbs/internal/dto" "carrot_bbs/internal/model" - "carrot_bbs/internal/pkg/websocket" + "carrot_bbs/internal/pkg/sse" "carrot_bbs/internal/repository" "gorm.io/gorm" @@ -41,17 +41,13 @@ type ChatService interface { RecallMessage(ctx context.Context, messageID string, userID string) error DeleteMessage(ctx context.Context, messageID string, userID string) error - // WebSocket相关 + // 实时事件相关 SendTyping(ctx context.Context, senderID string, conversationID string) - BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string) - // 系统消息推送 + // 在线状态 IsUserOnline(userID string) bool - PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error - PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error - PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error - // 仅保存消息到数据库,不发送 WebSocket 推送(供群聊等自行推送的场景使用) + // 仅保存消息到数据库,不发送实时推送(供群聊等自行推送的场景使用) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) } @@ -61,7 +57,7 @@ type chatServiceImpl struct { repo *repository.MessageRepository userRepo *repository.UserRepository sensitive SensitiveService - wsManager *websocket.WebSocketManager + sseHub *sse.Hub } // NewChatService 创建聊天服务 @@ -70,17 +66,24 @@ func NewChatService( repo *repository.MessageRepository, userRepo *repository.UserRepository, sensitive SensitiveService, - wsManager *websocket.WebSocketManager, + sseHub *sse.Hub, ) ChatService { return &chatServiceImpl{ db: db, repo: repo, userRepo: userRepo, sensitive: sensitive, - wsManager: wsManager, + sseHub: sseHub, } } +func (s *chatServiceImpl) publishSSEToUsers(userIDs []string, event string, payload interface{}) { + if s.sseHub == nil || len(userIDs) == 0 { + return + } + s.sseHub.PublishToUsers(userIDs, event, payload) +} + // GetOrCreateConversation 获取或创建私聊会话 func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) { return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID) @@ -228,30 +231,30 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv return nil, fmt.Errorf("failed to save message: %w", err) } - // 发送消息给接收者 - wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, websocket.ChatMessage{ - ID: message.ID, - ConversationID: message.ConversationID, - SenderID: senderID, - Segments: message.Segments, - Seq: message.Seq, - CreatedAt: message.CreatedAt.UnixMilli(), - }) - - // 获取会话中的其他参与者 + // 获取会话中的参与者并发送 SSE participants, err := s.repo.GetConversationParticipants(conversationID) if err == nil { + targetIDs := make([]string, 0, len(participants)) + for _, p := range participants { + targetIDs = append(targetIDs, p.UserID) + } + detailType := "private" + if conv.Type == model.ConversationTypeGroup { + detailType = "group" + } + s.publishSSEToUsers(targetIDs, "chat_message", map[string]interface{}{ + "detail_type": detailType, + "message": dto.ConvertMessageToResponse(message), + }) for _, p := range participants { - // 不发给自己 if p.UserID == senderID { continue } - // 如果接收者在线,发送实时消息 - if s.wsManager != nil { - isOnline := s.wsManager.IsUserOnline(p.UserID) - if isOnline { - s.wsManager.SendToUser(p.UserID, wsMsg) - } + if totalUnread, uErr := s.repo.GetAllUnreadCount(p.UserID); uErr == nil { + s.publishSSEToUsers([]string{p.UserID}, "conversation_unread", map[string]interface{}{ + "conversation_id": conversationID, + "total_unread": totalUnread, + }) } } } @@ -337,25 +340,33 @@ func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, return fmt.Errorf("failed to update last read seq: %w", err) } - // 发送已读回执(作为 meta 事件) - if s.wsManager != nil { - wsMsg := websocket.CreateWSMessage("meta", map[string]interface{}{ - "detail_type": websocket.MetaDetailTypeRead, - "conversation_id": conversationID, - "seq": seq, - "user_id": userID, - }) - - // 获取会话中的所有参与者 - participants, err := s.repo.GetConversationParticipants(conversationID) - if err == nil { - // 推送给会话中的所有参与者(包括自己) - for _, p := range participants { - if s.wsManager.IsUserOnline(p.UserID) { - s.wsManager.SendToUser(p.UserID, wsMsg) - } + participants, pErr := s.repo.GetConversationParticipants(conversationID) + if pErr == nil { + detailType := "private" + groupID := "" + if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup { + detailType = "group" + if conv.GroupID != nil { + groupID = *conv.GroupID } } + targetIDs := make([]string, 0, len(participants)) + for _, p := range participants { + targetIDs = append(targetIDs, p.UserID) + } + s.publishSSEToUsers(targetIDs, "message_read", map[string]interface{}{ + "detail_type": detailType, + "conversation_id": conversationID, + "group_id": groupID, + "user_id": userID, + "seq": seq, + }) + } + if totalUnread, uErr := s.repo.GetAllUnreadCount(userID); uErr == nil { + s.publishSSEToUsers([]string{userID}, "conversation_unread", map[string]interface{}{ + "conversation_id": conversationID, + "total_unread": totalUnread, + }) } return nil @@ -407,29 +418,35 @@ func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, u return errors.New("message recall timeout (2 minutes)") } - // 更新消息状态为已撤回 - err = s.db.Model(&message).Update("status", model.MessageStatusRecalled).Error + // 更新消息状态为已撤回,并清空原始消息内容,仅保留撤回占位 + err = s.db.Model(&message).Updates(map[string]interface{}{ + "status": model.MessageStatusRecalled, + "segments": model.MessageSegments{}, + }).Error if err != nil { return fmt.Errorf("failed to recall message: %w", err) } - // 发送撤回通知 - if s.wsManager != nil { - wsMsg := websocket.CreateWSMessage(websocket.MessageTypeRecall, map[string]interface{}{ - "messageId": messageID, - "conversationId": message.ConversationID, - "senderId": userID, - }) - - // 通知会话中的所有参与者 - participants, err := s.repo.GetConversationParticipants(message.ConversationID) - if err == nil { - for _, p := range participants { - if s.wsManager.IsUserOnline(p.UserID) { - s.wsManager.SendToUser(p.UserID, wsMsg) - } + if participants, pErr := s.repo.GetConversationParticipants(message.ConversationID); pErr == nil { + detailType := "private" + groupID := "" + if conv, convErr := s.repo.GetConversation(message.ConversationID); convErr == nil && conv.Type == model.ConversationTypeGroup { + detailType = "group" + if conv.GroupID != nil { + groupID = *conv.GroupID } } + targetIDs := make([]string, 0, len(participants)) + for _, p := range participants { + targetIDs = append(targetIDs, p.UserID) + } + s.publishSSEToUsers(targetIDs, "message_recall", map[string]interface{}{ + "detail_type": detailType, + "conversation_id": message.ConversationID, + "group_id": groupID, + "message_id": messageID, + "sender_id": userID, + }) } return nil @@ -473,7 +490,7 @@ func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, u // SendTyping 发送正在输入状态 func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) { - if s.wsManager == nil { + if s.sseHub == nil { return } @@ -489,98 +506,34 @@ func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conve return } + detailType := "private" + if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup { + detailType = "group" + } for _, p := range participants { if p.UserID == senderID { continue } - // 发送正在输入状态 - wsMsg := websocket.CreateWSMessage(websocket.MessageTypeTyping, map[string]string{ - "conversationId": conversationID, - "senderId": senderID, - }) - - if s.wsManager.IsUserOnline(p.UserID) { - s.wsManager.SendToUser(p.UserID, wsMsg) + if s.sseHub != nil { + s.sseHub.PublishToUser(p.UserID, "typing", map[string]interface{}{ + "detail_type": detailType, + "conversation_id": conversationID, + "user_id": senderID, + "is_typing": true, + }) } } } -// BroadcastMessage 广播消息给用户 -func (s *chatServiceImpl) BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string) { - if s.wsManager != nil { - s.wsManager.SendToUser(targetUser, msg) - } -} - // IsUserOnline 检查用户是否在线 func (s *chatServiceImpl) IsUserOnline(userID string) bool { - if s.wsManager == nil { - return false + if s.sseHub != nil { + return s.sseHub.HasSubscribers(userID) } - return s.wsManager.IsUserOnline(userID) + return false } -// PushSystemMessage 推送系统消息给指定用户 -func (s *chatServiceImpl) PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error { - if s.wsManager == nil { - return errors.New("websocket manager not available") - } - - if !s.wsManager.IsUserOnline(userID) { - return errors.New("user is offline") - } - - sysMsg := &websocket.SystemMessage{ - ID: "", // 由调用方生成 - Type: msgType, - Title: title, - Content: content, - Data: data, - CreatedAt: time.Now().UnixMilli(), - } - - wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg) - s.wsManager.SendToUser(userID, wsMsg) - return nil -} - -// PushNotificationMessage 推送通知消息给指定用户 -func (s *chatServiceImpl) PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error { - if s.wsManager == nil { - return errors.New("websocket manager not available") - } - - if !s.wsManager.IsUserOnline(userID) { - return errors.New("user is offline") - } - - // 确保时间戳已设置 - if notification.CreatedAt == 0 { - notification.CreatedAt = time.Now().UnixMilli() - } - - wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification) - s.wsManager.SendToUser(userID, wsMsg) - return nil -} - -// PushAnnouncementMessage 广播公告消息给所有在线用户 -func (s *chatServiceImpl) PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error { - if s.wsManager == nil { - return errors.New("websocket manager not available") - } - - // 确保时间戳已设置 - if announcement.CreatedAt == 0 { - announcement.CreatedAt = time.Now().UnixMilli() - } - - wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement) - s.wsManager.Broadcast(wsMsg) - return nil -} - -// SaveMessage 仅保存消息到数据库,不发送 WebSocket 推送 +// SaveMessage 仅保存消息到数据库,不发送实时推送 // 适用于群聊等由调用方自行负责推送的场景 func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) { // 验证会话是否存在 diff --git a/internal/service/comment_service.go b/internal/service/comment_service.go index 89a897f..0d5ebf3 100644 --- a/internal/service/comment_service.go +++ b/internal/service/comment_service.go @@ -7,6 +7,7 @@ import ( "log" "strings" + "carrot_bbs/internal/cache" "carrot_bbs/internal/model" "carrot_bbs/internal/pkg/gorse" "carrot_bbs/internal/repository" @@ -17,6 +18,7 @@ type CommentService struct { commentRepo *repository.CommentRepository postRepo *repository.PostRepository systemMessageService SystemMessageService + cache cache.Cache gorseClient gorse.Client postAIService *PostAIService } @@ -27,6 +29,7 @@ func NewCommentService(commentRepo *repository.CommentRepository, postRepo *repo commentRepo: commentRepo, postRepo: postRepo, systemMessageService: systemMessageService, + cache: cache.GetCache(), gorseClient: gorseClient, postAIService: postAIService, } @@ -96,6 +99,10 @@ func (s *CommentService) reviewCommentAsync( log.Printf("[WARN] Failed to publish comment without AI moderation: %v", err) return } + if err := s.applyCommentPublishedStats(commentID); err != nil { + log.Printf("[WARN] Failed to apply published stats for comment %s: %v", commentID, err) + } + s.invalidatePostCaches(postID) s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID) return } @@ -116,6 +123,10 @@ func (s *CommentService) reviewCommentAsync( log.Printf("[WARN] Failed to publish comment %s after moderation error: %v", commentID, updateErr) return } + if statsErr := s.applyCommentPublishedStats(commentID); statsErr != nil { + log.Printf("[WARN] Failed to apply published stats for comment %s: %v", commentID, statsErr) + } + s.invalidatePostCaches(postID) log.Printf("[WARN] Comment moderation failed, fallback publish comment=%s err=%v", commentID, err) s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID) return @@ -125,9 +136,26 @@ func (s *CommentService) reviewCommentAsync( log.Printf("[WARN] Failed to publish comment %s: %v", commentID, updateErr) return } + if statsErr := s.applyCommentPublishedStats(commentID); statsErr != nil { + log.Printf("[WARN] Failed to apply published stats for comment %s: %v", commentID, statsErr) + } + s.invalidatePostCaches(postID) s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID) } +func (s *CommentService) applyCommentPublishedStats(commentID string) error { + comment, err := s.commentRepo.GetByID(commentID) + if err != nil { + return err + } + return s.commentRepo.ApplyPublishedStats(comment) +} + +func (s *CommentService) invalidatePostCaches(postID string) { + cache.InvalidatePostDetail(s.cache, postID) + cache.InvalidatePostList(s.cache) +} + func (s *CommentService) afterCommentPublished(userID, postID, commentID string, parentID *string, parentUserID, postOwnerID string) { // 发送系统消息通知 if s.systemMessageService != nil { @@ -212,7 +240,15 @@ func (s *CommentService) Update(ctx context.Context, comment *model.Comment) err // Delete 删除评论 func (s *CommentService) Delete(ctx context.Context, id string) error { - return s.commentRepo.Delete(id) + comment, err := s.commentRepo.GetByID(id) + if err != nil { + return err + } + if err := s.commentRepo.Delete(id); err != nil { + return err + } + s.invalidatePostCaches(comment.PostID) + return nil } // Like 点赞评论 diff --git a/internal/service/group_service.go b/internal/service/group_service.go index 226eae8..91aafb0 100644 --- a/internal/service/group_service.go +++ b/internal/service/group_service.go @@ -9,8 +9,8 @@ import ( "carrot_bbs/internal/cache" "carrot_bbs/internal/model" + "carrot_bbs/internal/pkg/sse" "carrot_bbs/internal/pkg/utils" - "carrot_bbs/internal/pkg/websocket" "carrot_bbs/internal/repository" "gorm.io/gorm" @@ -18,7 +18,7 @@ import ( // 缓存TTL常量 const ( - GroupMembersTTL = 120 * time.Second // 群组成员缓存120秒 + GroupMembersTTL = 120 * time.Second // 群组成员缓存120秒 GroupMembersNullTTL = 5 * time.Second GroupCacheJitter = 0.1 ) @@ -99,12 +99,12 @@ type groupService struct { messageRepo *repository.MessageRepository requestRepo repository.GroupJoinRequestRepository notifyRepo *repository.SystemNotificationRepository - wsManager *websocket.WebSocketManager + sseHub *sse.Hub cache cache.Cache } // NewGroupService 创建群组服务 -func NewGroupService(db *gorm.DB, groupRepo repository.GroupRepository, userRepo *repository.UserRepository, messageRepo *repository.MessageRepository, wsManager *websocket.WebSocketManager) GroupService { +func NewGroupService(db *gorm.DB, groupRepo repository.GroupRepository, userRepo *repository.UserRepository, messageRepo *repository.MessageRepository, sseHub *sse.Hub) GroupService { return &groupService{ db: db, groupRepo: groupRepo, @@ -112,11 +112,39 @@ func NewGroupService(db *gorm.DB, groupRepo repository.GroupRepository, userRepo messageRepo: messageRepo, requestRepo: repository.NewGroupJoinRequestRepository(db), notifyRepo: repository.NewSystemNotificationRepository(db), - wsManager: wsManager, + sseHub: sseHub, cache: cache.GetCache(), } } +type groupNoticeData struct { + UserID string `json:"user_id,omitempty"` + Username string `json:"username,omitempty"` + OperatorID string `json:"operator_id,omitempty"` +} + +type groupNoticeMessage struct { + NoticeType string `json:"notice_type"` + GroupID string `json:"group_id"` + Data groupNoticeData `json:"data"` + Timestamp int64 `json:"timestamp"` + MessageID string `json:"message_id,omitempty"` + Seq int64 `json:"seq,omitempty"` +} + +func (s *groupService) publishGroupNotice(groupID string, notice groupNoticeMessage) { + members, _, err := s.groupRepo.GetMembers(groupID, 1, 1000) + if err != nil { + log.Printf("[groupService] 获取群成员失败: groupID=%s, err=%v", groupID, err) + return + } + if s.sseHub != nil { + for _, m := range members { + s.sseHub.PublishToUser(m.UserID, "group_notice", notice) + } + } +} + // ==================== 群组管理 ==================== // CreateGroup 创建群组 @@ -422,14 +450,10 @@ func (s *groupService) broadcastMemberJoinNotice(groupID string, targetUserID st } } - if s.wsManager == nil { - return - } - - noticeMsg := websocket.GroupNoticeMessage{ + noticeMsg := groupNoticeMessage{ NoticeType: "member_join", GroupID: groupID, - Data: websocket.GroupNoticeData{ + Data: groupNoticeData{ UserID: targetUserID, Username: targetUserName, OperatorID: operatorID, @@ -441,17 +465,7 @@ func (s *groupService) broadcastMemberJoinNotice(groupID string, targetUserID st noticeMsg.Seq = savedMessage.Seq } - wsMsg := websocket.CreateWSMessage(websocket.MessageTypeGroupNotice, noticeMsg) - members, _, err := s.groupRepo.GetMembers(groupID, 1, 1000) - if err != nil { - log.Printf("[broadcastMemberJoinNotice] 获取群成员失败: groupID=%s, err=%v", groupID, err) - return - } - for _, m := range members { - if s.wsManager.IsUserOnline(m.UserID) { - s.wsManager.SendToUser(m.UserID, wsMsg) - } - } + s.publishGroupNotice(groupID, noticeMsg) } func (s *groupService) addMemberToGroupAndConversation(group *model.Group, userID string, operatorID string) error { @@ -1282,46 +1296,20 @@ func (s *groupService) MuteMember(userID string, groupID string, targetUserID st } } - // 发送WebSocket通知给群成员 - if s.wsManager != nil { - log.Printf("[MuteMember] 准备发送禁言通知: groupID=%s, targetUserID=%s, noticeType=%s, operatorID=%s", groupID, targetUserID, noticeType, userID) - - // 构建通知消息,包含保存的消息信息 - noticeMsg := websocket.GroupNoticeMessage{ - NoticeType: noticeType, - GroupID: groupID, - Data: websocket.GroupNoticeData{ - UserID: targetUserID, - OperatorID: userID, - }, - Timestamp: time.Now().UnixMilli(), - } - - // 如果消息已保存,添加消息ID和seq - if savedMessage != nil { - noticeMsg.MessageID = savedMessage.ID - noticeMsg.Seq = savedMessage.Seq - } - - wsMsg := websocket.CreateWSMessage(websocket.MessageTypeGroupNotice, noticeMsg) - log.Printf("[MuteMember] 创建的WebSocket消息: Type=%s, Data=%+v", wsMsg.Type, wsMsg.Data) - - // 获取所有群成员并发送通知 - members, _, err := s.groupRepo.GetMembers(groupID, 1, 1000) - if err == nil { - log.Printf("[MuteMember] 获取到群成员数量: %d", len(members)) - for _, m := range members { - isOnline := s.wsManager.IsUserOnline(m.UserID) - log.Printf("[MuteMember] 成员 %s 在线状态: %v", m.UserID, isOnline) - if isOnline { - s.wsManager.SendToUser(m.UserID, wsMsg) - log.Printf("[MuteMember] 已发送通知给成员: %s", m.UserID) - } - } - } else { - log.Printf("[MuteMember] 获取群成员失败: %v", err) - } + noticeMsg := groupNoticeMessage{ + NoticeType: noticeType, + GroupID: groupID, + Data: groupNoticeData{ + UserID: targetUserID, + OperatorID: userID, + }, + Timestamp: time.Now().UnixMilli(), } + if savedMessage != nil { + noticeMsg.MessageID = savedMessage.ID + noticeMsg.Seq = savedMessage.Seq + } + s.publishGroupNotice(groupID, noticeMsg) // 失效群组成员缓存 cache.InvalidateGroupMembers(s.cache, groupID) diff --git a/internal/service/post_service.go b/internal/service/post_service.go index e796edc..6945ccc 100644 --- a/internal/service/post_service.go +++ b/internal/service/post_service.go @@ -77,6 +77,8 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima if s.postAIService == nil || !s.postAIService.IsEnabled() { if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil { log.Printf("[WARN] Failed to publish post without AI moderation: %v", err) + } else { + s.invalidatePostCaches(postID) } return } @@ -87,6 +89,8 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima if errors.As(err, &rejectedErr) { if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil { log.Printf("[WARN] Failed to reject post %s: %v", postID, updateErr) + } else { + s.invalidatePostCaches(postID) } s.notifyModerationRejected(userID, rejectedErr.Reason) return @@ -95,6 +99,8 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima // 规则审核不可用时,降级为发布,避免长时间pending if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil { log.Printf("[WARN] Failed to publish post %s after moderation error: %v", postID, updateErr) + } else { + s.invalidatePostCaches(postID) } log.Printf("[WARN] Post moderation failed, fallback publish post=%s err=%v", postID, err) return @@ -104,6 +110,7 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima log.Printf("[WARN] Failed to publish post %s: %v", postID, err) return } + s.invalidatePostCaches(postID) if s.gorseClient.IsEnabled() { post, getErr := s.postRepo.GetByID(postID) @@ -120,6 +127,11 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima } } +func (s *PostService) invalidatePostCaches(postID string) { + cache.InvalidatePostDetail(s.cache, postID) + cache.InvalidatePostList(s.cache) +} + func (s *PostService) notifyModerationRejected(userID, reason string) { if s.systemMessageService == nil || strings.TrimSpace(userID) == "" { return @@ -149,7 +161,12 @@ func (s *PostService) GetByID(ctx context.Context, id string) (*model.Post, erro // Update 更新帖子 func (s *PostService) Update(ctx context.Context, post *model.Post) error { - err := s.postRepo.Update(post) + return s.UpdateWithImages(ctx, post, nil) +} + +// UpdateWithImages 更新帖子并可选更新图片(images=nil 表示不更新图片) +func (s *PostService) UpdateWithImages(ctx context.Context, post *model.Post, images *[]string) error { + err := s.postRepo.UpdateWithImages(post, images) if err != nil { return err } @@ -185,7 +202,7 @@ func (s *PostService) Delete(ctx context.Context, id string) error { } // List 获取帖子列表(带缓存) -func (s *PostService) List(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) { +func (s *PostService) List(ctx context.Context, page, pageSize int, userID string, includePending bool) ([]*model.Post, int64, error) { cacheSettings := cache.GetSettings() postListTTL := cacheSettings.PostListTTL if postListTTL <= 0 { @@ -200,8 +217,12 @@ func (s *PostService) List(ctx context.Context, page, pageSize int, userID strin jitter = PostListJitterRatio } - // 生成缓存键(包含 userID 维度,避免过滤查询与全量查询互相污染) - cacheKey := cache.PostListKey("latest", userID, page, pageSize) + // 生成缓存键(包含 userID 维度与可见性维度,避免作者视角污染公开视角) + visibilityUserKey := userID + if includePending && userID != "" { + visibilityUserKey = "owner:" + userID + } + cacheKey := cache.PostListKey("latest", visibilityUserKey, page, pageSize) result, err := cache.GetOrLoadTyped[*PostListResult]( s.cache, @@ -210,7 +231,7 @@ func (s *PostService) List(ctx context.Context, page, pageSize int, userID strin jitter, nullTTL, func() (*PostListResult, error) { - posts, total, err := s.postRepo.List(page, pageSize, userID) + posts, total, err := s.postRepo.List(page, pageSize, userID, includePending) if err != nil { return nil, err } @@ -234,7 +255,7 @@ func (s *PostService) List(ctx context.Context, page, pageSize int, userID strin } } if missingAuthor { - posts, total, loadErr := s.postRepo.List(page, pageSize, userID) + posts, total, loadErr := s.postRepo.List(page, pageSize, userID, includePending) if loadErr != nil { return nil, 0, loadErr } @@ -247,12 +268,17 @@ func (s *PostService) List(ctx context.Context, page, pageSize int, userID strin // GetLatestPosts 获取最新帖子(语义化别名) func (s *PostService) GetLatestPosts(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) { - return s.List(ctx, page, pageSize, userID) + return s.List(ctx, page, pageSize, userID, false) +} + +// GetLatestPostsForOwner 获取作者视角帖子列表(包含待审核) +func (s *PostService) GetLatestPostsForOwner(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) { + return s.List(ctx, page, pageSize, userID, true) } // GetUserPosts 获取用户帖子 -func (s *PostService) GetUserPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) { - return s.postRepo.GetUserPosts(userID, page, pageSize) +func (s *PostService) GetUserPosts(ctx context.Context, userID string, page, pageSize int, includePending bool) ([]*model.Post, int64, error) { + return s.postRepo.GetUserPosts(userID, page, pageSize, includePending) } // Like 点赞 diff --git a/internal/service/push_service.go b/internal/service/push_service.go index 80fc1c1..b1a53df 100644 --- a/internal/service/push_service.go +++ b/internal/service/push_service.go @@ -8,7 +8,7 @@ import ( "carrot_bbs/internal/dto" "carrot_bbs/internal/model" - "carrot_bbs/internal/pkg/websocket" + "carrot_bbs/internal/pkg/sse" "carrot_bbs/internal/repository" ) @@ -42,8 +42,6 @@ type PushService interface { // 系统消息推送 PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error - PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error - PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error // 系统通知推送(新接口,使用独立的 SystemNotification 模型) PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error @@ -67,7 +65,7 @@ type pushServiceImpl struct { pushRepo *repository.PushRecordRepository deviceRepo *repository.DeviceTokenRepository messageRepo *repository.MessageRepository - wsManager *websocket.WebSocketManager + sseHub *sse.Hub // 推送队列 pushQueue chan *pushTask @@ -86,13 +84,13 @@ func NewPushService( pushRepo *repository.PushRecordRepository, deviceRepo *repository.DeviceTokenRepository, messageRepo *repository.MessageRepository, - wsManager *websocket.WebSocketManager, + sseHub *sse.Hub, ) PushService { return &pushServiceImpl{ pushRepo: pushRepo, deviceRepo: deviceRepo, messageRepo: messageRepo, - wsManager: wsManager, + sseHub: sseHub, pushQueue: make(chan *pushTask, PushQueueSize), stopChan: make(chan struct{}), } @@ -140,11 +138,7 @@ func (s *pushServiceImpl) PushToUser(ctx context.Context, userID string, message // pushViaWebSocket 通过WebSocket推送消息 // 返回true表示推送成功,false表示用户不在线 func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, message *model.Message) bool { - if s.wsManager == nil { - return false - } - - if !s.wsManager.IsUserOnline(userID) { + if s.sseHub == nil || !s.sseHub.HasSubscribers(userID) { return false } @@ -154,36 +148,33 @@ func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, m // 从 segments 中提取文本内容 content := dto.ExtractTextContentFromModel(message.Segments) - notification := &websocket.NotificationMessage{ - ID: fmt.Sprintf("%s", message.ID), - Type: string(message.SystemType), - Content: content, - Extra: make(map[string]interface{}), - CreatedAt: message.CreatedAt.UnixMilli(), + notification := map[string]interface{}{ + "id": fmt.Sprintf("%s", message.ID), + "type": string(message.SystemType), + "content": content, + "extra": map[string]interface{}{}, + "created_at": message.CreatedAt.UnixMilli(), } // 填充额外数据 if message.ExtraData != nil { - notification.Extra["actor_id"] = message.ExtraData.ActorID - notification.Extra["actor_name"] = message.ExtraData.ActorName - notification.Extra["avatar_url"] = message.ExtraData.AvatarURL - notification.Extra["target_id"] = message.ExtraData.TargetID - notification.Extra["target_type"] = message.ExtraData.TargetType - notification.Extra["action_url"] = message.ExtraData.ActionURL - notification.Extra["action_time"] = message.ExtraData.ActionTime - - // 设置触发用户信息 + extra := notification["extra"].(map[string]interface{}) + extra["actor_id"] = message.ExtraData.ActorID + extra["actor_name"] = message.ExtraData.ActorName + extra["avatar_url"] = message.ExtraData.AvatarURL + extra["target_id"] = message.ExtraData.TargetID + extra["target_type"] = message.ExtraData.TargetType + extra["action_url"] = message.ExtraData.ActionURL + extra["action_time"] = message.ExtraData.ActionTime if message.ExtraData.ActorID > 0 { - notification.TriggerUser = &websocket.NotificationUser{ - ID: fmt.Sprintf("%d", message.ExtraData.ActorID), - Username: message.ExtraData.ActorName, - Avatar: message.ExtraData.AvatarURL, + notification["trigger_user"] = map[string]interface{}{ + "id": fmt.Sprintf("%d", message.ExtraData.ActorID), + "username": message.ExtraData.ActorName, + "avatar": message.ExtraData.AvatarURL, } } } - - wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification) - s.wsManager.SendToUser(userID, wsMsg) + s.sseHub.PublishToUser(userID, "system_notification", notification) return true } @@ -208,8 +199,10 @@ func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, m SenderID: message.SenderID, } - wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, event) - s.wsManager.SendToUser(userID, wsMsg) + s.sseHub.PublishToUser(userID, "chat_message", map[string]interface{}{ + "detail_type": detailType, + "message": event, + }) return true } @@ -451,73 +444,21 @@ func (s *pushServiceImpl) PushSystemMessage(ctx context.Context, userID string, // pushSystemViaWebSocket 通过WebSocket推送系统消息 func (s *pushServiceImpl) pushSystemViaWebSocket(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) bool { - if s.wsManager == nil { + if s.sseHub == nil || !s.sseHub.HasSubscribers(userID) { return false } - if !s.wsManager.IsUserOnline(userID) { - return false + sysMsg := map[string]interface{}{ + "type": msgType, + "title": title, + "content": content, + "data": data, + "created_at": time.Now().UnixMilli(), } - - sysMsg := &websocket.SystemMessage{ - Type: msgType, - Title: title, - Content: content, - Data: data, - CreatedAt: time.Now().UnixMilli(), - } - - wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg) - s.wsManager.SendToUser(userID, wsMsg) + s.sseHub.PublishToUser(userID, "system_notification", sysMsg) return true } -// PushNotification 推送通知消息 -func (s *pushServiceImpl) PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error { - // 首先尝试WebSocket推送 - if s.pushNotificationViaWebSocket(ctx, userID, notification) { - return nil - } - - // 用户不在线,创建待推送记录 - // 通知消息可以等用户上线后拉取 - return errors.New("user is offline, notification will be available on next sync") -} - -// pushNotificationViaWebSocket 通过WebSocket推送通知消息 -func (s *pushServiceImpl) pushNotificationViaWebSocket(ctx context.Context, userID string, notification *websocket.NotificationMessage) bool { - if s.wsManager == nil { - return false - } - - if !s.wsManager.IsUserOnline(userID) { - return false - } - - if notification.CreatedAt == 0 { - notification.CreatedAt = time.Now().UnixMilli() - } - - wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification) - s.wsManager.SendToUser(userID, wsMsg) - return true -} - -// PushAnnouncement 广播公告消息 -func (s *pushServiceImpl) PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error { - if s.wsManager == nil { - return errors.New("websocket manager not available") - } - - if announcement.CreatedAt == 0 { - announcement.CreatedAt = time.Now().UnixMilli() - } - - wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement) - s.wsManager.Broadcast(wsMsg) - return nil -} - // PushSystemNotification 推送系统通知(使用独立的 SystemNotification 模型) func (s *pushServiceImpl) PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error { // 首先尝试WebSocket推送 @@ -531,45 +472,40 @@ func (s *pushServiceImpl) PushSystemNotification(ctx context.Context, userID str // pushSystemNotificationViaWebSocket 通过WebSocket推送系统通知 func (s *pushServiceImpl) pushSystemNotificationViaWebSocket(ctx context.Context, userID string, notification *model.SystemNotification) bool { - if s.wsManager == nil { + if s.sseHub == nil || !s.sseHub.HasSubscribers(userID) { return false } - if !s.wsManager.IsUserOnline(userID) { - return false - } - - // 构建 WebSocket 通知消息 - wsNotification := &websocket.NotificationMessage{ - ID: fmt.Sprintf("%d", notification.ID), - Type: string(notification.Type), - Title: notification.Title, - Content: notification.Content, - Extra: make(map[string]interface{}), - CreatedAt: notification.CreatedAt.UnixMilli(), + sseNotification := map[string]interface{}{ + "id": fmt.Sprintf("%d", notification.ID), + "type": string(notification.Type), + "title": notification.Title, + "content": notification.Content, + "extra": map[string]interface{}{}, + "created_at": notification.CreatedAt.UnixMilli(), } // 填充额外数据 if notification.ExtraData != nil { - wsNotification.Extra["actor_id_str"] = notification.ExtraData.ActorIDStr - wsNotification.Extra["actor_name"] = notification.ExtraData.ActorName - wsNotification.Extra["avatar_url"] = notification.ExtraData.AvatarURL - wsNotification.Extra["target_id"] = notification.ExtraData.TargetID - wsNotification.Extra["target_type"] = notification.ExtraData.TargetType - wsNotification.Extra["action_url"] = notification.ExtraData.ActionURL - wsNotification.Extra["action_time"] = notification.ExtraData.ActionTime + extra := sseNotification["extra"].(map[string]interface{}) + extra["actor_id_str"] = notification.ExtraData.ActorIDStr + extra["actor_name"] = notification.ExtraData.ActorName + extra["avatar_url"] = notification.ExtraData.AvatarURL + extra["target_id"] = notification.ExtraData.TargetID + extra["target_type"] = notification.ExtraData.TargetType + extra["action_url"] = notification.ExtraData.ActionURL + extra["action_time"] = notification.ExtraData.ActionTime // 设置触发用户信息 if notification.ExtraData.ActorIDStr != "" { - wsNotification.TriggerUser = &websocket.NotificationUser{ - ID: notification.ExtraData.ActorIDStr, - Username: notification.ExtraData.ActorName, - Avatar: notification.ExtraData.AvatarURL, + sseNotification["trigger_user"] = map[string]interface{}{ + "id": notification.ExtraData.ActorIDStr, + "username": notification.ExtraData.ActorName, + "avatar": notification.ExtraData.AvatarURL, } } } - wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, wsNotification) - s.wsManager.SendToUser(userID, wsMsg) + s.sseHub.PublishToUser(userID, "system_notification", sseNotification) return true }