Replace websocket flow with SSE support in backend.
Update handlers, services, router, and data conversion logic to support server-sent events and related message pipeline changes. Made-with: Cursor
This commit is contained in:
@@ -36,9 +36,9 @@ database:
|
||||
redis:
|
||||
type: miniredis # miniredis 或 redis
|
||||
redis:
|
||||
host: localhost
|
||||
host: 1Panel-redis-dfmM
|
||||
port: 6379
|
||||
password: ""
|
||||
password: "redis_j8CMza"
|
||||
db: 0
|
||||
miniredis:
|
||||
host: localhost
|
||||
@@ -67,13 +67,13 @@ cache:
|
||||
# S3对象存储配置
|
||||
# 环境变量: APP_S3_ENDPOINT, APP_S3_ACCESS_KEY, APP_S3_SECRET_KEY, APP_S3_BUCKET, APP_S3_DOMAIN
|
||||
s3:
|
||||
endpoint: ""
|
||||
access_key: ""
|
||||
secret_key: ""
|
||||
bucket: ""
|
||||
endpoint: "files.littlelan.cn"
|
||||
access_key: "E6bMcYkQzCldRTrtmhvi"
|
||||
secret_key: "4R9yjmwKNoHphiBkv05Oa8WGEIFbnlZeTLXfSgx3"
|
||||
bucket: "test"
|
||||
use_ssl: true
|
||||
region: us-east-1
|
||||
domain: ""
|
||||
domain: "files.littlelan.cn"
|
||||
# JWT配置
|
||||
# 环境变量: APP_JWT_SECRET
|
||||
jwt:
|
||||
@@ -130,12 +130,12 @@ audit:
|
||||
# Gorse推荐系统配置
|
||||
# 环境变量: APP_GORSE_ADDRESS, APP_GORSE_API_KEY, APP_GORSE_DASHBOARD, APP_GORSE_IMPORT_PASSWORD
|
||||
gorse:
|
||||
enabled: false
|
||||
address: "" # Gorse server地址
|
||||
enabled: true
|
||||
address: "http://111.170.19.33:8088" # Gorse server地址
|
||||
api_key: "" # API密钥
|
||||
dashboard: "" # Gorse dashboard地址
|
||||
import_password: "" # 导入数据密码
|
||||
embedding_api_key: ""
|
||||
import_password: "lanyimin123" # 导入数据密码
|
||||
embedding_api_key: "sk-ZPN5NMPSqEaOGCPfD2LqndZ5Wwmw3DC4CQgzgKhM35fI3RpD"
|
||||
embedding_url: "https://api.littlelan.cn/v1/embeddings"
|
||||
embedding_model: "BAAI/bge-m3"
|
||||
|
||||
@@ -147,7 +147,7 @@ gorse:
|
||||
openai:
|
||||
enabled: true
|
||||
base_url: "https://api.littlelan.cn/"
|
||||
api_key: ""
|
||||
api_key: "sk-y7LOeKsNfzbZWTRSFsTs79jd8WYlezbIVgdVPgMvG4Xz2AlV"
|
||||
moderation_model: "qwen3.5-122b"
|
||||
moderation_max_images_per_request: 1
|
||||
request_timeout: 30
|
||||
@@ -160,12 +160,12 @@ openai:
|
||||
# APP_EMAIL_FROM_ADDRESS, APP_EMAIL_FROM_NAME
|
||||
# APP_EMAIL_USE_TLS, APP_EMAIL_INSECURE_SKIP_VERIFY, APP_EMAIL_TIMEOUT
|
||||
email:
|
||||
enabled: false
|
||||
host: ""
|
||||
port: 587
|
||||
username: ""
|
||||
password: ""
|
||||
from_address: ""
|
||||
enabled: true
|
||||
host: "smtp.exmail.qq.com"
|
||||
port: 465
|
||||
username: "no-reply@qczlit.cn"
|
||||
password: "HbvwwVjRyiWg9gsK"
|
||||
from_address: "no-reply@qczlit.cn"
|
||||
from_name: "Carrot BBS"
|
||||
use_tls: true
|
||||
insecure_skip_verify: false
|
||||
|
||||
@@ -284,6 +284,7 @@ func ConvertPostToResponse(post *model.Post, isLiked, isFavorited bool) *PostRes
|
||||
Title: post.Title,
|
||||
Content: post.Content,
|
||||
Images: images,
|
||||
Status: string(post.Status),
|
||||
LikesCount: post.LikesCount,
|
||||
CommentsCount: post.CommentsCount,
|
||||
FavoritesCount: post.FavoritesCount,
|
||||
@@ -293,6 +294,7 @@ func ConvertPostToResponse(post *model.Post, isLiked, isFavorited bool) *PostRes
|
||||
IsLocked: post.IsLocked,
|
||||
IsVote: post.IsVote,
|
||||
CreatedAt: FormatTime(post.CreatedAt),
|
||||
UpdatedAt: FormatTime(post.UpdatedAt),
|
||||
Author: author,
|
||||
IsLiked: isLiked,
|
||||
IsFavorited: isFavorited,
|
||||
|
||||
@@ -68,6 +68,7 @@ type PostResponse struct {
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Images []PostImageResponse `json:"images"`
|
||||
Status string `json:"status,omitempty"`
|
||||
LikesCount int `json:"likes_count"`
|
||||
CommentsCount int `json:"comments_count"`
|
||||
FavoritesCount int `json:"favorites_count"`
|
||||
@@ -77,6 +78,7 @@ type PostResponse struct {
|
||||
IsLocked bool `json:"is_locked"`
|
||||
IsVote bool `json:"is_vote"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
Author *UserResponse `json:"author"`
|
||||
IsLiked bool `json:"is_liked"`
|
||||
IsFavorited bool `json:"is_favorited"`
|
||||
|
||||
@@ -2,12 +2,16 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/sse"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
@@ -18,18 +22,111 @@ type MessageHandler struct {
|
||||
messageService *service.MessageService
|
||||
userService *service.UserService
|
||||
groupService service.GroupService
|
||||
sseHub *sse.Hub
|
||||
}
|
||||
|
||||
// NewMessageHandler 创建消息处理器
|
||||
func NewMessageHandler(chatService service.ChatService, messageService *service.MessageService, userService *service.UserService, groupService service.GroupService) *MessageHandler {
|
||||
func NewMessageHandler(chatService service.ChatService, messageService *service.MessageService, userService *service.UserService, groupService service.GroupService, sseHub *sse.Hub) *MessageHandler {
|
||||
return &MessageHandler{
|
||||
chatService: chatService,
|
||||
messageService: messageService,
|
||||
userService: userService,
|
||||
groupService: groupService,
|
||||
sseHub: sseHub,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleSSE 实时消息订阅(SSE)
|
||||
// GET /api/v1/realtime/sse
|
||||
func (h *MessageHandler) HandleSSE(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
if h.sseHub == nil {
|
||||
response.InternalServerError(c, "sse hub not available")
|
||||
return
|
||||
}
|
||||
|
||||
lastID := sse.ParseEventID(c.GetHeader("Last-Event-ID"))
|
||||
if lastID == 0 {
|
||||
lastID = sse.ParseEventID(c.Query("last_event_id"))
|
||||
}
|
||||
ch, cancel, replay := h.sseHub.Subscribe(userID, lastID)
|
||||
defer cancel()
|
||||
|
||||
w := c.Writer
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
response.InternalServerError(c, "streaming unsupported")
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Status(http.StatusOK)
|
||||
flusher.Flush()
|
||||
|
||||
writeEvent := func(ev sse.Event) bool {
|
||||
data, err := sse.EncodeData(ev)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if _, err := fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", ev.ID, ev.Event, data); err != nil {
|
||||
return false
|
||||
}
|
||||
flusher.Flush()
|
||||
return true
|
||||
}
|
||||
|
||||
for _, ev := range replay {
|
||||
if !writeEvent(ev) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
heartbeat := time.NewTicker(25 * time.Second)
|
||||
defer heartbeat.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case ev, ok := <-ch:
|
||||
if !ok || !writeEvent(ev) {
|
||||
return
|
||||
}
|
||||
case <-heartbeat.C:
|
||||
if _, err := fmt.Fprint(w, "event: heartbeat\ndata: {}\n\n"); err != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HandleTyping 输入状态上报
|
||||
// POST /api/v1/conversations/typing
|
||||
func (h *MessageHandler) HandleTyping(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
var params struct {
|
||||
ConversationID string `json:"conversation_id" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
h.chatService.SendTyping(c.Request.Context(), userID, params.ConversationID)
|
||||
response.SuccessWithMessage(c, "typing sent", nil)
|
||||
}
|
||||
|
||||
// GetConversations 获取会话列表
|
||||
// GET /api/conversations
|
||||
func (h *MessageHandler) GetConversations(c *gin.Context) {
|
||||
|
||||
@@ -105,6 +105,7 @@ func (h *PostHandler) GetByID(c *gin.Context) {
|
||||
Title: post.Title,
|
||||
Content: post.Content,
|
||||
Images: dto.ConvertPostImagesToResponse(post.Images),
|
||||
Status: string(post.Status),
|
||||
LikesCount: post.LikesCount,
|
||||
CommentsCount: post.CommentsCount,
|
||||
FavoritesCount: post.FavoritesCount,
|
||||
@@ -114,6 +115,7 @@ func (h *PostHandler) GetByID(c *gin.Context) {
|
||||
IsLocked: post.IsLocked,
|
||||
IsVote: post.IsVote,
|
||||
CreatedAt: dto.FormatTime(post.CreatedAt),
|
||||
UpdatedAt: dto.FormatTime(post.UpdatedAt),
|
||||
Author: authorWithFollowStatus,
|
||||
IsLiked: isLiked,
|
||||
IsFavorited: isFavorited,
|
||||
@@ -175,10 +177,18 @@ func (h *PostHandler) List(c *gin.Context) {
|
||||
posts, total, err = h.postService.GetRecommendedPosts(c.Request.Context(), currentUserID, page, pageSize)
|
||||
case "latest":
|
||||
// 最新帖子
|
||||
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
|
||||
if userID != "" && userID == currentUserID {
|
||||
posts, total, err = h.postService.GetLatestPostsForOwner(c.Request.Context(), page, pageSize, userID)
|
||||
} else {
|
||||
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
|
||||
}
|
||||
default:
|
||||
// 默认获取最新帖子
|
||||
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
|
||||
if userID != "" && userID == currentUserID {
|
||||
posts, total, err = h.postService.GetLatestPostsForOwner(c.Request.Context(), page, pageSize, userID)
|
||||
} else {
|
||||
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -225,8 +235,9 @@ func (h *PostHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
type UpdateRequest struct {
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Images *[]string `json:"images"`
|
||||
}
|
||||
|
||||
var req UpdateRequest
|
||||
@@ -242,12 +253,18 @@ func (h *PostHandler) Update(c *gin.Context) {
|
||||
post.Content = req.Content
|
||||
}
|
||||
|
||||
err = h.postService.Update(c.Request.Context(), post)
|
||||
err = h.postService.UpdateWithImages(c.Request.Context(), post, req.Images)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to update post")
|
||||
return
|
||||
}
|
||||
|
||||
post, err = h.postService.GetByID(c.Request.Context(), post.ID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get updated post")
|
||||
return
|
||||
}
|
||||
|
||||
currentUserID := c.GetString("user_id")
|
||||
var isLiked, isFavorited bool
|
||||
if currentUserID != "" {
|
||||
@@ -410,14 +427,15 @@ func (h *PostHandler) GetUserPosts(c *gin.Context) {
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
posts, total, err := h.postService.GetUserPosts(c.Request.Context(), userID, page, pageSize)
|
||||
currentUserID := c.GetString("user_id")
|
||||
includePending := currentUserID != "" && currentUserID == userID
|
||||
posts, total, err := h.postService.GetUserPosts(c.Request.Context(), userID, page, pageSize, includePending)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get user posts")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户ID用于判断点赞和收藏状态
|
||||
currentUserID := c.GetString("user_id")
|
||||
isLikedMap := make(map[string]bool)
|
||||
isFavoritedMap := make(map[string]bool)
|
||||
if currentUserID != "" {
|
||||
|
||||
@@ -1,849 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/model"
|
||||
ws "carrot_bbs/internal/pkg/websocket"
|
||||
"carrot_bbs/internal/repository"
|
||||
"carrot_bbs/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // 允许所有来源,生产环境应该限制
|
||||
},
|
||||
}
|
||||
|
||||
// WebSocketHandler WebSocket处理器
|
||||
type WebSocketHandler struct {
|
||||
jwtService *service.JWTService
|
||||
chatService service.ChatService
|
||||
groupService service.GroupService
|
||||
groupRepo repository.GroupRepository
|
||||
userRepo *repository.UserRepository
|
||||
wsManager *ws.WebSocketManager
|
||||
}
|
||||
|
||||
// NewWebSocketHandler 创建WebSocket处理器
|
||||
func NewWebSocketHandler(
|
||||
jwtService *service.JWTService,
|
||||
chatService service.ChatService,
|
||||
groupService service.GroupService,
|
||||
groupRepo repository.GroupRepository,
|
||||
userRepo *repository.UserRepository,
|
||||
wsManager *ws.WebSocketManager,
|
||||
) *WebSocketHandler {
|
||||
return &WebSocketHandler{
|
||||
jwtService: jwtService,
|
||||
chatService: chatService,
|
||||
groupService: groupService,
|
||||
groupRepo: groupRepo,
|
||||
userRepo: userRepo,
|
||||
wsManager: wsManager,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleWebSocket 处理WebSocket连接
|
||||
func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) {
|
||||
// 调试:打印请求头信息
|
||||
log.Printf("[WebSocket] 收到请求: Method=%s, Path=%s", c.Request.Method, c.Request.URL.Path)
|
||||
log.Printf("[WebSocket] 请求头: Connection=%s, Upgrade=%s",
|
||||
c.GetHeader("Connection"),
|
||||
c.GetHeader("Upgrade"))
|
||||
log.Printf("[WebSocket] Sec-WebSocket-Key=%s, Sec-WebSocket-Version=%s",
|
||||
c.GetHeader("Sec-WebSocket-Key"),
|
||||
c.GetHeader("Sec-WebSocket-Version"))
|
||||
|
||||
// 从query参数获取token
|
||||
token := c.Query("token")
|
||||
if token == "" {
|
||||
// 尝试从header获取
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
token = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing token"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token
|
||||
claims, err := h.jwtService.ParseToken(token)
|
||||
if err != nil {
|
||||
log.Printf("Invalid token: %v", err)
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
userID := claims.UserID
|
||||
if userID == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token claims"})
|
||||
return
|
||||
}
|
||||
|
||||
// 升级HTTP连接为WebSocket连接
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to upgrade connection: %v", err)
|
||||
log.Printf("[WebSocket] 请求详情 - User-Agent: %s, Content-Type: %s",
|
||||
c.GetHeader("User-Agent"),
|
||||
c.GetHeader("Content-Type"))
|
||||
return
|
||||
}
|
||||
|
||||
// 创建客户端
|
||||
client := &ws.Client{
|
||||
ID: userID,
|
||||
UserID: userID,
|
||||
Conn: conn,
|
||||
Send: make(chan []byte, 256),
|
||||
Manager: h.wsManager,
|
||||
}
|
||||
|
||||
// 注册客户端
|
||||
h.wsManager.Register(client)
|
||||
|
||||
// 启动读写协程
|
||||
go client.WritePump()
|
||||
go h.handleMessages(client)
|
||||
|
||||
}
|
||||
|
||||
// handleMessages 处理客户端消息
|
||||
// 针对移动端优化:增加超时时间到 120 秒,配合客户端 55 秒心跳
|
||||
func (h *WebSocketHandler) handleMessages(client *ws.Client) {
|
||||
defer func() {
|
||||
h.wsManager.Unregister(client)
|
||||
client.Conn.Close()
|
||||
}()
|
||||
|
||||
client.Conn.SetReadLimit(512 * 1024) // 512KB
|
||||
client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒
|
||||
client.Conn.SetPongHandler(func(string) error {
|
||||
client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒
|
||||
return nil
|
||||
})
|
||||
|
||||
// 心跳定时器 - 服务端主动 ping 间隔保持 30 秒
|
||||
pingTicker := time.NewTicker(30 * time.Second)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pingTicker.C:
|
||||
// 发送心跳
|
||||
if err := client.SendPing(); err != nil {
|
||||
log.Printf("Failed to send ping: %v", err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
_, message, err := client.Conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Printf("WebSocket error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var wsMsg ws.WSMessage
|
||||
if err := json.Unmarshal(message, &wsMsg); err != nil {
|
||||
log.Printf("Failed to unmarshal message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
h.processMessage(client, &wsMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processMessage 处理消息
|
||||
func (h *WebSocketHandler) processMessage(client *ws.Client, msg *ws.WSMessage) {
|
||||
switch msg.Type {
|
||||
case ws.MessageTypePing:
|
||||
// 响应心跳
|
||||
if err := client.SendPong(); err != nil {
|
||||
log.Printf("Failed to send pong: %v", err)
|
||||
}
|
||||
|
||||
case ws.MessageTypePong:
|
||||
// 客户端响应心跳
|
||||
|
||||
case ws.MessageTypeMessage:
|
||||
// 处理聊天消息
|
||||
h.handleChatMessage(client, msg)
|
||||
|
||||
case ws.MessageTypeTyping:
|
||||
// 处理正在输入状态
|
||||
h.handleTyping(client, msg)
|
||||
|
||||
case ws.MessageTypeRead:
|
||||
// 处理已读回执
|
||||
h.handleReadReceipt(client, msg)
|
||||
|
||||
// 群组消息处理
|
||||
case ws.MessageTypeGroupMessage:
|
||||
// 处理群消息
|
||||
h.handleGroupMessage(client, msg)
|
||||
|
||||
case ws.MessageTypeGroupTyping:
|
||||
// 处理群输入状态
|
||||
h.handleGroupTyping(client, msg)
|
||||
|
||||
case ws.MessageTypeGroupRead:
|
||||
// 处理群消息已读
|
||||
h.handleGroupReadReceipt(client, msg)
|
||||
|
||||
case ws.MessageTypeGroupRecall:
|
||||
// 处理群消息撤回
|
||||
h.handleGroupRecall(client, msg)
|
||||
|
||||
default:
|
||||
log.Printf("Unknown message type: %s", msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// handleChatMessage 处理聊天消息
|
||||
func (h *WebSocketHandler) handleChatMessage(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
log.Printf("Invalid message data format")
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr, _ := data["conversationId"].(string)
|
||||
|
||||
if conversationIDStr == "" {
|
||||
log.Printf("Missing conversationId")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析会话ID
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
log.Printf("Invalid conversation ID: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 segments
|
||||
var segments model.MessageSegments
|
||||
if data["segments"] != nil {
|
||||
segmentsBytes, err := json.Marshal(data["segments"])
|
||||
if err == nil {
|
||||
json.Unmarshal(segmentsBytes, &segments)
|
||||
}
|
||||
}
|
||||
|
||||
// 从 segments 中提取回复消息ID
|
||||
replyToID := dto.GetReplyMessageID(segments)
|
||||
var replyToIDPtr *string
|
||||
if replyToID != "" {
|
||||
replyToIDPtr = &replyToID
|
||||
}
|
||||
|
||||
// 发送消息 - 使用 segments
|
||||
message, err := h.chatService.SendMessage(context.Background(), client.UserID, conversationID, segments, replyToIDPtr)
|
||||
if err != nil {
|
||||
log.Printf("Failed to send message: %v", err)
|
||||
// 发送错误消息
|
||||
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
||||
"error": "Failed to send message",
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(errorMsg)
|
||||
client.Send <- msgBytes
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 发送确认消息(使用 meta 事件格式,包含完整的消息内容)
|
||||
metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{
|
||||
"detail_type": ws.MetaDetailTypeAck,
|
||||
"conversation_id": conversationID,
|
||||
"id": message.ID,
|
||||
"user_id": client.UserID,
|
||||
"sender_id": client.UserID,
|
||||
"seq": message.Seq,
|
||||
"segments": message.Segments,
|
||||
"created_at": message.CreatedAt.UnixMilli(),
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(metaAckMsg)
|
||||
client.Send <- msgBytes
|
||||
}
|
||||
}
|
||||
|
||||
// handleTyping 处理正在输入状态
|
||||
func (h *WebSocketHandler) handleTyping(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr, _ := data["conversationId"].(string)
|
||||
if conversationIDStr == "" {
|
||||
return
|
||||
}
|
||||
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 直接使用 string 类型的 userID
|
||||
h.chatService.SendTyping(context.Background(), client.UserID, conversationID)
|
||||
}
|
||||
|
||||
// handleReadReceipt 处理已读回执
|
||||
func (h *WebSocketHandler) handleReadReceipt(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr, _ := data["conversationId"].(string)
|
||||
if conversationIDStr == "" {
|
||||
return
|
||||
}
|
||||
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取lastReadSeq
|
||||
lastReadSeq, _ := data["lastReadSeq"].(float64)
|
||||
if lastReadSeq == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 直接使用 string 类型的 userID 和 conversationID
|
||||
if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil {
|
||||
log.Printf("Failed to mark as read: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 群组消息处理 ====================
|
||||
|
||||
// handleGroupMessage 处理群消息
|
||||
func (h *WebSocketHandler) handleGroupMessage(client *ws.Client, msg *ws.WSMessage) {
|
||||
// 打印接收到的消息类型和数据,用于调试
|
||||
log.Printf("[handleGroupMessage] Received message type: %s", msg.Type)
|
||||
log.Printf("[handleGroupMessage] Message data: %+v", msg.Data)
|
||||
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
log.Printf("Invalid group message data format: data is not map[string]interface{}")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析群组ID(支持 camelCase 和 snake_case)
|
||||
var groupIDFloat float64
|
||||
groupID := "" // 使用 groupID 作为最终变量名
|
||||
if val, ok := data["groupId"].(float64); ok {
|
||||
groupIDFloat = val
|
||||
groupID = strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
|
||||
} else if val, ok := data["group_id"].(string); ok {
|
||||
groupID = val
|
||||
}
|
||||
|
||||
if groupID == "" {
|
||||
log.Printf("Missing groupId in group message")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析会话ID(支持 camelCase 和 snake_case)
|
||||
var conversationID string
|
||||
if val, ok := data["conversationId"].(string); ok {
|
||||
conversationID = val
|
||||
} else if val, ok := data["conversation_id"].(string); ok {
|
||||
conversationID = val
|
||||
}
|
||||
if conversationID == "" {
|
||||
log.Printf("Missing conversationId in group message")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 segments
|
||||
var segments model.MessageSegments
|
||||
if data["segments"] != nil {
|
||||
segmentsBytes, err := json.Marshal(data["segments"])
|
||||
if err == nil {
|
||||
json.Unmarshal(segmentsBytes, &segments)
|
||||
}
|
||||
}
|
||||
|
||||
// 解析@用户列表(支持 camelCase 和 snake_case)
|
||||
var mentionUsers []uint64
|
||||
var mentionUsersInterface []interface{}
|
||||
if val, ok := data["mentionUsers"].([]interface{}); ok {
|
||||
mentionUsersInterface = val
|
||||
} else if val, ok := data["mention_users"].([]interface{}); ok {
|
||||
mentionUsersInterface = val
|
||||
}
|
||||
if len(mentionUsersInterface) > 0 {
|
||||
for _, uid := range mentionUsersInterface {
|
||||
if uidFloat, ok := uid.(float64); ok {
|
||||
mentionUsers = append(mentionUsers, uint64(uidFloat))
|
||||
} else if uidStr, ok := uid.(string); ok {
|
||||
// 处理字符串格式的用户ID
|
||||
if uidInt, err := strconv.ParseUint(uidStr, 10, 64); err == nil {
|
||||
mentionUsers = append(mentionUsers, uidInt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 解析@所有人(支持 camelCase 和 snake_case)
|
||||
var mentionAll bool
|
||||
if val, ok := data["mentionAll"].(bool); ok {
|
||||
mentionAll = val
|
||||
} else if val, ok := data["mention_all"].(bool); ok {
|
||||
mentionAll = val
|
||||
}
|
||||
|
||||
// 检查用户是否可以发送群消息(验证成员身份和禁言状态)
|
||||
// client.UserID 已经是 string 格式的 UUID
|
||||
if err := h.groupService.CanSendGroupMessage(client.UserID, groupID); err != nil {
|
||||
log.Printf("User cannot send group message: %v", err)
|
||||
// 发送错误消息
|
||||
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
||||
"error": "Cannot send group message",
|
||||
"reason": err.Error(),
|
||||
"type": "group_message_error",
|
||||
"groupId": groupID,
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(errorMsg)
|
||||
client.Send <- msgBytes
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 检查@所有人权限(只有群主和管理员可以@所有人)
|
||||
if mentionAll {
|
||||
if !h.groupService.IsGroupAdmin(client.UserID, groupID) {
|
||||
log.Printf("User %s has no permission to mention all in group %s", client.UserID, groupID)
|
||||
mentionAll = false // 取消@所有人标记
|
||||
}
|
||||
}
|
||||
|
||||
// 创建消息
|
||||
message := &model.Message{
|
||||
ConversationID: conversationID,
|
||||
SenderID: client.UserID,
|
||||
Segments: segments,
|
||||
Status: model.MessageStatusNormal,
|
||||
MentionAll: mentionAll,
|
||||
}
|
||||
|
||||
// 序列化mentionUsers为JSON
|
||||
if len(mentionUsers) > 0 {
|
||||
mentionUsersJSON, _ := json.Marshal(mentionUsers)
|
||||
message.MentionUsers = string(mentionUsersJSON)
|
||||
}
|
||||
|
||||
// 保存消息到数据库(只存库,不发私聊 WebSocket 帧,群消息通过 BroadcastGroupMessage 单独广播)
|
||||
savedMessage, err := h.chatService.SaveMessage(context.Background(), client.UserID, conversationID, segments, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to save group message: %v", err)
|
||||
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
||||
"error": "Failed to save group message",
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(errorMsg)
|
||||
client.Send <- msgBytes
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 更新消息的mention信息
|
||||
if len(mentionUsers) > 0 || mentionAll {
|
||||
message.ID = savedMessage.ID
|
||||
message.Seq = savedMessage.Seq
|
||||
}
|
||||
|
||||
// 构造群消息响应
|
||||
groupMsg := &ws.GroupMessage{
|
||||
ID: savedMessage.ID,
|
||||
ConversationID: conversationID,
|
||||
GroupID: groupID,
|
||||
SenderID: client.UserID,
|
||||
Seq: savedMessage.Seq,
|
||||
Segments: segments,
|
||||
MentionUsers: mentionUsers,
|
||||
MentionAll: mentionAll,
|
||||
CreatedAt: savedMessage.CreatedAt.UnixMilli(),
|
||||
}
|
||||
|
||||
// 广播消息给群组所有成员(排除发送者)
|
||||
h.BroadcastGroupMessage(groupID, groupMsg, client.UserID)
|
||||
|
||||
// 发送确认消息给发送者(作为meta事件)
|
||||
// 使用 meta 事件格式发送 ack
|
||||
metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{
|
||||
"detail_type": ws.MetaDetailTypeAck,
|
||||
"conversation_id": conversationID,
|
||||
"group_id": groupID,
|
||||
"id": savedMessage.ID,
|
||||
"user_id": client.UserID,
|
||||
"sender_id": client.UserID,
|
||||
"seq": savedMessage.Seq,
|
||||
"segments": segments,
|
||||
"created_at": savedMessage.CreatedAt.UnixMilli(),
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(metaAckMsg)
|
||||
client.Send <- msgBytes
|
||||
} else {
|
||||
log.Printf("[ERROR HandleGroupMessageSend] client.Send 为 nil, userID=%s", client.UserID)
|
||||
}
|
||||
|
||||
// 处理@提及通知
|
||||
if len(mentionUsers) > 0 || mentionAll {
|
||||
// 提取文本正文(不含 @ 部分)
|
||||
textContent := dto.ExtractTextContentFromModel(segments)
|
||||
// 在通知内容前拼接被@的真实昵称,通过群成员列表查找
|
||||
mentionContent := h.buildMentionContent(groupID, mentionUsers, mentionAll, textContent)
|
||||
h.handleGroupMention(groupID, savedMessage.ID, client.UserID, mentionContent, mentionUsers, mentionAll)
|
||||
}
|
||||
}
|
||||
|
||||
// handleGroupTyping 处理群输入状态
|
||||
func (h *WebSocketHandler) handleGroupTyping(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
groupIDFloat, _ := data["groupId"].(float64)
|
||||
if groupIDFloat == 0 {
|
||||
return
|
||||
}
|
||||
groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
|
||||
|
||||
isTyping, _ := data["isTyping"].(bool)
|
||||
|
||||
// 验证用户是否是群成员
|
||||
// client.UserID 已经是 string 格式的 UUID
|
||||
isMember, err := h.groupRepo.IsMember(groupID, client.UserID)
|
||||
if err != nil || !isMember {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := h.userRepo.GetByID(client.UserID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 构造输入状态消息
|
||||
typingMsg := &ws.GroupTypingMessage{
|
||||
GroupID: groupID,
|
||||
UserID: client.UserID,
|
||||
Username: user.Username,
|
||||
IsTyping: isTyping,
|
||||
}
|
||||
|
||||
// 广播给群组其他成员
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg)
|
||||
h.BroadcastGroupNoticeExclude(groupID, wsMsg, client.UserID)
|
||||
}
|
||||
|
||||
// handleGroupReadReceipt 处理群消息已读回执
|
||||
func (h *WebSocketHandler) handleGroupReadReceipt(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
conversationID, _ := data["conversationId"].(string)
|
||||
if conversationID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
lastReadSeq, _ := data["lastReadSeq"].(float64)
|
||||
if lastReadSeq == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 标记已读
|
||||
if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil {
|
||||
log.Printf("Failed to mark group message as read: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleGroupRecall 处理群消息撤回
|
||||
func (h *WebSocketHandler) handleGroupRecall(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
messageID, _ := data["messageId"].(string)
|
||||
if messageID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
groupIDFloat, _ := data["groupId"].(float64)
|
||||
if groupIDFloat == 0 {
|
||||
return
|
||||
}
|
||||
groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
|
||||
|
||||
// 撤回消息
|
||||
if err := h.chatService.RecallMessage(context.Background(), messageID, client.UserID); err != nil {
|
||||
log.Printf("Failed to recall group message: %v", err)
|
||||
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
||||
"error": "Failed to recall message",
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(errorMsg)
|
||||
client.Send <- msgBytes
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 广播撤回通知给群组所有成员
|
||||
recallNotice := ws.CreateWSMessage(ws.MessageTypeGroupRecall, map[string]interface{}{
|
||||
"messageId": messageID,
|
||||
"groupId": groupID,
|
||||
"userId": client.UserID,
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
})
|
||||
h.BroadcastGroupNotice(groupID, recallNotice)
|
||||
}
|
||||
|
||||
// handleGroupMention 处理群消息@提及通知
|
||||
func (h *WebSocketHandler) handleGroupMention(groupID string, messageID, senderID, content string, mentionUsers []uint64, mentionAll bool) {
|
||||
// 如果@所有人,获取所有群成员
|
||||
if mentionAll {
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for mention all: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, member := range members {
|
||||
// 不通知发送者自己
|
||||
memberIDStr := member.UserID
|
||||
if memberIDStr == senderID {
|
||||
continue
|
||||
}
|
||||
|
||||
// 发送@提及通知
|
||||
mentionMsg := &ws.GroupMentionMessage{
|
||||
GroupID: groupID,
|
||||
MessageID: messageID,
|
||||
FromUserID: senderID,
|
||||
Content: truncateContent(content, 50),
|
||||
MentionAll: true,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg)
|
||||
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 处理特定用户的@提及
|
||||
for _, userID := range mentionUsers {
|
||||
// userID 是 uint64,转换为 string
|
||||
userIDStr := strconv.FormatUint(userID, 10)
|
||||
if userIDStr == senderID {
|
||||
continue // 不通知发送者自己
|
||||
}
|
||||
|
||||
mentionMsg := &ws.GroupMentionMessage{
|
||||
GroupID: groupID,
|
||||
MessageID: messageID,
|
||||
FromUserID: senderID,
|
||||
Content: truncateContent(content, 50),
|
||||
MentionAll: false,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg)
|
||||
h.wsManager.SendToUser(userIDStr, wsMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// buildMentionContent 构建@提及通知的内容,通过群成员列表查找被@用户的真实昵称
|
||||
func (h *WebSocketHandler) buildMentionContent(groupID string, mentionUsers []uint64, mentionAll bool, textBody string) string {
|
||||
var prefix string
|
||||
if mentionAll {
|
||||
prefix = "@所有人 "
|
||||
} else if len(mentionUsers) > 0 {
|
||||
// 查询群成员列表,找到被@用户的昵称
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err == nil {
|
||||
memberNickMap := make(map[string]string, len(members))
|
||||
for _, m := range members {
|
||||
displayName := m.Nickname
|
||||
if displayName == "" {
|
||||
displayName = m.UserID
|
||||
}
|
||||
memberNickMap[m.UserID] = displayName
|
||||
}
|
||||
for _, uid := range mentionUsers {
|
||||
uidStr := strconv.FormatUint(uid, 10)
|
||||
if name, ok := memberNickMap[uidStr]; ok {
|
||||
prefix += "@" + name + " "
|
||||
} else {
|
||||
prefix += "@某人 "
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for range mentionUsers {
|
||||
prefix += "@某人 "
|
||||
}
|
||||
}
|
||||
}
|
||||
return prefix + textBody
|
||||
}
|
||||
|
||||
// BroadcastGroupMessage 向群组所有成员广播消息
|
||||
func (h *WebSocketHandler) BroadcastGroupMessage(groupID string, message *ws.GroupMessage, excludeUserID string) {
|
||||
// 获取群组所有成员
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for broadcast: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建WebSocket消息
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMessage, message)
|
||||
|
||||
// 遍历成员,如果在线则发送消息
|
||||
for _, member := range members {
|
||||
memberIDStr := member.UserID
|
||||
|
||||
// 排除发送者
|
||||
if memberIDStr == excludeUserID {
|
||||
continue
|
||||
}
|
||||
|
||||
// 发送消息
|
||||
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastGroupNotice 广播群组通知给所有成员
|
||||
func (h *WebSocketHandler) BroadcastGroupNotice(groupID string, notice *ws.WSMessage) {
|
||||
// 获取群组所有成员
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for notice broadcast: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 遍历成员,如果在线则发送通知
|
||||
for _, member := range members {
|
||||
memberIDStr := member.UserID
|
||||
h.wsManager.SendToUser(memberIDStr, notice)
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastGroupNoticeExclude 广播群组通知给所有成员(排除指定用户)
|
||||
func (h *WebSocketHandler) BroadcastGroupNoticeExclude(groupID string, notice *ws.WSMessage, excludeUserID string) {
|
||||
// 获取群组所有成员
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for notice broadcast: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 遍历成员,如果在线则发送通知
|
||||
for _, member := range members {
|
||||
memberIDStr := member.UserID
|
||||
if memberIDStr == excludeUserID {
|
||||
continue
|
||||
}
|
||||
h.wsManager.SendToUser(memberIDStr, notice)
|
||||
}
|
||||
}
|
||||
|
||||
// SendGroupMemberNotice 发送群成员变动通知
|
||||
func (h *WebSocketHandler) SendGroupMemberNotice(noticeType string, groupID string, data *ws.GroupNoticeData) {
|
||||
notice := &ws.GroupNoticeMessage{
|
||||
NoticeType: noticeType,
|
||||
GroupID: groupID,
|
||||
Data: data,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupNotice, notice)
|
||||
h.BroadcastGroupNotice(groupID, wsMsg)
|
||||
}
|
||||
|
||||
// truncateContent 截断内容
|
||||
func truncateContent(content string, maxLen int) string {
|
||||
if len(content) <= maxLen {
|
||||
return content
|
||||
}
|
||||
return content[:maxLen] + "..."
|
||||
}
|
||||
|
||||
// BroadcastGroupTyping 向群组所有成员广播输入状态
|
||||
func (h *WebSocketHandler) BroadcastGroupTyping(groupID string, typingMsg *ws.GroupTypingMessage, excludeUserID string) {
|
||||
// 获取群组所有成员
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for typing broadcast: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建WebSocket消息
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg)
|
||||
|
||||
// 遍历成员,如果在线则发送消息
|
||||
for _, member := range members {
|
||||
memberIDStr := member.UserID
|
||||
|
||||
// 排除指定用户
|
||||
if memberIDStr == excludeUserID {
|
||||
continue
|
||||
}
|
||||
|
||||
// 发送消息
|
||||
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastGroupRead 向群组所有成员广播已读状态
|
||||
func (h *WebSocketHandler) BroadcastGroupRead(groupID string, readMsg map[string]interface{}, excludeUserID string) {
|
||||
// 获取群组所有成员
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for read broadcast: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建WebSocket消息
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupRead, readMsg)
|
||||
|
||||
// 遍历成员,如果在线则发送消息
|
||||
for _, member := range members {
|
||||
memberIDStr := member.UserID
|
||||
|
||||
// 排除指定用户
|
||||
if memberIDStr == excludeUserID {
|
||||
continue
|
||||
}
|
||||
|
||||
// 发送消息
|
||||
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
||||
}
|
||||
}
|
||||
@@ -1,18 +1,10 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
// CORS CORS中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取请求路径
|
||||
path := c.Request.URL.Path
|
||||
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
// 添加 WebSocket 升级所需的头
|
||||
@@ -22,25 +14,10 @@ func CORS() gin.HandlerFunc {
|
||||
|
||||
// 处理 WebSocket 升级请求的预检
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
log.Printf("[CORS] OPTIONS 预检请求: %s", path)
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
// 针对 WebSocket 路径的特殊处理
|
||||
if path == "/ws" {
|
||||
connection := c.GetHeader("Connection")
|
||||
upgrade := c.GetHeader("Upgrade")
|
||||
log.Printf("[CORS] WebSocket 请求: Connection=%s, Upgrade=%s", connection, upgrade)
|
||||
|
||||
// 检查是否是有效的 WebSocket 升级请求
|
||||
if strings.Contains(strings.ToLower(connection), "upgrade") && strings.ToLower(upgrade) == "websocket" {
|
||||
log.Printf("[CORS] 有效的 WebSocket 升级请求")
|
||||
} else {
|
||||
log.Printf("[CORS] 警告: 不是有效的 WebSocket 升级请求!")
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ type Post struct {
|
||||
|
||||
// 时间戳
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_posts_status_created,priority:2,sort:desc;index:idx_posts_user_status_created,priority:3,sort:desc;index:idx_posts_hot_score_created,priority:2,sort:desc"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime:false"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成UUID
|
||||
|
||||
152
internal/pkg/sse/hub.go
Normal file
152
internal/pkg/sse/hub.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package sse
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultUserBufferSize = 128
|
||||
maxReplayEvents = 200
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
ID uint64 `json:"event_id"`
|
||||
Event string `json:"event"`
|
||||
TS int64 `json:"ts"`
|
||||
Payload interface{} `json:"payload"`
|
||||
}
|
||||
|
||||
type subscriber struct {
|
||||
id uint64
|
||||
ch chan Event
|
||||
quit chan struct{}
|
||||
}
|
||||
|
||||
type Hub struct {
|
||||
seq uint64
|
||||
|
||||
mu sync.RWMutex
|
||||
subscribers map[string]map[uint64]*subscriber
|
||||
history map[string][]Event
|
||||
}
|
||||
|
||||
func NewHub() *Hub {
|
||||
return &Hub{
|
||||
subscribers: make(map[string]map[uint64]*subscriber),
|
||||
history: make(map[string][]Event),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) NextID() uint64 {
|
||||
return atomic.AddUint64(&h.seq, 1)
|
||||
}
|
||||
|
||||
func ParseEventID(raw string) uint64 {
|
||||
if raw == "" {
|
||||
return 0
|
||||
}
|
||||
id, err := strconv.ParseUint(raw, 10, 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func (h *Hub) Subscribe(userID string, afterID uint64) (chan Event, func(), []Event) {
|
||||
subID := h.NextID()
|
||||
sub := &subscriber{
|
||||
id: subID,
|
||||
ch: make(chan Event, defaultUserBufferSize),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
if _, ok := h.subscribers[userID]; !ok {
|
||||
h.subscribers[userID] = make(map[uint64]*subscriber)
|
||||
}
|
||||
h.subscribers[userID][subID] = sub
|
||||
|
||||
replay := make([]Event, 0)
|
||||
for _, e := range h.history[userID] {
|
||||
if e.ID > afterID {
|
||||
replay = append(replay, e)
|
||||
}
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
cancel := func() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if userSubs, ok := h.subscribers[userID]; ok {
|
||||
if s, exists := userSubs[subID]; exists {
|
||||
close(s.quit)
|
||||
delete(userSubs, subID)
|
||||
close(s.ch)
|
||||
}
|
||||
if len(userSubs) == 0 {
|
||||
delete(h.subscribers, userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sub.ch, cancel, replay
|
||||
}
|
||||
|
||||
func (h *Hub) HasSubscribers(userID string) bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return len(h.subscribers[userID]) > 0
|
||||
}
|
||||
|
||||
func (h *Hub) PublishToUser(userID string, eventName string, payload interface{}) Event {
|
||||
ev := Event{
|
||||
ID: h.NextID(),
|
||||
Event: eventName,
|
||||
TS: time.Now().UnixMilli(),
|
||||
Payload: payload,
|
||||
}
|
||||
h.publish(userID, ev)
|
||||
return ev
|
||||
}
|
||||
|
||||
func (h *Hub) PublishToUsers(userIDs []string, eventName string, payload interface{}) {
|
||||
for _, uid := range userIDs {
|
||||
h.PublishToUser(uid, eventName, payload)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) publish(userID string, ev Event) {
|
||||
h.mu.Lock()
|
||||
history := append(h.history[userID], ev)
|
||||
if len(history) > maxReplayEvents {
|
||||
history = history[len(history)-maxReplayEvents:]
|
||||
}
|
||||
h.history[userID] = history
|
||||
|
||||
targets := make([]*subscriber, 0, len(h.subscribers[userID]))
|
||||
for _, s := range h.subscribers[userID] {
|
||||
targets = append(targets, s)
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
for _, s := range targets {
|
||||
select {
|
||||
case <-s.quit:
|
||||
case s.ch <- ev:
|
||||
default:
|
||||
// 慢消费者丢弃单条消息,客户端可通过 Last-Event-ID + HTTP 同步补偿
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func EncodeData(ev Event) (string, error) {
|
||||
body, err := json.Marshal(ev.Payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(body), nil
|
||||
}
|
||||
@@ -1,435 +0,0 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// WebSocket消息类型常量
|
||||
const (
|
||||
MessageTypePing = "ping"
|
||||
MessageTypePong = "pong"
|
||||
MessageTypeMessage = "message"
|
||||
MessageTypeTyping = "typing"
|
||||
MessageTypeRead = "read"
|
||||
MessageTypeAck = "ack"
|
||||
MessageTypeError = "error"
|
||||
MessageTypeRecall = "recall" // 撤回消息
|
||||
MessageTypeSystem = "system" // 系统消息
|
||||
MessageTypeNotification = "notification" // 通知消息
|
||||
MessageTypeAnnouncement = "announcement" // 公告消息
|
||||
|
||||
// 群组相关消息类型
|
||||
MessageTypeGroupMessage = "group_message" // 群消息
|
||||
MessageTypeGroupTyping = "group_typing" // 群输入状态
|
||||
MessageTypeGroupNotice = "group_notice" // 群组通知(成员变动等)
|
||||
MessageTypeGroupMention = "group_mention" // @提及通知
|
||||
MessageTypeGroupRead = "group_read" // 群消息已读
|
||||
MessageTypeGroupRecall = "group_recall" // 群消息撤回
|
||||
|
||||
// Meta事件详细类型
|
||||
MetaDetailTypeHeartbeat = "heartbeat"
|
||||
MetaDetailTypeTyping = "typing"
|
||||
MetaDetailTypeAck = "ack" // 消息发送确认
|
||||
MetaDetailTypeRead = "read" // 已读回执
|
||||
)
|
||||
|
||||
// WSMessage WebSocket消息结构
|
||||
type WSMessage struct {
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
// ChatMessage 聊天消息结构
|
||||
type ChatMessage struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
SenderID string `json:"sender_id"`
|
||||
Seq int64 `json:"seq"`
|
||||
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
|
||||
ReplyToID *string `json:"reply_to_id,omitempty"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
|
||||
// SystemMessage 系统消息结构
|
||||
type SystemMessage struct {
|
||||
ID string `json:"id"` // 消息ID
|
||||
Type string `json:"type"` // 消息子类型(如:account_banned, post_approved等)
|
||||
Title string `json:"title"` // 消息标题
|
||||
Content string `json:"content"` // 消息内容
|
||||
Data map[string]interface{} `json:"data"` // 额外数据
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// NotificationMessage 通知消息结构
|
||||
type NotificationMessage struct {
|
||||
ID string `json:"id"` // 通知ID
|
||||
Type string `json:"type"` // 通知类型(like, comment, follow, mention等)
|
||||
Title string `json:"title"` // 通知标题
|
||||
Content string `json:"content"` // 通知内容
|
||||
TriggerUser *NotificationUser `json:"trigger_user"` // 触发用户
|
||||
ResourceType string `json:"resource_type"` // 资源类型(post, comment等)
|
||||
ResourceID string `json:"resource_id"` // 资源ID
|
||||
Extra map[string]interface{} `json:"extra"` // 额外数据
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// NotificationUser 通知中的用户信息
|
||||
type NotificationUser struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Avatar string `json:"avatar"`
|
||||
}
|
||||
|
||||
// AnnouncementMessage 公告消息结构
|
||||
type AnnouncementMessage struct {
|
||||
ID string `json:"id"` // 公告ID
|
||||
Title string `json:"title"` // 公告标题
|
||||
Content string `json:"content"` // 公告内容
|
||||
Priority int `json:"priority"` // 优先级(1-10)
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// GroupMessage 群消息结构
|
||||
type GroupMessage struct {
|
||||
ID string `json:"id"` // 消息ID
|
||||
ConversationID string `json:"conversation_id"` // 会话ID(群聊会话)
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
SenderID string `json:"sender_id"` // 发送者ID
|
||||
Seq int64 `json:"seq"` // 消息序号
|
||||
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
|
||||
ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID
|
||||
MentionUsers []uint64 `json:"mention_users,omitempty"` // @的用户ID列表
|
||||
MentionAll bool `json:"mention_all"` // 是否@所有人
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// GroupTypingMessage 群输入状态消息
|
||||
type GroupTypingMessage struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
UserID string `json:"user_id"` // 用户ID
|
||||
Username string `json:"username"` // 用户名
|
||||
IsTyping bool `json:"is_typing"` // 是否正在输入
|
||||
}
|
||||
|
||||
// GroupNoticeMessage 群组通知消息
|
||||
type GroupNoticeMessage struct {
|
||||
NoticeType string `json:"notice_type"` // 通知类型:member_join, member_leave, member_removed, role_changed, muted, unmuted, group_dissolved
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
Data interface{} `json:"data"` // 通知数据
|
||||
Timestamp int64 `json:"timestamp"` // 时间戳
|
||||
MessageID string `json:"message_id,omitempty"` // 消息ID(如果通知保存为消息)
|
||||
Seq int64 `json:"seq,omitempty"` // 消息序号(如果通知保存为消息)
|
||||
}
|
||||
|
||||
// GroupNoticeData 通知数据结构
|
||||
type GroupNoticeData struct {
|
||||
// 成员变动
|
||||
UserID string `json:"user_id,omitempty"` // 变动的用户ID
|
||||
Username string `json:"username,omitempty"` // 用户名
|
||||
OperatorID string `json:"operator_id,omitempty"` // 操作者ID
|
||||
OpName string `json:"op_name,omitempty"` // 操作者名称
|
||||
NewRole string `json:"new_role,omitempty"` // 新角色
|
||||
OldRole string `json:"old_role,omitempty"` // 旧角色
|
||||
MemberCount int `json:"member_count,omitempty"` // 当前成员数
|
||||
|
||||
// 群设置变更
|
||||
MuteAll bool `json:"mute_all,omitempty"` // 全员禁言状态
|
||||
}
|
||||
|
||||
// GroupMentionMessage @提及通知消息
|
||||
type GroupMentionMessage struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
MessageID string `json:"message_id"` // 消息ID
|
||||
FromUserID string `json:"from_user_id"` // 发送者ID
|
||||
FromName string `json:"from_name"` // 发送者名称
|
||||
Content string `json:"content"` // 消息内容预览
|
||||
MentionAll bool `json:"mention_all"` // 是否@所有人
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// AckMessage 消息发送确认结构
|
||||
type AckMessage struct {
|
||||
ConversationID string `json:"conversation_id"` // 会话ID
|
||||
GroupID string `json:"group_id,omitempty"` // 群组ID(群聊时)
|
||||
ID string `json:"id"` // 消息ID
|
||||
SenderID string `json:"sender_id"` // 发送者ID
|
||||
Seq int64 `json:"seq"` // 消息序号
|
||||
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// Client WebSocket客户端
|
||||
type Client struct {
|
||||
ID string
|
||||
UserID string
|
||||
Conn *websocket.Conn
|
||||
Send chan []byte
|
||||
Manager *WebSocketManager
|
||||
IsClosed bool
|
||||
Mu sync.Mutex
|
||||
closeOnce sync.Once // 确保 Send channel 只关闭一次
|
||||
}
|
||||
|
||||
// WebSocketManager WebSocket连接管理器
|
||||
type WebSocketManager struct {
|
||||
clients map[string]*Client // userID -> Client
|
||||
register chan *Client
|
||||
unregister chan *Client
|
||||
broadcast chan *BroadcastMessage
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// BroadcastMessage 广播消息
|
||||
type BroadcastMessage struct {
|
||||
Message *WSMessage
|
||||
ExcludeUser string // 排除的用户ID,为空表示不排除
|
||||
TargetUser string // 目标用户ID,为空表示广播给所有用户
|
||||
}
|
||||
|
||||
// NewWebSocketManager 创建WebSocket管理器
|
||||
func NewWebSocketManager() *WebSocketManager {
|
||||
return &WebSocketManager{
|
||||
clients: make(map[string]*Client),
|
||||
register: make(chan *Client, 100),
|
||||
unregister: make(chan *Client, 100),
|
||||
broadcast: make(chan *BroadcastMessage, 100),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动管理器
|
||||
func (m *WebSocketManager) Start() {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case client := <-m.register:
|
||||
m.mutex.Lock()
|
||||
m.clients[client.UserID] = client
|
||||
m.mutex.Unlock()
|
||||
log.Printf("WebSocket client connected: userID=%s, 当前在线用户数=%d", client.UserID, len(m.clients))
|
||||
|
||||
case client := <-m.unregister:
|
||||
m.mutex.Lock()
|
||||
if _, ok := m.clients[client.UserID]; ok {
|
||||
delete(m.clients, client.UserID)
|
||||
// 使用 closeOnce 确保 channel 只关闭一次,避免 panic
|
||||
client.closeOnce.Do(func() {
|
||||
close(client.Send)
|
||||
})
|
||||
log.Printf("WebSocket client disconnected: userID=%s", client.UserID)
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
case broadcast := <-m.broadcast:
|
||||
m.sendMessage(broadcast)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Register 注册客户端
|
||||
func (m *WebSocketManager) Register(client *Client) {
|
||||
m.register <- client
|
||||
}
|
||||
|
||||
// Unregister 注销客户端
|
||||
func (m *WebSocketManager) Unregister(client *Client) {
|
||||
m.unregister <- client
|
||||
}
|
||||
|
||||
// Broadcast 广播消息给所有用户
|
||||
func (m *WebSocketManager) Broadcast(msg *WSMessage) {
|
||||
m.broadcast <- &BroadcastMessage{
|
||||
Message: msg,
|
||||
TargetUser: "",
|
||||
}
|
||||
}
|
||||
|
||||
// SendToUser 发送消息给指定用户
|
||||
func (m *WebSocketManager) SendToUser(userID string, msg *WSMessage) {
|
||||
m.broadcast <- &BroadcastMessage{
|
||||
Message: msg,
|
||||
TargetUser: userID,
|
||||
}
|
||||
}
|
||||
|
||||
// SendToUsers 发送消息给指定用户列表
|
||||
func (m *WebSocketManager) SendToUsers(userIDs []string, msg *WSMessage) {
|
||||
for _, userID := range userIDs {
|
||||
m.SendToUser(userID, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// GetClient 获取客户端
|
||||
func (m *WebSocketManager) GetClient(userID string) (*Client, bool) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
client, ok := m.clients[userID]
|
||||
return client, ok
|
||||
}
|
||||
|
||||
// GetAllClients 获取所有客户端
|
||||
func (m *WebSocketManager) GetAllClients() map[string]*Client {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return m.clients
|
||||
}
|
||||
|
||||
// GetClientCount 获取在线用户数量
|
||||
func (m *WebSocketManager) GetClientCount() int {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return len(m.clients)
|
||||
}
|
||||
|
||||
// IsUserOnline 检查用户是否在线
|
||||
func (m *WebSocketManager) IsUserOnline(userID string) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
_, ok := m.clients[userID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// sendMessage 发送消息
|
||||
func (m *WebSocketManager) sendMessage(broadcast *BroadcastMessage) {
|
||||
msgBytes, err := json.Marshal(broadcast.Message)
|
||||
if err != nil {
|
||||
log.Printf("Failed to marshal message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
for userID, client := range m.clients {
|
||||
// 如果指定了目标用户,只发送给目标用户
|
||||
if broadcast.TargetUser != "" && userID != broadcast.TargetUser {
|
||||
continue
|
||||
}
|
||||
|
||||
// 如果指定了排除用户,跳过
|
||||
if broadcast.ExcludeUser != "" && userID == broadcast.ExcludeUser {
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case client.Send <- msgBytes:
|
||||
default:
|
||||
log.Printf("Failed to send message to user %s: channel full", userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SendPing 发送心跳
|
||||
func (c *Client) SendPing() error {
|
||||
c.Mu.Lock()
|
||||
defer c.Mu.Unlock()
|
||||
if c.IsClosed {
|
||||
return nil
|
||||
}
|
||||
msg := WSMessage{
|
||||
Type: MessageTypePing,
|
||||
Data: nil,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
msgBytes, _ := json.Marshal(msg)
|
||||
return c.Conn.WriteMessage(websocket.TextMessage, msgBytes)
|
||||
}
|
||||
|
||||
// SendPong 发送Pong响应
|
||||
func (c *Client) SendPong() error {
|
||||
c.Mu.Lock()
|
||||
defer c.Mu.Unlock()
|
||||
if c.IsClosed {
|
||||
return nil
|
||||
}
|
||||
msg := WSMessage{
|
||||
Type: MessageTypePong,
|
||||
Data: nil,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
msgBytes, _ := json.Marshal(msg)
|
||||
return c.Conn.WriteMessage(websocket.TextMessage, msgBytes)
|
||||
}
|
||||
|
||||
// WritePump 写入泵,将消息从Manager发送到客户端
|
||||
func (c *Client) WritePump() {
|
||||
defer func() {
|
||||
c.Conn.Close()
|
||||
c.Mu.Lock()
|
||||
c.IsClosed = true
|
||||
c.Mu.Unlock()
|
||||
}()
|
||||
|
||||
for {
|
||||
message, ok := <-c.Send
|
||||
if !ok {
|
||||
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
c.Mu.Lock()
|
||||
if c.IsClosed {
|
||||
c.Mu.Unlock()
|
||||
return
|
||||
}
|
||||
err := c.Conn.WriteMessage(websocket.TextMessage, message)
|
||||
c.Mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Write error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReadPump 读取泵,从客户端读取消息
|
||||
func (c *Client) ReadPump(handler func(msg *WSMessage)) {
|
||||
defer func() {
|
||||
c.Manager.Unregister(c)
|
||||
c.Conn.Close()
|
||||
c.Mu.Lock()
|
||||
c.IsClosed = true
|
||||
c.Mu.Unlock()
|
||||
}()
|
||||
|
||||
c.Conn.SetReadLimit(512 * 1024) // 512KB
|
||||
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
c.Conn.SetPongHandler(func(string) error {
|
||||
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
_, message, err := c.Conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Printf("WebSocket error: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
var wsMsg WSMessage
|
||||
if err := json.Unmarshal(message, &wsMsg); err != nil {
|
||||
log.Printf("Failed to unmarshal message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
handler(&wsMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateWSMessage 创建WebSocket消息
|
||||
func CreateWSMessage(msgType string, data interface{}) *WSMessage {
|
||||
return &WSMessage{
|
||||
Type: msgType,
|
||||
Data: data,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
}
|
||||
@@ -18,32 +18,7 @@ func NewCommentRepository(db *gorm.DB) *CommentRepository {
|
||||
|
||||
// Create 创建评论
|
||||
func (r *CommentRepository) Create(comment *model.Comment) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 创建评论
|
||||
err := tx.Create(comment).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 增加帖子的评论数并同步热度分
|
||||
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
|
||||
Updates(map[string]interface{}{
|
||||
"comments_count": gorm.Expr("comments_count + 1"),
|
||||
"hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"),
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果是回复,增加父评论的回复数
|
||||
if comment.ParentID != nil && *comment.ParentID != "" {
|
||||
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
|
||||
UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return r.db.Create(comment).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取评论
|
||||
@@ -87,23 +62,52 @@ func (r *CommentRepository) Delete(id string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// 减少帖子的评论数并同步热度分
|
||||
// 仅已发布评论才参与统计,避免 pending/rejected 影响计数
|
||||
if comment.Status == model.CommentStatusPublished {
|
||||
// 减少帖子的评论数并同步热度分
|
||||
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
|
||||
Updates(map[string]interface{}{
|
||||
"comments_count": gorm.Expr("comments_count - 1"),
|
||||
"hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"),
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果是回复,减少父评论的回复数
|
||||
if comment.ParentID != nil && *comment.ParentID != "" {
|
||||
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
|
||||
UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ApplyPublishedStats 在评论审核通过后更新帖子评论数/回复数
|
||||
func (r *CommentRepository) ApplyPublishedStats(comment *model.Comment) error {
|
||||
if comment == nil {
|
||||
return nil
|
||||
}
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 增加帖子的评论数并同步热度分
|
||||
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
|
||||
Updates(map[string]interface{}{
|
||||
"comments_count": gorm.Expr("comments_count - 1"),
|
||||
"hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"),
|
||||
"comments_count": gorm.Expr("comments_count + 1"),
|
||||
"hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"),
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果是回复,减少父评论的回复数
|
||||
// 如果是回复,增加父评论的回复数
|
||||
if comment.ParentID != nil && *comment.ParentID != "" {
|
||||
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
|
||||
UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil {
|
||||
UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -52,9 +53,41 @@ func (r *PostRepository) GetByID(id string) (*model.Post, error) {
|
||||
|
||||
// Update 更新帖子
|
||||
func (r *PostRepository) Update(post *model.Post) error {
|
||||
post.UpdatedAt = time.Now()
|
||||
return r.db.Save(post).Error
|
||||
}
|
||||
|
||||
// UpdateWithImages 更新帖子及其图片(images=nil 表示不更新图片)
|
||||
func (r *PostRepository) UpdateWithImages(post *model.Post, images *[]string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
post.UpdatedAt = time.Now()
|
||||
if err := tx.Save(post).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if images == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := tx.Where("post_id = ?", post.ID).Delete(&model.PostImage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, url := range *images {
|
||||
image := &model.PostImage{
|
||||
PostID: post.ID,
|
||||
URL: url,
|
||||
SortOrder: i,
|
||||
}
|
||||
if err := tx.Create(image).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateModerationStatus 更新帖子审核状态
|
||||
func (r *PostRepository) UpdateModerationStatus(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error {
|
||||
updates := map[string]interface{}{
|
||||
@@ -100,15 +133,24 @@ func (r *PostRepository) Delete(id string) error {
|
||||
}
|
||||
|
||||
// List 分页获取帖子列表
|
||||
func (r *PostRepository) List(page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
||||
// includePending=true 时,仅在指定 userID 下额外返回 pending(用于作者查看自己待审核帖子)
|
||||
func (r *PostRepository) List(page, pageSize int, userID string, includePending bool) ([]*model.Post, int64, error) {
|
||||
var posts []*model.Post
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished)
|
||||
query := r.db.Model(&model.Post{})
|
||||
|
||||
if userID != "" {
|
||||
query = query.Where("user_id = ?", userID)
|
||||
}
|
||||
if includePending && userID != "" {
|
||||
query = query.Where("status IN ?", []model.PostStatus{
|
||||
model.PostStatusPublished,
|
||||
model.PostStatusPending,
|
||||
})
|
||||
} else {
|
||||
query = query.Where("status = ?", model.PostStatusPublished)
|
||||
}
|
||||
|
||||
query.Count(&total)
|
||||
|
||||
@@ -119,14 +161,32 @@ func (r *PostRepository) List(page, pageSize int, userID string) ([]*model.Post,
|
||||
}
|
||||
|
||||
// GetUserPosts 获取用户帖子
|
||||
func (r *PostRepository) GetUserPosts(userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
func (r *PostRepository) GetUserPosts(userID string, page, pageSize int, includePending bool) ([]*model.Post, int64, error) {
|
||||
var posts []*model.Post
|
||||
var total int64
|
||||
|
||||
r.db.Model(&model.Post{}).Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Count(&total)
|
||||
statusQuery := r.db.Model(&model.Post{}).Where("user_id = ?", userID)
|
||||
if includePending {
|
||||
statusQuery = statusQuery.Where("status IN ?", []model.PostStatus{
|
||||
model.PostStatusPublished,
|
||||
model.PostStatusPending,
|
||||
})
|
||||
} else {
|
||||
statusQuery = statusQuery.Where("status = ?", model.PostStatusPublished)
|
||||
}
|
||||
statusQuery.Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := r.db.Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
|
||||
listQuery := r.db.Where("user_id = ?", userID)
|
||||
if includePending {
|
||||
listQuery = listQuery.Where("status IN ?", []model.PostStatus{
|
||||
model.PostStatusPublished,
|
||||
model.PostStatusPending,
|
||||
})
|
||||
} else {
|
||||
listQuery = listQuery.Where("status = ?", model.PostStatusPublished)
|
||||
}
|
||||
err := listQuery.Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
|
||||
|
||||
return posts, total, err
|
||||
}
|
||||
@@ -256,7 +316,8 @@ func (r *PostRepository) IsFavorited(postID, userID string) bool {
|
||||
// IncrementViews 增加帖子观看量
|
||||
func (r *PostRepository) IncrementViews(postID string) error {
|
||||
return r.db.Model(&model.Post{}).Where("id = ?", postID).
|
||||
Updates(map[string]interface{}{
|
||||
// 浏览量属于统计字段,不应影响帖子内容更新时间(updated_at)
|
||||
UpdateColumns(map[string]interface{}{
|
||||
"views_count": gorm.Expr("views_count + 1"),
|
||||
"hot_score": gorm.Expr("likes_count * 2 + comments_count * 3 + (views_count + 1) * 0.1"),
|
||||
}).Error
|
||||
|
||||
@@ -177,7 +177,9 @@ func (r *UserRepository) RefreshFollowersCount(userID string) error {
|
||||
// GetPostsCount 获取用户帖子数(实时计算)
|
||||
func (r *UserRepository) GetPostsCount(userID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Post{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
err := r.db.Model(&model.Post{}).
|
||||
Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
@@ -202,7 +204,7 @@ func (r *UserRepository) GetPostsCountBatch(userIDs []string) (map[string]int64,
|
||||
var counts []CountResult
|
||||
err := r.db.Model(&model.Post{}).
|
||||
Select("user_id, count(*) as count").
|
||||
Where("user_id IN ?", userIDs).
|
||||
Where("user_id IN ? AND status = ?", userIDs, model.PostStatusPublished).
|
||||
Group("user_id").
|
||||
Scan(&counts).Error
|
||||
if err != nil {
|
||||
|
||||
@@ -17,7 +17,6 @@ type Router struct {
|
||||
messageHandler *handler.MessageHandler
|
||||
notificationHandler *handler.NotificationHandler
|
||||
uploadHandler *handler.UploadHandler
|
||||
wsHandler *handler.WebSocketHandler
|
||||
pushHandler *handler.PushHandler
|
||||
systemMessageHandler *handler.SystemMessageHandler
|
||||
groupHandler *handler.GroupHandler
|
||||
@@ -36,7 +35,6 @@ func New(
|
||||
notificationHandler *handler.NotificationHandler,
|
||||
uploadHandler *handler.UploadHandler,
|
||||
jwtService *service.JWTService,
|
||||
wsHandler *handler.WebSocketHandler,
|
||||
pushHandler *handler.PushHandler,
|
||||
systemMessageHandler *handler.SystemMessageHandler,
|
||||
groupHandler *handler.GroupHandler,
|
||||
@@ -55,7 +53,6 @@ func New(
|
||||
messageHandler: messageHandler,
|
||||
notificationHandler: notificationHandler,
|
||||
uploadHandler: uploadHandler,
|
||||
wsHandler: wsHandler,
|
||||
pushHandler: pushHandler,
|
||||
systemMessageHandler: systemMessageHandler,
|
||||
groupHandler: groupHandler,
|
||||
@@ -79,11 +76,6 @@ func (r *Router) setupRoutes() {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// WebSocket 路由
|
||||
if r.wsHandler != nil {
|
||||
r.engine.GET("/ws", r.wsHandler.HandleWebSocket)
|
||||
}
|
||||
|
||||
// API v1
|
||||
v1 := r.engine.Group("/api/v1")
|
||||
{
|
||||
@@ -210,10 +202,18 @@ func (r *Router) setupRoutes() {
|
||||
conversations.POST("/set_pinned", r.messageHandler.HandleSetConversationPinned)
|
||||
// 获取未读消息总数
|
||||
conversations.GET("/unread/count", r.messageHandler.GetUnreadCount)
|
||||
// 上报输入状态
|
||||
conversations.POST("/typing", r.messageHandler.HandleTyping)
|
||||
// 仅自己删除会话
|
||||
conversations.DELETE("/:id/self", r.messageHandler.HandleDeleteConversationForSelf)
|
||||
}
|
||||
|
||||
realtime := v1.Group("/realtime")
|
||||
realtime.Use(authMiddleware)
|
||||
{
|
||||
realtime.GET("/sse", r.messageHandler.HandleSSE)
|
||||
}
|
||||
|
||||
// 消息操作路由
|
||||
messages := v1.Group("/messages")
|
||||
messages.Use(authMiddleware)
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/websocket"
|
||||
"carrot_bbs/internal/pkg/sse"
|
||||
"carrot_bbs/internal/repository"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -41,17 +41,13 @@ type ChatService interface {
|
||||
RecallMessage(ctx context.Context, messageID string, userID string) error
|
||||
DeleteMessage(ctx context.Context, messageID string, userID string) error
|
||||
|
||||
// WebSocket相关
|
||||
// 实时事件相关
|
||||
SendTyping(ctx context.Context, senderID string, conversationID string)
|
||||
BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string)
|
||||
|
||||
// 系统消息推送
|
||||
// 在线状态
|
||||
IsUserOnline(userID string) bool
|
||||
PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error
|
||||
PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error
|
||||
PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error
|
||||
|
||||
// 仅保存消息到数据库,不发送 WebSocket 推送(供群聊等自行推送的场景使用)
|
||||
// 仅保存消息到数据库,不发送实时推送(供群聊等自行推送的场景使用)
|
||||
SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
|
||||
}
|
||||
|
||||
@@ -61,7 +57,7 @@ type chatServiceImpl struct {
|
||||
repo *repository.MessageRepository
|
||||
userRepo *repository.UserRepository
|
||||
sensitive SensitiveService
|
||||
wsManager *websocket.WebSocketManager
|
||||
sseHub *sse.Hub
|
||||
}
|
||||
|
||||
// NewChatService 创建聊天服务
|
||||
@@ -70,17 +66,24 @@ func NewChatService(
|
||||
repo *repository.MessageRepository,
|
||||
userRepo *repository.UserRepository,
|
||||
sensitive SensitiveService,
|
||||
wsManager *websocket.WebSocketManager,
|
||||
sseHub *sse.Hub,
|
||||
) ChatService {
|
||||
return &chatServiceImpl{
|
||||
db: db,
|
||||
repo: repo,
|
||||
userRepo: userRepo,
|
||||
sensitive: sensitive,
|
||||
wsManager: wsManager,
|
||||
sseHub: sseHub,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *chatServiceImpl) publishSSEToUsers(userIDs []string, event string, payload interface{}) {
|
||||
if s.sseHub == nil || len(userIDs) == 0 {
|
||||
return
|
||||
}
|
||||
s.sseHub.PublishToUsers(userIDs, event, payload)
|
||||
}
|
||||
|
||||
// GetOrCreateConversation 获取或创建私聊会话
|
||||
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
|
||||
return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
|
||||
@@ -228,30 +231,30 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv
|
||||
return nil, fmt.Errorf("failed to save message: %w", err)
|
||||
}
|
||||
|
||||
// 发送消息给接收者
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, websocket.ChatMessage{
|
||||
ID: message.ID,
|
||||
ConversationID: message.ConversationID,
|
||||
SenderID: senderID,
|
||||
Segments: message.Segments,
|
||||
Seq: message.Seq,
|
||||
CreatedAt: message.CreatedAt.UnixMilli(),
|
||||
})
|
||||
|
||||
// 获取会话中的其他参与者
|
||||
// 获取会话中的参与者并发送 SSE
|
||||
participants, err := s.repo.GetConversationParticipants(conversationID)
|
||||
if err == nil {
|
||||
targetIDs := make([]string, 0, len(participants))
|
||||
for _, p := range participants {
|
||||
targetIDs = append(targetIDs, p.UserID)
|
||||
}
|
||||
detailType := "private"
|
||||
if conv.Type == model.ConversationTypeGroup {
|
||||
detailType = "group"
|
||||
}
|
||||
s.publishSSEToUsers(targetIDs, "chat_message", map[string]interface{}{
|
||||
"detail_type": detailType,
|
||||
"message": dto.ConvertMessageToResponse(message),
|
||||
})
|
||||
for _, p := range participants {
|
||||
// 不发给自己
|
||||
if p.UserID == senderID {
|
||||
continue
|
||||
}
|
||||
// 如果接收者在线,发送实时消息
|
||||
if s.wsManager != nil {
|
||||
isOnline := s.wsManager.IsUserOnline(p.UserID)
|
||||
if isOnline {
|
||||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||||
}
|
||||
if totalUnread, uErr := s.repo.GetAllUnreadCount(p.UserID); uErr == nil {
|
||||
s.publishSSEToUsers([]string{p.UserID}, "conversation_unread", map[string]interface{}{
|
||||
"conversation_id": conversationID,
|
||||
"total_unread": totalUnread,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -337,25 +340,33 @@ func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string,
|
||||
return fmt.Errorf("failed to update last read seq: %w", err)
|
||||
}
|
||||
|
||||
// 发送已读回执(作为 meta 事件)
|
||||
if s.wsManager != nil {
|
||||
wsMsg := websocket.CreateWSMessage("meta", map[string]interface{}{
|
||||
"detail_type": websocket.MetaDetailTypeRead,
|
||||
"conversation_id": conversationID,
|
||||
"seq": seq,
|
||||
"user_id": userID,
|
||||
})
|
||||
|
||||
// 获取会话中的所有参与者
|
||||
participants, err := s.repo.GetConversationParticipants(conversationID)
|
||||
if err == nil {
|
||||
// 推送给会话中的所有参与者(包括自己)
|
||||
for _, p := range participants {
|
||||
if s.wsManager.IsUserOnline(p.UserID) {
|
||||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||||
}
|
||||
participants, pErr := s.repo.GetConversationParticipants(conversationID)
|
||||
if pErr == nil {
|
||||
detailType := "private"
|
||||
groupID := ""
|
||||
if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
|
||||
detailType = "group"
|
||||
if conv.GroupID != nil {
|
||||
groupID = *conv.GroupID
|
||||
}
|
||||
}
|
||||
targetIDs := make([]string, 0, len(participants))
|
||||
for _, p := range participants {
|
||||
targetIDs = append(targetIDs, p.UserID)
|
||||
}
|
||||
s.publishSSEToUsers(targetIDs, "message_read", map[string]interface{}{
|
||||
"detail_type": detailType,
|
||||
"conversation_id": conversationID,
|
||||
"group_id": groupID,
|
||||
"user_id": userID,
|
||||
"seq": seq,
|
||||
})
|
||||
}
|
||||
if totalUnread, uErr := s.repo.GetAllUnreadCount(userID); uErr == nil {
|
||||
s.publishSSEToUsers([]string{userID}, "conversation_unread", map[string]interface{}{
|
||||
"conversation_id": conversationID,
|
||||
"total_unread": totalUnread,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -407,29 +418,35 @@ func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, u
|
||||
return errors.New("message recall timeout (2 minutes)")
|
||||
}
|
||||
|
||||
// 更新消息状态为已撤回
|
||||
err = s.db.Model(&message).Update("status", model.MessageStatusRecalled).Error
|
||||
// 更新消息状态为已撤回,并清空原始消息内容,仅保留撤回占位
|
||||
err = s.db.Model(&message).Updates(map[string]interface{}{
|
||||
"status": model.MessageStatusRecalled,
|
||||
"segments": model.MessageSegments{},
|
||||
}).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to recall message: %w", err)
|
||||
}
|
||||
|
||||
// 发送撤回通知
|
||||
if s.wsManager != nil {
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeRecall, map[string]interface{}{
|
||||
"messageId": messageID,
|
||||
"conversationId": message.ConversationID,
|
||||
"senderId": userID,
|
||||
})
|
||||
|
||||
// 通知会话中的所有参与者
|
||||
participants, err := s.repo.GetConversationParticipants(message.ConversationID)
|
||||
if err == nil {
|
||||
for _, p := range participants {
|
||||
if s.wsManager.IsUserOnline(p.UserID) {
|
||||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||||
}
|
||||
if participants, pErr := s.repo.GetConversationParticipants(message.ConversationID); pErr == nil {
|
||||
detailType := "private"
|
||||
groupID := ""
|
||||
if conv, convErr := s.repo.GetConversation(message.ConversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
|
||||
detailType = "group"
|
||||
if conv.GroupID != nil {
|
||||
groupID = *conv.GroupID
|
||||
}
|
||||
}
|
||||
targetIDs := make([]string, 0, len(participants))
|
||||
for _, p := range participants {
|
||||
targetIDs = append(targetIDs, p.UserID)
|
||||
}
|
||||
s.publishSSEToUsers(targetIDs, "message_recall", map[string]interface{}{
|
||||
"detail_type": detailType,
|
||||
"conversation_id": message.ConversationID,
|
||||
"group_id": groupID,
|
||||
"message_id": messageID,
|
||||
"sender_id": userID,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -473,7 +490,7 @@ func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, u
|
||||
|
||||
// SendTyping 发送正在输入状态
|
||||
func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) {
|
||||
if s.wsManager == nil {
|
||||
if s.sseHub == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -489,98 +506,34 @@ func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conve
|
||||
return
|
||||
}
|
||||
|
||||
detailType := "private"
|
||||
if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
|
||||
detailType = "group"
|
||||
}
|
||||
for _, p := range participants {
|
||||
if p.UserID == senderID {
|
||||
continue
|
||||
}
|
||||
// 发送正在输入状态
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeTyping, map[string]string{
|
||||
"conversationId": conversationID,
|
||||
"senderId": senderID,
|
||||
})
|
||||
|
||||
if s.wsManager.IsUserOnline(p.UserID) {
|
||||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||||
if s.sseHub != nil {
|
||||
s.sseHub.PublishToUser(p.UserID, "typing", map[string]interface{}{
|
||||
"detail_type": detailType,
|
||||
"conversation_id": conversationID,
|
||||
"user_id": senderID,
|
||||
"is_typing": true,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastMessage 广播消息给用户
|
||||
func (s *chatServiceImpl) BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string) {
|
||||
if s.wsManager != nil {
|
||||
s.wsManager.SendToUser(targetUser, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// IsUserOnline 检查用户是否在线
|
||||
func (s *chatServiceImpl) IsUserOnline(userID string) bool {
|
||||
if s.wsManager == nil {
|
||||
return false
|
||||
if s.sseHub != nil {
|
||||
return s.sseHub.HasSubscribers(userID)
|
||||
}
|
||||
return s.wsManager.IsUserOnline(userID)
|
||||
return false
|
||||
}
|
||||
|
||||
// PushSystemMessage 推送系统消息给指定用户
|
||||
func (s *chatServiceImpl) PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error {
|
||||
if s.wsManager == nil {
|
||||
return errors.New("websocket manager not available")
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return errors.New("user is offline")
|
||||
}
|
||||
|
||||
sysMsg := &websocket.SystemMessage{
|
||||
ID: "", // 由调用方生成
|
||||
Type: msgType,
|
||||
Title: title,
|
||||
Content: content,
|
||||
Data: data,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PushNotificationMessage 推送通知消息给指定用户
|
||||
func (s *chatServiceImpl) PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error {
|
||||
if s.wsManager == nil {
|
||||
return errors.New("websocket manager not available")
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return errors.New("user is offline")
|
||||
}
|
||||
|
||||
// 确保时间戳已设置
|
||||
if notification.CreatedAt == 0 {
|
||||
notification.CreatedAt = time.Now().UnixMilli()
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PushAnnouncementMessage 广播公告消息给所有在线用户
|
||||
func (s *chatServiceImpl) PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error {
|
||||
if s.wsManager == nil {
|
||||
return errors.New("websocket manager not available")
|
||||
}
|
||||
|
||||
// 确保时间戳已设置
|
||||
if announcement.CreatedAt == 0 {
|
||||
announcement.CreatedAt = time.Now().UnixMilli()
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
|
||||
s.wsManager.Broadcast(wsMsg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveMessage 仅保存消息到数据库,不发送 WebSocket 推送
|
||||
// SaveMessage 仅保存消息到数据库,不发送实时推送
|
||||
// 适用于群聊等由调用方自行负责推送的场景
|
||||
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
|
||||
// 验证会话是否存在
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/gorse"
|
||||
"carrot_bbs/internal/repository"
|
||||
@@ -17,6 +18,7 @@ type CommentService struct {
|
||||
commentRepo *repository.CommentRepository
|
||||
postRepo *repository.PostRepository
|
||||
systemMessageService SystemMessageService
|
||||
cache cache.Cache
|
||||
gorseClient gorse.Client
|
||||
postAIService *PostAIService
|
||||
}
|
||||
@@ -27,6 +29,7 @@ func NewCommentService(commentRepo *repository.CommentRepository, postRepo *repo
|
||||
commentRepo: commentRepo,
|
||||
postRepo: postRepo,
|
||||
systemMessageService: systemMessageService,
|
||||
cache: cache.GetCache(),
|
||||
gorseClient: gorseClient,
|
||||
postAIService: postAIService,
|
||||
}
|
||||
@@ -96,6 +99,10 @@ func (s *CommentService) reviewCommentAsync(
|
||||
log.Printf("[WARN] Failed to publish comment without AI moderation: %v", err)
|
||||
return
|
||||
}
|
||||
if err := s.applyCommentPublishedStats(commentID); err != nil {
|
||||
log.Printf("[WARN] Failed to apply published stats for comment %s: %v", commentID, err)
|
||||
}
|
||||
s.invalidatePostCaches(postID)
|
||||
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
||||
return
|
||||
}
|
||||
@@ -116,6 +123,10 @@ func (s *CommentService) reviewCommentAsync(
|
||||
log.Printf("[WARN] Failed to publish comment %s after moderation error: %v", commentID, updateErr)
|
||||
return
|
||||
}
|
||||
if statsErr := s.applyCommentPublishedStats(commentID); statsErr != nil {
|
||||
log.Printf("[WARN] Failed to apply published stats for comment %s: %v", commentID, statsErr)
|
||||
}
|
||||
s.invalidatePostCaches(postID)
|
||||
log.Printf("[WARN] Comment moderation failed, fallback publish comment=%s err=%v", commentID, err)
|
||||
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
||||
return
|
||||
@@ -125,9 +136,26 @@ func (s *CommentService) reviewCommentAsync(
|
||||
log.Printf("[WARN] Failed to publish comment %s: %v", commentID, updateErr)
|
||||
return
|
||||
}
|
||||
if statsErr := s.applyCommentPublishedStats(commentID); statsErr != nil {
|
||||
log.Printf("[WARN] Failed to apply published stats for comment %s: %v", commentID, statsErr)
|
||||
}
|
||||
s.invalidatePostCaches(postID)
|
||||
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
||||
}
|
||||
|
||||
func (s *CommentService) applyCommentPublishedStats(commentID string) error {
|
||||
comment, err := s.commentRepo.GetByID(commentID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.commentRepo.ApplyPublishedStats(comment)
|
||||
}
|
||||
|
||||
func (s *CommentService) invalidatePostCaches(postID string) {
|
||||
cache.InvalidatePostDetail(s.cache, postID)
|
||||
cache.InvalidatePostList(s.cache)
|
||||
}
|
||||
|
||||
func (s *CommentService) afterCommentPublished(userID, postID, commentID string, parentID *string, parentUserID, postOwnerID string) {
|
||||
// 发送系统消息通知
|
||||
if s.systemMessageService != nil {
|
||||
@@ -212,7 +240,15 @@ func (s *CommentService) Update(ctx context.Context, comment *model.Comment) err
|
||||
|
||||
// Delete 删除评论
|
||||
func (s *CommentService) Delete(ctx context.Context, id string) error {
|
||||
return s.commentRepo.Delete(id)
|
||||
comment, err := s.commentRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.commentRepo.Delete(id); err != nil {
|
||||
return err
|
||||
}
|
||||
s.invalidatePostCaches(comment.PostID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Like 点赞评论
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/sse"
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
"carrot_bbs/internal/pkg/websocket"
|
||||
"carrot_bbs/internal/repository"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
|
||||
// 缓存TTL常量
|
||||
const (
|
||||
GroupMembersTTL = 120 * time.Second // 群组成员缓存120秒
|
||||
GroupMembersTTL = 120 * time.Second // 群组成员缓存120秒
|
||||
GroupMembersNullTTL = 5 * time.Second
|
||||
GroupCacheJitter = 0.1
|
||||
)
|
||||
@@ -99,12 +99,12 @@ type groupService struct {
|
||||
messageRepo *repository.MessageRepository
|
||||
requestRepo repository.GroupJoinRequestRepository
|
||||
notifyRepo *repository.SystemNotificationRepository
|
||||
wsManager *websocket.WebSocketManager
|
||||
sseHub *sse.Hub
|
||||
cache cache.Cache
|
||||
}
|
||||
|
||||
// NewGroupService 创建群组服务
|
||||
func NewGroupService(db *gorm.DB, groupRepo repository.GroupRepository, userRepo *repository.UserRepository, messageRepo *repository.MessageRepository, wsManager *websocket.WebSocketManager) GroupService {
|
||||
func NewGroupService(db *gorm.DB, groupRepo repository.GroupRepository, userRepo *repository.UserRepository, messageRepo *repository.MessageRepository, sseHub *sse.Hub) GroupService {
|
||||
return &groupService{
|
||||
db: db,
|
||||
groupRepo: groupRepo,
|
||||
@@ -112,11 +112,39 @@ func NewGroupService(db *gorm.DB, groupRepo repository.GroupRepository, userRepo
|
||||
messageRepo: messageRepo,
|
||||
requestRepo: repository.NewGroupJoinRequestRepository(db),
|
||||
notifyRepo: repository.NewSystemNotificationRepository(db),
|
||||
wsManager: wsManager,
|
||||
sseHub: sseHub,
|
||||
cache: cache.GetCache(),
|
||||
}
|
||||
}
|
||||
|
||||
type groupNoticeData struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
OperatorID string `json:"operator_id,omitempty"`
|
||||
}
|
||||
|
||||
type groupNoticeMessage struct {
|
||||
NoticeType string `json:"notice_type"`
|
||||
GroupID string `json:"group_id"`
|
||||
Data groupNoticeData `json:"data"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
MessageID string `json:"message_id,omitempty"`
|
||||
Seq int64 `json:"seq,omitempty"`
|
||||
}
|
||||
|
||||
func (s *groupService) publishGroupNotice(groupID string, notice groupNoticeMessage) {
|
||||
members, _, err := s.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("[groupService] 获取群成员失败: groupID=%s, err=%v", groupID, err)
|
||||
return
|
||||
}
|
||||
if s.sseHub != nil {
|
||||
for _, m := range members {
|
||||
s.sseHub.PublishToUser(m.UserID, "group_notice", notice)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 群组管理 ====================
|
||||
|
||||
// CreateGroup 创建群组
|
||||
@@ -422,14 +450,10 @@ func (s *groupService) broadcastMemberJoinNotice(groupID string, targetUserID st
|
||||
}
|
||||
}
|
||||
|
||||
if s.wsManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
noticeMsg := websocket.GroupNoticeMessage{
|
||||
noticeMsg := groupNoticeMessage{
|
||||
NoticeType: "member_join",
|
||||
GroupID: groupID,
|
||||
Data: websocket.GroupNoticeData{
|
||||
Data: groupNoticeData{
|
||||
UserID: targetUserID,
|
||||
Username: targetUserName,
|
||||
OperatorID: operatorID,
|
||||
@@ -441,17 +465,7 @@ func (s *groupService) broadcastMemberJoinNotice(groupID string, targetUserID st
|
||||
noticeMsg.Seq = savedMessage.Seq
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeGroupNotice, noticeMsg)
|
||||
members, _, err := s.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("[broadcastMemberJoinNotice] 获取群成员失败: groupID=%s, err=%v", groupID, err)
|
||||
return
|
||||
}
|
||||
for _, m := range members {
|
||||
if s.wsManager.IsUserOnline(m.UserID) {
|
||||
s.wsManager.SendToUser(m.UserID, wsMsg)
|
||||
}
|
||||
}
|
||||
s.publishGroupNotice(groupID, noticeMsg)
|
||||
}
|
||||
|
||||
func (s *groupService) addMemberToGroupAndConversation(group *model.Group, userID string, operatorID string) error {
|
||||
@@ -1282,46 +1296,20 @@ func (s *groupService) MuteMember(userID string, groupID string, targetUserID st
|
||||
}
|
||||
}
|
||||
|
||||
// 发送WebSocket通知给群成员
|
||||
if s.wsManager != nil {
|
||||
log.Printf("[MuteMember] 准备发送禁言通知: groupID=%s, targetUserID=%s, noticeType=%s, operatorID=%s", groupID, targetUserID, noticeType, userID)
|
||||
|
||||
// 构建通知消息,包含保存的消息信息
|
||||
noticeMsg := websocket.GroupNoticeMessage{
|
||||
NoticeType: noticeType,
|
||||
GroupID: groupID,
|
||||
Data: websocket.GroupNoticeData{
|
||||
UserID: targetUserID,
|
||||
OperatorID: userID,
|
||||
},
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
// 如果消息已保存,添加消息ID和seq
|
||||
if savedMessage != nil {
|
||||
noticeMsg.MessageID = savedMessage.ID
|
||||
noticeMsg.Seq = savedMessage.Seq
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeGroupNotice, noticeMsg)
|
||||
log.Printf("[MuteMember] 创建的WebSocket消息: Type=%s, Data=%+v", wsMsg.Type, wsMsg.Data)
|
||||
|
||||
// 获取所有群成员并发送通知
|
||||
members, _, err := s.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err == nil {
|
||||
log.Printf("[MuteMember] 获取到群成员数量: %d", len(members))
|
||||
for _, m := range members {
|
||||
isOnline := s.wsManager.IsUserOnline(m.UserID)
|
||||
log.Printf("[MuteMember] 成员 %s 在线状态: %v", m.UserID, isOnline)
|
||||
if isOnline {
|
||||
s.wsManager.SendToUser(m.UserID, wsMsg)
|
||||
log.Printf("[MuteMember] 已发送通知给成员: %s", m.UserID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Printf("[MuteMember] 获取群成员失败: %v", err)
|
||||
}
|
||||
noticeMsg := groupNoticeMessage{
|
||||
NoticeType: noticeType,
|
||||
GroupID: groupID,
|
||||
Data: groupNoticeData{
|
||||
UserID: targetUserID,
|
||||
OperatorID: userID,
|
||||
},
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
if savedMessage != nil {
|
||||
noticeMsg.MessageID = savedMessage.ID
|
||||
noticeMsg.Seq = savedMessage.Seq
|
||||
}
|
||||
s.publishGroupNotice(groupID, noticeMsg)
|
||||
|
||||
// 失效群组成员缓存
|
||||
cache.InvalidateGroupMembers(s.cache, groupID)
|
||||
|
||||
@@ -77,6 +77,8 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
|
||||
if s.postAIService == nil || !s.postAIService.IsEnabled() {
|
||||
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil {
|
||||
log.Printf("[WARN] Failed to publish post without AI moderation: %v", err)
|
||||
} else {
|
||||
s.invalidatePostCaches(postID)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -87,6 +89,8 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
|
||||
if errors.As(err, &rejectedErr) {
|
||||
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
|
||||
log.Printf("[WARN] Failed to reject post %s: %v", postID, updateErr)
|
||||
} else {
|
||||
s.invalidatePostCaches(postID)
|
||||
}
|
||||
s.notifyModerationRejected(userID, rejectedErr.Reason)
|
||||
return
|
||||
@@ -95,6 +99,8 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
|
||||
// 规则审核不可用时,降级为发布,避免长时间pending
|
||||
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
|
||||
log.Printf("[WARN] Failed to publish post %s after moderation error: %v", postID, updateErr)
|
||||
} else {
|
||||
s.invalidatePostCaches(postID)
|
||||
}
|
||||
log.Printf("[WARN] Post moderation failed, fallback publish post=%s err=%v", postID, err)
|
||||
return
|
||||
@@ -104,6 +110,7 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
|
||||
log.Printf("[WARN] Failed to publish post %s: %v", postID, err)
|
||||
return
|
||||
}
|
||||
s.invalidatePostCaches(postID)
|
||||
|
||||
if s.gorseClient.IsEnabled() {
|
||||
post, getErr := s.postRepo.GetByID(postID)
|
||||
@@ -120,6 +127,11 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PostService) invalidatePostCaches(postID string) {
|
||||
cache.InvalidatePostDetail(s.cache, postID)
|
||||
cache.InvalidatePostList(s.cache)
|
||||
}
|
||||
|
||||
func (s *PostService) notifyModerationRejected(userID, reason string) {
|
||||
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
|
||||
return
|
||||
@@ -149,7 +161,12 @@ func (s *PostService) GetByID(ctx context.Context, id string) (*model.Post, erro
|
||||
|
||||
// Update 更新帖子
|
||||
func (s *PostService) Update(ctx context.Context, post *model.Post) error {
|
||||
err := s.postRepo.Update(post)
|
||||
return s.UpdateWithImages(ctx, post, nil)
|
||||
}
|
||||
|
||||
// UpdateWithImages 更新帖子并可选更新图片(images=nil 表示不更新图片)
|
||||
func (s *PostService) UpdateWithImages(ctx context.Context, post *model.Post, images *[]string) error {
|
||||
err := s.postRepo.UpdateWithImages(post, images)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -185,7 +202,7 @@ func (s *PostService) Delete(ctx context.Context, id string) error {
|
||||
}
|
||||
|
||||
// List 获取帖子列表(带缓存)
|
||||
func (s *PostService) List(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
||||
func (s *PostService) List(ctx context.Context, page, pageSize int, userID string, includePending bool) ([]*model.Post, int64, error) {
|
||||
cacheSettings := cache.GetSettings()
|
||||
postListTTL := cacheSettings.PostListTTL
|
||||
if postListTTL <= 0 {
|
||||
@@ -200,8 +217,12 @@ func (s *PostService) List(ctx context.Context, page, pageSize int, userID strin
|
||||
jitter = PostListJitterRatio
|
||||
}
|
||||
|
||||
// 生成缓存键(包含 userID 维度,避免过滤查询与全量查询互相污染)
|
||||
cacheKey := cache.PostListKey("latest", userID, page, pageSize)
|
||||
// 生成缓存键(包含 userID 维度与可见性维度,避免作者视角污染公开视角)
|
||||
visibilityUserKey := userID
|
||||
if includePending && userID != "" {
|
||||
visibilityUserKey = "owner:" + userID
|
||||
}
|
||||
cacheKey := cache.PostListKey("latest", visibilityUserKey, page, pageSize)
|
||||
|
||||
result, err := cache.GetOrLoadTyped[*PostListResult](
|
||||
s.cache,
|
||||
@@ -210,7 +231,7 @@ func (s *PostService) List(ctx context.Context, page, pageSize int, userID strin
|
||||
jitter,
|
||||
nullTTL,
|
||||
func() (*PostListResult, error) {
|
||||
posts, total, err := s.postRepo.List(page, pageSize, userID)
|
||||
posts, total, err := s.postRepo.List(page, pageSize, userID, includePending)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -234,7 +255,7 @@ func (s *PostService) List(ctx context.Context, page, pageSize int, userID strin
|
||||
}
|
||||
}
|
||||
if missingAuthor {
|
||||
posts, total, loadErr := s.postRepo.List(page, pageSize, userID)
|
||||
posts, total, loadErr := s.postRepo.List(page, pageSize, userID, includePending)
|
||||
if loadErr != nil {
|
||||
return nil, 0, loadErr
|
||||
}
|
||||
@@ -247,12 +268,17 @@ func (s *PostService) List(ctx context.Context, page, pageSize int, userID strin
|
||||
|
||||
// GetLatestPosts 获取最新帖子(语义化别名)
|
||||
func (s *PostService) GetLatestPosts(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
||||
return s.List(ctx, page, pageSize, userID)
|
||||
return s.List(ctx, page, pageSize, userID, false)
|
||||
}
|
||||
|
||||
// GetLatestPostsForOwner 获取作者视角帖子列表(包含待审核)
|
||||
func (s *PostService) GetLatestPostsForOwner(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
||||
return s.List(ctx, page, pageSize, userID, true)
|
||||
}
|
||||
|
||||
// GetUserPosts 获取用户帖子
|
||||
func (s *PostService) GetUserPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
return s.postRepo.GetUserPosts(userID, page, pageSize)
|
||||
func (s *PostService) GetUserPosts(ctx context.Context, userID string, page, pageSize int, includePending bool) ([]*model.Post, int64, error) {
|
||||
return s.postRepo.GetUserPosts(userID, page, pageSize, includePending)
|
||||
}
|
||||
|
||||
// Like 点赞
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/websocket"
|
||||
"carrot_bbs/internal/pkg/sse"
|
||||
"carrot_bbs/internal/repository"
|
||||
)
|
||||
|
||||
@@ -42,8 +42,6 @@ type PushService interface {
|
||||
|
||||
// 系统消息推送
|
||||
PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error
|
||||
PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error
|
||||
PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error
|
||||
|
||||
// 系统通知推送(新接口,使用独立的 SystemNotification 模型)
|
||||
PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error
|
||||
@@ -67,7 +65,7 @@ type pushServiceImpl struct {
|
||||
pushRepo *repository.PushRecordRepository
|
||||
deviceRepo *repository.DeviceTokenRepository
|
||||
messageRepo *repository.MessageRepository
|
||||
wsManager *websocket.WebSocketManager
|
||||
sseHub *sse.Hub
|
||||
|
||||
// 推送队列
|
||||
pushQueue chan *pushTask
|
||||
@@ -86,13 +84,13 @@ func NewPushService(
|
||||
pushRepo *repository.PushRecordRepository,
|
||||
deviceRepo *repository.DeviceTokenRepository,
|
||||
messageRepo *repository.MessageRepository,
|
||||
wsManager *websocket.WebSocketManager,
|
||||
sseHub *sse.Hub,
|
||||
) PushService {
|
||||
return &pushServiceImpl{
|
||||
pushRepo: pushRepo,
|
||||
deviceRepo: deviceRepo,
|
||||
messageRepo: messageRepo,
|
||||
wsManager: wsManager,
|
||||
sseHub: sseHub,
|
||||
pushQueue: make(chan *pushTask, PushQueueSize),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
@@ -140,11 +138,7 @@ func (s *pushServiceImpl) PushToUser(ctx context.Context, userID string, message
|
||||
// pushViaWebSocket 通过WebSocket推送消息
|
||||
// 返回true表示推送成功,false表示用户不在线
|
||||
func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, message *model.Message) bool {
|
||||
if s.wsManager == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
if s.sseHub == nil || !s.sseHub.HasSubscribers(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -154,36 +148,33 @@ func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, m
|
||||
// 从 segments 中提取文本内容
|
||||
content := dto.ExtractTextContentFromModel(message.Segments)
|
||||
|
||||
notification := &websocket.NotificationMessage{
|
||||
ID: fmt.Sprintf("%s", message.ID),
|
||||
Type: string(message.SystemType),
|
||||
Content: content,
|
||||
Extra: make(map[string]interface{}),
|
||||
CreatedAt: message.CreatedAt.UnixMilli(),
|
||||
notification := map[string]interface{}{
|
||||
"id": fmt.Sprintf("%s", message.ID),
|
||||
"type": string(message.SystemType),
|
||||
"content": content,
|
||||
"extra": map[string]interface{}{},
|
||||
"created_at": message.CreatedAt.UnixMilli(),
|
||||
}
|
||||
|
||||
// 填充额外数据
|
||||
if message.ExtraData != nil {
|
||||
notification.Extra["actor_id"] = message.ExtraData.ActorID
|
||||
notification.Extra["actor_name"] = message.ExtraData.ActorName
|
||||
notification.Extra["avatar_url"] = message.ExtraData.AvatarURL
|
||||
notification.Extra["target_id"] = message.ExtraData.TargetID
|
||||
notification.Extra["target_type"] = message.ExtraData.TargetType
|
||||
notification.Extra["action_url"] = message.ExtraData.ActionURL
|
||||
notification.Extra["action_time"] = message.ExtraData.ActionTime
|
||||
|
||||
// 设置触发用户信息
|
||||
extra := notification["extra"].(map[string]interface{})
|
||||
extra["actor_id"] = message.ExtraData.ActorID
|
||||
extra["actor_name"] = message.ExtraData.ActorName
|
||||
extra["avatar_url"] = message.ExtraData.AvatarURL
|
||||
extra["target_id"] = message.ExtraData.TargetID
|
||||
extra["target_type"] = message.ExtraData.TargetType
|
||||
extra["action_url"] = message.ExtraData.ActionURL
|
||||
extra["action_time"] = message.ExtraData.ActionTime
|
||||
if message.ExtraData.ActorID > 0 {
|
||||
notification.TriggerUser = &websocket.NotificationUser{
|
||||
ID: fmt.Sprintf("%d", message.ExtraData.ActorID),
|
||||
Username: message.ExtraData.ActorName,
|
||||
Avatar: message.ExtraData.AvatarURL,
|
||||
notification["trigger_user"] = map[string]interface{}{
|
||||
"id": fmt.Sprintf("%d", message.ExtraData.ActorID),
|
||||
"username": message.ExtraData.ActorName,
|
||||
"avatar": message.ExtraData.AvatarURL,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
s.sseHub.PublishToUser(userID, "system_notification", notification)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -208,8 +199,10 @@ func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, m
|
||||
SenderID: message.SenderID,
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, event)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
s.sseHub.PublishToUser(userID, "chat_message", map[string]interface{}{
|
||||
"detail_type": detailType,
|
||||
"message": event,
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -451,73 +444,21 @@ func (s *pushServiceImpl) PushSystemMessage(ctx context.Context, userID string,
|
||||
|
||||
// pushSystemViaWebSocket 通过WebSocket推送系统消息
|
||||
func (s *pushServiceImpl) pushSystemViaWebSocket(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) bool {
|
||||
if s.wsManager == nil {
|
||||
if s.sseHub == nil || !s.sseHub.HasSubscribers(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return false
|
||||
sysMsg := map[string]interface{}{
|
||||
"type": msgType,
|
||||
"title": title,
|
||||
"content": content,
|
||||
"data": data,
|
||||
"created_at": time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
sysMsg := &websocket.SystemMessage{
|
||||
Type: msgType,
|
||||
Title: title,
|
||||
Content: content,
|
||||
Data: data,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
s.sseHub.PublishToUser(userID, "system_notification", sysMsg)
|
||||
return true
|
||||
}
|
||||
|
||||
// PushNotification 推送通知消息
|
||||
func (s *pushServiceImpl) PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error {
|
||||
// 首先尝试WebSocket推送
|
||||
if s.pushNotificationViaWebSocket(ctx, userID, notification) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 用户不在线,创建待推送记录
|
||||
// 通知消息可以等用户上线后拉取
|
||||
return errors.New("user is offline, notification will be available on next sync")
|
||||
}
|
||||
|
||||
// pushNotificationViaWebSocket 通过WebSocket推送通知消息
|
||||
func (s *pushServiceImpl) pushNotificationViaWebSocket(ctx context.Context, userID string, notification *websocket.NotificationMessage) bool {
|
||||
if s.wsManager == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
if notification.CreatedAt == 0 {
|
||||
notification.CreatedAt = time.Now().UnixMilli()
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return true
|
||||
}
|
||||
|
||||
// PushAnnouncement 广播公告消息
|
||||
func (s *pushServiceImpl) PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error {
|
||||
if s.wsManager == nil {
|
||||
return errors.New("websocket manager not available")
|
||||
}
|
||||
|
||||
if announcement.CreatedAt == 0 {
|
||||
announcement.CreatedAt = time.Now().UnixMilli()
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
|
||||
s.wsManager.Broadcast(wsMsg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PushSystemNotification 推送系统通知(使用独立的 SystemNotification 模型)
|
||||
func (s *pushServiceImpl) PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error {
|
||||
// 首先尝试WebSocket推送
|
||||
@@ -531,45 +472,40 @@ func (s *pushServiceImpl) PushSystemNotification(ctx context.Context, userID str
|
||||
|
||||
// pushSystemNotificationViaWebSocket 通过WebSocket推送系统通知
|
||||
func (s *pushServiceImpl) pushSystemNotificationViaWebSocket(ctx context.Context, userID string, notification *model.SystemNotification) bool {
|
||||
if s.wsManager == nil {
|
||||
if s.sseHub == nil || !s.sseHub.HasSubscribers(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 构建 WebSocket 通知消息
|
||||
wsNotification := &websocket.NotificationMessage{
|
||||
ID: fmt.Sprintf("%d", notification.ID),
|
||||
Type: string(notification.Type),
|
||||
Title: notification.Title,
|
||||
Content: notification.Content,
|
||||
Extra: make(map[string]interface{}),
|
||||
CreatedAt: notification.CreatedAt.UnixMilli(),
|
||||
sseNotification := map[string]interface{}{
|
||||
"id": fmt.Sprintf("%d", notification.ID),
|
||||
"type": string(notification.Type),
|
||||
"title": notification.Title,
|
||||
"content": notification.Content,
|
||||
"extra": map[string]interface{}{},
|
||||
"created_at": notification.CreatedAt.UnixMilli(),
|
||||
}
|
||||
|
||||
// 填充额外数据
|
||||
if notification.ExtraData != nil {
|
||||
wsNotification.Extra["actor_id_str"] = notification.ExtraData.ActorIDStr
|
||||
wsNotification.Extra["actor_name"] = notification.ExtraData.ActorName
|
||||
wsNotification.Extra["avatar_url"] = notification.ExtraData.AvatarURL
|
||||
wsNotification.Extra["target_id"] = notification.ExtraData.TargetID
|
||||
wsNotification.Extra["target_type"] = notification.ExtraData.TargetType
|
||||
wsNotification.Extra["action_url"] = notification.ExtraData.ActionURL
|
||||
wsNotification.Extra["action_time"] = notification.ExtraData.ActionTime
|
||||
extra := sseNotification["extra"].(map[string]interface{})
|
||||
extra["actor_id_str"] = notification.ExtraData.ActorIDStr
|
||||
extra["actor_name"] = notification.ExtraData.ActorName
|
||||
extra["avatar_url"] = notification.ExtraData.AvatarURL
|
||||
extra["target_id"] = notification.ExtraData.TargetID
|
||||
extra["target_type"] = notification.ExtraData.TargetType
|
||||
extra["action_url"] = notification.ExtraData.ActionURL
|
||||
extra["action_time"] = notification.ExtraData.ActionTime
|
||||
|
||||
// 设置触发用户信息
|
||||
if notification.ExtraData.ActorIDStr != "" {
|
||||
wsNotification.TriggerUser = &websocket.NotificationUser{
|
||||
ID: notification.ExtraData.ActorIDStr,
|
||||
Username: notification.ExtraData.ActorName,
|
||||
Avatar: notification.ExtraData.AvatarURL,
|
||||
sseNotification["trigger_user"] = map[string]interface{}{
|
||||
"id": notification.ExtraData.ActorIDStr,
|
||||
"username": notification.ExtraData.ActorName,
|
||||
"avatar": notification.ExtraData.AvatarURL,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, wsNotification)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
s.sseHub.PublishToUser(userID, "system_notification", sseNotification)
|
||||
return true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user