package service import ( "context" "errors" "fmt" "time" "carrot_bbs/internal/dto" "carrot_bbs/internal/model" "carrot_bbs/internal/pkg/sse" "carrot_bbs/internal/repository" "gorm.io/gorm" ) // 撤回消息的时间限制(2分钟) const RecallMessageTimeout = 2 * time.Minute // ChatService 聊天服务接口 type ChatService interface { // 会话管理 GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error // 消息操作 SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) // 已读管理 MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) GetAllUnreadCount(ctx context.Context, userID string) (int64, error) // 消息扩展功能 RecallMessage(ctx context.Context, messageID string, userID string) error DeleteMessage(ctx context.Context, messageID string, userID string) error // 实时事件相关 SendTyping(ctx context.Context, senderID string, conversationID string) // 在线状态 IsUserOnline(userID string) bool // 仅保存消息到数据库,不发送实时推送(供群聊等自行推送的场景使用) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) } // chatServiceImpl 聊天服务实现 type chatServiceImpl struct { db *gorm.DB repo *repository.MessageRepository userRepo *repository.UserRepository sensitive SensitiveService sseHub *sse.Hub } // NewChatService 创建聊天服务 func NewChatService( db *gorm.DB, repo *repository.MessageRepository, userRepo *repository.UserRepository, sensitive SensitiveService, sseHub *sse.Hub, ) ChatService { return &chatServiceImpl{ db: db, repo: repo, userRepo: userRepo, sensitive: sensitive, sseHub: sseHub, } } func (s *chatServiceImpl) publishSSEToUsers(userIDs []string, event string, payload interface{}) { if s.sseHub == nil || len(userIDs) == 0 { return } s.sseHub.PublishToUsers(userIDs, event, payload) } // GetOrCreateConversation 获取或创建私聊会话 func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) { return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID) } // GetConversationList 获取用户的会话列表 func (s *chatServiceImpl) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) { return s.repo.GetConversations(userID, page, pageSize) } // GetConversationByID 获取会话详情 func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) { // 验证用户是否是会话参与者 participant, err := s.repo.GetParticipant(conversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("conversation not found or no permission") } return nil, fmt.Errorf("failed to get participant: %w", err) } // 获取会话信息 conv, err := s.repo.GetConversation(conversationID) if err != nil { return nil, fmt.Errorf("failed to get conversation: %w", err) } // 填充用户的已读位置信息 _ = participant // 可以用于返回已读位置等信息 return conv, nil } // DeleteConversationForSelf 仅自己删除会话 func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error { participant, err := s.repo.GetParticipant(conversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("conversation not found or no permission") } return fmt.Errorf("failed to get participant: %w", err) } if participant.ConversationID == "" { return errors.New("conversation not found or no permission") } if err := s.repo.HideConversationForUser(conversationID, userID); err != nil { return fmt.Errorf("failed to hide conversation: %w", err) } return nil } // SetConversationPinned 设置会话置顶(用户维度) func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error { participant, err := s.repo.GetParticipant(conversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("conversation not found or no permission") } return fmt.Errorf("failed to get participant: %w", err) } if participant.ConversationID == "" { return errors.New("conversation not found or no permission") } if err := s.repo.UpdatePinned(conversationID, userID, isPinned); err != nil { return fmt.Errorf("failed to update pinned status: %w", err) } return nil } // SendMessage 发送消息(使用 segments) func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) { // 首先验证会话是否存在 conv, err := s.repo.GetConversation(conversationID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("会话不存在,请重新创建会话") } return nil, fmt.Errorf("failed to get conversation: %w", err) } // 拉黑限制:仅拦截“被拉黑方 -> 拉黑人”方向 if conv.Type == model.ConversationTypePrivate && s.userRepo != nil { participants, pErr := s.repo.GetConversationParticipants(conversationID) if pErr != nil { return nil, fmt.Errorf("failed to get participants: %w", pErr) } var sentCount *int64 for _, p := range participants { if p.UserID == senderID { continue } blocked, bErr := s.userRepo.IsBlocked(p.UserID, senderID) if bErr != nil { return nil, fmt.Errorf("failed to check block status: %w", bErr) } if blocked { return nil, ErrUserBlocked } // 陌生人限制:对方未回关前,只允许发送一条文本消息,且禁止发送图片 isFollowedBack, fErr := s.userRepo.IsFollowing(p.UserID, senderID) if fErr != nil { return nil, fmt.Errorf("failed to check follow status: %w", fErr) } if !isFollowedBack { if containsImageSegment(segments) { return nil, errors.New("对方未关注你,暂不支持发送图片") } if sentCount == nil { c, cErr := s.repo.CountMessagesBySenderInConversation(conversationID, senderID) if cErr != nil { return nil, fmt.Errorf("failed to count sender messages: %w", cErr) } sentCount = &c } if *sentCount >= 1 { return nil, errors.New("对方未关注你前,仅允许发送一条消息") } } } } // 验证用户是否是会话参与者 participant, err := s.repo.GetParticipant(conversationID, senderID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("您不是该会话的参与者") } return nil, fmt.Errorf("failed to get participant: %w", err) } // 创建消息 message := &model.Message{ ConversationID: conversationID, SenderID: senderID, // 直接使用string类型的UUID Segments: segments, ReplyToID: replyToID, Status: model.MessageStatusNormal, } // 使用事务创建消息并更新seq if err := s.repo.CreateMessageWithSeq(message); err != nil { return nil, fmt.Errorf("failed to save message: %w", err) } // 获取会话中的参与者并发送 SSE participants, err := s.repo.GetConversationParticipants(conversationID) if err == nil { targetIDs := make([]string, 0, len(participants)) for _, p := range participants { targetIDs = append(targetIDs, p.UserID) } detailType := "private" if conv.Type == model.ConversationTypeGroup { detailType = "group" } s.publishSSEToUsers(targetIDs, "chat_message", map[string]interface{}{ "detail_type": detailType, "message": dto.ConvertMessageToResponse(message), }) for _, p := range participants { if p.UserID == senderID { continue } if totalUnread, uErr := s.repo.GetAllUnreadCount(p.UserID); uErr == nil { s.publishSSEToUsers([]string{p.UserID}, "conversation_unread", map[string]interface{}{ "conversation_id": conversationID, "total_unread": totalUnread, }) } } } _ = participant // 避免未使用变量警告 return message, nil } func containsImageSegment(segments model.MessageSegments) bool { for _, seg := range segments { if seg.Type == string(model.ContentTypeImage) || seg.Type == "image" { return true } } return false } // GetMessages 获取消息历史(分页) func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) { // 验证用户是否是会话参与者 _, err := s.repo.GetParticipant(conversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, 0, errors.New("conversation not found or no permission") } return nil, 0, fmt.Errorf("failed to get participant: %w", err) } return s.repo.GetMessages(conversationID, page, pageSize) } // GetMessagesAfterSeq 获取指定seq之后的消息(用于增量同步) func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) { // 验证用户是否是会话参与者 _, err := s.repo.GetParticipant(conversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("conversation not found or no permission") } return nil, fmt.Errorf("failed to get participant: %w", err) } if limit <= 0 { limit = 100 } return s.repo.GetMessagesAfterSeq(conversationID, afterSeq, limit) } // GetMessagesBeforeSeq 获取指定seq之前的历史消息(用于下拉加载更多) func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) { // 验证用户是否是会话参与者 _, err := s.repo.GetParticipant(conversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("conversation not found or no permission") } return nil, fmt.Errorf("failed to get participant: %w", err) } if limit <= 0 { limit = 20 } return s.repo.GetMessagesBeforeSeq(conversationID, beforeSeq, limit) } // MarkAsRead 标记已读 func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error { // 验证用户是否是会话参与者 _, err := s.repo.GetParticipant(conversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("conversation not found or no permission") } return fmt.Errorf("failed to get participant: %w", err) } // 更新参与者的已读位置 err = s.repo.UpdateLastReadSeq(conversationID, userID, seq) if err != nil { return fmt.Errorf("failed to update last read seq: %w", err) } participants, pErr := s.repo.GetConversationParticipants(conversationID) if pErr == nil { detailType := "private" groupID := "" if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup { detailType = "group" if conv.GroupID != nil { groupID = *conv.GroupID } } targetIDs := make([]string, 0, len(participants)) for _, p := range participants { targetIDs = append(targetIDs, p.UserID) } s.publishSSEToUsers(targetIDs, "message_read", map[string]interface{}{ "detail_type": detailType, "conversation_id": conversationID, "group_id": groupID, "user_id": userID, "seq": seq, }) } if totalUnread, uErr := s.repo.GetAllUnreadCount(userID); uErr == nil { s.publishSSEToUsers([]string{userID}, "conversation_unread", map[string]interface{}{ "conversation_id": conversationID, "total_unread": totalUnread, }) } return nil } // GetUnreadCount 获取指定会话的未读消息数 func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) { // 验证用户是否是会话参与者 _, err := s.repo.GetParticipant(conversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return 0, errors.New("conversation not found or no permission") } return 0, fmt.Errorf("failed to get participant: %w", err) } return s.repo.GetUnreadCount(conversationID, userID) } // GetAllUnreadCount 获取所有会话的未读消息总数 func (s *chatServiceImpl) GetAllUnreadCount(ctx context.Context, userID string) (int64, error) { return s.repo.GetAllUnreadCount(userID) } // RecallMessage 撤回消息(2分钟内) func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, userID string) error { // 获取消息 var message model.Message err := s.db.First(&message, "id = ?", messageID).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("message not found") } return fmt.Errorf("failed to get message: %w", err) } // 验证是否是消息发送者 if message.SenderIDStr() != userID { return errors.New("can only recall your own messages") } // 验证消息是否已被撤回 if message.Status == model.MessageStatusRecalled { return errors.New("message already recalled") } // 验证是否在2分钟内 if time.Since(message.CreatedAt) > RecallMessageTimeout { return errors.New("message recall timeout (2 minutes)") } // 更新消息状态为已撤回,并清空原始消息内容,仅保留撤回占位 err = s.db.Model(&message).Updates(map[string]interface{}{ "status": model.MessageStatusRecalled, "segments": model.MessageSegments{}, }).Error if err != nil { return fmt.Errorf("failed to recall message: %w", err) } if participants, pErr := s.repo.GetConversationParticipants(message.ConversationID); pErr == nil { detailType := "private" groupID := "" if conv, convErr := s.repo.GetConversation(message.ConversationID); convErr == nil && conv.Type == model.ConversationTypeGroup { detailType = "group" if conv.GroupID != nil { groupID = *conv.GroupID } } targetIDs := make([]string, 0, len(participants)) for _, p := range participants { targetIDs = append(targetIDs, p.UserID) } s.publishSSEToUsers(targetIDs, "message_recall", map[string]interface{}{ "detail_type": detailType, "conversation_id": message.ConversationID, "group_id": groupID, "message_id": messageID, "sender_id": userID, }) } return nil } // DeleteMessage 删除消息(仅对自己可见) func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, userID string) error { // 获取消息 var message model.Message err := s.db.First(&message, "id = ?", messageID).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("message not found") } return fmt.Errorf("failed to get message: %w", err) } // 验证用户是否是会话参与者 _, err = s.repo.GetParticipant(message.ConversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("no permission to delete this message") } return fmt.Errorf("failed to get participant: %w", err) } // 对于删除消息,我们使用软删除,但需要确保只对当前用户隐藏 // 这里简化处理:只有发送者可以删除自己的消息 if message.SenderIDStr() != userID { return errors.New("can only delete your own messages") } // 更新消息状态为已删除 err = s.db.Model(&message).Update("status", model.MessageStatusDeleted).Error if err != nil { return fmt.Errorf("failed to delete message: %w", err) } return nil } // SendTyping 发送正在输入状态 func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) { if s.sseHub == nil { return } // 验证用户是否是会话参与者 _, err := s.repo.GetParticipant(conversationID, senderID) if err != nil { return } // 获取会话中的其他参与者 participants, err := s.repo.GetConversationParticipants(conversationID) if err != nil { return } detailType := "private" if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup { detailType = "group" } for _, p := range participants { if p.UserID == senderID { continue } if s.sseHub != nil { s.sseHub.PublishToUser(p.UserID, "typing", map[string]interface{}{ "detail_type": detailType, "conversation_id": conversationID, "user_id": senderID, "is_typing": true, }) } } } // IsUserOnline 检查用户是否在线 func (s *chatServiceImpl) IsUserOnline(userID string) bool { if s.sseHub != nil { return s.sseHub.HasSubscribers(userID) } return false } // SaveMessage 仅保存消息到数据库,不发送实时推送 // 适用于群聊等由调用方自行负责推送的场景 func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) { // 验证会话是否存在 _, err := s.repo.GetConversation(conversationID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("会话不存在,请重新创建会话") } return nil, fmt.Errorf("failed to get conversation: %w", err) } // 验证用户是否是会话参与者 _, err = s.repo.GetParticipant(conversationID, senderID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("您不是该会话的参与者") } return nil, fmt.Errorf("failed to get participant: %w", err) } message := &model.Message{ ConversationID: conversationID, SenderID: senderID, Segments: segments, ReplyToID: replyToID, Status: model.MessageStatusNormal, } if err := s.repo.CreateMessageWithSeq(message); err != nil { return nil, fmt.Errorf("failed to save message: %w", err) } return message, nil }