Files
backend/internal/service/chat_service.go
lan 0a0cbacbcc feat(schedule): add course table screens and navigation
Add complete schedule functionality including:
- Schedule screen with weekly course table view
- Course detail screen with transparent modal presentation
- New ScheduleStack navigator integrated into main tab bar
- Schedule service for API interactions
- Type definitions for course entities

Also includes bug fixes for group invite/request handlers
to include required groupId parameter.
2026-03-12 08:38:14 +08:00

726 lines
24 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"errors"
"fmt"
"log"
"time"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/sse"
"carrot_bbs/internal/repository"
"gorm.io/gorm"
)
// 撤回消息的时间限制2分钟
const RecallMessageTimeout = 2 * time.Minute
// ChatService 聊天服务接口
type ChatService interface {
// 会话管理
GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error)
GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error)
GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error)
DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error
SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error
// 消息操作
SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error)
GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error)
GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error)
// 已读管理
MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error
GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error)
GetAllUnreadCount(ctx context.Context, userID string) (int64, error)
// 消息扩展功能
RecallMessage(ctx context.Context, messageID string, userID string) error
DeleteMessage(ctx context.Context, messageID string, userID string) error
// 实时事件相关
SendTyping(ctx context.Context, senderID string, conversationID string)
// 在线状态
IsUserOnline(userID string) bool
// 仅保存消息到数据库,不发送实时推送(供群聊等自行推送的场景使用)
SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
}
// chatServiceImpl 聊天服务实现
type chatServiceImpl struct {
db *gorm.DB
repo *repository.MessageRepository
userRepo *repository.UserRepository
sensitive SensitiveService
sseHub *sse.Hub
// 缓存相关字段
conversationCache *cache.ConversationCache
}
// NewChatService 创建聊天服务
func NewChatService(
db *gorm.DB,
repo *repository.MessageRepository,
userRepo *repository.UserRepository,
sensitive SensitiveService,
sseHub *sse.Hub,
) ChatService {
// 创建适配器
convRepoAdapter := cache.NewConversationRepositoryAdapter(repo)
msgRepoAdapter := cache.NewMessageRepositoryAdapter(repo)
// 创建会话缓存
conversationCache := cache.NewConversationCache(
cache.GetCache(),
convRepoAdapter,
msgRepoAdapter,
cache.DefaultConversationCacheSettings(),
)
return &chatServiceImpl{
db: db,
repo: repo,
userRepo: userRepo,
sensitive: sensitive,
sseHub: sseHub,
conversationCache: conversationCache,
}
}
func (s *chatServiceImpl) publishSSEToUsers(userIDs []string, event string, payload interface{}) {
if s.sseHub == nil || len(userIDs) == 0 {
return
}
s.sseHub.PublishToUsers(userIDs, event, payload)
}
// GetOrCreateConversation 获取或创建私聊会话
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
conv, err := s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
if err != nil {
return nil, err
}
// 失效会话列表缓存
if s.conversationCache != nil {
s.conversationCache.InvalidateConversationList(user1ID)
s.conversationCache.InvalidateConversationList(user2ID)
}
return conv, nil
}
// GetConversationList 获取用户的会话列表(带缓存)
func (s *chatServiceImpl) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
// 优先使用缓存
if s.conversationCache != nil {
return s.conversationCache.GetConversationList(ctx, userID, page, pageSize)
}
return s.repo.GetConversations(userID, page, pageSize)
}
// GetConversationByID 获取会话详情(带缓存)
func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) {
// 验证用户是否是会话参与者
participant, err := s.getParticipant(ctx, conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
// 获取会话信息(优先使用缓存)
var conv *model.Conversation
if s.conversationCache != nil {
conv, err = s.conversationCache.GetConversation(ctx, conversationID)
} else {
conv, err = s.repo.GetConversation(conversationID)
}
if err != nil {
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
_ = participant // 可以用于返回已读位置等信息
return conv, nil
}
// getParticipant 获取参与者信息(优先使用缓存)
func (s *chatServiceImpl) getParticipant(ctx context.Context, conversationID, userID string) (*model.ConversationParticipant, error) {
if s.conversationCache != nil {
return s.conversationCache.GetParticipant(ctx, conversationID, userID)
}
return s.repo.GetParticipant(conversationID, userID)
}
// DeleteConversationForSelf 仅自己删除会话
func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error {
participant, err := s.getParticipant(ctx, conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
if participant.ConversationID == "" {
return errors.New("conversation not found or no permission")
}
if err := s.repo.HideConversationForUser(conversationID, userID); err != nil {
return fmt.Errorf("failed to hide conversation: %w", err)
}
// 失效会话列表缓存
if s.conversationCache != nil {
s.conversationCache.InvalidateConversationList(userID)
}
return nil
}
// SetConversationPinned 设置会话置顶(用户维度)
func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error {
participant, err := s.getParticipant(ctx, conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
if participant.ConversationID == "" {
return errors.New("conversation not found or no permission")
}
if err := s.repo.UpdatePinned(conversationID, userID, isPinned); err != nil {
return fmt.Errorf("failed to update pinned status: %w", err)
}
// 失效缓存
if s.conversationCache != nil {
s.conversationCache.InvalidateParticipant(conversationID, userID)
s.conversationCache.InvalidateConversationList(userID)
}
return nil
}
// SendMessage 发送消息(使用 segments
func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
// 首先验证会话是否存在
conv, err := s.getConversation(ctx, conversationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("会话不存在,请重新创建会话")
}
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 拉黑限制:仅拦截"被拉黑方 -> 拉黑人"方向
if conv.Type == model.ConversationTypePrivate && s.userRepo != nil {
participants, pErr := s.getParticipants(ctx, conversationID)
if pErr != nil {
return nil, fmt.Errorf("failed to get participants: %w", pErr)
}
var sentCount *int64
for _, p := range participants {
if p.UserID == senderID {
continue
}
blocked, bErr := s.userRepo.IsBlocked(p.UserID, senderID)
if bErr != nil {
return nil, fmt.Errorf("failed to check block status: %w", bErr)
}
if blocked {
return nil, ErrUserBlocked
}
// 陌生人限制:对方未回关前,只允许发送一条文本消息,且禁止发送图片
isFollowedBack, fErr := s.userRepo.IsFollowing(p.UserID, senderID)
if fErr != nil {
return nil, fmt.Errorf("failed to check follow status: %w", fErr)
}
if !isFollowedBack {
if containsImageSegment(segments) {
return nil, errors.New("对方未关注你,暂不支持发送图片")
}
if sentCount == nil {
c, cErr := s.repo.CountMessagesBySenderInConversation(conversationID, senderID)
if cErr != nil {
return nil, fmt.Errorf("failed to count sender messages: %w", cErr)
}
sentCount = &c
}
if *sentCount >= 1 {
return nil, errors.New("对方未关注你前,仅允许发送一条消息")
}
}
}
}
// 验证用户是否是会话参与者
participant, err := s.getParticipant(ctx, conversationID, senderID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("您不是该会话的参与者")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
// 创建消息
message := &model.Message{
ConversationID: conversationID,
SenderID: senderID, // 直接使用string类型的UUID
Segments: segments,
ReplyToID: replyToID,
Status: model.MessageStatusNormal,
}
// 使用事务创建消息并更新seq
if err := s.repo.CreateMessageWithSeq(message); err != nil {
return nil, fmt.Errorf("failed to save message: %w", err)
}
// 新消息会改变分页结果,先失效分页缓存,避免读到旧列表
if s.conversationCache != nil {
s.conversationCache.InvalidateMessagePages(conversationID)
}
// 异步写入缓存
go func() {
if err := s.cacheMessage(context.Background(), conversationID, message); err != nil {
log.Printf("[ChatService] async cache message failed, convID=%s, msgID=%s, err=%v", conversationID, message.ID, err)
}
}()
// 获取会话中的参与者并发送 SSE
participants, err := s.getParticipants(ctx, conversationID)
if err == nil {
targetIDs := make([]string, 0, len(participants))
for _, p := range participants {
// 私聊场景下,发送者已经从 HTTP 响应拿到消息,避免再通过 SSE 回推导致本端重复展示。
if conv.Type == model.ConversationTypePrivate && p.UserID == senderID {
continue
}
targetIDs = append(targetIDs, p.UserID)
}
detailType := "private"
if conv.Type == model.ConversationTypeGroup {
detailType = "group"
}
s.publishSSEToUsers(targetIDs, "chat_message", map[string]interface{}{
"detail_type": detailType,
"message": dto.ConvertMessageToResponse(message),
})
for _, p := range participants {
if p.UserID == senderID {
continue
}
// 失效未读数缓存
if s.conversationCache != nil {
s.conversationCache.InvalidateUnreadCount(p.UserID, conversationID)
}
if totalUnread, uErr := s.repo.GetAllUnreadCount(p.UserID); uErr == nil {
s.publishSSEToUsers([]string{p.UserID}, "conversation_unread", map[string]interface{}{
"conversation_id": conversationID,
"total_unread": totalUnread,
})
}
}
}
// 失效会话列表缓存
if s.conversationCache != nil {
for _, p := range participants {
s.conversationCache.InvalidateConversationList(p.UserID)
}
}
_ = participant // 避免未使用变量警告
return message, nil
}
// getConversation 获取会话(优先使用缓存)
func (s *chatServiceImpl) getConversation(ctx context.Context, conversationID string) (*model.Conversation, error) {
if s.conversationCache != nil {
return s.conversationCache.GetConversation(ctx, conversationID)
}
return s.repo.GetConversation(conversationID)
}
// getParticipants 获取会话参与者(优先使用缓存)
func (s *chatServiceImpl) getParticipants(ctx context.Context, conversationID string) ([]*model.ConversationParticipant, error) {
if s.conversationCache != nil {
return s.conversationCache.GetParticipants(ctx, conversationID)
}
return s.repo.GetConversationParticipants(conversationID)
}
// cacheMessage 缓存消息(内部方法)
func (s *chatServiceImpl) cacheMessage(ctx context.Context, convID string, msg *model.Message) error {
if s.conversationCache == nil {
return nil
}
asyncCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
return s.conversationCache.CacheMessage(asyncCtx, convID, msg)
}
func containsImageSegment(segments model.MessageSegments) bool {
for _, seg := range segments {
if seg.Type == string(model.ContentTypeImage) || seg.Type == "image" {
return true
}
}
return false
}
// GetMessages 获取消息历史(分页,带缓存)
func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) {
// 验证用户是否是会话参与者
_, err := s.getParticipant(ctx, conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, 0, errors.New("conversation not found or no permission")
}
return nil, 0, fmt.Errorf("failed to get participant: %w", err)
}
// 优先使用缓存
if s.conversationCache != nil {
return s.conversationCache.GetMessages(ctx, conversationID, page, pageSize)
}
return s.repo.GetMessages(conversationID, page, pageSize)
}
// GetMessagesAfterSeq 获取指定seq之后的消息用于增量同步
func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) {
// 验证用户是否是会话参与者
_, err := s.getParticipant(ctx, conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
if limit <= 0 {
limit = 100
}
return s.repo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
}
// GetMessagesBeforeSeq 获取指定seq之前的历史消息用于下拉加载更多
func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) {
// 验证用户是否是会话参与者
_, err := s.getParticipant(ctx, conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
if limit <= 0 {
limit = 20
}
return s.repo.GetMessagesBeforeSeq(conversationID, beforeSeq, limit)
}
// MarkAsRead 标记已读
func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error {
// 验证用户是否是会话参与者
_, err := s.getParticipant(ctx, conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
// 1. 先写入DB保证数据一致性DB是唯一数据源
err = s.repo.UpdateLastReadSeq(conversationID, userID, seq)
if err != nil {
return fmt.Errorf("failed to update last read seq: %w", err)
}
// 2. DB 写入成功后失效缓存Cache-Aside 模式)
if s.conversationCache != nil {
// 失效参与者缓存,下次读取时会从 DB 加载最新数据
s.conversationCache.InvalidateParticipant(conversationID, userID)
// 失效未读数缓存
s.conversationCache.InvalidateUnreadCount(userID, conversationID)
// 失效会话列表缓存
s.conversationCache.InvalidateConversationList(userID)
}
participants, pErr := s.getParticipants(ctx, conversationID)
if pErr == nil {
detailType := "private"
groupID := ""
if conv, convErr := s.getConversation(ctx, conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
detailType = "group"
if conv.GroupID != nil {
groupID = *conv.GroupID
}
}
targetIDs := make([]string, 0, len(participants))
for _, p := range participants {
targetIDs = append(targetIDs, p.UserID)
}
s.publishSSEToUsers(targetIDs, "message_read", map[string]interface{}{
"detail_type": detailType,
"conversation_id": conversationID,
"group_id": groupID,
"user_id": userID,
"seq": seq,
})
}
if totalUnread, uErr := s.repo.GetAllUnreadCount(userID); uErr == nil {
s.publishSSEToUsers([]string{userID}, "conversation_unread", map[string]interface{}{
"conversation_id": conversationID,
"total_unread": totalUnread,
})
}
return nil
}
// GetUnreadCount 获取指定会话的未读消息数(带缓存)
func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
// 验证用户是否是会话参与者
_, err := s.getParticipant(ctx, conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return 0, errors.New("conversation not found or no permission")
}
return 0, fmt.Errorf("failed to get participant: %w", err)
}
// 优先使用缓存
if s.conversationCache != nil {
return s.conversationCache.GetUnreadCount(ctx, userID, conversationID)
}
return s.repo.GetUnreadCount(conversationID, userID)
}
// GetAllUnreadCount 获取所有会话的未读消息总数
func (s *chatServiceImpl) GetAllUnreadCount(ctx context.Context, userID string) (int64, error) {
return s.repo.GetAllUnreadCount(userID)
}
// RecallMessage 撤回消息2分钟内
func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, userID string) error {
// 获取消息
var message model.Message
err := s.db.First(&message, "id = ?", messageID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("message not found")
}
return fmt.Errorf("failed to get message: %w", err)
}
// 验证是否是消息发送者
if message.SenderIDStr() != userID {
return errors.New("can only recall your own messages")
}
// 验证消息是否已被撤回
if message.Status == model.MessageStatusRecalled {
return errors.New("message already recalled")
}
// 验证是否在2分钟内
if time.Since(message.CreatedAt) > RecallMessageTimeout {
return errors.New("message recall timeout (2 minutes)")
}
// 更新消息状态为已撤回,并清空原始消息内容,仅保留撤回占位
err = s.db.Model(&message).Updates(map[string]interface{}{
"status": model.MessageStatusRecalled,
"segments": model.MessageSegments{},
}).Error
if err != nil {
return fmt.Errorf("failed to recall message: %w", err)
}
// 失效消息缓存
if s.conversationCache != nil {
s.conversationCache.InvalidateConversation(message.ConversationID)
}
if participants, pErr := s.getParticipants(ctx, message.ConversationID); pErr == nil {
detailType := "private"
groupID := ""
if conv, convErr := s.getConversation(ctx, message.ConversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
detailType = "group"
if conv.GroupID != nil {
groupID = *conv.GroupID
}
}
targetIDs := make([]string, 0, len(participants))
for _, p := range participants {
targetIDs = append(targetIDs, p.UserID)
}
s.publishSSEToUsers(targetIDs, "message_recall", map[string]interface{}{
"detail_type": detailType,
"conversation_id": message.ConversationID,
"group_id": groupID,
"message_id": messageID,
"sender_id": userID,
})
}
return nil
}
// DeleteMessage 删除消息(仅对自己可见)
func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, userID string) error {
// 获取消息
var message model.Message
err := s.db.First(&message, "id = ?", messageID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("message not found")
}
return fmt.Errorf("failed to get message: %w", err)
}
// 验证用户是否是会话参与者
_, err = s.getParticipant(ctx, message.ConversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("no permission to delete this message")
}
return fmt.Errorf("failed to get participant: %w", err)
}
// 对于删除消息,我们使用软删除,但需要确保只对当前用户隐藏
// 这里简化处理:只有发送者可以删除自己的消息
if message.SenderIDStr() != userID {
return errors.New("can only delete your own messages")
}
// 更新消息状态为已删除
err = s.db.Model(&message).Update("status", model.MessageStatusDeleted).Error
if err != nil {
return fmt.Errorf("failed to delete message: %w", err)
}
// 失效消息缓存
if s.conversationCache != nil {
s.conversationCache.InvalidateConversation(message.ConversationID)
}
return nil
}
// SendTyping 发送正在输入状态
func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) {
if s.sseHub == nil {
return
}
// 验证用户是否是会话参与者
_, err := s.getParticipant(ctx, conversationID, senderID)
if err != nil {
return
}
// 获取会话中的其他参与者
participants, err := s.getParticipants(ctx, conversationID)
if err != nil {
return
}
detailType := "private"
if conv, convErr := s.getConversation(ctx, conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
detailType = "group"
}
for _, p := range participants {
if p.UserID == senderID {
continue
}
if s.sseHub != nil {
s.sseHub.PublishToUser(p.UserID, "typing", map[string]interface{}{
"detail_type": detailType,
"conversation_id": conversationID,
"user_id": senderID,
"is_typing": true,
})
}
}
}
// IsUserOnline 检查用户是否在线
func (s *chatServiceImpl) IsUserOnline(userID string) bool {
if s.sseHub != nil {
return s.sseHub.HasSubscribers(userID)
}
return false
}
// SaveMessage 仅保存消息到数据库,不发送实时推送
// 适用于群聊等由调用方自行负责推送的场景
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
// 验证会话是否存在
_, err := s.getConversation(ctx, conversationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("会话不存在,请重新创建会话")
}
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 验证用户是否是会话参与者
_, err = s.getParticipant(ctx, conversationID, senderID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("您不是该会话的参与者")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
message := &model.Message{
ConversationID: conversationID,
SenderID: senderID,
Segments: segments,
ReplyToID: replyToID,
Status: model.MessageStatusNormal,
}
if err := s.repo.CreateMessageWithSeq(message); err != nil {
return nil, fmt.Errorf("failed to save message: %w", err)
}
// 新消息会改变分页结果,先失效分页缓存,避免读到旧列表
if s.conversationCache != nil {
s.conversationCache.InvalidateMessagePages(conversationID)
}
// 异步写入缓存
go func() {
if err := s.cacheMessage(context.Background(), conversationID, message); err != nil {
log.Printf("[ChatService] async cache message failed, convID=%s, msgID=%s, err=%v", conversationID, message.ID, err)
}
}()
return message, nil
}