Replace websocket flow with SSE support in backend.

Update handlers, services, router, and data conversion logic to support server-sent events and related message pipeline changes.

Made-with: Cursor
This commit is contained in:
2026-03-10 12:58:23 +08:00
parent 4c0177149a
commit 86ef150fec
19 changed files with 689 additions and 1719 deletions

View File

@@ -36,9 +36,9 @@ database:
redis:
type: miniredis # miniredis 或 redis
redis:
host: localhost
host: 1Panel-redis-dfmM
port: 6379
password: ""
password: "redis_j8CMza"
db: 0
miniredis:
host: localhost
@@ -67,13 +67,13 @@ cache:
# S3对象存储配置
# 环境变量: APP_S3_ENDPOINT, APP_S3_ACCESS_KEY, APP_S3_SECRET_KEY, APP_S3_BUCKET, APP_S3_DOMAIN
s3:
endpoint: ""
access_key: ""
secret_key: ""
bucket: ""
endpoint: "files.littlelan.cn"
access_key: "E6bMcYkQzCldRTrtmhvi"
secret_key: "4R9yjmwKNoHphiBkv05Oa8WGEIFbnlZeTLXfSgx3"
bucket: "test"
use_ssl: true
region: us-east-1
domain: ""
domain: "files.littlelan.cn"
# JWT配置
# 环境变量: APP_JWT_SECRET
jwt:
@@ -130,12 +130,12 @@ audit:
# Gorse推荐系统配置
# 环境变量: APP_GORSE_ADDRESS, APP_GORSE_API_KEY, APP_GORSE_DASHBOARD, APP_GORSE_IMPORT_PASSWORD
gorse:
enabled: false
address: "" # Gorse server地址
enabled: true
address: "http://111.170.19.33:8088" # Gorse server地址
api_key: "" # API密钥
dashboard: "" # Gorse dashboard地址
import_password: "" # 导入数据密码
embedding_api_key: ""
import_password: "lanyimin123" # 导入数据密码
embedding_api_key: "sk-ZPN5NMPSqEaOGCPfD2LqndZ5Wwmw3DC4CQgzgKhM35fI3RpD"
embedding_url: "https://api.littlelan.cn/v1/embeddings"
embedding_model: "BAAI/bge-m3"
@@ -147,7 +147,7 @@ gorse:
openai:
enabled: true
base_url: "https://api.littlelan.cn/"
api_key: ""
api_key: "sk-y7LOeKsNfzbZWTRSFsTs79jd8WYlezbIVgdVPgMvG4Xz2AlV"
moderation_model: "qwen3.5-122b"
moderation_max_images_per_request: 1
request_timeout: 30
@@ -160,12 +160,12 @@ openai:
# APP_EMAIL_FROM_ADDRESS, APP_EMAIL_FROM_NAME
# APP_EMAIL_USE_TLS, APP_EMAIL_INSECURE_SKIP_VERIFY, APP_EMAIL_TIMEOUT
email:
enabled: false
host: ""
port: 587
username: ""
password: ""
from_address: ""
enabled: true
host: "smtp.exmail.qq.com"
port: 465
username: "no-reply@qczlit.cn"
password: "HbvwwVjRyiWg9gsK"
from_address: "no-reply@qczlit.cn"
from_name: "Carrot BBS"
use_tls: true
insecure_skip_verify: false

View File

@@ -284,6 +284,7 @@ func ConvertPostToResponse(post *model.Post, isLiked, isFavorited bool) *PostRes
Title: post.Title,
Content: post.Content,
Images: images,
Status: string(post.Status),
LikesCount: post.LikesCount,
CommentsCount: post.CommentsCount,
FavoritesCount: post.FavoritesCount,
@@ -293,6 +294,7 @@ func ConvertPostToResponse(post *model.Post, isLiked, isFavorited bool) *PostRes
IsLocked: post.IsLocked,
IsVote: post.IsVote,
CreatedAt: FormatTime(post.CreatedAt),
UpdatedAt: FormatTime(post.UpdatedAt),
Author: author,
IsLiked: isLiked,
IsFavorited: isFavorited,

View File

