Files
backend/internal/service/chat_service.go

571 lines
19 KiB
Go
Raw Permalink Normal View History

package service
import (
"context"
"errors"
"fmt"
"time"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/sse"
"carrot_bbs/internal/repository"
"gorm.io/gorm"
)
// 撤回消息的时间限制2分钟
const RecallMessageTimeout = 2 * time.Minute
// ChatService 聊天服务接口
type ChatService interface {
// 会话管理
GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error)
GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error)
GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error)
DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error
SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error
// 消息操作
SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error)
GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error)
GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error)
// 已读管理
MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error
GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error)
GetAllUnreadCount(ctx context.Context, userID string) (int64, error)
// 消息扩展功能
RecallMessage(ctx context.Context, messageID string, userID string) error
DeleteMessage(ctx context.Context, messageID string, userID string) error
// 实时事件相关
SendTyping(ctx context.Context, senderID string, conversationID string)
// 在线状态
IsUserOnline(userID string) bool
// 仅保存消息到数据库,不发送实时推送(供群聊等自行推送的场景使用)
SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
}
// chatServiceImpl 聊天服务实现
type chatServiceImpl struct {
db *gorm.DB
repo *repository.MessageRepository
userRepo *repository.UserRepository
sensitive SensitiveService
sseHub *sse.Hub
}
// NewChatService 创建聊天服务
func NewChatService(
db *gorm.DB,
repo *repository.MessageRepository,
userRepo *repository.UserRepository,
sensitive SensitiveService,
sseHub *sse.Hub,
) ChatService {
return &chatServiceImpl{
db: db,
repo: repo,
userRepo: userRepo,
sensitive: sensitive,
sseHub: sseHub,
}
}
func (s *chatServiceImpl) publishSSEToUsers(userIDs []string, event string, payload interface{}) {
if s.sseHub == nil || len(userIDs) == 0 {
return
}
s.sseHub.PublishToUsers(userIDs, event, payload)
}
// GetOrCreateConversation 获取或创建私聊会话
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
}
// GetConversationList 获取用户的会话列表
func (s *chatServiceImpl) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
return s.repo.GetConversations(userID, page, pageSize)
}
// GetConversationByID 获取会话详情
func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) {
// 验证用户是否是会话参与者
participant, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
// 获取会话信息
conv, err := s.repo.GetConversation(conversationID)
if err != nil {
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 填充用户的已读位置信息
_ = participant // 可以用于返回已读位置等信息
return conv, nil
}
// DeleteConversationForSelf 仅自己删除会话
func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error {
participant, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
if participant.ConversationID == "" {
return errors.New("conversation not found or no permission")
}
if err := s.repo.HideConversationForUser(conversationID, userID); err != nil {
return fmt.Errorf("failed to hide conversation: %w", err)
}
return nil
}
// SetConversationPinned 设置会话置顶(用户维度)
func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error {
participant, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
if participant.ConversationID == "" {
return errors.New("conversation not found or no permission")
}
if err := s.repo.UpdatePinned(conversationID, userID, isPinned); err != nil {
return fmt.Errorf("failed to update pinned status: %w", err)
}
return nil
}
// SendMessage 发送消息(使用 segments
func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
// 首先验证会话是否存在
conv, err := s.repo.GetConversation(conversationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("会话不存在,请重新创建会话")
}
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 拉黑限制:仅拦截“被拉黑方 -> 拉黑人”方向
if conv.Type == model.ConversationTypePrivate && s.userRepo != nil {
participants, pErr := s.repo.GetConversationParticipants(conversationID)
if pErr != nil {
return nil, fmt.Errorf("failed to get participants: %w", pErr)
}
var sentCount *int64
for _, p := range participants {
if p.UserID == senderID {
continue
}
blocked, bErr := s.userRepo.IsBlocked(p.UserID, senderID)
if bErr != nil {
return nil, fmt.Errorf("failed to check block status: %w", bErr)
}
if blocked {
return nil, ErrUserBlocked
}
// 陌生人限制:对方未回关前,只允许发送一条文本消息,且禁止发送图片
isFollowedBack, fErr := s.userRepo.IsFollowing(p.UserID, senderID)
if fErr != nil {
return nil, fmt.Errorf("failed to check follow status: %w", fErr)
}
if !isFollowedBack {
if containsImageSegment(segments) {
return nil, errors.New("对方未关注你,暂不支持发送图片")
}
if sentCount == nil {
c, cErr := s.repo.CountMessagesBySenderInConversation(conversationID, senderID)
if cErr != nil {
return nil, fmt.Errorf("failed to count sender messages: %w", cErr)
}
sentCount = &c
}
if *sentCount >= 1 {
return nil, errors.New("对方未关注你前,仅允许发送一条消息")
}
}
}
}
// 验证用户是否是会话参与者
participant, err := s.repo.GetParticipant(conversationID, senderID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("您不是该会话的参与者")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
// 创建消息
message := &model.Message{
ConversationID: conversationID,
SenderID: senderID, // 直接使用string类型的UUID
Segments: segments,
ReplyToID: replyToID,
Status: model.MessageStatusNormal,
}
// 使用事务创建消息并更新seq
if err := s.repo.CreateMessageWithSeq(message); err != nil {
return nil, fmt.Errorf("failed to save message: %w", err)
}
// 获取会话中的参与者并发送 SSE
participants, err := s.repo.GetConversationParticipants(conversationID)
if err == nil {
targetIDs := make([]string, 0, len(participants))
for _, p := range participants {
targetIDs = append(targetIDs, p.UserID)
}
detailType := "private"
if conv.Type == model.ConversationTypeGroup {
detailType = "group"
}
s.publishSSEToUsers(targetIDs, "chat_message", map[string]interface{}{
"detail_type": detailType,
"message": dto.ConvertMessageToResponse(message),
})
for _, p := range participants {
if p.UserID == senderID {
continue
}
if totalUnread, uErr := s.repo.GetAllUnreadCount(p.UserID); uErr == nil {
s.publishSSEToUsers([]string{p.UserID}, "conversation_unread", map[string]interface{}{
"conversation_id": conversationID,
"total_unread": totalUnread,
})
}
}
}
_ = participant // 避免未使用变量警告
return message, nil
}
func containsImageSegment(segments model.MessageSegments) bool {
for _, seg := range segments {
if seg.Type == string(model.ContentTypeImage) || seg.Type == "image" {
return true
}
}
return false
}
// GetMessages 获取消息历史(分页)
func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, 0, errors.New("conversation not found or no permission")
}
return nil, 0, fmt.Errorf("failed to get participant: %w", err)
}
return s.repo.GetMessages(conversationID, page, pageSize)
}
// GetMessagesAfterSeq 获取指定seq之后的消息用于增量同步
func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
if limit <= 0 {
limit = 100
}
return s.repo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
}
// GetMessagesBeforeSeq 获取指定seq之前的历史消息用于下拉加载更多
func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
if limit <= 0 {
limit = 20
}
return s.repo.GetMessagesBeforeSeq(conversationID, beforeSeq, limit)
}
// MarkAsRead 标记已读
func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
// 更新参与者的已读位置
err = s.repo.UpdateLastReadSeq(conversationID, userID, seq)
if err != nil {
return fmt.Errorf("failed to update last read seq: %w", err)
}
participants, pErr := s.repo.GetConversationParticipants(conversationID)
if pErr == nil {
detailType := "private"
groupID := ""
if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
detailType = "group"
if conv.GroupID != nil {
groupID = *conv.GroupID
}
}
targetIDs := make([]string, 0, len(participants))
for _, p := range participants {
targetIDs = append(targetIDs, p.UserID)
}
s.publishSSEToUsers(targetIDs, "message_read", map[string]interface{}{
"detail_type": detailType,
"conversation_id": conversationID,
"group_id": groupID,
"user_id": userID,
"seq": seq,
})
}
if totalUnread, uErr := s.repo.GetAllUnreadCount(userID); uErr == nil {
s.publishSSEToUsers([]string{userID}, "conversation_unread", map[string]interface{}{
"conversation_id": conversationID,
"total_unread": totalUnread,
})
}
return nil
}
// GetUnreadCount 获取指定会话的未读消息数
func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return 0, errors.New("conversation not found or no permission")
}
return 0, fmt.Errorf("failed to get participant: %w", err)
}
return s.repo.GetUnreadCount(conversationID, userID)
}
// GetAllUnreadCount 获取所有会话的未读消息总数
func (s *chatServiceImpl) GetAllUnreadCount(ctx context.Context, userID string) (int64, error) {
return s.repo.GetAllUnreadCount(userID)
}
// RecallMessage 撤回消息2分钟内
func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, userID string) error {
// 获取消息
var message model.Message
err := s.db.First(&message, "id = ?", messageID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("message not found")
}
return fmt.Errorf("failed to get message: %w", err)
}
// 验证是否是消息发送者
if message.SenderIDStr() != userID {
return errors.New("can only recall your own messages")
}
// 验证消息是否已被撤回
if message.Status == model.MessageStatusRecalled {
return errors.New("message already recalled")
}
// 验证是否在2分钟内
if time.Since(message.CreatedAt) > RecallMessageTimeout {
return errors.New("message recall timeout (2 minutes)")
}
// 更新消息状态为已撤回,并清空原始消息内容,仅保留撤回占位
err = s.db.Model(&message).Updates(map[string]interface{}{
"status": model.MessageStatusRecalled,
"segments": model.MessageSegments{},
}).Error
if err != nil {
return fmt.Errorf("failed to recall message: %w", err)
}
if participants, pErr := s.repo.GetConversationParticipants(message.ConversationID); pErr == nil {
detailType := "private"
groupID := ""
if conv, convErr := s.repo.GetConversation(message.ConversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
detailType = "group"
if conv.GroupID != nil {
groupID = *conv.GroupID
}
}
targetIDs := make([]string, 0, len(participants))
for _, p := range participants {
targetIDs = append(targetIDs, p.UserID)
}
s.publishSSEToUsers(targetIDs, "message_recall", map[string]interface{}{
"detail_type": detailType,
"conversation_id": message.ConversationID,
"group_id": groupID,
"message_id": messageID,
"sender_id": userID,
})
}
return nil
}
// DeleteMessage 删除消息(仅对自己可见)
func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, userID string) error {
// 获取消息
var message model.Message
err := s.db.First(&message, "id = ?", messageID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("message not found")
}
return fmt.Errorf("failed to get message: %w", err)
}
// 验证用户是否是会话参与者
_, err = s.repo.GetParticipant(message.ConversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("no permission to delete this message")
}
return fmt.Errorf("failed to get participant: %w", err)
}
// 对于删除消息,我们使用软删除,但需要确保只对当前用户隐藏
// 这里简化处理:只有发送者可以删除自己的消息
if message.SenderIDStr() != userID {
return errors.New("can only delete your own messages")
}
// 更新消息状态为已删除
err = s.db.Model(&message).Update("status", model.MessageStatusDeleted).Error
if err != nil {
return fmt.Errorf("failed to delete message: %w", err)
}
return nil
}
// SendTyping 发送正在输入状态
func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) {
if s.sseHub == nil {
return
}
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, senderID)
if err != nil {
return
}
// 获取会话中的其他参与者
participants, err := s.repo.GetConversationParticipants(conversationID)
if err != nil {
return
}
detailType := "private"
if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
detailType = "group"
}
for _, p := range participants {
if p.UserID == senderID {
continue
}
if s.sseHub != nil {
s.sseHub.PublishToUser(p.UserID, "typing", map[string]interface{}{
"detail_type": detailType,
"conversation_id": conversationID,
"user_id": senderID,
"is_typing": true,
})
}
}
}
// IsUserOnline 检查用户是否在线
func (s *chatServiceImpl) IsUserOnline(userID string) bool {
if s.sseHub != nil {
return s.sseHub.HasSubscribers(userID)
}
return false
}
// SaveMessage 仅保存消息到数据库,不发送实时推送
// 适用于群聊等由调用方自行负责推送的场景
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
// 验证会话是否存在
_, err := s.repo.GetConversation(conversationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("会话不存在,请重新创建会话")
}
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 验证用户是否是会话参与者
_, err = s.repo.GetParticipant(conversationID, senderID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("您不是该会话的参与者")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
message := &model.Message{
ConversationID: conversationID,
SenderID: senderID,
Segments: segments,
ReplyToID: replyToID,
Status: model.MessageStatusNormal,
}
if err := s.repo.CreateMessageWithSeq(message); err != nil {
return nil, fmt.Errorf("failed to save message: %w", err)
}
return message, nil
}