Files
backend/internal/service/chat_service.go
lan 4d8f2ec997 Initial backend repository commit.
Set up project files and add .gitignore to exclude local build/runtime artifacts.

Made-with: Cursor
2026-03-09 21:28:58 +08:00

623 lines
20 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"errors"
"fmt"
"log"
"time"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/websocket"
"carrot_bbs/internal/repository"
"gorm.io/gorm"
)
// 撤回消息的时间限制2分钟
const RecallMessageTimeout = 2 * time.Minute
// ChatService 聊天服务接口
type ChatService interface {
// 会话管理
GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error)
GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error)
GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error)
DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error
SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error
// 消息操作
SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error)
GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error)
GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error)
// 已读管理
MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error
GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error)
GetAllUnreadCount(ctx context.Context, userID string) (int64, error)
// 消息扩展功能
RecallMessage(ctx context.Context, messageID string, userID string) error
DeleteMessage(ctx context.Context, messageID string, userID string) error
// WebSocket相关
SendTyping(ctx context.Context, senderID string, conversationID string)
BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string)
// 系统消息推送
IsUserOnline(userID string) bool
PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error
PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error
PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error
// 仅保存消息到数据库,不发送 WebSocket 推送(供群聊等自行推送的场景使用)
SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
}
// chatServiceImpl 聊天服务实现
type chatServiceImpl struct {
db *gorm.DB
repo *repository.MessageRepository
userRepo *repository.UserRepository
sensitive SensitiveService
wsManager *websocket.WebSocketManager
}
// NewChatService 创建聊天服务
func NewChatService(
db *gorm.DB,
repo *repository.MessageRepository,
userRepo *repository.UserRepository,
sensitive SensitiveService,
wsManager *websocket.WebSocketManager,
) ChatService {
return &chatServiceImpl{
db: db,
repo: repo,
userRepo: userRepo,
sensitive: sensitive,
wsManager: wsManager,
}
}
// GetOrCreateConversation 获取或创建私聊会话
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
}
// GetConversationList 获取用户的会话列表
func (s *chatServiceImpl) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
return s.repo.GetConversations(userID, page, pageSize)
}
// GetConversationByID 获取会话详情
func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) {
// 验证用户是否是会话参与者
participant, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
// 获取会话信息
conv, err := s.repo.GetConversation(conversationID)
if err != nil {
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 填充用户的已读位置信息
_ = participant // 可以用于返回已读位置等信息
return conv, nil
}
// DeleteConversationForSelf 仅自己删除会话
func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error {
participant, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
if participant.ConversationID == "" {
return errors.New("conversation not found or no permission")
}
if err := s.repo.HideConversationForUser(conversationID, userID); err != nil {
return fmt.Errorf("failed to hide conversation: %w", err)
}
return nil
}
// SetConversationPinned 设置会话置顶(用户维度)
func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error {
participant, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
if participant.ConversationID == "" {
return errors.New("conversation not found or no permission")
}
if err := s.repo.UpdatePinned(conversationID, userID, isPinned); err != nil {
return fmt.Errorf("failed to update pinned status: %w", err)
}
return nil
}
// SendMessage 发送消息(使用 segments
func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
// 首先验证会话是否存在
conv, err := s.repo.GetConversation(conversationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("会话不存在,请重新创建会话")
}
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 拉黑限制:仅拦截“被拉黑方 -> 拉黑人”方向
if conv.Type == model.ConversationTypePrivate && s.userRepo != nil {
participants, pErr := s.repo.GetConversationParticipants(conversationID)
if pErr != nil {
return nil, fmt.Errorf("failed to get participants: %w", pErr)
}
var sentCount *int64
for _, p := range participants {
if p.UserID == senderID {
continue
}
blocked, bErr := s.userRepo.IsBlocked(p.UserID, senderID)
if bErr != nil {
return nil, fmt.Errorf("failed to check block status: %w", bErr)
}
if blocked {
return nil, ErrUserBlocked
}
// 陌生人限制:对方未回关前,只允许发送一条文本消息,且禁止发送图片
isFollowedBack, fErr := s.userRepo.IsFollowing(p.UserID, senderID)
if fErr != nil {
return nil, fmt.Errorf("failed to check follow status: %w", fErr)
}
if !isFollowedBack {
if containsImageSegment(segments) {
return nil, errors.New("对方未关注你,暂不支持发送图片")
}
if sentCount == nil {
c, cErr := s.repo.CountMessagesBySenderInConversation(conversationID, senderID)
if cErr != nil {
return nil, fmt.Errorf("failed to count sender messages: %w", cErr)
}
sentCount = &c
}
if *sentCount >= 1 {
return nil, errors.New("对方未关注你前,仅允许发送一条消息")
}
}
}
}
// 验证用户是否是会话参与者
participant, err := s.repo.GetParticipant(conversationID, senderID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("您不是该会话的参与者")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
// 创建消息
message := &model.Message{
ConversationID: conversationID,
SenderID: senderID, // 直接使用string类型的UUID
Segments: segments,
ReplyToID: replyToID,
Status: model.MessageStatusNormal,
}
// 使用事务创建消息并更新seq
if err := s.repo.CreateMessageWithSeq(message); err != nil {
return nil, fmt.Errorf("failed to save message: %w", err)
}
// 发送消息给接收者
log.Printf("[DEBUG SendMessage] 私聊消息 segments 类型: %T, 值: %+v", message.Segments, message.Segments)
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, websocket.ChatMessage{
ID: message.ID,
ConversationID: message.ConversationID,
SenderID: senderID,
Segments: message.Segments,
Seq: message.Seq,
CreatedAt: message.CreatedAt.UnixMilli(),
})
// 获取会话中的其他参与者
participants, err := s.repo.GetConversationParticipants(conversationID)
if err == nil {
for _, p := range participants {
// 不发给自己
if p.UserID == senderID {
continue
}
// 如果接收者在线,发送实时消息
if s.wsManager != nil {
isOnline := s.wsManager.IsUserOnline(p.UserID)
log.Printf("[DEBUG SendMessage] 接收者 UserID=%s, 在线状态=%v", p.UserID, isOnline)
if isOnline {
log.Printf("[DEBUG SendMessage] 发送WebSocket消息给 UserID=%s, 消息类型=%s", p.UserID, wsMsg.Type)
s.wsManager.SendToUser(p.UserID, wsMsg)
}
}
}
} else {
log.Printf("[DEBUG SendMessage] 获取参与者失败: %v", err)
}
_ = participant // 避免未使用变量警告
return message, nil
}
func containsImageSegment(segments model.MessageSegments) bool {
for _, seg := range segments {
if seg.Type == string(model.ContentTypeImage) || seg.Type == "image" {
return true
}
}
return false
}
// GetMessages 获取消息历史(分页)
func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, 0, errors.New("conversation not found or no permission")
}
return nil, 0, fmt.Errorf("failed to get participant: %w", err)
}
return s.repo.GetMessages(conversationID, page, pageSize)
}
// GetMessagesAfterSeq 获取指定seq之后的消息用于增量同步
func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
if limit <= 0 {
limit = 100
}
return s.repo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
}
// GetMessagesBeforeSeq 获取指定seq之前的历史消息用于下拉加载更多
func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
if limit <= 0 {
limit = 20
}
return s.repo.GetMessagesBeforeSeq(conversationID, beforeSeq, limit)
}
// MarkAsRead 标记已读
func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
// 更新参与者的已读位置
err = s.repo.UpdateLastReadSeq(conversationID, userID, seq)
if err != nil {
return fmt.Errorf("failed to update last read seq: %w", err)
}
// 发送已读回执(作为 meta 事件)
if s.wsManager != nil {
wsMsg := websocket.CreateWSMessage("meta", map[string]interface{}{
"detail_type": websocket.MetaDetailTypeRead,
"conversation_id": conversationID,
"seq": seq,
"user_id": userID,
})
// 获取会话中的所有参与者
participants, err := s.repo.GetConversationParticipants(conversationID)
if err == nil {
// 推送给会话中的所有参与者(包括自己)
for _, p := range participants {
if s.wsManager.IsUserOnline(p.UserID) {
s.wsManager.SendToUser(p.UserID, wsMsg)
}
}
}
}
return nil
}
// GetUnreadCount 获取指定会话的未读消息数
func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return 0, errors.New("conversation not found or no permission")
}
return 0, fmt.Errorf("failed to get participant: %w", err)
}
return s.repo.GetUnreadCount(conversationID, userID)
}
// GetAllUnreadCount 获取所有会话的未读消息总数
func (s *chatServiceImpl) GetAllUnreadCount(ctx context.Context, userID string) (int64, error) {
return s.repo.GetAllUnreadCount(userID)
}
// RecallMessage 撤回消息2分钟内
func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, userID string) error {
// 获取消息
var message model.Message
err := s.db.First(&message, "id = ?", messageID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("message not found")
}
return fmt.Errorf("failed to get message: %w", err)
}
// 验证是否是消息发送者
if message.SenderIDStr() != userID {
return errors.New("can only recall your own messages")
}
// 验证消息是否已被撤回
if message.Status == model.MessageStatusRecalled {
return errors.New("message already recalled")
}
// 验证是否在2分钟内
if time.Since(message.CreatedAt) > RecallMessageTimeout {
return errors.New("message recall timeout (2 minutes)")
}
// 更新消息状态为已撤回
err = s.db.Model(&message).Update("status", model.MessageStatusRecalled).Error
if err != nil {
return fmt.Errorf("failed to recall message: %w", err)
}
// 发送撤回通知
if s.wsManager != nil {
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeRecall, map[string]interface{}{
"messageId": messageID,
"conversationId": message.ConversationID,
"senderId": userID,
})
// 通知会话中的所有参与者
participants, err := s.repo.GetConversationParticipants(message.ConversationID)
if err == nil {
for _, p := range participants {
if s.wsManager.IsUserOnline(p.UserID) {
s.wsManager.SendToUser(p.UserID, wsMsg)
}
}
}
}
return nil
}
// DeleteMessage 删除消息(仅对自己可见)
func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, userID string) error {
// 获取消息
var message model.Message
err := s.db.First(&message, "id = ?", messageID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("message not found")
}
return fmt.Errorf("failed to get message: %w", err)
}
// 验证用户是否是会话参与者
_, err = s.repo.GetParticipant(message.ConversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("no permission to delete this message")
}
return fmt.Errorf("failed to get participant: %w", err)
}
// 对于删除消息,我们使用软删除,但需要确保只对当前用户隐藏
// 这里简化处理:只有发送者可以删除自己的消息
if message.SenderIDStr() != userID {
return errors.New("can only delete your own messages")
}
// 更新消息状态为已删除
err = s.db.Model(&message).Update("status", model.MessageStatusDeleted).Error
if err != nil {
return fmt.Errorf("failed to delete message: %w", err)
}
return nil
}
// SendTyping 发送正在输入状态
func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) {
if s.wsManager == nil {
return
}
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, senderID)
if err != nil {
return
}
// 获取会话中的其他参与者
participants, err := s.repo.GetConversationParticipants(conversationID)
if err != nil {
return
}
for _, p := range participants {
if p.UserID == senderID {
continue
}
// 发送正在输入状态
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeTyping, map[string]string{
"conversationId": conversationID,
"senderId": senderID,
})
if s.wsManager.IsUserOnline(p.UserID) {
s.wsManager.SendToUser(p.UserID, wsMsg)
}
}
}
// BroadcastMessage 广播消息给用户
func (s *chatServiceImpl) BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string) {
if s.wsManager != nil {
s.wsManager.SendToUser(targetUser, msg)
}
}
// IsUserOnline 检查用户是否在线
func (s *chatServiceImpl) IsUserOnline(userID string) bool {
if s.wsManager == nil {
return false
}
return s.wsManager.IsUserOnline(userID)
}
// PushSystemMessage 推送系统消息给指定用户
func (s *chatServiceImpl) PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error {
if s.wsManager == nil {
return errors.New("websocket manager not available")
}
if !s.wsManager.IsUserOnline(userID) {
return errors.New("user is offline")
}
sysMsg := &websocket.SystemMessage{
ID: "", // 由调用方生成
Type: msgType,
Title: title,
Content: content,
Data: data,
CreatedAt: time.Now().UnixMilli(),
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
s.wsManager.SendToUser(userID, wsMsg)
return nil
}
// PushNotificationMessage 推送通知消息给指定用户
func (s *chatServiceImpl) PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error {
if s.wsManager == nil {
return errors.New("websocket manager not available")
}
if !s.wsManager.IsUserOnline(userID) {
return errors.New("user is offline")
}
// 确保时间戳已设置
if notification.CreatedAt == 0 {
notification.CreatedAt = time.Now().UnixMilli()
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
s.wsManager.SendToUser(userID, wsMsg)
return nil
}
// PushAnnouncementMessage 广播公告消息给所有在线用户
func (s *chatServiceImpl) PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error {
if s.wsManager == nil {
return errors.New("websocket manager not available")
}
// 确保时间戳已设置
if announcement.CreatedAt == 0 {
announcement.CreatedAt = time.Now().UnixMilli()
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
s.wsManager.Broadcast(wsMsg)
return nil
}
// SaveMessage 仅保存消息到数据库,不发送 WebSocket 推送
// 适用于群聊等由调用方自行负责推送的场景
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
// 验证会话是否存在
_, err := s.repo.GetConversation(conversationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("会话不存在,请重新创建会话")
}
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 验证用户是否是会话参与者
_, err = s.repo.GetParticipant(conversationID, senderID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("您不是该会话的参与者")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
message := &model.Message{
ConversationID: conversationID,
SenderID: senderID,
Segments: segments,
ReplyToID: replyToID,
Status: model.MessageStatusNormal,
}
if err := s.repo.CreateMessageWithSeq(message); err != nil {
return nil, fmt.Errorf("failed to save message: %w", err)
}
return message, nil
}