@@ -68,6 +68,7 @@ type PostResponse struct {
Title string `json:"title"`
Content string `json:"content"`
Images []PostImageResponse `json:"images"`
Status string `json:"status,omitempty"`
LikesCount int `json:"likes_count"`
CommentsCount int `json:"comments_count"`
FavoritesCount int `json:"favorites_count"`
@@ -77,6 +78,7 @@ type PostResponse struct {
IsLocked bool `json:"is_locked"`
IsVote bool `json:"is_vote"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
Author *UserResponse `json:"author"`
IsLiked bool `json:"is_liked"`
IsFavorited bool `json:"is_favorited"`

View File

@@ -2,12 +2,16 @@ package handler
import (
"context"
"fmt"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/sse"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
@@ -18,18 +22,111 @@ type MessageHandler struct {
messageService *service.MessageService
userService *service.UserService
groupService service.GroupService
sseHub *sse.Hub
}
// NewMessageHandler 创建消息处理器
func NewMessageHandler(chatService service.ChatService, messageService *service.MessageService, userService *service.UserService, groupService service.GroupService) *MessageHandler {
func NewMessageHandler(chatService service.ChatService, messageService *service.MessageService, userService *service.UserService, groupService service.GroupService, sseHub *sse.Hub) *MessageHandler {
return &MessageHandler{
chatService: chatService,
messageService: messageService,
userService: userService,
groupService: groupService,
sseHub: sseHub,
}
}
// HandleSSE 实时消息订阅SSE
// GET /api/v1/realtime/sse
func (h *MessageHandler) HandleSSE(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
if h.sseHub == nil {
response.InternalServerError(c, "sse hub not available")
return
}
lastID := sse.ParseEventID(c.GetHeader("Last-Event-ID"))
if lastID == 0 {
lastID = sse.ParseEventID(c.Query("last_event_id"))
}
ch, cancel, replay := h.sseHub.Subscribe(userID, lastID)
defer cancel()
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
response.InternalServerError(c, "streaming unsupported")
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
flusher.Flush()
writeEvent := func(ev sse.Event) bool {
data, err := sse.EncodeData(ev)
if err != nil {
return false
}
if _, err := fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", ev.ID, ev.Event, data); err != nil {
return false
}
flusher.Flush()
return true
}
for _, ev := range replay {
if !writeEvent(ev) {
return
}
}
heartbeat := time.NewTicker(25 * time.Second)
defer heartbeat.Stop()
for {
select {
case <-c.Request.Context().Done():
return
case ev, ok := <-ch:
if !ok || !writeEvent(ev) {
return
}
case <-heartbeat.C:
if _, err := fmt.Fprint(w, "event: heartbeat\ndata: {}\n\n"); err != nil {
return
}
flusher.Flush()
}
}
}
// HandleTyping 输入状态上报
// POST /api/v1/conversations/typing
func (h *MessageHandler) HandleTyping(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var params struct {
ConversationID string `json:"conversation_id" binding:"required"`
}
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
h.chatService.SendTyping(c.Request.Context(), userID, params.ConversationID)
response.SuccessWithMessage(c, "typing sent", nil)
}
// GetConversations 获取会话列表
// GET /api/conversations
func (h *MessageHandler) GetConversations(c *gin.Context) {

View File

@@ -105,6 +105,7 @@ func (h *PostHandler) GetByID(c *gin.Context) {
Title: post.Title,
Content: post.Content,
Images: dto.ConvertPostImagesToResponse(post.Images),
Status: string(post.Status),
LikesCount: post.LikesCount,
CommentsCount: post.CommentsCount,
FavoritesCount: post.FavoritesCount,
@@ -114,6 +115,7 @@ func (h *PostHandler) GetByID(c *gin.Context) {
IsLocked: post.IsLocked,
IsVote: post.IsVote,
CreatedAt: dto.FormatTime(post.CreatedAt),
UpdatedAt: dto.FormatTime(post.UpdatedAt),
Author: authorWithFollowStatus,
IsLiked: isLiked,
IsFavorited: isFavorited,
@@ -175,10 +177,18 @@ func (h *PostHandler) List(c *gin.Context) {
posts, total, err = h.postService.GetRecommendedPosts(c.Request.Context(), currentUserID, page, pageSize)
case "latest":
// 最新帖子
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
if userID != "" && userID == currentUserID {
posts, total, err = h.postService.GetLatestPostsForOwner(c.Request.Context(), page, pageSize, userID)
} else {
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
}
default:
// 默认获取最新帖子
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
if userID != "" && userID == currentUserID {
posts, total, err = h.postService.GetLatestPostsForOwner(c.Request.Context(), page, pageSize, userID)
} else {
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
}
}
if err != nil {
@@ -225,8 +235,9 @@ func (h *PostHandler) Update(c *gin.Context) {
}
type UpdateRequest struct {
Title string `json:"title"`
Content string `json:"content"`
Title string `json:"title"`
Content string `json:"content"`
Images *[]string `json:"images"`
}
var req UpdateRequest
@@ -242,12 +253,18 @@ func (h *PostHandler) Update(c *gin.Context) {
post.Content = req.Content
}
err = h.postService.Update(c.Request.Context(), post)
err = h.postService.UpdateWithImages(c.Request.Context(), post, req.Images)
if err != nil {
response.InternalServerError(c, "failed to update post")
return
}
post, err = h.postService.GetByID(c.Request.Context(), post.ID)
if err != nil {
response.InternalServerError(c, "failed to get updated post")
return
}
currentUserID := c.GetString("user_id")
var isLiked, isFavorited bool
if currentUserID != "" {
@@ -410,14 +427,15 @@ func (h *PostHandler) GetUserPosts(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
posts, total, err := h.postService.GetUserPosts(c.Request.Context(), userID, page, pageSize)
currentUserID := c.GetString("user_id")
includePending := currentUserID != "" && currentUserID == userID
posts, total, err := h.postService.GetUserPosts(c.Request.Context(), userID, page, pageSize, includePending)
if err != nil {
response.InternalServerError(c, "failed to get user posts")
return
}
// 获取当前用户ID用于判断点赞和收藏状态
currentUserID := c.GetString("user_id")
isLikedMap := make(map[string]bool)
isFavoritedMap := make(map[string]bool)
if currentUserID != "" {

View File

@@ -1,849 +0,0 @@
package handler
import (
"context"
"encoding/json"
"log"
"net/http"
"strconv"
"strings"
"time"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
ws "carrot_bbs/internal/pkg/websocket"
"carrot_bbs/internal/repository"
"carrot_bbs/internal/service"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // 允许所有来源,生产环境应该限制
},
}
// WebSocketHandler WebSocket处理器
type WebSocketHandler struct {
jwtService *service.JWTService
chatService service.ChatService
groupService service.GroupService
groupRepo repository.GroupRepository
userRepo *repository.UserRepository
wsManager *ws.WebSocketManager
}
// NewWebSocketHandler 创建WebSocket处理器
func NewWebSocketHandler(
jwtService *service.JWTService,
chatService service.ChatService,
groupService service.GroupService,
groupRepo repository.GroupRepository,
userRepo *repository.UserRepository,
wsManager *ws.WebSocketManager,
) *WebSocketHandler {
return &WebSocketHandler{
jwtService: jwtService,
chatService: chatService,
groupService: groupService,
groupRepo: groupRepo,
userRepo: userRepo,
wsManager: wsManager,
}
}
// HandleWebSocket 处理WebSocket连接
func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) {
// 调试:打印请求头信息
log.Printf("[WebSocket] 收到请求: Method=%s, Path=%s", c.Request.Method, c.Request.URL.Path)
log.Printf("[WebSocket] 请求头: Connection=%s, Upgrade=%s",
c.GetHeader("Connection"),
c.GetHeader("Upgrade"))
log.Printf("[WebSocket] Sec-WebSocket-Key=%s, Sec-WebSocket-Version=%s",
c.GetHeader("Sec-WebSocket-Key"),
c.GetHeader("Sec-WebSocket-Version"))
// 从query参数获取token
token := c.Query("token")
if token == "" {
// 尝试从header获取
authHeader := c.GetHeader("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
token = strings.TrimPrefix(authHeader, "Bearer ")
}
}
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing token"})
return
}
// 验证token
claims, err := h.jwtService.ParseToken(token)
if err != nil {
log.Printf("Invalid token: %v", err)
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
userID := claims.UserID
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token claims"})
return
}
// 升级HTTP连接为WebSocket连接
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("Failed to upgrade connection: %v", err)
log.Printf("[WebSocket] 请求详情 - User-Agent: %s, Content-Type: %s",
c.GetHeader("User-Agent"),
c.GetHeader("Content-Type"))
return
}
// 创建客户端
client := &ws.Client{
ID: userID,
UserID: userID,
Conn: conn,
Send: make(chan []byte, 256),
Manager: h.wsManager,
}
// 注册客户端
h.wsManager.Register(client)
// 启动读写协程
go client.WritePump()
go h.handleMessages(client)
}
// handleMessages 处理客户端消息
// 针对移动端优化:增加超时时间到 120 秒,配合客户端 55 秒心跳
func (h *WebSocketHandler) handleMessages(client *ws.Client) {
defer func() {
h.wsManager.Unregister(client)
client.Conn.Close()
}()
client.Conn.SetReadLimit(512 * 1024) // 512KB
client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒
client.Conn.SetPongHandler(func(string) error {
client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒
return nil
})
// 心跳定时器 - 服务端主动 ping 间隔保持 30 秒
pingTicker := time.NewTicker(30 * time.Second)
defer pingTicker.Stop()
for {
select {
case <-pingTicker.C:
// 发送心跳
if err := client.SendPing(); err != nil {
log.Printf("Failed to send ping: %v", err)
return
}
default:
_, message, err := client.Conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket error: %v", err)
}
return
}
var wsMsg ws.WSMessage
if err := json.Unmarshal(message, &wsMsg); err != nil {
log.Printf("Failed to unmarshal message: %v", err)
continue
}
h.processMessage(client, &wsMsg)
}
}
}
// processMessage 处理消息
func (h *WebSocketHandler) processMessage(client *ws.Client, msg *ws.WSMessage) {
switch msg.Type {
case ws.MessageTypePing:
// 响应心跳
if err := client.SendPong(); err != nil {
log.Printf("Failed to send pong: %v", err)
}
case ws.MessageTypePong:
// 客户端响应心跳
case ws.MessageTypeMessage:
// 处理聊天消息
h.handleChatMessage(client, msg)
case ws.MessageTypeTyping:
// 处理正在输入状态
h.handleTyping(client, msg)
case ws.MessageTypeRead:
// 处理已读回执
h.handleReadReceipt(client, msg)
// 群组消息处理
case ws.MessageTypeGroupMessage:
// 处理群消息
h.handleGroupMessage(client, msg)
case ws.MessageTypeGroupTyping:
// 处理群输入状态
h.handleGroupTyping(client, msg)
case ws.MessageTypeGroupRead:
// 处理群消息已读
h.handleGroupReadReceipt(client, msg)
case ws.MessageTypeGroupRecall:
// 处理群消息撤回
h.handleGroupRecall(client, msg)
default:
log.Printf("Unknown message type: %s", msg.Type)
}
}
// handleChatMessage 处理聊天消息
func (h *WebSocketHandler) handleChatMessage(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
log.Printf("Invalid message data format")
return
}
conversationIDStr, _ := data["conversationId"].(string)
if conversationIDStr == "" {
log.Printf("Missing conversationId")
return
}
// 解析会话ID
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
log.Printf("Invalid conversation ID: %v", err)
return
}
// 解析 segments
var segments model.MessageSegments
if data["segments"] != nil {
segmentsBytes, err := json.Marshal(data["segments"])
if err == nil {
json.Unmarshal(segmentsBytes, &segments)
}
}
// 从 segments 中提取回复消息ID
replyToID := dto.GetReplyMessageID(segments)
var replyToIDPtr *string
if replyToID != "" {
replyToIDPtr = &replyToID
}
// 发送消息 - 使用 segments
message, err := h.chatService.SendMessage(context.Background(), client.UserID, conversationID, segments, replyToIDPtr)
if err != nil {
log.Printf("Failed to send message: %v", err)
// 发送错误消息
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
"error": "Failed to send message",
})
if client.Send != nil {
msgBytes, _ := json.Marshal(errorMsg)
client.Send <- msgBytes
}
return
}
// 发送确认消息(使用 meta 事件格式,包含完整的消息内容)
metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{
"detail_type": ws.MetaDetailTypeAck,
"conversation_id": conversationID,
"id": message.ID,
"user_id": client.UserID,
"sender_id": client.UserID,
"seq": message.Seq,
"segments": message.Segments,
"created_at": message.CreatedAt.UnixMilli(),
})
if client.Send != nil {
msgBytes, _ := json.Marshal(metaAckMsg)
client.Send <- msgBytes
}
}
// handleTyping 处理正在输入状态
func (h *WebSocketHandler) handleTyping(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
return
}
conversationIDStr, _ := data["conversationId"].(string)
if conversationIDStr == "" {
return
}
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
return
}
// 直接使用 string 类型的 userID
h.chatService.SendTyping(context.Background(), client.UserID, conversationID)
}
// handleReadReceipt 处理已读回执
func (h *WebSocketHandler) handleReadReceipt(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
return
}
conversationIDStr, _ := data["conversationId"].(string)
if conversationIDStr == "" {
return
}
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
return
}
// 获取lastReadSeq
lastReadSeq, _ := data["lastReadSeq"].(float64)
if lastReadSeq == 0 {
return
}
// 直接使用 string 类型的 userID 和 conversationID
if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil {
log.Printf("Failed to mark as read: %v", err)
}
}
// ==================== 群组消息处理 ====================
// handleGroupMessage 处理群消息
func (h *WebSocketHandler) handleGroupMessage(client *ws.Client, msg *ws.WSMessage) {
// 打印接收到的消息类型和数据,用于调试
log.Printf("[handleGroupMessage] Received message type: %s", msg.Type)
log.Printf("[handleGroupMessage] Message data: %+v", msg.Data)
data, ok := msg.Data.(map[string]interface{})
if !ok {
log.Printf("Invalid group message data format: data is not map[string]interface{}")
return
}
// 解析群组ID支持 camelCase 和 snake_case
var groupIDFloat float64
groupID := "" // 使用 groupID 作为最终变量名
if val, ok := data["groupId"].(float64); ok {
groupIDFloat = val
groupID = strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
} else if val, ok := data["group_id"].(string); ok {
groupID = val
}
if groupID == "" {
log.Printf("Missing groupId in group message")
return
}
// 解析会话ID支持 camelCase 和 snake_case
var conversationID string
if val, ok := data["conversationId"].(string); ok {
conversationID = val
} else if val, ok := data["conversation_id"].(string); ok {
conversationID = val
}
if conversationID == "" {
log.Printf("Missing conversationId in group message")
return
}
// 解析 segments
var segments model.MessageSegments
if data["segments"] != nil {
segmentsBytes, err := json.Marshal(data["segments"])
if err == nil {
json.Unmarshal(segmentsBytes, &segments)
}
}
// 解析@用户列表(支持 camelCase 和 snake_case
var mentionUsers []uint64
var mentionUsersInterface []interface{}
if val, ok := data["mentionUsers"].([]interface{}); ok {
mentionUsersInterface = val
} else if val, ok := data["mention_users"].([]interface{}); ok {
mentionUsersInterface = val
}
if len(mentionUsersInterface) > 0 {
for _, uid := range mentionUsersInterface {
if uidFloat, ok := uid.(float64); ok {
mentionUsers = append(mentionUsers, uint64(uidFloat))
} else if uidStr, ok := uid.(string); ok {
// 处理字符串格式的用户ID
if uidInt, err := strconv.ParseUint(uidStr, 10, 64); err == nil {
mentionUsers = append(mentionUsers, uidInt)
}
}
}
}
// 解析@所有人(支持 camelCase 和 snake_case
var mentionAll bool
if val, ok := data["mentionAll"].(bool); ok {
mentionAll = val
} else if val, ok := data["mention_all"].(bool); ok {
mentionAll = val
}
// 检查用户是否可以发送群消息(验证成员身份和禁言状态)
// client.UserID 已经是 string 格式的 UUID
if err := h.groupService.CanSendGroupMessage(client.UserID, groupID); err != nil {
log.Printf("User cannot send group message: %v", err)
// 发送错误消息
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
"error": "Cannot send group message",
"reason": err.Error(),
"type": "group_message_error",
"groupId": groupID,
})
if client.Send != nil {
msgBytes, _ := json.Marshal(errorMsg)
client.Send <- msgBytes
}
return
}
// 检查@所有人权限(只有群主和管理员可以@所有人)
if mentionAll {
if !h.groupService.IsGroupAdmin(client.UserID, groupID) {
log.Printf("User %s has no permission to mention all in group %s", client.UserID, groupID)
mentionAll = false // 取消@所有人标记
}
}
// 创建消息
message := &model.Message{
ConversationID: conversationID,
SenderID: client.UserID,
Segments: segments,
Status: model.MessageStatusNormal,
MentionAll: mentionAll,
}
// 序列化mentionUsers为JSON
if len(mentionUsers) > 0 {
mentionUsersJSON, _ := json.Marshal(mentionUsers)
message.MentionUsers = string(mentionUsersJSON)
}
// 保存消息到数据库(只存库,不发私聊 WebSocket 帧,群消息通过 BroadcastGroupMessage 单独广播)
savedMessage, err := h.chatService.SaveMessage(context.Background(), client.UserID, conversationID, segments, nil)
if err != nil {
log.Printf("Failed to save group message: %v", err)
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
"error": "Failed to save group message",
})
if client.Send != nil {
msgBytes, _ := json.Marshal(errorMsg)
client.Send <- msgBytes
}
return
}
// 更新消息的mention信息
if len(mentionUsers) > 0 || mentionAll {
message.ID = savedMessage.ID
message.Seq = savedMessage.Seq
}
// 构造群消息响应
groupMsg := &ws.GroupMessage{
ID: savedMessage.ID,
ConversationID: conversationID,
GroupID: groupID,
SenderID: client.UserID,
Seq: savedMessage.Seq,
Segments: segments,
MentionUsers: mentionUsers,
MentionAll: mentionAll,
CreatedAt: savedMessage.CreatedAt.UnixMilli(),
}
// 广播消息给群组所有成员(排除发送者)
h.BroadcastGroupMessage(groupID, groupMsg, client.UserID)
// 发送确认消息给发送者作为meta事件
// 使用 meta 事件格式发送 ack
metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{
"detail_type": ws.MetaDetailTypeAck,
"conversation_id": conversationID,
"group_id": groupID,
"id": savedMessage.ID,
"user_id": client.UserID,
"sender_id": client.UserID,
"seq": savedMessage.Seq,
"segments": segments,
"created_at": savedMessage.CreatedAt.UnixMilli(),
})
if client.Send != nil {
msgBytes, _ := json.Marshal(metaAckMsg)
client.Send <- msgBytes
} else {
log.Printf("[ERROR HandleGroupMessageSend] client.Send 为 nil, userID=%s", client.UserID)
}
// 处理@提及通知
if len(mentionUsers) > 0 || mentionAll {
// 提取文本正文(不含 @ 部分)
textContent := dto.ExtractTextContentFromModel(segments)
// 在通知内容前拼接被@的真实昵称,通过群成员列表查找
mentionContent := h.buildMentionContent(groupID, mentionUsers, mentionAll, textContent)
h.handleGroupMention(groupID, savedMessage.ID, client.UserID, mentionContent, mentionUsers, mentionAll)
}
}
// handleGroupTyping 处理群输入状态
func (h *WebSocketHandler) handleGroupTyping(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
return
}
groupIDFloat, _ := data["groupId"].(float64)
if groupIDFloat == 0 {
return
}
groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
isTyping, _ := data["isTyping"].(bool)
// 验证用户是否是群成员
// client.UserID 已经是 string 格式的 UUID
isMember, err := h.groupRepo.IsMember(groupID, client.UserID)
if err != nil || !isMember {
return
}
// 获取用户信息
user, err := h.userRepo.GetByID(client.UserID)
if err != nil {
return
}
// 构造输入状态消息
typingMsg := &ws.GroupTypingMessage{
GroupID: groupID,
UserID: client.UserID,
Username: user.Username,
IsTyping: isTyping,
}
// 广播给群组其他成员
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg)
h.BroadcastGroupNoticeExclude(groupID, wsMsg, client.UserID)
}
// handleGroupReadReceipt 处理群消息已读回执
func (h *WebSocketHandler) handleGroupReadReceipt(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
return
}
conversationID, _ := data["conversationId"].(string)
if conversationID == "" {
return
}
lastReadSeq, _ := data["lastReadSeq"].(float64)
if lastReadSeq == 0 {
return
}
// 标记已读
if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil {
log.Printf("Failed to mark group message as read: %v", err)
}
}
// handleGroupRecall 处理群消息撤回
func (h *WebSocketHandler) handleGroupRecall(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
return
}
messageID, _ := data["messageId"].(string)
if messageID == "" {
return
}
groupIDFloat, _ := data["groupId"].(float64)
if groupIDFloat == 0 {
return
}
groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
// 撤回消息
if err := h.chatService.RecallMessage(context.Background(), messageID, client.UserID); err != nil {
log.Printf("Failed to recall group message: %v", err)
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
"error": "Failed to recall message",
})
if client.Send != nil {
msgBytes, _ := json.Marshal(errorMsg)
client.Send <- msgBytes
}
return
}
// 广播撤回通知给群组所有成员
recallNotice := ws.CreateWSMessage(ws.MessageTypeGroupRecall, map[string]interface{}{
"messageId": messageID,
"groupId": groupID,
"userId": client.UserID,
"timestamp": time.Now().UnixMilli(),
})
h.BroadcastGroupNotice(groupID, recallNotice)
}
// handleGroupMention 处理群消息@提及通知
func (h *WebSocketHandler) handleGroupMention(groupID string, messageID, senderID, content string, mentionUsers []uint64, mentionAll bool) {
// 如果@所有人,获取所有群成员
if mentionAll {
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for mention all: %v", err)
return
}
for _, member := range members {
// 不通知发送者自己
memberIDStr := member.UserID
if memberIDStr == senderID {
continue
}
// 发送@提及通知
mentionMsg := &ws.GroupMentionMessage{
GroupID: groupID,
MessageID: messageID,
FromUserID: senderID,
Content: truncateContent(content, 50),
MentionAll: true,
CreatedAt: time.Now().UnixMilli(),
}
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg)
h.wsManager.SendToUser(memberIDStr, wsMsg)
}
return
}
// 处理特定用户的@提及
for _, userID := range mentionUsers {
// userID 是 uint64转换为 string
userIDStr := strconv.FormatUint(userID, 10)
if userIDStr == senderID {
continue // 不通知发送者自己
}
mentionMsg := &ws.GroupMentionMessage{
GroupID: groupID,
MessageID: messageID,
FromUserID: senderID,
Content: truncateContent(content, 50),
MentionAll: false,
CreatedAt: time.Now().UnixMilli(),
}
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg)
h.wsManager.SendToUser(userIDStr, wsMsg)
}
}
// buildMentionContent 构建@提及通知的内容,通过群成员列表查找被@用户的真实昵称
func (h *WebSocketHandler) buildMentionContent(groupID string, mentionUsers []uint64, mentionAll bool, textBody string) string {
var prefix string
if mentionAll {
prefix = "@所有人 "
} else if len(mentionUsers) > 0 {
// 查询群成员列表,找到被@用户的昵称
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err == nil {
memberNickMap := make(map[string]string, len(members))
for _, m := range members {
displayName := m.Nickname
if displayName == "" {
displayName = m.UserID
}
memberNickMap[m.UserID] = displayName
}
for _, uid := range mentionUsers {
uidStr := strconv.FormatUint(uid, 10)
if name, ok := memberNickMap[uidStr]; ok {
prefix += "@" + name + " "
} else {
prefix += "@某人 "
}
}
} else {
for range mentionUsers {
prefix += "@某人 "
}
}
}
return prefix + textBody
}
// BroadcastGroupMessage 向群组所有成员广播消息
func (h *WebSocketHandler) BroadcastGroupMessage(groupID string, message *ws.GroupMessage, excludeUserID string) {
// 获取群组所有成员
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for broadcast: %v", err)
return
}
// 创建WebSocket消息
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMessage, message)
// 遍历成员,如果在线则发送消息
for _, member := range members {
memberIDStr := member.UserID
// 排除发送者
if memberIDStr == excludeUserID {
continue
}
// 发送消息
h.wsManager.SendToUser(memberIDStr, wsMsg)
}
}
// BroadcastGroupNotice 广播群组通知给所有成员
func (h *WebSocketHandler) BroadcastGroupNotice(groupID string, notice *ws.WSMessage) {
// 获取群组所有成员
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for notice broadcast: %v", err)
return
}
// 遍历成员,如果在线则发送通知
for _, member := range members {
memberIDStr := member.UserID
h.wsManager.SendToUser(memberIDStr, notice)
}
}
// BroadcastGroupNoticeExclude 广播群组通知给所有成员(排除指定用户)
func (h *WebSocketHandler) BroadcastGroupNoticeExclude(groupID string, notice *ws.WSMessage, excludeUserID string) {
// 获取群组所有成员
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for notice broadcast: %v", err)
return
}
// 遍历成员,如果在线则发送通知
for _, member := range members {
memberIDStr := member.UserID
if memberIDStr == excludeUserID {
continue
}
h.wsManager.SendToUser(memberIDStr, notice)
}
}
// SendGroupMemberNotice 发送群成员变动通知
func (h *WebSocketHandler) SendGroupMemberNotice(noticeType string, groupID string, data *ws.GroupNoticeData) {
notice := &ws.GroupNoticeMessage{
NoticeType: noticeType,
GroupID: groupID,
Data: data,
Timestamp: time.Now().UnixMilli(),
}
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupNotice, notice)
h.BroadcastGroupNotice(groupID, wsMsg)
}
// truncateContent 截断内容
func truncateContent(content string, maxLen int) string {
if len(content) <= maxLen {
return content
}
return content[:maxLen] + "..."
}
// BroadcastGroupTyping 向群组所有成员广播输入状态
func (h *WebSocketHandler) BroadcastGroupTyping(groupID string, typingMsg *ws.GroupTypingMessage, excludeUserID string) {
// 获取群组所有成员
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for typing broadcast: %v", err)
return
}
// 创建WebSocket消息
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg)
// 遍历成员,如果在线则发送消息
for _, member := range members {
memberIDStr := member.UserID
// 排除指定用户
if memberIDStr == excludeUserID {
continue
}
// 发送消息
h.wsManager.SendToUser(memberIDStr, wsMsg)
}
}
// BroadcastGroupRead 向群组所有成员广播已读状态
func (h *WebSocketHandler) BroadcastGroupRead(groupID string, readMsg map[string]interface{}, excludeUserID string) {
// 获取群组所有成员
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for read broadcast: %v", err)
return
}
// 创建WebSocket消息
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupRead, readMsg)
// 遍历成员,如果在线则发送消息
for _, member := range members {
memberIDStr := member.UserID
// 排除指定用户
if memberIDStr == excludeUserID {
continue
}
// 发送消息
h.wsManager.SendToUser(memberIDStr, wsMsg)
}
}

View File

@@ -1,18 +1,10 @@
package middleware
import (
"log"
"strings"
"github.com/gin-gonic/gin"
)
import "github.com/gin-gonic/gin"
// CORS CORS中间件
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取请求路径
path := c.Request.URL.Path
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
// 添加 WebSocket 升级所需的头
@@ -22,25 +14,10 @@ func CORS() gin.HandlerFunc {
// 处理 WebSocket 升级请求的预检
if c.Request.Method == "OPTIONS" {
log.Printf("[CORS] OPTIONS 预检请求: %s", path)
c.AbortWithStatus(204)
return
}
// 针对 WebSocket 路径的特殊处理
if path == "/ws" {
connection := c.GetHeader("Connection")
upgrade := c.GetHeader("Upgrade")
log.Printf("[CORS] WebSocket 请求: Connection=%s, Upgrade=%s", connection, upgrade)
// 检查是否是有效的 WebSocket 升级请求
if strings.Contains(strings.ToLower(connection), "upgrade") && strings.ToLower(upgrade) == "websocket" {
log.Printf("[CORS] 有效的 WebSocket 升级请求")
} else {
log.Printf("[CORS] 警告: 不是有效的 WebSocket 升级请求!")
}
}
c.Next()
}
}

View File

@@ -58,7 +58,7 @@ type Post struct {
// 时间戳
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_posts_status_created,priority:2,sort:desc;index:idx_posts_user_status_created,priority:3,sort:desc;index:idx_posts_hot_score_created,priority:2,sort:desc"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime:false"`
}
// BeforeCreate 创建前生成UUID

152
internal/pkg/sse/hub.go Normal file
View File

@@ -0,0 +1,152 @@
package sse
import (
"encoding/json"
"strconv"
"sync"
"sync/atomic"
"time"
)
const (
defaultUserBufferSize = 128
maxReplayEvents = 200
)
type Event struct {
ID uint64 `json:"event_id"`
Event string `json:"event"`
TS int64 `json:"ts"`
Payload interface{} `json:"payload"`
}
type subscriber struct {
id uint64
ch chan Event
quit chan struct{}
}
type Hub struct {
seq uint64
mu sync.RWMutex
subscribers map[string]map[uint64]*subscriber
history map[string][]Event
}
func NewHub() *Hub {
return &Hub{
subscribers: make(map[string]map[uint64]*subscriber),
history: make(map[string][]Event),
}
}
func (h *Hub) NextID() uint64 {
return atomic.AddUint64(&h.seq, 1)
}
func ParseEventID(raw string) uint64 {
if raw == "" {
return 0
}
id, err := strconv.ParseUint(raw, 10, 64)
if err != nil {
return 0
}
return id
}
func (h *Hub) Subscribe(userID string, afterID uint64) (chan Event, func(), []Event) {
subID := h.NextID()
sub := &subscriber{
id: subID,
ch: make(chan Event, defaultUserBufferSize),
quit: make(chan struct{}),
}
h.mu.Lock()
if _, ok := h.subscribers[userID]; !ok {
h.subscribers[userID] = make(map[uint64]*subscriber)
}
h.subscribers[userID][subID] = sub
replay := make([]Event, 0)
for _, e := range h.history[userID] {
if e.ID > afterID {
replay = append(replay, e)
}
}
h.mu.Unlock()
cancel := func() {
h.mu.Lock()
defer h.mu.Unlock()
if userSubs, ok := h.subscribers[userID]; ok {
if s, exists := userSubs[subID]; exists {
close(s.quit)
delete(userSubs, subID)
close(s.ch)
}
if len(userSubs) == 0 {
delete(h.subscribers, userID)
}
}
}
return sub.ch, cancel, replay
}
func (h *Hub) HasSubscribers(userID string) bool {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.subscribers[userID]) > 0
}
func (h *Hub) PublishToUser(userID string, eventName string, payload interface{}) Event {
ev := Event{
ID: h.NextID(),
Event: eventName,
TS: time.Now().UnixMilli(),
Payload: payload,
}
h.publish(userID, ev)
return ev
}
func (h *Hub) PublishToUsers(userIDs []string, eventName string, payload interface{}) {
for _, uid := range userIDs {
h.PublishToUser(uid, eventName, payload)
}
}
func (h *Hub) publish(userID string, ev Event) {
h.mu.Lock()
history := append(h.history[userID], ev)
if len(history) > maxReplayEvents {
history = history[len(history)-maxReplayEvents:]
}
h.history[userID] = history
targets := make([]*subscriber, 0, len(h.subscribers[userID]))
for _, s := range h.subscribers[userID] {
targets = append(targets, s)
}
h.mu.Unlock()
for _, s := range targets {
select {
case <-s.quit:
case s.ch <- ev:
default:
// 慢消费者丢弃单条消息,客户端可通过 Last-Event-ID + HTTP 同步补偿
}
}
}
func EncodeData(ev Event) (string, error) {
body, err := json.Marshal(ev.Payload)
if err != nil {
return "", err
}
return string(body), nil
}

