Set up project files and add .gitignore to exclude local build/runtime artifacts. Made-with: Cursor
623 lines
20 KiB
Go
623 lines
20 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"log"
|
||
"time"
|
||
|
||
"carrot_bbs/internal/model"
|
||
"carrot_bbs/internal/pkg/websocket"
|
||
"carrot_bbs/internal/repository"
|
||
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// 撤回消息的时间限制(2分钟)
|
||
const RecallMessageTimeout = 2 * time.Minute
|
||
|
||
// ChatService 聊天服务接口
|
||
type ChatService interface {
|
||
// 会话管理
|
||
GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error)
|
||
GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error)
|
||
GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error)
|
||
DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error
|
||
SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error
|
||
|
||
// 消息操作
|
||
SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
|
||
GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error)
|
||
GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error)
|
||
GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error)
|
||
|
||
// 已读管理
|
||
MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error
|
||
GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error)
|
||
GetAllUnreadCount(ctx context.Context, userID string) (int64, error)
|
||
|
||
// 消息扩展功能
|
||
RecallMessage(ctx context.Context, messageID string, userID string) error
|
||
DeleteMessage(ctx context.Context, messageID string, userID string) error
|
||
|
||
// WebSocket相关
|
||
SendTyping(ctx context.Context, senderID string, conversationID string)
|
||
BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string)
|
||
|
||
// 系统消息推送
|
||
IsUserOnline(userID string) bool
|
||
PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error
|
||
PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error
|
||
PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error
|
||
|
||
// 仅保存消息到数据库,不发送 WebSocket 推送(供群聊等自行推送的场景使用)
|
||
SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
|
||
}
|
||
|
||
// chatServiceImpl 聊天服务实现
|
||
type chatServiceImpl struct {
|
||
db *gorm.DB
|
||
repo *repository.MessageRepository
|
||
userRepo *repository.UserRepository
|
||
sensitive SensitiveService
|
||
wsManager *websocket.WebSocketManager
|
||
}
|
||
|
||
// NewChatService 创建聊天服务
|
||
func NewChatService(
|
||
db *gorm.DB,
|
||
repo *repository.MessageRepository,
|
||
userRepo *repository.UserRepository,
|
||
sensitive SensitiveService,
|
||
wsManager *websocket.WebSocketManager,
|
||
) ChatService {
|
||
return &chatServiceImpl{
|
||
db: db,
|
||
repo: repo,
|
||
userRepo: userRepo,
|
||
sensitive: sensitive,
|
||
wsManager: wsManager,
|
||
}
|
||
}
|
||
|
||
// GetOrCreateConversation 获取或创建私聊会话
|
||
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
|
||
return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
|
||
}
|
||
|
||
// GetConversationList 获取用户的会话列表
|
||
func (s *chatServiceImpl) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
|
||
return s.repo.GetConversations(userID, page, pageSize)
|
||
}
|
||
|
||
// GetConversationByID 获取会话详情
|
||
func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) {
|
||
// 验证用户是否是会话参与者
|
||
participant, err := s.repo.GetParticipant(conversationID, userID)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("conversation not found or no permission")
|
||
}
|
||
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||
}
|
||
|
||
// 获取会话信息
|
||
conv, err := s.repo.GetConversation(conversationID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
||
}
|
||
|
||
// 填充用户的已读位置信息
|
||
_ = participant // 可以用于返回已读位置等信息
|
||
|
||
return conv, nil
|
||
}
|
||
|
||
// DeleteConversationForSelf 仅自己删除会话
|
||
func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error {
|
||
participant, err := s.repo.GetParticipant(conversationID, userID)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return errors.New("conversation not found or no permission")
|
||
}
|
||
return fmt.Errorf("failed to get participant: %w", err)
|
||
}
|
||
if participant.ConversationID == "" {
|
||
return errors.New("conversation not found or no permission")
|
||
}
|
||
|
||
if err := s.repo.HideConversationForUser(conversationID, userID); err != nil {
|
||
return fmt.Errorf("failed to hide conversation: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// SetConversationPinned 设置会话置顶(用户维度)
|
||
func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error {
|
||
participant, err := s.repo.GetParticipant(conversationID, userID)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return errors.New("conversation not found or no permission")
|
||
}
|
||
return fmt.Errorf("failed to get participant: %w", err)
|
||
}
|
||
if participant.ConversationID == "" {
|
||
return errors.New("conversation not found or no permission")
|
||
}
|
||
|
||
if err := s.repo.UpdatePinned(conversationID, userID, isPinned); err != nil {
|
||
return fmt.Errorf("failed to update pinned status: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// SendMessage 发送消息(使用 segments)
|
||
func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
|
||
// 首先验证会话是否存在
|
||
conv, err := s.repo.GetConversation(conversationID)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("会话不存在,请重新创建会话")
|
||
}
|
||
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
||
}
|
||
|
||
// 拉黑限制:仅拦截“被拉黑方 -> 拉黑人”方向
|
||
if conv.Type == model.ConversationTypePrivate && s.userRepo != nil {
|
||
participants, pErr := s.repo.GetConversationParticipants(conversationID)
|
||
if pErr != nil {
|
||
return nil, fmt.Errorf("failed to get participants: %w", pErr)
|
||
}
|
||
var sentCount *int64
|
||
for _, p := range participants {
|
||
if p.UserID == senderID {
|
||
continue
|
||
}
|
||
blocked, bErr := s.userRepo.IsBlocked(p.UserID, senderID)
|
||
if bErr != nil {
|
||
return nil, fmt.Errorf("failed to check block status: %w", bErr)
|
||
}
|
||
if blocked {
|
||
return nil, ErrUserBlocked
|
||
}
|
||
|
||
// 陌生人限制:对方未回关前,只允许发送一条文本消息,且禁止发送图片
|
||
isFollowedBack, fErr := s.userRepo.IsFollowing(p.UserID, senderID)
|
||
if fErr != nil {
|
||
return nil, fmt.Errorf("failed to check follow status: %w", fErr)
|
||
}
|
||
if !isFollowedBack {
|
||
if containsImageSegment(segments) {
|
||
return nil, errors.New("对方未关注你,暂不支持发送图片")
|
||
}
|
||
if sentCount == nil {
|
||
c, cErr := s.repo.CountMessagesBySenderInConversation(conversationID, senderID)
|
||
if cErr != nil {
|
||
return nil, fmt.Errorf("failed to count sender messages: %w", cErr)
|
||
}
|
||
sentCount = &c
|
||
}
|
||
if *sentCount >= 1 {
|
||
return nil, errors.New("对方未关注你前,仅允许发送一条消息")
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 验证用户是否是会话参与者
|
||
participant, err := s.repo.GetParticipant(conversationID, senderID)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("您不是该会话的参与者")
|
||
}
|
||
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||
}
|
||
|
||
// 创建消息
|
||
message := &model.Message{
|
||
ConversationID: conversationID,
|
||
SenderID: senderID, // 直接使用string类型的UUID
|
||
Segments: segments,
|
||
ReplyToID: replyToID,
|
||
Status: model.MessageStatusNormal,
|
||
}
|
||
|
||
// 使用事务创建消息并更新seq
|
||
if err := s.repo.CreateMessageWithSeq(message); err != nil {
|
||
return nil, fmt.Errorf("failed to save message: %w", err)
|
||
}
|
||
|
||
// 发送消息给接收者
|
||
log.Printf("[DEBUG SendMessage] 私聊消息 segments 类型: %T, 值: %+v", message.Segments, message.Segments)
|
||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, websocket.ChatMessage{
|
||
ID: message.ID,
|
||
ConversationID: message.ConversationID,
|
||
SenderID: senderID,
|
||
Segments: message.Segments,
|
||
Seq: message.Seq,
|
||
CreatedAt: message.CreatedAt.UnixMilli(),
|
||
})
|
||
|
||
// 获取会话中的其他参与者
|
||
participants, err := s.repo.GetConversationParticipants(conversationID)
|
||
if err == nil {
|
||
for _, p := range participants {
|
||
// 不发给自己
|
||
if p.UserID == senderID {
|
||
continue
|
||
}
|
||
// 如果接收者在线,发送实时消息
|
||
if s.wsManager != nil {
|
||
isOnline := s.wsManager.IsUserOnline(p.UserID)
|
||
log.Printf("[DEBUG SendMessage] 接收者 UserID=%s, 在线状态=%v", p.UserID, isOnline)
|
||
if isOnline {
|
||
log.Printf("[DEBUG SendMessage] 发送WebSocket消息给 UserID=%s, 消息类型=%s", p.UserID, wsMsg.Type)
|
||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||
}
|
||
}
|
||
}
|
||
} else {
|
||
log.Printf("[DEBUG SendMessage] 获取参与者失败: %v", err)
|
||
}
|
||
|
||
_ = participant // 避免未使用变量警告
|
||
|
||
return message, nil
|
||
}
|
||
|
||
func containsImageSegment(segments model.MessageSegments) bool {
|
||
for _, seg := range segments {
|
||
if seg.Type == string(model.ContentTypeImage) || seg.Type == "image" {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// GetMessages 获取消息历史(分页)
|
||
func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) {
|
||
// 验证用户是否是会话参与者
|
||
_, err := s.repo.GetParticipant(conversationID, userID)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, 0, errors.New("conversation not found or no permission")
|
||
}
|
||
return nil, 0, fmt.Errorf("failed to get participant: %w", err)
|
||
}
|
||
|
||
return s.repo.GetMessages(conversationID, page, pageSize)
|
||
}
|
||
|
||
// GetMessagesAfterSeq 获取指定seq之后的消息(用于增量同步)
|
||
func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) {
|
||
// 验证用户是否是会话参与者
|
||
_, err := s.repo.GetParticipant(conversationID, userID)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("conversation not found or no permission")
|
||
}
|
||
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||
}
|
||
|
||
if limit <= 0 {
|
||
limit = 100
|
||
}
|
||
|
||
return s.repo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
|
||
}
|
||
|
||
// GetMessagesBeforeSeq 获取指定seq之前的历史消息(用于下拉加载更多)
|
||
func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) {
|
||
// 验证用户是否是会话参与者
|
||
_, err := s.repo.GetParticipant(conversationID, userID)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("conversation not found or no permission")
|
||
}
|
||
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||
}
|
||
|
||
if limit <= 0 {
|
||
limit = 20
|
||
}
|
||
|
||
return s.repo.GetMessagesBeforeSeq(conversationID, beforeSeq, limit)
|
||
}
|
||
|
||
// MarkAsRead 标记已读
|
||
func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error {
|
||
// 验证用户是否是会话参与者
|
||
_, err := s.repo.GetParticipant(conversationID, userID)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return errors.New("conversation not found or no permission")
|
||
}
|
||
return fmt.Errorf("failed to get participant: %w", err)
|
||
}
|
||
|
||
// 更新参与者的已读位置
|
||
err = s.repo.UpdateLastReadSeq(conversationID, userID, seq)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to update last read seq: %w", err)
|
||
}
|
||
|
||
// 发送已读回执(作为 meta 事件)
|
||
if s.wsManager != nil {
|
||
wsMsg := websocket.CreateWSMessage("meta", map[string]interface{}{
|
||
"detail_type": websocket.MetaDetailTypeRead,
|
||
"conversation_id": conversationID,
|
||
"seq": seq,
|
||
"user_id": userID,
|
||
})
|
||
|
||
// 获取会话中的所有参与者
|
||
participants, err := s.repo.GetConversationParticipants(conversationID)
|
||
if err == nil {
|
||
// 推送给会话中的所有参与者(包括自己)
|
||
for _, p := range participants {
|
||
if s.wsManager.IsUserOnline(p.UserID) {
|
||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetUnreadCount 获取指定会话的未读消息数
|
||
func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
|
||
// 验证用户是否是会话参与者
|
||
_, err := s.repo.GetParticipant(conversationID, userID)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return 0, errors.New("conversation not found or no permission")
|
||
}
|
||
return 0, fmt.Errorf("failed to get participant: %w", err)
|
||
}
|
||
|
||
return s.repo.GetUnreadCount(conversationID, userID)
|
||
}
|
||
|
||
// GetAllUnreadCount 获取所有会话的未读消息总数
|
||
func (s *chatServiceImpl) GetAllUnreadCount(ctx context.Context, userID string) (int64, error) {
|
||
return s.repo.GetAllUnreadCount(userID)
|
||
}
|
||
|
||
// RecallMessage 撤回消息(2分钟内)
|
||
func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, userID string) error {
|
||
// 获取消息
|
||
var message model.Message
|
||
err := s.db.First(&message, "id = ?", messageID).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return errors.New("message not found")
|
||
}
|
||
return fmt.Errorf("failed to get message: %w", err)
|
||
}
|
||
|
||
// 验证是否是消息发送者
|
||
if message.SenderIDStr() != userID {
|
||
return errors.New("can only recall your own messages")
|
||
}
|
||
|
||
// 验证消息是否已被撤回
|
||
if message.Status == model.MessageStatusRecalled {
|
||
return errors.New("message already recalled")
|
||
}
|
||
|
||
// 验证是否在2分钟内
|
||
if time.Since(message.CreatedAt) > RecallMessageTimeout {
|
||
return errors.New("message recall timeout (2 minutes)")
|
||
}
|
||
|
||
// 更新消息状态为已撤回
|
||
err = s.db.Model(&message).Update("status", model.MessageStatusRecalled).Error
|
||
if err != nil {
|
||
return fmt.Errorf("failed to recall message: %w", err)
|
||
}
|
||
|
||
// 发送撤回通知
|
||
if s.wsManager != nil {
|
||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeRecall, map[string]interface{}{
|
||
"messageId": messageID,
|
||
"conversationId": message.ConversationID,
|
||
"senderId": userID,
|
||
})
|
||
|
||
// 通知会话中的所有参与者
|
||
participants, err := s.repo.GetConversationParticipants(message.ConversationID)
|
||
if err == nil {
|
||
for _, p := range participants {
|
||
if s.wsManager.IsUserOnline(p.UserID) {
|
||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// DeleteMessage 删除消息(仅对自己可见)
|
||
func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, userID string) error {
|
||
// 获取消息
|
||
var message model.Message
|
||
err := s.db.First(&message, "id = ?", messageID).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return errors.New("message not found")
|
||
}
|
||
return fmt.Errorf("failed to get message: %w", err)
|
||
}
|
||
|
||
// 验证用户是否是会话参与者
|
||
_, err = s.repo.GetParticipant(message.ConversationID, userID)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return errors.New("no permission to delete this message")
|
||
}
|
||
return fmt.Errorf("failed to get participant: %w", err)
|
||
}
|
||
|
||
// 对于删除消息,我们使用软删除,但需要确保只对当前用户隐藏
|
||
// 这里简化处理:只有发送者可以删除自己的消息
|
||
if message.SenderIDStr() != userID {
|
||
return errors.New("can only delete your own messages")
|
||
}
|
||
|
||
// 更新消息状态为已删除
|
||
err = s.db.Model(&message).Update("status", model.MessageStatusDeleted).Error
|
||
if err != nil {
|
||
return fmt.Errorf("failed to delete message: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// SendTyping 发送正在输入状态
|
||
func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) {
|
||
if s.wsManager == nil {
|
||
return
|
||
}
|
||
|
||
// 验证用户是否是会话参与者
|
||
_, err := s.repo.GetParticipant(conversationID, senderID)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 获取会话中的其他参与者
|
||
participants, err := s.repo.GetConversationParticipants(conversationID)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
for _, p := range participants {
|
||
if p.UserID == senderID {
|
||
continue
|
||
}
|
||
// 发送正在输入状态
|
||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeTyping, map[string]string{
|
||
"conversationId": conversationID,
|
||
"senderId": senderID,
|
||
})
|
||
|
||
if s.wsManager.IsUserOnline(p.UserID) {
|
||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||
}
|
||
}
|
||
}
|
||
|
||
// BroadcastMessage 广播消息给用户
|
||
func (s *chatServiceImpl) BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string) {
|
||
if s.wsManager != nil {
|
||
s.wsManager.SendToUser(targetUser, msg)
|
||
}
|
||
}
|
||
|
||
// IsUserOnline 检查用户是否在线
|
||
func (s *chatServiceImpl) IsUserOnline(userID string) bool {
|
||
if s.wsManager == nil {
|
||
return false
|
||
}
|
||
return s.wsManager.IsUserOnline(userID)
|
||
}
|
||
|
||
// PushSystemMessage 推送系统消息给指定用户
|
||
func (s *chatServiceImpl) PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error {
|
||
if s.wsManager == nil {
|
||
return errors.New("websocket manager not available")
|
||
}
|
||
|
||
if !s.wsManager.IsUserOnline(userID) {
|
||
return errors.New("user is offline")
|
||
}
|
||
|
||
sysMsg := &websocket.SystemMessage{
|
||
ID: "", // 由调用方生成
|
||
Type: msgType,
|
||
Title: title,
|
||
Content: content,
|
||
Data: data,
|
||
CreatedAt: time.Now().UnixMilli(),
|
||
}
|
||
|
||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
|
||
s.wsManager.SendToUser(userID, wsMsg)
|
||
return nil
|
||
}
|
||
|
||
// PushNotificationMessage 推送通知消息给指定用户
|
||
func (s *chatServiceImpl) PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error {
|
||
if s.wsManager == nil {
|
||
return errors.New("websocket manager not available")
|
||
}
|
||
|
||
if !s.wsManager.IsUserOnline(userID) {
|
||
return errors.New("user is offline")
|
||
}
|
||
|
||
// 确保时间戳已设置
|
||
if notification.CreatedAt == 0 {
|
||
notification.CreatedAt = time.Now().UnixMilli()
|
||
}
|
||
|
||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
|
||
s.wsManager.SendToUser(userID, wsMsg)
|
||
return nil
|
||
}
|
||
|
||
// PushAnnouncementMessage 广播公告消息给所有在线用户
|
||
func (s *chatServiceImpl) PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error {
|
||
if s.wsManager == nil {
|
||
return errors.New("websocket manager not available")
|
||
}
|
||
|
||
// 确保时间戳已设置
|
||
if announcement.CreatedAt == 0 {
|
||
announcement.CreatedAt = time.Now().UnixMilli()
|
||
}
|
||
|
||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
|
||
s.wsManager.Broadcast(wsMsg)
|
||
return nil
|
||
}
|
||
|
||
// SaveMessage 仅保存消息到数据库,不发送 WebSocket 推送
|
||
// 适用于群聊等由调用方自行负责推送的场景
|
||
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
|
||
// 验证会话是否存在
|
||
_, err := s.repo.GetConversation(conversationID)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("会话不存在,请重新创建会话")
|
||
}
|
||
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
||
}
|
||
|
||
// 验证用户是否是会话参与者
|
||
_, err = s.repo.GetParticipant(conversationID, senderID)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("您不是该会话的参与者")
|
||
}
|
||
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||
}
|
||
|
||
message := &model.Message{
|
||
ConversationID: conversationID,
|
||
SenderID: senderID,
|
||
Segments: segments,
|
||
ReplyToID: replyToID,
|
||
Status: model.MessageStatusNormal,
|
||
}
|
||
|
||
if err := s.repo.CreateMessageWithSeq(message); err != nil {
|
||
return nil, fmt.Errorf("failed to save message: %w", err)
|
||
}
|
||
|
||
return message, nil
|
||
}
|