Replace websocket flow with SSE support in backend.
Update handlers, services, router, and data conversion logic to support server-sent events and related message pipeline changes. Made-with: Cursor
This commit is contained in:
@@ -36,9 +36,9 @@ database:
|
|||||||
redis:
|
redis:
|
||||||
type: miniredis # miniredis 或 redis
|
type: miniredis # miniredis 或 redis
|
||||||
redis:
|
redis:
|
||||||
host: localhost
|
host: 1Panel-redis-dfmM
|
||||||
port: 6379
|
port: 6379
|
||||||
password: ""
|
password: "redis_j8CMza"
|
||||||
db: 0
|
db: 0
|
||||||
miniredis:
|
miniredis:
|
||||||
host: localhost
|
host: localhost
|
||||||
@@ -67,13 +67,13 @@ cache:
|
|||||||
# S3对象存储配置
|
# S3对象存储配置
|
||||||
# 环境变量: APP_S3_ENDPOINT, APP_S3_ACCESS_KEY, APP_S3_SECRET_KEY, APP_S3_BUCKET, APP_S3_DOMAIN
|
# 环境变量: APP_S3_ENDPOINT, APP_S3_ACCESS_KEY, APP_S3_SECRET_KEY, APP_S3_BUCKET, APP_S3_DOMAIN
|
||||||
s3:
|
s3:
|
||||||
endpoint: ""
|
endpoint: "files.littlelan.cn"
|
||||||
access_key: ""
|
access_key: "E6bMcYkQzCldRTrtmhvi"
|
||||||
secret_key: ""
|
secret_key: "4R9yjmwKNoHphiBkv05Oa8WGEIFbnlZeTLXfSgx3"
|
||||||
bucket: ""
|
bucket: "test"
|
||||||
use_ssl: true
|
use_ssl: true
|
||||||
region: us-east-1
|
region: us-east-1
|
||||||
domain: ""
|
domain: "files.littlelan.cn"
|
||||||
# JWT配置
|
# JWT配置
|
||||||
# 环境变量: APP_JWT_SECRET
|
# 环境变量: APP_JWT_SECRET
|
||||||
jwt:
|
jwt:
|
||||||
@@ -130,12 +130,12 @@ audit:
|
|||||||
# Gorse推荐系统配置
|
# Gorse推荐系统配置
|
||||||
# 环境变量: APP_GORSE_ADDRESS, APP_GORSE_API_KEY, APP_GORSE_DASHBOARD, APP_GORSE_IMPORT_PASSWORD
|
# 环境变量: APP_GORSE_ADDRESS, APP_GORSE_API_KEY, APP_GORSE_DASHBOARD, APP_GORSE_IMPORT_PASSWORD
|
||||||
gorse:
|
gorse:
|
||||||
enabled: false
|
enabled: true
|
||||||
address: "" # Gorse server地址
|
address: "http://111.170.19.33:8088" # Gorse server地址
|
||||||
api_key: "" # API密钥
|
api_key: "" # API密钥
|
||||||
dashboard: "" # Gorse dashboard地址
|
dashboard: "" # Gorse dashboard地址
|
||||||
import_password: "" # 导入数据密码
|
import_password: "lanyimin123" # 导入数据密码
|
||||||
embedding_api_key: ""
|
embedding_api_key: "sk-ZPN5NMPSqEaOGCPfD2LqndZ5Wwmw3DC4CQgzgKhM35fI3RpD"
|
||||||
embedding_url: "https://api.littlelan.cn/v1/embeddings"
|
embedding_url: "https://api.littlelan.cn/v1/embeddings"
|
||||||
embedding_model: "BAAI/bge-m3"
|
embedding_model: "BAAI/bge-m3"
|
||||||
|
|
||||||
@@ -147,7 +147,7 @@ gorse:
|
|||||||
openai:
|
openai:
|
||||||
enabled: true
|
enabled: true
|
||||||
base_url: "https://api.littlelan.cn/"
|
base_url: "https://api.littlelan.cn/"
|
||||||
api_key: ""
|
api_key: "sk-y7LOeKsNfzbZWTRSFsTs79jd8WYlezbIVgdVPgMvG4Xz2AlV"
|
||||||
moderation_model: "qwen3.5-122b"
|
moderation_model: "qwen3.5-122b"
|
||||||
moderation_max_images_per_request: 1
|
moderation_max_images_per_request: 1
|
||||||
request_timeout: 30
|
request_timeout: 30
|
||||||
@@ -160,12 +160,12 @@ openai:
|
|||||||
# APP_EMAIL_FROM_ADDRESS, APP_EMAIL_FROM_NAME
|
# APP_EMAIL_FROM_ADDRESS, APP_EMAIL_FROM_NAME
|
||||||
# APP_EMAIL_USE_TLS, APP_EMAIL_INSECURE_SKIP_VERIFY, APP_EMAIL_TIMEOUT
|
# APP_EMAIL_USE_TLS, APP_EMAIL_INSECURE_SKIP_VERIFY, APP_EMAIL_TIMEOUT
|
||||||
email:
|
email:
|
||||||
enabled: false
|
enabled: true
|
||||||
host: ""
|
host: "smtp.exmail.qq.com"
|
||||||
port: 587
|
port: 465
|
||||||
username: ""
|
username: "no-reply@qczlit.cn"
|
||||||
password: ""
|
password: "HbvwwVjRyiWg9gsK"
|
||||||
from_address: ""
|
from_address: "no-reply@qczlit.cn"
|
||||||
from_name: "Carrot BBS"
|
from_name: "Carrot BBS"
|
||||||
use_tls: true
|
use_tls: true
|
||||||
insecure_skip_verify: false
|
insecure_skip_verify: false
|
||||||
|
|||||||
@@ -284,6 +284,7 @@ func ConvertPostToResponse(post *model.Post, isLiked, isFavorited bool) *PostRes
|
|||||||
Title: post.Title,
|
Title: post.Title,
|
||||||
Content: post.Content,
|
Content: post.Content,
|
||||||
Images: images,
|
Images: images,
|
||||||
|
Status: string(post.Status),
|
||||||
LikesCount: post.LikesCount,
|
LikesCount: post.LikesCount,
|
||||||
CommentsCount: post.CommentsCount,
|
CommentsCount: post.CommentsCount,
|
||||||
FavoritesCount: post.FavoritesCount,
|
FavoritesCount: post.FavoritesCount,
|
||||||
@@ -293,6 +294,7 @@ func ConvertPostToResponse(post *model.Post, isLiked, isFavorited bool) *PostRes
|
|||||||
IsLocked: post.IsLocked,
|
IsLocked: post.IsLocked,
|
||||||
IsVote: post.IsVote,
|
IsVote: post.IsVote,
|
||||||
CreatedAt: FormatTime(post.CreatedAt),
|
CreatedAt: FormatTime(post.CreatedAt),
|
||||||
|
UpdatedAt: FormatTime(post.UpdatedAt),
|
||||||
Author: author,
|
Author: author,
|
||||||
IsLiked: isLiked,
|
IsLiked: isLiked,
|
||||||
IsFavorited: isFavorited,
|
IsFavorited: isFavorited,
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ type PostResponse struct {
|
|||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Images []PostImageResponse `json:"images"`
|
Images []PostImageResponse `json:"images"`
|
||||||
|
Status string `json:"status,omitempty"`
|
||||||
LikesCount int `json:"likes_count"`
|
LikesCount int `json:"likes_count"`
|
||||||
CommentsCount int `json:"comments_count"`
|
CommentsCount int `json:"comments_count"`
|
||||||
FavoritesCount int `json:"favorites_count"`
|
FavoritesCount int `json:"favorites_count"`
|
||||||
@@ -77,6 +78,7 @@ type PostResponse struct {
|
|||||||
IsLocked bool `json:"is_locked"`
|
IsLocked bool `json:"is_locked"`
|
||||||
IsVote bool `json:"is_vote"`
|
IsVote bool `json:"is_vote"`
|
||||||
CreatedAt string `json:"created_at"`
|
CreatedAt string `json:"created_at"`
|
||||||
|
UpdatedAt string `json:"updated_at"`
|
||||||
Author *UserResponse `json:"author"`
|
Author *UserResponse `json:"author"`
|
||||||
IsLiked bool `json:"is_liked"`
|
IsLiked bool `json:"is_liked"`
|
||||||
IsFavorited bool `json:"is_favorited"`
|
IsFavorited bool `json:"is_favorited"`
|
||||||
|
|||||||
@@ -2,12 +2,16 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
"carrot_bbs/internal/dto"
|
"carrot_bbs/internal/dto"
|
||||||
"carrot_bbs/internal/model"
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/pkg/sse"
|
||||||
"carrot_bbs/internal/pkg/response"
|
"carrot_bbs/internal/pkg/response"
|
||||||
"carrot_bbs/internal/service"
|
"carrot_bbs/internal/service"
|
||||||
)
|
)
|
||||||
@@ -18,18 +22,111 @@ type MessageHandler struct {
|
|||||||
messageService *service.MessageService
|
messageService *service.MessageService
|
||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
groupService service.GroupService
|
groupService service.GroupService
|
||||||
|
sseHub *sse.Hub
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMessageHandler 创建消息处理器
|
// NewMessageHandler 创建消息处理器
|
||||||
func NewMessageHandler(chatService service.ChatService, messageService *service.MessageService, userService *service.UserService, groupService service.GroupService) *MessageHandler {
|
func NewMessageHandler(chatService service.ChatService, messageService *service.MessageService, userService *service.UserService, groupService service.GroupService, sseHub *sse.Hub) *MessageHandler {
|
||||||
return &MessageHandler{
|
return &MessageHandler{
|
||||||
chatService: chatService,
|
chatService: chatService,
|
||||||
messageService: messageService,
|
messageService: messageService,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
groupService: groupService,
|
groupService: groupService,
|
||||||
|
sseHub: sseHub,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HandleSSE 实时消息订阅(SSE)
|
||||||
|
// GET /api/v1/realtime/sse
|
||||||
|
func (h *MessageHandler) HandleSSE(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.sseHub == nil {
|
||||||
|
response.InternalServerError(c, "sse hub not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lastID := sse.ParseEventID(c.GetHeader("Last-Event-ID"))
|
||||||
|
if lastID == 0 {
|
||||||
|
lastID = sse.ParseEventID(c.Query("last_event_id"))
|
||||||
|
}
|
||||||
|
ch, cancel, replay := h.sseHub.Subscribe(userID, lastID)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
w := c.Writer
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
response.InternalServerError(c, "streaming unsupported")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
flusher.Flush()
|
||||||
|
|
||||||
|
writeEvent := func(ev sse.Event) bool {
|
||||||
|
data, err := sse.EncodeData(ev)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", ev.ID, ev.Event, data); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ev := range replay {
|
||||||
|
if !writeEvent(ev) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
heartbeat := time.NewTicker(25 * time.Second)
|
||||||
|
defer heartbeat.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
return
|
||||||
|
case ev, ok := <-ch:
|
||||||
|
if !ok || !writeEvent(ev) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-heartbeat.C:
|
||||||
|
if _, err := fmt.Fprint(w, "event: heartbeat\ndata: {}\n\n"); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleTyping 输入状态上报
|
||||||
|
// POST /api/v1/conversations/typing
|
||||||
|
func (h *MessageHandler) HandleTyping(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var params struct {
|
||||||
|
ConversationID string `json:"conversation_id" binding:"required"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.chatService.SendTyping(c.Request.Context(), userID, params.ConversationID)
|
||||||
|
response.SuccessWithMessage(c, "typing sent", nil)
|
||||||
|
}
|
||||||
|
|
||||||
// GetConversations 获取会话列表
|
// GetConversations 获取会话列表
|
||||||
// GET /api/conversations
|
// GET /api/conversations
|
||||||
func (h *MessageHandler) GetConversations(c *gin.Context) {
|
func (h *MessageHandler) GetConversations(c *gin.Context) {
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ func (h *PostHandler) GetByID(c *gin.Context) {
|
|||||||
Title: post.Title,
|
Title: post.Title,
|
||||||
Content: post.Content,
|
Content: post.Content,
|
||||||
Images: dto.ConvertPostImagesToResponse(post.Images),
|
Images: dto.ConvertPostImagesToResponse(post.Images),
|
||||||
|
Status: string(post.Status),
|
||||||
LikesCount: post.LikesCount,
|
LikesCount: post.LikesCount,
|
||||||
CommentsCount: post.CommentsCount,
|
CommentsCount: post.CommentsCount,
|
||||||
FavoritesCount: post.FavoritesCount,
|
FavoritesCount: post.FavoritesCount,
|
||||||
@@ -114,6 +115,7 @@ func (h *PostHandler) GetByID(c *gin.Context) {
|
|||||||
IsLocked: post.IsLocked,
|
IsLocked: post.IsLocked,
|
||||||
IsVote: post.IsVote,
|
IsVote: post.IsVote,
|
||||||
CreatedAt: dto.FormatTime(post.CreatedAt),
|
CreatedAt: dto.FormatTime(post.CreatedAt),
|
||||||
|
UpdatedAt: dto.FormatTime(post.UpdatedAt),
|
||||||
Author: authorWithFollowStatus,
|
Author: authorWithFollowStatus,
|
||||||
IsLiked: isLiked,
|
IsLiked: isLiked,
|
||||||
IsFavorited: isFavorited,
|
IsFavorited: isFavorited,
|
||||||
@@ -175,10 +177,18 @@ func (h *PostHandler) List(c *gin.Context) {
|
|||||||
posts, total, err = h.postService.GetRecommendedPosts(c.Request.Context(), currentUserID, page, pageSize)
|
posts, total, err = h.postService.GetRecommendedPosts(c.Request.Context(), currentUserID, page, pageSize)
|
||||||
case "latest":
|
case "latest":
|
||||||
// 最新帖子
|
// 最新帖子
|
||||||
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
|
if userID != "" && userID == currentUserID {
|
||||||
|
posts, total, err = h.postService.GetLatestPostsForOwner(c.Request.Context(), page, pageSize, userID)
|
||||||
|
} else {
|
||||||
|
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
// 默认获取最新帖子
|
// 默认获取最新帖子
|
||||||
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
|
if userID != "" && userID == currentUserID {
|
||||||
|
posts, total, err = h.postService.GetLatestPostsForOwner(c.Request.Context(), page, pageSize, userID)
|
||||||
|
} else {
|
||||||
|
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -225,8 +235,9 @@ func (h *PostHandler) Update(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type UpdateRequest struct {
|
type UpdateRequest struct {
|
||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
|
Images *[]string `json:"images"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var req UpdateRequest
|
var req UpdateRequest
|
||||||
@@ -242,12 +253,18 @@ func (h *PostHandler) Update(c *gin.Context) {
|
|||||||
post.Content = req.Content
|
post.Content = req.Content
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.postService.Update(c.Request.Context(), post)
|
err = h.postService.UpdateWithImages(c.Request.Context(), post, req.Images)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.InternalServerError(c, "failed to update post")
|
response.InternalServerError(c, "failed to update post")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
post, err = h.postService.GetByID(c.Request.Context(), post.ID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get updated post")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
currentUserID := c.GetString("user_id")
|
currentUserID := c.GetString("user_id")
|
||||||
var isLiked, isFavorited bool
|
var isLiked, isFavorited bool
|
||||||
if currentUserID != "" {
|
if currentUserID != "" {
|
||||||
@@ -410,14 +427,15 @@ func (h *PostHandler) GetUserPosts(c *gin.Context) {
|
|||||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
|
||||||
posts, total, err := h.postService.GetUserPosts(c.Request.Context(), userID, page, pageSize)
|
currentUserID := c.GetString("user_id")
|
||||||
|
includePending := currentUserID != "" && currentUserID == userID
|
||||||
|
posts, total, err := h.postService.GetUserPosts(c.Request.Context(), userID, page, pageSize, includePending)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.InternalServerError(c, "failed to get user posts")
|
response.InternalServerError(c, "failed to get user posts")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取当前用户ID用于判断点赞和收藏状态
|
// 获取当前用户ID用于判断点赞和收藏状态
|
||||||
currentUserID := c.GetString("user_id")
|
|
||||||
isLikedMap := make(map[string]bool)
|
isLikedMap := make(map[string]bool)
|
||||||
isFavoritedMap := make(map[string]bool)
|
isFavoritedMap := make(map[string]bool)
|
||||||
if currentUserID != "" {
|
if currentUserID != "" {
|
||||||
|
|||||||
@@ -1,849 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"carrot_bbs/internal/dto"
|
|
||||||
"carrot_bbs/internal/model"
|
|
||||||
ws "carrot_bbs/internal/pkg/websocket"
|
|
||||||
"carrot_bbs/internal/repository"
|
|
||||||
"carrot_bbs/internal/service"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{
|
|
||||||
ReadBufferSize: 1024,
|
|
||||||
WriteBufferSize: 1024,
|
|
||||||
CheckOrigin: func(r *http.Request) bool {
|
|
||||||
return true // 允许所有来源,生产环境应该限制
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// WebSocketHandler WebSocket处理器
|
|
||||||
type WebSocketHandler struct {
|
|
||||||
jwtService *service.JWTService
|
|
||||||
chatService service.ChatService
|
|
||||||
groupService service.GroupService
|
|
||||||
groupRepo repository.GroupRepository
|
|
||||||
userRepo *repository.UserRepository
|
|
||||||
wsManager *ws.WebSocketManager
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewWebSocketHandler 创建WebSocket处理器
|
|
||||||
func NewWebSocketHandler(
|
|
||||||
jwtService *service.JWTService,
|
|
||||||
chatService service.ChatService,
|
|
||||||
groupService service.GroupService,
|
|
||||||
groupRepo repository.GroupRepository,
|
|
||||||
userRepo *repository.UserRepository,
|
|
||||||
wsManager *ws.WebSocketManager,
|
|
||||||
) *WebSocketHandler {
|
|
||||||
return &WebSocketHandler{
|
|
||||||
jwtService: jwtService,
|
|
||||||
chatService: chatService,
|
|
||||||
groupService: groupService,
|
|
||||||
groupRepo: groupRepo,
|
|
||||||
userRepo: userRepo,
|
|
||||||
wsManager: wsManager,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleWebSocket 处理WebSocket连接
|
|
||||||
func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) {
|
|
||||||
// 调试:打印请求头信息
|
|
||||||
log.Printf("[WebSocket] 收到请求: Method=%s, Path=%s", c.Request.Method, c.Request.URL.Path)
|
|
||||||
log.Printf("[WebSocket] 请求头: Connection=%s, Upgrade=%s",
|
|
||||||
c.GetHeader("Connection"),
|
|
||||||
c.GetHeader("Upgrade"))
|
|
||||||
log.Printf("[WebSocket] Sec-WebSocket-Key=%s, Sec-WebSocket-Version=%s",
|
|
||||||
c.GetHeader("Sec-WebSocket-Key"),
|
|
||||||
c.GetHeader("Sec-WebSocket-Version"))
|
|
||||||
|
|
||||||
// 从query参数获取token
|
|
||||||
token := c.Query("token")
|
|
||||||
if token == "" {
|
|
||||||
// 尝试从header获取
|
|
||||||
authHeader := c.GetHeader("Authorization")
|
|
||||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
|
||||||
token = strings.TrimPrefix(authHeader, "Bearer ")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if token == "" {
|
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing token"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证token
|
|
||||||
claims, err := h.jwtService.ParseToken(token)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Invalid token: %v", err)
|
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userID := claims.UserID
|
|
||||||
if userID == "" {
|
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token claims"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 升级HTTP连接为WebSocket连接
|
|
||||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to upgrade connection: %v", err)
|
|
||||||
log.Printf("[WebSocket] 请求详情 - User-Agent: %s, Content-Type: %s",
|
|
||||||
c.GetHeader("User-Agent"),
|
|
||||||
c.GetHeader("Content-Type"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建客户端
|
|
||||||
client := &ws.Client{
|
|
||||||
ID: userID,
|
|
||||||
UserID: userID,
|
|
||||||
Conn: conn,
|
|
||||||
Send: make(chan []byte, 256),
|
|
||||||
Manager: h.wsManager,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 注册客户端
|
|
||||||
h.wsManager.Register(client)
|
|
||||||
|
|
||||||
// 启动读写协程
|
|
||||||
go client.WritePump()
|
|
||||||
go h.handleMessages(client)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleMessages 处理客户端消息
|
|
||||||
// 针对移动端优化:增加超时时间到 120 秒,配合客户端 55 秒心跳
|
|
||||||
func (h *WebSocketHandler) handleMessages(client *ws.Client) {
|
|
||||||
defer func() {
|
|
||||||
h.wsManager.Unregister(client)
|
|
||||||
client.Conn.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
client.Conn.SetReadLimit(512 * 1024) // 512KB
|
|
||||||
client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒
|
|
||||||
client.Conn.SetPongHandler(func(string) error {
|
|
||||||
client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
// 心跳定时器 - 服务端主动 ping 间隔保持 30 秒
|
|
||||||
pingTicker := time.NewTicker(30 * time.Second)
|
|
||||||
defer pingTicker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-pingTicker.C:
|
|
||||||
// 发送心跳
|
|
||||||
if err := client.SendPing(); err != nil {
|
|
||||||
log.Printf("Failed to send ping: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
_, message, err := client.Conn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
||||||
log.Printf("WebSocket error: %v", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var wsMsg ws.WSMessage
|
|
||||||
if err := json.Unmarshal(message, &wsMsg); err != nil {
|
|
||||||
log.Printf("Failed to unmarshal message: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
h.processMessage(client, &wsMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// processMessage 处理消息
|
|
||||||
func (h *WebSocketHandler) processMessage(client *ws.Client, msg *ws.WSMessage) {
|
|
||||||
switch msg.Type {
|
|
||||||
case ws.MessageTypePing:
|
|
||||||
// 响应心跳
|
|
||||||
if err := client.SendPong(); err != nil {
|
|
||||||
log.Printf("Failed to send pong: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
case ws.MessageTypePong:
|
|
||||||
// 客户端响应心跳
|
|
||||||
|
|
||||||
case ws.MessageTypeMessage:
|
|
||||||
// 处理聊天消息
|
|
||||||
h.handleChatMessage(client, msg)
|
|
||||||
|
|
||||||
case ws.MessageTypeTyping:
|
|
||||||
// 处理正在输入状态
|
|
||||||
h.handleTyping(client, msg)
|
|
||||||
|
|
||||||
case ws.MessageTypeRead:
|
|
||||||
// 处理已读回执
|
|
||||||
h.handleReadReceipt(client, msg)
|
|
||||||
|
|
||||||
// 群组消息处理
|
|
||||||
case ws.MessageTypeGroupMessage:
|
|
||||||
// 处理群消息
|
|
||||||
h.handleGroupMessage(client, msg)
|
|
||||||
|
|
||||||
case ws.MessageTypeGroupTyping:
|
|
||||||
// 处理群输入状态
|
|
||||||
h.handleGroupTyping(client, msg)
|
|
||||||
|
|
||||||
case ws.MessageTypeGroupRead:
|
|
||||||
// 处理群消息已读
|
|
||||||
h.handleGroupReadReceipt(client, msg)
|
|
||||||
|
|
||||||
case ws.MessageTypeGroupRecall:
|
|
||||||
// 处理群消息撤回
|
|
||||||
h.handleGroupRecall(client, msg)
|
|
||||||
|
|
||||||
default:
|
|
||||||
log.Printf("Unknown message type: %s", msg.Type)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleChatMessage 处理聊天消息
|
|
||||||
func (h *WebSocketHandler) handleChatMessage(client *ws.Client, msg *ws.WSMessage) {
|
|
||||||
data, ok := msg.Data.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
log.Printf("Invalid message data format")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
conversationIDStr, _ := data["conversationId"].(string)
|
|
||||||
|
|
||||||
if conversationIDStr == "" {
|
|
||||||
log.Printf("Missing conversationId")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析会话ID
|
|
||||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Invalid conversation ID: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析 segments
|
|
||||||
var segments model.MessageSegments
|
|
||||||
if data["segments"] != nil {
|
|
||||||
segmentsBytes, err := json.Marshal(data["segments"])
|
|
||||||
if err == nil {
|
|
||||||
json.Unmarshal(segmentsBytes, &segments)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 从 segments 中提取回复消息ID
|
|
||||||
replyToID := dto.GetReplyMessageID(segments)
|
|
||||||
var replyToIDPtr *string
|
|
||||||
if replyToID != "" {
|
|
||||||
replyToIDPtr = &replyToID
|
|
||||||
}
|
|
||||||
|
|
||||||
// 发送消息 - 使用 segments
|
|
||||||
message, err := h.chatService.SendMessage(context.Background(), client.UserID, conversationID, segments, replyToIDPtr)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to send message: %v", err)
|
|
||||||
// 发送错误消息
|
|
||||||
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
|
||||||
"error": "Failed to send message",
|
|
||||||
})
|
|
||||||
if client.Send != nil {
|
|
||||||
msgBytes, _ := json.Marshal(errorMsg)
|
|
||||||
client.Send <- msgBytes
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 发送确认消息(使用 meta 事件格式,包含完整的消息内容)
|
|
||||||
metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{
|
|
||||||
"detail_type": ws.MetaDetailTypeAck,
|
|
||||||
"conversation_id": conversationID,
|
|
||||||
"id": message.ID,
|
|
||||||
"user_id": client.UserID,
|
|
||||||
"sender_id": client.UserID,
|
|
||||||
"seq": message.Seq,
|
|
||||||
"segments": message.Segments,
|
|
||||||
"created_at": message.CreatedAt.UnixMilli(),
|
|
||||||
})
|
|
||||||
if client.Send != nil {
|
|
||||||
msgBytes, _ := json.Marshal(metaAckMsg)
|
|
||||||
client.Send <- msgBytes
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleTyping 处理正在输入状态
|
|
||||||
func (h *WebSocketHandler) handleTyping(client *ws.Client, msg *ws.WSMessage) {
|
|
||||||
data, ok := msg.Data.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
conversationIDStr, _ := data["conversationId"].(string)
|
|
||||||
if conversationIDStr == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 直接使用 string 类型的 userID
|
|
||||||
h.chatService.SendTyping(context.Background(), client.UserID, conversationID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleReadReceipt 处理已读回执
|
|
||||||
func (h *WebSocketHandler) handleReadReceipt(client *ws.Client, msg *ws.WSMessage) {
|
|
||||||
data, ok := msg.Data.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
conversationIDStr, _ := data["conversationId"].(string)
|
|
||||||
if conversationIDStr == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取lastReadSeq
|
|
||||||
lastReadSeq, _ := data["lastReadSeq"].(float64)
|
|
||||||
if lastReadSeq == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 直接使用 string 类型的 userID 和 conversationID
|
|
||||||
if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil {
|
|
||||||
log.Printf("Failed to mark as read: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== 群组消息处理 ====================
|
|
||||||
|
|
||||||
// handleGroupMessage 处理群消息
|
|
||||||
func (h *WebSocketHandler) handleGroupMessage(client *ws.Client, msg *ws.WSMessage) {
|
|
||||||
// 打印接收到的消息类型和数据,用于调试
|
|
||||||
log.Printf("[handleGroupMessage] Received message type: %s", msg.Type)
|
|
||||||
log.Printf("[handleGroupMessage] Message data: %+v", msg.Data)
|
|
||||||
|
|
||||||
data, ok := msg.Data.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
log.Printf("Invalid group message data format: data is not map[string]interface{}")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析群组ID(支持 camelCase 和 snake_case)
|
|
||||||
var groupIDFloat float64
|
|
||||||
groupID := "" // 使用 groupID 作为最终变量名
|
|
||||||
if val, ok := data["groupId"].(float64); ok {
|
|
||||||
groupIDFloat = val
|
|
||||||
groupID = strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
|
|
||||||
} else if val, ok := data["group_id"].(string); ok {
|
|
||||||
groupID = val
|
|
||||||
}
|
|
||||||
|
|
||||||
if groupID == "" {
|
|
||||||
log.Printf("Missing groupId in group message")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析会话ID(支持 camelCase 和 snake_case)
|
|
||||||
var conversationID string
|
|
||||||
if val, ok := data["conversationId"].(string); ok {
|
|
||||||
conversationID = val
|
|
||||||
} else if val, ok := data["conversation_id"].(string); ok {
|
|
||||||
conversationID = val
|
|
||||||
}
|
|
||||||
if conversationID == "" {
|
|
||||||
log.Printf("Missing conversationId in group message")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析 segments
|
|
||||||
var segments model.MessageSegments
|
|
||||||
if data["segments"] != nil {
|
|
||||||
segmentsBytes, err := json.Marshal(data["segments"])
|
|
||||||
if err == nil {
|
|
||||||
json.Unmarshal(segmentsBytes, &segments)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析@用户列表(支持 camelCase 和 snake_case)
|
|
||||||
var mentionUsers []uint64
|
|
||||||
var mentionUsersInterface []interface{}
|
|
||||||
if val, ok := data["mentionUsers"].([]interface{}); ok {
|
|
||||||
mentionUsersInterface = val
|
|
||||||
} else if val, ok := data["mention_users"].([]interface{}); ok {
|
|
||||||
mentionUsersInterface = val
|
|
||||||
}
|
|
||||||
if len(mentionUsersInterface) > 0 {
|
|
||||||
for _, uid := range mentionUsersInterface {
|
|
||||||
if uidFloat, ok := uid.(float64); ok {
|
|
||||||
mentionUsers = append(mentionUsers, uint64(uidFloat))
|
|
||||||
} else if uidStr, ok := uid.(string); ok {
|
|
||||||
// 处理字符串格式的用户ID
|
|
||||||
if uidInt, err := strconv.ParseUint(uidStr, 10, 64); err == nil {
|
|
||||||
mentionUsers = append(mentionUsers, uidInt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析@所有人(支持 camelCase 和 snake_case)
|
|
||||||
var mentionAll bool
|
|
||||||
if val, ok := data["mentionAll"].(bool); ok {
|
|
||||||
mentionAll = val
|
|
||||||
} else if val, ok := data["mention_all"].(bool); ok {
|
|
||||||
mentionAll = val
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查用户是否可以发送群消息(验证成员身份和禁言状态)
|
|
||||||
// client.UserID 已经是 string 格式的 UUID
|
|
||||||
if err := h.groupService.CanSendGroupMessage(client.UserID, groupID); err != nil {
|
|
||||||
log.Printf("User cannot send group message: %v", err)
|
|
||||||
// 发送错误消息
|
|
||||||
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
|
||||||
"error": "Cannot send group message",
|
|
||||||
"reason": err.Error(),
|
|
||||||
"type": "group_message_error",
|
|
||||||
"groupId": groupID,
|
|
||||||
})
|
|
||||||
if client.Send != nil {
|
|
||||||
msgBytes, _ := json.Marshal(errorMsg)
|
|
||||||
client.Send <- msgBytes
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查@所有人权限(只有群主和管理员可以@所有人)
|
|
||||||
if mentionAll {
|
|
||||||
if !h.groupService.IsGroupAdmin(client.UserID, groupID) {
|
|
||||||
log.Printf("User %s has no permission to mention all in group %s", client.UserID, groupID)
|
|
||||||
mentionAll = false // 取消@所有人标记
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建消息
|
|
||||||
message := &model.Message{
|
|
||||||
ConversationID: conversationID,
|
|
||||||
SenderID: client.UserID,
|
|
||||||
Segments: segments,
|
|
||||||
Status: model.MessageStatusNormal,
|
|
||||||
MentionAll: mentionAll,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 序列化mentionUsers为JSON
|
|
||||||
if len(mentionUsers) > 0 {
|
|
||||||
mentionUsersJSON, _ := json.Marshal(mentionUsers)
|
|
||||||
message.MentionUsers = string(mentionUsersJSON)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存消息到数据库(只存库,不发私聊 WebSocket 帧,群消息通过 BroadcastGroupMessage 单独广播)
|
|
||||||
savedMessage, err := h.chatService.SaveMessage(context.Background(), client.UserID, conversationID, segments, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to save group message: %v", err)
|
|
||||||
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
|
||||||
"error": "Failed to save group message",
|
|
||||||
})
|
|
||||||
if client.Send != nil {
|
|
||||||
msgBytes, _ := json.Marshal(errorMsg)
|
|
||||||
client.Send <- msgBytes
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新消息的mention信息
|
|
||||||
if len(mentionUsers) > 0 || mentionAll {
|
|
||||||
message.ID = savedMessage.ID
|
|
||||||
message.Seq = savedMessage.Seq
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构造群消息响应
|
|
||||||
groupMsg := &ws.GroupMessage{
|
|
||||||
ID: savedMessage.ID,
|
|
||||||
ConversationID: conversationID,
|
|
||||||
GroupID: groupID,
|
|
||||||
SenderID: client.UserID,
|
|
||||||
Seq: savedMessage.Seq,
|
|
||||||
Segments: segments,
|
|
||||||
MentionUsers: mentionUsers,
|
|
||||||
MentionAll: mentionAll,
|
|
||||||
CreatedAt: savedMessage.CreatedAt.UnixMilli(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// 广播消息给群组所有成员(排除发送者)
|
|
||||||
h.BroadcastGroupMessage(groupID, groupMsg, client.UserID)
|
|
||||||
|
|
||||||
// 发送确认消息给发送者(作为meta事件)
|
|
||||||
// 使用 meta 事件格式发送 ack
|
|
||||||
metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{
|
|
||||||
"detail_type": ws.MetaDetailTypeAck,
|
|
||||||
"conversation_id": conversationID,
|
|
||||||
"group_id": groupID,
|
|
||||||
"id": savedMessage.ID,
|
|
||||||
"user_id": client.UserID,
|
|
||||||
"sender_id": client.UserID,
|
|
||||||
"seq": savedMessage.Seq,
|
|
||||||
"segments": segments,
|
|
||||||
"created_at": savedMessage.CreatedAt.UnixMilli(),
|
|
||||||
})
|
|
||||||
if client.Send != nil {
|
|
||||||
msgBytes, _ := json.Marshal(metaAckMsg)
|
|
||||||
client.Send <- msgBytes
|
|
||||||
} else {
|
|
||||||
log.Printf("[ERROR HandleGroupMessageSend] client.Send 为 nil, userID=%s", client.UserID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 处理@提及通知
|
|
||||||
if len(mentionUsers) > 0 || mentionAll {
|
|
||||||
// 提取文本正文(不含 @ 部分)
|
|
||||||
textContent := dto.ExtractTextContentFromModel(segments)
|
|
||||||
// 在通知内容前拼接被@的真实昵称,通过群成员列表查找
|
|
||||||
mentionContent := h.buildMentionContent(groupID, mentionUsers, mentionAll, textContent)
|
|
||||||
h.handleGroupMention(groupID, savedMessage.ID, client.UserID, mentionContent, mentionUsers, mentionAll)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleGroupTyping 处理群输入状态
|
|
||||||
func (h *WebSocketHandler) handleGroupTyping(client *ws.Client, msg *ws.WSMessage) {
|
|
||||||
data, ok := msg.Data.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
groupIDFloat, _ := data["groupId"].(float64)
|
|
||||||
if groupIDFloat == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
|
|
||||||
|
|
||||||
isTyping, _ := data["isTyping"].(bool)
|
|
||||||
|
|
||||||
// 验证用户是否是群成员
|
|
||||||
// client.UserID 已经是 string 格式的 UUID
|
|
||||||
isMember, err := h.groupRepo.IsMember(groupID, client.UserID)
|
|
||||||
if err != nil || !isMember {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取用户信息
|
|
||||||
user, err := h.userRepo.GetByID(client.UserID)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构造输入状态消息
|
|
||||||
typingMsg := &ws.GroupTypingMessage{
|
|
||||||
GroupID: groupID,
|
|
||||||
UserID: client.UserID,
|
|
||||||
Username: user.Username,
|
|
||||||
IsTyping: isTyping,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 广播给群组其他成员
|
|
||||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg)
|
|
||||||
h.BroadcastGroupNoticeExclude(groupID, wsMsg, client.UserID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleGroupReadReceipt 处理群消息已读回执
|
|
||||||
func (h *WebSocketHandler) handleGroupReadReceipt(client *ws.Client, msg *ws.WSMessage) {
|
|
||||||
data, ok := msg.Data.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
conversationID, _ := data["conversationId"].(string)
|
|
||||||
if conversationID == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
lastReadSeq, _ := data["lastReadSeq"].(float64)
|
|
||||||
if lastReadSeq == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 标记已读
|
|
||||||
if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil {
|
|
||||||
log.Printf("Failed to mark group message as read: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleGroupRecall 处理群消息撤回
|
|
||||||
func (h *WebSocketHandler) handleGroupRecall(client *ws.Client, msg *ws.WSMessage) {
|
|
||||||
data, ok := msg.Data.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
messageID, _ := data["messageId"].(string)
|
|
||||||
if messageID == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
groupIDFloat, _ := data["groupId"].(float64)
|
|
||||||
if groupIDFloat == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
|
|
||||||
|
|
||||||
// 撤回消息
|
|
||||||
if err := h.chatService.RecallMessage(context.Background(), messageID, client.UserID); err != nil {
|
|
||||||
log.Printf("Failed to recall group message: %v", err)
|
|
||||||
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
|
||||||
"error": "Failed to recall message",
|
|
||||||
})
|
|
||||||
if client.Send != nil {
|
|
||||||
msgBytes, _ := json.Marshal(errorMsg)
|
|
||||||
client.Send <- msgBytes
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 广播撤回通知给群组所有成员
|
|
||||||
recallNotice := ws.CreateWSMessage(ws.MessageTypeGroupRecall, map[string]interface{}{
|
|
||||||
"messageId": messageID,
|
|
||||||
"groupId": groupID,
|
|
||||||
"userId": client.UserID,
|
|
||||||
"timestamp": time.Now().UnixMilli(),
|
|
||||||
})
|
|
||||||
h.BroadcastGroupNotice(groupID, recallNotice)
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleGroupMention 处理群消息@提及通知
|
|
||||||
func (h *WebSocketHandler) handleGroupMention(groupID string, messageID, senderID, content string, mentionUsers []uint64, mentionAll bool) {
|
|
||||||
// 如果@所有人,获取所有群成员
|
|
||||||
if mentionAll {
|
|
||||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to get group members for mention all: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, member := range members {
|
|
||||||
// 不通知发送者自己
|
|
||||||
memberIDStr := member.UserID
|
|
||||||
if memberIDStr == senderID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 发送@提及通知
|
|
||||||
mentionMsg := &ws.GroupMentionMessage{
|
|
||||||
GroupID: groupID,
|
|
||||||
MessageID: messageID,
|
|
||||||
FromUserID: senderID,
|
|
||||||
Content: truncateContent(content, 50),
|
|
||||||
MentionAll: true,
|
|
||||||
CreatedAt: time.Now().UnixMilli(),
|
|
||||||
}
|
|
||||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg)
|
|
||||||
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 处理特定用户的@提及
|
|
||||||
for _, userID := range mentionUsers {
|
|
||||||
// userID 是 uint64,转换为 string
|
|
||||||
userIDStr := strconv.FormatUint(userID, 10)
|
|
||||||
if userIDStr == senderID {
|
|
||||||
continue // 不通知发送者自己
|
|
||||||
}
|
|
||||||
|
|
||||||
mentionMsg := &ws.GroupMentionMessage{
|
|
||||||
GroupID: groupID,
|
|
||||||
MessageID: messageID,
|
|
||||||
FromUserID: senderID,
|
|
||||||
Content: truncateContent(content, 50),
|
|
||||||
MentionAll: false,
|
|
||||||
CreatedAt: time.Now().UnixMilli(),
|
|
||||||
}
|
|
||||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg)
|
|
||||||
h.wsManager.SendToUser(userIDStr, wsMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildMentionContent 构建@提及通知的内容,通过群成员列表查找被@用户的真实昵称
|
|
||||||
func (h *WebSocketHandler) buildMentionContent(groupID string, mentionUsers []uint64, mentionAll bool, textBody string) string {
|
|
||||||
var prefix string
|
|
||||||
if mentionAll {
|
|
||||||
prefix = "@所有人 "
|
|
||||||
} else if len(mentionUsers) > 0 {
|
|
||||||
// 查询群成员列表,找到被@用户的昵称
|
|
||||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
|
||||||
if err == nil {
|
|
||||||
memberNickMap := make(map[string]string, len(members))
|
|
||||||
for _, m := range members {
|
|
||||||
displayName := m.Nickname
|
|
||||||
if displayName == "" {
|
|
||||||
displayName = m.UserID
|
|
||||||
}
|
|
||||||
memberNickMap[m.UserID] = displayName
|
|
||||||
}
|
|
||||||
for _, uid := range mentionUsers {
|
|
||||||
uidStr := strconv.FormatUint(uid, 10)
|
|
||||||
if name, ok := memberNickMap[uidStr]; ok {
|
|
||||||
prefix += "@" + name + " "
|
|
||||||
} else {
|
|
||||||
prefix += "@某人 "
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for range mentionUsers {
|
|
||||||
prefix += "@某人 "
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return prefix + textBody
|
|
||||||
}
|
|
||||||
|
|
||||||
// BroadcastGroupMessage 向群组所有成员广播消息
|
|
||||||
func (h *WebSocketHandler) BroadcastGroupMessage(groupID string, message *ws.GroupMessage, excludeUserID string) {
|
|
||||||
// 获取群组所有成员
|
|
||||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to get group members for broadcast: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建WebSocket消息
|
|
||||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMessage, message)
|
|
||||||
|
|
||||||
// 遍历成员,如果在线则发送消息
|
|
||||||
for _, member := range members {
|
|
||||||
memberIDStr := member.UserID
|
|
||||||
|
|
||||||
// 排除发送者
|
|
||||||
if memberIDStr == excludeUserID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 发送消息
|
|
||||||
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BroadcastGroupNotice 广播群组通知给所有成员
|
|
||||||
func (h *WebSocketHandler) BroadcastGroupNotice(groupID string, notice *ws.WSMessage) {
|
|
||||||
// 获取群组所有成员
|
|
||||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to get group members for notice broadcast: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 遍历成员,如果在线则发送通知
|
|
||||||
for _, member := range members {
|
|
||||||
memberIDStr := member.UserID
|
|
||||||
h.wsManager.SendToUser(memberIDStr, notice)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BroadcastGroupNoticeExclude 广播群组通知给所有成员(排除指定用户)
|
|
||||||
func (h *WebSocketHandler) BroadcastGroupNoticeExclude(groupID string, notice *ws.WSMessage, excludeUserID string) {
|
|
||||||
// 获取群组所有成员
|
|
||||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to get group members for notice broadcast: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 遍历成员,如果在线则发送通知
|
|
||||||
for _, member := range members {
|
|
||||||
memberIDStr := member.UserID
|
|
||||||
if memberIDStr == excludeUserID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
h.wsManager.SendToUser(memberIDStr, notice)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendGroupMemberNotice 发送群成员变动通知
|
|
||||||
func (h *WebSocketHandler) SendGroupMemberNotice(noticeType string, groupID string, data *ws.GroupNoticeData) {
|
|
||||||
notice := &ws.GroupNoticeMessage{
|
|
||||||
NoticeType: noticeType,
|
|
||||||
GroupID: groupID,
|
|
||||||
Data: data,
|
|
||||||
Timestamp: time.Now().UnixMilli(),
|
|
||||||
}
|
|
||||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupNotice, notice)
|
|
||||||
h.BroadcastGroupNotice(groupID, wsMsg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// truncateContent 截断内容
|
|
||||||
func truncateContent(content string, maxLen int) string {
|
|
||||||
if len(content) <= maxLen {
|
|
||||||
return content
|
|
||||||
}
|
|
||||||
return content[:maxLen] + "..."
|
|
||||||
}
|
|
||||||
|
|
||||||
// BroadcastGroupTyping 向群组所有成员广播输入状态
|
|
||||||
func (h *WebSocketHandler) BroadcastGroupTyping(groupID string, typingMsg *ws.GroupTypingMessage, excludeUserID string) {
|
|
||||||
// 获取群组所有成员
|
|
||||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to get group members for typing broadcast: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建WebSocket消息
|
|
||||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg)
|
|
||||||
|
|
||||||
// 遍历成员,如果在线则发送消息
|
|
||||||
for _, member := range members {
|
|
||||||
memberIDStr := member.UserID
|
|
||||||
|
|
||||||
// 排除指定用户
|
|
||||||
if memberIDStr == excludeUserID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 发送消息
|
|
||||||
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BroadcastGroupRead 向群组所有成员广播已读状态
|
|
||||||
func (h *WebSocketHandler) BroadcastGroupRead(groupID string, readMsg map[string]interface{}, excludeUserID string) {
|
|
||||||
// 获取群组所有成员
|
|
||||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to get group members for read broadcast: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建WebSocket消息
|
|
||||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupRead, readMsg)
|
|
||||||
|
|
||||||
// 遍历成员,如果在线则发送消息
|
|
||||||
for _, member := range members {
|
|
||||||
memberIDStr := member.UserID
|
|
||||||
|
|
||||||
// 排除指定用户
|
|
||||||
if memberIDStr == excludeUserID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 发送消息
|
|
||||||
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,18 +1,10 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import "github.com/gin-gonic/gin"
|
||||||
"log"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CORS CORS中间件
|
// CORS CORS中间件
|
||||||
func CORS() gin.HandlerFunc {
|
func CORS() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// 获取请求路径
|
|
||||||
path := c.Request.URL.Path
|
|
||||||
|
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||||
// 添加 WebSocket 升级所需的头
|
// 添加 WebSocket 升级所需的头
|
||||||
@@ -22,25 +14,10 @@ func CORS() gin.HandlerFunc {
|
|||||||
|
|
||||||
// 处理 WebSocket 升级请求的预检
|
// 处理 WebSocket 升级请求的预检
|
||||||
if c.Request.Method == "OPTIONS" {
|
if c.Request.Method == "OPTIONS" {
|
||||||
log.Printf("[CORS] OPTIONS 预检请求: %s", path)
|
|
||||||
c.AbortWithStatus(204)
|
c.AbortWithStatus(204)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 针对 WebSocket 路径的特殊处理
|
|
||||||
if path == "/ws" {
|
|
||||||
connection := c.GetHeader("Connection")
|
|
||||||
upgrade := c.GetHeader("Upgrade")
|
|
||||||
log.Printf("[CORS] WebSocket 请求: Connection=%s, Upgrade=%s", connection, upgrade)
|
|
||||||
|
|
||||||
// 检查是否是有效的 WebSocket 升级请求
|
|
||||||
if strings.Contains(strings.ToLower(connection), "upgrade") && strings.ToLower(upgrade) == "websocket" {
|
|
||||||
log.Printf("[CORS] 有效的 WebSocket 升级请求")
|
|
||||||
} else {
|
|
||||||
log.Printf("[CORS] 警告: 不是有效的 WebSocket 升级请求!")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ type Post struct {
|
|||||||
|
|
||||||
// 时间戳
|
// 时间戳
|
||||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_posts_status_created,priority:2,sort:desc;index:idx_posts_user_status_created,priority:3,sort:desc;index:idx_posts_hot_score_created,priority:2,sort:desc"`
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_posts_status_created,priority:2,sort:desc;index:idx_posts_user_status_created,priority:3,sort:desc;index:idx_posts_hot_score_created,priority:2,sort:desc"`
|
||||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime:false"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// BeforeCreate 创建前生成UUID
|
// BeforeCreate 创建前生成UUID
|
||||||
|
|||||||
152
internal/pkg/sse/hub.go
Normal file
152
internal/pkg/sse/hub.go
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
package sse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultUserBufferSize = 128
|
||||||
|
maxReplayEvents = 200
|
||||||
|
)
|
||||||
|
|
||||||
|
type Event struct {
|
||||||
|
ID uint64 `json:"event_id"`
|
||||||
|
Event string `json:"event"`
|
||||||
|
TS int64 `json:"ts"`
|
||||||
|
Payload interface{} `json:"payload"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type subscriber struct {
|
||||||
|
id uint64
|
||||||
|
ch chan Event
|
||||||
|
quit chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Hub struct {
|
||||||
|
seq uint64
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
subscribers map[string]map[uint64]*subscriber
|
||||||
|
history map[string][]Event
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHub() *Hub {
|
||||||
|
return &Hub{
|
||||||
|
subscribers: make(map[string]map[uint64]*subscriber),
|
||||||
|
history: make(map[string][]Event),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hub) NextID() uint64 {
|
||||||
|
return atomic.AddUint64(&h.seq, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseEventID(raw string) uint64 {
|
||||||
|
if raw == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
id, err := strconv.ParseUint(raw, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hub) Subscribe(userID string, afterID uint64) (chan Event, func(), []Event) {
|
||||||
|
subID := h.NextID()
|
||||||
|
sub := &subscriber{
|
||||||
|
id: subID,
|
||||||
|
ch: make(chan Event, defaultUserBufferSize),
|
||||||
|
quit: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
if _, ok := h.subscribers[userID]; !ok {
|
||||||
|
h.subscribers[userID] = make(map[uint64]*subscriber)
|
||||||
|
}
|
||||||
|
h.subscribers[userID][subID] = sub
|
||||||
|
|
||||||
|
replay := make([]Event, 0)
|
||||||
|
for _, e := range h.history[userID] {
|
||||||
|
if e.ID > afterID {
|
||||||
|
replay = append(replay, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
cancel := func() {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
if userSubs, ok := h.subscribers[userID]; ok {
|
||||||
|
if s, exists := userSubs[subID]; exists {
|
||||||
|
close(s.quit)
|
||||||
|
delete(userSubs, subID)
|
||||||
|
close(s.ch)
|
||||||
|
}
|
||||||
|
if len(userSubs) == 0 {
|
||||||
|
delete(h.subscribers, userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sub.ch, cancel, replay
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hub) HasSubscribers(userID string) bool {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
return len(h.subscribers[userID]) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hub) PublishToUser(userID string, eventName string, payload interface{}) Event {
|
||||||
|
ev := Event{
|
||||||
|
ID: h.NextID(),
|
||||||
|
Event: eventName,
|
||||||
|
TS: time.Now().UnixMilli(),
|
||||||
|
Payload: payload,
|
||||||
|
}
|
||||||
|
h.publish(userID, ev)
|
||||||
|
return ev
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hub) PublishToUsers(userIDs []string, eventName string, payload interface{}) {
|
||||||
|
for _, uid := range userIDs {
|
||||||
|
h.PublishToUser(uid, eventName, payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hub) publish(userID string, ev Event) {
|
||||||
|
h.mu.Lock()
|
||||||
|
history := append(h.history[userID], ev)
|
||||||
|
if len(history) > maxReplayEvents {
|
||||||
|
history = history[len(history)-maxReplayEvents:]
|
||||||
|
}
|
||||||
|
h.history[userID] = history
|
||||||
|
|
||||||
|
targets := make([]*subscriber, 0, len(h.subscribers[userID]))
|
||||||
|
for _, s := range h.subscribers[userID] {
|
||||||
|
targets = append(targets, s)
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
for _, s := range targets {
|
||||||
|
select {
|
||||||
|
case <-s.quit:
|
||||||
|
case s.ch <- ev:
|
||||||
|
default:
|
||||||
|
// 慢消费者丢弃单条消息,客户端可通过 Last-Event-ID + HTTP 同步补偿
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeData(ev Event) (string, error) {
|
||||||
|
body, err := json.Marshal(ev.Payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(body), nil
|
||||||
|
}
|
||||||
@@ -1,435 +0,0 @@
|
|||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"carrot_bbs/internal/model"
|
|
||||||
"encoding/json"
|
|
||||||
"log"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
// WebSocket消息类型常量
|
|
||||||
const (
|
|
||||||
MessageTypePing = "ping"
|
|
||||||
MessageTypePong = "pong"
|
|
||||||
MessageTypeMessage = "message"
|
|
||||||
MessageTypeTyping = "typing"
|
|
||||||
MessageTypeRead = "read"
|
|
||||||
MessageTypeAck = "ack"
|
|
||||||
MessageTypeError = "error"
|
|
||||||
MessageTypeRecall = "recall" // 撤回消息
|
|
||||||
MessageTypeSystem = "system" // 系统消息
|
|
||||||
MessageTypeNotification = "notification" // 通知消息
|
|
||||||
MessageTypeAnnouncement = "announcement" // 公告消息
|
|
||||||
|
|
||||||
// 群组相关消息类型
|
|
||||||
MessageTypeGroupMessage = "group_message" // 群消息
|
|
||||||
MessageTypeGroupTyping = "group_typing" // 群输入状态
|
|
||||||
MessageTypeGroupNotice = "group_notice" // 群组通知(成员变动等)
|
|
||||||
MessageTypeGroupMention = "group_mention" // @提及通知
|
|
||||||
MessageTypeGroupRead = "group_read" // 群消息已读
|
|
||||||
MessageTypeGroupRecall = "group_recall" // 群消息撤回
|
|
||||||
|
|
||||||
// Meta事件详细类型
|
|
||||||
MetaDetailTypeHeartbeat = "heartbeat"
|
|
||||||
MetaDetailTypeTyping = "typing"
|
|
||||||
MetaDetailTypeAck = "ack" // 消息发送确认
|
|
||||||
MetaDetailTypeRead = "read" // 已读回执
|
|
||||||
)
|
|
||||||
|
|
||||||
// WSMessage WebSocket消息结构
|
|
||||||
type WSMessage struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Data interface{} `json:"data"`
|
|
||||||
Timestamp int64 `json:"timestamp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatMessage 聊天消息结构
|
|
||||||
type ChatMessage struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
ConversationID string `json:"conversation_id"`
|
|
||||||
SenderID string `json:"sender_id"`
|
|
||||||
Seq int64 `json:"seq"`
|
|
||||||
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
|
|
||||||
ReplyToID *string `json:"reply_to_id,omitempty"`
|
|
||||||
CreatedAt int64 `json:"created_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SystemMessage 系统消息结构
|
|
||||||
type SystemMessage struct {
|
|
||||||
ID string `json:"id"` // 消息ID
|
|
||||||
Type string `json:"type"` // 消息子类型(如:account_banned, post_approved等)
|
|
||||||
Title string `json:"title"` // 消息标题
|
|
||||||
Content string `json:"content"` // 消息内容
|
|
||||||
Data map[string]interface{} `json:"data"` // 额外数据
|
|
||||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
|
||||||
}
|
|
||||||
|
|
||||||
// NotificationMessage 通知消息结构
|
|
||||||
type NotificationMessage struct {
|
|
||||||
ID string `json:"id"` // 通知ID
|
|
||||||
Type string `json:"type"` // 通知类型(like, comment, follow, mention等)
|
|
||||||
Title string `json:"title"` // 通知标题
|
|
||||||
Content string `json:"content"` // 通知内容
|
|
||||||
TriggerUser *NotificationUser `json:"trigger_user"` // 触发用户
|
|
||||||
ResourceType string `json:"resource_type"` // 资源类型(post, comment等)
|
|
||||||
ResourceID string `json:"resource_id"` // 资源ID
|
|
||||||
Extra map[string]interface{} `json:"extra"` // 额外数据
|
|
||||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
|
||||||
}
|
|
||||||
|
|
||||||
// NotificationUser 通知中的用户信息
|
|
||||||
type NotificationUser struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Username string `json:"username"`
|
|
||||||
Avatar string `json:"avatar"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// AnnouncementMessage 公告消息结构
|
|
||||||
type AnnouncementMessage struct {
|
|
||||||
ID string `json:"id"` // 公告ID
|
|
||||||
Title string `json:"title"` // 公告标题
|
|
||||||
Content string `json:"content"` // 公告内容
|
|
||||||
Priority int `json:"priority"` // 优先级(1-10)
|
|
||||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
|
||||||
}
|
|
||||||
|
|
||||||
// GroupMessage 群消息结构
|
|
||||||
type GroupMessage struct {
|
|
||||||
ID string `json:"id"` // 消息ID
|
|
||||||
ConversationID string `json:"conversation_id"` // 会话ID(群聊会话)
|
|
||||||
GroupID string `json:"group_id"` // 群组ID
|
|
||||||
SenderID string `json:"sender_id"` // 发送者ID
|
|
||||||
Seq int64 `json:"seq"` // 消息序号
|
|
||||||
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
|
|
||||||
ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID
|
|
||||||
MentionUsers []uint64 `json:"mention_users,omitempty"` // @的用户ID列表
|
|
||||||
MentionAll bool `json:"mention_all"` // 是否@所有人
|
|
||||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
|
||||||
}
|
|
||||||
|
|
||||||
// GroupTypingMessage 群输入状态消息
|
|
||||||
type GroupTypingMessage struct {
|
|
||||||
GroupID string `json:"group_id"` // 群组ID
|
|
||||||
UserID string `json:"user_id"` // 用户ID
|
|
||||||
Username string `json:"username"` // 用户名
|
|
||||||
IsTyping bool `json:"is_typing"` // 是否正在输入
|
|
||||||
}
|
|
||||||
|
|
||||||
// GroupNoticeMessage 群组通知消息
|
|
||||||
type GroupNoticeMessage struct {
|
|
||||||
NoticeType string `json:"notice_type"` // 通知类型:member_join, member_leave, member_removed, role_changed, muted, unmuted, group_dissolved
|
|
||||||
GroupID string `json:"group_id"` // 群组ID
|
|
||||||
Data interface{} `json:"data"` // 通知数据
|
|
||||||
Timestamp int64 `json:"timestamp"` // 时间戳
|
|
||||||
MessageID string `json:"message_id,omitempty"` // 消息ID(如果通知保存为消息)
|
|
||||||
Seq int64 `json:"seq,omitempty"` // 消息序号(如果通知保存为消息)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GroupNoticeData 通知数据结构
|
|
||||||
type GroupNoticeData struct {
|
|
||||||
// 成员变动
|
|
||||||
UserID string `json:"user_id,omitempty"` // 变动的用户ID
|
|
||||||
Username string `json:"username,omitempty"` // 用户名
|
|
||||||
OperatorID string `json:"operator_id,omitempty"` // 操作者ID
|
|
||||||
OpName string `json:"op_name,omitempty"` // 操作者名称
|
|
||||||
NewRole string `json:"new_role,omitempty"` // 新角色
|
|
||||||
OldRole string `json:"old_role,omitempty"` // 旧角色
|
|
||||||
MemberCount int `json:"member_count,omitempty"` // 当前成员数
|
|
||||||
|
|
||||||
// 群设置变更
|
|
||||||
MuteAll bool `json:"mute_all,omitempty"` // 全员禁言状态
|
|
||||||
}
|
|
||||||
|
|
||||||
// GroupMentionMessage @提及通知消息
|
|
||||||
type GroupMentionMessage struct {
|
|
||||||
GroupID string `json:"group_id"` // 群组ID
|
|
||||||
MessageID string `json:"message_id"` // 消息ID
|
|
||||||
FromUserID string `json:"from_user_id"` // 发送者ID
|
|
||||||
FromName string `json:"from_name"` // 发送者名称
|
|
||||||
Content string `json:"content"` // 消息内容预览
|
|
||||||
MentionAll bool `json:"mention_all"` // 是否@所有人
|
|
||||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
|
||||||
}
|
|
||||||
|
|
||||||
// AckMessage 消息发送确认结构
|
|
||||||
type AckMessage struct {
|
|
||||||
ConversationID string `json:"conversation_id"` // 会话ID
|
|
||||||
GroupID string `json:"group_id,omitempty"` // 群组ID(群聊时)
|
|
||||||
ID string `json:"id"` // 消息ID
|
|
||||||
SenderID string `json:"sender_id"` // 发送者ID
|
|
||||||
Seq int64 `json:"seq"` // 消息序号
|
|
||||||
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
|
|
||||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
|
||||||
}
|
|
||||||
|
|
||||||
// Client WebSocket客户端
|
|
||||||
type Client struct {
|
|
||||||
ID string
|
|
||||||
UserID string
|
|
||||||
Conn *websocket.Conn
|
|
||||||
Send chan []byte
|
|
||||||
Manager *WebSocketManager
|
|
||||||
IsClosed bool
|
|
||||||
Mu sync.Mutex
|
|
||||||
closeOnce sync.Once // 确保 Send channel 只关闭一次
|
|
||||||
}
|
|
||||||
|
|
||||||
// WebSocketManager WebSocket连接管理器
|
|
||||||
type WebSocketManager struct {
|
|
||||||
clients map[string]*Client // userID -> Client
|
|
||||||
register chan *Client
|
|
||||||
unregister chan *Client
|
|
||||||
broadcast chan *BroadcastMessage
|
|
||||||
mutex sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// BroadcastMessage 广播消息
|
|
||||||
type BroadcastMessage struct {
|
|
||||||
Message *WSMessage
|
|
||||||
ExcludeUser string // 排除的用户ID,为空表示不排除
|
|
||||||
TargetUser string // 目标用户ID,为空表示广播给所有用户
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewWebSocketManager 创建WebSocket管理器
|
|
||||||
func NewWebSocketManager() *WebSocketManager {
|
|
||||||
return &WebSocketManager{
|
|
||||||
clients: make(map[string]*Client),
|
|
||||||
register: make(chan *Client, 100),
|
|
||||||
unregister: make(chan *Client, 100),
|
|
||||||
broadcast: make(chan *BroadcastMessage, 100),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start 启动管理器
|
|
||||||
func (m *WebSocketManager) Start() {
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case client := <-m.register:
|
|
||||||
m.mutex.Lock()
|
|
||||||
m.clients[client.UserID] = client
|
|
||||||
m.mutex.Unlock()
|
|
||||||
log.Printf("WebSocket client connected: userID=%s, 当前在线用户数=%d", client.UserID, len(m.clients))
|
|
||||||
|
|
||||||
case client := <-m.unregister:
|
|
||||||
m.mutex.Lock()
|
|
||||||
if _, ok := m.clients[client.UserID]; ok {
|
|
||||||
delete(m.clients, client.UserID)
|
|
||||||
// 使用 closeOnce 确保 channel 只关闭一次,避免 panic
|
|
||||||
client.closeOnce.Do(func() {
|
|
||||||
close(client.Send)
|
|
||||||
})
|
|
||||||
log.Printf("WebSocket client disconnected: userID=%s", client.UserID)
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
|
|
||||||
case broadcast := <-m.broadcast:
|
|
||||||
m.sendMessage(broadcast)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register 注册客户端
|
|
||||||
func (m *WebSocketManager) Register(client *Client) {
|
|
||||||
m.register <- client
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unregister 注销客户端
|
|
||||||
func (m *WebSocketManager) Unregister(client *Client) {
|
|
||||||
m.unregister <- client
|
|
||||||
}
|
|
||||||
|
|
||||||
// Broadcast 广播消息给所有用户
|
|
||||||
func (m *WebSocketManager) Broadcast(msg *WSMessage) {
|
|
||||||
m.broadcast <- &BroadcastMessage{
|
|
||||||
Message: msg,
|
|
||||||
TargetUser: "",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendToUser 发送消息给指定用户
|
|
||||||
func (m *WebSocketManager) SendToUser(userID string, msg *WSMessage) {
|
|
||||||
m.broadcast <- &BroadcastMessage{
|
|
||||||
Message: msg,
|
|
||||||
TargetUser: userID,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendToUsers 发送消息给指定用户列表
|
|
||||||
func (m *WebSocketManager) SendToUsers(userIDs []string, msg *WSMessage) {
|
|
||||||
for _, userID := range userIDs {
|
|
||||||
m.SendToUser(userID, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClient 获取客户端
|
|
||||||
func (m *WebSocketManager) GetClient(userID string) (*Client, bool) {
|
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
client, ok := m.clients[userID]
|
|
||||||
return client, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllClients 获取所有客户端
|
|
||||||
func (m *WebSocketManager) GetAllClients() map[string]*Client {
|
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
return m.clients
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClientCount 获取在线用户数量
|
|
||||||
func (m *WebSocketManager) GetClientCount() int {
|
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
return len(m.clients)
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsUserOnline 检查用户是否在线
|
|
||||||
func (m *WebSocketManager) IsUserOnline(userID string) bool {
|
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
_, ok := m.clients[userID]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendMessage 发送消息
|
|
||||||
func (m *WebSocketManager) sendMessage(broadcast *BroadcastMessage) {
|
|
||||||
msgBytes, err := json.Marshal(broadcast.Message)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to marshal message: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
|
|
||||||
for userID, client := range m.clients {
|
|
||||||
// 如果指定了目标用户,只发送给目标用户
|
|
||||||
if broadcast.TargetUser != "" && userID != broadcast.TargetUser {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果指定了排除用户,跳过
|
|
||||||
if broadcast.ExcludeUser != "" && userID == broadcast.ExcludeUser {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case client.Send <- msgBytes:
|
|
||||||
default:
|
|
||||||
log.Printf("Failed to send message to user %s: channel full", userID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendPing 发送心跳
|
|
||||||
func (c *Client) SendPing() error {
|
|
||||||
c.Mu.Lock()
|
|
||||||
defer c.Mu.Unlock()
|
|
||||||
if c.IsClosed {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
msg := WSMessage{
|
|
||||||
Type: MessageTypePing,
|
|
||||||
Data: nil,
|
|
||||||
Timestamp: time.Now().UnixMilli(),
|
|
||||||
}
|
|
||||||
msgBytes, _ := json.Marshal(msg)
|
|
||||||
return c.Conn.WriteMessage(websocket.TextMessage, msgBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendPong 发送Pong响应
|
|
||||||
func (c *Client) SendPong() error {
|
|
||||||
c.Mu.Lock()
|
|
||||||
defer c.Mu.Unlock()
|
|
||||||
if c.IsClosed {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
msg := WSMessage{
|
|
||||||
Type: MessageTypePong,
|
|
||||||
Data: nil,
|
|
||||||
Timestamp: time.Now().UnixMilli(),
|
|
||||||
}
|
|
||||||
msgBytes, _ := json.Marshal(msg)
|
|
||||||
return c.Conn.WriteMessage(websocket.TextMessage, msgBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// WritePump 写入泵,将消息从Manager发送到客户端
|
|
||||||
func (c *Client) WritePump() {
|
|
||||||
defer func() {
|
|
||||||
c.Conn.Close()
|
|
||||||
c.Mu.Lock()
|
|
||||||
c.IsClosed = true
|
|
||||||
c.Mu.Unlock()
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
message, ok := <-c.Send
|
|
||||||
if !ok {
|
|
||||||
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Mu.Lock()
|
|
||||||
if c.IsClosed {
|
|
||||||
c.Mu.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err := c.Conn.WriteMessage(websocket.TextMessage, message)
|
|
||||||
c.Mu.Unlock()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Write error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadPump 读取泵,从客户端读取消息
|
|
||||||
func (c *Client) ReadPump(handler func(msg *WSMessage)) {
|
|
||||||
defer func() {
|
|
||||||
c.Manager.Unregister(c)
|
|
||||||
c.Conn.Close()
|
|
||||||
c.Mu.Lock()
|
|
||||||
c.IsClosed = true
|
|
||||||
c.Mu.Unlock()
|
|
||||||
}()
|
|
||||||
|
|
||||||
c.Conn.SetReadLimit(512 * 1024) // 512KB
|
|
||||||
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
||||||
c.Conn.SetPongHandler(func(string) error {
|
|
||||||
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
for {
|
|
||||||
_, message, err := c.Conn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
||||||
log.Printf("WebSocket error: %v", err)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
var wsMsg WSMessage
|
|
||||||
if err := json.Unmarshal(message, &wsMsg); err != nil {
|
|
||||||
log.Printf("Failed to unmarshal message: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
handler(&wsMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateWSMessage 创建WebSocket消息
|
|
||||||
func CreateWSMessage(msgType string, data interface{}) *WSMessage {
|
|
||||||
return &WSMessage{
|
|
||||||
Type: msgType,
|
|
||||||
Data: data,
|
|
||||||
Timestamp: time.Now().UnixMilli(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -18,32 +18,7 @@ func NewCommentRepository(db *gorm.DB) *CommentRepository {
|
|||||||
|
|
||||||
// Create 创建评论
|
// Create 创建评论
|
||||||
func (r *CommentRepository) Create(comment *model.Comment) error {
|
func (r *CommentRepository) Create(comment *model.Comment) error {
|
||||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
return r.db.Create(comment).Error
|
||||||
// 创建评论
|
|
||||||
err := tx.Create(comment).Error
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 增加帖子的评论数并同步热度分
|
|
||||||
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
|
|
||||||
Updates(map[string]interface{}{
|
|
||||||
"comments_count": gorm.Expr("comments_count + 1"),
|
|
||||||
"hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"),
|
|
||||||
}).Error; err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果是回复,增加父评论的回复数
|
|
||||||
if comment.ParentID != nil && *comment.ParentID != "" {
|
|
||||||
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
|
|
||||||
UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByID 根据ID获取评论
|
// GetByID 根据ID获取评论
|
||||||
@@ -87,23 +62,52 @@ func (r *CommentRepository) Delete(id string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 减少帖子的评论数并同步热度分
|
// 仅已发布评论才参与统计,避免 pending/rejected 影响计数
|
||||||
|
if comment.Status == model.CommentStatusPublished {
|
||||||
|
// 减少帖子的评论数并同步热度分
|
||||||
|
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"comments_count": gorm.Expr("comments_count - 1"),
|
||||||
|
"hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"),
|
||||||
|
}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果是回复,减少父评论的回复数
|
||||||
|
if comment.ParentID != nil && *comment.ParentID != "" {
|
||||||
|
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
|
||||||
|
UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyPublishedStats 在评论审核通过后更新帖子评论数/回复数
|
||||||
|
func (r *CommentRepository) ApplyPublishedStats(comment *model.Comment) error {
|
||||||
|
if comment == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 增加帖子的评论数并同步热度分
|
||||||
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
|
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]interface{}{
|
||||||
"comments_count": gorm.Expr("comments_count - 1"),
|
"comments_count": gorm.Expr("comments_count + 1"),
|
||||||
"hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"),
|
"hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"),
|
||||||
}).Error; err != nil {
|
}).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果是回复,减少父评论的回复数
|
// 如果是回复,增加父评论的回复数
|
||||||
if comment.ParentID != nil && *comment.ParentID != "" {
|
if comment.ParentID != nil && *comment.ParentID != "" {
|
||||||
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
|
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
|
||||||
UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil {
|
UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"carrot_bbs/internal/model"
|
"carrot_bbs/internal/model"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -52,9 +53,41 @@ func (r *PostRepository) GetByID(id string) (*model.Post, error) {
|
|||||||
|
|
||||||
// Update 更新帖子
|
// Update 更新帖子
|
||||||
func (r *PostRepository) Update(post *model.Post) error {
|
func (r *PostRepository) Update(post *model.Post) error {
|
||||||
|
post.UpdatedAt = time.Now()
|
||||||
return r.db.Save(post).Error
|
return r.db.Save(post).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateWithImages 更新帖子及其图片(images=nil 表示不更新图片)
|
||||||
|
func (r *PostRepository) UpdateWithImages(post *model.Post, images *[]string) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
post.UpdatedAt = time.Now()
|
||||||
|
if err := tx.Save(post).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if images == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Where("post_id = ?", post.ID).Delete(&model.PostImage{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, url := range *images {
|
||||||
|
image := &model.PostImage{
|
||||||
|
PostID: post.ID,
|
||||||
|
URL: url,
|
||||||
|
SortOrder: i,
|
||||||
|
}
|
||||||
|
if err := tx.Create(image).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateModerationStatus 更新帖子审核状态
|
// UpdateModerationStatus 更新帖子审核状态
|
||||||
func (r *PostRepository) UpdateModerationStatus(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error {
|
func (r *PostRepository) UpdateModerationStatus(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error {
|
||||||
updates := map[string]interface{}{
|
updates := map[string]interface{}{
|
||||||
@@ -100,15 +133,24 @@ func (r *PostRepository) Delete(id string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// List 分页获取帖子列表
|
// List 分页获取帖子列表
|
||||||
func (r *PostRepository) List(page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
// includePending=true 时,仅在指定 userID 下额外返回 pending(用于作者查看自己待审核帖子)
|
||||||
|
func (r *PostRepository) List(page, pageSize int, userID string, includePending bool) ([]*model.Post, int64, error) {
|
||||||
var posts []*model.Post
|
var posts []*model.Post
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
query := r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished)
|
query := r.db.Model(&model.Post{})
|
||||||
|
|
||||||
if userID != "" {
|
if userID != "" {
|
||||||
query = query.Where("user_id = ?", userID)
|
query = query.Where("user_id = ?", userID)
|
||||||
}
|
}
|
||||||
|
if includePending && userID != "" {
|
||||||
|
query = query.Where("status IN ?", []model.PostStatus{
|
||||||
|
model.PostStatusPublished,
|
||||||
|
model.PostStatusPending,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
query = query.Where("status = ?", model.PostStatusPublished)
|
||||||
|
}
|
||||||
|
|
||||||
query.Count(&total)
|
query.Count(&total)
|
||||||
|
|
||||||
@@ -119,14 +161,32 @@ func (r *PostRepository) List(page, pageSize int, userID string) ([]*model.Post,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUserPosts 获取用户帖子
|
// GetUserPosts 获取用户帖子
|
||||||
func (r *PostRepository) GetUserPosts(userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
func (r *PostRepository) GetUserPosts(userID string, page, pageSize int, includePending bool) ([]*model.Post, int64, error) {
|
||||||
var posts []*model.Post
|
var posts []*model.Post
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
r.db.Model(&model.Post{}).Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Count(&total)
|
statusQuery := r.db.Model(&model.Post{}).Where("user_id = ?", userID)
|
||||||
|
if includePending {
|
||||||
|
statusQuery = statusQuery.Where("status IN ?", []model.PostStatus{
|
||||||
|
model.PostStatusPublished,
|
||||||
|
model.PostStatusPending,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
statusQuery = statusQuery.Where("status = ?", model.PostStatusPublished)
|
||||||
|
}
|
||||||
|
statusQuery.Count(&total)
|
||||||
|
|
||||||
offset := (page - 1) * pageSize
|
offset := (page - 1) * pageSize
|
||||||
err := r.db.Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
|
listQuery := r.db.Where("user_id = ?", userID)
|
||||||
|
if includePending {
|
||||||
|
listQuery = listQuery.Where("status IN ?", []model.PostStatus{
|
||||||
|
model.PostStatusPublished,
|
||||||
|
model.PostStatusPending,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
listQuery = listQuery.Where("status = ?", model.PostStatusPublished)
|
||||||
|
}
|
||||||
|
err := listQuery.Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
|
||||||
|
|
||||||
return posts, total, err
|
return posts, total, err
|
||||||
}
|
}
|
||||||
@@ -256,7 +316,8 @@ func (r *PostRepository) IsFavorited(postID, userID string) bool {
|
|||||||
// IncrementViews 增加帖子观看量
|
// IncrementViews 增加帖子观看量
|
||||||
func (r *PostRepository) IncrementViews(postID string) error {
|
func (r *PostRepository) IncrementViews(postID string) error {
|
||||||
return r.db.Model(&model.Post{}).Where("id = ?", postID).
|
return r.db.Model(&model.Post{}).Where("id = ?", postID).
|
||||||
Updates(map[string]interface{}{
|
// 浏览量属于统计字段,不应影响帖子内容更新时间(updated_at)
|
||||||
|
UpdateColumns(map[string]interface{}{
|
||||||
"views_count": gorm.Expr("views_count + 1"),
|
"views_count": gorm.Expr("views_count + 1"),
|
||||||
"hot_score": gorm.Expr("likes_count * 2 + comments_count * 3 + (views_count + 1) * 0.1"),
|
"hot_score": gorm.Expr("likes_count * 2 + comments_count * 3 + (views_count + 1) * 0.1"),
|
||||||
}).Error
|
}).Error
|
||||||
|
|||||||
@@ -177,7 +177,9 @@ func (r *UserRepository) RefreshFollowersCount(userID string) error {
|
|||||||
// GetPostsCount 获取用户帖子数(实时计算)
|
// GetPostsCount 获取用户帖子数(实时计算)
|
||||||
func (r *UserRepository) GetPostsCount(userID string) (int64, error) {
|
func (r *UserRepository) GetPostsCount(userID string) (int64, error) {
|
||||||
var count int64
|
var count int64
|
||||||
err := r.db.Model(&model.Post{}).Where("user_id = ?", userID).Count(&count).Error
|
err := r.db.Model(&model.Post{}).
|
||||||
|
Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).
|
||||||
|
Count(&count).Error
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -202,7 +204,7 @@ func (r *UserRepository) GetPostsCountBatch(userIDs []string) (map[string]int64,
|
|||||||
var counts []CountResult
|
var counts []CountResult
|
||||||
err := r.db.Model(&model.Post{}).
|
err := r.db.Model(&model.Post{}).
|
||||||
Select("user_id, count(*) as count").
|
Select("user_id, count(*) as count").
|
||||||
Where("user_id IN ?", userIDs).
|
Where("user_id IN ? AND status = ?", userIDs, model.PostStatusPublished).
|
||||||
Group("user_id").
|
Group("user_id").
|
||||||
Scan(&counts).Error
|
Scan(&counts).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ type Router struct {
|
|||||||
messageHandler *handler.MessageHandler
|
messageHandler *handler.MessageHandler
|
||||||
notificationHandler *handler.NotificationHandler
|
notificationHandler *handler.NotificationHandler
|
||||||
uploadHandler *handler.UploadHandler
|
uploadHandler *handler.UploadHandler
|
||||||
wsHandler *handler.WebSocketHandler
|
|
||||||
pushHandler *handler.PushHandler
|
pushHandler *handler.PushHandler
|
||||||
systemMessageHandler *handler.SystemMessageHandler
|
systemMessageHandler *handler.SystemMessageHandler
|
||||||
groupHandler *handler.GroupHandler
|
groupHandler *handler.GroupHandler
|
||||||
@@ -36,7 +35,6 @@ func New(
|
|||||||
notificationHandler *handler.NotificationHandler,
|
notificationHandler *handler.NotificationHandler,
|
||||||
uploadHandler *handler.UploadHandler,
|
uploadHandler *handler.UploadHandler,
|
||||||
jwtService *service.JWTService,
|
jwtService *service.JWTService,
|
||||||
wsHandler *handler.WebSocketHandler,
|
|
||||||
pushHandler *handler.PushHandler,
|
pushHandler *handler.PushHandler,
|
||||||
systemMessageHandler *handler.SystemMessageHandler,
|
systemMessageHandler *handler.SystemMessageHandler,
|
||||||
groupHandler *handler.GroupHandler,
|
groupHandler *handler.GroupHandler,
|
||||||
@@ -55,7 +53,6 @@ func New(
|
|||||||
messageHandler: messageHandler,
|
messageHandler: messageHandler,
|
||||||
notificationHandler: notificationHandler,
|
notificationHandler: notificationHandler,
|
||||||
uploadHandler: uploadHandler,
|
uploadHandler: uploadHandler,
|
||||||
wsHandler: wsHandler,
|
|
||||||
pushHandler: pushHandler,
|
pushHandler: pushHandler,
|
||||||
systemMessageHandler: systemMessageHandler,
|
systemMessageHandler: systemMessageHandler,
|
||||||
groupHandler: groupHandler,
|
groupHandler: groupHandler,
|
||||||
@@ -79,11 +76,6 @@ func (r *Router) setupRoutes() {
|
|||||||
c.JSON(200, gin.H{"status": "ok"})
|
c.JSON(200, gin.H{"status": "ok"})
|
||||||
})
|
})
|
||||||
|
|
||||||
// WebSocket 路由
|
|
||||||
if r.wsHandler != nil {
|
|
||||||
r.engine.GET("/ws", r.wsHandler.HandleWebSocket)
|
|
||||||
}
|
|
||||||
|
|
||||||
// API v1
|
// API v1
|
||||||
v1 := r.engine.Group("/api/v1")
|
v1 := r.engine.Group("/api/v1")
|
||||||
{
|
{
|
||||||
@@ -210,10 +202,18 @@ func (r *Router) setupRoutes() {
|
|||||||
conversations.POST("/set_pinned", r.messageHandler.HandleSetConversationPinned)
|
conversations.POST("/set_pinned", r.messageHandler.HandleSetConversationPinned)
|
||||||
// 获取未读消息总数
|
// 获取未读消息总数
|
||||||
conversations.GET("/unread/count", r.messageHandler.GetUnreadCount)
|
conversations.GET("/unread/count", r.messageHandler.GetUnreadCount)
|
||||||
|
// 上报输入状态
|
||||||
|
conversations.POST("/typing", r.messageHandler.HandleTyping)
|
||||||
// 仅自己删除会话
|
// 仅自己删除会话
|
||||||
conversations.DELETE("/:id/self", r.messageHandler.HandleDeleteConversationForSelf)
|
conversations.DELETE("/:id/self", r.messageHandler.HandleDeleteConversationForSelf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
realtime := v1.Group("/realtime")
|
||||||
|
realtime.Use(authMiddleware)
|
||||||
|
{
|
||||||
|
realtime.GET("/sse", r.messageHandler.HandleSSE)
|
||||||
|
}
|
||||||
|
|
||||||
// 消息操作路由
|
// 消息操作路由
|
||||||
messages := v1.Group("/messages")
|
messages := v1.Group("/messages")
|
||||||
messages.Use(authMiddleware)
|
messages.Use(authMiddleware)
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/dto"
|
||||||
"carrot_bbs/internal/model"
|
"carrot_bbs/internal/model"
|
||||||
"carrot_bbs/internal/pkg/websocket"
|
"carrot_bbs/internal/pkg/sse"
|
||||||
"carrot_bbs/internal/repository"
|
"carrot_bbs/internal/repository"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -41,17 +41,13 @@ type ChatService interface {
|
|||||||
RecallMessage(ctx context.Context, messageID string, userID string) error
|
RecallMessage(ctx context.Context, messageID string, userID string) error
|
||||||
DeleteMessage(ctx context.Context, messageID string, userID string) error
|
DeleteMessage(ctx context.Context, messageID string, userID string) error
|
||||||
|
|
||||||
// WebSocket相关
|
// 实时事件相关
|
||||||
SendTyping(ctx context.Context, senderID string, conversationID string)
|
SendTyping(ctx context.Context, senderID string, conversationID string)
|
||||||
BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string)
|
|
||||||
|
|
||||||
// 系统消息推送
|
// 在线状态
|
||||||
IsUserOnline(userID string) bool
|
IsUserOnline(userID string) bool
|
||||||
PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error
|
|
||||||
PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error
|
|
||||||
PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error
|
|
||||||
|
|
||||||
// 仅保存消息到数据库,不发送 WebSocket 推送(供群聊等自行推送的场景使用)
|
// 仅保存消息到数据库,不发送实时推送(供群聊等自行推送的场景使用)
|
||||||
SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
|
SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,7 +57,7 @@ type chatServiceImpl struct {
|
|||||||
repo *repository.MessageRepository
|
repo *repository.MessageRepository
|
||||||
userRepo *repository.UserRepository
|
userRepo *repository.UserRepository
|
||||||
sensitive SensitiveService
|
sensitive SensitiveService
|
||||||
wsManager *websocket.WebSocketManager
|
sseHub *sse.Hub
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewChatService 创建聊天服务
|
// NewChatService 创建聊天服务
|
||||||
@@ -70,17 +66,24 @@ func NewChatService(
|
|||||||
repo *repository.MessageRepository,
|
repo *repository.MessageRepository,
|
||||||
userRepo *repository.UserRepository,
|
userRepo *repository.UserRepository,
|
||||||
sensitive SensitiveService,
|
sensitive SensitiveService,
|
||||||
wsManager *websocket.WebSocketManager,
|
sseHub *sse.Hub,
|
||||||
) ChatService {
|
) ChatService {
|
||||||
return &chatServiceImpl{
|
return &chatServiceImpl{
|
||||||
db: db,
|
db: db,
|
||||||
repo: repo,
|
repo: repo,
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
sensitive: sensitive,
|
sensitive: sensitive,
|
||||||
wsManager: wsManager,
|
sseHub: sseHub,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *chatServiceImpl) publishSSEToUsers(userIDs []string, event string, payload interface{}) {
|
||||||
|
if s.sseHub == nil || len(userIDs) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.sseHub.PublishToUsers(userIDs, event, payload)
|
||||||
|
}
|
||||||
|
|
||||||
// GetOrCreateConversation 获取或创建私聊会话
|
// GetOrCreateConversation 获取或创建私聊会话
|
||||||
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
|
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
|
||||||
return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
|
return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
|
||||||
@@ -228,30 +231,30 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv
|
|||||||
return nil, fmt.Errorf("failed to save message: %w", err)
|
return nil, fmt.Errorf("failed to save message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 发送消息给接收者
|
// 获取会话中的参与者并发送 SSE
|
||||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, websocket.ChatMessage{
|
|
||||||
ID: message.ID,
|
|
||||||
ConversationID: message.ConversationID,
|
|
||||||
SenderID: senderID,
|
|
||||||
Segments: message.Segments,
|
|
||||||
Seq: message.Seq,
|
|
||||||
CreatedAt: message.CreatedAt.UnixMilli(),
|
|
||||||
})
|
|
||||||
|
|
||||||
// 获取会话中的其他参与者
|
|
||||||
participants, err := s.repo.GetConversationParticipants(conversationID)
|
participants, err := s.repo.GetConversationParticipants(conversationID)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
targetIDs := make([]string, 0, len(participants))
|
||||||
|
for _, p := range participants {
|
||||||
|
targetIDs = append(targetIDs, p.UserID)
|
||||||
|
}
|
||||||
|
detailType := "private"
|
||||||
|
if conv.Type == model.ConversationTypeGroup {
|
||||||
|
detailType = "group"
|
||||||
|
}
|
||||||
|
s.publishSSEToUsers(targetIDs, "chat_message", map[string]interface{}{
|
||||||
|
"detail_type": detailType,
|
||||||
|
"message": dto.ConvertMessageToResponse(message),
|
||||||
|
})
|
||||||
for _, p := range participants {
|
for _, p := range participants {
|
||||||
// 不发给自己
|
|
||||||
if p.UserID == senderID {
|
if p.UserID == senderID {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// 如果接收者在线,发送实时消息
|
if totalUnread, uErr := s.repo.GetAllUnreadCount(p.UserID); uErr == nil {
|
||||||
if s.wsManager != nil {
|
s.publishSSEToUsers([]string{p.UserID}, "conversation_unread", map[string]interface{}{
|
||||||
isOnline := s.wsManager.IsUserOnline(p.UserID)
|
"conversation_id": conversationID,
|
||||||
if isOnline {
|
"total_unread": totalUnread,
|
||||||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
})
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -337,25 +340,33 @@ func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string,
|
|||||||
return fmt.Errorf("failed to update last read seq: %w", err)
|
return fmt.Errorf("failed to update last read seq: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 发送已读回执(作为 meta 事件)
|
participants, pErr := s.repo.GetConversationParticipants(conversationID)
|
||||||
if s.wsManager != nil {
|
if pErr == nil {
|
||||||
wsMsg := websocket.CreateWSMessage("meta", map[string]interface{}{
|
detailType := "private"
|
||||||
"detail_type": websocket.MetaDetailTypeRead,
|
groupID := ""
|
||||||
"conversation_id": conversationID,
|
if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
|
||||||
"seq": seq,
|
detailType = "group"
|
||||||
"user_id": userID,
|
if conv.GroupID != nil {
|
||||||
})
|
groupID = *conv.GroupID
|
||||||
|
|
||||||
// 获取会话中的所有参与者
|
|
||||||
participants, err := s.repo.GetConversationParticipants(conversationID)
|
|
||||||
if err == nil {
|
|
||||||
// 推送给会话中的所有参与者(包括自己)
|
|
||||||
for _, p := range participants {
|
|
||||||
if s.wsManager.IsUserOnline(p.UserID) {
|
|
||||||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
targetIDs := make([]string, 0, len(participants))
|
||||||
|
for _, p := range participants {
|
||||||
|
targetIDs = append(targetIDs, p.UserID)
|
||||||
|
}
|
||||||
|
s.publishSSEToUsers(targetIDs, "message_read", map[string]interface{}{
|
||||||
|
"detail_type": detailType,
|
||||||
|
"conversation_id": conversationID,
|
||||||
|
"group_id": groupID,
|
||||||
|
"user_id": userID,
|
||||||
|
"seq": seq,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if totalUnread, uErr := s.repo.GetAllUnreadCount(userID); uErr == nil {
|
||||||
|
s.publishSSEToUsers([]string{userID}, "conversation_unread", map[string]interface{}{
|
||||||
|
"conversation_id": conversationID,
|
||||||
|
"total_unread": totalUnread,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -407,29 +418,35 @@ func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, u
|
|||||||
return errors.New("message recall timeout (2 minutes)")
|
return errors.New("message recall timeout (2 minutes)")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新消息状态为已撤回
|
// 更新消息状态为已撤回,并清空原始消息内容,仅保留撤回占位
|
||||||
err = s.db.Model(&message).Update("status", model.MessageStatusRecalled).Error
|
err = s.db.Model(&message).Updates(map[string]interface{}{
|
||||||
|
"status": model.MessageStatusRecalled,
|
||||||
|
"segments": model.MessageSegments{},
|
||||||
|
}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to recall message: %w", err)
|
return fmt.Errorf("failed to recall message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 发送撤回通知
|
if participants, pErr := s.repo.GetConversationParticipants(message.ConversationID); pErr == nil {
|
||||||
if s.wsManager != nil {
|
detailType := "private"
|
||||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeRecall, map[string]interface{}{
|
groupID := ""
|
||||||
"messageId": messageID,
|
if conv, convErr := s.repo.GetConversation(message.ConversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
|
||||||
"conversationId": message.ConversationID,
|
detailType = "group"
|
||||||
"senderId": userID,
|
if conv.GroupID != nil {
|
||||||
})
|
groupID = *conv.GroupID
|
||||||
|
|
||||||
// 通知会话中的所有参与者
|
|
||||||
participants, err := s.repo.GetConversationParticipants(message.ConversationID)
|
|
||||||
if err == nil {
|
|
||||||
for _, p := range participants {
|
|
||||||
if s.wsManager.IsUserOnline(p.UserID) {
|
|
||||||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
targetIDs := make([]string, 0, len(participants))
|
||||||
|
for _, p := range participants {
|
||||||
|
targetIDs = append(targetIDs, p.UserID)
|
||||||
|
}
|
||||||
|
s.publishSSEToUsers(targetIDs, "message_recall", map[string]interface{}{
|
||||||
|
"detail_type": detailType,
|
||||||
|
"conversation_id": message.ConversationID,
|
||||||
|
"group_id": groupID,
|
||||||
|
"message_id": messageID,
|
||||||
|
"sender_id": userID,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -473,7 +490,7 @@ func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, u
|
|||||||
|
|
||||||
// SendTyping 发送正在输入状态
|
// SendTyping 发送正在输入状态
|
||||||
func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) {
|
func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) {
|
||||||
if s.wsManager == nil {
|
if s.sseHub == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -489,98 +506,34 @@ func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conve
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
detailType := "private"
|
||||||
|
if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
|
||||||
|
detailType = "group"
|
||||||
|
}
|
||||||
for _, p := range participants {
|
for _, p := range participants {
|
||||||
if p.UserID == senderID {
|
if p.UserID == senderID {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// 发送正在输入状态
|
if s.sseHub != nil {
|
||||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeTyping, map[string]string{
|
s.sseHub.PublishToUser(p.UserID, "typing", map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"detail_type": detailType,
|
||||||
"senderId": senderID,
|
"conversation_id": conversationID,
|
||||||
})
|
"user_id": senderID,
|
||||||
|
"is_typing": true,
|
||||||
if s.wsManager.IsUserOnline(p.UserID) {
|
})
|
||||||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BroadcastMessage 广播消息给用户
|
|
||||||
func (s *chatServiceImpl) BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string) {
|
|
||||||
if s.wsManager != nil {
|
|
||||||
s.wsManager.SendToUser(targetUser, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsUserOnline 检查用户是否在线
|
// IsUserOnline 检查用户是否在线
|
||||||
func (s *chatServiceImpl) IsUserOnline(userID string) bool {
|
func (s *chatServiceImpl) IsUserOnline(userID string) bool {
|
||||||
if s.wsManager == nil {
|
if s.sseHub != nil {
|
||||||
return false
|
return s.sseHub.HasSubscribers(userID)
|
||||||
}
|
}
|
||||||
return s.wsManager.IsUserOnline(userID)
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// PushSystemMessage 推送系统消息给指定用户
|
// SaveMessage 仅保存消息到数据库,不发送实时推送
|
||||||
func (s *chatServiceImpl) PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error {
|
|
||||||
if s.wsManager == nil {
|
|
||||||
return errors.New("websocket manager not available")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.wsManager.IsUserOnline(userID) {
|
|
||||||
return errors.New("user is offline")
|
|
||||||
}
|
|
||||||
|
|
||||||
sysMsg := &websocket.SystemMessage{
|
|
||||||
ID: "", // 由调用方生成
|
|
||||||
Type: msgType,
|
|
||||||
Title: title,
|
|
||||||
Content: content,
|
|
||||||
Data: data,
|
|
||||||
CreatedAt: time.Now().UnixMilli(),
|
|
||||||
}
|
|
||||||
|
|
||||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
|
|
||||||
s.wsManager.SendToUser(userID, wsMsg)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PushNotificationMessage 推送通知消息给指定用户
|
|
||||||
func (s *chatServiceImpl) PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error {
|
|
||||||
if s.wsManager == nil {
|
|
||||||
return errors.New("websocket manager not available")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.wsManager.IsUserOnline(userID) {
|
|
||||||
return errors.New("user is offline")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 确保时间戳已设置
|
|
||||||
if notification.CreatedAt == 0 {
|
|
||||||
notification.CreatedAt = time.Now().UnixMilli()
|
|
||||||
}
|
|
||||||
|
|
||||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
|
|
||||||
s.wsManager.SendToUser(userID, wsMsg)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PushAnnouncementMessage 广播公告消息给所有在线用户
|
|
||||||
func (s *chatServiceImpl) PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error {
|
|
||||||
if s.wsManager == nil {
|
|
||||||
return errors.New("websocket manager not available")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 确保时间戳已设置
|
|
||||||
if announcement.CreatedAt == 0 {
|
|
||||||
announcement.CreatedAt = time.Now().UnixMilli()
|
|
||||||
}
|
|
||||||
|
|
||||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
|
|
||||||
s.wsManager.Broadcast(wsMsg)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveMessage 仅保存消息到数据库,不发送 WebSocket 推送
|
|
||||||
// 适用于群聊等由调用方自行负责推送的场景
|
// 适用于群聊等由调用方自行负责推送的场景
|
||||||
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
|
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
|
||||||
// 验证会话是否存在
|
// 验证会话是否存在
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/cache"
|
||||||
"carrot_bbs/internal/model"
|
"carrot_bbs/internal/model"
|
||||||
"carrot_bbs/internal/pkg/gorse"
|
"carrot_bbs/internal/pkg/gorse"
|
||||||
"carrot_bbs/internal/repository"
|
"carrot_bbs/internal/repository"
|
||||||
@@ -17,6 +18,7 @@ type CommentService struct {
|
|||||||
commentRepo *repository.CommentRepository
|
commentRepo *repository.CommentRepository
|
||||||
postRepo *repository.PostRepository
|
postRepo *repository.PostRepository
|
||||||
systemMessageService SystemMessageService
|
systemMessageService SystemMessageService
|
||||||
|
cache cache.Cache
|
||||||
gorseClient gorse.Client
|
gorseClient gorse.Client
|
||||||
postAIService *PostAIService
|
postAIService *PostAIService
|
||||||
}
|
}
|
||||||
@@ -27,6 +29,7 @@ func NewCommentService(commentRepo *repository.CommentRepository, postRepo *repo
|
|||||||
commentRepo: commentRepo,
|
commentRepo: commentRepo,
|
||||||
postRepo: postRepo,
|
postRepo: postRepo,
|
||||||
systemMessageService: systemMessageService,
|
systemMessageService: systemMessageService,
|
||||||
|
cache: cache.GetCache(),
|
||||||
gorseClient: gorseClient,
|
gorseClient: gorseClient,
|
||||||
postAIService: postAIService,
|
postAIService: postAIService,
|
||||||
}
|
}
|
||||||
@@ -96,6 +99,10 @@ func (s *CommentService) reviewCommentAsync(
|
|||||||
log.Printf("[WARN] Failed to publish comment without AI moderation: %v", err)
|
log.Printf("[WARN] Failed to publish comment without AI moderation: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if err := s.applyCommentPublishedStats(commentID); err != nil {
|
||||||
|
log.Printf("[WARN] Failed to apply published stats for comment %s: %v", commentID, err)
|
||||||
|
}
|
||||||
|
s.invalidatePostCaches(postID)
|
||||||
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -116,6 +123,10 @@ func (s *CommentService) reviewCommentAsync(
|
|||||||
log.Printf("[WARN] Failed to publish comment %s after moderation error: %v", commentID, updateErr)
|
log.Printf("[WARN] Failed to publish comment %s after moderation error: %v", commentID, updateErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if statsErr := s.applyCommentPublishedStats(commentID); statsErr != nil {
|
||||||
|
log.Printf("[WARN] Failed to apply published stats for comment %s: %v", commentID, statsErr)
|
||||||
|
}
|
||||||
|
s.invalidatePostCaches(postID)
|
||||||
log.Printf("[WARN] Comment moderation failed, fallback publish comment=%s err=%v", commentID, err)
|
log.Printf("[WARN] Comment moderation failed, fallback publish comment=%s err=%v", commentID, err)
|
||||||
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
||||||
return
|
return
|
||||||
@@ -125,9 +136,26 @@ func (s *CommentService) reviewCommentAsync(
|
|||||||
log.Printf("[WARN] Failed to publish comment %s: %v", commentID, updateErr)
|
log.Printf("[WARN] Failed to publish comment %s: %v", commentID, updateErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if statsErr := s.applyCommentPublishedStats(commentID); statsErr != nil {
|
||||||
|
log.Printf("[WARN] Failed to apply published stats for comment %s: %v", commentID, statsErr)
|
||||||
|
}
|
||||||
|
s.invalidatePostCaches(postID)
|
||||||
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *CommentService) applyCommentPublishedStats(commentID string) error {
|
||||||
|
comment, err := s.commentRepo.GetByID(commentID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.commentRepo.ApplyPublishedStats(comment)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CommentService) invalidatePostCaches(postID string) {
|
||||||
|
cache.InvalidatePostDetail(s.cache, postID)
|
||||||
|
cache.InvalidatePostList(s.cache)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *CommentService) afterCommentPublished(userID, postID, commentID string, parentID *string, parentUserID, postOwnerID string) {
|
func (s *CommentService) afterCommentPublished(userID, postID, commentID string, parentID *string, parentUserID, postOwnerID string) {
|
||||||
// 发送系统消息通知
|
// 发送系统消息通知
|
||||||
if s.systemMessageService != nil {
|
if s.systemMessageService != nil {
|
||||||
@@ -212,7 +240,15 @@ func (s *CommentService) Update(ctx context.Context, comment *model.Comment) err
|
|||||||
|
|
||||||
// Delete 删除评论
|
// Delete 删除评论
|
||||||
func (s *CommentService) Delete(ctx context.Context, id string) error {
|
func (s *CommentService) Delete(ctx context.Context, id string) error {
|
||||||
return s.commentRepo.Delete(id)
|
comment, err := s.commentRepo.GetByID(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := s.commentRepo.Delete(id); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.invalidatePostCaches(comment.PostID)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Like 点赞评论
|
// Like 点赞评论
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
|
|
||||||
"carrot_bbs/internal/cache"
|
"carrot_bbs/internal/cache"
|
||||||
"carrot_bbs/internal/model"
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/pkg/sse"
|
||||||
"carrot_bbs/internal/pkg/utils"
|
"carrot_bbs/internal/pkg/utils"
|
||||||
"carrot_bbs/internal/pkg/websocket"
|
|
||||||
"carrot_bbs/internal/repository"
|
"carrot_bbs/internal/repository"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -18,7 +18,7 @@ import (
|
|||||||
|
|
||||||
// 缓存TTL常量
|
// 缓存TTL常量
|
||||||
const (
|
const (
|
||||||
GroupMembersTTL = 120 * time.Second // 群组成员缓存120秒
|
GroupMembersTTL = 120 * time.Second // 群组成员缓存120秒
|
||||||
GroupMembersNullTTL = 5 * time.Second
|
GroupMembersNullTTL = 5 * time.Second
|
||||||
GroupCacheJitter = 0.1
|
GroupCacheJitter = 0.1
|
||||||
)
|
)
|
||||||
@@ -99,12 +99,12 @@ type groupService struct {
|
|||||||
messageRepo *repository.MessageRepository
|
messageRepo *repository.MessageRepository
|
||||||
requestRepo repository.GroupJoinRequestRepository
|
requestRepo repository.GroupJoinRequestRepository
|
||||||
notifyRepo *repository.SystemNotificationRepository
|
notifyRepo *repository.SystemNotificationRepository
|
||||||
wsManager *websocket.WebSocketManager
|
sseHub *sse.Hub
|
||||||
cache cache.Cache
|
cache cache.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGroupService 创建群组服务
|
// NewGroupService 创建群组服务
|
||||||
func NewGroupService(db *gorm.DB, groupRepo repository.GroupRepository, userRepo *repository.UserRepository, messageRepo *repository.MessageRepository, wsManager *websocket.WebSocketManager) GroupService {
|
func NewGroupService(db *gorm.DB, groupRepo repository.GroupRepository, userRepo *repository.UserRepository, messageRepo *repository.MessageRepository, sseHub *sse.Hub) GroupService {
|
||||||
return &groupService{
|
return &groupService{
|
||||||
db: db,
|
db: db,
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
@@ -112,11 +112,39 @@ func NewGroupService(db *gorm.DB, groupRepo repository.GroupRepository, userRepo
|
|||||||
messageRepo: messageRepo,
|
messageRepo: messageRepo,
|
||||||
requestRepo: repository.NewGroupJoinRequestRepository(db),
|
requestRepo: repository.NewGroupJoinRequestRepository(db),
|
||||||
notifyRepo: repository.NewSystemNotificationRepository(db),
|
notifyRepo: repository.NewSystemNotificationRepository(db),
|
||||||
wsManager: wsManager,
|
sseHub: sseHub,
|
||||||
cache: cache.GetCache(),
|
cache: cache.GetCache(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type groupNoticeData struct {
|
||||||
|
UserID string `json:"user_id,omitempty"`
|
||||||
|
Username string `json:"username,omitempty"`
|
||||||
|
OperatorID string `json:"operator_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type groupNoticeMessage struct {
|
||||||
|
NoticeType string `json:"notice_type"`
|
||||||
|
GroupID string `json:"group_id"`
|
||||||
|
Data groupNoticeData `json:"data"`
|
||||||
|
Timestamp int64 `json:"timestamp"`
|
||||||
|
MessageID string `json:"message_id,omitempty"`
|
||||||
|
Seq int64 `json:"seq,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *groupService) publishGroupNotice(groupID string, notice groupNoticeMessage) {
|
||||||
|
members, _, err := s.groupRepo.GetMembers(groupID, 1, 1000)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[groupService] 获取群成员失败: groupID=%s, err=%v", groupID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.sseHub != nil {
|
||||||
|
for _, m := range members {
|
||||||
|
s.sseHub.PublishToUser(m.UserID, "group_notice", notice)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ==================== 群组管理 ====================
|
// ==================== 群组管理 ====================
|
||||||
|
|
||||||
// CreateGroup 创建群组
|
// CreateGroup 创建群组
|
||||||
@@ -422,14 +450,10 @@ func (s *groupService) broadcastMemberJoinNotice(groupID string, targetUserID st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.wsManager == nil {
|
noticeMsg := groupNoticeMessage{
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
noticeMsg := websocket.GroupNoticeMessage{
|
|
||||||
NoticeType: "member_join",
|
NoticeType: "member_join",
|
||||||
GroupID: groupID,
|
GroupID: groupID,
|
||||||
Data: websocket.GroupNoticeData{
|
Data: groupNoticeData{
|
||||||
UserID: targetUserID,
|
UserID: targetUserID,
|
||||||
Username: targetUserName,
|
Username: targetUserName,
|
||||||
OperatorID: operatorID,
|
OperatorID: operatorID,
|
||||||
@@ -441,17 +465,7 @@ func (s *groupService) broadcastMemberJoinNotice(groupID string, targetUserID st
|
|||||||
noticeMsg.Seq = savedMessage.Seq
|
noticeMsg.Seq = savedMessage.Seq
|
||||||
}
|
}
|
||||||
|
|
||||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeGroupNotice, noticeMsg)
|
s.publishGroupNotice(groupID, noticeMsg)
|
||||||
members, _, err := s.groupRepo.GetMembers(groupID, 1, 1000)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[broadcastMemberJoinNotice] 获取群成员失败: groupID=%s, err=%v", groupID, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, m := range members {
|
|
||||||
if s.wsManager.IsUserOnline(m.UserID) {
|
|
||||||
s.wsManager.SendToUser(m.UserID, wsMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *groupService) addMemberToGroupAndConversation(group *model.Group, userID string, operatorID string) error {
|
func (s *groupService) addMemberToGroupAndConversation(group *model.Group, userID string, operatorID string) error {
|
||||||
@@ -1282,46 +1296,20 @@ func (s *groupService) MuteMember(userID string, groupID string, targetUserID st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 发送WebSocket通知给群成员
|
noticeMsg := groupNoticeMessage{
|
||||||
if s.wsManager != nil {
|
NoticeType: noticeType,
|
||||||
log.Printf("[MuteMember] 准备发送禁言通知: groupID=%s, targetUserID=%s, noticeType=%s, operatorID=%s", groupID, targetUserID, noticeType, userID)
|
GroupID: groupID,
|
||||||
|
Data: groupNoticeData{
|
||||||
// 构建通知消息,包含保存的消息信息
|
UserID: targetUserID,
|
||||||
noticeMsg := websocket.GroupNoticeMessage{
|
OperatorID: userID,
|
||||||
NoticeType: noticeType,
|
},
|
||||||
GroupID: groupID,
|
Timestamp: time.Now().UnixMilli(),
|
||||||
Data: websocket.GroupNoticeData{
|
|
||||||
UserID: targetUserID,
|
|
||||||
OperatorID: userID,
|
|
||||||
},
|
|
||||||
Timestamp: time.Now().UnixMilli(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果消息已保存,添加消息ID和seq
|
|
||||||
if savedMessage != nil {
|
|
||||||
noticeMsg.MessageID = savedMessage.ID
|
|
||||||
noticeMsg.Seq = savedMessage.Seq
|
|
||||||
}
|
|
||||||
|
|
||||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeGroupNotice, noticeMsg)
|
|
||||||
log.Printf("[MuteMember] 创建的WebSocket消息: Type=%s, Data=%+v", wsMsg.Type, wsMsg.Data)
|
|
||||||
|
|
||||||
// 获取所有群成员并发送通知
|
|
||||||
members, _, err := s.groupRepo.GetMembers(groupID, 1, 1000)
|
|
||||||
if err == nil {
|
|
||||||
log.Printf("[MuteMember] 获取到群成员数量: %d", len(members))
|
|
||||||
for _, m := range members {
|
|
||||||
isOnline := s.wsManager.IsUserOnline(m.UserID)
|
|
||||||
log.Printf("[MuteMember] 成员 %s 在线状态: %v", m.UserID, isOnline)
|
|
||||||
if isOnline {
|
|
||||||
s.wsManager.SendToUser(m.UserID, wsMsg)
|
|
||||||
log.Printf("[MuteMember] 已发送通知给成员: %s", m.UserID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Printf("[MuteMember] 获取群成员失败: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if savedMessage != nil {
|
||||||
|
noticeMsg.MessageID = savedMessage.ID
|
||||||
|
noticeMsg.Seq = savedMessage.Seq
|
||||||
|
}
|
||||||
|
s.publishGroupNotice(groupID, noticeMsg)
|
||||||
|
|
||||||
// 失效群组成员缓存
|
// 失效群组成员缓存
|
||||||
cache.InvalidateGroupMembers(s.cache, groupID)
|
cache.InvalidateGroupMembers(s.cache, groupID)
|
||||||
|
|||||||
@@ -77,6 +77,8 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
|
|||||||
if s.postAIService == nil || !s.postAIService.IsEnabled() {
|
if s.postAIService == nil || !s.postAIService.IsEnabled() {
|
||||||
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil {
|
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil {
|
||||||
log.Printf("[WARN] Failed to publish post without AI moderation: %v", err)
|
log.Printf("[WARN] Failed to publish post without AI moderation: %v", err)
|
||||||
|
} else {
|
||||||
|
s.invalidatePostCaches(postID)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -87,6 +89,8 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
|
|||||||
if errors.As(err, &rejectedErr) {
|
if errors.As(err, &rejectedErr) {
|
||||||
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
|
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
|
||||||
log.Printf("[WARN] Failed to reject post %s: %v", postID, updateErr)
|
log.Printf("[WARN] Failed to reject post %s: %v", postID, updateErr)
|
||||||
|
} else {
|
||||||
|
s.invalidatePostCaches(postID)
|
||||||
}
|
}
|
||||||
s.notifyModerationRejected(userID, rejectedErr.Reason)
|
s.notifyModerationRejected(userID, rejectedErr.Reason)
|
||||||
return
|
return
|
||||||
@@ -95,6 +99,8 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
|
|||||||
// 规则审核不可用时,降级为发布,避免长时间pending
|
// 规则审核不可用时,降级为发布,避免长时间pending
|
||||||
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
|
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
|
||||||
log.Printf("[WARN] Failed to publish post %s after moderation error: %v", postID, updateErr)
|
log.Printf("[WARN] Failed to publish post %s after moderation error: %v", postID, updateErr)
|
||||||
|
} else {
|
||||||
|
s.invalidatePostCaches(postID)
|
||||||
}
|
}
|
||||||
log.Printf("[WARN] Post moderation failed, fallback publish post=%s err=%v", postID, err)
|
log.Printf("[WARN] Post moderation failed, fallback publish post=%s err=%v", postID, err)
|
||||||
return
|
return
|
||||||
@@ -104,6 +110,7 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
|
|||||||
log.Printf("[WARN] Failed to publish post %s: %v", postID, err)
|
log.Printf("[WARN] Failed to publish post %s: %v", postID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
s.invalidatePostCaches(postID)
|
||||||
|
|
||||||
if s.gorseClient.IsEnabled() {
|
if s.gorseClient.IsEnabled() {
|
||||||
post, getErr := s.postRepo.GetByID(postID)
|
post, getErr := s.postRepo.GetByID(postID)
|
||||||
@@ -120,6 +127,11 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *PostService) invalidatePostCaches(postID string) {
|
||||||
|
cache.InvalidatePostDetail(s.cache, postID)
|
||||||
|
cache.InvalidatePostList(s.cache)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *PostService) notifyModerationRejected(userID, reason string) {
|
func (s *PostService) notifyModerationRejected(userID, reason string) {
|
||||||
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
|
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
|
||||||
return
|
return
|
||||||
@@ -149,7 +161,12 @@ func (s *PostService) GetByID(ctx context.Context, id string) (*model.Post, erro
|
|||||||
|
|
||||||
// Update 更新帖子
|
// Update 更新帖子
|
||||||
func (s *PostService) Update(ctx context.Context, post *model.Post) error {
|
func (s *PostService) Update(ctx context.Context, post *model.Post) error {
|
||||||
err := s.postRepo.Update(post)
|
return s.UpdateWithImages(ctx, post, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateWithImages 更新帖子并可选更新图片(images=nil 表示不更新图片)
|
||||||
|
func (s *PostService) UpdateWithImages(ctx context.Context, post *model.Post, images *[]string) error {
|
||||||
|
err := s.postRepo.UpdateWithImages(post, images)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -185,7 +202,7 @@ func (s *PostService) Delete(ctx context.Context, id string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// List 获取帖子列表(带缓存)
|
// List 获取帖子列表(带缓存)
|
||||||
func (s *PostService) List(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
func (s *PostService) List(ctx context.Context, page, pageSize int, userID string, includePending bool) ([]*model.Post, int64, error) {
|
||||||
cacheSettings := cache.GetSettings()
|
cacheSettings := cache.GetSettings()
|
||||||
postListTTL := cacheSettings.PostListTTL
|
postListTTL := cacheSettings.PostListTTL
|
||||||
if postListTTL <= 0 {
|
if postListTTL <= 0 {
|
||||||
@@ -200,8 +217,12 @@ func (s *PostService) List(ctx context.Context, page, pageSize int, userID strin
|
|||||||
jitter = PostListJitterRatio
|
jitter = PostListJitterRatio
|
||||||
}
|
}
|
||||||
|
|
||||||
// 生成缓存键(包含 userID 维度,避免过滤查询与全量查询互相污染)
|
// 生成缓存键(包含 userID 维度与可见性维度,避免作者视角污染公开视角)
|
||||||
cacheKey := cache.PostListKey("latest", userID, page, pageSize)
|
visibilityUserKey := userID
|
||||||
|
if includePending && userID != "" {
|
||||||
|
visibilityUserKey = "owner:" + userID
|
||||||
|
}
|
||||||
|
cacheKey := cache.PostListKey("latest", visibilityUserKey, page, pageSize)
|
||||||
|
|
||||||
result, err := cache.GetOrLoadTyped[*PostListResult](
|
result, err := cache.GetOrLoadTyped[*PostListResult](
|
||||||
s.cache,
|
s.cache,
|
||||||
@@ -210,7 +231,7 @@ func (s *PostService) List(ctx context.Context, page, pageSize int, userID strin
|
|||||||
jitter,
|
jitter,
|
||||||
nullTTL,
|
nullTTL,
|
||||||
func() (*PostListResult, error) {
|
func() (*PostListResult, error) {
|
||||||
posts, total, err := s.postRepo.List(page, pageSize, userID)
|
posts, total, err := s.postRepo.List(page, pageSize, userID, includePending)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -234,7 +255,7 @@ func (s *PostService) List(ctx context.Context, page, pageSize int, userID strin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if missingAuthor {
|
if missingAuthor {
|
||||||
posts, total, loadErr := s.postRepo.List(page, pageSize, userID)
|
posts, total, loadErr := s.postRepo.List(page, pageSize, userID, includePending)
|
||||||
if loadErr != nil {
|
if loadErr != nil {
|
||||||
return nil, 0, loadErr
|
return nil, 0, loadErr
|
||||||
}
|
}
|
||||||
@@ -247,12 +268,17 @@ func (s *PostService) List(ctx context.Context, page, pageSize int, userID strin
|
|||||||
|
|
||||||
// GetLatestPosts 获取最新帖子(语义化别名)
|
// GetLatestPosts 获取最新帖子(语义化别名)
|
||||||
func (s *PostService) GetLatestPosts(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
func (s *PostService) GetLatestPosts(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
||||||
return s.List(ctx, page, pageSize, userID)
|
return s.List(ctx, page, pageSize, userID, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLatestPostsForOwner 获取作者视角帖子列表(包含待审核)
|
||||||
|
func (s *PostService) GetLatestPostsForOwner(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
||||||
|
return s.List(ctx, page, pageSize, userID, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserPosts 获取用户帖子
|
// GetUserPosts 获取用户帖子
|
||||||
func (s *PostService) GetUserPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
func (s *PostService) GetUserPosts(ctx context.Context, userID string, page, pageSize int, includePending bool) ([]*model.Post, int64, error) {
|
||||||
return s.postRepo.GetUserPosts(userID, page, pageSize)
|
return s.postRepo.GetUserPosts(userID, page, pageSize, includePending)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Like 点赞
|
// Like 点赞
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
|
|
||||||
"carrot_bbs/internal/dto"
|
"carrot_bbs/internal/dto"
|
||||||
"carrot_bbs/internal/model"
|
"carrot_bbs/internal/model"
|
||||||
"carrot_bbs/internal/pkg/websocket"
|
"carrot_bbs/internal/pkg/sse"
|
||||||
"carrot_bbs/internal/repository"
|
"carrot_bbs/internal/repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -42,8 +42,6 @@ type PushService interface {
|
|||||||
|
|
||||||
// 系统消息推送
|
// 系统消息推送
|
||||||
PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error
|
PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error
|
||||||
PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error
|
|
||||||
PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error
|
|
||||||
|
|
||||||
// 系统通知推送(新接口,使用独立的 SystemNotification 模型)
|
// 系统通知推送(新接口,使用独立的 SystemNotification 模型)
|
||||||
PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error
|
PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error
|
||||||
@@ -67,7 +65,7 @@ type pushServiceImpl struct {
|
|||||||
pushRepo *repository.PushRecordRepository
|
pushRepo *repository.PushRecordRepository
|
||||||
deviceRepo *repository.DeviceTokenRepository
|
deviceRepo *repository.DeviceTokenRepository
|
||||||
messageRepo *repository.MessageRepository
|
messageRepo *repository.MessageRepository
|
||||||
wsManager *websocket.WebSocketManager
|
sseHub *sse.Hub
|
||||||
|
|
||||||
// 推送队列
|
// 推送队列
|
||||||
pushQueue chan *pushTask
|
pushQueue chan *pushTask
|
||||||
@@ -86,13 +84,13 @@ func NewPushService(
|
|||||||
pushRepo *repository.PushRecordRepository,
|
pushRepo *repository.PushRecordRepository,
|
||||||
deviceRepo *repository.DeviceTokenRepository,
|
deviceRepo *repository.DeviceTokenRepository,
|
||||||
messageRepo *repository.MessageRepository,
|
messageRepo *repository.MessageRepository,
|
||||||
wsManager *websocket.WebSocketManager,
|
sseHub *sse.Hub,
|
||||||
) PushService {
|
) PushService {
|
||||||
return &pushServiceImpl{
|
return &pushServiceImpl{
|
||||||
pushRepo: pushRepo,
|
pushRepo: pushRepo,
|
||||||
deviceRepo: deviceRepo,
|
deviceRepo: deviceRepo,
|
||||||
messageRepo: messageRepo,
|
messageRepo: messageRepo,
|
||||||
wsManager: wsManager,
|
sseHub: sseHub,
|
||||||
pushQueue: make(chan *pushTask, PushQueueSize),
|
pushQueue: make(chan *pushTask, PushQueueSize),
|
||||||
stopChan: make(chan struct{}),
|
stopChan: make(chan struct{}),
|
||||||
}
|
}
|
||||||
@@ -140,11 +138,7 @@ func (s *pushServiceImpl) PushToUser(ctx context.Context, userID string, message
|
|||||||
// pushViaWebSocket 通过WebSocket推送消息
|
// pushViaWebSocket 通过WebSocket推送消息
|
||||||
// 返回true表示推送成功,false表示用户不在线
|
// 返回true表示推送成功,false表示用户不在线
|
||||||
func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, message *model.Message) bool {
|
func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, message *model.Message) bool {
|
||||||
if s.wsManager == nil {
|
if s.sseHub == nil || !s.sseHub.HasSubscribers(userID) {
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.wsManager.IsUserOnline(userID) {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,36 +148,33 @@ func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, m
|
|||||||
// 从 segments 中提取文本内容
|
// 从 segments 中提取文本内容
|
||||||
content := dto.ExtractTextContentFromModel(message.Segments)
|
content := dto.ExtractTextContentFromModel(message.Segments)
|
||||||
|
|
||||||
notification := &websocket.NotificationMessage{
|
notification := map[string]interface{}{
|
||||||
ID: fmt.Sprintf("%s", message.ID),
|
"id": fmt.Sprintf("%s", message.ID),
|
||||||
Type: string(message.SystemType),
|
"type": string(message.SystemType),
|
||||||
Content: content,
|
"content": content,
|
||||||
Extra: make(map[string]interface{}),
|
"extra": map[string]interface{}{},
|
||||||
CreatedAt: message.CreatedAt.UnixMilli(),
|
"created_at": message.CreatedAt.UnixMilli(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// 填充额外数据
|
// 填充额外数据
|
||||||
if message.ExtraData != nil {
|
if message.ExtraData != nil {
|
||||||
notification.Extra["actor_id"] = message.ExtraData.ActorID
|
extra := notification["extra"].(map[string]interface{})
|
||||||
notification.Extra["actor_name"] = message.ExtraData.ActorName
|
extra["actor_id"] = message.ExtraData.ActorID
|
||||||
notification.Extra["avatar_url"] = message.ExtraData.AvatarURL
|
extra["actor_name"] = message.ExtraData.ActorName
|
||||||
notification.Extra["target_id"] = message.ExtraData.TargetID
|
extra["avatar_url"] = message.ExtraData.AvatarURL
|
||||||
notification.Extra["target_type"] = message.ExtraData.TargetType
|
extra["target_id"] = message.ExtraData.TargetID
|
||||||
notification.Extra["action_url"] = message.ExtraData.ActionURL
|
extra["target_type"] = message.ExtraData.TargetType
|
||||||
notification.Extra["action_time"] = message.ExtraData.ActionTime
|
extra["action_url"] = message.ExtraData.ActionURL
|
||||||
|
extra["action_time"] = message.ExtraData.ActionTime
|
||||||
// 设置触发用户信息
|
|
||||||
if message.ExtraData.ActorID > 0 {
|
if message.ExtraData.ActorID > 0 {
|
||||||
notification.TriggerUser = &websocket.NotificationUser{
|
notification["trigger_user"] = map[string]interface{}{
|
||||||
ID: fmt.Sprintf("%d", message.ExtraData.ActorID),
|
"id": fmt.Sprintf("%d", message.ExtraData.ActorID),
|
||||||
Username: message.ExtraData.ActorName,
|
"username": message.ExtraData.ActorName,
|
||||||
Avatar: message.ExtraData.AvatarURL,
|
"avatar": message.ExtraData.AvatarURL,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
s.sseHub.PublishToUser(userID, "system_notification", notification)
|
||||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
|
|
||||||
s.wsManager.SendToUser(userID, wsMsg)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -208,8 +199,10 @@ func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, m
|
|||||||
SenderID: message.SenderID,
|
SenderID: message.SenderID,
|
||||||
}
|
}
|
||||||
|
|
||||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, event)
|
s.sseHub.PublishToUser(userID, "chat_message", map[string]interface{}{
|
||||||
s.wsManager.SendToUser(userID, wsMsg)
|
"detail_type": detailType,
|
||||||
|
"message": event,
|
||||||
|
})
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -451,73 +444,21 @@ func (s *pushServiceImpl) PushSystemMessage(ctx context.Context, userID string,
|
|||||||
|
|
||||||
// pushSystemViaWebSocket 通过WebSocket推送系统消息
|
// pushSystemViaWebSocket 通过WebSocket推送系统消息
|
||||||
func (s *pushServiceImpl) pushSystemViaWebSocket(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) bool {
|
func (s *pushServiceImpl) pushSystemViaWebSocket(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) bool {
|
||||||
if s.wsManager == nil {
|
if s.sseHub == nil || !s.sseHub.HasSubscribers(userID) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.wsManager.IsUserOnline(userID) {
|
sysMsg := map[string]interface{}{
|
||||||
return false
|
"type": msgType,
|
||||||
|
"title": title,
|
||||||
|
"content": content,
|
||||||
|
"data": data,
|
||||||
|
"created_at": time.Now().UnixMilli(),
|
||||||
}
|
}
|
||||||
|
s.sseHub.PublishToUser(userID, "system_notification", sysMsg)
|
||||||
sysMsg := &websocket.SystemMessage{
|
|
||||||
Type: msgType,
|
|
||||||
Title: title,
|
|
||||||
Content: content,
|
|
||||||
Data: data,
|
|
||||||
CreatedAt: time.Now().UnixMilli(),
|
|
||||||
}
|
|
||||||
|
|
||||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
|
|
||||||
s.wsManager.SendToUser(userID, wsMsg)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// PushNotification 推送通知消息
|
|
||||||
func (s *pushServiceImpl) PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error {
|
|
||||||
// 首先尝试WebSocket推送
|
|
||||||
if s.pushNotificationViaWebSocket(ctx, userID, notification) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 用户不在线,创建待推送记录
|
|
||||||
// 通知消息可以等用户上线后拉取
|
|
||||||
return errors.New("user is offline, notification will be available on next sync")
|
|
||||||
}
|
|
||||||
|
|
||||||
// pushNotificationViaWebSocket 通过WebSocket推送通知消息
|
|
||||||
func (s *pushServiceImpl) pushNotificationViaWebSocket(ctx context.Context, userID string, notification *websocket.NotificationMessage) bool {
|
|
||||||
if s.wsManager == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.wsManager.IsUserOnline(userID) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if notification.CreatedAt == 0 {
|
|
||||||
notification.CreatedAt = time.Now().UnixMilli()
|
|
||||||
}
|
|
||||||
|
|
||||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
|
|
||||||
s.wsManager.SendToUser(userID, wsMsg)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// PushAnnouncement 广播公告消息
|
|
||||||
func (s *pushServiceImpl) PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error {
|
|
||||||
if s.wsManager == nil {
|
|
||||||
return errors.New("websocket manager not available")
|
|
||||||
}
|
|
||||||
|
|
||||||
if announcement.CreatedAt == 0 {
|
|
||||||
announcement.CreatedAt = time.Now().UnixMilli()
|
|
||||||
}
|
|
||||||
|
|
||||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
|
|
||||||
s.wsManager.Broadcast(wsMsg)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PushSystemNotification 推送系统通知(使用独立的 SystemNotification 模型)
|
// PushSystemNotification 推送系统通知(使用独立的 SystemNotification 模型)
|
||||||
func (s *pushServiceImpl) PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error {
|
func (s *pushServiceImpl) PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error {
|
||||||
// 首先尝试WebSocket推送
|
// 首先尝试WebSocket推送
|
||||||
@@ -531,45 +472,40 @@ func (s *pushServiceImpl) PushSystemNotification(ctx context.Context, userID str
|
|||||||
|
|
||||||
// pushSystemNotificationViaWebSocket 通过WebSocket推送系统通知
|
// pushSystemNotificationViaWebSocket 通过WebSocket推送系统通知
|
||||||
func (s *pushServiceImpl) pushSystemNotificationViaWebSocket(ctx context.Context, userID string, notification *model.SystemNotification) bool {
|
func (s *pushServiceImpl) pushSystemNotificationViaWebSocket(ctx context.Context, userID string, notification *model.SystemNotification) bool {
|
||||||
if s.wsManager == nil {
|
if s.sseHub == nil || !s.sseHub.HasSubscribers(userID) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.wsManager.IsUserOnline(userID) {
|
sseNotification := map[string]interface{}{
|
||||||
return false
|
"id": fmt.Sprintf("%d", notification.ID),
|
||||||
}
|
"type": string(notification.Type),
|
||||||
|
"title": notification.Title,
|
||||||
// 构建 WebSocket 通知消息
|
"content": notification.Content,
|
||||||
wsNotification := &websocket.NotificationMessage{
|
"extra": map[string]interface{}{},
|
||||||
ID: fmt.Sprintf("%d", notification.ID),
|
"created_at": notification.CreatedAt.UnixMilli(),
|
||||||
Type: string(notification.Type),
|
|
||||||
Title: notification.Title,
|
|
||||||
Content: notification.Content,
|
|
||||||
Extra: make(map[string]interface{}),
|
|
||||||
CreatedAt: notification.CreatedAt.UnixMilli(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 填充额外数据
|
// 填充额外数据
|
||||||
if notification.ExtraData != nil {
|
if notification.ExtraData != nil {
|
||||||
wsNotification.Extra["actor_id_str"] = notification.ExtraData.ActorIDStr
|
extra := sseNotification["extra"].(map[string]interface{})
|
||||||
wsNotification.Extra["actor_name"] = notification.ExtraData.ActorName
|
extra["actor_id_str"] = notification.ExtraData.ActorIDStr
|
||||||
wsNotification.Extra["avatar_url"] = notification.ExtraData.AvatarURL
|
extra["actor_name"] = notification.ExtraData.ActorName
|
||||||
wsNotification.Extra["target_id"] = notification.ExtraData.TargetID
|
extra["avatar_url"] = notification.ExtraData.AvatarURL
|
||||||
wsNotification.Extra["target_type"] = notification.ExtraData.TargetType
|
extra["target_id"] = notification.ExtraData.TargetID
|
||||||
wsNotification.Extra["action_url"] = notification.ExtraData.ActionURL
|
extra["target_type"] = notification.ExtraData.TargetType
|
||||||
wsNotification.Extra["action_time"] = notification.ExtraData.ActionTime
|
extra["action_url"] = notification.ExtraData.ActionURL
|
||||||
|
extra["action_time"] = notification.ExtraData.ActionTime
|
||||||
|
|
||||||
// 设置触发用户信息
|
// 设置触发用户信息
|
||||||
if notification.ExtraData.ActorIDStr != "" {
|
if notification.ExtraData.ActorIDStr != "" {
|
||||||
wsNotification.TriggerUser = &websocket.NotificationUser{
|
sseNotification["trigger_user"] = map[string]interface{}{
|
||||||
ID: notification.ExtraData.ActorIDStr,
|
"id": notification.ExtraData.ActorIDStr,
|
||||||
Username: notification.ExtraData.ActorName,
|
"username": notification.ExtraData.ActorName,
|
||||||
Avatar: notification.ExtraData.AvatarURL,
|
"avatar": notification.ExtraData.AvatarURL,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, wsNotification)
|
s.sseHub.PublishToUser(userID, "system_notification", sseNotification)
|
||||||
s.wsManager.SendToUser(userID, wsMsg)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user