View File

@@ -1,435 +0,0 @@
package websocket
import (
"carrot_bbs/internal/model"
"encoding/json"
"log"
"sync"
"time"
"github.com/gorilla/websocket"
)
// WebSocket消息类型常量
const (
MessageTypePing = "ping"
MessageTypePong = "pong"
MessageTypeMessage = "message"
MessageTypeTyping = "typing"
MessageTypeRead = "read"
MessageTypeAck = "ack"
MessageTypeError = "error"
MessageTypeRecall = "recall" // 撤回消息
MessageTypeSystem = "system" // 系统消息
MessageTypeNotification = "notification" // 通知消息
MessageTypeAnnouncement = "announcement" // 公告消息
// 群组相关消息类型
MessageTypeGroupMessage = "group_message" // 群消息
MessageTypeGroupTyping = "group_typing" // 群输入状态
MessageTypeGroupNotice = "group_notice" // 群组通知(成员变动等)
MessageTypeGroupMention = "group_mention" // @提及通知
MessageTypeGroupRead = "group_read" // 群消息已读
MessageTypeGroupRecall = "group_recall" // 群消息撤回
// Meta事件详细类型
MetaDetailTypeHeartbeat = "heartbeat"
MetaDetailTypeTyping = "typing"
MetaDetailTypeAck = "ack" // 消息发送确认
MetaDetailTypeRead = "read" // 已读回执
)
// WSMessage WebSocket消息结构
type WSMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
Timestamp int64 `json:"timestamp"`
}
// ChatMessage 聊天消息结构
type ChatMessage struct {
ID string `json:"id"`
ConversationID string `json:"conversation_id"`
SenderID string `json:"sender_id"`
Seq int64 `json:"seq"`
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
ReplyToID *string `json:"reply_to_id,omitempty"`
CreatedAt int64 `json:"created_at"`
}
// SystemMessage 系统消息结构
type SystemMessage struct {
ID string `json:"id"` // 消息ID
Type string `json:"type"` // 消息子类型account_banned, post_approved等
Title string `json:"title"` // 消息标题
Content string `json:"content"` // 消息内容
Data map[string]interface{} `json:"data"` // 额外数据
CreatedAt int64 `json:"created_at"` // 创建时间戳
}
// NotificationMessage 通知消息结构
type NotificationMessage struct {
ID string `json:"id"` // 通知ID
Type string `json:"type"` // 通知类型like, comment, follow, mention等
Title string `json:"title"` // 通知标题
Content string `json:"content"` // 通知内容
TriggerUser *NotificationUser `json:"trigger_user"` // 触发用户
ResourceType string `json:"resource_type"` // 资源类型post, comment等
ResourceID string `json:"resource_id"` // 资源ID
Extra map[string]interface{} `json:"extra"` // 额外数据
CreatedAt int64 `json:"created_at"` // 创建时间戳
}
// NotificationUser 通知中的用户信息
type NotificationUser struct {
ID string `json:"id"`
Username string `json:"username"`
Avatar string `json:"avatar"`
}
// AnnouncementMessage 公告消息结构
type AnnouncementMessage struct {
ID string `json:"id"` // 公告ID
Title string `json:"title"` // 公告标题
Content string `json:"content"` // 公告内容
Priority int `json:"priority"` // 优先级1-10
CreatedAt int64 `json:"created_at"` // 创建时间戳
}
// GroupMessage 群消息结构
type GroupMessage struct {
ID string `json:"id"` // 消息ID
ConversationID string `json:"conversation_id"` // 会话ID群聊会话
GroupID string `json:"group_id"` // 群组ID
SenderID string `json:"sender_id"` // 发送者ID
Seq int64 `json:"seq"` // 消息序号
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID
MentionUsers []uint64 `json:"mention_users,omitempty"` // @的用户ID列表
MentionAll bool `json:"mention_all"` // 是否@所有人
CreatedAt int64 `json:"created_at"` // 创建时间戳
}
// GroupTypingMessage 群输入状态消息
type GroupTypingMessage struct {
GroupID string `json:"group_id"` // 群组ID
UserID string `json:"user_id"` // 用户ID
Username string `json:"username"` // 用户名
IsTyping bool `json:"is_typing"` // 是否正在输入
}
// GroupNoticeMessage 群组通知消息
type GroupNoticeMessage struct {
NoticeType string `json:"notice_type"` // 通知类型member_join, member_leave, member_removed, role_changed, muted, unmuted, group_dissolved
GroupID string `json:"group_id"` // 群组ID
Data interface{} `json:"data"` // 通知数据
Timestamp int64 `json:"timestamp"` // 时间戳
MessageID string `json:"message_id,omitempty"` // 消息ID如果通知保存为消息
Seq int64 `json:"seq,omitempty"` // 消息序号(如果通知保存为消息)
}
// GroupNoticeData 通知数据结构
type GroupNoticeData struct {
// 成员变动
UserID string `json:"user_id,omitempty"` // 变动的用户ID
Username string `json:"username,omitempty"` // 用户名
OperatorID string `json:"operator_id,omitempty"` // 操作者ID
OpName string `json:"op_name,omitempty"` // 操作者名称
NewRole string `json:"new_role,omitempty"` // 新角色
OldRole string `json:"old_role,omitempty"` // 旧角色
MemberCount int `json:"member_count,omitempty"` // 当前成员数
// 群设置变更
MuteAll bool `json:"mute_all,omitempty"` // 全员禁言状态
}
// GroupMentionMessage @提及通知消息
type GroupMentionMessage struct {
GroupID string `json:"group_id"` // 群组ID
MessageID string `json:"message_id"` // 消息ID
FromUserID string `json:"from_user_id"` // 发送者ID
FromName string `json:"from_name"` // 发送者名称
Content string `json:"content"` // 消息内容预览
MentionAll bool `json:"mention_all"` // 是否@所有人
CreatedAt int64 `json:"created_at"` // 创建时间戳
}
// AckMessage 消息发送确认结构
type AckMessage struct {
ConversationID string `json:"conversation_id"` // 会话ID
GroupID string `json:"group_id,omitempty"` // 群组ID群聊时
ID string `json:"id"` // 消息ID
SenderID string `json:"sender_id"` // 发送者ID
Seq int64 `json:"seq"` // 消息序号
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
CreatedAt int64 `json:"created_at"` // 创建时间戳
}
// Client WebSocket客户端
type Client struct {
ID string
UserID string
Conn *websocket.Conn
Send chan []byte
Manager *WebSocketManager
IsClosed bool
Mu sync.Mutex
closeOnce sync.Once // 确保 Send channel 只关闭一次
}
// WebSocketManager WebSocket连接管理器
type WebSocketManager struct {
clients map[string]*Client // userID -> Client
register chan *Client
unregister chan *Client
broadcast chan *BroadcastMessage
mutex sync.RWMutex
}
// BroadcastMessage 广播消息
type BroadcastMessage struct {
Message *WSMessage
ExcludeUser string // 排除的用户ID为空表示不排除
TargetUser string // 目标用户ID为空表示广播给所有用户
}
// NewWebSocketManager 创建WebSocket管理器
func NewWebSocketManager() *WebSocketManager {
return &WebSocketManager{
clients: make(map[string]*Client),
register: make(chan *Client, 100),
unregister: make(chan *Client, 100),
broadcast: make(chan *BroadcastMessage, 100),
}
}
// Start 启动管理器
func (m *WebSocketManager) Start() {
go func() {
for {
select {
case client := <-m.register:
m.mutex.Lock()
m.clients[client.UserID] = client
m.mutex.Unlock()
log.Printf("WebSocket client connected: userID=%s, 当前在线用户数=%d", client.UserID, len(m.clients))
case client := <-m.unregister:
m.mutex.Lock()
if _, ok := m.clients[client.UserID]; ok {
delete(m.clients, client.UserID)
// 使用 closeOnce 确保 channel 只关闭一次,避免 panic
client.closeOnce.Do(func() {
close(client.Send)
})
log.Printf("WebSocket client disconnected: userID=%s", client.UserID)
}
m.mutex.Unlock()
case broadcast := <-m.broadcast:
m.sendMessage(broadcast)
}
}
}()
}
// Register 注册客户端
func (m *WebSocketManager) Register(client *Client) {
m.register <- client
}
// Unregister 注销客户端
func (m *WebSocketManager) Unregister(client *Client) {
m.unregister <- client
}
// Broadcast 广播消息给所有用户
func (m *WebSocketManager) Broadcast(msg *WSMessage) {
m.broadcast <- &BroadcastMessage{
Message: msg,
TargetUser: "",
}
}
// SendToUser 发送消息给指定用户
func (m *WebSocketManager) SendToUser(userID string, msg *WSMessage) {
m.broadcast <- &BroadcastMessage{
Message: msg,
TargetUser: userID,
}
}
// SendToUsers 发送消息给指定用户列表
func (m *WebSocketManager) SendToUsers(userIDs []string, msg *WSMessage) {
for _, userID := range userIDs {
m.SendToUser(userID, msg)
}
}
// GetClient 获取客户端
func (m *WebSocketManager) GetClient(userID string) (*Client, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
client, ok := m.clients[userID]
return client, ok
}
// GetAllClients 获取所有客户端
func (m *WebSocketManager) GetAllClients() map[string]*Client {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.clients
}
// GetClientCount 获取在线用户数量
func (m *WebSocketManager) GetClientCount() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
return len(m.clients)
}
// IsUserOnline 检查用户是否在线
func (m *WebSocketManager) IsUserOnline(userID string) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
_, ok := m.clients[userID]
return ok
}
// sendMessage 发送消息
func (m *WebSocketManager) sendMessage(broadcast *BroadcastMessage) {
msgBytes, err := json.Marshal(broadcast.Message)
if err != nil {
log.Printf("Failed to marshal message: %v", err)
return
}
m.mutex.RLock()
defer m.mutex.RUnlock()
for userID, client := range m.clients {
// 如果指定了目标用户,只发送给目标用户
if broadcast.TargetUser != "" && userID != broadcast.TargetUser {
continue
}
// 如果指定了排除用户,跳过
if broadcast.ExcludeUser != "" && userID == broadcast.ExcludeUser {
continue
}
select {
case client.Send <- msgBytes:
default:
log.Printf("Failed to send message to user %s: channel full", userID)
}
}
}
// SendPing 发送心跳
func (c *Client) SendPing() error {
c.Mu.Lock()
defer c.Mu.Unlock()
if c.IsClosed {
return nil
}
msg := WSMessage{
Type: MessageTypePing,
Data: nil,
Timestamp: time.Now().UnixMilli(),
}
msgBytes, _ := json.Marshal(msg)
return c.Conn.WriteMessage(websocket.TextMessage, msgBytes)
}
// SendPong 发送Pong响应
func (c *Client) SendPong() error {
c.Mu.Lock()
defer c.Mu.Unlock()
if c.IsClosed {
return nil
}
msg := WSMessage{
Type: MessageTypePong,
Data: nil,
Timestamp: time.Now().UnixMilli(),
}
msgBytes, _ := json.Marshal(msg)
return c.Conn.WriteMessage(websocket.TextMessage, msgBytes)
}
// WritePump 写入泵将消息从Manager发送到客户端
func (c *Client) WritePump() {
defer func() {
c.Conn.Close()
c.Mu.Lock()
c.IsClosed = true
c.Mu.Unlock()
}()
for {
message, ok := <-c.Send
if !ok {
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
c.Mu.Lock()
if c.IsClosed {
c.Mu.Unlock()
return
}
err := c.Conn.WriteMessage(websocket.TextMessage, message)
c.Mu.Unlock()
if err != nil {
log.Printf("Write error: %v", err)
return
}
}
}
// ReadPump 读取泵,从客户端读取消息
func (c *Client) ReadPump(handler func(msg *WSMessage)) {
defer func() {
c.Manager.Unregister(c)
c.Conn.Close()
c.Mu.Lock()
c.IsClosed = true
c.Mu.Unlock()
}()
c.Conn.SetReadLimit(512 * 1024) // 512KB
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
c.Conn.SetPongHandler(func(string) error {
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
for {
_, message, err := c.Conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket error: %v", err)
}
break
}
var wsMsg WSMessage
if err := json.Unmarshal(message, &wsMsg); err != nil {
log.Printf("Failed to unmarshal message: %v", err)
continue
}
handler(&wsMsg)
}
}
// CreateWSMessage 创建WebSocket消息
func CreateWSMessage(msgType string, data interface{}) *WSMessage {
return &WSMessage{
Type: msgType,
Data: data,
Timestamp: time.Now().UnixMilli(),
}
}

View File

@@ -18,32 +18,7 @@ func NewCommentRepository(db *gorm.DB) *CommentRepository {
// Create 创建评论
func (r *CommentRepository) Create(comment *model.Comment) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 创建评论
err := tx.Create(comment).Error
if err != nil {
return err
}
// 增加帖子的评论数并同步热度分
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
Updates(map[string]interface{}{
"comments_count": gorm.Expr("comments_count + 1"),
"hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"),
}).Error; err != nil {
return err
}
// 如果是回复,增加父评论的回复数
if comment.ParentID != nil && *comment.ParentID != "" {
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil {
return err
}
}
return nil
})
return r.db.Create(comment).Error
}
// GetByID 根据ID获取评论
@@ -87,23 +62,52 @@ func (r *CommentRepository) Delete(id string) error {
return err
}
// 减少帖子的评论数并同步热度分
// 仅已发布评论才参与统计,避免 pending/rejected 影响计数
if comment.Status == model.CommentStatusPublished {
// 减少帖子的评论数并同步热度分
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
Updates(map[string]interface{}{
"comments_count": gorm.Expr("comments_count - 1"),
"hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"),
}).Error; err != nil {
return err
}
// 如果是回复,减少父评论的回复数
if comment.ParentID != nil && *comment.ParentID != "" {
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil {
return err
}
}
}
return nil
})
}
// ApplyPublishedStats 在评论审核通过后更新帖子评论数/回复数
func (r *CommentRepository) ApplyPublishedStats(comment *model.Comment) error {
if comment == nil {
return nil
}
return r.db.Transaction(func(tx *gorm.DB) error {
// 增加帖子的评论数并同步热度分
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
Updates(map[string]interface{}{
"comments_count": gorm.Expr("comments_count - 1"),
"hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"),
"comments_count": gorm.Expr("comments_count + 1"),
"hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"),
}).Error; err != nil {
return err
}
// 如果是回复,减少父评论的回复数
// 如果是回复,增加父评论的回复数
if comment.ParentID != nil && *comment.ParentID != "" {
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil {
UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil {
return err
}
}
return nil
})
}

