Files
backend/internal/service/message_service.go

216 lines
6.6 KiB
Go
Raw Permalink Normal View History

package service
import (
"context"
"time"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/model"
"carrot_bbs/internal/repository"
)
// 缓存TTL常量
const (
ConversationListTTL = 60 * time.Second // 会话列表缓存60秒
ConversationDetailTTL = 60 * time.Second // 会话详情缓存60秒
UnreadCountTTL = 30 * time.Second // 未读数缓存30秒
ConversationNullTTL = 5 * time.Second
UnreadNullTTL = 5 * time.Second
CacheJitterRatio = 0.1
)
// MessageService 消息服务
type MessageService struct {
messageRepo *repository.MessageRepository
cache cache.Cache
}
// NewMessageService 创建消息服务
func NewMessageService(messageRepo *repository.MessageRepository) *MessageService {
return &MessageService{
messageRepo: messageRepo,
cache: cache.GetCache(),
}
}
// ConversationListResult 会话列表缓存结果
type ConversationListResult struct {
Conversations []*model.Conversation
Total int64
}
// SendMessage 发送消息(使用 segments
// senderID 和 receiverID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) SendMessage(ctx context.Context, senderID, receiverID string, segments model.MessageSegments) (*model.Message, error) {
// 获取或创建会话
conv, err := s.messageRepo.GetOrCreatePrivateConversation(senderID, receiverID)
if err != nil {
return nil, err
}
msg := &model.Message{
ConversationID: conv.ID,
SenderID: senderID,
Segments: segments,
Status: model.MessageStatusNormal,
}
// 使用事务创建消息并更新seq
err = s.messageRepo.CreateMessageWithSeq(msg)
if err != nil {
return nil, err
}
// 失效会话列表缓存(发送者和接收者)
cache.InvalidateConversationList(s.cache, senderID)
cache.InvalidateConversationList(s.cache, receiverID)
// 失效未读数缓存
cache.InvalidateUnreadConversation(s.cache, receiverID)
cache.InvalidateUnreadDetail(s.cache, receiverID, conv.ID)
return msg, nil
}
// GetConversations 获取会话列表(带缓存)
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) GetConversations(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
cacheSettings := cache.GetSettings()
conversationTTL := cacheSettings.ConversationTTL
if conversationTTL <= 0 {
conversationTTL = ConversationListTTL
}
nullTTL := cacheSettings.NullTTL
if nullTTL <= 0 {
nullTTL = ConversationNullTTL
}
jitter := cacheSettings.JitterRatio
if jitter <= 0 {
jitter = CacheJitterRatio
}
// 生成缓存键
cacheKey := cache.ConversationListKey(userID, page, pageSize)
result, err := cache.GetOrLoadTyped[*ConversationListResult](
s.cache,
cacheKey,
conversationTTL,
jitter,
nullTTL,
func() (*ConversationListResult, error) {
conversations, total, err := s.messageRepo.GetConversations(userID, page, pageSize)
if err != nil {
return nil, err
}
return &ConversationListResult{
Conversations: conversations,
Total: total,
}, nil
},
)
if err != nil {
return nil, 0, err
}
if result == nil {
return []*model.Conversation{}, 0, nil
}
return result.Conversations, result.Total, nil
}
// GetMessages 获取消息列表
func (s *MessageService) GetMessages(ctx context.Context, conversationID string, page, pageSize int) ([]*model.Message, int64, error) {
return s.messageRepo.GetMessages(conversationID, page, pageSize)
}
// GetMessagesAfterSeq 获取指定seq之后的消息增量同步
func (s *MessageService) GetMessagesAfterSeq(ctx context.Context, conversationID string, afterSeq int64, limit int) ([]*model.Message, error) {
return s.messageRepo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
}
// MarkAsRead 标记为已读
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) MarkAsRead(ctx context.Context, conversationID string, userID string, lastReadSeq int64) error {
err := s.messageRepo.UpdateLastReadSeq(conversationID, userID, lastReadSeq)
if err != nil {
return err
}
// 失效未读数缓存
cache.InvalidateUnreadConversation(s.cache, userID)
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
// 失效会话列表缓存
cache.InvalidateConversationList(s.cache, userID)
return nil
}
// GetUnreadCount 获取未读消息数(带缓存)
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
cacheSettings := cache.GetSettings()
unreadTTL := cacheSettings.UnreadCountTTL
if unreadTTL <= 0 {
unreadTTL = UnreadCountTTL
}
nullTTL := cacheSettings.NullTTL
if nullTTL <= 0 {
nullTTL = UnreadNullTTL
}
jitter := cacheSettings.JitterRatio
if jitter <= 0 {
jitter = CacheJitterRatio
}
// 生成缓存键
cacheKey := cache.UnreadDetailKey(userID, conversationID)
return cache.GetOrLoadTyped[int64](
s.cache,
cacheKey,
unreadTTL,
jitter,
nullTTL,
func() (int64, error) {
return s.messageRepo.GetUnreadCount(conversationID, userID)
},
)
}
// GetOrCreateConversation 获取或创建私聊会话
// user1ID 和 user2ID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
conv, err := s.messageRepo.GetOrCreatePrivateConversation(user1ID, user2ID)
if err != nil {
return nil, err
}
// 失效会话列表缓存
cache.InvalidateConversationList(s.cache, user1ID)
cache.InvalidateConversationList(s.cache, user2ID)
return conv, nil
}
// GetConversationParticipants 获取会话参与者列表
func (s *MessageService) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) {
return s.messageRepo.GetConversationParticipants(conversationID)
}
// ParseConversationID 辅助函数直接返回字符串ID已经是string类型
func ParseConversationID(idStr string) (string, error) {
return idStr, nil
}
// InvalidateUserConversationCache 失效用户会话相关缓存(供外部调用)
func (s *MessageService) InvalidateUserConversationCache(userID string) {
cache.InvalidateConversationList(s.cache, userID)
cache.InvalidateUnreadConversation(s.cache, userID)
}
// InvalidateUserUnreadCache 失效用户未读数缓存(供外部调用)
func (s *MessageService) InvalidateUserUnreadCache(userID, conversationID string) {
cache.InvalidateUnreadConversation(s.cache, userID)
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
}