Files
backend/internal/service/chat_service.go
lan 86ef150fec Replace websocket flow with SSE support in backend.
Update handlers, services, router, and data conversion logic to support server-sent events and related message pipeline changes.

Made-with: Cursor
2026-03-10 12:58:23 +08:00

571 lines
19 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"errors"
"fmt"
"time"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/sse"
"carrot_bbs/internal/repository"
"gorm.io/gorm"
)
// 撤回消息的时间限制2分钟
const RecallMessageTimeout = 2 * time.Minute
// ChatService 聊天服务接口
type ChatService interface {
// 会话管理
GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error)
GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error)
GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error)
DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error
SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error
// 消息操作
SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error)
GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error)
GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error)
// 已读管理
MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error
GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error)
GetAllUnreadCount(ctx context.Context, userID string) (int64, error)
// 消息扩展功能
RecallMessage(ctx context.Context, messageID string, userID string) error
DeleteMessage(ctx context.Context, messageID string, userID string) error
// 实时事件相关
SendTyping(ctx context.Context, senderID string, conversationID string)
// 在线状态
IsUserOnline(userID string) bool
// 仅保存消息到数据库,不发送实时推送(供群聊等自行推送的场景使用)
SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
}
// chatServiceImpl 聊天服务实现
type chatServiceImpl struct {
db *gorm.DB
repo *repository.MessageRepository
userRepo *repository.UserRepository
sensitive SensitiveService
sseHub *sse.Hub
}
// NewChatService 创建聊天服务
func NewChatService(
db *gorm.DB,
repo *repository.MessageRepository,
userRepo *repository.UserRepository,
sensitive SensitiveService,
sseHub *sse.Hub,
) ChatService {
return &chatServiceImpl{
db: db,
repo: repo,
userRepo: userRepo,
sensitive: sensitive,
sseHub: sseHub,
}
}
func (s *chatServiceImpl) publishSSEToUsers(userIDs []string, event string, payload interface{}) {
if s.sseHub == nil || len(userIDs) == 0 {
return
}
s.sseHub.PublishToUsers(userIDs, event, payload)
}
// GetOrCreateConversation 获取或创建私聊会话
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
}
// GetConversationList 获取用户的会话列表
func (s *chatServiceImpl) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
return s.repo.GetConversations(userID, page, pageSize)
}
// GetConversationByID 获取会话详情
func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) {
// 验证用户是否是会话参与者
participant, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
// 获取会话信息
conv, err := s.repo.GetConversation(conversationID)
if err != nil {
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 填充用户的已读位置信息
_ = participant // 可以用于返回已读位置等信息
return conv, nil
}
// DeleteConversationForSelf 仅自己删除会话
func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error {
participant, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
if participant.ConversationID == "" {
return errors.New("conversation not found or no permission")
}
if err := s.repo.HideConversationForUser(conversationID, userID); err != nil {
return fmt.Errorf("failed to hide conversation: %w", err)
}
return nil
}
// SetConversationPinned 设置会话置顶(用户维度)
func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error {
participant, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
if participant.ConversationID == "" {
return errors.New("conversation not found or no permission")
}
if err := s.repo.UpdatePinned(conversationID, userID, isPinned); err != nil {
return fmt.Errorf("failed to update pinned status: %w", err)
}
return nil
}
// SendMessage 发送消息(使用 segments
func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
// 首先验证会话是否存在
conv, err := s.repo.GetConversation(conversationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("会话不存在,请重新创建会话")
}
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 拉黑限制:仅拦截“被拉黑方 -> 拉黑人”方向
if conv.Type == model.ConversationTypePrivate && s.userRepo != nil {
participants, pErr := s.repo.GetConversationParticipants(conversationID)
if pErr != nil {
return nil, fmt.Errorf("failed to get participants: %w", pErr)
}
var sentCount *int64
for _, p := range participants {
if p.UserID == senderID {
continue
}
blocked, bErr := s.userRepo.IsBlocked(p.UserID, senderID)
if bErr != nil {
return nil, fmt.Errorf("failed to check block status: %w", bErr)
}
if blocked {
return nil, ErrUserBlocked
}
// 陌生人限制:对方未回关前,只允许发送一条文本消息,且禁止发送图片
isFollowedBack, fErr := s.userRepo.IsFollowing(p.UserID, senderID)
if fErr != nil {
return nil, fmt.Errorf("failed to check follow status: %w", fErr)
}
if !isFollowedBack {
if containsImageSegment(segments) {
return nil, errors.New("对方未关注你,暂不支持发送图片")
}
if sentCount == nil {
c, cErr := s.repo.CountMessagesBySenderInConversation(conversationID, senderID)
if cErr != nil {
return nil, fmt.Errorf("failed to count sender messages: %w", cErr)
}
sentCount = &c
}
if *sentCount >= 1 {
return nil, errors.New("对方未关注你前,仅允许发送一条消息")
}
}
}
}
// 验证用户是否是会话参与者
participant, err := s.repo.GetParticipant(conversationID, senderID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("您不是该会话的参与者")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
// 创建消息
message := &model.Message{
ConversationID: conversationID,
SenderID: senderID, // 直接使用string类型的UUID
Segments: segments,
ReplyToID: replyToID,
Status: model.MessageStatusNormal,
}
// 使用事务创建消息并更新seq
if err := s.repo.CreateMessageWithSeq(message); err != nil {
return nil, fmt.Errorf("failed to save message: %w", err)
}
// 获取会话中的参与者并发送 SSE
participants, err := s.repo.GetConversationParticipants(conversationID)
if err == nil {
targetIDs := make([]string, 0, len(participants))
for _, p := range participants {
targetIDs = append(targetIDs, p.UserID)
}
detailType := "private"
if conv.Type == model.ConversationTypeGroup {
detailType = "group"
}
s.publishSSEToUsers(targetIDs, "chat_message", map[string]interface{}{
"detail_type": detailType,
"message": dto.ConvertMessageToResponse(message),
})
for _, p := range participants {
if p.UserID == senderID {
continue
}
if totalUnread, uErr := s.repo.GetAllUnreadCount(p.UserID); uErr == nil {
s.publishSSEToUsers([]string{p.UserID}, "conversation_unread", map[string]interface{}{
"conversation_id": conversationID,
"total_unread": totalUnread,
})
}
}
}
_ = participant // 避免未使用变量警告
return message, nil
}
func containsImageSegment(segments model.MessageSegments) bool {
for _, seg := range segments {
if seg.Type == string(model.ContentTypeImage) || seg.Type == "image" {
return true
}
}
return false
}
// GetMessages 获取消息历史(分页)
func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, 0, errors.New("conversation not found or no permission")
}
return nil, 0, fmt.Errorf("failed to get participant: %w", err)
}
return s.repo.GetMessages(conversationID, page, pageSize)
}
// GetMessagesAfterSeq 获取指定seq之后的消息用于增量同步
func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
if limit <= 0 {
limit = 100
}
return s.repo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
}
// GetMessagesBeforeSeq 获取指定seq之前的历史消息用于下拉加载更多
func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
if limit <= 0 {
limit = 20
}
return s.repo.GetMessagesBeforeSeq(conversationID, beforeSeq, limit)
}
// MarkAsRead 标记已读
func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
// 更新参与者的已读位置
err = s.repo.UpdateLastReadSeq(conversationID, userID, seq)
if err != nil {
return fmt.Errorf("failed to update last read seq: %w", err)
}
participants, pErr := s.repo.GetConversationParticipants(conversationID)
if pErr == nil {
detailType := "private"
groupID := ""
if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
detailType = "group"
if conv.GroupID != nil {
groupID = *conv.GroupID
}
}
targetIDs := make([]string, 0, len(participants))
for _, p := range participants {
targetIDs = append(targetIDs, p.UserID)
}
s.publishSSEToUsers(targetIDs, "message_read", map[string]interface{}{
"detail_type": detailType,
"conversation_id": conversationID,
"group_id": groupID,
"user_id": userID,
"seq": seq,
})
}
if totalUnread, uErr := s.repo.GetAllUnreadCount(userID); uErr == nil {
s.publishSSEToUsers([]string{userID}, "conversation_unread", map[string]interface{}{
"conversation_id": conversationID,
"total_unread": totalUnread,
})
}
return nil
}
// GetUnreadCount 获取指定会话的未读消息数
func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return 0, errors.New("conversation not found or no permission")
}
return 0, fmt.Errorf("failed to get participant: %w", err)
}
return s.repo.GetUnreadCount(conversationID, userID)
}
// GetAllUnreadCount 获取所有会话的未读消息总数
func (s *chatServiceImpl) GetAllUnreadCount(ctx context.Context, userID string) (int64, error) {
return s.repo.GetAllUnreadCount(userID)
}
// RecallMessage 撤回消息2分钟内
func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, userID string) error {
// 获取消息
var message model.Message
err := s.db.First(&message, "id = ?", messageID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("message not found")
}
return fmt.Errorf("failed to get message: %w", err)
}
// 验证是否是消息发送者
if message.SenderIDStr() != userID {
return errors.New("can only recall your own messages")
}
// 验证消息是否已被撤回
if message.Status == model.MessageStatusRecalled {
return errors.New("message already recalled")
}
// 验证是否在2分钟内
if time.Since(message.CreatedAt) > RecallMessageTimeout {
return errors.New("message recall timeout (2 minutes)")
}
// 更新消息状态为已撤回,并清空原始消息内容,仅保留撤回占位
err = s.db.Model(&message).Updates(map[string]interface{}{
"status": model.MessageStatusRecalled,
"segments": model.MessageSegments{},
}).Error
if err != nil {
return fmt.Errorf("failed to recall message: %w", err)
}
if participants, pErr := s.repo.GetConversationParticipants(message.ConversationID); pErr == nil {
detailType := "private"
groupID := ""
if conv, convErr := s.repo.GetConversation(message.ConversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
detailType = "group"
if conv.GroupID != nil {
groupID = *conv.GroupID
}
}
targetIDs := make([]string, 0, len(participants))
for _, p := range participants {
targetIDs = append(targetIDs, p.UserID)
}
s.publishSSEToUsers(targetIDs, "message_recall", map[string]interface{}{
"detail_type": detailType,
"conversation_id": message.ConversationID,
"group_id": groupID,
"message_id": messageID,
"sender_id": userID,
})
}
return nil
}
// DeleteMessage 删除消息(仅对自己可见)
func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, userID string) error {
// 获取消息
var message model.Message
err := s.db.First(&message, "id = ?", messageID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("message not found")
}
return fmt.Errorf("failed to get message: %w", err)
}
// 验证用户是否是会话参与者
_, err = s.repo.GetParticipant(message.ConversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("no permission to delete this message")
}
return fmt.Errorf("failed to get participant: %w", err)
}
// 对于删除消息,我们使用软删除,但需要确保只对当前用户隐藏
// 这里简化处理:只有发送者可以删除自己的消息
if message.SenderIDStr() != userID {
return errors.New("can only delete your own messages")
}
// 更新消息状态为已删除
err = s.db.Model(&message).Update("status", model.MessageStatusDeleted).Error
if err != nil {
return fmt.Errorf("failed to delete message: %w", err)
}
return nil
}
// SendTyping 发送正在输入状态
func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) {
if s.sseHub == nil {
return
}
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, senderID)
if err != nil {
return
}
// 获取会话中的其他参与者
participants, err := s.repo.GetConversationParticipants(conversationID)
if err != nil {
return
}
detailType := "private"
if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
detailType = "group"
}
for _, p := range participants {
if p.UserID == senderID {
continue
}
if s.sseHub != nil {
s.sseHub.PublishToUser(p.UserID, "typing", map[string]interface{}{
"detail_type": detailType,
"conversation_id": conversationID,
"user_id": senderID,
"is_typing": true,
})
}
}
}
// IsUserOnline 检查用户是否在线
func (s *chatServiceImpl) IsUserOnline(userID string) bool {
if s.sseHub != nil {
return s.sseHub.HasSubscribers(userID)
}
return false
}
// SaveMessage 仅保存消息到数据库,不发送实时推送
// 适用于群聊等由调用方自行负责推送的场景
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
// 验证会话是否存在
_, err := s.repo.GetConversation(conversationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("会话不存在,请重新创建会话")
}
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 验证用户是否是会话参与者
_, err = s.repo.GetParticipant(conversationID, senderID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("您不是该会话的参与者")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
message := &model.Message{
ConversationID: conversationID,
SenderID: senderID,
Segments: segments,
ReplyToID: replyToID,
Status: model.MessageStatusNormal,
}
if err := s.repo.CreateMessageWithSeq(message); err != nil {
return nil, fmt.Errorf("failed to save message: %w", err)
}
return message, nil
}