View File

@@ -2,6 +2,7 @@ package repository
import (
"carrot_bbs/internal/model"
"time"
"gorm.io/gorm"
)
@@ -52,9 +53,41 @@ func (r *PostRepository) GetByID(id string) (*model.Post, error) {
// Update 更新帖子
func (r *PostRepository) Update(post *model.Post) error {
post.UpdatedAt = time.Now()
return r.db.Save(post).Error
}
// UpdateWithImages 更新帖子及其图片images=nil 表示不更新图片)
func (r *PostRepository) UpdateWithImages(post *model.Post, images *[]string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
post.UpdatedAt = time.Now()
if err := tx.Save(post).Error; err != nil {
return err
}
if images == nil {
return nil
}
if err := tx.Where("post_id = ?", post.ID).Delete(&model.PostImage{}).Error; err != nil {
return err
}
for i, url := range *images {
image := &model.PostImage{
PostID: post.ID,
URL: url,
SortOrder: i,
}
if err := tx.Create(image).Error; err != nil {
return err
}
}
return nil
})
}
// UpdateModerationStatus 更新帖子审核状态
func (r *PostRepository) UpdateModerationStatus(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error {
updates := map[string]interface{}{
@@ -100,15 +133,24 @@ func (r *PostRepository) Delete(id string) error {
}
// List 分页获取帖子列表
func (r *PostRepository) List(page, pageSize int, userID string) ([]*model.Post, int64, error) {
// includePending=true 时,仅在指定 userID 下额外返回 pending用于作者查看自己待审核帖子
func (r *PostRepository) List(page, pageSize int, userID string, includePending bool) ([]*model.Post, int64, error) {
var posts []*model.Post
var total int64
query := r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished)
query := r.db.Model(&model.Post{})
if userID != "" {
query = query.Where("user_id = ?", userID)
}
if includePending && userID != "" {
query = query.Where("status IN ?", []model.PostStatus{
model.PostStatusPublished,
model.PostStatusPending,
})
} else {
query = query.Where("status = ?", model.PostStatusPublished)
}
query.Count(&total)
@@ -119,14 +161,32 @@ func (r *PostRepository) List(page, pageSize int, userID string) ([]*model.Post,
}
// GetUserPosts 获取用户帖子
func (r *PostRepository) GetUserPosts(userID string, page, pageSize int) ([]*model.Post, int64, error) {
func (r *PostRepository) GetUserPosts(userID string, page, pageSize int, includePending bool) ([]*model.Post, int64, error) {
var posts []*model.Post
var total int64
r.db.Model(&model.Post{}).Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Count(&total)
statusQuery := r.db.Model(&model.Post{}).Where("user_id = ?", userID)
if includePending {
statusQuery = statusQuery.Where("status IN ?", []model.PostStatus{
model.PostStatusPublished,
model.PostStatusPending,
})
} else {
statusQuery = statusQuery.Where("status = ?", model.PostStatusPublished)
}
statusQuery.Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
listQuery := r.db.Where("user_id = ?", userID)
if includePending {
listQuery = listQuery.Where("status IN ?", []model.PostStatus{
model.PostStatusPublished,
model.PostStatusPending,
})
} else {
listQuery = listQuery.Where("status = ?", model.PostStatusPublished)
}
err := listQuery.Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
return posts, total, err
}
@@ -256,7 +316,8 @@ func (r *PostRepository) IsFavorited(postID, userID string) bool {
// IncrementViews 增加帖子观看量
func (r *PostRepository) IncrementViews(postID string) error {
return r.db.Model(&model.Post{}).Where("id = ?", postID).
Updates(map[string]interface{}{
// 浏览量属于统计字段不应影响帖子内容更新时间updated_at
UpdateColumns(map[string]interface{}{
"views_count": gorm.Expr("views_count + 1"),
"hot_score": gorm.Expr("likes_count * 2 + comments_count * 3 + (views_count + 1) * 0.1"),
}).Error

View File

@@ -177,7 +177,9 @@ func (r *UserRepository) RefreshFollowersCount(userID string) error {
// GetPostsCount 获取用户帖子数(实时计算)
func (r *UserRepository) GetPostsCount(userID string) (int64, error) {
var count int64
err := r.db.Model(&model.Post{}).Where("user_id = ?", userID).Count(&count).Error
err := r.db.Model(&model.Post{}).
Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).
Count(&count).Error
return count, err
}
@@ -202,7 +204,7 @@ func (r *UserRepository) GetPostsCountBatch(userIDs []string) (map[string]int64,
var counts []CountResult
err := r.db.Model(&model.Post{}).
Select("user_id, count(*) as count").
Where("user_id IN ?", userIDs).
Where("user_id IN ? AND status = ?", userIDs, model.PostStatusPublished).
Group("user_id").
Scan(&counts).Error
if err != nil {

View File

@@ -17,7 +17,6 @@ type Router struct {
messageHandler *handler.MessageHandler
notificationHandler *handler.NotificationHandler
uploadHandler *handler.UploadHandler
wsHandler *handler.WebSocketHandler
pushHandler *handler.PushHandler
systemMessageHandler *handler.SystemMessageHandler
groupHandler *handler.GroupHandler
@@ -36,7 +35,6 @@ func New(
notificationHandler *handler.NotificationHandler,
uploadHandler *handler.UploadHandler,
jwtService *service.JWTService,
wsHandler *handler.WebSocketHandler,
pushHandler *handler.PushHandler,
systemMessageHandler *handler.SystemMessageHandler,
groupHandler *handler.GroupHandler,
@@ -55,7 +53,6 @@ func New(
messageHandler: messageHandler,
notificationHandler: notificationHandler,
uploadHandler: uploadHandler,
wsHandler: wsHandler,
pushHandler: pushHandler,
systemMessageHandler: systemMessageHandler,
groupHandler: groupHandler,
@@ -79,11 +76,6 @@ func (r *Router) setupRoutes() {
c.JSON(200, gin.H{"status": "ok"})
})
// WebSocket 路由
if r.wsHandler != nil {
r.engine.GET("/ws", r.wsHandler.HandleWebSocket)
}
// API v1
v1 := r.engine.Group("/api/v1")
{
@@ -210,10 +202,18 @@ func (r *Router) setupRoutes() {
conversations.POST("/set_pinned", r.messageHandler.HandleSetConversationPinned)
// 获取未读消息总数
conversations.GET("/unread/count", r.messageHandler.GetUnreadCount)
// 上报输入状态
conversations.POST("/typing", r.messageHandler.HandleTyping)
// 仅自己删除会话
conversations.DELETE("/:id/self", r.messageHandler.HandleDeleteConversationForSelf)
}
realtime := v1.Group("/realtime")
realtime.Use(authMiddleware)
{
realtime.GET("/sse", r.messageHandler.HandleSSE)
}
// 消息操作路由
messages := v1.Group("/messages")
messages.Use(authMiddleware)

View File

@@ -4,11 +4,11 @@ import (
"context"
"errors"
"fmt"
"log"
"time"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/websocket"
"carrot_bbs/internal/pkg/sse"
"carrot_bbs/internal/repository"
"gorm.io/gorm"
@@ -41,17 +41,13 @@ type ChatService interface {
RecallMessage(ctx context.Context, messageID string, userID string) error
DeleteMessage(ctx context.Context, messageID string, userID string) error
// WebSocket相关
// 实时事件相关
SendTyping(ctx context.Context, senderID string, conversationID string)
BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string)
// 系统消息推送
// 在线状态
IsUserOnline(userID string) bool
PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error
PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error
PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error
// 仅保存消息到数据库,不发送 WebSocket 推送(供群聊等自行推送的场景使用)
// 仅保存消息到数据库,不发送实时推送(供群聊等自行推送的场景使用)
SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
}
@@ -61,7 +57,7 @@ type chatServiceImpl struct {
repo *repository.MessageRepository
userRepo *repository.UserRepository
sensitive SensitiveService
wsManager *websocket.WebSocketManager
sseHub *sse.Hub
}
// NewChatService 创建聊天服务
@@ -70,17 +66,24 @@ func NewChatService(
repo *repository.MessageRepository,
userRepo *repository.UserRepository,
sensitive SensitiveService,
wsManager *websocket.WebSocketManager,
sseHub *sse.Hub,
) ChatService {
return &chatServiceImpl{
db: db,
repo: repo,
userRepo: userRepo,
sensitive: sensitive,
wsManager: wsManager,
sseHub: sseHub,
}
}
func (s *chatServiceImpl) publishSSEToUsers(userIDs []string, event string, payload interface{}) {
if s.sseHub == nil || len(userIDs) == 0 {
return
}
s.sseHub.PublishToUsers(userIDs, event, payload)
}
// GetOrCreateConversation 获取或创建私聊会话
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
@@ -228,30 +231,30 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv
return nil, fmt.Errorf("failed to save message: %w", err)
}
// 发送消息给接收者
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, websocket.ChatMessage{
ID: message.ID,
ConversationID: message.ConversationID,
SenderID: senderID,
Segments: message.Segments,
Seq: message.Seq,
CreatedAt: message.CreatedAt.UnixMilli(),
})
// 获取会话中的其他参与者
// 获取会话中的参与者并发送 SSE
participants, err := s.repo.GetConversationParticipants(conversationID)
if err == nil {
targetIDs := make([]string, 0, len(participants))
for _, p := range participants {
targetIDs = append(targetIDs, p.UserID)
}
detailType := "private"
if conv.Type == model.ConversationTypeGroup {
detailType = "group"
}
s.publishSSEToUsers(targetIDs, "chat_message", map[string]interface{}{
"detail_type": detailType,
"message": dto.ConvertMessageToResponse(message),
})
for _, p := range participants {
// 不发给自己
if p.UserID == senderID {
continue
}
// 如果接收者在线,发送实时消息
if s.wsManager != nil {
isOnline := s.wsManager.IsUserOnline(p.UserID)
if isOnline {
s.wsManager.SendToUser(p.UserID, wsMsg)
}
if totalUnread, uErr := s.repo.GetAllUnreadCount(p.UserID); uErr == nil {
s.publishSSEToUsers([]string{p.UserID}, "conversation_unread", map[string]interface{}{
"conversation_id": conversationID,
"total_unread": totalUnread,
})
}
}
}
@@ -337,25 +340,33 @@ func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string,
return fmt.Errorf("failed to update last read seq: %w", err)
}
// 发送已读回执(作为 meta 事件)
if s.wsManager != nil {
wsMsg := websocket.CreateWSMessage("meta", map[string]interface{}{
"detail_type": websocket.MetaDetailTypeRead,
"conversation_id": conversationID,
"seq": seq,
"user_id": userID,
})
// 获取会话中的所有参与者
participants, err := s.repo.GetConversationParticipants(conversationID)
if err == nil {
// 推送给会话中的所有参与者(包括自己)
for _, p := range participants {
if s.wsManager.IsUserOnline(p.UserID) {
s.wsManager.SendToUser(p.UserID, wsMsg)
}
participants, pErr := s.repo.GetConversationParticipants(conversationID)
if pErr == nil {
detailType := "private"
groupID := ""
if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
detailType = "group"
if conv.GroupID != nil {
groupID = *conv.GroupID
}
}
targetIDs := make([]string, 0, len(participants))
for _, p := range participants {
targetIDs = append(targetIDs, p.UserID)
}
s.publishSSEToUsers(targetIDs, "message_read", map[string]interface{}{
"detail_type": detailType,
"conversation_id": conversationID,
"group_id": groupID,
"user_id": userID,
"seq": seq,
})
}
if totalUnread, uErr := s.repo.GetAllUnreadCount(userID); uErr == nil {
s.publishSSEToUsers([]string{userID}, "conversation_unread", map[string]interface{}{
"conversation_id": conversationID,
"total_unread": totalUnread,
})
}
return nil
@@ -407,29 +418,35 @@ func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, u
return errors.New("message recall timeout (2 minutes)")
}
// 更新消息状态为已撤回
err = s.db.Model(&message).Update("status", model.MessageStatusRecalled).Error
// 更新消息状态为已撤回,并清空原始消息内容,仅保留撤回占位
err = s.db.Model(&message).Updates(map[string]interface{}{
"status": model.MessageStatusRecalled,
"segments": model.MessageSegments{},
}).Error
if err != nil {
return fmt.Errorf("failed to recall message: %w", err)
}
// 发送撤回通知
if s.wsManager != nil {
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeRecall, map[string]interface{}{
"messageId": messageID,
"conversationId": message.ConversationID,
"senderId": userID,
})
// 通知会话中的所有参与者
participants, err := s.repo.GetConversationParticipants(message.ConversationID)
if err == nil {
for _, p := range participants {
if s.wsManager.IsUserOnline(p.UserID) {
s.wsManager.SendToUser(p.UserID, wsMsg)
}
if participants, pErr := s.repo.GetConversationParticipants(message.ConversationID); pErr == nil {
detailType := "private"
groupID := ""
if conv, convErr := s.repo.GetConversation(message.ConversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
detailType = "group"
if conv.GroupID != nil {
groupID = *conv.GroupID
}
}
targetIDs := make([]string, 0, len(participants))
for _, p := range participants {
targetIDs = append(targetIDs, p.UserID)
}
s.publishSSEToUsers(targetIDs, "message_recall", map[string]interface{}{
"detail_type": detailType,
"conversation_id": message.ConversationID,
"group_id": groupID,
"message_id": messageID,
"sender_id": userID,
})
}
return nil
@@ -473,7 +490,7 @@ func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, u
// SendTyping 发送正在输入状态
func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) {
if s.wsManager == nil {
if s.sseHub == nil {
return
}
@@ -489,98 +506,34 @@ func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conve
return
}
detailType := "private"
if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
detailType = "group"
}
for _, p := range participants {
if p.UserID == senderID {
continue
}
// 发送正在输入状态
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeTyping, map[string]string{
"conversationId": conversationID,
"senderId": senderID,
})
if s.wsManager.IsUserOnline(p.UserID) {
s.wsManager.SendToUser(p.UserID, wsMsg)
if s.sseHub != nil {
s.sseHub.PublishToUser(p.UserID, "typing", map[string]interface{}{
"detail_type": detailType,
"conversation_id": conversationID,
"user_id": senderID,
"is_typing": true,
})
}
}
}
// BroadcastMessage 广播消息给用户
func (s *chatServiceImpl) BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string) {
if s.wsManager != nil {
s.wsManager.SendToUser(targetUser, msg)
}
}
// IsUserOnline 检查用户是否在线
func (s *chatServiceImpl) IsUserOnline(userID string) bool {
if s.wsManager == nil {
return false
if s.sseHub != nil {
return s.sseHub.HasSubscribers(userID)
}
return s.wsManager.IsUserOnline(userID)
return false
}
// PushSystemMessage 推送系统消息给指定用户
func (s *chatServiceImpl) PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error {
if s.wsManager == nil {
return errors.New("websocket manager not available")
}
if !s.wsManager.IsUserOnline(userID) {
return errors.New("user is offline")
}
sysMsg := &websocket.SystemMessage{
ID: "", // 由调用方生成
Type: msgType,
Title: title,
Content: content,
Data: data,
CreatedAt: time.Now().UnixMilli(),
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
s.wsManager.SendToUser(userID, wsMsg)
return nil
}
// PushNotificationMessage 推送通知消息给指定用户
func (s *chatServiceImpl) PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error {
if s.wsManager == nil {
return errors.New("websocket manager not available")
}
if !s.wsManager.IsUserOnline(userID) {
return errors.New("user is offline")
}
// 确保时间戳已设置
if notification.CreatedAt == 0 {
notification.CreatedAt = time.Now().UnixMilli()
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
s.wsManager.SendToUser(userID, wsMsg)
return nil
}
// PushAnnouncementMessage 广播公告消息给所有在线用户
func (s *chatServiceImpl) PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error {
if s.wsManager == nil {
return errors.New("websocket manager not available")
}
// 确保时间戳已设置
if announcement.CreatedAt == 0 {
announcement.CreatedAt = time.Now().UnixMilli()
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
s.wsManager.Broadcast(wsMsg)
return nil
}
// SaveMessage 仅保存消息到数据库,不发送 WebSocket 推送
// SaveMessage 仅保存消息到数据库,不发送实时推送
// 适用于群聊等由调用方自行负责推送的场景
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
// 验证会话是否存在

View File

@@ -7,6 +7,7 @@ import (
"log"
"strings"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/gorse"
"carrot_bbs/internal/repository"
@@ -17,6 +18,7 @@ type CommentService struct {
commentRepo *repository.CommentRepository
postRepo *repository.PostRepository
systemMessageService SystemMessageService
cache cache.Cache
gorseClient gorse.Client
postAIService *PostAIService
}
@@ -27,6 +29,7 @@ func NewCommentService(commentRepo *repository.CommentRepository, postRepo *repo
commentRepo: commentRepo,
postRepo: postRepo,
systemMessageService: systemMessageService,
cache: cache.GetCache(),
gorseClient: gorseClient,
postAIService: postAIService,
}
@@ -96,6 +99,10 @@ func (s *CommentService) reviewCommentAsync(
log.Printf("[WARN] Failed to publish comment without AI moderation: %v", err)
return
}
if err := s.applyCommentPublishedStats(commentID); err != nil {
log.Printf("[WARN] Failed to apply published stats for comment %s: %v", commentID, err)
}
s.invalidatePostCaches(postID)
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
return
}
@@ -116,6 +123,10 @@ func (s *CommentService) reviewCommentAsync(
log.Printf("[WARN] Failed to publish comment %s after moderation error: %v", commentID, updateErr)
return
}
if statsErr := s.applyCommentPublishedStats(commentID); statsErr != nil {
log.Printf("[WARN] Failed to apply published stats for comment %s: %v", commentID, statsErr)
}
s.invalidatePostCaches(postID)
log.Printf("[WARN] Comment moderation failed, fallback publish comment=%s err=%v", commentID, err)
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
return
@@ -125,9 +136,26 @@ func (s *CommentService) reviewCommentAsync(
log.Printf("[WARN] Failed to publish comment %s: %v", commentID, updateErr)
return
}
if statsErr := s.applyCommentPublishedStats(commentID); statsErr != nil {
log.Printf("[WARN] Failed to apply published stats for comment %s: %v", commentID, statsErr)
}
s.invalidatePostCaches(postID)
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
}
func (s *CommentService) applyCommentPublishedStats(commentID string) error {
comment, err := s.commentRepo.GetByID(commentID)
if err != nil {
return err
}
return s.commentRepo.ApplyPublishedStats(comment)
}
func (s *CommentService) invalidatePostCaches(postID string) {
cache.InvalidatePostDetail(s.cache, postID)
cache.InvalidatePostList(s.cache)
}
func (s *CommentService) afterCommentPublished(userID, postID, commentID string, parentID *string, parentUserID, postOwnerID string) {
// 发送系统消息通知
if s.systemMessageService != nil {
@@ -212,7 +240,15 @@ func (s *CommentService) Update(ctx context.Context, comment *model.Comment) err
// Delete 删除评论
func (s *CommentService) Delete(ctx context.Context, id string) error {
return s.commentRepo.Delete(id)
comment, err := s.commentRepo.GetByID(id)
if err != nil {
return err
}
if err := s.commentRepo.Delete(id); err != nil {
return err
}
s.invalidatePostCaches(comment.PostID)
return nil
}
// Like 点赞评论

View File

@@ -9,8 +9,8 @@ import (
"carrot_bbs/internal/cache"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/sse"
"carrot_bbs/internal/pkg/utils"
"carrot_bbs/internal/pkg/websocket"
"carrot_bbs/internal/repository"
"gorm.io/gorm"
@@ -18,7 +18,7 @@ import (
// 缓存TTL常量
const (
GroupMembersTTL = 120 * time.Second // 群组成员缓存120秒
GroupMembersTTL = 120 * time.Second // 群组成员缓存120秒
GroupMembersNullTTL = 5 * time.Second
GroupCacheJitter = 0.1
)
@@ -99,12 +99,12 @@ type groupService struct {
messageRepo *repository.MessageRepository
requestRepo repository.GroupJoinRequestRepository
notifyRepo *repository.SystemNotificationRepository
wsManager *websocket.WebSocketManager
sseHub *sse.Hub
cache cache.Cache
}
// NewGroupService 创建群组服务
func NewGroupService(db *gorm.DB, groupRepo repository.GroupRepository, userRepo *repository.UserRepository, messageRepo *repository.MessageRepository, wsManager *websocket.WebSocketManager) GroupService {
func NewGroupService(db *gorm.DB, groupRepo repository.GroupRepository, userRepo *repository.UserRepository, messageRepo *repository.MessageRepository, sseHub *sse.Hub) GroupService {
return &groupService{
db: db,
groupRepo: groupRepo,
@@ -112,11 +112,39 @@ func NewGroupService(db *gorm.DB, groupRepo repository.GroupRepository, userRepo
messageRepo: messageRepo,
requestRepo: repository.NewGroupJoinRequestRepository(db),
notifyRepo: repository.NewSystemNotificationRepository(db),
wsManager: wsManager,
sseHub: sseHub,
cache: cache.GetCache(),
}
}
type groupNoticeData struct {
UserID string `json:"user_id,omitempty"`
Username string `json:"username,omitempty"`
OperatorID string `json:"operator_id,omitempty"`
}
type groupNoticeMessage struct {
NoticeType string `json:"notice_type"`
GroupID string `json:"group_id"`
Data groupNoticeData `json:"data"`
Timestamp int64 `json:"timestamp"`
MessageID string `json:"message_id,omitempty"`
Seq int64 `json:"seq,omitempty"`
}
func (s *groupService) publishGroupNotice(groupID string, notice groupNoticeMessage) {
members, _, err := s.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("[groupService] 获取群成员失败: groupID=%s, err=%v", groupID, err)
return
}
if s.sseHub != nil {
for _, m := range members {
s.sseHub.PublishToUser(m.UserID, "group_notice", notice)
}
}
}
// ==================== 群组管理 ====================
// CreateGroup 创建群组
@@ -422,14 +450,10 @@ func (s *groupService) broadcastMemberJoinNotice(groupID string, targetUserID st
}
}
if s.wsManager == nil {
return
}
noticeMsg := websocket.GroupNoticeMessage{
noticeMsg := groupNoticeMessage{
NoticeType: "member_join",
GroupID: groupID,
Data: websocket.GroupNoticeData{
Data: groupNoticeData{
UserID: targetUserID,
Username: targetUserName,
OperatorID: operatorID,
@@ -441,17 +465,7 @@ func (s *groupService) broadcastMemberJoinNotice(groupID string, targetUserID st
noticeMsg.Seq = savedMessage.Seq
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeGroupNotice, noticeMsg)
members, _, err := s.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("[broadcastMemberJoinNotice] 获取群成员失败: groupID=%s, err=%v", groupID, err)
return
}
for _, m := range members {
if s.wsManager.IsUserOnline(m.UserID) {
s.wsManager.SendToUser(m.UserID, wsMsg)
}
}
s.publishGroupNotice(groupID, noticeMsg)
}
func (s *groupService) addMemberToGroupAndConversation(group *model.Group, userID string, operatorID string) error {
@@ -1282,46 +1296,20 @@ func (s *groupService) MuteMember(userID string, groupID string, targetUserID st
}
}
// 发送WebSocket通知给群成员
if s.wsManager != nil {
log.Printf("[MuteMember] 准备发送禁言通知: groupID=%s, targetUserID=%s, noticeType=%s, operatorID=%s", groupID, targetUserID, noticeType, userID)
// 构建通知消息,包含保存的消息信息
noticeMsg := websocket.GroupNoticeMessage{
NoticeType: noticeType,
GroupID: groupID,
Data: websocket.GroupNoticeData{
UserID: targetUserID,
OperatorID: userID,
},
Timestamp: time.Now().UnixMilli(),
}
// 如果消息已保存添加消息ID和seq
if savedMessage != nil {
noticeMsg.MessageID = savedMessage.ID
noticeMsg.Seq = savedMessage.Seq
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeGroupNotice, noticeMsg)
log.Printf("[MuteMember] 创建的WebSocket消息: Type=%s, Data=%+v", wsMsg.Type, wsMsg.Data)
// 获取所有群成员并发送通知
members, _, err := s.groupRepo.GetMembers(groupID, 1, 1000)
if err == nil {
log.Printf("[MuteMember] 获取到群成员数量: %d", len(members))
for _, m := range members {
isOnline := s.wsManager.IsUserOnline(m.UserID)
log.Printf("[MuteMember] 成员 %s 在线状态: %v", m.UserID, isOnline)
if isOnline {
s.wsManager.SendToUser(m.UserID, wsMsg)
log.Printf("[MuteMember] 已发送通知给成员: %s", m.UserID)
}
}
} else {
log.Printf("[MuteMember] 获取群成员失败: %v", err)
}
noticeMsg := groupNoticeMessage{
NoticeType: noticeType,
GroupID: groupID,
Data: groupNoticeData{
UserID: targetUserID,
OperatorID: userID,
},
Timestamp: time.Now().UnixMilli(),
}
if savedMessage != nil {
noticeMsg.MessageID = savedMessage.ID
noticeMsg.Seq = savedMessage.Seq
}
s.publishGroupNotice(groupID, noticeMsg)
// 失效群组成员缓存
cache.InvalidateGroupMembers(s.cache, groupID)

View File

@@ -77,6 +77,8 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
if s.postAIService == nil || !s.postAIService.IsEnabled() {
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil {
log.Printf("[WARN] Failed to publish post without AI moderation: %v", err)
} else {
s.invalidatePostCaches(postID)
}
return
}
@@ -87,6 +89,8 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
if errors.As(err, &rejectedErr) {
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
log.Printf("[WARN] Failed to reject post %s: %v", postID, updateErr)
} else {
s.invalidatePostCaches(postID)
}
s.notifyModerationRejected(userID, rejectedErr.Reason)
return
@@ -95,6 +99,8 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
// 规则审核不可用时降级为发布避免长时间pending
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
log.Printf("[WARN] Failed to publish post %s after moderation error: %v", postID, updateErr)
} else {
s.invalidatePostCaches(postID)
}
log.Printf("[WARN] Post moderation failed, fallback publish post=%s err=%v", postID, err)
return
@@ -104,6 +110,7 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
log.Printf("[WARN] Failed to publish post %s: %v", postID, err)
return
}
s.invalidatePostCaches(postID)
if s.gorseClient.IsEnabled() {
post, getErr := s.postRepo.GetByID(postID)
@@ -120,6 +127,11 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
}
}
func (s *PostService) invalidatePostCaches(postID string) {
cache.InvalidatePostDetail(s.cache, postID)
cache.InvalidatePostList(s.cache)
}
func (s *PostService) notifyModerationRejected(userID, reason string) {
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
return
@@ -149,7 +161,12 @@ func (s *PostService) GetByID(ctx context.Context, id string) (*model.Post, erro
// Update 更新帖子
func (s *PostService) Update(ctx context.Context, post *model.Post) error {
err := s.postRepo.Update(post)
return s.UpdateWithImages(ctx, post, nil)
}
// UpdateWithImages 更新帖子并可选更新图片images=nil 表示不更新图片)
func (s *PostService) UpdateWithImages(ctx context.Context, post *model.Post, images *[]string) error {
err := s.postRepo.UpdateWithImages(post, images)
if err != nil {
return err
}
@@ -185,7 +202,7 @@ func (s *PostService) Delete(ctx context.Context, id string) error {
}
// List 获取帖子列表(带缓存)
func (s *PostService) List(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
func (s *PostService) List(ctx context.Context, page, pageSize int, userID string, includePending bool) ([]*model.Post, int64, error) {
cacheSettings := cache.GetSettings()
postListTTL := cacheSettings.PostListTTL
if postListTTL <= 0 {
@@ -200,8 +217,12 @@ func (s *PostService) List(ctx context.Context, page, pageSize int, userID strin
jitter = PostListJitterRatio
}
// 生成缓存键(包含 userID 维度,避免过滤查询与全量查询互相污染
cacheKey := cache.PostListKey("latest", userID, page, pageSize)
// 生成缓存键(包含 userID 维度与可见性维度,避免作者视角污染公开视角
visibilityUserKey := userID
if includePending && userID != "" {
visibilityUserKey = "owner:" + userID
}
cacheKey := cache.PostListKey("latest", visibilityUserKey, page, pageSize)
result, err := cache.GetOrLoadTyped[*PostListResult](
s.cache,
@@ -210,7 +231,7 @@ func (s *PostService) List(ctx context.Context, page, pageSize int, userID strin
jitter,
nullTTL,
func() (*PostListResult, error) {
posts, total, err := s.postRepo.List(page, pageSize, userID)
posts, total, err := s.postRepo.List(page, pageSize, userID, includePending)
if err != nil {
return nil, err
}
@@ -234,7 +255,7 @@ func (s *PostService) List(ctx context.Context, page, pageSize int, userID strin
}
}
if missingAuthor {
posts, total, loadErr := s.postRepo.List(page, pageSize, userID)
posts, total, loadErr := s.postRepo.List(page, pageSize, userID, includePending)
if loadErr != nil {
return nil, 0, loadErr
}
@@ -247,12 +268,17 @@ func (s *PostService) List(ctx context.Context, page, pageSize int, userID strin
// GetLatestPosts 获取最新帖子(语义化别名)
func (s *PostService) GetLatestPosts(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
return s.List(ctx, page, pageSize, userID)
return s.List(ctx, page, pageSize, userID, false)
}
// GetLatestPostsForOwner 获取作者视角帖子列表(包含待审核)
func (s *PostService) GetLatestPostsForOwner(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
return s.List(ctx, page, pageSize, userID, true)
}
// GetUserPosts 获取用户帖子
func (s *PostService) GetUserPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
return s.postRepo.GetUserPosts(userID, page, pageSize)
func (s *PostService) GetUserPosts(ctx context.Context, userID string, page, pageSize int, includePending bool) ([]*model.Post, int64, error) {
return s.postRepo.GetUserPosts(userID, page, pageSize, includePending)
}
// Like 点赞

View File

@@ -8,7 +8,7 @@ import (
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/websocket"
"carrot_bbs/internal/pkg/sse"
"carrot_bbs/internal/repository"
)
@@ -42,8 +42,6 @@ type PushService interface {
// 系统消息推送
PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error
PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error
PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error
// 系统通知推送(新接口,使用独立的 SystemNotification 模型)
PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error
@@ -67,7 +65,7 @@ type pushServiceImpl struct {
pushRepo *repository.PushRecordRepository
deviceRepo *repository.DeviceTokenRepository
messageRepo *repository.MessageRepository
wsManager *websocket.WebSocketManager
sseHub *sse.Hub
// 推送队列
pushQueue chan *pushTask
@@ -86,13 +84,13 @@ func NewPushService(
pushRepo *repository.PushRecordRepository,
deviceRepo *repository.DeviceTokenRepository,
messageRepo *repository.MessageRepository,
wsManager *websocket.WebSocketManager,
sseHub *sse.Hub,
) PushService {
return &pushServiceImpl{
pushRepo: pushRepo,
deviceRepo: deviceRepo,
messageRepo: messageRepo,
wsManager: wsManager,
sseHub: sseHub,
pushQueue: make(chan *pushTask, PushQueueSize),
stopChan: make(chan struct{}),
}
@@ -140,11 +138,7 @@ func (s *pushServiceImpl) PushToUser(ctx context.Context, userID string, message
// pushViaWebSocket 通过WebSocket推送消息
// 返回true表示推送成功false表示用户不在线
func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, message *model.Message) bool {
if s.wsManager == nil {
return false
}
if !s.wsManager.IsUserOnline(userID) {
if s.sseHub == nil || !s.sseHub.HasSubscribers(userID) {
return false
}
@@ -154,36 +148,33 @@ func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, m
// 从 segments 中提取文本内容
content := dto.ExtractTextContentFromModel(message.Segments)
notification := &websocket.NotificationMessage{
ID: fmt.Sprintf("%s", message.ID),
Type: string(message.SystemType),
Content: content,
Extra: make(map[string]interface{}),
CreatedAt: message.CreatedAt.UnixMilli(),
notification := map[string]interface{}{
"id": fmt.Sprintf("%s", message.ID),
"type": string(message.SystemType),
"content": content,
"extra": map[string]interface{}{},
"created_at": message.CreatedAt.UnixMilli(),
}
// 填充额外数据
if message.ExtraData != nil {
notification.Extra["actor_id"] = message.ExtraData.ActorID
notification.Extra["actor_name"] = message.ExtraData.ActorName
notification.Extra["avatar_url"] = message.ExtraData.AvatarURL
notification.Extra["target_id"] = message.ExtraData.TargetID
notification.Extra["target_type"] = message.ExtraData.TargetType
notification.Extra["action_url"] = message.ExtraData.ActionURL
notification.Extra["action_time"] = message.ExtraData.ActionTime
// 设置触发用户信息
extra := notification["extra"].(map[string]interface{})
extra["actor_id"] = message.ExtraData.ActorID
extra["actor_name"] = message.ExtraData.ActorName
extra["avatar_url"] = message.ExtraData.AvatarURL
extra["target_id"] = message.ExtraData.TargetID
extra["target_type"] = message.ExtraData.TargetType
extra["action_url"] = message.ExtraData.ActionURL
extra["action_time"] = message.ExtraData.ActionTime
if message.ExtraData.ActorID > 0 {
notification.TriggerUser = &websocket.NotificationUser{
ID: fmt.Sprintf("%d", message.ExtraData.ActorID),
Username: message.ExtraData.ActorName,
Avatar: message.ExtraData.AvatarURL,
notification["trigger_user"] = map[string]interface{}{
"id": fmt.Sprintf("%d", message.ExtraData.ActorID),
"username": message.ExtraData.ActorName,
"avatar": message.ExtraData.AvatarURL,
}
}
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
s.wsManager.SendToUser(userID, wsMsg)
s.sseHub.PublishToUser(userID, "system_notification", notification)
return true
}
@@ -208,8 +199,10 @@ func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, m
SenderID: message.SenderID,
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, event)
s.wsManager.SendToUser(userID, wsMsg)
s.sseHub.PublishToUser(userID, "chat_message", map[string]interface{}{
"detail_type": detailType,
"message": event,
})
return true
}
@@ -451,73 +444,21 @@ func (s *pushServiceImpl) PushSystemMessage(ctx context.Context, userID string,
// pushSystemViaWebSocket 通过WebSocket推送系统消息
func (s *pushServiceImpl) pushSystemViaWebSocket(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) bool {
if s.wsManager == nil {
if s.sseHub == nil || !s.sseHub.HasSubscribers(userID) {
return false
}
if !s.wsManager.IsUserOnline(userID) {
return false
sysMsg := map[string]interface{}{
"type": msgType,
"title": title,
"content": content,
"data": data,
"created_at": time.Now().UnixMilli(),
}
sysMsg := &websocket.SystemMessage{
Type: msgType,
Title: title,
Content: content,
Data: data,
CreatedAt: time.Now().UnixMilli(),
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
s.wsManager.SendToUser(userID, wsMsg)
s.sseHub.PublishToUser(userID, "system_notification", sysMsg)
return true
}
// PushNotification 推送通知消息
func (s *pushServiceImpl) PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error {
// 首先尝试WebSocket推送
if s.pushNotificationViaWebSocket(ctx, userID, notification) {
return nil
}
// 用户不在线,创建待推送记录
// 通知消息可以等用户上线后拉取
return errors.New("user is offline, notification will be available on next sync")
}
// pushNotificationViaWebSocket 通过WebSocket推送通知消息
func (s *pushServiceImpl) pushNotificationViaWebSocket(ctx context.Context, userID string, notification *websocket.NotificationMessage) bool {
if s.wsManager == nil {
return false
}
if !s.wsManager.IsUserOnline(userID) {
return false
}
if notification.CreatedAt == 0 {
notification.CreatedAt = time.Now().UnixMilli()
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
s.wsManager.SendToUser(userID, wsMsg)
return true
}
// PushAnnouncement 广播公告消息
func (s *pushServiceImpl) PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error {
if s.wsManager == nil {
return errors.New("websocket manager not available")
}
if announcement.CreatedAt == 0 {
announcement.CreatedAt = time.Now().UnixMilli()
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
s.wsManager.Broadcast(wsMsg)
return nil
}
// PushSystemNotification 推送系统通知(使用独立的 SystemNotification 模型)
func (s *pushServiceImpl) PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error {
// 首先尝试WebSocket推送
@@ -531,45 +472,40 @@ func (s *pushServiceImpl) PushSystemNotification(ctx context.Context, userID str
// pushSystemNotificationViaWebSocket 通过WebSocket推送系统通知
func (s *pushServiceImpl) pushSystemNotificationViaWebSocket(ctx context.Context, userID string, notification *model.SystemNotification) bool {
if s.wsManager == nil {
if s.sseHub == nil || !s.sseHub.HasSubscribers(userID) {
return false
}
if !s.wsManager.IsUserOnline(userID) {
return false
}
// 构建 WebSocket 通知消息
wsNotification := &websocket.NotificationMessage{
ID: fmt.Sprintf("%d", notification.ID),
Type: string(notification.Type),
Title: notification.Title,
Content: notification.Content,
Extra: make(map[string]interface{}),
CreatedAt: notification.CreatedAt.UnixMilli(),
sseNotification := map[string]interface{}{
"id": fmt.Sprintf("%d", notification.ID),
"type": string(notification.Type),
"title": notification.Title,
"content": notification.Content,
"extra": map[string]interface{}{},
"created_at": notification.CreatedAt.UnixMilli(),
}
// 填充额外数据
if notification.ExtraData != nil {
wsNotification.Extra["actor_id_str"] = notification.ExtraData.ActorIDStr
wsNotification.Extra["actor_name"] = notification.ExtraData.ActorName
wsNotification.Extra["avatar_url"] = notification.ExtraData.AvatarURL
wsNotification.Extra["target_id"] = notification.ExtraData.TargetID
wsNotification.Extra["target_type"] = notification.ExtraData.TargetType
wsNotification.Extra["action_url"] = notification.ExtraData.ActionURL
wsNotification.Extra["action_time"] = notification.ExtraData.ActionTime
extra := sseNotification["extra"].(map[string]interface{})
extra["actor_id_str"] = notification.ExtraData.ActorIDStr
extra["actor_name"] = notification.ExtraData.ActorName
extra["avatar_url"] = notification.ExtraData.AvatarURL
extra["target_id"] = notification.ExtraData.TargetID
extra["target_type"] = notification.ExtraData.TargetType
extra["action_url"] = notification.ExtraData.ActionURL
extra["action_time"] = notification.ExtraData.ActionTime
// 设置触发用户信息
if notification.ExtraData.ActorIDStr != "" {
wsNotification.TriggerUser = &websocket.NotificationUser{
ID: notification.ExtraData.ActorIDStr,
Username: notification.ExtraData.ActorName,
Avatar: notification.ExtraData.AvatarURL,
sseNotification["trigger_user"] = map[string]interface{}{
"id": notification.ExtraData.ActorIDStr,
"username": notification.ExtraData.ActorName,
"avatar": notification.ExtraData.AvatarURL,
}
}
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, wsNotification)
s.wsManager.SendToUser(userID, wsMsg)
s.sseHub.PublishToUser(userID, "system_notification", sseNotification)
return true
}