Initial backend repository commit.
Set up project files and add .gitignore to exclude local build/runtime artifacts. Made-with: Cursor
This commit is contained in:
253
internal/handler/comment_handler.go
Normal file
253
internal/handler/comment_handler.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// CommentHandler 评论处理器
|
||||
type CommentHandler struct {
|
||||
commentService *service.CommentService
|
||||
}
|
||||
|
||||
// NewCommentHandler 创建评论处理器
|
||||
func NewCommentHandler(commentService *service.CommentService) *CommentHandler {
|
||||
return &CommentHandler{
|
||||
commentService: commentService,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建评论
|
||||
func (h *CommentHandler) Create(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
type CreateRequest struct {
|
||||
PostID string `json:"post_id" binding:"required"`
|
||||
Content string `json:"content"` // 内容可选,允许只发图片
|
||||
ParentID *string `json:"parent_id"`
|
||||
Images []string `json:"images"` // 图片URL列表
|
||||
}
|
||||
|
||||
var req CreateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证:评论必须有内容或图片
|
||||
if req.Content == "" && len(req.Images) == 0 {
|
||||
response.BadRequest(c, "评论内容或图片不能同时为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 将图片列表转换为JSON字符串
|
||||
var imagesJSON string
|
||||
if len(req.Images) > 0 {
|
||||
imagesBytes, _ := json.Marshal(req.Images)
|
||||
imagesJSON = string(imagesBytes)
|
||||
}
|
||||
|
||||
comment, err := h.commentService.Create(c.Request.Context(), req.PostID, userID, req.Content, req.ParentID, imagesJSON, req.Images)
|
||||
if err != nil {
|
||||
var moderationErr *service.CommentModerationRejectedError
|
||||
if errors.As(err, &moderationErr) {
|
||||
response.BadRequest(c, moderationErr.UserMessage())
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to create comment")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ConvertCommentToResponse(comment, false))
|
||||
}
|
||||
|
||||
// GetByID 获取单条评论详情
|
||||
func (h *CommentHandler) GetByID(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
id := c.Param("id")
|
||||
|
||||
comment, err := h.commentService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "comment not found")
|
||||
return
|
||||
}
|
||||
|
||||
resp := dto.ConvertCommentToResponse(comment, h.commentService.IsLiked(c.Request.Context(), id, userID))
|
||||
response.Success(c, resp)
|
||||
}
|
||||
|
||||
// GetByPostID 获取帖子评论
|
||||
func (h *CommentHandler) GetByPostID(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
postID := c.Param("id")
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
comments, total, err := h.commentService.GetByPostID(c.Request.Context(), postID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get comments")
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应结构,检查每个评论的点赞状态
|
||||
commentResponses := dto.ConvertCommentsToResponseWithUser(comments, userID, h.commentService)
|
||||
|
||||
response.Paginated(c, commentResponses, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetReplies 获取回复
|
||||
func (h *CommentHandler) GetReplies(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
parentID := c.Param("id")
|
||||
|
||||
comments, err := h.commentService.GetReplies(c.Request.Context(), parentID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get replies")
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应结构,检查每个回复的点赞状态
|
||||
commentResponses := dto.ConvertCommentsToResponseWithUser(comments, userID, h.commentService)
|
||||
|
||||
response.Success(c, commentResponses)
|
||||
}
|
||||
|
||||
// GetRepliesByRootID 根据根评论ID分页获取回复(扁平化)
|
||||
func (h *CommentHandler) GetRepliesByRootID(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
rootID := c.Param("id")
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10"))
|
||||
|
||||
replies, total, err := h.commentService.GetRepliesByRootID(c.Request.Context(), rootID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get replies")
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应结构,检查每个回复的点赞状态
|
||||
replyResponses := dto.ConvertCommentsToResponseWithUser(replies, userID, h.commentService)
|
||||
|
||||
response.Paginated(c, replyResponses, total, page, pageSize)
|
||||
}
|
||||
|
||||
// Update 更新评论
|
||||
func (h *CommentHandler) Update(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
comment, err := h.commentService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "comment not found")
|
||||
return
|
||||
}
|
||||
|
||||
if comment.UserID != userID {
|
||||
response.Forbidden(c, "cannot update others' comment")
|
||||
return
|
||||
}
|
||||
|
||||
type UpdateRequest struct {
|
||||
Content string `json:"content" binding:"required"`
|
||||
}
|
||||
|
||||
var req UpdateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
comment.Content = req.Content
|
||||
|
||||
err = h.commentService.Update(c.Request.Context(), comment)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to update comment")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ConvertCommentToResponse(comment, false))
|
||||
}
|
||||
|
||||
// Delete 删除评论
|
||||
func (h *CommentHandler) Delete(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
comment, err := h.commentService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "comment not found")
|
||||
return
|
||||
}
|
||||
|
||||
if comment.UserID != userID {
|
||||
response.Forbidden(c, "cannot delete others' comment")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.commentService.Delete(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to delete comment")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "comment deleted", nil)
|
||||
}
|
||||
|
||||
// Like 点赞评论
|
||||
func (h *CommentHandler) Like(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
err := h.commentService.Like(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to like comment")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "liked", nil)
|
||||
}
|
||||
|
||||
// Unlike 取消点赞评论
|
||||
func (h *CommentHandler) Unlike(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
err := h.commentService.Unlike(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to unlike comment")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "unliked", nil)
|
||||
}
|
||||
234
internal/handler/gorse_handler.go
Normal file
234
internal/handler/gorse_handler.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/config"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/gorse"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
|
||||
gorseio "github.com/gorse-io/gorse-go"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GorseHandler Gorse推荐处理器
|
||||
type GorseHandler struct {
|
||||
importPassword string
|
||||
gorseConfig config.GorseConfig
|
||||
}
|
||||
|
||||
// NewGorseHandler 创建Gorse处理器
|
||||
func NewGorseHandler(cfg config.GorseConfig) *GorseHandler {
|
||||
return &GorseHandler{
|
||||
importPassword: cfg.ImportPassword,
|
||||
gorseConfig: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// ImportRequest 导入请求
|
||||
type ImportRequest struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// ImportData 导入数据到Gorse
|
||||
// POST /api/v1/gorse/import
|
||||
func (h *GorseHandler) ImportData(c *gin.Context) {
|
||||
// 验证密码
|
||||
if h.importPassword == "" {
|
||||
response.BadRequest(c, "Gorse import is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
var req ImportRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Password != h.importPassword {
|
||||
response.Unauthorized(c, "invalid password")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
stats, err := h.importAllData(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] gorse import failed: %v", err)
|
||||
response.InternalServerError(c, "gorse import failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"message": "import completed",
|
||||
"status": "done",
|
||||
"stats": stats,
|
||||
})
|
||||
}
|
||||
|
||||
// GetStatus 获取Gorse状态
|
||||
// GET /api/v1/gorse/status
|
||||
func (h *GorseHandler) GetStatus(c *gin.Context) {
|
||||
// 返回Gorse连接状态和配置信息
|
||||
hasPassword := h.importPassword != ""
|
||||
response.Success(c, gin.H{
|
||||
"enabled": h.gorseConfig.Enabled,
|
||||
"has_password": hasPassword,
|
||||
"address": h.gorseConfig.Address,
|
||||
"api_key": strings.Repeat("*", 8), // 不返回实际APIKey
|
||||
})
|
||||
}
|
||||
|
||||
func (h *GorseHandler) importAllData(ctx context.Context) (gin.H, error) {
|
||||
gorseClient := gorseio.NewGorseClient(h.gorseConfig.Address, h.gorseConfig.APIKey)
|
||||
gorse.InitEmbeddingWithConfig(h.gorseConfig.EmbeddingAPIKey, h.gorseConfig.EmbeddingURL, h.gorseConfig.EmbeddingModel)
|
||||
|
||||
stats := gin.H{
|
||||
"items": 0,
|
||||
"users": 0,
|
||||
"likes": 0,
|
||||
"favorites": 0,
|
||||
"comments": 0,
|
||||
}
|
||||
|
||||
// 导入帖子
|
||||
var posts []model.Post
|
||||
if err := model.DB.Find(&posts).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, post := range posts {
|
||||
embedding, err := gorse.GetEmbedding(strings.TrimSpace(post.Title + " " + post.Content))
|
||||
if err != nil {
|
||||
log.Printf("[WARN] get embedding failed for post %s: %v", post.ID, err)
|
||||
embedding = make([]float64, 1024)
|
||||
}
|
||||
_, err = gorseClient.InsertItem(ctx, gorseio.Item{
|
||||
ItemId: post.ID,
|
||||
IsHidden: post.DeletedAt.Valid,
|
||||
Categories: buildPostCategories(&post),
|
||||
Comment: post.Title,
|
||||
Timestamp: post.CreatedAt.UTC().Truncate(time.Second),
|
||||
Labels: map[string]any{
|
||||
"embedding": embedding,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("[WARN] import item failed (%s): %v", post.ID, err)
|
||||
continue
|
||||
}
|
||||
stats["items"] = stats["items"].(int) + 1
|
||||
}
|
||||
|
||||
// 导入用户
|
||||
var users []model.User
|
||||
if err := model.DB.Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, user := range users {
|
||||
_, err := gorseClient.InsertUser(ctx, gorseio.User{
|
||||
UserId: user.ID,
|
||||
Labels: map[string]any{
|
||||
"posts_count": user.PostsCount,
|
||||
"followers_count": user.FollowersCount,
|
||||
"following_count": user.FollowingCount,
|
||||
},
|
||||
Comment: user.Nickname,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("[WARN] import user failed (%s): %v", user.ID, err)
|
||||
continue
|
||||
}
|
||||
stats["users"] = stats["users"].(int) + 1
|
||||
}
|
||||
|
||||
// 导入点赞
|
||||
var likes []model.PostLike
|
||||
if err := model.DB.Find(&likes).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, like := range likes {
|
||||
_, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{
|
||||
FeedbackType: string(gorse.FeedbackTypeLike),
|
||||
UserId: like.UserID,
|
||||
ItemId: like.PostID,
|
||||
Timestamp: like.CreatedAt.UTC().Truncate(time.Second),
|
||||
}})
|
||||
if err != nil {
|
||||
log.Printf("[WARN] import like failed (%s/%s): %v", like.UserID, like.PostID, err)
|
||||
continue
|
||||
}
|
||||
stats["likes"] = stats["likes"].(int) + 1
|
||||
}
|
||||
|
||||
// 导入收藏
|
||||
var favorites []model.Favorite
|
||||
if err := model.DB.Find(&favorites).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, fav := range favorites {
|
||||
_, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{
|
||||
FeedbackType: string(gorse.FeedbackTypeStar),
|
||||
UserId: fav.UserID,
|
||||
ItemId: fav.PostID,
|
||||
Timestamp: fav.CreatedAt.UTC().Truncate(time.Second),
|
||||
}})
|
||||
if err != nil {
|
||||
log.Printf("[WARN] import favorite failed (%s/%s): %v", fav.UserID, fav.PostID, err)
|
||||
continue
|
||||
}
|
||||
stats["favorites"] = stats["favorites"].(int) + 1
|
||||
}
|
||||
|
||||
// 导入评论(按用户-帖子去重)
|
||||
var comments []model.Comment
|
||||
if err := model.DB.Where("status = ?", model.CommentStatusPublished).Find(&comments).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
seen := make(map[string]struct{})
|
||||
for _, cm := range comments {
|
||||
key := cm.UserID + ":" + cm.PostID
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
_, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{
|
||||
FeedbackType: string(gorse.FeedbackTypeComment),
|
||||
UserId: cm.UserID,
|
||||
ItemId: cm.PostID,
|
||||
Timestamp: cm.CreatedAt.UTC().Truncate(time.Second),
|
||||
}})
|
||||
if err != nil {
|
||||
log.Printf("[WARN] import comment failed (%s/%s): %v", cm.UserID, cm.PostID, err)
|
||||
continue
|
||||
}
|
||||
stats["comments"] = stats["comments"].(int) + 1
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func buildPostCategories(post *model.Post) []string {
|
||||
var categories []string
|
||||
if post.ViewsCount > 1000 {
|
||||
categories = append(categories, "hot_high")
|
||||
} else if post.ViewsCount > 100 {
|
||||
categories = append(categories, "hot_medium")
|
||||
}
|
||||
if post.LikesCount > 100 {
|
||||
categories = append(categories, "likes_100+")
|
||||
} else if post.LikesCount > 10 {
|
||||
categories = append(categories, "likes_10+")
|
||||
}
|
||||
age := time.Since(post.CreatedAt)
|
||||
if age < 24*time.Hour {
|
||||
categories = append(categories, "today")
|
||||
} else if age < 7*24*time.Hour {
|
||||
categories = append(categories, "this_week")
|
||||
}
|
||||
return categories
|
||||
}
|
||||
1801
internal/handler/group_handler.go
Normal file
1801
internal/handler/group_handler.go
Normal file
File diff suppressed because it is too large
Load Diff
879
internal/handler/message_handler.go
Normal file
879
internal/handler/message_handler.go
Normal file
@@ -0,0 +1,879 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// MessageHandler 消息处理器
|
||||
type MessageHandler struct {
|
||||
chatService service.ChatService
|
||||
messageService *service.MessageService
|
||||
userService *service.UserService
|
||||
groupService service.GroupService
|
||||
}
|
||||
|
||||
// NewMessageHandler 创建消息处理器
|
||||
func NewMessageHandler(chatService service.ChatService, messageService *service.MessageService, userService *service.UserService, groupService service.GroupService) *MessageHandler {
|
||||
return &MessageHandler{
|
||||
chatService: chatService,
|
||||
messageService: messageService,
|
||||
userService: userService,
|
||||
groupService: groupService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetConversations 获取会话列表
|
||||
// GET /api/conversations
|
||||
func (h *MessageHandler) GetConversations(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
// 添加调试日志
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
convs, _, err := h.chatService.GetConversationList(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get conversations")
|
||||
return
|
||||
}
|
||||
|
||||
// 过滤掉系统会话(系统通知现在使用独立的表)
|
||||
filteredConvs := make([]*model.Conversation, 0)
|
||||
for _, conv := range convs {
|
||||
if conv.ID != model.SystemConversationID {
|
||||
filteredConvs = append(filteredConvs, conv)
|
||||
}
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := make([]*dto.ConversationResponse, len(filteredConvs))
|
||||
for i, conv := range filteredConvs {
|
||||
// 获取未读数
|
||||
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conv.ID, userID)
|
||||
|
||||
// 获取最后一条消息
|
||||
var lastMessage *model.Message
|
||||
messages, _, _ := h.chatService.GetMessages(c.Request.Context(), conv.ID, userID, 1, 1)
|
||||
if len(messages) > 0 {
|
||||
lastMessage = messages[0]
|
||||
}
|
||||
|
||||
// 群聊时返回member_count,私聊时返回participants
|
||||
var resp *dto.ConversationResponse
|
||||
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
|
||||
isPinned := myParticipant != nil && myParticipant.IsPinned
|
||||
if conv.Type == model.ConversationTypeGroup && conv.GroupID != nil && *conv.GroupID != "" {
|
||||
// 群聊:实时计算群成员数量
|
||||
memberCount, _ := h.groupService.GetMemberCount(*conv.GroupID)
|
||||
// 创建响应并设置member_count
|
||||
resp = dto.ConvertConversationToResponse(conv, nil, int(unreadCount), lastMessage, isPinned)
|
||||
resp.MemberCount = memberCount
|
||||
} else {
|
||||
// 私聊:获取参与者信息
|
||||
participants, _ := h.getConversationParticipants(c.Request.Context(), conv.ID, userID)
|
||||
resp = dto.ConvertConversationToResponse(conv, participants, int(unreadCount), lastMessage, isPinned)
|
||||
}
|
||||
result[i] = resp
|
||||
}
|
||||
|
||||
// 更新 total 为过滤后的数量
|
||||
response.Paginated(c, result, int64(len(filteredConvs)), page, pageSize)
|
||||
}
|
||||
|
||||
// CreateConversation 创建私聊会话
|
||||
// POST /api/conversations
|
||||
func (h *MessageHandler) CreateConversation(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var req dto.CreateConversationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证目标用户是否存在
|
||||
targetUser, err := h.userService.GetUserByID(c.Request.Context(), req.UserID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "target user not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 不能和自己创建会话
|
||||
if userID == req.UserID {
|
||||
response.BadRequest(c, "cannot create conversation with yourself")
|
||||
return
|
||||
}
|
||||
|
||||
conv, err := h.chatService.GetOrCreateConversation(c.Request.Context(), userID, req.UserID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to create conversation")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取参与者信息
|
||||
participants := []*model.User{targetUser}
|
||||
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
|
||||
isPinned := myParticipant != nil && myParticipant.IsPinned
|
||||
|
||||
response.Success(c, dto.ConvertConversationToResponse(conv, participants, 0, nil, isPinned))
|
||||
}
|
||||
|
||||
// GetConversationByID 获取会话详情
|
||||
// GET /api/conversations/:id
|
||||
func (h *MessageHandler) GetConversationByID(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr := c.Param("id")
|
||||
fmt.Printf("[DEBUG] GetConversationByID: conversationIDStr = %s\n", conversationIDStr)
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] GetConversationByID: failed to parse conversation ID: %v\n", err)
|
||||
response.BadRequest(c, "invalid conversation id")
|
||||
return
|
||||
}
|
||||
fmt.Printf("[DEBUG] GetConversationByID: conversationID = %s\n", conversationID)
|
||||
|
||||
conv, err := h.chatService.GetConversationByID(c.Request.Context(), conversationID, userID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取未读数
|
||||
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
|
||||
|
||||
// 获取参与者信息
|
||||
participants, _ := h.getConversationParticipants(c.Request.Context(), conversationID, userID)
|
||||
|
||||
// 获取当前用户的已读位置
|
||||
myLastReadSeq := int64(0)
|
||||
isPinned := false
|
||||
allParticipants, _ := h.messageService.GetConversationParticipants(conversationID)
|
||||
for _, p := range allParticipants {
|
||||
if p.UserID == userID {
|
||||
myLastReadSeq = p.LastReadSeq
|
||||
isPinned = p.IsPinned
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 获取对方用户的已读位置
|
||||
otherLastReadSeq := int64(0)
|
||||
response.Success(c, dto.ConvertConversationToDetailResponse(conv, participants, unreadCount, nil, myLastReadSeq, otherLastReadSeq, isPinned))
|
||||
}
|
||||
|
||||
// GetMessages 获取消息列表
|
||||
// GET /api/conversations/:id/messages
|
||||
func (h *MessageHandler) GetMessages(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr := c.Param("id")
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid conversation id")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否使用增量同步(after_seq参数)
|
||||
afterSeqStr := c.Query("after_seq")
|
||||
if afterSeqStr != "" {
|
||||
// 增量同步模式
|
||||
afterSeq, err := strconv.ParseInt(afterSeqStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid after_seq")
|
||||
return
|
||||
}
|
||||
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||
|
||||
messages, err := h.chatService.GetMessagesAfterSeq(c.Request.Context(), conversationID, userID, afterSeq, limit)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := dto.ConvertMessagesToResponse(messages)
|
||||
|
||||
response.Success(c, &dto.MessageSyncResponse{
|
||||
Messages: result,
|
||||
HasMore: len(messages) == limit,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否使用历史消息加载(before_seq参数)
|
||||
beforeSeqStr := c.Query("before_seq")
|
||||
if beforeSeqStr != "" {
|
||||
// 加载更早的消息(下拉加载更多)
|
||||
beforeSeq, err := strconv.ParseInt(beforeSeqStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid before_seq")
|
||||
return
|
||||
}
|
||||
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||
|
||||
messages, err := h.chatService.GetMessagesBeforeSeq(c.Request.Context(), conversationID, userID, beforeSeq, limit)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := dto.ConvertMessagesToResponse(messages)
|
||||
|
||||
response.Success(c, &dto.MessageSyncResponse{
|
||||
Messages: result,
|
||||
HasMore: len(messages) == limit,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 分页模式
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
messages, total, err := h.chatService.GetMessages(c.Request.Context(), conversationID, userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := dto.ConvertMessagesToResponse(messages)
|
||||
|
||||
response.Paginated(c, result, total, page, pageSize)
|
||||
}
|
||||
|
||||
// SendMessage 发送消息
|
||||
// POST /api/conversations/:id/messages
|
||||
func (h *MessageHandler) SendMessage(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr := c.Param("id")
|
||||
fmt.Printf("[DEBUG] SendMessage: conversationIDStr = %s\n", conversationIDStr)
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] SendMessage: failed to parse conversation ID: %v\n", err)
|
||||
response.BadRequest(c, "invalid conversation id")
|
||||
return
|
||||
}
|
||||
fmt.Printf("[DEBUG] SendMessage: conversationID = %s, userID = %s\n", conversationID, userID)
|
||||
|
||||
var req dto.SendMessageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 直接使用 segments
|
||||
msg, err := h.chatService.SendMessage(c.Request.Context(), userID, conversationID, req.Segments, req.ReplyToID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ConvertMessageToResponse(msg))
|
||||
}
|
||||
|
||||
// HandleSendMessage RESTful 风格的发送消息端点
|
||||
// POST /api/v1/conversations/send_message
|
||||
// 请求体格式: {"detail_type": "private", "conversation_id": "123445667", "segments": [{"type": "text", "data": {"text": "嗨~"}}]}
|
||||
func (h *MessageHandler) HandleSendMessage(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var params dto.SendMessageParams
|
||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证参数
|
||||
if params.ConversationID == "" {
|
||||
response.BadRequest(c, "conversation_id is required")
|
||||
return
|
||||
}
|
||||
if params.DetailType == "" {
|
||||
response.BadRequest(c, "detail_type is required")
|
||||
return
|
||||
}
|
||||
if params.Segments == nil || len(params.Segments) == 0 {
|
||||
response.BadRequest(c, "segments is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 发送消息
|
||||
msg, err := h.chatService.SendMessage(c.Request.Context(), userID, params.ConversationID, params.Segments, params.ReplyToID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 构建 WSEventResponse 格式响应
|
||||
wsResponse := dto.WSEventResponse{
|
||||
ID: msg.ID,
|
||||
Time: msg.CreatedAt.UnixMilli(),
|
||||
Type: "message",
|
||||
DetailType: params.DetailType,
|
||||
Seq: strconv.FormatInt(msg.Seq, 10),
|
||||
Segments: params.Segments,
|
||||
SenderID: userID,
|
||||
}
|
||||
|
||||
response.Success(c, wsResponse)
|
||||
}
|
||||
|
||||
// HandleDeleteMsg 撤回消息
|
||||
// POST /api/v1/messages/delete_msg
|
||||
// 请求体格式: {"message_id": "xxx"}
|
||||
func (h *MessageHandler) HandleDeleteMsg(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var params dto.DeleteMsgParams
|
||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证参数
|
||||
if params.MessageID == "" {
|
||||
response.BadRequest(c, "message_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 撤回消息
|
||||
err := h.chatService.RecallMessage(c.Request.Context(), params.MessageID, userID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "消息已撤回", nil)
|
||||
}
|
||||
|
||||
// HandleGetConversationList 获取会话列表
|
||||
// GET /api/v1/conversations/list
|
||||
func (h *MessageHandler) HandleGetConversationList(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
convs, _, err := h.chatService.GetConversationList(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get conversations")
|
||||
return
|
||||
}
|
||||
|
||||
// 过滤掉系统会话(系统通知现在使用独立的表)
|
||||
filteredConvs := make([]*model.Conversation, 0)
|
||||
for _, conv := range convs {
|
||||
if conv.ID != model.SystemConversationID {
|
||||
filteredConvs = append(filteredConvs, conv)
|
||||
}
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := make([]*dto.ConversationResponse, len(filteredConvs))
|
||||
for i, conv := range filteredConvs {
|
||||
// 获取未读数
|
||||
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conv.ID, userID)
|
||||
|
||||
// 获取最后一条消息
|
||||
var lastMessage *model.Message
|
||||
messages, _, _ := h.chatService.GetMessages(c.Request.Context(), conv.ID, userID, 1, 1)
|
||||
if len(messages) > 0 {
|
||||
lastMessage = messages[0]
|
||||
}
|
||||
|
||||
// 群聊时返回member_count,私聊时返回participants
|
||||
var resp *dto.ConversationResponse
|
||||
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
|
||||
isPinned := myParticipant != nil && myParticipant.IsPinned
|
||||
if conv.Type == model.ConversationTypeGroup && conv.GroupID != nil && *conv.GroupID != "" {
|
||||
// 群聊:实时计算群成员数量
|
||||
memberCount, _ := h.groupService.GetMemberCount(*conv.GroupID)
|
||||
// 创建响应并设置member_count
|
||||
resp = dto.ConvertConversationToResponse(conv, nil, int(unreadCount), lastMessage, isPinned)
|
||||
resp.MemberCount = memberCount
|
||||
} else {
|
||||
// 私聊:获取参与者信息
|
||||
participants, _ := h.getConversationParticipants(c.Request.Context(), conv.ID, userID)
|
||||
resp = dto.ConvertConversationToResponse(conv, participants, int(unreadCount), lastMessage, isPinned)
|
||||
}
|
||||
result[i] = resp
|
||||
}
|
||||
|
||||
response.Paginated(c, result, int64(len(filteredConvs)), page, pageSize)
|
||||
}
|
||||
|
||||
// HandleDeleteConversationForSelf 仅自己删除会话
|
||||
// DELETE /api/v1/conversations/:id/self
|
||||
func (h *MessageHandler) HandleDeleteConversationForSelf(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
conversationID := c.Param("id")
|
||||
if conversationID == "" {
|
||||
response.BadRequest(c, "conversation id is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.chatService.DeleteConversationForSelf(c.Request.Context(), conversationID, userID); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "conversation deleted for self", nil)
|
||||
}
|
||||
|
||||
// MarkAsRead 标记为已读
|
||||
// POST /api/conversations/:id/read
|
||||
func (h *MessageHandler) MarkAsRead(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr := c.Param("id")
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid conversation id")
|
||||
return
|
||||
}
|
||||
|
||||
var req dto.MarkReadRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "last_read_seq is required")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.chatService.MarkAsRead(c.Request.Context(), conversationID, userID, req.LastReadSeq)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "marked as read", nil)
|
||||
}
|
||||
|
||||
// GetUnreadCount 获取未读消息总数
|
||||
// GET /api/conversations/unread/count
|
||||
func (h *MessageHandler) GetUnreadCount(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
// 添加调试日志
|
||||
fmt.Printf("[DEBUG] GetUnreadCount: user_id from context = %q\n", userID)
|
||||
|
||||
if userID == "" {
|
||||
fmt.Printf("[DEBUG] GetUnreadCount: user_id is empty, returning 401\n")
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
count, err := h.chatService.GetAllUnreadCount(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get unread count")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, &dto.UnreadCountResponse{
|
||||
TotalUnreadCount: count,
|
||||
})
|
||||
}
|
||||
|
||||
// GetConversationUnreadCount 获取单个会话的未读数
|
||||
// GET /api/conversations/:id/unread/count
|
||||
func (h *MessageHandler) GetConversationUnreadCount(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr := c.Param("id")
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid conversation id")
|
||||
return
|
||||
}
|
||||
|
||||
count, err := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, &dto.ConversationUnreadCountResponse{
|
||||
ConversationID: conversationID,
|
||||
UnreadCount: count,
|
||||
})
|
||||
}
|
||||
|
||||
// RecallMessage 撤回消息
|
||||
// POST /api/messages/:id/recall
|
||||
func (h *MessageHandler) RecallMessage(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
messageIDStr := c.Param("id")
|
||||
// 直接使用 string 类型的 messageID
|
||||
err := h.chatService.RecallMessage(c.Request.Context(), messageIDStr, userID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "message recalled", nil)
|
||||
}
|
||||
|
||||
// DeleteMessage 删除消息
|
||||
// DELETE /api/messages/:id
|
||||
func (h *MessageHandler) DeleteMessage(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
messageIDStr := c.Param("id")
|
||||
// 直接使用 string 类型的 messageID
|
||||
err := h.chatService.DeleteMessage(c.Request.Context(), messageIDStr, userID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "message deleted", nil)
|
||||
}
|
||||
|
||||
// 辅助函数:验证内容类型
|
||||
func isValidContentType(contentType model.ContentType) bool {
|
||||
switch contentType {
|
||||
case model.ContentTypeText, model.ContentTypeImage, model.ContentTypeVideo, model.ContentTypeAudio, model.ContentTypeFile:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数:获取会话参与者信息
|
||||
func (h *MessageHandler) getConversationParticipants(ctx context.Context, conversationID string, currentUserID string) ([]*model.User, error) {
|
||||
// 从repository获取参与者列表
|
||||
participants, err := h.messageService.GetConversationParticipants(conversationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取参与者用户信息
|
||||
var users []*model.User
|
||||
for _, p := range participants {
|
||||
// 跳过当前用户
|
||||
if p.UserID == currentUserID {
|
||||
continue
|
||||
}
|
||||
user, err := h.userService.GetUserByID(ctx, p.UserID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// 获取当前用户在会话中的参与者信息
|
||||
func (h *MessageHandler) getMyConversationParticipant(conversationID string, userID string) (*model.ConversationParticipant, error) {
|
||||
participants, err := h.messageService.GetConversationParticipants(conversationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, p := range participants {
|
||||
if p.UserID == userID {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// ==================== RESTful Action 端点 ====================
|
||||
|
||||
// HandleCreateConversation 创建会话
|
||||
// POST /api/v1/conversations/create
|
||||
func (h *MessageHandler) HandleCreateConversation(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var params dto.CreateConversationParams
|
||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证目标用户是否存在
|
||||
targetUser, err := h.userService.GetUserByID(c.Request.Context(), params.UserID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "target user not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 不能和自己创建会话
|
||||
if userID == params.UserID {
|
||||
response.BadRequest(c, "cannot create conversation with yourself")
|
||||
return
|
||||
}
|
||||
|
||||
conv, err := h.chatService.GetOrCreateConversation(c.Request.Context(), userID, params.UserID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to create conversation")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取参与者信息
|
||||
participants := []*model.User{targetUser}
|
||||
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
|
||||
isPinned := myParticipant != nil && myParticipant.IsPinned
|
||||
|
||||
response.Success(c, dto.ConvertConversationToResponse(conv, participants, 0, nil, isPinned))
|
||||
}
|
||||
|
||||
// HandleGetConversation 获取会话详情
|
||||
// GET /api/v1/conversations/get?conversation_id=xxx
|
||||
func (h *MessageHandler) HandleGetConversation(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
conversationID := c.Query("conversation_id")
|
||||
if conversationID == "" {
|
||||
response.BadRequest(c, "conversation_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
conv, err := h.chatService.GetConversationByID(c.Request.Context(), conversationID, userID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取未读数
|
||||
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
|
||||
|
||||
// 获取参与者信息
|
||||
participants, _ := h.getConversationParticipants(c.Request.Context(), conversationID, userID)
|
||||
|
||||
// 获取当前用户的已读位置
|
||||
myLastReadSeq := int64(0)
|
||||
isPinned := false
|
||||
allParticipants, _ := h.messageService.GetConversationParticipants(conversationID)
|
||||
for _, p := range allParticipants {
|
||||
if p.UserID == userID {
|
||||
myLastReadSeq = p.LastReadSeq
|
||||
isPinned = p.IsPinned
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 获取对方用户的已读位置
|
||||
otherLastReadSeq := int64(0)
|
||||
response.Success(c, dto.ConvertConversationToDetailResponse(conv, participants, unreadCount, nil, myLastReadSeq, otherLastReadSeq, isPinned))
|
||||
}
|
||||
|
||||
// HandleGetMessages 获取会话消息
|
||||
// GET /api/v1/conversations/get_messages?conversation_id=xxx
|
||||
func (h *MessageHandler) HandleGetMessages(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
conversationID := c.Query("conversation_id")
|
||||
if conversationID == "" {
|
||||
response.BadRequest(c, "conversation_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否使用增量同步(after_seq参数)
|
||||
afterSeqStr := c.Query("after_seq")
|
||||
if afterSeqStr != "" {
|
||||
// 增量同步模式
|
||||
afterSeq, err := strconv.ParseInt(afterSeqStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid after_seq")
|
||||
return
|
||||
}
|
||||
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||
|
||||
messages, err := h.chatService.GetMessagesAfterSeq(c.Request.Context(), conversationID, userID, afterSeq, limit)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := dto.ConvertMessagesToResponse(messages)
|
||||
|
||||
response.Success(c, &dto.MessageSyncResponse{
|
||||
Messages: result,
|
||||
HasMore: len(messages) == limit,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否使用历史消息加载(before_seq参数)
|
||||
beforeSeqStr := c.Query("before_seq")
|
||||
if beforeSeqStr != "" {
|
||||
// 加载更早的消息(下拉加载更多)
|
||||
beforeSeq, err := strconv.ParseInt(beforeSeqStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid before_seq")
|
||||
return
|
||||
}
|
||||
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||
|
||||
messages, err := h.chatService.GetMessagesBeforeSeq(c.Request.Context(), conversationID, userID, beforeSeq, limit)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := dto.ConvertMessagesToResponse(messages)
|
||||
|
||||
response.Success(c, &dto.MessageSyncResponse{
|
||||
Messages: result,
|
||||
HasMore: len(messages) == limit,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 分页模式
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
messages, total, err := h.chatService.GetMessages(c.Request.Context(), conversationID, userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := dto.ConvertMessagesToResponse(messages)
|
||||
|
||||
response.Paginated(c, result, total, page, pageSize)
|
||||
}
|
||||
|
||||
// HandleMarkRead 标记已读
|
||||
// POST /api/v1/conversations/mark_read
|
||||
func (h *MessageHandler) HandleMarkRead(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var params dto.MarkReadParams
|
||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if params.ConversationID == "" {
|
||||
response.BadRequest(c, "conversation_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.chatService.MarkAsRead(c.Request.Context(), params.ConversationID, userID, params.LastReadSeq)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "marked as read", nil)
|
||||
}
|
||||
|
||||
// HandleSetConversationPinned 设置会话置顶
|
||||
// POST /api/v1/conversations/set_pinned
|
||||
func (h *MessageHandler) HandleSetConversationPinned(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var params dto.SetConversationPinnedParams
|
||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if params.ConversationID == "" {
|
||||
response.BadRequest(c, "conversation_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.chatService.SetConversationPinned(c.Request.Context(), params.ConversationID, userID, params.IsPinned); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "conversation pinned status updated", gin.H{
|
||||
"conversation_id": params.ConversationID,
|
||||
"is_pinned": params.IsPinned,
|
||||
})
|
||||
}
|
||||
132
internal/handler/notification_handler.go
Normal file
132
internal/handler/notification_handler.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// NotificationHandler 通知处理器
|
||||
type NotificationHandler struct {
|
||||
notificationService *service.NotificationService
|
||||
}
|
||||
|
||||
// NewNotificationHandler 创建通知处理器
|
||||
func NewNotificationHandler(notificationService *service.NotificationService) *NotificationHandler {
|
||||
return &NotificationHandler{
|
||||
notificationService: notificationService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetNotifications 获取通知列表
|
||||
func (h *NotificationHandler) GetNotifications(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
unreadOnly := c.Query("unread_only") == "true"
|
||||
|
||||
notifications, total, err := h.notificationService.GetByUserID(c.Request.Context(), userID, page, pageSize, unreadOnly)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get notifications")
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, notifications, total, page, pageSize)
|
||||
}
|
||||
|
||||
// MarkAsRead 标记为已读
|
||||
func (h *NotificationHandler) MarkAsRead(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
err := h.notificationService.MarkAsReadWithUserID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to mark as read")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "marked as read", nil)
|
||||
}
|
||||
|
||||
// MarkAllAsRead 标记所有为已读
|
||||
func (h *NotificationHandler) MarkAllAsRead(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.notificationService.MarkAllAsRead(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to mark all as read")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "all marked as read", nil)
|
||||
}
|
||||
|
||||
// GetUnreadCount 获取未读数量
|
||||
func (h *NotificationHandler) GetUnreadCount(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
count, err := h.notificationService.GetUnreadCount(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get unread count")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"count": count})
|
||||
}
|
||||
|
||||
// DeleteNotification 删除通知
|
||||
func (h *NotificationHandler) DeleteNotification(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
err := h.notificationService.DeleteNotification(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to delete notification")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// ClearAllNotifications 清空所有通知
|
||||
func (h *NotificationHandler) ClearAllNotifications(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.notificationService.ClearAllNotifications(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to clear notifications")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
511
internal/handler/post_handler.go
Normal file
511
internal/handler/post_handler.go
Normal file
@@ -0,0 +1,511 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// PostHandler 帖子处理器
|
||||
type PostHandler struct {
|
||||
postService *service.PostService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
// NewPostHandler 创建帖子处理器
|
||||
func NewPostHandler(postService *service.PostService, userService *service.UserService) *PostHandler {
|
||||
return &PostHandler{
|
||||
postService: postService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建帖子
|
||||
func (h *PostHandler) Create(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
type CreateRequest struct {
|
||||
Title string `json:"title" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
Images []string `json:"images"`
|
||||
}
|
||||
|
||||
var req CreateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
post, err := h.postService.Create(c.Request.Context(), userID, req.Title, req.Content, req.Images)
|
||||
if err != nil {
|
||||
var moderationErr *service.PostModerationRejectedError
|
||||
if errors.As(err, &moderationErr) {
|
||||
response.BadRequest(c, moderationErr.UserMessage())
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to create post")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ConvertPostToResponse(post, false, false))
|
||||
}
|
||||
|
||||
// GetByID 获取帖子(不增加浏览量)
|
||||
func (h *PostHandler) GetByID(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "post not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 非作者不可查看未发布内容
|
||||
currentUserID := c.GetString("user_id")
|
||||
if post.Status != model.PostStatusPublished && post.UserID != currentUserID {
|
||||
response.NotFound(c, "post not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 注意:不再自动增加浏览量,浏览量通过 RecordView 端点单独记录
|
||||
|
||||
// 获取当前用户ID用于判断点赞和收藏状态
|
||||
fmt.Printf("[DEBUG] GetByID - postID: %s, currentUserID: %s\n", id, currentUserID)
|
||||
|
||||
var isLiked, isFavorited bool
|
||||
if currentUserID != "" {
|
||||
isLiked = h.postService.IsLiked(c.Request.Context(), id, currentUserID)
|
||||
isFavorited = h.postService.IsFavorited(c.Request.Context(), id, currentUserID)
|
||||
fmt.Printf("[DEBUG] GetByID - isLiked: %v, isFavorited: %v\n", isLiked, isFavorited)
|
||||
} else {
|
||||
fmt.Printf("[DEBUG] GetByID - user not logged in, isLiked: false, isFavorited: false\n")
|
||||
}
|
||||
|
||||
// 如果有当前用户,检查与帖子作者的相互关注状态
|
||||
var authorWithFollowStatus *dto.UserResponse
|
||||
if currentUserID != "" && post.User != nil {
|
||||
_, isFollowing, isFollowingMe, err := h.userService.GetUserByIDWithMutualFollowStatus(c.Request.Context(), post.UserID, currentUserID)
|
||||
if err == nil {
|
||||
authorWithFollowStatus = dto.ConvertUserToResponseWithMutualFollow(post.User, isFollowing, isFollowingMe)
|
||||
} else {
|
||||
// 如果出错,使用默认的author
|
||||
authorWithFollowStatus = dto.ConvertUserToResponse(post.User)
|
||||
}
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
responseData := &dto.PostResponse{
|
||||
ID: post.ID,
|
||||
UserID: post.UserID,
|
||||
Title: post.Title,
|
||||
Content: post.Content,
|
||||
Images: dto.ConvertPostImagesToResponse(post.Images),
|
||||
LikesCount: post.LikesCount,
|
||||
CommentsCount: post.CommentsCount,
|
||||
FavoritesCount: post.FavoritesCount,
|
||||
SharesCount: post.SharesCount,
|
||||
ViewsCount: post.ViewsCount,
|
||||
IsPinned: post.IsPinned,
|
||||
IsLocked: post.IsLocked,
|
||||
IsVote: post.IsVote,
|
||||
CreatedAt: dto.FormatTime(post.CreatedAt),
|
||||
Author: authorWithFollowStatus,
|
||||
IsLiked: isLiked,
|
||||
IsFavorited: isFavorited,
|
||||
}
|
||||
|
||||
response.Success(c, responseData)
|
||||
}
|
||||
|
||||
// RecordView 记录帖子浏览(增加浏览量)
|
||||
func (h *PostHandler) RecordView(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
userID := c.GetString("user_id")
|
||||
|
||||
// 验证帖子存在
|
||||
_, err := h.postService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "post not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 增加浏览量
|
||||
if err := h.postService.IncrementViews(c.Request.Context(), id, userID); err != nil {
|
||||
fmt.Printf("[DEBUG] Failed to increment views for post %s: %v\n", id, err)
|
||||
response.InternalServerError(c, "failed to record view")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// List 获取帖子列表
|
||||
func (h *PostHandler) List(c *gin.Context) {
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
userID := c.Query("user_id")
|
||||
tab := c.Query("tab") // recommend, follow, hot, latest
|
||||
|
||||
// 获取当前用户ID
|
||||
currentUserID := c.GetString("user_id")
|
||||
|
||||
var posts []*model.Post
|
||||
var total int64
|
||||
var err error
|
||||
|
||||
// 根据 tab 参数选择不同的获取方式
|
||||
switch tab {
|
||||
case "follow":
|
||||
// 获取关注用户的帖子,需要登录
|
||||
if currentUserID == "" {
|
||||
response.Unauthorized(c, "请先登录")
|
||||
return
|
||||
}
|
||||
posts, total, err = h.postService.GetFollowingPosts(c.Request.Context(), currentUserID, page, pageSize)
|
||||
case "hot":
|
||||
// 获取热门帖子
|
||||
posts, total, err = h.postService.GetHotPosts(c.Request.Context(), page, pageSize)
|
||||
case "recommend":
|
||||
// 推荐帖子(从Gorse获取个性化推荐)
|
||||
posts, total, err = h.postService.GetRecommendedPosts(c.Request.Context(), currentUserID, page, pageSize)
|
||||
case "latest":
|
||||
// 最新帖子
|
||||
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
|
||||
default:
|
||||
// 默认获取最新帖子
|
||||
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get posts")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] List - tab: %s, currentUserID: %s, posts count: %d\n", tab, currentUserID, len(posts))
|
||||
|
||||
isLikedMap := make(map[string]bool)
|
||||
isFavoritedMap := make(map[string]bool)
|
||||
if currentUserID != "" {
|
||||
for _, post := range posts {
|
||||
isLiked := h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
|
||||
isFavorited := h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
|
||||
isLikedMap[post.ID] = isLiked
|
||||
isFavoritedMap[post.ID] = isFavorited
|
||||
fmt.Printf("[DEBUG] List - postID: %s, isLiked: %v, isFavorited: %v\n", post.ID, isLiked, isFavorited)
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("[DEBUG] List - user not logged in\n")
|
||||
}
|
||||
|
||||
// 转换为响应结构
|
||||
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
|
||||
|
||||
response.Paginated(c, postResponses, total, page, pageSize)
|
||||
}
|
||||
|
||||
// Update 更新帖子
|
||||
func (h *PostHandler) Update(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "post not found")
|
||||
return
|
||||
}
|
||||
|
||||
if post.UserID != userID {
|
||||
response.Forbidden(c, "cannot update others' post")
|
||||
return
|
||||
}
|
||||
|
||||
type UpdateRequest struct {
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
var req UpdateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.Title != "" {
|
||||
post.Title = req.Title
|
||||
}
|
||||
if req.Content != "" {
|
||||
post.Content = req.Content
|
||||
}
|
||||
|
||||
err = h.postService.Update(c.Request.Context(), post)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to update post")
|
||||
return
|
||||
}
|
||||
|
||||
currentUserID := c.GetString("user_id")
|
||||
var isLiked, isFavorited bool
|
||||
if currentUserID != "" {
|
||||
isLiked = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
|
||||
isFavorited = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
|
||||
}
|
||||
|
||||
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
|
||||
}
|
||||
|
||||
// Delete 删除帖子
|
||||
func (h *PostHandler) Delete(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "post not found")
|
||||
return
|
||||
}
|
||||
|
||||
if post.UserID != userID {
|
||||
response.Forbidden(c, "cannot delete others' post")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.postService.Delete(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to delete post")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "post deleted", nil)
|
||||
}
|
||||
|
||||
// Like 点赞帖子
|
||||
func (h *PostHandler) Like(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
fmt.Printf("[DEBUG] Like - postID: %s, userID: %s\n", id, userID)
|
||||
|
||||
err := h.postService.Like(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to like post")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取更新后的帖子状态
|
||||
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get post")
|
||||
return
|
||||
}
|
||||
|
||||
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
|
||||
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
|
||||
fmt.Printf("[DEBUG] Like - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
|
||||
|
||||
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
|
||||
}
|
||||
|
||||
// Unlike 取消点赞
|
||||
func (h *PostHandler) Unlike(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
fmt.Printf("[DEBUG] Unlike - postID: %s, userID: %s\n", id, userID)
|
||||
|
||||
err := h.postService.Unlike(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to unlike post")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取更新后的帖子状态
|
||||
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get post")
|
||||
return
|
||||
}
|
||||
|
||||
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
|
||||
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
|
||||
fmt.Printf("[DEBUG] Unlike - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
|
||||
|
||||
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
|
||||
}
|
||||
|
||||
// Favorite 收藏帖子
|
||||
func (h *PostHandler) Favorite(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
fmt.Printf("[DEBUG] Favorite - postID: %s, userID: %s\n", id, userID)
|
||||
|
||||
err := h.postService.Favorite(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to favorite post")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取更新后的帖子状态
|
||||
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get post")
|
||||
return
|
||||
}
|
||||
|
||||
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
|
||||
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
|
||||
fmt.Printf("[DEBUG] Favorite - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
|
||||
|
||||
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
|
||||
}
|
||||
|
||||
// Unfavorite 取消收藏
|
||||
func (h *PostHandler) Unfavorite(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
fmt.Printf("[DEBUG] Unfavorite - postID: %s, userID: %s\n", id, userID)
|
||||
|
||||
err := h.postService.Unfavorite(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to unfavorite post")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取更新后的帖子状态
|
||||
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get post")
|
||||
return
|
||||
}
|
||||
|
||||
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
|
||||
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
|
||||
fmt.Printf("[DEBUG] Unfavorite - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
|
||||
|
||||
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
|
||||
}
|
||||
|
||||
// GetUserPosts 获取用户帖子列表
|
||||
func (h *PostHandler) GetUserPosts(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
posts, total, err := h.postService.GetUserPosts(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get user posts")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户ID用于判断点赞和收藏状态
|
||||
currentUserID := c.GetString("user_id")
|
||||
isLikedMap := make(map[string]bool)
|
||||
isFavoritedMap := make(map[string]bool)
|
||||
if currentUserID != "" {
|
||||
for _, post := range posts {
|
||||
isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
|
||||
isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
|
||||
}
|
||||
}
|
||||
|
||||
// 转换为响应结构
|
||||
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
|
||||
|
||||
response.Paginated(c, postResponses, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetFavorites 获取收藏列表
|
||||
func (h *PostHandler) GetFavorites(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
posts, total, err := h.postService.GetFavorites(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get favorites")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户ID用于判断点赞和收藏状态
|
||||
currentUserID := c.GetString("user_id")
|
||||
isLikedMap := make(map[string]bool)
|
||||
isFavoritedMap := make(map[string]bool)
|
||||
if currentUserID != "" {
|
||||
for _, post := range posts {
|
||||
isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
|
||||
isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
|
||||
}
|
||||
}
|
||||
|
||||
// 转换为响应结构
|
||||
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
|
||||
|
||||
response.Paginated(c, postResponses, total, page, pageSize)
|
||||
}
|
||||
|
||||
// Search 搜索帖子
|
||||
func (h *PostHandler) Search(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
posts, total, err := h.postService.Search(c.Request.Context(), keyword, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to search posts")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户ID用于判断点赞和收藏状态
|
||||
currentUserID := c.GetString("user_id")
|
||||
isLikedMap := make(map[string]bool)
|
||||
isFavoritedMap := make(map[string]bool)
|
||||
if currentUserID != "" {
|
||||
for _, post := range posts {
|
||||
isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
|
||||
isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
|
||||
}
|
||||
}
|
||||
|
||||
// 转换为响应结构
|
||||
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
|
||||
|
||||
response.Paginated(c, postResponses, total, page, pageSize)
|
||||
}
|
||||
157
internal/handler/push_handler.go
Normal file
157
internal/handler/push_handler.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// PushHandler 推送处理器
|
||||
type PushHandler struct {
|
||||
pushService service.PushService
|
||||
}
|
||||
|
||||
// NewPushHandler 创建推送处理器
|
||||
func NewPushHandler(pushService service.PushService) *PushHandler {
|
||||
return &PushHandler{
|
||||
pushService: pushService,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterDevice 注册设备
|
||||
// POST /api/v1/push/devices
|
||||
func (h *PushHandler) RegisterDevice(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var req dto.RegisterDeviceRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证设备类型
|
||||
deviceType := model.DeviceType(req.DeviceType)
|
||||
if !isValidDeviceType(deviceType) {
|
||||
response.BadRequest(c, "invalid device type")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.pushService.RegisterDevice(c.Request.Context(), userID, req.DeviceID, deviceType, req.PushToken)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to register device")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "device registered successfully", nil)
|
||||
}
|
||||
|
||||
// UnregisterDevice 注销设备
|
||||
// DELETE /api/v1/push/devices/:device_id
|
||||
func (h *PushHandler) UnregisterDevice(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
deviceID := c.Param("device_id")
|
||||
if deviceID == "" {
|
||||
response.BadRequest(c, "device_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.pushService.UnregisterDevice(c.Request.Context(), deviceID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to unregister device")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "device unregistered successfully", nil)
|
||||
}
|
||||
|
||||
// GetMyDevices 获取当前用户的设备列表
|
||||
// GET /api/v1/push/devices
|
||||
func (h *PushHandler) GetMyDevices(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
// 这里需要从DeviceTokenRepository获取设备列表
|
||||
// 由于PushService接口没有提供获取设备列表的方法,我们暂时返回空列表
|
||||
// TODO: 在PushService接口中添加GetUserDevices方法
|
||||
_ = userID // 避免未使用变量警告
|
||||
|
||||
response.Success(c, []*dto.DeviceTokenResponse{})
|
||||
}
|
||||
|
||||
// GetPushRecords 获取推送记录
|
||||
// GET /api/v1/push/records
|
||||
func (h *PushHandler) GetPushRecords(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
records, err := h.pushService.GetPendingPushes(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get push records")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, &dto.PushRecordListResponse{
|
||||
Records: dto.PushRecordsToResponse(records),
|
||||
Total: int64(len(records)),
|
||||
})
|
||||
}
|
||||
|
||||
// 辅助函数:验证设备类型
|
||||
func isValidDeviceType(deviceType model.DeviceType) bool {
|
||||
switch deviceType {
|
||||
case model.DeviceTypeIOS, model.DeviceTypeAndroid, model.DeviceTypeWeb:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateDeviceToken 更新设备推送Token
|
||||
// PUT /api/v1/push/devices/:device_id/token
|
||||
func (h *PushHandler) UpdateDeviceToken(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
deviceID := c.Param("device_id")
|
||||
if deviceID == "" {
|
||||
response.BadRequest(c, "device_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
PushToken string `json:"push_token" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err := h.pushService.UpdateDeviceToken(c.Request.Context(), deviceID, req.PushToken)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to update device token")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "device token updated successfully", nil)
|
||||
}
|
||||
164
internal/handler/sticker_handler.go
Normal file
164
internal/handler/sticker_handler.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// StickerHandler 自定义表情处理器
|
||||
type StickerHandler struct {
|
||||
stickerService service.StickerService
|
||||
}
|
||||
|
||||
// NewStickerHandler 创建自定义表情处理器
|
||||
func NewStickerHandler(stickerService service.StickerService) *StickerHandler {
|
||||
return &StickerHandler{
|
||||
stickerService: stickerService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetStickersRequest 获取表情列表请求
|
||||
type GetStickersRequest struct {
|
||||
UserID string `form:"user_id" binding:"required"`
|
||||
}
|
||||
|
||||
// AddStickerRequest 添加表情请求
|
||||
type AddStickerRequest struct {
|
||||
URL string `json:"url" binding:"required"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
}
|
||||
|
||||
// DeleteStickerRequest 删除表情请求
|
||||
type DeleteStickerRequest struct {
|
||||
StickerID string `json:"sticker_id" binding:"required"`
|
||||
}
|
||||
|
||||
// ReorderStickersRequest 重新排序请求
|
||||
type ReorderStickersRequest struct {
|
||||
Orders map[string]int `json:"orders" binding:"required"`
|
||||
}
|
||||
|
||||
// CheckStickerRequest 检查表情是否存在请求
|
||||
type CheckStickerRequest struct {
|
||||
URL string `form:"url" binding:"required"`
|
||||
}
|
||||
|
||||
// GetStickers 获取用户的表情列表
|
||||
func (h *StickerHandler) GetStickers(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
stickers, err := h.stickerService.GetUserStickers(userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get stickers")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stickers": stickers})
|
||||
}
|
||||
|
||||
// AddSticker 添加表情
|
||||
func (h *StickerHandler) AddSticker(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var req AddStickerRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sticker, err := h.stickerService.AddSticker(userID, req.URL, req.Width, req.Height)
|
||||
if err != nil {
|
||||
if err == service.ErrStickerAlreadyExists {
|
||||
response.Error(c, http.StatusConflict, "sticker already exists")
|
||||
return
|
||||
}
|
||||
if err == service.ErrInvalidStickerURL {
|
||||
response.BadRequest(c, "invalid sticker url, only http/https is allowed")
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"sticker": sticker})
|
||||
}
|
||||
|
||||
// DeleteSticker 删除表情
|
||||
func (h *StickerHandler) DeleteSticker(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var req DeleteStickerRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.stickerService.DeleteSticker(userID, req.StickerID); err != nil {
|
||||
response.InternalServerError(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "sticker deleted successfully", nil)
|
||||
}
|
||||
|
||||
// ReorderStickers 重新排序表情
|
||||
func (h *StickerHandler) ReorderStickers(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var req ReorderStickersRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.stickerService.ReorderStickers(userID, req.Orders); err != nil {
|
||||
response.InternalServerError(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "stickers reordered successfully", nil)
|
||||
}
|
||||
|
||||
// CheckStickerExists 检查表情是否存在
|
||||
func (h *StickerHandler) CheckStickerExists(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
url := c.Query("url")
|
||||
if url == "" {
|
||||
response.BadRequest(c, "url is required")
|
||||
return
|
||||
}
|
||||
|
||||
exists, err := h.stickerService.CheckExists(userID, url)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"exists": exists})
|
||||
}
|
||||
154
internal/handler/system_message_handler.go
Normal file
154
internal/handler/system_message_handler.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/repository"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// SystemMessageHandler 系统消息处理器
|
||||
type SystemMessageHandler struct {
|
||||
systemMsgService service.SystemMessageService
|
||||
notifyRepo *repository.SystemNotificationRepository
|
||||
}
|
||||
|
||||
// NewSystemMessageHandler 创建系统消息处理器
|
||||
func NewSystemMessageHandler(
|
||||
systemMsgService service.SystemMessageService,
|
||||
notifyRepo *repository.SystemNotificationRepository,
|
||||
) *SystemMessageHandler {
|
||||
return &SystemMessageHandler{
|
||||
systemMsgService: systemMsgService,
|
||||
notifyRepo: notifyRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GetSystemMessages 获取系统消息列表
|
||||
// GET /api/v1/messages/system
|
||||
func (h *SystemMessageHandler) GetSystemMessages(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
// 获取当前用户的系统通知(从独立表中获取)
|
||||
notifications, total, err := h.notifyRepo.GetByReceiverID(userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get system messages")
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := make([]*dto.SystemMessageResponse, 0)
|
||||
for _, n := range notifications {
|
||||
resp := dto.SystemNotificationToResponse(n)
|
||||
result = append(result, resp)
|
||||
}
|
||||
|
||||
response.Paginated(c, result, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetUnreadCount 获取系统消息未读数
|
||||
// GET /api/v1/messages/system/unread-count
|
||||
func (h *SystemMessageHandler) GetUnreadCount(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户的未读通知数
|
||||
unreadCount, err := h.notifyRepo.GetUnreadCount(userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get unread count")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, &dto.SystemUnreadCountResponse{
|
||||
UnreadCount: unreadCount,
|
||||
})
|
||||
}
|
||||
|
||||
// MarkAsRead 标记系统消息为已读
|
||||
// PUT /api/v1/messages/system/:id/read
|
||||
func (h *SystemMessageHandler) MarkAsRead(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
notificationIDStr := c.Param("id")
|
||||
notificationID, err := strconv.ParseInt(notificationIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid notification id")
|
||||
return
|
||||
}
|
||||
|
||||
// 标记为已读
|
||||
err = h.notifyRepo.MarkAsRead(notificationID, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to mark as read")
|
||||
return
|
||||
}
|
||||
cache.InvalidateUnreadSystem(cache.GetCache(), userID)
|
||||
|
||||
response.SuccessWithMessage(c, "marked as read", nil)
|
||||
}
|
||||
|
||||
// MarkAllAsRead 标记所有系统消息为已读
|
||||
// PUT /api/v1/messages/system/read-all
|
||||
func (h *SystemMessageHandler) MarkAllAsRead(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
// 标记当前用户所有通知为已读
|
||||
err := h.notifyRepo.MarkAllAsRead(userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to mark all as read")
|
||||
return
|
||||
}
|
||||
cache.InvalidateUnreadSystem(cache.GetCache(), userID)
|
||||
|
||||
response.SuccessWithMessage(c, "all messages marked as read", nil)
|
||||
}
|
||||
|
||||
// DeleteSystemMessage 删除系统消息
|
||||
// DELETE /api/v1/messages/system/:id
|
||||
func (h *SystemMessageHandler) DeleteSystemMessage(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
notificationIDStr := c.Param("id")
|
||||
notificationID, err := strconv.ParseInt(notificationIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid notification id")
|
||||
return
|
||||
}
|
||||
|
||||
// 删除通知
|
||||
err = h.notifyRepo.Delete(notificationID, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to delete notification")
|
||||
return
|
||||
}
|
||||
cache.InvalidateUnreadSystem(cache.GetCache(), userID)
|
||||
|
||||
response.SuccessWithMessage(c, "notification deleted", nil)
|
||||
}
|
||||
90
internal/handler/upload_handler.go
Normal file
90
internal/handler/upload_handler.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// UploadHandler 上传处理器
|
||||
type UploadHandler struct {
|
||||
uploadService *service.UploadService
|
||||
}
|
||||
|
||||
// NewUploadHandler 创建上传处理器
|
||||
func NewUploadHandler(uploadService *service.UploadService) *UploadHandler {
|
||||
return &UploadHandler{
|
||||
uploadService: uploadService,
|
||||
}
|
||||
}
|
||||
|
||||
// UploadImage 上传图片
|
||||
func (h *UploadHandler) UploadImage(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
file, err := c.FormFile("image")
|
||||
if err != nil {
|
||||
response.BadRequest(c, "image file is required")
|
||||
return
|
||||
}
|
||||
|
||||
url, err := h.uploadService.UploadImage(c.Request.Context(), file)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to upload image")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"url": url})
|
||||
}
|
||||
|
||||
// UploadAvatar 上传头像
|
||||
func (h *UploadHandler) UploadAvatar(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
file, err := c.FormFile("image")
|
||||
|
||||
if err != nil {
|
||||
response.BadRequest(c, "avatar file is required")
|
||||
return
|
||||
}
|
||||
|
||||
url, err := h.uploadService.UploadAvatar(c.Request.Context(), userID, file)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to upload avatar")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"url": url})
|
||||
}
|
||||
|
||||
// UploadCover 上传头图(个人主页封面)
|
||||
func (h *UploadHandler) UploadCover(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
file, err := c.FormFile("image")
|
||||
if err != nil {
|
||||
response.BadRequest(c, "image file is required")
|
||||
return
|
||||
}
|
||||
|
||||
url, err := h.uploadService.UploadCover(c.Request.Context(), userID, file)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to upload cover")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"url": url})
|
||||
}
|
||||
705
internal/handler/user_handler.go
Normal file
705
internal/handler/user_handler.go
Normal file
@@ -0,0 +1,705 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// UserHandler 用户处理器
|
||||
type UserHandler struct {
|
||||
userService *service.UserService
|
||||
jwtService *service.JWTService
|
||||
}
|
||||
|
||||
// NewUserHandler 创建用户处理器
|
||||
func NewUserHandler(userService *service.UserService) *UserHandler {
|
||||
return &UserHandler{
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 用户注册
|
||||
func (h *UserHandler) Register(c *gin.Context) {
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Nickname string `json:"nickname" binding:"required"`
|
||||
Phone string `json:"phone"`
|
||||
VerificationCode string `json:"verification_code" binding:"required"`
|
||||
}
|
||||
|
||||
var req RegisterRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.Register(c.Request.Context(), req.Username, req.Email, req.Password, req.Nickname, req.Phone, req.VerificationCode)
|
||||
if err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to register")
|
||||
return
|
||||
}
|
||||
|
||||
// 生成Token
|
||||
accessToken, _ := h.jwtService.GenerateAccessToken(user.ID, user.Username)
|
||||
refreshToken, _ := h.jwtService.GenerateRefreshToken(user.ID, user.Username)
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"user": dto.ConvertUserToResponse(user),
|
||||
"token": accessToken,
|
||||
"refresh_token": refreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
// Login 用户登录
|
||||
func (h *UserHandler) Login(c *gin.Context) {
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username"`
|
||||
Account string `json:"account"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
var req LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
account := req.Account
|
||||
if account == "" {
|
||||
account = req.Username
|
||||
}
|
||||
if account == "" {
|
||||
response.BadRequest(c, "username or account is required")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.Login(c.Request.Context(), account, req.Password)
|
||||
if err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to login")
|
||||
return
|
||||
}
|
||||
|
||||
// 生成Token
|
||||
accessToken, _ := h.jwtService.GenerateAccessToken(user.ID, user.Username)
|
||||
refreshToken, _ := h.jwtService.GenerateRefreshToken(user.ID, user.Username)
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"user": dto.ConvertUserToResponse(user),
|
||||
"token": accessToken,
|
||||
"refresh_token": refreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
// SendRegisterCode 发送注册验证码
|
||||
func (h *UserHandler) SendRegisterCode(c *gin.Context) {
|
||||
type SendCodeRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
}
|
||||
|
||||
var req SendCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.SendRegisterCode(c.Request.Context(), req.Email); err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to send register verification code")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// SendPasswordResetCode 发送找回密码验证码
|
||||
func (h *UserHandler) SendPasswordResetCode(c *gin.Context) {
|
||||
type SendCodeRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
}
|
||||
|
||||
var req SendCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.SendPasswordResetCode(c.Request.Context(), req.Email); err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to send reset verification code")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// ResetPassword 找回密码并重置
|
||||
func (h *UserHandler) ResetPassword(c *gin.Context) {
|
||||
type ResetPasswordRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
VerificationCode string `json:"verification_code" binding:"required"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
var req ResetPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.ResetPasswordByEmail(c.Request.Context(), req.Email, req.VerificationCode, req.NewPassword); err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to reset password")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// GetCurrentUser 获取当前用户
|
||||
func (h *UserHandler) GetCurrentUser(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "user not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 实时计算帖子数量
|
||||
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
// 如果获取失败,使用数据库中的值
|
||||
postsCount = int64(user.PostsCount)
|
||||
}
|
||||
|
||||
response.Success(c, dto.ConvertUserToDetailResponseWithPostsCount(user, int(postsCount)))
|
||||
}
|
||||
|
||||
// GetUserByID 获取指定用户
|
||||
func (h *UserHandler) GetUserByID(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
currentUserID := c.GetString("user_id")
|
||||
|
||||
// 获取用户信息,包含双向关注状态
|
||||
user, isFollowing, isFollowingMe, err := h.userService.GetUserByIDWithMutualFollowStatus(c.Request.Context(), id, currentUserID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "user not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 实时计算帖子数量
|
||||
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
// 如果获取失败,使用数据库中的值
|
||||
postsCount = int64(user.PostsCount)
|
||||
}
|
||||
|
||||
// 转换为响应格式,包含双向关注状态和实时计算的帖子数量
|
||||
userResponse := dto.ConvertUserToResponseWithMutualFollowAndPostsCount(user, isFollowing, isFollowingMe, int(postsCount))
|
||||
|
||||
response.Success(c, userResponse)
|
||||
}
|
||||
|
||||
// UpdateUser 更新用户
|
||||
func (h *UserHandler) UpdateUser(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
type UpdateRequest struct {
|
||||
Nickname string `json:"nickname"`
|
||||
Bio string `json:"bio"`
|
||||
Website string `json:"website"`
|
||||
Location string `json:"location"`
|
||||
Avatar string `json:"avatar"`
|
||||
Phone *string `json:"phone"`
|
||||
Email *string `json:"email"`
|
||||
}
|
||||
|
||||
var req UpdateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "user not found")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Nickname != "" {
|
||||
user.Nickname = req.Nickname
|
||||
}
|
||||
if req.Bio != "" {
|
||||
user.Bio = req.Bio
|
||||
}
|
||||
if req.Website != "" {
|
||||
user.Website = req.Website
|
||||
}
|
||||
if req.Location != "" {
|
||||
user.Location = req.Location
|
||||
}
|
||||
if req.Avatar != "" {
|
||||
user.Avatar = req.Avatar
|
||||
}
|
||||
if req.Phone != nil {
|
||||
user.Phone = req.Phone
|
||||
}
|
||||
if req.Email != nil {
|
||||
if user.Email == nil || *user.Email != *req.Email {
|
||||
user.EmailVerified = false
|
||||
}
|
||||
user.Email = req.Email
|
||||
}
|
||||
|
||||
err = h.userService.UpdateUser(c.Request.Context(), user)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to update user")
|
||||
return
|
||||
}
|
||||
|
||||
// 实时计算帖子数量
|
||||
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
// 如果获取失败,使用数据库中的值
|
||||
postsCount = int64(user.PostsCount)
|
||||
}
|
||||
|
||||
response.Success(c, dto.ConvertUserToDetailResponseWithPostsCount(user, int(postsCount)))
|
||||
}
|
||||
|
||||
// SendEmailVerifyCode 发送当前用户邮箱验证码
|
||||
func (h *UserHandler) SendEmailVerifyCode(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
type SendCodeRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
}
|
||||
var req SendCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.SendCurrentUserEmailVerifyCode(c.Request.Context(), userID, req.Email); err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to send email verify code")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// VerifyEmail 验证当前用户邮箱
|
||||
func (h *UserHandler) VerifyEmail(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
type VerifyEmailRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
VerificationCode string `json:"verification_code" binding:"required"`
|
||||
}
|
||||
var req VerifyEmailRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.VerifyCurrentUserEmail(c.Request.Context(), userID, req.Email, req.VerificationCode); err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to verify email")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// RefreshToken 刷新Token
|
||||
func (h *UserHandler) RefreshToken(c *gin.Context) {
|
||||
type RefreshRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
}
|
||||
|
||||
var req RefreshRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 refresh token
|
||||
claims, err := h.jwtService.ParseToken(req.RefreshToken)
|
||||
if err != nil {
|
||||
response.Unauthorized(c, "invalid refresh token")
|
||||
return
|
||||
}
|
||||
|
||||
// 生成新 token
|
||||
accessToken, _ := h.jwtService.GenerateAccessToken(claims.UserID, claims.Username)
|
||||
refreshToken, _ := h.jwtService.GenerateRefreshToken(claims.UserID, claims.Username)
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"token": accessToken,
|
||||
"refresh_token": refreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
// SetJWTService 设置JWT服务
|
||||
func (h *UserHandler) SetJWTService(jwtSvc *service.JWTService) {
|
||||
h.jwtService = jwtSvc
|
||||
}
|
||||
|
||||
// FollowUser 关注用户
|
||||
func (h *UserHandler) FollowUser(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
currentUserID := c.GetString("user_id")
|
||||
|
||||
if userID == currentUserID {
|
||||
response.BadRequest(c, "cannot follow yourself")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.userService.FollowUser(c.Request.Context(), currentUserID, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to follow user")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// UnfollowUser 取消关注用户
|
||||
func (h *UserHandler) UnfollowUser(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
currentUserID := c.GetString("user_id")
|
||||
|
||||
err := h.userService.UnfollowUser(c.Request.Context(), currentUserID, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to unfollow user")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// BlockUser 拉黑用户
|
||||
func (h *UserHandler) BlockUser(c *gin.Context) {
|
||||
targetUserID := c.Param("id")
|
||||
currentUserID := c.GetString("user_id")
|
||||
|
||||
if targetUserID == currentUserID {
|
||||
response.BadRequest(c, "cannot block yourself")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.userService.BlockUser(c.Request.Context(), currentUserID, targetUserID)
|
||||
if err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to block user")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// UnblockUser 取消拉黑
|
||||
func (h *UserHandler) UnblockUser(c *gin.Context) {
|
||||
targetUserID := c.Param("id")
|
||||
currentUserID := c.GetString("user_id")
|
||||
|
||||
if targetUserID == currentUserID {
|
||||
response.BadRequest(c, "cannot unblock yourself")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.userService.UnblockUser(c.Request.Context(), currentUserID, targetUserID)
|
||||
if err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to unblock user")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// GetBlockedUsers 获取黑名单列表
|
||||
func (h *UserHandler) GetBlockedUsers(c *gin.Context) {
|
||||
currentUserID := c.GetString("user_id")
|
||||
if currentUserID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
users, total, err := h.userService.GetBlockedUsers(c.Request.Context(), currentUserID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get blocked users")
|
||||
return
|
||||
}
|
||||
|
||||
userIDs := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
userIDs[i] = u.ID
|
||||
}
|
||||
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||
userResponses := dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
|
||||
response.Paginated(c, userResponses, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetBlockStatus 获取拉黑状态
|
||||
func (h *UserHandler) GetBlockStatus(c *gin.Context) {
|
||||
targetUserID := c.Param("id")
|
||||
currentUserID := c.GetString("user_id")
|
||||
if currentUserID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
if targetUserID == "" {
|
||||
response.BadRequest(c, "target user id is required")
|
||||
return
|
||||
}
|
||||
|
||||
isBlocked, err := h.userService.IsBlocked(c.Request.Context(), currentUserID, targetUserID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get block status")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"is_blocked": isBlocked})
|
||||
}
|
||||
|
||||
// GetFollowingList 获取关注列表
|
||||
func (h *UserHandler) GetFollowingList(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
currentUserID := c.GetString("user_id")
|
||||
page := c.DefaultQuery("page", "1")
|
||||
pageSize := c.DefaultQuery("page_size", "20")
|
||||
|
||||
users, err := h.userService.GetFollowingList(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get following list")
|
||||
return
|
||||
}
|
||||
|
||||
// 如果已登录,获取双向关注状态和实时计算的帖子数量
|
||||
var userResponses []*dto.UserResponse
|
||||
if currentUserID != "" && len(users) > 0 {
|
||||
userIDs := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
userIDs[i] = u.ID
|
||||
}
|
||||
statusMap, _ := h.userService.GetMutualFollowStatus(c.Request.Context(), currentUserID, userIDs)
|
||||
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, statusMap, postsCountMap)
|
||||
} else if len(users) > 0 {
|
||||
userIDs := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
userIDs[i] = u.ID
|
||||
}
|
||||
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
|
||||
} else {
|
||||
userResponses = dto.ConvertUsersToResponse(users)
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"list": userResponses,
|
||||
})
|
||||
}
|
||||
|
||||
// GetFollowersList 获取粉丝列表
|
||||
func (h *UserHandler) GetFollowersList(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
currentUserID := c.GetString("user_id")
|
||||
page := c.DefaultQuery("page", "1")
|
||||
pageSize := c.DefaultQuery("page_size", "20")
|
||||
|
||||
fmt.Printf("[DEBUG] GetFollowersList: userID=%s, currentUserID=%s\n", userID, currentUserID)
|
||||
|
||||
users, err := h.userService.GetFollowersList(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get followers list")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] GetFollowersList: found %d users\n", len(users))
|
||||
|
||||
// 如果已登录,获取双向关注状态和实时计算的帖子数量
|
||||
var userResponses []*dto.UserResponse
|
||||
if currentUserID != "" && len(users) > 0 {
|
||||
userIDs := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
userIDs[i] = u.ID
|
||||
}
|
||||
fmt.Printf("[DEBUG] GetFollowersList: checking mutual follow status for userIDs=%v\n", userIDs)
|
||||
statusMap, _ := h.userService.GetMutualFollowStatus(c.Request.Context(), currentUserID, userIDs)
|
||||
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, statusMap, postsCountMap)
|
||||
} else if len(users) > 0 {
|
||||
userIDs := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
userIDs[i] = u.ID
|
||||
}
|
||||
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
|
||||
} else {
|
||||
userResponses = dto.ConvertUsersToResponse(users)
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"list": userResponses,
|
||||
})
|
||||
}
|
||||
|
||||
// CheckUsername 检查用户名是否可用
|
||||
func (h *UserHandler) CheckUsername(c *gin.Context) {
|
||||
username := c.Query("username")
|
||||
if username == "" {
|
||||
response.BadRequest(c, "username is required")
|
||||
return
|
||||
}
|
||||
|
||||
available, err := h.userService.CheckUsernameAvailable(c.Request.Context(), username)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to check username")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"available": available})
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
currentUserID := c.GetString("user_id")
|
||||
|
||||
type ChangePasswordRequest struct {
|
||||
OldPassword string `json:"old_password" binding:"required"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
VerificationCode string `json:"verification_code" binding:"required"`
|
||||
}
|
||||
|
||||
var req ChangePasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err := h.userService.ChangePassword(c.Request.Context(), currentUserID, req.OldPassword, req.NewPassword, req.VerificationCode)
|
||||
if err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to change password")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// SendChangePasswordCode 发送修改密码验证码
|
||||
func (h *UserHandler) SendChangePasswordCode(c *gin.Context) {
|
||||
currentUserID := c.GetString("user_id")
|
||||
if currentUserID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.userService.SendChangePasswordCode(c.Request.Context(), currentUserID)
|
||||
if err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to send change password code")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// Search 搜索用户
|
||||
func (h *UserHandler) Search(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
users, total, err := h.userService.Search(c.Request.Context(), keyword, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to search users")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取实时计算的帖子数量
|
||||
var userResponses []*dto.UserResponse
|
||||
if len(users) > 0 {
|
||||
userIDs := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
userIDs[i] = u.ID
|
||||
}
|
||||
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
|
||||
} else {
|
||||
userResponses = dto.ConvertUsersToResponse(users)
|
||||
}
|
||||
|
||||
response.Paginated(c, userResponses, total, page, pageSize)
|
||||
}
|
||||
216
internal/handler/vote_handler.go
Normal file
216
internal/handler/vote_handler.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// VoteHandler 投票处理器
|
||||
type VoteHandler struct {
|
||||
voteService *service.VoteService
|
||||
postService *service.PostService
|
||||
}
|
||||
|
||||
// NewVoteHandler 创建投票处理器
|
||||
func NewVoteHandler(voteService *service.VoteService, postService *service.PostService) *VoteHandler {
|
||||
return &VoteHandler{
|
||||
voteService: voteService,
|
||||
postService: postService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateVotePost 创建投票帖子
|
||||
// POST /api/v1/posts/vote
|
||||
func (h *VoteHandler) CreateVotePost(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "请先登录")
|
||||
return
|
||||
}
|
||||
|
||||
var req dto.CreateVotePostRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
post, err := h.voteService.CreateVotePost(c.Request.Context(), userID, &req)
|
||||
if err != nil {
|
||||
var moderationErr *service.PostModerationRejectedError
|
||||
if errors.As(err, &moderationErr) {
|
||||
response.BadRequest(c, moderationErr.UserMessage())
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, post)
|
||||
}
|
||||
|
||||
// GetVoteResult 获取投票结果
|
||||
// GET /api/v1/posts/:id/vote
|
||||
func (h *VoteHandler) GetVoteResult(c *gin.Context) {
|
||||
postID := c.Param("id")
|
||||
if postID == "" {
|
||||
response.BadRequest(c, "帖子ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证帖子存在
|
||||
_, err := h.postService.GetByID(c.Request.Context(), postID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "帖子不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户ID(可选登录)
|
||||
userID := c.GetString("user_id")
|
||||
|
||||
// 如果用户未登录,返回不带has_voted的结果
|
||||
if userID == "" {
|
||||
options, err := h.voteService.GetVoteOptions(postID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "获取投票选项失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 计算总票数
|
||||
totalVotes := 0
|
||||
for _, option := range options {
|
||||
totalVotes += option.VotesCount
|
||||
}
|
||||
|
||||
result := &dto.VoteResultDTO{
|
||||
Options: options,
|
||||
TotalVotes: totalVotes,
|
||||
HasVoted: false,
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
return
|
||||
}
|
||||
|
||||
// 用户已登录,获取完整的投票结果
|
||||
result, err := h.voteService.GetVoteResult(postID, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "获取投票结果失败")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// Vote 投票
|
||||
// POST /api/v1/posts/:id/vote
|
||||
func (h *VoteHandler) Vote(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "请先登录")
|
||||
return
|
||||
}
|
||||
|
||||
postID := c.Param("id")
|
||||
if postID == "" {
|
||||
response.BadRequest(c, "帖子ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证帖子存在
|
||||
_, err := h.postService.GetByID(c.Request.Context(), postID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "帖子不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求体
|
||||
var req struct {
|
||||
OptionID string `json:"option_id" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.voteService.Vote(c.Request.Context(), postID, userID, req.OptionID); err != nil {
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// Unvote 取消投票
|
||||
// DELETE /api/v1/posts/:id/vote
|
||||
func (h *VoteHandler) Unvote(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "请先登录")
|
||||
return
|
||||
}
|
||||
|
||||
postID := c.Param("id")
|
||||
if postID == "" {
|
||||
response.BadRequest(c, "帖子ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证帖子存在
|
||||
_, err := h.postService.GetByID(c.Request.Context(), postID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "帖子不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.voteService.Unvote(c.Request.Context(), postID, userID); err != nil {
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// UpdateVoteOption 更新投票选项(仅作者)
|
||||
// PUT /api/v1/vote-options/:id
|
||||
func (h *VoteHandler) UpdateVoteOption(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "请先登录")
|
||||
return
|
||||
}
|
||||
|
||||
optionID := c.Param("id")
|
||||
if optionID == "" {
|
||||
response.BadRequest(c, "选项ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求体
|
||||
var req struct {
|
||||
Content string `json:"content" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取帖子ID(从查询参数或请求体中获取)
|
||||
postID := c.Query("post_id")
|
||||
if postID == "" {
|
||||
response.BadRequest(c, "帖子ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.voteService.UpdateVoteOption(c.Request.Context(), postID, optionID, userID, req.Content); err != nil {
|
||||
response.Error(c, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
866
internal/handler/websocket_handler.go
Normal file
866
internal/handler/websocket_handler.go
Normal file
@@ -0,0 +1,866 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/model"
|
||||
ws "carrot_bbs/internal/pkg/websocket"
|
||||
"carrot_bbs/internal/repository"
|
||||
"carrot_bbs/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // 允许所有来源,生产环境应该限制
|
||||
},
|
||||
}
|
||||
|
||||
// WebSocketHandler WebSocket处理器
|
||||
type WebSocketHandler struct {
|
||||
jwtService *service.JWTService
|
||||
chatService service.ChatService
|
||||
groupService service.GroupService
|
||||
groupRepo repository.GroupRepository
|
||||
userRepo *repository.UserRepository
|
||||
wsManager *ws.WebSocketManager
|
||||
}
|
||||
|
||||
// NewWebSocketHandler 创建WebSocket处理器
|
||||
func NewWebSocketHandler(
|
||||
jwtService *service.JWTService,
|
||||
chatService service.ChatService,
|
||||
groupService service.GroupService,
|
||||
groupRepo repository.GroupRepository,
|
||||
userRepo *repository.UserRepository,
|
||||
wsManager *ws.WebSocketManager,
|
||||
) *WebSocketHandler {
|
||||
return &WebSocketHandler{
|
||||
jwtService: jwtService,
|
||||
chatService: chatService,
|
||||
groupService: groupService,
|
||||
groupRepo: groupRepo,
|
||||
userRepo: userRepo,
|
||||
wsManager: wsManager,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleWebSocket 处理WebSocket连接
|
||||
func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) {
|
||||
// 调试:打印请求头信息
|
||||
log.Printf("[WebSocket] 收到请求: Method=%s, Path=%s", c.Request.Method, c.Request.URL.Path)
|
||||
log.Printf("[WebSocket] 请求头: Connection=%s, Upgrade=%s",
|
||||
c.GetHeader("Connection"),
|
||||
c.GetHeader("Upgrade"))
|
||||
log.Printf("[WebSocket] Sec-WebSocket-Key=%s, Sec-WebSocket-Version=%s",
|
||||
c.GetHeader("Sec-WebSocket-Key"),
|
||||
c.GetHeader("Sec-WebSocket-Version"))
|
||||
|
||||
// 从query参数获取token
|
||||
token := c.Query("token")
|
||||
if token == "" {
|
||||
// 尝试从header获取
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
token = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing token"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token
|
||||
claims, err := h.jwtService.ParseToken(token)
|
||||
if err != nil {
|
||||
log.Printf("Invalid token: %v", err)
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
userID := claims.UserID
|
||||
if userID == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token claims"})
|
||||
return
|
||||
}
|
||||
|
||||
// 升级HTTP连接为WebSocket连接
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to upgrade connection: %v", err)
|
||||
log.Printf("[WebSocket] 请求详情 - User-Agent: %s, Content-Type: %s",
|
||||
c.GetHeader("User-Agent"),
|
||||
c.GetHeader("Content-Type"))
|
||||
return
|
||||
}
|
||||
|
||||
// 如果用户已在线,先注销旧连接
|
||||
if h.wsManager.IsUserOnline(userID) {
|
||||
log.Printf("[DEBUG] 用户 %s 已有在线连接,复用该连接", userID)
|
||||
} else {
|
||||
log.Printf("[DEBUG] 用户 %s 当前不在线,创建新连接", userID)
|
||||
}
|
||||
|
||||
// 创建客户端
|
||||
client := &ws.Client{
|
||||
ID: userID,
|
||||
UserID: userID,
|
||||
Conn: conn,
|
||||
Send: make(chan []byte, 256),
|
||||
Manager: h.wsManager,
|
||||
}
|
||||
|
||||
// 注册客户端
|
||||
h.wsManager.Register(client)
|
||||
|
||||
// 启动读写协程
|
||||
go client.WritePump()
|
||||
go h.handleMessages(client)
|
||||
|
||||
log.Printf("[DEBUG] WebSocket连接建立: userID=%s, 当前在线=%v", userID, h.wsManager.IsUserOnline(userID))
|
||||
}
|
||||
|
||||
// handleMessages 处理客户端消息
|
||||
// 针对移动端优化:增加超时时间到 120 秒,配合客户端 55 秒心跳
|
||||
func (h *WebSocketHandler) handleMessages(client *ws.Client) {
|
||||
defer func() {
|
||||
h.wsManager.Unregister(client)
|
||||
client.Conn.Close()
|
||||
}()
|
||||
|
||||
client.Conn.SetReadLimit(512 * 1024) // 512KB
|
||||
client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒
|
||||
client.Conn.SetPongHandler(func(string) error {
|
||||
client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒
|
||||
return nil
|
||||
})
|
||||
|
||||
// 心跳定时器 - 服务端主动 ping 间隔保持 30 秒
|
||||
pingTicker := time.NewTicker(30 * time.Second)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pingTicker.C:
|
||||
// 发送心跳
|
||||
if err := client.SendPing(); err != nil {
|
||||
log.Printf("Failed to send ping: %v", err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
_, message, err := client.Conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Printf("WebSocket error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var wsMsg ws.WSMessage
|
||||
if err := json.Unmarshal(message, &wsMsg); err != nil {
|
||||
log.Printf("Failed to unmarshal message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
h.processMessage(client, &wsMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processMessage 处理消息
|
||||
func (h *WebSocketHandler) processMessage(client *ws.Client, msg *ws.WSMessage) {
|
||||
switch msg.Type {
|
||||
case ws.MessageTypePing:
|
||||
// 响应心跳
|
||||
if err := client.SendPong(); err != nil {
|
||||
log.Printf("Failed to send pong: %v", err)
|
||||
}
|
||||
|
||||
case ws.MessageTypePong:
|
||||
// 客户端响应心跳
|
||||
|
||||
case ws.MessageTypeMessage:
|
||||
// 处理聊天消息
|
||||
h.handleChatMessage(client, msg)
|
||||
|
||||
case ws.MessageTypeTyping:
|
||||
// 处理正在输入状态
|
||||
h.handleTyping(client, msg)
|
||||
|
||||
case ws.MessageTypeRead:
|
||||
// 处理已读回执
|
||||
h.handleReadReceipt(client, msg)
|
||||
|
||||
// 群组消息处理
|
||||
case ws.MessageTypeGroupMessage:
|
||||
// 处理群消息
|
||||
h.handleGroupMessage(client, msg)
|
||||
|
||||
case ws.MessageTypeGroupTyping:
|
||||
// 处理群输入状态
|
||||
h.handleGroupTyping(client, msg)
|
||||
|
||||
case ws.MessageTypeGroupRead:
|
||||
// 处理群消息已读
|
||||
h.handleGroupReadReceipt(client, msg)
|
||||
|
||||
case ws.MessageTypeGroupRecall:
|
||||
// 处理群消息撤回
|
||||
h.handleGroupRecall(client, msg)
|
||||
|
||||
default:
|
||||
log.Printf("Unknown message type: %s", msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// handleChatMessage 处理聊天消息
|
||||
func (h *WebSocketHandler) handleChatMessage(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
log.Printf("Invalid message data format")
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG handleChatMessage] 完整data: %+v", data)
|
||||
|
||||
conversationIDStr, _ := data["conversationId"].(string)
|
||||
|
||||
if conversationIDStr == "" {
|
||||
log.Printf("Missing conversationId")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析会话ID
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
log.Printf("Invalid conversation ID: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 segments
|
||||
var segments model.MessageSegments
|
||||
if data["segments"] != nil {
|
||||
segmentsBytes, err := json.Marshal(data["segments"])
|
||||
if err == nil {
|
||||
json.Unmarshal(segmentsBytes, &segments)
|
||||
}
|
||||
}
|
||||
|
||||
// 从 segments 中提取回复消息ID
|
||||
replyToID := dto.GetReplyMessageID(segments)
|
||||
var replyToIDPtr *string
|
||||
if replyToID != "" {
|
||||
replyToIDPtr = &replyToID
|
||||
}
|
||||
|
||||
// 发送消息 - 使用 segments
|
||||
message, err := h.chatService.SendMessage(context.Background(), client.UserID, conversationID, segments, replyToIDPtr)
|
||||
if err != nil {
|
||||
log.Printf("Failed to send message: %v", err)
|
||||
// 发送错误消息
|
||||
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
||||
"error": "Failed to send message",
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(errorMsg)
|
||||
client.Send <- msgBytes
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 发送确认消息(使用 meta 事件格式,包含完整的消息内容)
|
||||
metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{
|
||||
"detail_type": ws.MetaDetailTypeAck,
|
||||
"conversation_id": conversationID,
|
||||
"id": message.ID,
|
||||
"user_id": client.UserID,
|
||||
"sender_id": client.UserID,
|
||||
"seq": message.Seq,
|
||||
"segments": message.Segments,
|
||||
"created_at": message.CreatedAt.UnixMilli(),
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(metaAckMsg)
|
||||
log.Printf("[DEBUG handleChatMessage] 私聊 ack 消息: %s", string(msgBytes))
|
||||
log.Printf("[DEBUG handleChatMessage] message.Segments 类型: %T, 值: %+v", message.Segments, message.Segments)
|
||||
client.Send <- msgBytes
|
||||
}
|
||||
}
|
||||
|
||||
// handleTyping 处理正在输入状态
|
||||
func (h *WebSocketHandler) handleTyping(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr, _ := data["conversationId"].(string)
|
||||
if conversationIDStr == "" {
|
||||
return
|
||||
}
|
||||
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 直接使用 string 类型的 userID
|
||||
h.chatService.SendTyping(context.Background(), client.UserID, conversationID)
|
||||
}
|
||||
|
||||
// handleReadReceipt 处理已读回执
|
||||
func (h *WebSocketHandler) handleReadReceipt(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr, _ := data["conversationId"].(string)
|
||||
if conversationIDStr == "" {
|
||||
return
|
||||
}
|
||||
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取lastReadSeq
|
||||
lastReadSeq, _ := data["lastReadSeq"].(float64)
|
||||
if lastReadSeq == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 直接使用 string 类型的 userID 和 conversationID
|
||||
if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil {
|
||||
log.Printf("Failed to mark as read: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 群组消息处理 ====================
|
||||
|
||||
// handleGroupMessage 处理群消息
|
||||
func (h *WebSocketHandler) handleGroupMessage(client *ws.Client, msg *ws.WSMessage) {
|
||||
// 打印接收到的消息类型和数据,用于调试
|
||||
log.Printf("[handleGroupMessage] Received message type: %s", msg.Type)
|
||||
log.Printf("[handleGroupMessage] Message data: %+v", msg.Data)
|
||||
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
log.Printf("Invalid group message data format: data is not map[string]interface{}")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析群组ID(支持 camelCase 和 snake_case)
|
||||
var groupIDFloat float64
|
||||
groupID := "" // 使用 groupID 作为最终变量名
|
||||
if val, ok := data["groupId"].(float64); ok {
|
||||
groupIDFloat = val
|
||||
groupID = strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
|
||||
} else if val, ok := data["group_id"].(string); ok {
|
||||
groupID = val
|
||||
}
|
||||
|
||||
if groupID == "" {
|
||||
log.Printf("Missing groupId in group message")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析会话ID(支持 camelCase 和 snake_case)
|
||||
var conversationID string
|
||||
if val, ok := data["conversationId"].(string); ok {
|
||||
conversationID = val
|
||||
} else if val, ok := data["conversation_id"].(string); ok {
|
||||
conversationID = val
|
||||
}
|
||||
if conversationID == "" {
|
||||
log.Printf("Missing conversationId in group message")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 segments
|
||||
var segments model.MessageSegments
|
||||
if data["segments"] != nil {
|
||||
segmentsBytes, err := json.Marshal(data["segments"])
|
||||
if err == nil {
|
||||
json.Unmarshal(segmentsBytes, &segments)
|
||||
}
|
||||
}
|
||||
|
||||
// 解析@用户列表(支持 camelCase 和 snake_case)
|
||||
var mentionUsers []uint64
|
||||
var mentionUsersInterface []interface{}
|
||||
if val, ok := data["mentionUsers"].([]interface{}); ok {
|
||||
mentionUsersInterface = val
|
||||
} else if val, ok := data["mention_users"].([]interface{}); ok {
|
||||
mentionUsersInterface = val
|
||||
}
|
||||
if len(mentionUsersInterface) > 0 {
|
||||
for _, uid := range mentionUsersInterface {
|
||||
if uidFloat, ok := uid.(float64); ok {
|
||||
mentionUsers = append(mentionUsers, uint64(uidFloat))
|
||||
} else if uidStr, ok := uid.(string); ok {
|
||||
// 处理字符串格式的用户ID
|
||||
if uidInt, err := strconv.ParseUint(uidStr, 10, 64); err == nil {
|
||||
mentionUsers = append(mentionUsers, uidInt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 解析@所有人(支持 camelCase 和 snake_case)
|
||||
var mentionAll bool
|
||||
if val, ok := data["mentionAll"].(bool); ok {
|
||||
mentionAll = val
|
||||
} else if val, ok := data["mention_all"].(bool); ok {
|
||||
mentionAll = val
|
||||
}
|
||||
|
||||
// 检查用户是否可以发送群消息(验证成员身份和禁言状态)
|
||||
// client.UserID 已经是 string 格式的 UUID
|
||||
if err := h.groupService.CanSendGroupMessage(client.UserID, groupID); err != nil {
|
||||
log.Printf("User cannot send group message: %v", err)
|
||||
// 发送错误消息
|
||||
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
||||
"error": "Cannot send group message",
|
||||
"reason": err.Error(),
|
||||
"type": "group_message_error",
|
||||
"groupId": groupID,
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(errorMsg)
|
||||
client.Send <- msgBytes
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 检查@所有人权限(只有群主和管理员可以@所有人)
|
||||
if mentionAll {
|
||||
if !h.groupService.IsGroupAdmin(client.UserID, groupID) {
|
||||
log.Printf("User %s has no permission to mention all in group %s", client.UserID, groupID)
|
||||
mentionAll = false // 取消@所有人标记
|
||||
}
|
||||
}
|
||||
|
||||
// 创建消息
|
||||
message := &model.Message{
|
||||
ConversationID: conversationID,
|
||||
SenderID: client.UserID,
|
||||
Segments: segments,
|
||||
Status: model.MessageStatusNormal,
|
||||
MentionAll: mentionAll,
|
||||
}
|
||||
|
||||
// 序列化mentionUsers为JSON
|
||||
if len(mentionUsers) > 0 {
|
||||
mentionUsersJSON, _ := json.Marshal(mentionUsers)
|
||||
message.MentionUsers = string(mentionUsersJSON)
|
||||
}
|
||||
|
||||
// 保存消息到数据库(只存库,不发私聊 WebSocket 帧,群消息通过 BroadcastGroupMessage 单独广播)
|
||||
savedMessage, err := h.chatService.SaveMessage(context.Background(), client.UserID, conversationID, segments, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to save group message: %v", err)
|
||||
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
||||
"error": "Failed to save group message",
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(errorMsg)
|
||||
client.Send <- msgBytes
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 更新消息的mention信息
|
||||
if len(mentionUsers) > 0 || mentionAll {
|
||||
message.ID = savedMessage.ID
|
||||
message.Seq = savedMessage.Seq
|
||||
}
|
||||
|
||||
// 构造群消息响应
|
||||
groupMsg := &ws.GroupMessage{
|
||||
ID: savedMessage.ID,
|
||||
ConversationID: conversationID,
|
||||
GroupID: groupID,
|
||||
SenderID: client.UserID,
|
||||
Seq: savedMessage.Seq,
|
||||
Segments: segments,
|
||||
MentionUsers: mentionUsers,
|
||||
MentionAll: mentionAll,
|
||||
CreatedAt: savedMessage.CreatedAt.UnixMilli(),
|
||||
}
|
||||
|
||||
// 广播消息给群组所有成员(排除发送者)
|
||||
h.BroadcastGroupMessage(groupID, groupMsg, client.UserID)
|
||||
|
||||
// 发送确认消息给发送者(作为meta事件)
|
||||
// 使用 meta 事件格式发送 ack
|
||||
log.Printf("[DEBUG HandleGroupMessageSend] 准备发送 ack 消息, userID=%s, messageID=%s, seq=%d",
|
||||
client.UserID, savedMessage.ID, savedMessage.Seq)
|
||||
|
||||
metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{
|
||||
"detail_type": ws.MetaDetailTypeAck,
|
||||
"conversation_id": conversationID,
|
||||
"group_id": groupID,
|
||||
"id": savedMessage.ID,
|
||||
"user_id": client.UserID,
|
||||
"sender_id": client.UserID,
|
||||
"seq": savedMessage.Seq,
|
||||
"segments": segments,
|
||||
"created_at": savedMessage.CreatedAt.UnixMilli(),
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(metaAckMsg)
|
||||
log.Printf("[DEBUG HandleGroupMessageSend] 发送 ack 消息到 channel, userID=%s, msg=%s",
|
||||
client.UserID, string(msgBytes))
|
||||
client.Send <- msgBytes
|
||||
} else {
|
||||
log.Printf("[ERROR HandleGroupMessageSend] client.Send 为 nil, userID=%s", client.UserID)
|
||||
}
|
||||
|
||||
// 处理@提及通知
|
||||
if len(mentionUsers) > 0 || mentionAll {
|
||||
// 提取文本正文(不含 @ 部分)
|
||||
textContent := dto.ExtractTextContentFromModel(segments)
|
||||
// 在通知内容前拼接被@的真实昵称,通过群成员列表查找
|
||||
mentionContent := h.buildMentionContent(groupID, mentionUsers, mentionAll, textContent)
|
||||
h.handleGroupMention(groupID, savedMessage.ID, client.UserID, mentionContent, mentionUsers, mentionAll)
|
||||
}
|
||||
}
|
||||
|
||||
// handleGroupTyping 处理群输入状态
|
||||
func (h *WebSocketHandler) handleGroupTyping(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
groupIDFloat, _ := data["groupId"].(float64)
|
||||
if groupIDFloat == 0 {
|
||||
return
|
||||
}
|
||||
groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
|
||||
|
||||
isTyping, _ := data["isTyping"].(bool)
|
||||
|
||||
// 验证用户是否是群成员
|
||||
// client.UserID 已经是 string 格式的 UUID
|
||||
isMember, err := h.groupRepo.IsMember(groupID, client.UserID)
|
||||
if err != nil || !isMember {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := h.userRepo.GetByID(client.UserID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 构造输入状态消息
|
||||
typingMsg := &ws.GroupTypingMessage{
|
||||
GroupID: groupID,
|
||||
UserID: client.UserID,
|
||||
Username: user.Username,
|
||||
IsTyping: isTyping,
|
||||
}
|
||||
|
||||
// 广播给群组其他成员
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg)
|
||||
h.BroadcastGroupNoticeExclude(groupID, wsMsg, client.UserID)
|
||||
}
|
||||
|
||||
// handleGroupReadReceipt 处理群消息已读回执
|
||||
func (h *WebSocketHandler) handleGroupReadReceipt(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
conversationID, _ := data["conversationId"].(string)
|
||||
if conversationID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
lastReadSeq, _ := data["lastReadSeq"].(float64)
|
||||
if lastReadSeq == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 标记已读
|
||||
if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil {
|
||||
log.Printf("Failed to mark group message as read: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleGroupRecall 处理群消息撤回
|
||||
func (h *WebSocketHandler) handleGroupRecall(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
messageID, _ := data["messageId"].(string)
|
||||
if messageID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
groupIDFloat, _ := data["groupId"].(float64)
|
||||
if groupIDFloat == 0 {
|
||||
return
|
||||
}
|
||||
groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
|
||||
|
||||
// 撤回消息
|
||||
if err := h.chatService.RecallMessage(context.Background(), messageID, client.UserID); err != nil {
|
||||
log.Printf("Failed to recall group message: %v", err)
|
||||
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
||||
"error": "Failed to recall message",
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(errorMsg)
|
||||
client.Send <- msgBytes
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 广播撤回通知给群组所有成员
|
||||
recallNotice := ws.CreateWSMessage(ws.MessageTypeGroupRecall, map[string]interface{}{
|
||||
"messageId": messageID,
|
||||
"groupId": groupID,
|
||||
"userId": client.UserID,
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
})
|
||||
h.BroadcastGroupNotice(groupID, recallNotice)
|
||||
}
|
||||
|
||||
// handleGroupMention 处理群消息@提及通知
|
||||
func (h *WebSocketHandler) handleGroupMention(groupID string, messageID, senderID, content string, mentionUsers []uint64, mentionAll bool) {
|
||||
// 如果@所有人,获取所有群成员
|
||||
if mentionAll {
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for mention all: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, member := range members {
|
||||
// 不通知发送者自己
|
||||
memberIDStr := member.UserID
|
||||
if memberIDStr == senderID {
|
||||
continue
|
||||
}
|
||||
|
||||
// 发送@提及通知
|
||||
mentionMsg := &ws.GroupMentionMessage{
|
||||
GroupID: groupID,
|
||||
MessageID: messageID,
|
||||
FromUserID: senderID,
|
||||
Content: truncateContent(content, 50),
|
||||
MentionAll: true,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg)
|
||||
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 处理特定用户的@提及
|
||||
for _, userID := range mentionUsers {
|
||||
// userID 是 uint64,转换为 string
|
||||
userIDStr := strconv.FormatUint(userID, 10)
|
||||
if userIDStr == senderID {
|
||||
continue // 不通知发送者自己
|
||||
}
|
||||
|
||||
mentionMsg := &ws.GroupMentionMessage{
|
||||
GroupID: groupID,
|
||||
MessageID: messageID,
|
||||
FromUserID: senderID,
|
||||
Content: truncateContent(content, 50),
|
||||
MentionAll: false,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg)
|
||||
h.wsManager.SendToUser(userIDStr, wsMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// buildMentionContent 构建@提及通知的内容,通过群成员列表查找被@用户的真实昵称
|
||||
func (h *WebSocketHandler) buildMentionContent(groupID string, mentionUsers []uint64, mentionAll bool, textBody string) string {
|
||||
var prefix string
|
||||
if mentionAll {
|
||||
prefix = "@所有人 "
|
||||
} else if len(mentionUsers) > 0 {
|
||||
// 查询群成员列表,找到被@用户的昵称
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err == nil {
|
||||
memberNickMap := make(map[string]string, len(members))
|
||||
for _, m := range members {
|
||||
displayName := m.Nickname
|
||||
if displayName == "" {
|
||||
displayName = m.UserID
|
||||
}
|
||||
memberNickMap[m.UserID] = displayName
|
||||
}
|
||||
for _, uid := range mentionUsers {
|
||||
uidStr := strconv.FormatUint(uid, 10)
|
||||
if name, ok := memberNickMap[uidStr]; ok {
|
||||
prefix += "@" + name + " "
|
||||
} else {
|
||||
prefix += "@某人 "
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for range mentionUsers {
|
||||
prefix += "@某人 "
|
||||
}
|
||||
}
|
||||
}
|
||||
return prefix + textBody
|
||||
}
|
||||
|
||||
// BroadcastGroupMessage 向群组所有成员广播消息
|
||||
func (h *WebSocketHandler) BroadcastGroupMessage(groupID string, message *ws.GroupMessage, excludeUserID string) {
|
||||
// 获取群组所有成员
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for broadcast: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建WebSocket消息
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMessage, message)
|
||||
|
||||
// 遍历成员,如果在线则发送消息
|
||||
for _, member := range members {
|
||||
memberIDStr := member.UserID
|
||||
|
||||
// 排除发送者
|
||||
if memberIDStr == excludeUserID {
|
||||
continue
|
||||
}
|
||||
|
||||
// 发送消息
|
||||
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastGroupNotice 广播群组通知给所有成员
|
||||
func (h *WebSocketHandler) BroadcastGroupNotice(groupID string, notice *ws.WSMessage) {
|
||||
// 获取群组所有成员
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for notice broadcast: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 遍历成员,如果在线则发送通知
|
||||
for _, member := range members {
|
||||
memberIDStr := member.UserID
|
||||
h.wsManager.SendToUser(memberIDStr, notice)
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastGroupNoticeExclude 广播群组通知给所有成员(排除指定用户)
|
||||
func (h *WebSocketHandler) BroadcastGroupNoticeExclude(groupID string, notice *ws.WSMessage, excludeUserID string) {
|
||||
// 获取群组所有成员
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for notice broadcast: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 遍历成员,如果在线则发送通知
|
||||
for _, member := range members {
|
||||
memberIDStr := member.UserID
|
||||
if memberIDStr == excludeUserID {
|
||||
continue
|
||||
}
|
||||
h.wsManager.SendToUser(memberIDStr, notice)
|
||||
}
|
||||
}
|
||||
|
||||
// SendGroupMemberNotice 发送群成员变动通知
|
||||
func (h *WebSocketHandler) SendGroupMemberNotice(noticeType string, groupID string, data *ws.GroupNoticeData) {
|
||||
notice := &ws.GroupNoticeMessage{
|
||||
NoticeType: noticeType,
|
||||
GroupID: groupID,
|
||||
Data: data,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupNotice, notice)
|
||||
h.BroadcastGroupNotice(groupID, wsMsg)
|
||||
}
|
||||
|
||||
// truncateContent 截断内容
|
||||
func truncateContent(content string, maxLen int) string {
|
||||
if len(content) <= maxLen {
|
||||
return content
|
||||
}
|
||||
return content[:maxLen] + "..."
|
||||
}
|
||||
|
||||
// BroadcastGroupTyping 向群组所有成员广播输入状态
|
||||
func (h *WebSocketHandler) BroadcastGroupTyping(groupID string, typingMsg *ws.GroupTypingMessage, excludeUserID string) {
|
||||
// 获取群组所有成员
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for typing broadcast: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建WebSocket消息
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg)
|
||||
|
||||
// 遍历成员,如果在线则发送消息
|
||||
for _, member := range members {
|
||||
memberIDStr := member.UserID
|
||||
|
||||
// 排除指定用户
|
||||
if memberIDStr == excludeUserID {
|
||||
continue
|
||||
}
|
||||
|
||||
// 发送消息
|
||||
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastGroupRead 向群组所有成员广播已读状态
|
||||
func (h *WebSocketHandler) BroadcastGroupRead(groupID string, readMsg map[string]interface{}, excludeUserID string) {
|
||||
// 获取群组所有成员
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for read broadcast: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建WebSocket消息
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupRead, readMsg)
|
||||
|
||||
// 遍历成员,如果在线则发送消息
|
||||
for _, member := range members {
|
||||
memberIDStr := member.UserID
|
||||
|
||||
// 排除指定用户
|
||||
if memberIDStr == excludeUserID {
|
||||
continue
|
||||
}
|
||||
|
||||
// 发送消息
|
||||
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user