Initial backend repository commit.

Set up project files and add .gitignore to exclude local build/runtime artifacts.

Made-with: Cursor
This commit is contained in:
2026-03-09 21:28:58 +08:00
commit 4d8f2ec997
102 changed files with 25022 additions and 0 deletions

View File

@@ -0,0 +1,759 @@
package service
import (
"context"
"encoding/json"
"fmt"
"log"
"strings"
"sync"
"time"
"carrot_bbs/internal/model"
"gorm.io/gorm"
)
// ==================== 内容审核服务接口和实现 ====================
// AuditServiceProvider 内容审核服务提供商接口
type AuditServiceProvider interface {
// AuditText 审核文本
AuditText(ctx context.Context, text string, scene string) (*AuditResult, error)
// AuditImage 审核图片
AuditImage(ctx context.Context, imageURL string) (*AuditResult, error)
// GetName 获取提供商名称
GetName() string
}
// AuditResult 审核结果
type AuditResult struct {
Pass bool `json:"pass"` // 是否通过
Risk string `json:"risk"` // 风险等级: low, medium, high
Labels []string `json:"labels"` // 标签列表
Suggest string `json:"suggest"` // 建议: pass, review, block
Detail string `json:"detail"` // 详细说明
Provider string `json:"provider"` // 服务提供商
}
// AuditService 内容审核服务接口
type AuditService interface {
// AuditText 审核文本
AuditText(ctx context.Context, text string, auditType string) (*AuditResult, error)
// AuditImage 审核图片
AuditImage(ctx context.Context, imageURL string) (*AuditResult, error)
// GetAuditResult 获取审核结果
GetAuditResult(ctx context.Context, auditID string) (*AuditResult, error)
// SetProvider 设置审核服务提供商
SetProvider(provider AuditServiceProvider)
// GetProvider 获取当前审核服务提供商
GetProvider() AuditServiceProvider
}
// auditServiceImpl 内容审核服务实现
type auditServiceImpl struct {
db *gorm.DB
provider AuditServiceProvider
config *AuditConfig
mu sync.RWMutex
}
// AuditConfig 内容审核服务配置
type AuditConfig struct {
Enabled bool `mapstructure:"enabled" yaml:"enabled"`
// 审核服务提供商: local, aliyun, tencent, baidu
Provider string `mapstructure:"provider" yaml:"provider"`
// 阿里云配置
AliyunAccessKey string `mapstructure:"aliyun_access_key" yaml:"aliyun_access_key"`
AliyunSecretKey string `mapstructure:"aliyun_secret_key" yaml:"aliyun_secret_key"`
AliyunRegion string `mapstructure:"aliyun_region" yaml:"aliyun_region"`
// 腾讯云配置
TencentSecretID string `mapstructure:"tencent_secret_id" yaml:"tencent_secret_id"`
TencentSecretKey string `mapstructure:"tencent_secret_key" yaml:"tencent_secret_key"`
// 百度云配置
BaiduAPIKey string `mapstructure:"baidu_api_key" yaml:"baidu_api_key"`
BaiduSecretKey string `mapstructure:"baidu_secret_key" yaml:"baidu_secret_key"`
// 是否自动审核
AutoAudit bool `mapstructure:"auto_audit" yaml:"auto_audit"`
// 审核超时时间(秒)
Timeout int `mapstructure:"timeout" yaml:"timeout"`
}
// NewAuditService 创建内容审核服务
func NewAuditService(db *gorm.DB, config *AuditConfig) AuditService {
s := &auditServiceImpl{
db: db,
config: config,
}
// 根据配置初始化提供商
if config.Enabled {
provider := s.initProvider(config.Provider)
s.provider = provider
}
return s
}
// initProvider 根据配置初始化审核服务提供商
func (s *auditServiceImpl) initProvider(providerType string) AuditServiceProvider {
switch strings.ToLower(providerType) {
case "aliyun":
return NewAliyunAuditProvider(s.config.AliyunAccessKey, s.config.AliyunSecretKey, s.config.AliyunRegion)
case "tencent":
return NewTencentAuditProvider(s.config.TencentSecretID, s.config.TencentSecretKey)
case "baidu":
return NewBaiduAuditProvider(s.config.BaiduAPIKey, s.config.BaiduSecretKey)
case "local":
fallthrough
default:
// 默认使用本地审核服务
return NewLocalAuditProvider()
}
}
// AuditText 审核文本
func (s *auditServiceImpl) AuditText(ctx context.Context, text string, auditType string) (*AuditResult, error) {
if !s.config.Enabled {
// 如果审核服务未启用,直接返回通过
return &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Detail: "Audit service disabled",
}, nil
}
if text == "" {
return &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Detail: "Empty text",
}, nil
}
var result *AuditResult
var err error
// 使用提供商审核
if s.provider != nil {
result, err = s.provider.AuditText(ctx, text, auditType)
} else {
// 如果没有设置提供商,使用本地审核
localProvider := NewLocalAuditProvider()
result, err = localProvider.AuditText(ctx, text, auditType)
}
if err != nil {
log.Printf("Audit text error: %v", err)
return &AuditResult{
Pass: false,
Risk: "high",
Suggest: "review",
Detail: fmt.Sprintf("Audit error: %v", err),
}, err
}
// 记录审核日志
go s.saveAuditLog(ctx, "text", "", text, auditType, result)
return result, nil
}
// AuditImage 审核图片
func (s *auditServiceImpl) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
if !s.config.Enabled {
return &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Detail: "Audit service disabled",
}, nil
}
if imageURL == "" {
return &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Detail: "Empty image URL",
}, nil
}
var result *AuditResult
var err error
// 使用提供商审核
if s.provider != nil {
result, err = s.provider.AuditImage(ctx, imageURL)
} else {
// 如果没有设置提供商,使用本地审核
localProvider := NewLocalAuditProvider()
result, err = localProvider.AuditImage(ctx, imageURL)
}
if err != nil {
log.Printf("Audit image error: %v", err)
return &AuditResult{
Pass: false,
Risk: "high",
Suggest: "review",
Detail: fmt.Sprintf("Audit error: %v", err),
}, err
}
// 记录审核日志
go s.saveAuditLog(ctx, "image", "", "", "image", result)
return result, nil
}
// GetAuditResult 获取审核结果
func (s *auditServiceImpl) GetAuditResult(ctx context.Context, auditID string) (*AuditResult, error) {
if s.db == nil || auditID == "" {
return nil, fmt.Errorf("invalid audit ID")
}
var auditLog model.AuditLog
if err := s.db.Where("id = ?", auditID).First(&auditLog).Error; err != nil {
return nil, err
}
result := &AuditResult{
Pass: auditLog.Result == model.AuditResultPass,
Risk: string(auditLog.RiskLevel),
Suggest: auditLog.Suggestion,
Detail: auditLog.Detail,
}
// 解析标签
if auditLog.Labels != "" {
json.Unmarshal([]byte(auditLog.Labels), &result.Labels)
}
return result, nil
}
// SetProvider 设置审核服务提供商
func (s *auditServiceImpl) SetProvider(provider AuditServiceProvider) {
s.mu.Lock()
defer s.mu.Unlock()
s.provider = provider
}
// GetProvider 获取当前审核服务提供商
func (s *auditServiceImpl) GetProvider() AuditServiceProvider {
s.mu.RLock()
defer s.mu.RUnlock()
return s.provider
}
// saveAuditLog 保存审核日志
func (s *auditServiceImpl) saveAuditLog(ctx context.Context, contentType, content, imageURL, auditType string, result *AuditResult) {
if s.db == nil {
return
}
auditLog := model.AuditLog{
ContentType: contentType,
Content: content,
ContentURL: imageURL,
AuditType: auditType,
Labels: strings.Join(result.Labels, ","),
Suggestion: result.Suggest,
Detail: result.Detail,
Source: model.AuditSourceAuto,
Status: "completed",
}
if result.Pass {
auditLog.Result = model.AuditResultPass
} else if result.Suggest == "review" {
auditLog.Result = model.AuditResultReview
} else {
auditLog.Result = model.AuditResultBlock
}
switch result.Risk {
case "low":
auditLog.RiskLevel = model.AuditRiskLevelLow
case "medium":
auditLog.RiskLevel = model.AuditRiskLevelMedium
case "high":
auditLog.RiskLevel = model.AuditRiskLevelHigh
default:
auditLog.RiskLevel = model.AuditRiskLevelLow
}
if err := s.db.Create(&auditLog).Error; err != nil {
log.Printf("Failed to save audit log: %v", err)
}
}
// ==================== 本地审核服务提供商 ====================
// localAuditProvider 本地审核服务提供商
type localAuditProvider struct {
// 可以注入敏感词服务进行本地审核
sensitiveService SensitiveService
}
// NewLocalAuditProvider 创建本地审核服务提供商
func NewLocalAuditProvider() AuditServiceProvider {
return &localAuditProvider{
sensitiveService: nil,
}
}
// GetName 获取提供商名称
func (p *localAuditProvider) GetName() string {
return "local"
}
// AuditText 审核文本
func (p *localAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
// 本地审核逻辑
// 1. 敏感词检查
// 2. 规则匹配
// 3. 简单的关键词检测
result := &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Labels: []string{},
Provider: "local",
}
// 如果有敏感词服务,使用它进行检测
if p.sensitiveService != nil {
hasSensitive, words := p.sensitiveService.Check(ctx, text)
if hasSensitive {
result.Pass = false
result.Risk = "high"
result.Suggest = "block"
result.Detail = fmt.Sprintf("包含敏感词: %s", strings.Join(words, ","))
result.Labels = append(result.Labels, "sensitive")
}
}
// 简单的关键词检测规则
// 实际项目中应该从数据库加载
suspiciousPatterns := []string{
"诈骗",
"钓鱼",
"木马",
"病毒",
}
for _, pattern := range suspiciousPatterns {
if strings.Contains(text, pattern) {
result.Pass = false
result.Risk = "high"
result.Suggest = "block"
result.Labels = append(result.Labels, "suspicious")
if result.Detail == "" {
result.Detail = fmt.Sprintf("包含可疑内容: %s", pattern)
} else {
result.Detail += fmt.Sprintf(", %s", pattern)
}
}
}
return result, nil
}
// AuditImage 审核图片
func (p *localAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
// 本地图片审核逻辑
// 1. 图片URL合法性检查
// 2. 图片格式检查
// 3. 可以扩展接入本地图片识别服务
result := &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Labels: []string{},
Provider: "local",
}
// 检查URL是否为空
if imageURL == "" {
result.Detail = "Empty image URL"
return result, nil
}
// 检查是否为支持的图片URL格式
validPrefixes := []string{"http://", "https://", "s3://", "oss://", "cos://"}
isValid := false
for _, prefix := range validPrefixes {
if strings.HasPrefix(strings.ToLower(imageURL), prefix) {
isValid = true
break
}
}
if !isValid {
result.Pass = false
result.Risk = "medium"
result.Suggest = "review"
result.Detail = "Invalid image URL format"
result.Labels = append(result.Labels, "invalid_url")
}
return result, nil
}
// SetSensitiveService 设置敏感词服务
func (p *localAuditProvider) SetSensitiveService(ss SensitiveService) {
p.sensitiveService = ss
}
// ==================== 阿里云审核服务提供商 ====================
// aliyunAuditProvider 阿里云审核服务提供商
type aliyunAuditProvider struct {
accessKey string
secretKey string
region string
}
// NewAliyunAuditProvider 创建阿里云审核服务提供商
func NewAliyunAuditProvider(accessKey, secretKey, region string) AuditServiceProvider {
return &aliyunAuditProvider{
accessKey: accessKey,
secretKey: secretKey,
region: region,
}
}
// GetName 获取提供商名称
func (p *aliyunAuditProvider) GetName() string {
return "aliyun"
}
// AuditText 审核文本
func (p *aliyunAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
// 阿里云内容安全API调用
// 实际项目中需要实现阿里云SDK调用
// 这里预留接口
result := &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Labels: []string{},
Provider: "aliyun",
Detail: "Aliyun audit not implemented, using pass",
}
// TODO: 实现阿里云内容安全API调用
// 具体参考: https://help.aliyun.com/document_detail/28417.html
return result, nil
}
// AuditImage 审核图片
func (p *aliyunAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
result := &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Labels: []string{},
Provider: "aliyun",
Detail: "Aliyun image audit not implemented, using pass",
}
// TODO: 实现阿里云图片审核API调用
return result, nil
}
// ==================== 腾讯云审核服务提供商 ====================
// tencentAuditProvider 腾讯云审核服务提供商
type tencentAuditProvider struct {
secretID string
secretKey string
}
// NewTencentAuditProvider 创建腾讯云审核服务提供商
func NewTencentAuditProvider(secretID, secretKey string) AuditServiceProvider {
return &tencentAuditProvider{
secretID: secretID,
secretKey: secretKey,
}
}
// GetName 获取提供商名称
func (p *tencentAuditProvider) GetName() string {
return "tencent"
}
// AuditText 审核文本
func (p *tencentAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
result := &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Labels: []string{},
Provider: "tencent",
Detail: "Tencent audit not implemented, using pass",
}
// TODO: 实现腾讯云内容审核API调用
// 具体参考: https://cloud.tencent.com/document/product/1124/64508
return result, nil
}
// AuditImage 审核图片
func (p *tencentAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
result := &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Labels: []string{},
Provider: "tencent",
Detail: "Tencent image audit not implemented, using pass",
}
// TODO: 实现腾讯云图片审核API调用
return result, nil
}
// ==================== 百度云审核服务提供商 ====================
// baiduAuditProvider 百度云审核服务提供商
type baiduAuditProvider struct {
apiKey string
secretKey string
}
// NewBaiduAuditProvider 创建百度云审核服务提供商
func NewBaiduAuditProvider(apiKey, secretKey string) AuditServiceProvider {
return &baiduAuditProvider{
apiKey: apiKey,
secretKey: secretKey,
}
}
// GetName 获取提供商名称
func (p *baiduAuditProvider) GetName() string {
return "baidu"
}
// AuditText 审核文本
func (p *baiduAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
result := &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Labels: []string{},
Provider: "baidu",
Detail: "Baidu audit not implemented, using pass",
}
// TODO: 实现百度云内容审核API调用
// 具体参考: https://cloud.baidu.com/doc/ANTISPAM/s/Jjw0r1iF6
return result, nil
}
// AuditImage 审核图片
func (p *baiduAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
result := &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Labels: []string{},
Provider: "baidu",
Detail: "Baidu image audit not implemented, using pass",
}
// TODO: 实现百度云图片审核API调用
return result, nil
}
// ==================== 审核结果回调处理 ====================
// AuditCallback 审核回调处理
type AuditCallback struct {
service AuditService
}
// NewAuditCallback 创建审核回调处理
func NewAuditCallback(service AuditService) *AuditCallback {
return &AuditCallback{
service: service,
}
}
// HandleTextCallback 处理文本审核回调
func (c *AuditCallback) HandleTextCallback(ctx context.Context, auditID string, result *AuditResult) error {
if c.service == nil || auditID == "" || result == nil {
return fmt.Errorf("invalid parameters")
}
log.Printf("Processing text audit callback: auditID=%s, result=%+v", auditID, result)
// 根据审核结果执行相应操作
// 例如: 更新帖子状态、发送通知等
return nil
}
// HandleImageCallback 处理图片审核回调
func (c *AuditCallback) HandleImageCallback(ctx context.Context, auditID string, result *AuditResult) error {
if c.service == nil || auditID == "" || result == nil {
return fmt.Errorf("invalid parameters")
}
log.Printf("Processing image audit callback: auditID=%s, result=%+v", auditID, result)
// 根据审核结果执行相应操作
// 例如: 更新图片状态、删除违规图片等
return nil
}
// ==================== 辅助函数 ====================
// IsContentSafe 判断内容是否安全
func IsContentSafe(result *AuditResult) bool {
if result == nil {
return true
}
return result.Pass && result.Suggest != "block"
}
// NeedReview 判断内容是否需要人工复审
func NeedReview(result *AuditResult) bool {
if result == nil {
return false
}
return result.Suggest == "review"
}
// GetRiskLevel 获取风险等级
func GetRiskLevel(result *AuditResult) string {
if result == nil {
return "low"
}
return result.Risk
}
// FormatAuditResult 格式化审核结果为字符串
func FormatAuditResult(result *AuditResult) string {
if result == nil {
return "{}"
}
data, _ := json.Marshal(result)
return string(data)
}
// ParseAuditResult 从字符串解析审核结果
func ParseAuditResult(data string) (*AuditResult, error) {
if data == "" {
return nil, fmt.Errorf("empty data")
}
var result AuditResult
if err := json.Unmarshal([]byte(data), &result); err != nil {
return nil, err
}
return &result, nil
}
// ==================== 审核日志查询 ====================
// GetAuditLogs 获取审核日志列表
func GetAuditLogs(db *gorm.DB, targetType string, targetID string, result string, page, pageSize int) ([]model.AuditLog, int64, error) {
query := db.Model(&model.AuditLog{})
if targetType != "" {
query = query.Where("target_type = ?", targetType)
}
if targetID != "" {
query = query.Where("target_id = ?", targetID)
}
if result != "" {
query = query.Where("result = ?", result)
}
var total int64
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
var logs []model.AuditLog
offset := (page - 1) * pageSize
if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&logs).Error; err != nil {
return nil, 0, err
}
return logs, total, nil
}
// ==================== 定时任务 ====================
// AuditScheduler 审核调度器
type AuditScheduler struct {
db *gorm.DB
service AuditService
interval time.Duration
stopCh chan bool
}
// NewAuditScheduler 创建审核调度器
func NewAuditScheduler(db *gorm.DB, service AuditService, interval time.Duration) *AuditScheduler {
return &AuditScheduler{
db: db,
service: service,
interval: interval,
stopCh: make(chan bool),
}
}
// Start 启动调度器
func (s *AuditScheduler) Start() {
go func() {
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.processPendingAudits()
case <-s.stopCh:
return
}
}
}()
}
// Stop 停止调度器
func (s *AuditScheduler) Stop() {
s.stopCh <- true
}
// processPendingAudits 处理待审核内容
func (s *AuditScheduler) processPendingAudits() {
// 查询待审核的内容
// 1. 查询审核状态为 pending 的记录
// 2. 调用审核服务
// 3. 更新审核状态
// 示例逻辑,实际需要根据业务需求实现
log.Println("Processing pending audits...")
}
// CleanupOldLogs 清理旧的审核日志
func CleanupOldLogs(db *gorm.DB, days int) error {
// 清理指定天数之前的审核日志
cutoffTime := time.Now().AddDate(0, 0, -days)
return db.Where("created_at < ? AND result = ?", cutoffTime, model.AuditResultPass).Delete(&model.AuditLog{}).Error
}

View File

@@ -0,0 +1,622 @@
package service
import (
"context"
"errors"
"fmt"
"log"
"time"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/websocket"
"carrot_bbs/internal/repository"
"gorm.io/gorm"
)
// 撤回消息的时间限制2分钟
const RecallMessageTimeout = 2 * time.Minute
// ChatService 聊天服务接口
type ChatService interface {
// 会话管理
GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error)
GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error)
GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error)
DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error
SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error
// 消息操作
SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error)
GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error)
GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error)
// 已读管理
MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error
GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error)
GetAllUnreadCount(ctx context.Context, userID string) (int64, error)
// 消息扩展功能
RecallMessage(ctx context.Context, messageID string, userID string) error
DeleteMessage(ctx context.Context, messageID string, userID string) error
// WebSocket相关
SendTyping(ctx context.Context, senderID string, conversationID string)
BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string)
// 系统消息推送
IsUserOnline(userID string) bool
PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error
PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error
PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error
// 仅保存消息到数据库,不发送 WebSocket 推送(供群聊等自行推送的场景使用)
SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
}
// chatServiceImpl 聊天服务实现
type chatServiceImpl struct {
db *gorm.DB
repo *repository.MessageRepository
userRepo *repository.UserRepository
sensitive SensitiveService
wsManager *websocket.WebSocketManager
}
// NewChatService 创建聊天服务
func NewChatService(
db *gorm.DB,
repo *repository.MessageRepository,
userRepo *repository.UserRepository,
sensitive SensitiveService,
wsManager *websocket.WebSocketManager,
) ChatService {
return &chatServiceImpl{
db: db,
repo: repo,
userRepo: userRepo,
sensitive: sensitive,
wsManager: wsManager,
}
}
// GetOrCreateConversation 获取或创建私聊会话
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
}
// GetConversationList 获取用户的会话列表
func (s *chatServiceImpl) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
return s.repo.GetConversations(userID, page, pageSize)
}
// GetConversationByID 获取会话详情
func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) {
// 验证用户是否是会话参与者
participant, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
// 获取会话信息
conv, err := s.repo.GetConversation(conversationID)
if err != nil {
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 填充用户的已读位置信息
_ = participant // 可以用于返回已读位置等信息
return conv, nil
}
// DeleteConversationForSelf 仅自己删除会话
func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error {
participant, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
if participant.ConversationID == "" {
return errors.New("conversation not found or no permission")
}
if err := s.repo.HideConversationForUser(conversationID, userID); err != nil {
return fmt.Errorf("failed to hide conversation: %w", err)
}
return nil
}
// SetConversationPinned 设置会话置顶(用户维度)
func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error {
participant, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
if participant.ConversationID == "" {
return errors.New("conversation not found or no permission")
}
if err := s.repo.UpdatePinned(conversationID, userID, isPinned); err != nil {
return fmt.Errorf("failed to update pinned status: %w", err)
}
return nil
}
// SendMessage 发送消息(使用 segments
func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
// 首先验证会话是否存在
conv, err := s.repo.GetConversation(conversationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("会话不存在,请重新创建会话")
}
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 拉黑限制:仅拦截“被拉黑方 -> 拉黑人”方向
if conv.Type == model.ConversationTypePrivate && s.userRepo != nil {
participants, pErr := s.repo.GetConversationParticipants(conversationID)
if pErr != nil {
return nil, fmt.Errorf("failed to get participants: %w", pErr)
}
var sentCount *int64
for _, p := range participants {
if p.UserID == senderID {
continue
}
blocked, bErr := s.userRepo.IsBlocked(p.UserID, senderID)
if bErr != nil {
return nil, fmt.Errorf("failed to check block status: %w", bErr)
}
if blocked {
return nil, ErrUserBlocked
}
// 陌生人限制:对方未回关前,只允许发送一条文本消息,且禁止发送图片
isFollowedBack, fErr := s.userRepo.IsFollowing(p.UserID, senderID)
if fErr != nil {
return nil, fmt.Errorf("failed to check follow status: %w", fErr)
}
if !isFollowedBack {
if containsImageSegment(segments) {
return nil, errors.New("对方未关注你,暂不支持发送图片")
}
if sentCount == nil {
c, cErr := s.repo.CountMessagesBySenderInConversation(conversationID, senderID)
if cErr != nil {
return nil, fmt.Errorf("failed to count sender messages: %w", cErr)
}
sentCount = &c
}
if *sentCount >= 1 {
return nil, errors.New("对方未关注你前,仅允许发送一条消息")
}
}
}
}
// 验证用户是否是会话参与者
participant, err := s.repo.GetParticipant(conversationID, senderID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("您不是该会话的参与者")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
// 创建消息
message := &model.Message{
ConversationID: conversationID,
SenderID: senderID, // 直接使用string类型的UUID
Segments: segments,
ReplyToID: replyToID,
Status: model.MessageStatusNormal,
}
// 使用事务创建消息并更新seq
if err := s.repo.CreateMessageWithSeq(message); err != nil {
return nil, fmt.Errorf("failed to save message: %w", err)
}
// 发送消息给接收者
log.Printf("[DEBUG SendMessage] 私聊消息 segments 类型: %T, 值: %+v", message.Segments, message.Segments)
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, websocket.ChatMessage{
ID: message.ID,
ConversationID: message.ConversationID,
SenderID: senderID,
Segments: message.Segments,
Seq: message.Seq,
CreatedAt: message.CreatedAt.UnixMilli(),
})
// 获取会话中的其他参与者
participants, err := s.repo.GetConversationParticipants(conversationID)
if err == nil {
for _, p := range participants {
// 不发给自己
if p.UserID == senderID {
continue
}
// 如果接收者在线,发送实时消息
if s.wsManager != nil {
isOnline := s.wsManager.IsUserOnline(p.UserID)
log.Printf("[DEBUG SendMessage] 接收者 UserID=%s, 在线状态=%v", p.UserID, isOnline)
if isOnline {
log.Printf("[DEBUG SendMessage] 发送WebSocket消息给 UserID=%s, 消息类型=%s", p.UserID, wsMsg.Type)
s.wsManager.SendToUser(p.UserID, wsMsg)
}
}
}
} else {
log.Printf("[DEBUG SendMessage] 获取参与者失败: %v", err)
}
_ = participant // 避免未使用变量警告
return message, nil
}
func containsImageSegment(segments model.MessageSegments) bool {
for _, seg := range segments {
if seg.Type == string(model.ContentTypeImage) || seg.Type == "image" {
return true
}
}
return false
}
// GetMessages 获取消息历史(分页)
func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, 0, errors.New("conversation not found or no permission")
}
return nil, 0, fmt.Errorf("failed to get participant: %w", err)
}
return s.repo.GetMessages(conversationID, page, pageSize)
}
// GetMessagesAfterSeq 获取指定seq之后的消息用于增量同步
func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
if limit <= 0 {
limit = 100
}
return s.repo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
}
// GetMessagesBeforeSeq 获取指定seq之前的历史消息用于下拉加载更多
func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
if limit <= 0 {
limit = 20
}
return s.repo.GetMessagesBeforeSeq(conversationID, beforeSeq, limit)
}
// MarkAsRead 标记已读
func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
// 更新参与者的已读位置
err = s.repo.UpdateLastReadSeq(conversationID, userID, seq)
if err != nil {
return fmt.Errorf("failed to update last read seq: %w", err)
}
// 发送已读回执(作为 meta 事件)
if s.wsManager != nil {
wsMsg := websocket.CreateWSMessage("meta", map[string]interface{}{
"detail_type": websocket.MetaDetailTypeRead,
"conversation_id": conversationID,
"seq": seq,
"user_id": userID,
})
// 获取会话中的所有参与者
participants, err := s.repo.GetConversationParticipants(conversationID)
if err == nil {
// 推送给会话中的所有参与者(包括自己)
for _, p := range participants {
if s.wsManager.IsUserOnline(p.UserID) {
s.wsManager.SendToUser(p.UserID, wsMsg)
}
}
}
}
return nil
}
// GetUnreadCount 获取指定会话的未读消息数
func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return 0, errors.New("conversation not found or no permission")
}
return 0, fmt.Errorf("failed to get participant: %w", err)
}
return s.repo.GetUnreadCount(conversationID, userID)
}
// GetAllUnreadCount 获取所有会话的未读消息总数
func (s *chatServiceImpl) GetAllUnreadCount(ctx context.Context, userID string) (int64, error) {
return s.repo.GetAllUnreadCount(userID)
}
// RecallMessage 撤回消息2分钟内
func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, userID string) error {
// 获取消息
var message model.Message
err := s.db.First(&message, "id = ?", messageID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("message not found")
}
return fmt.Errorf("failed to get message: %w", err)
}
// 验证是否是消息发送者
if message.SenderIDStr() != userID {
return errors.New("can only recall your own messages")
}
// 验证消息是否已被撤回
if message.Status == model.MessageStatusRecalled {
return errors.New("message already recalled")
}
// 验证是否在2分钟内
if time.Since(message.CreatedAt) > RecallMessageTimeout {
return errors.New("message recall timeout (2 minutes)")
}
// 更新消息状态为已撤回
err = s.db.Model(&message).Update("status", model.MessageStatusRecalled).Error
if err != nil {
return fmt.Errorf("failed to recall message: %w", err)
}
// 发送撤回通知
if s.wsManager != nil {
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeRecall, map[string]interface{}{
"messageId": messageID,
"conversationId": message.ConversationID,
"senderId": userID,
})
// 通知会话中的所有参与者
participants, err := s.repo.GetConversationParticipants(message.ConversationID)
if err == nil {
for _, p := range participants {
if s.wsManager.IsUserOnline(p.UserID) {
s.wsManager.SendToUser(p.UserID, wsMsg)
}
}
}
}
return nil
}
// DeleteMessage 删除消息(仅对自己可见)
func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, userID string) error {
// 获取消息
var message model.Message
err := s.db.First(&message, "id = ?", messageID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("message not found")
}
return fmt.Errorf("failed to get message: %w", err)
}
// 验证用户是否是会话参与者
_, err = s.repo.GetParticipant(message.ConversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("no permission to delete this message")
}
return fmt.Errorf("failed to get participant: %w", err)
}
// 对于删除消息,我们使用软删除,但需要确保只对当前用户隐藏
// 这里简化处理:只有发送者可以删除自己的消息
if message.SenderIDStr() != userID {
return errors.New("can only delete your own messages")
}
// 更新消息状态为已删除
err = s.db.Model(&message).Update("status", model.MessageStatusDeleted).Error
if err != nil {
return fmt.Errorf("failed to delete message: %w", err)
}
return nil
}
// SendTyping 发送正在输入状态
func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) {
if s.wsManager == nil {
return
}
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, senderID)
if err != nil {
return
}
// 获取会话中的其他参与者
participants, err := s.repo.GetConversationParticipants(conversationID)
if err != nil {
return
}
for _, p := range participants {
if p.UserID == senderID {
continue
}
// 发送正在输入状态
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeTyping, map[string]string{
"conversationId": conversationID,
"senderId": senderID,
})
if s.wsManager.IsUserOnline(p.UserID) {
s.wsManager.SendToUser(p.UserID, wsMsg)
}
}
}
// BroadcastMessage 广播消息给用户
func (s *chatServiceImpl) BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string) {
if s.wsManager != nil {
s.wsManager.SendToUser(targetUser, msg)
}
}
// IsUserOnline 检查用户是否在线
func (s *chatServiceImpl) IsUserOnline(userID string) bool {
if s.wsManager == nil {
return false
}
return s.wsManager.IsUserOnline(userID)
}
// PushSystemMessage 推送系统消息给指定用户
func (s *chatServiceImpl) PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error {
if s.wsManager == nil {
return errors.New("websocket manager not available")
}
if !s.wsManager.IsUserOnline(userID) {
return errors.New("user is offline")
}
sysMsg := &websocket.SystemMessage{
ID: "", // 由调用方生成
Type: msgType,
Title: title,
Content: content,
Data: data,
CreatedAt: time.Now().UnixMilli(),
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
s.wsManager.SendToUser(userID, wsMsg)
return nil
}
// PushNotificationMessage 推送通知消息给指定用户
func (s *chatServiceImpl) PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error {
if s.wsManager == nil {
return errors.New("websocket manager not available")
}
if !s.wsManager.IsUserOnline(userID) {
return errors.New("user is offline")
}
// 确保时间戳已设置
if notification.CreatedAt == 0 {
notification.CreatedAt = time.Now().UnixMilli()
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
s.wsManager.SendToUser(userID, wsMsg)
return nil
}
// PushAnnouncementMessage 广播公告消息给所有在线用户
func (s *chatServiceImpl) PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error {
if s.wsManager == nil {
return errors.New("websocket manager not available")
}
// 确保时间戳已设置
if announcement.CreatedAt == 0 {
announcement.CreatedAt = time.Now().UnixMilli()
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
s.wsManager.Broadcast(wsMsg)
return nil
}
// SaveMessage 仅保存消息到数据库,不发送 WebSocket 推送
// 适用于群聊等由调用方自行负责推送的场景
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
// 验证会话是否存在
_, err := s.repo.GetConversation(conversationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("会话不存在,请重新创建会话")
}
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 验证用户是否是会话参与者
_, err = s.repo.GetParticipant(conversationID, senderID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("您不是该会话的参与者")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
message := &model.Message{
ConversationID: conversationID,
SenderID: senderID,
Segments: segments,
ReplyToID: replyToID,
Status: model.MessageStatusNormal,
}
if err := s.repo.CreateMessageWithSeq(message); err != nil {
return nil, fmt.Errorf("failed to save message: %w", err)
}
return message, nil
}

View File

@@ -0,0 +1,273 @@
package service
import (
"context"
"errors"
"fmt"
"log"
"strings"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/gorse"
"carrot_bbs/internal/repository"
)
// CommentService 评论服务
type CommentService struct {
commentRepo *repository.CommentRepository
postRepo *repository.PostRepository
systemMessageService SystemMessageService
gorseClient gorse.Client
postAIService *PostAIService
}
// NewCommentService 创建评论服务
func NewCommentService(commentRepo *repository.CommentRepository, postRepo *repository.PostRepository, systemMessageService SystemMessageService, gorseClient gorse.Client, postAIService *PostAIService) *CommentService {
return &CommentService{
commentRepo: commentRepo,
postRepo: postRepo,
systemMessageService: systemMessageService,
gorseClient: gorseClient,
postAIService: postAIService,
}
}
// Create 创建评论
func (s *CommentService) Create(ctx context.Context, postID, userID, content string, parentID *string, images string, imageURLs []string) (*model.Comment, error) {
if s.postAIService != nil {
// 采用异步审核,前端先立即返回
}
// 获取帖子信息用于发送通知
post, err := s.postRepo.GetByID(postID)
if err != nil {
return nil, err
}
comment := &model.Comment{
PostID: postID,
UserID: userID,
Content: content,
ParentID: parentID,
Images: images,
Status: model.CommentStatusPending,
}
// 如果有父评论设置根评论ID
var parentUserID string
if parentID != nil {
parent, err := s.commentRepo.GetByID(*parentID)
if err == nil && parent != nil {
if parent.RootID != nil {
comment.RootID = parent.RootID
} else {
comment.RootID = parentID
}
parentUserID = parent.UserID
}
}
err = s.commentRepo.Create(comment)
if err != nil {
return nil, err
}
// 重新查询以获取关联的 User
comment, err = s.commentRepo.GetByID(comment.ID)
if err != nil {
return nil, err
}
go s.reviewCommentAsync(comment.ID, userID, postID, content, imageURLs, parentID, parentUserID, post.UserID)
return comment, nil
}
func (s *CommentService) reviewCommentAsync(
commentID, userID, postID, content string,
imageURLs []string,
parentID *string,
parentUserID string,
postOwnerID string,
) {
// 未启用AI时直接通过审核并发送后续通知
if s.postAIService == nil || !s.postAIService.IsEnabled() {
if err := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); err != nil {
log.Printf("[WARN] Failed to publish comment without AI moderation: %v", err)
return
}
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
return
}
err := s.postAIService.ModerateComment(context.Background(), content, imageURLs)
if err != nil {
var rejectedErr *CommentModerationRejectedError
if errors.As(err, &rejectedErr) {
if delErr := s.commentRepo.Delete(commentID); delErr != nil {
log.Printf("[WARN] Failed to delete rejected comment %s: %v", commentID, delErr)
}
s.notifyCommentModerationRejected(userID, rejectedErr.Reason)
return
}
// 审核服务异常时降级放行避免评论长期pending
if updateErr := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); updateErr != nil {
log.Printf("[WARN] Failed to publish comment %s after moderation error: %v", commentID, updateErr)
return
}
log.Printf("[WARN] Comment moderation failed, fallback publish comment=%s err=%v", commentID, err)
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
return
}
if updateErr := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); updateErr != nil {
log.Printf("[WARN] Failed to publish comment %s: %v", commentID, updateErr)
return
}
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
}
func (s *CommentService) afterCommentPublished(userID, postID, commentID string, parentID *string, parentUserID, postOwnerID string) {
// 发送系统消息通知
if s.systemMessageService != nil {
go func() {
if parentID != nil && parentUserID != "" {
// 回复评论,通知被回复的人
if parentUserID != userID {
notifyErr := s.systemMessageService.SendReplyNotification(context.Background(), parentUserID, userID, postID, *parentID, commentID)
if notifyErr != nil {
fmt.Printf("[DEBUG] Error sending reply notification: %v\n", notifyErr)
}
}
} else {
// 评论帖子,通知帖子作者
if postOwnerID != userID {
notifyErr := s.systemMessageService.SendCommentNotification(context.Background(), postOwnerID, userID, postID, commentID)
if notifyErr != nil {
fmt.Printf("[DEBUG] Error sending comment notification: %v\n", notifyErr)
}
}
}
}()
}
// 推送评论行为到Gorse异步
go func() {
if s.gorseClient.IsEnabled() {
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeComment, userID, postID); err != nil {
log.Printf("[WARN] Failed to insert comment feedback to Gorse: %v", err)
}
}
}()
}
func (s *CommentService) notifyCommentModerationRejected(userID, reason string) {
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
return
}
content := "您发布的评论未通过AI审核请修改后重试。"
if strings.TrimSpace(reason) != "" {
content = fmt.Sprintf("您发布的评论未通过AI审核原因%s。请修改后重试。", reason)
}
go func() {
if err := s.systemMessageService.SendSystemAnnouncement(
context.Background(),
[]string{userID},
"评论审核未通过",
content,
); err != nil {
log.Printf("[WARN] Failed to send comment moderation reject notification: %v", err)
}
}()
}
// GetByID 根据ID获取评论
func (s *CommentService) GetByID(ctx context.Context, id string) (*model.Comment, error) {
return s.commentRepo.GetByID(id)
}
// GetByPostID 获取帖子评论
func (s *CommentService) GetByPostID(ctx context.Context, postID string, page, pageSize int) ([]*model.Comment, int64, error) {
// 使用带回复的查询默认加载前3条回复
return s.commentRepo.GetByPostIDWithReplies(postID, page, pageSize, 3)
}
// GetRepliesByRootID 根据根评论ID分页获取回复
func (s *CommentService) GetRepliesByRootID(ctx context.Context, rootID string, page, pageSize int) ([]*model.Comment, int64, error) {
return s.commentRepo.GetRepliesByRootID(rootID, page, pageSize)
}
// GetReplies 获取回复
func (s *CommentService) GetReplies(ctx context.Context, parentID string) ([]*model.Comment, error) {
return s.commentRepo.GetReplies(parentID)
}
// Update 更新评论
func (s *CommentService) Update(ctx context.Context, comment *model.Comment) error {
return s.commentRepo.Update(comment)
}
// Delete 删除评论
func (s *CommentService) Delete(ctx context.Context, id string) error {
return s.commentRepo.Delete(id)
}
// Like 点赞评论
func (s *CommentService) Like(ctx context.Context, commentID, userID string) error {
// 获取评论信息用于发送通知
comment, err := s.commentRepo.GetByID(commentID)
if err != nil {
return err
}
err = s.commentRepo.Like(commentID, userID)
if err != nil {
return err
}
// 发送评论/回复点赞通知(只有不是给自己点赞时才发送)
if s.systemMessageService != nil && comment.UserID != userID {
go func() {
var notifyErr error
if comment.ParentID != nil {
notifyErr = s.systemMessageService.SendLikeReplyNotification(
context.Background(),
comment.UserID,
userID,
comment.PostID,
commentID,
comment.Content,
)
} else {
notifyErr = s.systemMessageService.SendLikeCommentNotification(
context.Background(),
comment.UserID,
userID,
comment.PostID,
commentID,
comment.Content,
)
}
if notifyErr != nil {
fmt.Printf("[DEBUG] Error sending like notification: %v\n", notifyErr)
} else {
fmt.Printf("[DEBUG] Like notification sent successfully\n")
}
}()
}
return nil
}
// Unlike 取消点赞评论
func (s *CommentService) Unlike(ctx context.Context, commentID, userID string) error {
return s.commentRepo.Unlike(commentID, userID)
}
// IsLiked 检查是否已点赞
func (s *CommentService) IsLiked(ctx context.Context, commentID, userID string) bool {
return s.commentRepo.IsLiked(commentID, userID)
}

View File

@@ -0,0 +1,234 @@
package service
import (
"context"
"crypto/rand"
"encoding/json"
"fmt"
"math/big"
"strings"
"time"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/pkg/utils"
)
const (
verifyCodeTTL = 10 * time.Minute
verifyCodeRateLimitTTL = 60 * time.Second
)
const (
CodePurposeRegister = "register"
CodePurposePasswordReset = "password_reset"
CodePurposeEmailVerify = "email_verify"
CodePurposeChangePassword = "change_password"
)
type verificationCodePayload struct {
Code string `json:"code"`
Purpose string `json:"purpose"`
Email string `json:"email"`
ExpiresAt int64 `json:"expires_at"`
}
type EmailCodeService interface {
SendCode(ctx context.Context, purpose, email string) error
VerifyCode(purpose, email, code string) error
}
type emailCodeServiceImpl struct {
emailService EmailService
cache cache.Cache
}
func NewEmailCodeService(emailService EmailService, cacheBackend cache.Cache) EmailCodeService {
if cacheBackend == nil {
cacheBackend = cache.GetCache()
}
return &emailCodeServiceImpl{
emailService: emailService,
cache: cacheBackend,
}
}
func verificationCodeCacheKey(purpose, email string) string {
return fmt.Sprintf("auth:verify_code:%s:%s", purpose, strings.ToLower(strings.TrimSpace(email)))
}
func verificationCodeRateLimitKey(purpose, email string) string {
return fmt.Sprintf("auth:verify_code_rate_limit:%s:%s", purpose, strings.ToLower(strings.TrimSpace(email)))
}
func generateNumericCode(length int) (string, error) {
if length <= 0 {
return "", fmt.Errorf("invalid code length")
}
max := big.NewInt(10)
result := make([]byte, length)
for i := 0; i < length; i++ {
n, err := rand.Int(rand.Reader, max)
if err != nil {
return "", err
}
result[i] = byte('0' + n.Int64())
}
return string(result), nil
}
func (s *emailCodeServiceImpl) SendCode(ctx context.Context, purpose, email string) error {
if strings.TrimSpace(email) == "" || !utils.ValidateEmail(email) {
return ErrInvalidEmail
}
if s.emailService == nil || !s.emailService.IsEnabled() {
return ErrEmailServiceUnavailable
}
if s.cache == nil {
return ErrVerificationCodeUnavailable
}
rateLimitKey := verificationCodeRateLimitKey(purpose, email)
if s.cache.Exists(rateLimitKey) {
return ErrVerificationCodeTooFrequent
}
code, err := generateNumericCode(6)
if err != nil {
return fmt.Errorf("generate verification code failed: %w", err)
}
payload := verificationCodePayload{
Code: code,
Purpose: purpose,
Email: strings.ToLower(strings.TrimSpace(email)),
ExpiresAt: time.Now().Add(verifyCodeTTL).Unix(),
}
cacheKey := verificationCodeCacheKey(purpose, email)
s.cache.Set(cacheKey, payload, verifyCodeTTL)
s.cache.Set(rateLimitKey, "1", verifyCodeRateLimitTTL)
subject, sceneText := verificationEmailMeta(purpose)
textBody := fmt.Sprintf("【%s】验证码%s\n有效期10分钟\n请勿将验证码泄露给他人。", sceneText, code)
htmlBody := buildVerificationEmailHTML(sceneText, code)
if err := s.emailService.Send(ctx, SendEmailRequest{
To: []string{email},
Subject: subject,
TextBody: textBody,
HTMLBody: htmlBody,
}); err != nil {
s.cache.Delete(cacheKey)
return fmt.Errorf("send verification email failed: %w", err)
}
return nil
}
func (s *emailCodeServiceImpl) VerifyCode(purpose, email, code string) error {
if strings.TrimSpace(email) == "" || strings.TrimSpace(code) == "" {
return ErrVerificationCodeInvalid
}
if s.cache == nil {
return ErrVerificationCodeUnavailable
}
cacheKey := verificationCodeCacheKey(purpose, email)
raw, ok := s.cache.Get(cacheKey)
if !ok {
return ErrVerificationCodeExpired
}
var payload verificationCodePayload
switch v := raw.(type) {
case string:
if err := json.Unmarshal([]byte(v), &payload); err != nil {
return ErrVerificationCodeInvalid
}
case []byte:
if err := json.Unmarshal(v, &payload); err != nil {
return ErrVerificationCodeInvalid
}
case verificationCodePayload:
payload = v
default:
data, err := json.Marshal(v)
if err != nil {
return ErrVerificationCodeInvalid
}
if err := json.Unmarshal(data, &payload); err != nil {
return ErrVerificationCodeInvalid
}
}
if payload.Purpose != purpose || payload.Email != strings.ToLower(strings.TrimSpace(email)) {
return ErrVerificationCodeInvalid
}
if payload.ExpiresAt > 0 && time.Now().Unix() > payload.ExpiresAt {
s.cache.Delete(cacheKey)
return ErrVerificationCodeExpired
}
if payload.Code != strings.TrimSpace(code) {
return ErrVerificationCodeInvalid
}
s.cache.Delete(cacheKey)
return nil
}
func verificationEmailMeta(purpose string) (subject string, sceneText string) {
switch purpose {
case CodePurposeRegister:
return "Carrot BBS 注册验证码", "注册账号"
case CodePurposePasswordReset:
return "Carrot BBS 找回密码验证码", "找回密码"
case CodePurposeEmailVerify:
return "Carrot BBS 邮箱验证验证码", "验证邮箱"
case CodePurposeChangePassword:
return "Carrot BBS 修改密码验证码", "修改密码"
default:
return "Carrot BBS 验证码", "身份验证"
}
}
func buildVerificationEmailHTML(sceneText, code string) string {
return fmt.Sprintf(`<!doctype html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Carrot BBS 验证码</title>
</head>
<body style="margin:0;padding:0;background:#f4f6fb;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,'PingFang SC','Microsoft YaHei',sans-serif;color:#1f2937;">
<table role="presentation" width="100%%" cellspacing="0" cellpadding="0" style="background:#f4f6fb;padding:24px 12px;">
<tr>
<td align="center">
<table role="presentation" width="100%%" cellspacing="0" cellpadding="0" style="max-width:560px;background:#ffffff;border-radius:14px;overflow:hidden;box-shadow:0 8px 30px rgba(15,23,42,0.08);">
<tr>
<td style="background:linear-gradient(135deg,#ff6b35,#ff8f66);padding:24px 28px;color:#ffffff;">
<div style="font-size:22px;font-weight:700;line-height:1.2;">Carrot BBS</div>
<div style="margin-top:6px;font-size:14px;opacity:0.95;">%s 验证</div>
</td>
</tr>
<tr>
<td style="padding:28px;">
<p style="margin:0 0 14px;font-size:15px;line-height:1.75;">你好,</p>
<p style="margin:0 0 20px;font-size:15px;line-height:1.75;">你正在进行 <strong>%s</strong> 操作,请使用下方验证码完成验证:</p>
<div style="margin:0 auto 18px;max-width:320px;border:1px dashed #ff8f66;background:#fff8f4;border-radius:12px;padding:14px 12px;text-align:center;">
<div style="font-size:13px;color:#9a3412;letter-spacing:0.5px;">验证码10分钟内有效</div>
<div style="margin-top:8px;font-size:34px;line-height:1;font-weight:800;letter-spacing:8px;color:#ea580c;">%s</div>
</div>
<p style="margin:0 0 8px;font-size:13px;color:#6b7280;line-height:1.7;">如果不是你本人操作,请忽略此邮件,并及时检查账号安全。</p>
<p style="margin:0;font-size:13px;color:#6b7280;line-height:1.7;">请勿向任何人透露验证码,平台不会以任何理由索取验证码。</p>
</td>
</tr>
<tr>
<td style="padding:14px 28px;background:#f8fafc;border-top:1px solid #e5e7eb;color:#94a3b8;font-size:12px;line-height:1.7;">
此邮件由系统自动发送,请勿直接回复。<br/>
© Carrot BBS
</td>
</tr>
</table>
</td>
</tr>
</table>
</body>
</html>`, sceneText, sceneText, code)
}

View File

@@ -0,0 +1,82 @@
package service
import (
"context"
"fmt"
"strings"
emailpkg "carrot_bbs/internal/pkg/email"
)
// SendEmailRequest 发信请求
type SendEmailRequest struct {
To []string
Cc []string
Bcc []string
ReplyTo []string
Subject string
TextBody string
HTMLBody string
Attachments []string
}
type EmailService interface {
IsEnabled() bool
Send(ctx context.Context, req SendEmailRequest) error
SendText(ctx context.Context, to []string, subject, body string) error
SendHTML(ctx context.Context, to []string, subject, html string) error
}
type emailServiceImpl struct {
client emailpkg.Client
}
func NewEmailService(client emailpkg.Client) EmailService {
return &emailServiceImpl{client: client}
}
func (s *emailServiceImpl) IsEnabled() bool {
return s.client != nil && s.client.IsEnabled()
}
func (s *emailServiceImpl) Send(ctx context.Context, req SendEmailRequest) error {
if s.client == nil {
return fmt.Errorf("email client is nil")
}
if !s.client.IsEnabled() {
return fmt.Errorf("email service is disabled")
}
if len(req.To) == 0 {
return fmt.Errorf("email recipient is empty")
}
if strings.TrimSpace(req.Subject) == "" {
return fmt.Errorf("email subject is empty")
}
return s.client.Send(ctx, emailpkg.Message{
To: req.To,
Cc: req.Cc,
Bcc: req.Bcc,
ReplyTo: req.ReplyTo,
Subject: req.Subject,
TextBody: req.TextBody,
HTMLBody: req.HTMLBody,
Attachments: req.Attachments,
})
}
func (s *emailServiceImpl) SendText(ctx context.Context, to []string, subject, body string) error {
return s.Send(ctx, SendEmailRequest{
To: to,
Subject: subject,
TextBody: body,
})
}
func (s *emailServiceImpl) SendHTML(ctx context.Context, to []string, subject, html string) error {
return s.Send(ctx, SendEmailRequest{
To: to,
Subject: subject,
HTMLBody: html,
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,38 @@
package service
import (
"carrot_bbs/internal/pkg/jwt"
"time"
)
// JWTService JWT服务
type JWTService struct {
jwt *jwt.JWT
}
// NewJWTService 创建JWT服务
func NewJWTService(secret string, accessExpire, refreshExpire int64) *JWTService {
return &JWTService{
jwt: jwt.New(secret, time.Duration(accessExpire)*time.Second, time.Duration(refreshExpire)*time.Second),
}
}
// GenerateAccessToken 生成访问令牌
func (s *JWTService) GenerateAccessToken(userID, username string) (string, error) {
return s.jwt.GenerateAccessToken(userID, username)
}
// GenerateRefreshToken 生成刷新令牌
func (s *JWTService) GenerateRefreshToken(userID, username string) (string, error) {
return s.jwt.GenerateRefreshToken(userID, username)
}
// ParseToken 解析令牌
func (s *JWTService) ParseToken(tokenString string) (*jwt.Claims, error) {
return s.jwt.ParseToken(tokenString)
}
// ValidateToken 验证令牌
func (s *JWTService) ValidateToken(tokenString string) error {
return s.jwt.ValidateToken(tokenString)
}

View File

@@ -0,0 +1,215 @@
package service
import (
"context"
"time"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/model"
"carrot_bbs/internal/repository"
)
// 缓存TTL常量
const (
ConversationListTTL = 60 * time.Second // 会话列表缓存60秒
ConversationDetailTTL = 60 * time.Second // 会话详情缓存60秒
UnreadCountTTL = 30 * time.Second // 未读数缓存30秒
ConversationNullTTL = 5 * time.Second
UnreadNullTTL = 5 * time.Second
CacheJitterRatio = 0.1
)
// MessageService 消息服务
type MessageService struct {
messageRepo *repository.MessageRepository
cache cache.Cache
}
// NewMessageService 创建消息服务
func NewMessageService(messageRepo *repository.MessageRepository) *MessageService {
return &MessageService{
messageRepo: messageRepo,
cache: cache.GetCache(),
}
}
// ConversationListResult 会话列表缓存结果
type ConversationListResult struct {
Conversations []*model.Conversation
Total int64
}
// SendMessage 发送消息(使用 segments
// senderID 和 receiverID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) SendMessage(ctx context.Context, senderID, receiverID string, segments model.MessageSegments) (*model.Message, error) {
// 获取或创建会话
conv, err := s.messageRepo.GetOrCreatePrivateConversation(senderID, receiverID)
if err != nil {
return nil, err
}
msg := &model.Message{
ConversationID: conv.ID,
SenderID: senderID,
Segments: segments,
Status: model.MessageStatusNormal,
}
// 使用事务创建消息并更新seq
err = s.messageRepo.CreateMessageWithSeq(msg)
if err != nil {
return nil, err
}
// 失效会话列表缓存(发送者和接收者)
cache.InvalidateConversationList(s.cache, senderID)
cache.InvalidateConversationList(s.cache, receiverID)
// 失效未读数缓存
cache.InvalidateUnreadConversation(s.cache, receiverID)
cache.InvalidateUnreadDetail(s.cache, receiverID, conv.ID)
return msg, nil
}
// GetConversations 获取会话列表(带缓存)
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) GetConversations(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
cacheSettings := cache.GetSettings()
conversationTTL := cacheSettings.ConversationTTL
if conversationTTL <= 0 {
conversationTTL = ConversationListTTL
}
nullTTL := cacheSettings.NullTTL
if nullTTL <= 0 {
nullTTL = ConversationNullTTL
}
jitter := cacheSettings.JitterRatio
if jitter <= 0 {
jitter = CacheJitterRatio
}
// 生成缓存键
cacheKey := cache.ConversationListKey(userID, page, pageSize)
result, err := cache.GetOrLoadTyped[*ConversationListResult](
s.cache,
cacheKey,
conversationTTL,
jitter,
nullTTL,
func() (*ConversationListResult, error) {
conversations, total, err := s.messageRepo.GetConversations(userID, page, pageSize)
if err != nil {
return nil, err
}
return &ConversationListResult{
Conversations: conversations,
Total: total,
}, nil
},
)
if err != nil {
return nil, 0, err
}
if result == nil {
return []*model.Conversation{}, 0, nil
}
return result.Conversations, result.Total, nil
}
// GetMessages 获取消息列表
func (s *MessageService) GetMessages(ctx context.Context, conversationID string, page, pageSize int) ([]*model.Message, int64, error) {
return s.messageRepo.GetMessages(conversationID, page, pageSize)
}
// GetMessagesAfterSeq 获取指定seq之后的消息增量同步
func (s *MessageService) GetMessagesAfterSeq(ctx context.Context, conversationID string, afterSeq int64, limit int) ([]*model.Message, error) {
return s.messageRepo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
}
// MarkAsRead 标记为已读
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) MarkAsRead(ctx context.Context, conversationID string, userID string, lastReadSeq int64) error {
err := s.messageRepo.UpdateLastReadSeq(conversationID, userID, lastReadSeq)
if err != nil {
return err
}
// 失效未读数缓存
cache.InvalidateUnreadConversation(s.cache, userID)
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
// 失效会话列表缓存
cache.InvalidateConversationList(s.cache, userID)
return nil
}
// GetUnreadCount 获取未读消息数(带缓存)
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
cacheSettings := cache.GetSettings()
unreadTTL := cacheSettings.UnreadCountTTL
if unreadTTL <= 0 {
unreadTTL = UnreadCountTTL
}
nullTTL := cacheSettings.NullTTL
if nullTTL <= 0 {
nullTTL = UnreadNullTTL
}
jitter := cacheSettings.JitterRatio
if jitter <= 0 {
jitter = CacheJitterRatio
}
// 生成缓存键
cacheKey := cache.UnreadDetailKey(userID, conversationID)
return cache.GetOrLoadTyped[int64](
s.cache,
cacheKey,
unreadTTL,
jitter,
nullTTL,
func() (int64, error) {
return s.messageRepo.GetUnreadCount(conversationID, userID)
},
)
}
// GetOrCreateConversation 获取或创建私聊会话
// user1ID 和 user2ID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
conv, err := s.messageRepo.GetOrCreatePrivateConversation(user1ID, user2ID)
if err != nil {
return nil, err
}
// 失效会话列表缓存
cache.InvalidateConversationList(s.cache, user1ID)
cache.InvalidateConversationList(s.cache, user2ID)
return conv, nil
}
// GetConversationParticipants 获取会话参与者列表
func (s *MessageService) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) {
return s.messageRepo.GetConversationParticipants(conversationID)
}
// ParseConversationID 辅助函数直接返回字符串ID已经是string类型
func ParseConversationID(idStr string) (string, error) {
return idStr, nil
}
// InvalidateUserConversationCache 失效用户会话相关缓存(供外部调用)
func (s *MessageService) InvalidateUserConversationCache(userID string) {
cache.InvalidateConversationList(s.cache, userID)
cache.InvalidateUnreadConversation(s.cache, userID)
}
// InvalidateUserUnreadCache 失效用户未读数缓存(供外部调用)
func (s *MessageService) InvalidateUserUnreadCache(userID, conversationID string) {
cache.InvalidateUnreadConversation(s.cache, userID)
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
}

View File

@@ -0,0 +1,169 @@
package service
import (
"context"
"time"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/model"
"carrot_bbs/internal/repository"
)
// 缓存TTL常量
const (
NotificationUnreadCountTTL = 30 * time.Second // 通知未读数缓存30秒
NotificationNullTTL = 5 * time.Second
NotificationCacheJitter = 0.1
)
// NotificationService 通知服务
type NotificationService struct {
notificationRepo *repository.NotificationRepository
cache cache.Cache
}
// NewNotificationService 创建通知服务
func NewNotificationService(notificationRepo *repository.NotificationRepository) *NotificationService {
return &NotificationService{
notificationRepo: notificationRepo,
cache: cache.GetCache(),
}
}
// Create 创建通知
func (s *NotificationService) Create(ctx context.Context, userID string, notificationType model.NotificationType, title, content string) (*model.Notification, error) {
notification := &model.Notification{
UserID: userID,
Type: notificationType,
Title: title,
Content: content,
IsRead: false,
}
err := s.notificationRepo.Create(notification)
if err != nil {
return nil, err
}
// 失效未读数缓存
cache.InvalidateUnreadSystem(s.cache, userID)
return notification, nil
}
// GetByUserID 获取用户通知
func (s *NotificationService) GetByUserID(ctx context.Context, userID string, page, pageSize int, unreadOnly bool) ([]*model.Notification, int64, error) {
return s.notificationRepo.GetByUserID(userID, page, pageSize, unreadOnly)
}
// MarkAsRead 标记为已读
func (s *NotificationService) MarkAsRead(ctx context.Context, id string) error {
err := s.notificationRepo.MarkAsRead(id)
if err != nil {
return err
}
// 注意这里无法获取userID所以不在缓存中失效
// 调用方应该使用MarkAsReadWithUserID方法
return nil
}
// MarkAsReadWithUserID 标记为已读带用户ID用于缓存失效
func (s *NotificationService) MarkAsReadWithUserID(ctx context.Context, id, userID string) error {
err := s.notificationRepo.MarkAsRead(id)
if err != nil {
return err
}
// 失效未读数缓存
cache.InvalidateUnreadSystem(s.cache, userID)
return nil
}
// MarkAllAsRead 标记所有为已读
func (s *NotificationService) MarkAllAsRead(ctx context.Context, userID string) error {
err := s.notificationRepo.MarkAllAsRead(userID)
if err != nil {
return err
}
// 失效未读数缓存
cache.InvalidateUnreadSystem(s.cache, userID)
return nil
}
// Delete 删除通知
func (s *NotificationService) Delete(ctx context.Context, id string) error {
return s.notificationRepo.Delete(id)
}
// GetUnreadCount 获取未读数量(带缓存)
func (s *NotificationService) GetUnreadCount(ctx context.Context, userID string) (int64, error) {
cacheSettings := cache.GetSettings()
unreadTTL := cacheSettings.UnreadCountTTL
if unreadTTL <= 0 {
unreadTTL = NotificationUnreadCountTTL
}
nullTTL := cacheSettings.NullTTL
if nullTTL <= 0 {
nullTTL = NotificationNullTTL
}
jitter := cacheSettings.JitterRatio
if jitter <= 0 {
jitter = NotificationCacheJitter
}
// 生成缓存键
cacheKey := cache.UnreadSystemKey(userID)
return cache.GetOrLoadTyped[int64](
s.cache,
cacheKey,
unreadTTL,
jitter,
nullTTL,
func() (int64, error) {
return s.notificationRepo.GetUnreadCount(userID)
},
)
}
// DeleteNotification 删除通知(带用户验证)
func (s *NotificationService) DeleteNotification(ctx context.Context, id, userID string) error {
// 先检查通知是否属于该用户
notification, err := s.notificationRepo.GetByID(id)
if err != nil {
return err
}
if notification.UserID != userID {
return ErrUnauthorizedNotification
}
err = s.notificationRepo.Delete(id)
if err != nil {
return err
}
// 失效未读数缓存
cache.InvalidateUnreadSystem(s.cache, userID)
return nil
}
// ClearAllNotifications 清空所有通知
func (s *NotificationService) ClearAllNotifications(ctx context.Context, userID string) error {
err := s.notificationRepo.DeleteAllByUserID(userID)
if err != nil {
return err
}
// 失效未读数缓存
cache.InvalidateUnreadSystem(s.cache, userID)
return nil
}
// 错误定义
var ErrUnauthorizedNotification = &ServiceError{Code: 403, Message: "unauthorized to delete this notification"}

View File

@@ -0,0 +1,103 @@
package service
import (
"context"
"log"
"strings"
"carrot_bbs/internal/pkg/openai"
)
// PostModerationRejectedError 帖子审核拒绝错误
type PostModerationRejectedError struct {
Reason string
}
func (e *PostModerationRejectedError) Error() string {
if strings.TrimSpace(e.Reason) == "" {
return "post rejected by moderation"
}
return "post rejected by moderation: " + e.Reason
}
// UserMessage 返回给前端的用户可读文案
func (e *PostModerationRejectedError) UserMessage() string {
if strings.TrimSpace(e.Reason) == "" {
return "内容未通过审核,请修改后重试"
}
return strings.TrimSpace(e.Reason)
}
// CommentModerationRejectedError 评论审核拒绝错误
type CommentModerationRejectedError struct {
Reason string
}
func (e *CommentModerationRejectedError) Error() string {
if strings.TrimSpace(e.Reason) == "" {
return "comment rejected by moderation"
}
return "comment rejected by moderation: " + e.Reason
}
// UserMessage 返回给前端的用户可读文案
func (e *CommentModerationRejectedError) UserMessage() string {
if strings.TrimSpace(e.Reason) == "" {
return "评论未通过审核,请修改后重试"
}
return strings.TrimSpace(e.Reason)
}
type PostAIService struct {
openAIClient openai.Client
}
func NewPostAIService(openAIClient openai.Client) *PostAIService {
return &PostAIService{
openAIClient: openAIClient,
}
}
func (s *PostAIService) IsEnabled() bool {
return s != nil && s.openAIClient != nil && s.openAIClient.IsEnabled()
}
// ModeratePost 审核帖子内容,返回 nil 表示通过
func (s *PostAIService) ModeratePost(ctx context.Context, title, content string, images []string) error {
if !s.IsEnabled() {
return nil
}
approved, reason, err := s.openAIClient.ModeratePost(ctx, title, content, images)
if err != nil {
if s.openAIClient.Config().StrictModeration {
return err
}
log.Printf("[WARN] AI moderation failed, fallback allow: %v", err)
return nil
}
if !approved {
return &PostModerationRejectedError{Reason: reason}
}
return nil
}
// ModerateComment 审核评论内容,返回 nil 表示通过
func (s *PostAIService) ModerateComment(ctx context.Context, content string, images []string) error {
if !s.IsEnabled() {
return nil
}
approved, reason, err := s.openAIClient.ModerateComment(ctx, content, images)
if err != nil {
if s.openAIClient.Config().StrictModeration {
return err
}
log.Printf("[WARN] AI comment moderation failed, fallback allow: %v", err)
return nil
}
if !approved {
return &CommentModerationRejectedError{Reason: reason}
}
return nil
}

View File

@@ -0,0 +1,593 @@
package service
import (
"context"
"errors"
"fmt"
"log"
"strings"
"time"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/gorse"
"carrot_bbs/internal/repository"
)
// 缓存TTL常量
const (
PostListTTL = 30 * time.Second // 帖子列表缓存30秒
PostListNullTTL = 5 * time.Second
PostListJitterRatio = 0.15
anonymousViewUserID = "_anon_view"
)
// PostService 帖子服务
type PostService struct {
postRepo *repository.PostRepository
systemMessageService SystemMessageService
cache cache.Cache
gorseClient gorse.Client
postAIService *PostAIService
}
// NewPostService 创建帖子服务
func NewPostService(postRepo *repository.PostRepository, systemMessageService SystemMessageService, gorseClient gorse.Client, postAIService *PostAIService) *PostService {
return &PostService{
postRepo: postRepo,
systemMessageService: systemMessageService,
cache: cache.GetCache(),
gorseClient: gorseClient,
postAIService: postAIService,
}
}
// PostListResult 帖子列表缓存结果
type PostListResult struct {
Posts []*model.Post
Total int64
}
// Create 创建帖子
func (s *PostService) Create(ctx context.Context, userID, title, content string, images []string) (*model.Post, error) {
post := &model.Post{
UserID: userID,
Title: title,
Content: content,
Status: model.PostStatusPending,
}
err := s.postRepo.Create(post, images)
if err != nil {
return nil, err
}
// 失效帖子列表缓存
cache.InvalidatePostList(s.cache)
// 同步到Gorse推荐系统异步
go s.reviewPostAsync(post.ID, userID, title, content, images)
// 重新查询以获取关联的 User 和 Images
return s.postRepo.GetByID(post.ID)
}
func (s *PostService) reviewPostAsync(postID, userID, title, content string, images []string) {
// 未启用AI时直接发布
if s.postAIService == nil || !s.postAIService.IsEnabled() {
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil {
log.Printf("[WARN] Failed to publish post without AI moderation: %v", err)
}
return
}
err := s.postAIService.ModeratePost(context.Background(), title, content, images)
if err != nil {
var rejectedErr *PostModerationRejectedError
if errors.As(err, &rejectedErr) {
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
log.Printf("[WARN] Failed to reject post %s: %v", postID, updateErr)
}
s.notifyModerationRejected(userID, rejectedErr.Reason)
return
}
// 规则审核不可用时降级为发布避免长时间pending
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
log.Printf("[WARN] Failed to publish post %s after moderation error: %v", postID, updateErr)
}
log.Printf("[WARN] Post moderation failed, fallback publish post=%s err=%v", postID, err)
return
}
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "ai"); err != nil {
log.Printf("[WARN] Failed to publish post %s: %v", postID, err)
return
}
if s.gorseClient.IsEnabled() {
post, getErr := s.postRepo.GetByID(postID)
if getErr != nil {
log.Printf("[WARN] Failed to load published post for gorse sync: %v", getErr)
return
}
categories := s.buildPostCategories(post)
comment := post.Title
textToEmbed := post.Title + " " + post.Content
if upsertErr := s.gorseClient.UpsertItemWithEmbedding(context.Background(), post.ID, categories, comment, textToEmbed); upsertErr != nil {
log.Printf("[WARN] Failed to upsert item to Gorse: %v", upsertErr)
}
}
}
func (s *PostService) notifyModerationRejected(userID, reason string) {
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
return
}
content := "您发布的帖子未通过AI审核请修改后重试。"
if strings.TrimSpace(reason) != "" {
content = fmt.Sprintf("您发布的帖子未通过AI审核原因%s。请修改后重试。", reason)
}
go func() {
if err := s.systemMessageService.SendSystemAnnouncement(
context.Background(),
[]string{userID},
"帖子审核未通过",
content,
); err != nil {
log.Printf("[WARN] Failed to send moderation reject notification: %v", err)
}
}()
}
// GetByID 根据ID获取帖子
func (s *PostService) GetByID(ctx context.Context, id string) (*model.Post, error) {
return s.postRepo.GetByID(id)
}
// Update 更新帖子
func (s *PostService) Update(ctx context.Context, post *model.Post) error {
err := s.postRepo.Update(post)
if err != nil {
return err
}
// 失效帖子详情缓存和列表缓存
cache.InvalidatePostDetail(s.cache, post.ID)
cache.InvalidatePostList(s.cache)
return nil
}
// Delete 删除帖子
func (s *PostService) Delete(ctx context.Context, id string) error {
err := s.postRepo.Delete(id)
if err != nil {
return err
}
// 失效帖子详情缓存和列表缓存
cache.InvalidatePostDetail(s.cache, id)
cache.InvalidatePostList(s.cache)
// 从Gorse中删除帖子异步
go func() {
if s.gorseClient.IsEnabled() {
if err := s.gorseClient.DeleteItem(context.Background(), id); err != nil {
log.Printf("[WARN] Failed to delete item from Gorse: %v", err)
}
}
}()
return nil
}
// List 获取帖子列表(带缓存)
func (s *PostService) List(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
cacheSettings := cache.GetSettings()
postListTTL := cacheSettings.PostListTTL
if postListTTL <= 0 {
postListTTL = PostListTTL
}
nullTTL := cacheSettings.NullTTL
if nullTTL <= 0 {
nullTTL = PostListNullTTL
}
jitter := cacheSettings.JitterRatio
if jitter <= 0 {
jitter = PostListJitterRatio
}
// 生成缓存键(包含 userID 维度,避免过滤查询与全量查询互相污染)
cacheKey := cache.PostListKey("latest", userID, page, pageSize)
result, err := cache.GetOrLoadTyped[*PostListResult](
s.cache,
cacheKey,
postListTTL,
jitter,
nullTTL,
func() (*PostListResult, error) {
posts, total, err := s.postRepo.List(page, pageSize, userID)
if err != nil {
return nil, err
}
return &PostListResult{Posts: posts, Total: total}, nil
},
)
if err != nil {
return nil, 0, err
}
if result == nil {
return []*model.Post{}, 0, nil
}
// 兼容历史脏缓存:旧缓存序列化会丢失 Post.User导致前端显示“匿名用户”
// 这里检测并回源重建一次缓存,避免在 TTL 内持续返回缺失作者的数据。
missingAuthor := false
for _, post := range result.Posts {
if post != nil && post.UserID != "" && post.User == nil {
missingAuthor = true
break
}
}
if missingAuthor {
posts, total, loadErr := s.postRepo.List(page, pageSize, userID)
if loadErr != nil {
return nil, 0, loadErr
}
result = &PostListResult{Posts: posts, Total: total}
cache.SetWithJitter(s.cache, cacheKey, result, postListTTL, jitter)
}
return result.Posts, result.Total, nil
}
// GetLatestPosts 获取最新帖子(语义化别名)
func (s *PostService) GetLatestPosts(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
return s.List(ctx, page, pageSize, userID)
}
// GetUserPosts 获取用户帖子
func (s *PostService) GetUserPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
return s.postRepo.GetUserPosts(userID, page, pageSize)
}
// Like 点赞
func (s *PostService) Like(ctx context.Context, postID, userID string) error {
// 获取帖子信息用于发送通知
post, err := s.postRepo.GetByID(postID)
if err != nil {
return err
}
err = s.postRepo.Like(postID, userID)
if err != nil {
return err
}
// 失效帖子详情缓存
cache.InvalidatePostDetail(s.cache, postID)
// 发送点赞通知(不给自己发通知)
if s.systemMessageService != nil && post.UserID != userID {
go func() {
notifyErr := s.systemMessageService.SendLikeNotification(context.Background(), post.UserID, userID, postID)
if notifyErr != nil {
fmt.Printf("[DEBUG] Error sending like notification: %v\n", notifyErr)
} else {
fmt.Printf("[DEBUG] Like notification sent successfully\n")
}
}()
}
// 推送点赞行为到Gorse异步
go func() {
if s.gorseClient.IsEnabled() {
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeLike, userID, postID); err != nil {
log.Printf("[WARN] Failed to insert like feedback to Gorse: %v", err)
}
}
}()
return nil
}
// Unlike 取消点赞
func (s *PostService) Unlike(ctx context.Context, postID, userID string) error {
err := s.postRepo.Unlike(postID, userID)
if err != nil {
return err
}
// 失效帖子详情缓存
cache.InvalidatePostDetail(s.cache, postID)
// 删除Gorse中的点赞反馈异步
go func() {
if s.gorseClient.IsEnabled() {
if err := s.gorseClient.DeleteFeedback(context.Background(), gorse.FeedbackTypeLike, userID, postID); err != nil {
log.Printf("[WARN] Failed to delete like feedback from Gorse: %v", err)
}
}
}()
return nil
}
// IsLiked 检查是否点赞
func (s *PostService) IsLiked(ctx context.Context, postID, userID string) bool {
return s.postRepo.IsLiked(postID, userID)
}
// Favorite 收藏
func (s *PostService) Favorite(ctx context.Context, postID, userID string) error {
// 获取帖子信息用于发送通知
post, err := s.postRepo.GetByID(postID)
if err != nil {
return err
}
err = s.postRepo.Favorite(postID, userID)
if err != nil {
return err
}
// 失效帖子详情缓存
cache.InvalidatePostDetail(s.cache, postID)
// 发送收藏通知(不给自己发通知)
if s.systemMessageService != nil && post.UserID != userID {
go func() {
notifyErr := s.systemMessageService.SendFavoriteNotification(context.Background(), post.UserID, userID, postID)
if notifyErr != nil {
fmt.Printf("[DEBUG] Error sending favorite notification: %v\n", notifyErr)
} else {
fmt.Printf("[DEBUG] Favorite notification sent successfully\n")
}
}()
}
// 推送收藏行为到Gorse异步
go func() {
if s.gorseClient.IsEnabled() {
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeStar, userID, postID); err != nil {
log.Printf("[WARN] Failed to insert favorite feedback to Gorse: %v", err)
}
}
}()
return nil
}
// Unfavorite 取消收藏
func (s *PostService) Unfavorite(ctx context.Context, postID, userID string) error {
err := s.postRepo.Unfavorite(postID, userID)
if err != nil {
return err
}
// 失效帖子详情缓存
cache.InvalidatePostDetail(s.cache, postID)
// 删除Gorse中的收藏反馈异步
go func() {
if s.gorseClient.IsEnabled() {
if err := s.gorseClient.DeleteFeedback(context.Background(), gorse.FeedbackTypeStar, userID, postID); err != nil {
log.Printf("[WARN] Failed to delete favorite feedback from Gorse: %v", err)
}
}
}()
return nil
}
// IsFavorited 检查是否收藏
func (s *PostService) IsFavorited(ctx context.Context, postID, userID string) bool {
return s.postRepo.IsFavorited(postID, userID)
}
// IncrementViews 增加帖子观看量并同步到Gorse
func (s *PostService) IncrementViews(ctx context.Context, postID, userID string) error {
if err := s.postRepo.IncrementViews(postID); err != nil {
return err
}
// 同步浏览行为到Gorse异步
go func() {
if !s.gorseClient.IsEnabled() {
return
}
feedbackUserID := userID
if feedbackUserID == "" {
feedbackUserID = anonymousViewUserID
}
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeRead, feedbackUserID, postID); err != nil {
log.Printf("[WARN] Failed to insert read feedback to Gorse: %v", err)
}
}()
return nil
}
// GetFavorites 获取收藏列表
func (s *PostService) GetFavorites(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
return s.postRepo.GetFavorites(userID, page, pageSize)
}
// Search 搜索帖子
func (s *PostService) Search(ctx context.Context, keyword string, page, pageSize int) ([]*model.Post, int64, error) {
return s.postRepo.Search(keyword, page, pageSize)
}
// GetFollowingPosts 获取关注用户的帖子(带缓存)
func (s *PostService) GetFollowingPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
cacheSettings := cache.GetSettings()
postListTTL := cacheSettings.PostListTTL
if postListTTL <= 0 {
postListTTL = PostListTTL
}
nullTTL := cacheSettings.NullTTL
if nullTTL <= 0 {
nullTTL = PostListNullTTL
}
jitter := cacheSettings.JitterRatio
if jitter <= 0 {
jitter = PostListJitterRatio
}
// 生成缓存键
cacheKey := cache.PostListKey("follow", userID, page, pageSize)
result, err := cache.GetOrLoadTyped[*PostListResult](
s.cache,
cacheKey,
postListTTL,
jitter,
nullTTL,
func() (*PostListResult, error) {
posts, total, err := s.postRepo.GetFollowingPosts(userID, page, pageSize)
if err != nil {
return nil, err
}
return &PostListResult{Posts: posts, Total: total}, nil
},
)
if err != nil {
return nil, 0, err
}
if result == nil {
return []*model.Post{}, 0, nil
}
return result.Posts, result.Total, nil
}
// GetHotPosts 获取热门帖子使用Gorse非个性化推荐
func (s *PostService) GetHotPosts(ctx context.Context, page, pageSize int) ([]*model.Post, int64, error) {
// 如果Gorse启用使用自定义的非个性化推荐器
if s.gorseClient.IsEnabled() {
offset := (page - 1) * pageSize
// 使用 most_liked_weekly 推荐器获取周热门
// 多取1条用于判断是否还有下一页
itemIDs, err := s.gorseClient.GetNonPersonalized(ctx, "most_liked_weekly", pageSize+1, offset, "")
if err != nil {
log.Printf("[WARN] Gorse GetNonPersonalized failed: %v, fallback to database", err)
return s.getHotPostsFromDB(ctx, page, pageSize)
}
if len(itemIDs) > 0 {
hasNext := len(itemIDs) > pageSize
if hasNext {
itemIDs = itemIDs[:pageSize]
}
posts, err := s.postRepo.GetByIDs(itemIDs)
if err != nil {
return nil, 0, err
}
// 近似 total当 hasNext 为 true 时,按分页窗口估算,避免因脏数据/缺失数据导致总页数被低估
estimatedTotal := int64(offset + len(posts))
if hasNext {
estimatedTotal = int64(offset + pageSize + 1)
}
return posts, estimatedTotal, nil
}
}
// 降级:从数据库获取
return s.getHotPostsFromDB(ctx, page, pageSize)
}
// getHotPostsFromDB 从数据库获取热门帖子(降级路径)
func (s *PostService) getHotPostsFromDB(ctx context.Context, page, pageSize int) ([]*model.Post, int64, error) {
// 直接查询数据库不再使用本地缓存Gorse失败降级时使用
posts, total, err := s.postRepo.GetHotPosts(page, pageSize)
if err != nil {
return nil, 0, err
}
return posts, total, nil
}
// GetRecommendedPosts 获取推荐帖子
func (s *PostService) GetRecommendedPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
// 如果Gorse未启用或用户未登录降级为热门帖子
if !s.gorseClient.IsEnabled() || userID == "" {
return s.GetHotPosts(ctx, page, pageSize)
}
// 计算偏移量
offset := (page - 1) * pageSize
// 从Gorse获取推荐列表
// 多取1条用于判断是否还有下一页
itemIDs, err := s.gorseClient.GetRecommend(ctx, userID, pageSize+1, offset)
if err != nil {
log.Printf("[WARN] Gorse recommendation failed: %v, fallback to hot posts", err)
return s.GetHotPosts(ctx, page, pageSize)
}
// 如果没有推荐结果,降级为热门帖子
if len(itemIDs) == 0 {
return s.GetHotPosts(ctx, page, pageSize)
}
hasNext := len(itemIDs) > pageSize
if hasNext {
itemIDs = itemIDs[:pageSize]
}
// 根据ID列表查询帖子详情
posts, err := s.postRepo.GetByIDs(itemIDs)
if err != nil {
return nil, 0, err
}
// 近似 total当 hasNext 为 true 时,按分页窗口估算,避免因脏数据/缺失数据导致总页数被低估
estimatedTotal := int64(offset + len(posts))
if hasNext {
estimatedTotal = int64(offset + pageSize + 1)
}
return posts, estimatedTotal, nil
}
// buildPostCategories 构建帖子的类别标签
func (s *PostService) buildPostCategories(post *model.Post) []string {
var categories []string
// 热度标签
if post.ViewsCount > 1000 {
categories = append(categories, "hot_high")
} else if post.ViewsCount > 100 {
categories = append(categories, "hot_medium")
}
// 点赞标签
if post.LikesCount > 100 {
categories = append(categories, "likes_100+")
} else if post.LikesCount > 50 {
categories = append(categories, "likes_50+")
} else if post.LikesCount > 10 {
categories = append(categories, "likes_10+")
}
// 评论标签
if post.CommentsCount > 50 {
categories = append(categories, "comments_50+")
} else if post.CommentsCount > 10 {
categories = append(categories, "comments_10+")
}
// 时间标签
age := time.Since(post.CreatedAt)
if age < 24*time.Hour {
categories = append(categories, "today")
} else if age < 7*24*time.Hour {
categories = append(categories, "this_week")
} else if age < 30*24*time.Hour {
categories = append(categories, "this_month")
}
return categories
}

View File

@@ -0,0 +1,575 @@
package service
import (
"context"
"errors"
"fmt"
"time"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/websocket"
"carrot_bbs/internal/repository"
)
// 推送相关常量
const (
// DefaultPushTimeout 默认推送超时时间
DefaultPushTimeout = 30 * time.Second
// MaxRetryCount 最大重试次数
MaxRetryCount = 3
// DefaultExpiredTime 默认消息过期时间24小时
DefaultExpiredTime = 24 * time.Hour
// PushQueueSize 推送队列大小
PushQueueSize = 1000
)
// PushPriority 推送优先级
type PushPriority int
const (
PriorityLow PushPriority = 1 // 低优先级(营销消息等)
PriorityNormal PushPriority = 5 // 普通优先级(系统通知)
PriorityHigh PushPriority = 8 // 高优先级(聊天消息)
PriorityCritical PushPriority = 10 // 最高优先级(重要系统通知)
)
// PushService 推送服务接口
type PushService interface {
// 推送核心方法
PushMessage(ctx context.Context, userID string, message *model.Message) error
PushToUser(ctx context.Context, userID string, message *model.Message, priority int) error
// 系统消息推送
PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error
PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error
PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error
// 系统通知推送(新接口,使用独立的 SystemNotification 模型)
PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error
// 设备管理
RegisterDevice(ctx context.Context, userID string, deviceID string, deviceType model.DeviceType, pushToken string) error
UnregisterDevice(ctx context.Context, deviceID string) error
UpdateDeviceToken(ctx context.Context, deviceID string, newPushToken string) error
// 推送记录管理
CreatePushRecord(ctx context.Context, userID string, messageID string, channel model.PushChannel) (*model.PushRecord, error)
GetPendingPushes(ctx context.Context, userID string) ([]*model.PushRecord, error)
// 后台任务
StartPushWorker(ctx context.Context)
StopPushWorker()
}
// pushServiceImpl 推送服务实现
type pushServiceImpl struct {
pushRepo *repository.PushRecordRepository
deviceRepo *repository.DeviceTokenRepository
messageRepo *repository.MessageRepository
wsManager *websocket.WebSocketManager
// 推送队列
pushQueue chan *pushTask
stopChan chan struct{}
}
// pushTask 推送任务
type pushTask struct {
userID string
message *model.Message
priority int
}
// NewPushService 创建推送服务
func NewPushService(
pushRepo *repository.PushRecordRepository,
deviceRepo *repository.DeviceTokenRepository,
messageRepo *repository.MessageRepository,
wsManager *websocket.WebSocketManager,
) PushService {
return &pushServiceImpl{
pushRepo: pushRepo,
deviceRepo: deviceRepo,
messageRepo: messageRepo,
wsManager: wsManager,
pushQueue: make(chan *pushTask, PushQueueSize),
stopChan: make(chan struct{}),
}
}
// PushMessage 推送消息给用户
func (s *pushServiceImpl) PushMessage(ctx context.Context, userID string, message *model.Message) error {
return s.PushToUser(ctx, userID, message, int(PriorityNormal))
}
// PushToUser 带优先级的推送
func (s *pushServiceImpl) PushToUser(ctx context.Context, userID string, message *model.Message, priority int) error {
// 首先尝试WebSocket推送实时推送
if s.pushViaWebSocket(ctx, userID, message) {
// WebSocket推送成功记录推送状态
record, err := s.CreatePushRecord(ctx, userID, message.ID, model.PushChannelWebSocket)
if err != nil {
return fmt.Errorf("failed to create push record: %w", err)
}
record.MarkPushed()
if err := s.pushRepo.Update(record); err != nil {
return fmt.Errorf("failed to update push record: %w", err)
}
return nil
}
// WebSocket推送失败加入推送队列等待移动端推送
select {
case s.pushQueue <- &pushTask{
userID: userID,
message: message,
priority: priority,
}:
return nil
default:
// 队列已满,直接创建待推送记录
_, err := s.CreatePushRecord(ctx, userID, message.ID, model.PushChannelFCM)
if err != nil {
return fmt.Errorf("failed to create pending push record: %w", err)
}
return errors.New("push queue is full, message queued for later delivery")
}
}
// pushViaWebSocket 通过WebSocket推送消息
// 返回true表示推送成功false表示用户不在线
func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, message *model.Message) bool {
if s.wsManager == nil {
return false
}
if !s.wsManager.IsUserOnline(userID) {
return false
}
// 判断是否为系统消息/通知消息
if message.IsSystemMessage() || message.Category == model.CategoryNotification {
// 使用 NotificationMessage 格式推送系统通知
// 从 segments 中提取文本内容
content := dto.ExtractTextContentFromModel(message.Segments)
notification := &websocket.NotificationMessage{
ID: fmt.Sprintf("%s", message.ID),
Type: string(message.SystemType),
Content: content,
Extra: make(map[string]interface{}),
CreatedAt: message.CreatedAt.UnixMilli(),
}
// 填充额外数据
if message.ExtraData != nil {
notification.Extra["actor_id"] = message.ExtraData.ActorID
notification.Extra["actor_name"] = message.ExtraData.ActorName
notification.Extra["avatar_url"] = message.ExtraData.AvatarURL
notification.Extra["target_id"] = message.ExtraData.TargetID
notification.Extra["target_type"] = message.ExtraData.TargetType
notification.Extra["action_url"] = message.ExtraData.ActionURL
notification.Extra["action_time"] = message.ExtraData.ActionTime
// 设置触发用户信息
if message.ExtraData.ActorID > 0 {
notification.TriggerUser = &websocket.NotificationUser{
ID: fmt.Sprintf("%d", message.ExtraData.ActorID),
Username: message.ExtraData.ActorName,
Avatar: message.ExtraData.AvatarURL,
}
}
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
s.wsManager.SendToUser(userID, wsMsg)
return true
}
// 构建普通聊天消息的 WebSocket 消息 - 使用新的 WSEventResponse 格式
// 获取会话类型 (private/group)
detailType := "private"
if message.ConversationID != "" {
// 从会话中获取类型,需要查询数据库或从消息中判断
// 这里暂时默认为 privategroup 类型需要额外逻辑
}
// 直接使用 message.Segments
segments := message.Segments
event := &dto.WSEventResponse{
ID: fmt.Sprintf("%s", message.ID),
Time: message.CreatedAt.UnixMilli(),
Type: "message",
DetailType: detailType,
Seq: fmt.Sprintf("%d", message.Seq),
Segments: segments,
SenderID: message.SenderID,
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, event)
s.wsManager.SendToUser(userID, wsMsg)
return true
}
// pushViaFCM 通过FCM推送预留接口
func (s *pushServiceImpl) pushViaFCM(ctx context.Context, deviceToken *model.DeviceToken, message *model.Message) error {
// TODO: 实现FCM推送
// 1. 构建FCM消息
// 2. 调用Firebase Admin SDK发送消息
// 3. 处理发送结果
return errors.New("FCM push not implemented")
}
// pushViaAPNs 通过APNs推送预留接口
func (s *pushServiceImpl) pushViaAPNs(ctx context.Context, deviceToken *model.DeviceToken, message *model.Message) error {
// TODO: 实现APNs推送
// 1. 构建APNs消息
// 2. 调用APNs SDK发送消息
// 3. 处理发送结果
return errors.New("APNs push not implemented")
}
// RegisterDevice 注册设备
func (s *pushServiceImpl) RegisterDevice(ctx context.Context, userID string, deviceID string, deviceType model.DeviceType, pushToken string) error {
deviceToken := &model.DeviceToken{
UserID: userID,
DeviceID: deviceID,
DeviceType: deviceType,
PushToken: pushToken,
IsActive: true,
}
deviceToken.UpdateLastUsed()
return s.deviceRepo.Upsert(deviceToken)
}
// UnregisterDevice 注销设备
func (s *pushServiceImpl) UnregisterDevice(ctx context.Context, deviceID string) error {
return s.deviceRepo.Deactivate(deviceID)
}
// UpdateDeviceToken 更新设备Token
func (s *pushServiceImpl) UpdateDeviceToken(ctx context.Context, deviceID string, newPushToken string) error {
deviceToken, err := s.deviceRepo.GetByDeviceID(deviceID)
if err != nil {
return fmt.Errorf("device not found: %w", err)
}
deviceToken.PushToken = newPushToken
deviceToken.Activate()
return s.deviceRepo.Update(deviceToken)
}
// CreatePushRecord 创建推送记录
func (s *pushServiceImpl) CreatePushRecord(ctx context.Context, userID string, messageID string, channel model.PushChannel) (*model.PushRecord, error) {
expiredAt := time.Now().Add(DefaultExpiredTime)
record := &model.PushRecord{
UserID: userID,
MessageID: messageID,
PushChannel: channel,
PushStatus: model.PushStatusPending,
MaxRetry: MaxRetryCount,
ExpiredAt: &expiredAt,
}
if err := s.pushRepo.Create(record); err != nil {
return nil, fmt.Errorf("failed to create push record: %w", err)
}
return record, nil
}
// GetPendingPushes 获取待推送记录
func (s *pushServiceImpl) GetPendingPushes(ctx context.Context, userID string) ([]*model.PushRecord, error) {
return s.pushRepo.GetByUserID(userID, 100, 0)
}
// StartPushWorker 启动推送工作协程
func (s *pushServiceImpl) StartPushWorker(ctx context.Context) {
go s.processPushQueue()
go s.retryFailedPushes()
}
// StopPushWorker 停止推送工作协程
func (s *pushServiceImpl) StopPushWorker() {
close(s.stopChan)
}
// processPushQueue 处理推送队列
func (s *pushServiceImpl) processPushQueue() {
for {
select {
case <-s.stopChan:
return
case task := <-s.pushQueue:
s.processPushTask(task)
}
}
}
// processPushTask 处理单个推送任务
func (s *pushServiceImpl) processPushTask(task *pushTask) {
ctx, cancel := context.WithTimeout(context.Background(), DefaultPushTimeout)
defer cancel()
// 获取用户活跃设备
devices, err := s.deviceRepo.GetActiveByUserID(task.userID)
if err != nil || len(devices) == 0 {
// 没有可用设备,创建待推送记录
s.CreatePushRecord(ctx, task.userID, task.message.ID, model.PushChannelFCM)
return
}
// 对每个设备创建推送记录并尝试推送
for _, device := range devices {
record, err := s.CreatePushRecord(ctx, task.userID, task.message.ID, s.getChannelForDevice(device))
if err != nil {
continue
}
var pushErr error
switch {
case device.IsIOS():
pushErr = s.pushViaAPNs(ctx, device, task.message)
case device.IsAndroid():
pushErr = s.pushViaFCM(ctx, device, task.message)
default:
// Web设备只支持WebSocket
continue
}
if pushErr != nil {
record.MarkFailed(pushErr.Error())
} else {
record.MarkPushed()
}
s.pushRepo.Update(record)
}
}
// getChannelForDevice 根据设备类型获取推送通道
func (s *pushServiceImpl) getChannelForDevice(device *model.DeviceToken) model.PushChannel {
switch device.DeviceType {
case model.DeviceTypeIOS:
return model.PushChannelAPNs
case model.DeviceTypeAndroid:
return model.PushChannelFCM
default:
return model.PushChannelWebSocket
}
}
// retryFailedPushes 重试失败的推送
func (s *pushServiceImpl) retryFailedPushes() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopChan:
return
case <-ticker.C:
s.doRetry()
}
}
}
// doRetry 执行重试
func (s *pushServiceImpl) doRetry() {
ctx := context.Background()
// 获取失败待重试的推送
records, err := s.pushRepo.GetFailedPushesForRetry(100)
if err != nil {
return
}
for _, record := range records {
// 检查是否过期
if record.IsExpired() {
record.MarkExpired()
s.pushRepo.Update(record)
continue
}
// 获取消息
message, err := s.messageRepo.GetMessageByID(record.MessageID)
if err != nil {
record.MarkFailed("message not found")
s.pushRepo.Update(record)
continue
}
// 尝试WebSocket推送
if s.pushViaWebSocket(ctx, record.UserID, message) {
record.MarkDelivered()
s.pushRepo.Update(record)
continue
}
// 获取设备并尝试移动端推送
if record.DeviceToken != "" {
device, err := s.deviceRepo.GetByPushToken(record.DeviceToken)
if err != nil {
record.MarkFailed("device not found")
s.pushRepo.Update(record)
continue
}
var pushErr error
switch {
case device.IsIOS():
pushErr = s.pushViaAPNs(ctx, device, message)
case device.IsAndroid():
pushErr = s.pushViaFCM(ctx, device, message)
}
if pushErr != nil {
record.MarkFailed(pushErr.Error())
} else {
record.MarkPushed()
}
s.pushRepo.Update(record)
}
}
}
// PushSystemMessage 推送系统消息
func (s *pushServiceImpl) PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error {
// 首先尝试WebSocket推送
if s.pushSystemViaWebSocket(ctx, userID, msgType, title, content, data) {
return nil
}
// 用户不在线,创建待推送记录(移动端上线后可通过其他方式获取)
// 系统消息通常不需要离线推送,客户端上线后会主动拉取
return errors.New("user is offline, system message will be available on next sync")
}
// pushSystemViaWebSocket 通过WebSocket推送系统消息
func (s *pushServiceImpl) pushSystemViaWebSocket(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) bool {
if s.wsManager == nil {
return false
}
if !s.wsManager.IsUserOnline(userID) {
return false
}
sysMsg := &websocket.SystemMessage{
Type: msgType,
Title: title,
Content: content,
Data: data,
CreatedAt: time.Now().UnixMilli(),
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
s.wsManager.SendToUser(userID, wsMsg)
return true
}
// PushNotification 推送通知消息
func (s *pushServiceImpl) PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error {
// 首先尝试WebSocket推送
if s.pushNotificationViaWebSocket(ctx, userID, notification) {
return nil
}
// 用户不在线,创建待推送记录
// 通知消息可以等用户上线后拉取
return errors.New("user is offline, notification will be available on next sync")
}
// pushNotificationViaWebSocket 通过WebSocket推送通知消息
func (s *pushServiceImpl) pushNotificationViaWebSocket(ctx context.Context, userID string, notification *websocket.NotificationMessage) bool {
if s.wsManager == nil {
return false
}
if !s.wsManager.IsUserOnline(userID) {
return false
}
if notification.CreatedAt == 0 {
notification.CreatedAt = time.Now().UnixMilli()
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
s.wsManager.SendToUser(userID, wsMsg)
return true
}
// PushAnnouncement 广播公告消息
func (s *pushServiceImpl) PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error {
if s.wsManager == nil {
return errors.New("websocket manager not available")
}
if announcement.CreatedAt == 0 {
announcement.CreatedAt = time.Now().UnixMilli()
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
s.wsManager.Broadcast(wsMsg)
return nil
}
// PushSystemNotification 推送系统通知(使用独立的 SystemNotification 模型)
func (s *pushServiceImpl) PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error {
// 首先尝试WebSocket推送
if s.pushSystemNotificationViaWebSocket(ctx, userID, notification) {
return nil
}
// 用户不在线,系统通知已存储在数据库中,用户上线后会主动拉取
return nil
}
// pushSystemNotificationViaWebSocket 通过WebSocket推送系统通知
func (s *pushServiceImpl) pushSystemNotificationViaWebSocket(ctx context.Context, userID string, notification *model.SystemNotification) bool {
if s.wsManager == nil {
return false
}
if !s.wsManager.IsUserOnline(userID) {
return false
}
// 构建 WebSocket 通知消息
wsNotification := &websocket.NotificationMessage{
ID: fmt.Sprintf("%d", notification.ID),
Type: string(notification.Type),
Title: notification.Title,
Content: notification.Content,
Extra: make(map[string]interface{}),
CreatedAt: notification.CreatedAt.UnixMilli(),
}
// 填充额外数据
if notification.ExtraData != nil {
wsNotification.Extra["actor_id_str"] = notification.ExtraData.ActorIDStr
wsNotification.Extra["actor_name"] = notification.ExtraData.ActorName
wsNotification.Extra["avatar_url"] = notification.ExtraData.AvatarURL
wsNotification.Extra["target_id"] = notification.ExtraData.TargetID
wsNotification.Extra["target_type"] = notification.ExtraData.TargetType
wsNotification.Extra["action_url"] = notification.ExtraData.ActionURL
wsNotification.Extra["action_time"] = notification.ExtraData.ActionTime
// 设置触发用户信息
if notification.ExtraData.ActorIDStr != "" {
wsNotification.TriggerUser = &websocket.NotificationUser{
ID: notification.ExtraData.ActorIDStr,
Username: notification.ExtraData.ActorName,
Avatar: notification.ExtraData.AvatarURL,
}
}
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, wsNotification)
s.wsManager.SendToUser(userID, wsMsg)
return true
}

View File

@@ -0,0 +1,559 @@
package service
import (
"context"
"encoding/json"
"fmt"
"log"
"regexp"
"strings"
"sync"
"time"
"unicode/utf8"
"carrot_bbs/internal/model"
redisclient "carrot_bbs/internal/pkg/redis"
"gorm.io/gorm"
)
// ==================== DFA 敏感词过滤实现 ====================
// SensitiveNode 敏感词树节点
type SensitiveNode struct {
// 子节点映射
Children map[rune]*SensitiveNode
// 是否为敏感词结尾
IsEnd bool
// 敏感词信息(仅在 IsEnd 为 true 时有效)
Word string
Level model.SensitiveWordLevel
Category model.SensitiveWordCategory
}
// NewSensitiveNode 创建新的敏感词节点
func NewSensitiveNode() *SensitiveNode {
return &SensitiveNode{
Children: make(map[rune]*SensitiveNode),
IsEnd: false,
}
}
// SensitiveWordTree 敏感词树
type SensitiveWordTree struct {
root *SensitiveNode
wordCount int
mu sync.RWMutex
lastReload time.Time
}
// NewSensitiveWordTree 创建新的敏感词树
func NewSensitiveWordTree() *SensitiveWordTree {
return &SensitiveWordTree{
root: NewSensitiveNode(),
wordCount: 0,
lastReload: time.Now(),
}
}
// AddWord 添加敏感词到树中
func (t *SensitiveWordTree) AddWord(word string, level model.SensitiveWordLevel, category model.SensitiveWordCategory) {
if word == "" {
return
}
t.mu.Lock()
defer t.mu.Unlock()
node := t.root
// 转换为小写进行匹配(不区分大小写)
lowerWord := strings.ToLower(word)
runes := []rune(lowerWord)
for _, r := range runes {
child, exists := node.Children[r]
if !exists {
child = NewSensitiveNode()
node.Children[r] = child
}
node = child
}
// 如果不是已存在的敏感词,则计数+1
if !node.IsEnd {
t.wordCount++
}
node.IsEnd = true
node.Word = word
node.Level = level
node.Category = category
}
// RemoveWord 从树中移除敏感词
func (t *SensitiveWordTree) RemoveWord(word string) {
if word == "" {
return
}
t.mu.Lock()
defer t.mu.Unlock()
lowerWord := strings.ToLower(word)
runes := []rune(lowerWord)
// 查找节点
node := t.root
for _, r := range runes {
child, exists := node.Children[r]
if !exists {
return // 敏感词不存在
}
node = child
}
if node.IsEnd {
node.IsEnd = false
node.Word = ""
t.wordCount--
}
}
// Check 检查文本是否包含敏感词,返回是否包含及敏感词列表
func (t *SensitiveWordTree) Check(text string) (bool, []string) {
if text == "" {
return false, nil
}
t.mu.RLock()
defer t.mu.RUnlock()
var foundWords []string
runes := []rune(strings.ToLower(text))
length := len(runes)
// 用于标记已找到的敏感词位置,避免重复计算
marked := make([]bool, length)
for i := 0; i < length; i++ {
// 从当前位置开始搜索
node := t.root
matchEnd := -1
matchWord := ""
for j := i; j < length; j++ {
child, exists := node.Children[runes[j]]
if !exists {
break
}
node = child
if node.IsEnd {
matchEnd = j
matchWord = node.Word
}
}
// 标记找到的敏感词位置
if matchEnd >= 0 && !marked[i] {
for k := i; k <= matchEnd; k++ {
marked[k] = true
}
foundWords = append(foundWords, matchWord)
}
}
return len(foundWords) > 0, foundWords
}
// Replace 替换文本中的敏感词
func (t *SensitiveWordTree) Replace(text string, repl string) string {
if text == "" {
return text
}
t.mu.RLock()
defer t.mu.RUnlock()
runes := []rune(text)
length := len(runes)
result := make([]rune, 0, length)
// 用于标记已替换的位置
marked := make([]bool, length)
for i := 0; i < length; i++ {
if marked[i] {
continue
}
// 从当前位置开始搜索
node := t.root
matchEnd := -1
for j := i; j < length; j++ {
child, exists := node.Children[runes[j]]
if !exists {
break
}
node = child
if node.IsEnd {
matchEnd = j
}
}
if matchEnd >= 0 {
// 标记已替换的位置
for k := i; k <= matchEnd; k++ {
marked[k] = true
}
// 追加替换符
replRunes := []rune(repl)
result = append(result, replRunes...)
// 跳过已匹配的字符
i = matchEnd
} else {
// 追加原字符
result = append(result, runes[i])
}
}
return string(result)
}
// WordCount 获取敏感词数量
func (t *SensitiveWordTree) WordCount() int {
t.mu.RLock()
defer t.mu.RUnlock()
return t.wordCount
}
// ==================== 敏感词服务实现 ====================
// SensitiveService 敏感词服务接口
type SensitiveService interface {
// Check 检查文本是否包含敏感词
Check(ctx context.Context, text string) (bool, []string)
// Replace 替换敏感词
Replace(ctx context.Context, text string, repl string) string
// AddWord 添加敏感词
AddWord(ctx context.Context, word string, category string, level int) error
// RemoveWord 移除敏感词
RemoveWord(ctx context.Context, word string) error
// Reload 重新加载敏感词库
Reload(ctx context.Context) error
// GetWordCount 获取敏感词数量
GetWordCount(ctx context.Context) int
}
// sensitiveServiceImpl 敏感词服务实现
type sensitiveServiceImpl struct {
tree *SensitiveWordTree
db *gorm.DB
redis *redisclient.Client
config *SensitiveConfig
mu sync.RWMutex
replaceStr string
}
// SensitiveConfig 敏感词服务配置
type SensitiveConfig struct {
Enabled bool `mapstructure:"enabled" yaml:"enabled"`
ReplaceStr string `mapstructure:"replace_str" yaml:"replace_str"`
// 最小匹配长度
MinMatchLen int `mapstructure:"min_match_len" yaml:"min_match_len"`
// 是否从数据库加载
LoadFromDB bool `mapstructure:"load_from_db" yaml:"load_from_db"`
// 是否从Redis加载
LoadFromRedis bool `mapstructure:"load_from_redis" yaml:"load_from_redis"`
// Redis键前缀
RedisKeyPrefix string `mapstructure:"redis_key_prefix" yaml:"redis_key_prefix"`
}
// NewSensitiveService 创建敏感词服务
func NewSensitiveService(db *gorm.DB, redisClient *redisclient.Client, config *SensitiveConfig) SensitiveService {
s := &sensitiveServiceImpl{
tree: NewSensitiveWordTree(),
db: db,
redis: redisClient,
config: config,
replaceStr: config.ReplaceStr,
}
// 如果未设置替换符,默认使用 ***
if s.replaceStr == "" {
s.replaceStr = "***"
}
// 初始化加载敏感词
if config.LoadFromDB {
if err := s.loadFromDB(context.Background()); err != nil {
log.Printf("Failed to load sensitive words from database: %v", err)
}
}
if config.LoadFromRedis && redisClient != nil {
if err := s.loadFromRedis(context.Background()); err != nil {
log.Printf("Failed to load sensitive words from redis: %v", err)
}
}
return s
}
// Check 检查文本是否包含敏感词
func (s *sensitiveServiceImpl) Check(ctx context.Context, text string) (bool, []string) {
if !s.config.Enabled {
return false, nil
}
if text == "" {
return false, nil
}
return s.tree.Check(text)
}
// Replace 替换敏感词
func (s *sensitiveServiceImpl) Replace(ctx context.Context, text string, repl string) string {
if !s.config.Enabled {
return text
}
if text == "" {
return text
}
// 如果未指定替换符,使用默认替换符
if repl == "" {
repl = s.replaceStr
}
return s.tree.Replace(text, repl)
}
// AddWord 添加敏感词
func (s *sensitiveServiceImpl) AddWord(ctx context.Context, word string, category string, level int) error {
if word == "" {
return fmt.Errorf("word cannot be empty")
}
// 转换为敏感词级别
wordLevel := model.SensitiveWordLevel(level)
if wordLevel < 1 || wordLevel > 3 {
wordLevel = model.SensitiveWordLevelLow
}
// 转换为敏感词分类
wordCategory := model.SensitiveWordCategory(category)
if wordCategory == "" {
wordCategory = model.SensitiveWordCategoryOther
}
// 添加到树
s.tree.AddWord(word, wordLevel, wordCategory)
// 持久化到数据库
if s.db != nil {
sensitiveWord := model.SensitiveWord{
Word: word,
Category: wordCategory,
Level: wordLevel,
IsActive: true,
}
// 使用 upsert 逻辑
var existing model.SensitiveWord
result := s.db.Where("word = ?", word).First(&existing)
if result.Error == gorm.ErrRecordNotFound {
if err := s.db.Create(&sensitiveWord).Error; err != nil {
log.Printf("Failed to save sensitive word to database: %v", err)
}
} else if result.Error == nil {
// 更新已存在的记录
existing.Category = wordCategory
existing.Level = wordLevel
existing.IsActive = true
if err := s.db.Save(&existing).Error; err != nil {
log.Printf("Failed to update sensitive word in database: %v", err)
}
}
}
// 同步到 Redis
if s.redis != nil && s.config.RedisKeyPrefix != "" {
key := fmt.Sprintf("%s:%s", s.config.RedisKeyPrefix, word)
data := map[string]interface{}{
"word": word,
"category": category,
"level": level,
}
jsonData, _ := json.Marshal(data)
s.redis.Set(ctx, key, jsonData, 0)
}
return nil
}
// RemoveWord 移除敏感词
func (s *sensitiveServiceImpl) RemoveWord(ctx context.Context, word string) error {
if word == "" {
return fmt.Errorf("word cannot be empty")
}
// 从树中移除
s.tree.RemoveWord(word)
// 从数据库中标记为不活跃
if s.db != nil {
result := s.db.Model(&model.SensitiveWord{}).Where("word = ?", word).Update("is_active", false)
if result.Error != nil {
log.Printf("Failed to deactivate sensitive word in database: %v", result.Error)
}
}
// 从 Redis 中删除
if s.redis != nil && s.config.RedisKeyPrefix != "" {
key := fmt.Sprintf("%s:%s", s.config.RedisKeyPrefix, word)
s.redis.Del(ctx, key)
}
return nil
}
// Reload 重新加载敏感词库
func (s *sensitiveServiceImpl) Reload(ctx context.Context) error {
// 清空现有树
s.tree = NewSensitiveWordTree()
// 从数据库加载
if s.config.LoadFromDB {
if err := s.loadFromDB(ctx); err != nil {
return fmt.Errorf("failed to load from database: %w", err)
}
}
// 从 Redis 加载
if s.config.LoadFromRedis && s.redis != nil {
if err := s.loadFromRedis(ctx); err != nil {
return fmt.Errorf("failed to load from redis: %w", err)
}
}
return nil
}
// GetWordCount 获取敏感词数量
func (s *sensitiveServiceImpl) GetWordCount(ctx context.Context) int {
return s.tree.WordCount()
}
// loadFromDB 从数据库加载敏感词
func (s *sensitiveServiceImpl) loadFromDB(ctx context.Context) error {
if s.db == nil {
return nil
}
var words []model.SensitiveWord
if err := s.db.Where("is_active = ?", true).Find(&words).Error; err != nil {
return err
}
for _, word := range words {
s.tree.AddWord(word.Word, word.Level, word.Category)
}
log.Printf("Loaded %d sensitive words from database", len(words))
return nil
}
// loadFromRedis 从 Redis 加载敏感词
func (s *sensitiveServiceImpl) loadFromRedis(ctx context.Context) error {
if s.redis == nil || s.config.RedisKeyPrefix == "" {
return nil
}
// 使用 SCAN 命令代替 KEYS避免阻塞
pattern := fmt.Sprintf("%s:*", s.config.RedisKeyPrefix)
var cursor uint64
for {
keys, nextCursor, err := s.redis.GetClient().Scan(ctx, cursor, pattern, 100).Result()
if err != nil {
return err
}
for _, key := range keys {
data, err := s.redis.Get(ctx, key)
if err != nil {
continue
}
var wordData map[string]interface{}
if err := json.Unmarshal([]byte(data), &wordData); err != nil {
continue
}
word, _ := wordData["word"].(string)
category, _ := wordData["category"].(string)
level, _ := wordData["level"].(float64)
if word != "" {
s.tree.AddWord(word, model.SensitiveWordLevel(int(level)), model.SensitiveWordCategory(category))
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return nil
}
// ==================== 辅助函数 ====================
// ContainsSensitiveWord 快速检查文本是否包含敏感词
func ContainsSensitiveWord(text string, tree *SensitiveWordTree) bool {
if tree == nil || text == "" {
return false
}
hasSensitive, _ := tree.Check(text)
return hasSensitive
}
// FilterSensitiveWords 过滤敏感词并返回替换后的文本
func FilterSensitiveWords(text string, tree *SensitiveWordTree, repl string) string {
if tree == nil || text == "" {
return text
}
if repl == "" {
repl = "***"
}
return tree.Replace(text, repl)
}
// ValidateTextLength 验证文本长度是否合法
func ValidateTextLength(text string, minLen, maxLen int) bool {
length := utf8.RuneCountInString(text)
return length >= minLen && length <= maxLen
}
// SanitizeText 清理文本,移除多余空白字符
func SanitizeText(text string) string {
// 替换多个连续空白字符为单个空格
spaceReg := regexp.MustCompile(`\s+`)
text = spaceReg.ReplaceAllString(text, " ")
// 去除首尾空白
return strings.TrimSpace(text)
}
// ==================== 默认敏感词列表 ====================
// DefaultSensitiveWords 返回默认敏感词列表(示例)
func DefaultSensitiveWords() map[string]struct{} {
return map[string]struct{}{
// 示例敏感词,实际需要从数据库或配置加载
"测试敏感词1": {},
"测试敏感词2": {},
"测试敏感词3": {},
}
}

View File

@@ -0,0 +1,139 @@
package service
import (
"carrot_bbs/internal/model"
"carrot_bbs/internal/repository"
"errors"
"net/url"
"strings"
)
var (
ErrStickerAlreadyExists = errors.New("sticker already exists")
ErrInvalidStickerURL = errors.New("invalid sticker url")
)
// StickerService 自定义表情服务接口
type StickerService interface {
// 获取用户的所有表情
GetUserStickers(userID string) ([]model.UserSticker, error)
// 添加表情
AddSticker(userID string, url string, width, height int) (*model.UserSticker, error)
// 删除表情
DeleteSticker(userID string, stickerID string) error
// 检查表情是否已存在
CheckExists(userID string, url string) (bool, error)
// 重新排序
ReorderStickers(userID string, orders map[string]int) error
// 获取用户表情数量
GetStickerCount(userID string) (int64, error)
}
// stickerService 自定义表情服务实现
type stickerService struct {
stickerRepo repository.StickerRepository
}
// NewStickerService 创建自定义表情服务
func NewStickerService(stickerRepo repository.StickerRepository) StickerService {
return &stickerService{
stickerRepo: stickerRepo,
}
}
// GetUserStickers 获取用户的所有表情
func (s *stickerService) GetUserStickers(userID string) ([]model.UserSticker, error) {
stickers, err := s.stickerRepo.GetByUserID(userID)
if err != nil {
return nil, err
}
// 兼容历史脏数据:过滤本地文件 URI避免客户端加载 file:// 报错
filtered := make([]model.UserSticker, 0, len(stickers))
for _, sticker := range stickers {
if isValidStickerURL(sticker.URL) {
filtered = append(filtered, sticker)
}
}
return filtered, nil
}
// AddSticker 添加表情
func (s *stickerService) AddSticker(userID string, url string, width, height int) (*model.UserSticker, error) {
if !isValidStickerURL(url) {
return nil, ErrInvalidStickerURL
}
// 检查是否已存在
exists, err := s.stickerRepo.Exists(userID, url)
if err != nil {
return nil, err
}
if exists {
return nil, ErrStickerAlreadyExists
}
// 获取当前数量用于设置排序
count, err := s.stickerRepo.CountByUserID(userID)
if err != nil {
return nil, err
}
sticker := &model.UserSticker{
UserID: userID,
URL: url,
Width: width,
Height: height,
SortOrder: int(count), // 新表情添加到末尾
}
if err := s.stickerRepo.Create(sticker); err != nil {
return nil, err
}
return sticker, nil
}
func isValidStickerURL(raw string) bool {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return false
}
parsed, err := url.Parse(trimmed)
if err != nil {
return false
}
scheme := strings.ToLower(parsed.Scheme)
return scheme == "http" || scheme == "https"
}
// DeleteSticker 删除表情
func (s *stickerService) DeleteSticker(userID string, stickerID string) error {
// 先检查表情是否属于该用户
sticker, err := s.stickerRepo.GetByID(stickerID)
if err != nil {
return err
}
if sticker.UserID != userID {
return errors.New("sticker not found")
}
return s.stickerRepo.Delete(stickerID)
}
// CheckExists 检查表情是否已存在
func (s *stickerService) CheckExists(userID string, url string) (bool, error) {
return s.stickerRepo.Exists(userID, url)
}
// ReorderStickers 重新排序
func (s *stickerService) ReorderStickers(userID string, orders map[string]int) error {
return s.stickerRepo.BatchUpdateSortOrder(userID, orders)
}
// GetStickerCount 获取用户表情数量
func (s *stickerService) GetStickerCount(userID string) (int64, error) {
return s.stickerRepo.CountByUserID(userID)
}

View File

@@ -0,0 +1,462 @@
package service
import (
"context"
"fmt"
"time"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/utils"
"carrot_bbs/internal/repository"
)
// SystemMessageService 系统消息服务接口
type SystemMessageService interface {
// 发送互动通知
SendLikeNotification(ctx context.Context, userID string, operatorID string, postID string) error
SendCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string) error
SendReplyNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, replyID string) error
SendFollowNotification(ctx context.Context, userID string, operatorID string) error
SendMentionNotification(ctx context.Context, userID string, operatorID string, postID string) error
SendFavoriteNotification(ctx context.Context, userID string, operatorID string, postID string) error
SendLikeCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, commentContent string) error
SendLikeReplyNotification(ctx context.Context, userID string, operatorID string, postID string, replyID string, replyContent string) error
// 发送系统公告
SendSystemAnnouncement(ctx context.Context, userIDs []string, title string, content string) error
SendBroadcastAnnouncement(ctx context.Context, title string, content string) error
}
type systemMessageServiceImpl struct {
notifyRepo *repository.SystemNotificationRepository
pushService PushService
userRepo *repository.UserRepository
postRepo *repository.PostRepository
cache cache.Cache
}
// NewSystemMessageService 创建系统消息服务
func NewSystemMessageService(
notifyRepo *repository.SystemNotificationRepository,
pushService PushService,
userRepo *repository.UserRepository,
postRepo *repository.PostRepository,
) SystemMessageService {
return &systemMessageServiceImpl{
notifyRepo: notifyRepo,
pushService: pushService,
userRepo: userRepo,
postRepo: postRepo,
cache: cache.GetCache(),
}
}
// SendLikeNotification 发送点赞通知
func (s *systemMessageServiceImpl) SendLikeNotification(ctx context.Context, userID string, operatorID string, postID string) error {
// 获取操作者信息
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
if err != nil {
return err
}
// 获取帖子标题
postTitle, err := s.getPostTitle(postID)
if err != nil {
postTitle = "您的帖子"
}
extraData := &model.SystemNotificationExtra{
ActorIDStr: operatorID,
ActorName: actorName,
AvatarURL: avatarURL,
TargetID: postID,
TargetTitle: postTitle,
TargetType: "post",
ActionURL: fmt.Sprintf("/posts/%s", postID),
ActionTime: time.Now().Format(time.RFC3339),
}
content := fmt.Sprintf("%s 赞了「%s」", actorName, postTitle)
// 创建通知
notification, err := s.createNotification(ctx, userID, model.SysNotifyLikePost, content, extraData)
if err != nil {
return fmt.Errorf("failed to create like notification: %w", err)
}
// 推送通知
return s.pushService.PushSystemNotification(ctx, userID, notification)
}
// SendCommentNotification 发送评论通知
func (s *systemMessageServiceImpl) SendCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string) error {
// 获取操作者信息
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
if err != nil {
return err
}
// 获取帖子标题
postTitle, err := s.getPostTitle(postID)
if err != nil {
postTitle = "您的帖子"
}
extraData := &model.SystemNotificationExtra{
ActorIDStr: operatorID,
ActorName: actorName,
AvatarURL: avatarURL,
TargetID: postID,
TargetTitle: postTitle,
TargetType: "comment",
ActionURL: fmt.Sprintf("/posts/%s?comment=%s", postID, commentID),
ActionTime: time.Now().Format(time.RFC3339),
}
content := fmt.Sprintf("%s 评论了「%s」", actorName, postTitle)
// 创建通知
notification, err := s.createNotification(ctx, userID, model.SysNotifyComment, content, extraData)
if err != nil {
return fmt.Errorf("failed to create comment notification: %w", err)
}
// 推送通知
return s.pushService.PushSystemNotification(ctx, userID, notification)
}
// SendReplyNotification 发送回复通知
func (s *systemMessageServiceImpl) SendReplyNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, replyID string) error {
// 获取操作者信息
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
if err != nil {
return err
}
// 获取帖子标题
postTitle, err := s.getPostTitle(postID)
if err != nil {
postTitle = "您的帖子"
}
extraData := &model.SystemNotificationExtra{
ActorIDStr: operatorID,
ActorName: actorName,
AvatarURL: avatarURL,
TargetID: replyID,
TargetTitle: postTitle,
TargetType: "reply",
ActionURL: fmt.Sprintf("/posts/%s?comment=%s&reply=%s", postID, commentID, replyID),
ActionTime: time.Now().Format(time.RFC3339),
}
content := fmt.Sprintf("%s 回复了您在「%s」的评论", actorName, postTitle)
// 创建通知
notification, err := s.createNotification(ctx, userID, model.SysNotifyReply, content, extraData)
if err != nil {
return fmt.Errorf("failed to create reply notification: %w", err)
}
// 推送通知
return s.pushService.PushSystemNotification(ctx, userID, notification)
}
// SendFollowNotification 发送关注通知
func (s *systemMessageServiceImpl) SendFollowNotification(ctx context.Context, userID string, operatorID string) error {
fmt.Printf("[DEBUG] SendFollowNotification: userID=%s, operatorID=%s\n", userID, operatorID)
// 获取操作者信息
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
if err != nil {
fmt.Printf("[DEBUG] SendFollowNotification: getActorInfo error: %v\n", err)
return err
}
fmt.Printf("[DEBUG] SendFollowNotification: actorName=%s, avatarURL=%s\n", actorName, avatarURL)
extraData := &model.SystemNotificationExtra{
ActorIDStr: operatorID,
ActorName: actorName,
AvatarURL: avatarURL,
TargetID: "",
TargetType: "user",
ActionURL: fmt.Sprintf("/users/%s", operatorID),
ActionTime: time.Now().Format(time.RFC3339),
}
content := fmt.Sprintf("%s 关注了你", actorName)
// 创建通知
notification, err := s.createNotification(ctx, userID, model.SysNotifyFollow, content, extraData)
if err != nil {
return fmt.Errorf("failed to create follow notification: %w", err)
}
fmt.Printf("[DEBUG] SendFollowNotification: notification created, ID=%d, Content=%s\n", notification.ID, notification.Content)
// 推送通知
pushErr := s.pushService.PushSystemNotification(ctx, userID, notification)
if pushErr != nil {
fmt.Printf("[DEBUG] SendFollowNotification: PushSystemNotification error: %v\n", pushErr)
} else {
fmt.Printf("[DEBUG] SendFollowNotification: PushSystemNotification success\n")
}
return pushErr
}
// SendFavoriteNotification 发送收藏通知
func (s *systemMessageServiceImpl) SendFavoriteNotification(ctx context.Context, userID string, operatorID string, postID string) error {
// 获取操作者信息
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
if err != nil {
return err
}
// 获取帖子标题
postTitle, err := s.getPostTitle(postID)
if err != nil {
postTitle = "您的帖子"
}
extraData := &model.SystemNotificationExtra{
ActorIDStr: operatorID,
ActorName: actorName,
AvatarURL: avatarURL,
TargetID: postID,
TargetTitle: postTitle,
TargetType: "post",
ActionURL: fmt.Sprintf("/posts/%s", postID),
ActionTime: time.Now().Format(time.RFC3339),
}
content := fmt.Sprintf("%s 收藏了「%s」", actorName, postTitle)
// 创建通知
notification, err := s.createNotification(ctx, userID, model.SysNotifyFavoritePost, content, extraData)
if err != nil {
return fmt.Errorf("failed to create favorite notification: %w", err)
}
// 推送通知
return s.pushService.PushSystemNotification(ctx, userID, notification)
}
// SendLikeCommentNotification 发送评论点赞通知
func (s *systemMessageServiceImpl) SendLikeCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, commentContent string) error {
// 获取操作者信息
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
if err != nil {
return err
}
// 截取评论内容预览最多50字
preview := commentContent
runes := []rune(preview)
if len(runes) > 50 {
preview = string(runes[:50]) + "..."
}
extraData := &model.SystemNotificationExtra{
ActorIDStr: operatorID,
ActorName: actorName,
AvatarURL: avatarURL,
TargetID: postID,
TargetTitle: preview,
TargetType: "comment",
ActionURL: fmt.Sprintf("/posts/%s?comment=%s", postID, commentID),
ActionTime: time.Now().Format(time.RFC3339),
}
content := fmt.Sprintf("%s 赞了您的评论", actorName)
// 创建通知
notification, err := s.createNotification(ctx, userID, model.SysNotifyLikeComment, content, extraData)
if err != nil {
return fmt.Errorf("failed to create like comment notification: %w", err)
}
// 推送通知
return s.pushService.PushSystemNotification(ctx, userID, notification)
}
// SendLikeReplyNotification 发送回复点赞通知
func (s *systemMessageServiceImpl) SendLikeReplyNotification(ctx context.Context, userID string, operatorID string, postID string, replyID string, replyContent string) error {
// 获取操作者信息
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
if err != nil {
return err
}
// 截取回复内容预览最多50字
preview := replyContent
runes := []rune(preview)
if len(runes) > 50 {
preview = string(runes[:50]) + "..."
}
extraData := &model.SystemNotificationExtra{
ActorIDStr: operatorID,
ActorName: actorName,
AvatarURL: avatarURL,
TargetID: postID,
TargetTitle: preview,
TargetType: "reply",
ActionURL: fmt.Sprintf("/posts/%s?reply=%s", postID, replyID),
ActionTime: time.Now().Format(time.RFC3339),
}
content := fmt.Sprintf("%s 赞了您的回复", actorName)
// 创建通知
notification, err := s.createNotification(ctx, userID, model.SysNotifyLikeReply, content, extraData)
if err != nil {
return fmt.Errorf("failed to create like reply notification: %w", err)
}
// 推送通知
return s.pushService.PushSystemNotification(ctx, userID, notification)
}
// SendMentionNotification 发送@提及通知
func (s *systemMessageServiceImpl) SendMentionNotification(ctx context.Context, userID string, operatorID string, postID string) error {
// 获取操作者信息
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
if err != nil {
return err
}
// 获取帖子标题
postTitle, err := s.getPostTitle(postID)
if err != nil {
postTitle = "您的帖子"
}
extraData := &model.SystemNotificationExtra{
ActorIDStr: operatorID,
ActorName: actorName,
AvatarURL: avatarURL,
TargetID: postID,
TargetTitle: postTitle,
TargetType: "post",
ActionURL: fmt.Sprintf("/posts/%s", postID),
ActionTime: time.Now().Format(time.RFC3339),
}
content := fmt.Sprintf("%s 在「%s」中提到了你", actorName, postTitle)
// 创建通知
notification, err := s.createNotification(ctx, userID, model.SysNotifyMention, content, extraData)
if err != nil {
return fmt.Errorf("failed to create mention notification: %w", err)
}
// 推送通知
return s.pushService.PushSystemNotification(ctx, userID, notification)
}
// SendSystemAnnouncement 发送系统公告给指定用户
func (s *systemMessageServiceImpl) SendSystemAnnouncement(ctx context.Context, userIDs []string, title string, content string) error {
for _, userID := range userIDs {
extraData := &model.SystemNotificationExtra{
TargetType: "announcement",
ActionTime: time.Now().Format(time.RFC3339),
}
notification, err := s.createNotification(ctx, userID, model.SysNotifyAnnounce, fmt.Sprintf("【%s】%s", title, content), extraData)
if err != nil {
continue // 单个失败不影响其他用户
}
// 推送通知(使用高优先级)
if err := s.pushService.PushSystemNotification(ctx, userID, notification); err != nil {
continue
}
}
return nil
}
// SendBroadcastAnnouncement 发送广播公告给所有在线用户
func (s *systemMessageServiceImpl) SendBroadcastAnnouncement(ctx context.Context, title string, content string) error {
// TODO: 实现广播公告
// 1. 获取所有在线用户
// 2. 批量发送公告
// 3. 对于离线用户,存储为待推送记录
return fmt.Errorf("broadcast announcement not implemented")
}
// createNotification 创建系统通知(存储到独立表)
func (s *systemMessageServiceImpl) createNotification(ctx context.Context, userID string, notifyType model.SystemNotificationType, content string, extraData *model.SystemNotificationExtra) (*model.SystemNotification, error) {
fmt.Printf("[DEBUG] createNotification: userID=%s, notifyType=%s\n", userID, notifyType)
// 生成雪花算法ID
id, err := utils.GetSnowflake().GenerateID()
if err != nil {
fmt.Printf("[DEBUG] createNotification: failed to generate ID: %v\n", err)
return nil, fmt.Errorf("failed to generate notification ID: %w", err)
}
notification := &model.SystemNotification{
ID: id,
ReceiverID: userID,
Type: notifyType,
Content: content,
ExtraData: extraData,
IsRead: false,
}
fmt.Printf("[DEBUG] createNotification: notification created with ID=%d, ReceiverID=%s\n", id, userID)
// 保存通知到数据库
if err := s.notifyRepo.Create(notification); err != nil {
fmt.Printf("[DEBUG] createNotification: failed to save notification: %v\n", err)
return nil, fmt.Errorf("failed to save notification: %w", err)
}
// 失效系统消息未读数缓存
cache.InvalidateUnreadSystem(s.cache, userID)
fmt.Printf("[DEBUG] createNotification: notification saved successfully, ID=%d\n", notification.ID)
return notification, nil
}
// getActorInfo 获取操作者信息
func (s *systemMessageServiceImpl) getActorInfo(ctx context.Context, operatorID string) (string, string, error) {
// 从用户仓储获取用户信息
if s.userRepo != nil {
user, err := s.userRepo.GetByID(operatorID)
if err != nil {
fmt.Printf("[DEBUG] getActorInfo: failed to get user %s: %v\n", operatorID, err)
return "用户", utils.GenerateDefaultAvatarURL("用户"), nil // 返回默认值,不阻断流程
}
avatar := utils.GetAvatarOrDefault(user.Username, user.Nickname, user.Avatar)
return user.Nickname, avatar, nil
}
// 如果没有用户仓储,返回默认值
return "用户", utils.GenerateDefaultAvatarURL("用户"), nil
}
// getPostTitle 获取帖子标题
func (s *systemMessageServiceImpl) getPostTitle(postID string) (string, error) {
if s.postRepo == nil {
if len(postID) >= 8 {
return fmt.Sprintf("帖子#%s", postID[:8]), nil
}
return fmt.Sprintf("帖子#%s", postID), nil
}
post, err := s.postRepo.GetByID(postID)
if err != nil {
if len(postID) >= 8 {
return fmt.Sprintf("帖子#%s", postID[:8]), nil
}
return fmt.Sprintf("帖子#%s", postID), nil
}
if post.Title != "" {
return post.Title, nil
}
// 如果没有标题返回内容前20个字符
if len(post.Content) > 20 {
return post.Content[:20] + "...", nil
}
return post.Content, nil
}

View File

@@ -0,0 +1,273 @@
package service
import (
"bytes"
"context"
"crypto/sha256"
"fmt"
"image"
"image/jpeg"
"image/png"
"io"
"mime"
"mime/multipart"
"net/http"
"path/filepath"
"strings"
"carrot_bbs/internal/pkg/s3"
_ "golang.org/x/image/bmp"
_ "golang.org/x/image/tiff"
)
// UploadService 上传服务
type UploadService struct {
s3Client *s3.Client
userService *UserService
}
// NewUploadService 创建上传服务
func NewUploadService(s3Client *s3.Client, userService *UserService) *UploadService {
return &UploadService{
s3Client: s3Client,
userService: userService,
}
}
// UploadImage 上传图片
func (s *UploadService) UploadImage(ctx context.Context, file *multipart.FileHeader) (string, error) {
processedData, contentType, ext, err := prepareImageForUpload(file)
if err != nil {
return "", err
}
// 压缩后再计算哈希,确保同一压缩结果映射同一对象名
hash := sha256.Sum256(processedData)
hashStr := fmt.Sprintf("%x", hash)
objectName := fmt.Sprintf("images/%s%s", hashStr, ext)
url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType)
if err != nil {
return "", fmt.Errorf("failed to upload to S3: %w", err)
}
return url, nil
}
// getExtFromContentType 根据Content-Type获取文件扩展名
func getExtFromContentType(contentType string) string {
baseType, _, err := mime.ParseMediaType(contentType)
if err == nil && baseType != "" {
contentType = baseType
}
switch contentType {
case "image/jpg", "image/jpeg":
return ".jpg"
case "image/png":
return ".png"
case "image/gif":
return ".gif"
case "image/webp":
return ".webp"
case "image/bmp", "image/x-ms-bmp":
return ".bmp"
case "image/tiff":
return ".tiff"
default:
return ""
}
}
// UploadAvatar 上传头像
func (s *UploadService) UploadAvatar(ctx context.Context, userID string, file *multipart.FileHeader) (string, error) {
processedData, contentType, ext, err := prepareImageForUpload(file)
if err != nil {
return "", err
}
// 压缩后再计算哈希
hash := sha256.Sum256(processedData)
hashStr := fmt.Sprintf("%x", hash)
objectName := fmt.Sprintf("avatars/%s%s", hashStr, ext)
url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType)
if err != nil {
return "", fmt.Errorf("failed to upload to S3: %w", err)
}
// 更新用户头像
if s.userService != nil {
user, err := s.userService.GetUserByID(ctx, userID)
if err == nil && user != nil {
user.Avatar = url
err = s.userService.UpdateUser(ctx, user)
if err != nil {
// 更新失败不影响上传结果,只记录日志
fmt.Printf("[UploadAvatar] failed to update user avatar: %v\n", err)
}
}
}
return url, nil
}
// UploadCover 上传头图(个人主页封面)
func (s *UploadService) UploadCover(ctx context.Context, userID string, file *multipart.FileHeader) (string, error) {
processedData, contentType, ext, err := prepareImageForUpload(file)
if err != nil {
return "", err
}
// 压缩后再计算哈希
hash := sha256.Sum256(processedData)
hashStr := fmt.Sprintf("%x", hash)
objectName := fmt.Sprintf("covers/%s%s", hashStr, ext)
url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType)
if err != nil {
return "", fmt.Errorf("failed to upload to S3: %w", err)
}
// 更新用户头图
if s.userService != nil {
user, err := s.userService.GetUserByID(ctx, userID)
if err == nil && user != nil {
user.CoverURL = url
err = s.userService.UpdateUser(ctx, user)
if err != nil {
// 更新失败不影响上传结果,只记录日志
fmt.Printf("[UploadCover] failed to update user cover: %v\n", err)
}
}
}
return url, nil
}
// GetURL 获取文件URL
func (s *UploadService) GetURL(ctx context.Context, objectName string) (string, error) {
return s.s3Client.GetURL(ctx, objectName)
}
// Delete 删除文件
func (s *UploadService) Delete(ctx context.Context, objectName string) error {
return s.s3Client.Delete(ctx, objectName)
}
func prepareImageForUpload(file *multipart.FileHeader) ([]byte, string, string, error) {
f, err := file.Open()
if err != nil {
return nil, "", "", fmt.Errorf("failed to open file: %w", err)
}
defer f.Close()
originalData, err := io.ReadAll(f)
if err != nil {
return nil, "", "", fmt.Errorf("failed to read file: %w", err)
}
// 优先从文件字节探测真实类型,避免前端压缩/转码后 header 与实际格式不一致
detectedType := normalizeImageContentType(http.DetectContentType(originalData))
headerType := normalizeImageContentType(file.Header.Get("Content-Type"))
contentType := detectedType
if contentType == "" || contentType == "application/octet-stream" {
contentType = headerType
}
compressedData, compressedType, err := compressImageData(originalData, contentType)
if err != nil {
// 压缩失败时回退到原图,保证上传可用性
compressedData = originalData
compressedType = contentType
}
if compressedType == "" {
compressedType = contentType
}
if compressedType == "" {
compressedType = http.DetectContentType(compressedData)
}
ext := getExtFromContentType(compressedType)
if ext == "" {
ext = strings.ToLower(filepath.Ext(file.Filename))
}
if ext == "" {
// 最终兜底,避免对象名无扩展名导致 URL 语义不明确
ext = ".jpg"
}
return compressedData, compressedType, ext, nil
}
func compressImageData(data []byte, contentType string) ([]byte, string, error) {
contentType = normalizeImageContentType(contentType)
// GIF/WebP 等格式先保留原图,避免动画和透明通道丢失
if contentType == "image/gif" || contentType == "image/webp" {
return data, contentType, nil
}
if contentType != "image/jpeg" &&
contentType != "image/png" &&
contentType != "image/bmp" &&
contentType != "image/x-ms-bmp" &&
contentType != "image/tiff" {
return data, contentType, nil
}
img, _, err := image.Decode(bytes.NewReader(data))
if err != nil {
return nil, "", fmt.Errorf("failed to decode image: %w", err)
}
var buf bytes.Buffer
switch contentType {
case "image/png":
encoder := png.Encoder{CompressionLevel: png.BestCompression}
if err := encoder.Encode(&buf, img); err != nil {
return nil, "", fmt.Errorf("failed to encode png: %w", err)
}
return buf.Bytes(), "image/png", nil
default:
// BMP/TIFF 等无损大图统一压缩为 JPEG控制体积
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 82}); err != nil {
return nil, "", fmt.Errorf("failed to encode jpeg: %w", err)
}
return buf.Bytes(), "image/jpeg", nil
}
}
func normalizeImageContentType(contentType string) string {
if contentType == "" {
return ""
}
baseType, _, err := mime.ParseMediaType(contentType)
if err == nil && baseType != "" {
contentType = baseType
}
switch strings.ToLower(contentType) {
case "image/jpg":
return "image/jpeg"
case "image/jpeg":
return "image/jpeg"
case "image/png":
return "image/png"
case "image/gif":
return "image/gif"
case "image/webp":
return "image/webp"
case "image/bmp", "image/x-ms-bmp":
return "image/bmp"
case "image/tiff":
return "image/tiff"
default:
return contentType
}
}

View File

@@ -0,0 +1,592 @@
package service
import (
"context"
"fmt"
"strings"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/utils"
"carrot_bbs/internal/repository"
)
// UserService 用户服务
type UserService struct {
userRepo *repository.UserRepository
systemMessageService SystemMessageService
emailCodeService EmailCodeService
}
// NewUserService 创建用户服务
func NewUserService(
userRepo *repository.UserRepository,
systemMessageService SystemMessageService,
emailService EmailService,
cacheBackend cache.Cache,
) *UserService {
return &UserService{
userRepo: userRepo,
systemMessageService: systemMessageService,
emailCodeService: NewEmailCodeService(emailService, cacheBackend),
}
}
// SendRegisterCode 发送注册验证码
func (s *UserService) SendRegisterCode(ctx context.Context, email string) error {
user, err := s.userRepo.GetByEmail(email)
if err == nil && user != nil {
return ErrEmailExists
}
return s.emailCodeService.SendCode(ctx, CodePurposeRegister, email)
}
// SendPasswordResetCode 发送找回密码验证码
func (s *UserService) SendPasswordResetCode(ctx context.Context, email string) error {
user, err := s.userRepo.GetByEmail(email)
if err != nil || user == nil {
return ErrUserNotFound
}
return s.emailCodeService.SendCode(ctx, CodePurposePasswordReset, email)
}
// SendCurrentUserEmailVerifyCode 发送当前用户邮箱验证验证码
func (s *UserService) SendCurrentUserEmailVerifyCode(ctx context.Context, userID, email string) error {
user, err := s.userRepo.GetByID(userID)
if err != nil || user == nil {
return ErrUserNotFound
}
targetEmail := strings.TrimSpace(email)
if targetEmail == "" && user.Email != nil {
targetEmail = strings.TrimSpace(*user.Email)
}
if targetEmail == "" || !utils.ValidateEmail(targetEmail) {
return ErrInvalidEmail
}
if user.EmailVerified && user.Email != nil && strings.EqualFold(strings.TrimSpace(*user.Email), targetEmail) {
return ErrEmailAlreadyVerified
}
existingUser, queryErr := s.userRepo.GetByEmail(targetEmail)
if queryErr == nil && existingUser != nil && existingUser.ID != userID {
return ErrEmailExists
}
return s.emailCodeService.SendCode(ctx, CodePurposeEmailVerify, targetEmail)
}
// VerifyCurrentUserEmail 验证当前用户邮箱
func (s *UserService) VerifyCurrentUserEmail(ctx context.Context, userID, email, verificationCode string) error {
user, err := s.userRepo.GetByID(userID)
if err != nil || user == nil {
return ErrUserNotFound
}
targetEmail := strings.TrimSpace(email)
if targetEmail == "" && user.Email != nil {
targetEmail = strings.TrimSpace(*user.Email)
}
if targetEmail == "" || !utils.ValidateEmail(targetEmail) {
return ErrInvalidEmail
}
if err := s.emailCodeService.VerifyCode(CodePurposeEmailVerify, targetEmail, verificationCode); err != nil {
return err
}
existingUser, queryErr := s.userRepo.GetByEmail(targetEmail)
if queryErr == nil && existingUser != nil && existingUser.ID != userID {
return ErrEmailExists
}
user.Email = &targetEmail
user.EmailVerified = true
return s.userRepo.Update(user)
}
// SendChangePasswordCode 发送修改密码验证码
func (s *UserService) SendChangePasswordCode(ctx context.Context, userID string) error {
user, err := s.userRepo.GetByID(userID)
if err != nil || user == nil {
return ErrUserNotFound
}
if user.Email == nil || strings.TrimSpace(*user.Email) == "" {
return ErrEmailNotBound
}
return s.emailCodeService.SendCode(ctx, CodePurposeChangePassword, *user.Email)
}
// Register 用户注册
func (s *UserService) Register(ctx context.Context, username, email, password, nickname, phone, verificationCode string) (*model.User, error) {
// 验证用户名
if !utils.ValidateUsername(username) {
return nil, ErrInvalidUsername
}
// 注册必须提供邮箱并完成验证码校验
if email == "" || !utils.ValidateEmail(email) {
return nil, ErrInvalidEmail
}
if err := s.emailCodeService.VerifyCode(CodePurposeRegister, email, verificationCode); err != nil {
return nil, err
}
// 验证密码
if !utils.ValidatePassword(password) {
return nil, ErrWeakPassword
}
// 验证手机号(如果提供)
if phone != "" && !utils.ValidatePhone(phone) {
return nil, ErrInvalidPhone
}
// 检查用户名是否已存在
existingUser, err := s.userRepo.GetByUsername(username)
if err == nil && existingUser != nil {
return nil, ErrUsernameExists
}
// 检查邮箱是否已存在(如果提供)
if email != "" {
existingUser, err = s.userRepo.GetByEmail(email)
if err == nil && existingUser != nil {
return nil, ErrEmailExists
}
}
// 检查手机号是否已存在(如果提供)
if phone != "" {
existingUser, err = s.userRepo.GetByPhone(phone)
if err == nil && existingUser != nil {
return nil, ErrPhoneExists
}
}
// 密码哈希
hashedPassword, err := utils.HashPassword(password)
if err != nil {
return nil, err
}
// 创建用户
user := &model.User{
Username: username,
Nickname: nickname,
EmailVerified: true,
PasswordHash: hashedPassword,
Status: model.UserStatusActive,
}
// 如果提供了邮箱,设置指针值
if email != "" {
user.Email = &email
}
// 如果提供了手机号,设置指针值
if phone != "" {
user.Phone = &phone
}
err = s.userRepo.Create(user)
if err != nil {
return nil, err
}
return user, nil
}
// Login 用户登录
func (s *UserService) Login(ctx context.Context, account, password string) (*model.User, error) {
account = strings.TrimSpace(account)
var (
user *model.User
err error
)
if utils.ValidateEmail(account) {
user, err = s.userRepo.GetByEmail(account)
} else if utils.ValidatePhone(account) {
user, err = s.userRepo.GetByPhone(account)
} else {
user, err = s.userRepo.GetByUsername(account)
}
if err != nil || user == nil {
return nil, ErrInvalidCredentials
}
if !utils.CheckPasswordHash(password, user.PasswordHash) {
return nil, ErrInvalidCredentials
}
if user.Status != model.UserStatusActive {
return nil, ErrUserBanned
}
return user, nil
}
// GetUserByID 根据ID获取用户
func (s *UserService) GetUserByID(ctx context.Context, id string) (*model.User, error) {
return s.userRepo.GetByID(id)
}
// GetUserPostCount 获取用户帖子数(实时计算)
func (s *UserService) GetUserPostCount(ctx context.Context, userID string) (int64, error) {
return s.userRepo.GetPostsCount(userID)
}
// GetUserPostCountBatch 批量获取用户帖子数(实时计算)
func (s *UserService) GetUserPostCountBatch(ctx context.Context, userIDs []string) (map[string]int64, error) {
return s.userRepo.GetPostsCountBatch(userIDs)
}
// GetUserByIDWithFollowingStatus 根据ID获取用户包含当前用户是否关注的状态
func (s *UserService) GetUserByIDWithFollowingStatus(ctx context.Context, userID, currentUserID string) (*model.User, bool, error) {
user, err := s.userRepo.GetByID(userID)
if err != nil {
return nil, false, err
}
// 如果查询的是当前用户自己,不需要检查关注状态
if userID == currentUserID {
return user, false, nil
}
isFollowing, err := s.userRepo.IsFollowing(currentUserID, userID)
if err != nil {
return user, false, err
}
return user, isFollowing, nil
}
// GetUserByIDWithMutualFollowStatus 根据ID获取用户包含双向关注状态
func (s *UserService) GetUserByIDWithMutualFollowStatus(ctx context.Context, userID, currentUserID string) (*model.User, bool, bool, error) {
user, err := s.userRepo.GetByID(userID)
if err != nil {
return nil, false, false, err
}
// 如果查询的是当前用户自己,不需要检查关注状态
if userID == currentUserID {
return user, false, false, nil
}
// 当前用户是否关注了该用户
isFollowing, err := s.userRepo.IsFollowing(currentUserID, userID)
if err != nil {
return user, false, false, err
}
// 该用户是否关注了当前用户
isFollowingMe, err := s.userRepo.IsFollowing(userID, currentUserID)
if err != nil {
return user, isFollowing, false, err
}
return user, isFollowing, isFollowingMe, nil
}
// UpdateUser 更新用户
func (s *UserService) UpdateUser(ctx context.Context, user *model.User) error {
return s.userRepo.Update(user)
}
// GetFollowers 获取粉丝
func (s *UserService) GetFollowers(ctx context.Context, userID string, page, pageSize int) ([]*model.User, int64, error) {
return s.userRepo.GetFollowers(userID, page, pageSize)
}
// GetFollowing 获取关注
func (s *UserService) GetFollowing(ctx context.Context, userID string, page, pageSize int) ([]*model.User, int64, error) {
return s.userRepo.GetFollowing(userID, page, pageSize)
}
// FollowUser 关注用户
func (s *UserService) FollowUser(ctx context.Context, followerID, followeeID string) error {
fmt.Printf("[DEBUG] FollowUser called: followerID=%s, followeeID=%s\n", followerID, followeeID)
blocked, err := s.userRepo.IsBlockedEitherDirection(followerID, followeeID)
if err != nil {
return err
}
if blocked {
return ErrUserBlocked
}
// 检查是否已经关注
isFollowing, err := s.userRepo.IsFollowing(followerID, followeeID)
if err != nil {
fmt.Printf("[DEBUG] Error checking existing follow: %v\n", err)
return err
}
if isFollowing {
fmt.Printf("[DEBUG] Already following, skip creation\n")
return nil // 已经关注,直接返回成功
}
// 创建关注关系
follow := &model.Follow{
FollowerID: followerID,
FollowingID: followeeID,
}
err = s.userRepo.CreateFollow(follow)
if err != nil {
fmt.Printf("[DEBUG] CreateFollow error: %v\n", err)
return err
}
fmt.Printf("[DEBUG] Follow record created successfully\n")
// 刷新关注者的关注数(通过实际计数,更可靠)
err = s.userRepo.RefreshFollowingCount(followerID)
if err != nil {
fmt.Printf("[DEBUG] Error refreshing following count: %v\n", err)
// 不回滚,计数可以通过其他方式修复
}
// 刷新被关注者的粉丝数(通过实际计数,更可靠)
err = s.userRepo.RefreshFollowersCount(followeeID)
if err != nil {
fmt.Printf("[DEBUG] Error refreshing followers count: %v\n", err)
// 不回滚,计数可以通过其他方式修复
}
// 发送关注通知给被关注者
if s.systemMessageService != nil {
// 异步发送通知,不阻塞主流程
go func() {
notifyErr := s.systemMessageService.SendFollowNotification(context.Background(), followeeID, followerID)
if notifyErr != nil {
fmt.Printf("[DEBUG] Error sending follow notification: %v\n", notifyErr)
} else {
fmt.Printf("[DEBUG] Follow notification sent successfully to %s\n", followeeID)
}
}()
}
fmt.Printf("[DEBUG] FollowUser completed: followerID=%s, followeeID=%s\n", followerID, followeeID)
return nil
}
// UnfollowUser 取消关注用户
func (s *UserService) UnfollowUser(ctx context.Context, followerID, followeeID string) error {
fmt.Printf("[DEBUG] UnfollowUser called: followerID=%s, followeeID=%s\n", followerID, followeeID)
// 检查是否已经关注
isFollowing, err := s.userRepo.IsFollowing(followerID, followeeID)
if err != nil {
fmt.Printf("[DEBUG] Error checking existing follow: %v\n", err)
return err
}
if !isFollowing {
fmt.Printf("[DEBUG] Not following, skip deletion\n")
return nil // 没有关注,直接返回成功
}
// 删除关注关系
err = s.userRepo.DeleteFollow(followerID, followeeID)
if err != nil {
fmt.Printf("[DEBUG] DeleteFollow error: %v\n", err)
return err
}
fmt.Printf("[DEBUG] Follow record deleted successfully\n")
// 刷新关注者的关注数(通过实际计数,更可靠)
err = s.userRepo.RefreshFollowingCount(followerID)
if err != nil {
fmt.Printf("[DEBUG] Error refreshing following count: %v\n", err)
}
// 刷新被关注者的粉丝数(通过实际计数,更可靠)
err = s.userRepo.RefreshFollowersCount(followeeID)
if err != nil {
fmt.Printf("[DEBUG] Error refreshing followers count: %v\n", err)
}
fmt.Printf("[DEBUG] UnfollowUser completed: followerID=%s, followeeID=%s\n", followerID, followeeID)
return nil
}
// BlockUser 拉黑用户,并自动清理双向关注/粉丝关系
func (s *UserService) BlockUser(ctx context.Context, blockerID, blockedID string) error {
if blockerID == blockedID {
return ErrInvalidOperation
}
return s.userRepo.BlockUserAndCleanupRelations(blockerID, blockedID)
}
// UnblockUser 取消拉黑
func (s *UserService) UnblockUser(ctx context.Context, blockerID, blockedID string) error {
if blockerID == blockedID {
return ErrInvalidOperation
}
return s.userRepo.UnblockUser(blockerID, blockedID)
}
// GetBlockedUsers 获取黑名单列表
func (s *UserService) GetBlockedUsers(ctx context.Context, blockerID string, page, pageSize int) ([]*model.User, int64, error) {
return s.userRepo.GetBlockedUsers(blockerID, page, pageSize)
}
// IsBlocked 检查当前用户是否已拉黑目标用户
func (s *UserService) IsBlocked(ctx context.Context, blockerID, blockedID string) (bool, error) {
return s.userRepo.IsBlocked(blockerID, blockedID)
}
// GetFollowingList 获取关注列表(字符串参数版本)
func (s *UserService) GetFollowingList(ctx context.Context, userID, page, pageSize string) ([]*model.User, error) {
// 转换字符串参数为整数
pageInt := 1
pageSizeInt := 20
if page != "" {
_, err := fmt.Sscanf(page, "%d", &pageInt)
if err != nil {
pageInt = 1
}
}
if pageSize != "" {
_, err := fmt.Sscanf(pageSize, "%d", &pageSizeInt)
if err != nil {
pageSizeInt = 20
}
}
users, _, err := s.userRepo.GetFollowing(userID, pageInt, pageSizeInt)
return users, err
}
// GetFollowersList 获取粉丝列表(字符串参数版本)
func (s *UserService) GetFollowersList(ctx context.Context, userID, page, pageSize string) ([]*model.User, error) {
// 转换字符串参数为整数
pageInt := 1
pageSizeInt := 20
if page != "" {
_, err := fmt.Sscanf(page, "%d", &pageInt)
if err != nil {
pageInt = 1
}
}
if pageSize != "" {
_, err := fmt.Sscanf(pageSize, "%d", &pageSizeInt)
if err != nil {
pageSizeInt = 20
}
}
users, _, err := s.userRepo.GetFollowers(userID, pageInt, pageSizeInt)
return users, err
}
// GetMutualFollowStatus 批量获取双向关注状态
func (s *UserService) GetMutualFollowStatus(ctx context.Context, currentUserID string, targetUserIDs []string) (map[string][2]bool, error) {
return s.userRepo.GetMutualFollowStatus(currentUserID, targetUserIDs)
}
// CheckUsernameAvailable 检查用户名是否可用
func (s *UserService) CheckUsernameAvailable(ctx context.Context, username string) (bool, error) {
user, err := s.userRepo.GetByUsername(username)
if err != nil {
return true, nil // 用户不存在,可用
}
return user == nil, nil
}
// ChangePassword 修改密码
func (s *UserService) ChangePassword(ctx context.Context, userID, oldPassword, newPassword, verificationCode string) error {
// 获取用户
user, err := s.userRepo.GetByID(userID)
if err != nil {
return ErrUserNotFound
}
if user.Email == nil || strings.TrimSpace(*user.Email) == "" {
return ErrEmailNotBound
}
if err := s.emailCodeService.VerifyCode(CodePurposeChangePassword, *user.Email, verificationCode); err != nil {
return err
}
// 验证旧密码
if !utils.CheckPasswordHash(oldPassword, user.PasswordHash) {
return ErrInvalidCredentials
}
// 哈希新密码
hashedPassword, err := utils.HashPassword(newPassword)
if err != nil {
return err
}
// 更新密码
user.PasswordHash = hashedPassword
return s.userRepo.Update(user)
}
// ResetPasswordByEmail 通过邮箱重置密码
func (s *UserService) ResetPasswordByEmail(ctx context.Context, email, verificationCode, newPassword string) error {
email = strings.TrimSpace(email)
if !utils.ValidateEmail(email) {
return ErrInvalidEmail
}
if !utils.ValidatePassword(newPassword) {
return ErrWeakPassword
}
if err := s.emailCodeService.VerifyCode(CodePurposePasswordReset, email, verificationCode); err != nil {
return err
}
user, err := s.userRepo.GetByEmail(email)
if err != nil || user == nil {
return ErrUserNotFound
}
hashedPassword, err := utils.HashPassword(newPassword)
if err != nil {
return err
}
user.PasswordHash = hashedPassword
return s.userRepo.Update(user)
}
// Search 搜索用户
func (s *UserService) Search(ctx context.Context, keyword string, page, pageSize int) ([]*model.User, int64, error) {
return s.userRepo.Search(keyword, page, pageSize)
}
// 错误定义
var (
ErrInvalidUsername = &ServiceError{Code: 400, Message: "invalid username"}
ErrInvalidEmail = &ServiceError{Code: 400, Message: "invalid email"}
ErrInvalidPhone = &ServiceError{Code: 400, Message: "invalid phone number"}
ErrWeakPassword = &ServiceError{Code: 400, Message: "password too weak"}
ErrUsernameExists = &ServiceError{Code: 400, Message: "username already exists"}
ErrEmailExists = &ServiceError{Code: 400, Message: "email already exists"}
ErrPhoneExists = &ServiceError{Code: 400, Message: "phone number already exists"}
ErrUserNotFound = &ServiceError{Code: 404, Message: "user not found"}
ErrUserBanned = &ServiceError{Code: 403, Message: "user is banned"}
ErrUserBlocked = &ServiceError{Code: 403, Message: "blocked relationship exists"}
ErrInvalidOperation = &ServiceError{Code: 400, Message: "invalid operation"}
ErrEmailServiceUnavailable = &ServiceError{Code: 503, Message: "email service unavailable"}
ErrVerificationCodeTooFrequent = &ServiceError{Code: 429, Message: "verification code sent too frequently"}
ErrVerificationCodeInvalid = &ServiceError{Code: 400, Message: "invalid verification code"}
ErrVerificationCodeExpired = &ServiceError{Code: 400, Message: "verification code expired"}
ErrVerificationCodeUnavailable = &ServiceError{Code: 500, Message: "verification code storage unavailable"}
ErrEmailAlreadyVerified = &ServiceError{Code: 400, Message: "email already verified"}
ErrEmailNotBound = &ServiceError{Code: 400, Message: "email not bound"}
)
// ServiceError 服务错误
type ServiceError struct {
Code int
Message string
}
func (e *ServiceError) Error() string {
return e.Message
}
var ErrInvalidCredentials = &ServiceError{Code: 401, Message: "invalid username or password"}

View File

@@ -0,0 +1,282 @@
package service
import (
"context"
"errors"
"fmt"
"log"
"strings"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/repository"
)
// VoteService 投票服务
type VoteService struct {
voteRepo *repository.VoteRepository
postRepo *repository.PostRepository
cache cache.Cache
postAIService *PostAIService
systemMessageService SystemMessageService
}
// NewVoteService 创建投票服务
func NewVoteService(
voteRepo *repository.VoteRepository,
postRepo *repository.PostRepository,
cache cache.Cache,
postAIService *PostAIService,
systemMessageService SystemMessageService,
) *VoteService {
return &VoteService{
voteRepo: voteRepo,
postRepo: postRepo,
cache: cache,
postAIService: postAIService,
systemMessageService: systemMessageService,
}
}
// CreateVotePost 创建投票帖子
func (s *VoteService) CreateVotePost(ctx context.Context, userID string, req *dto.CreateVotePostRequest) (*dto.PostResponse, error) {
// 验证投票选项数量
if len(req.VoteOptions) < 2 {
return nil, errors.New("投票选项至少需要2个")
}
if len(req.VoteOptions) > 10 {
return nil, errors.New("投票选项最多10个")
}
// 创建普通帖子设置IsVote=true
post := &model.Post{
UserID: userID,
CommunityID: req.CommunityID,
Title: req.Title,
Content: req.Content,
Status: model.PostStatusPending,
IsVote: true,
}
err := s.postRepo.Create(post, req.Images)
if err != nil {
return nil, err
}
// 创建投票选项
err = s.voteRepo.CreateOptions(post.ID, req.VoteOptions)
if err != nil {
return nil, err
}
// 异步审核
go s.reviewVotePostAsync(post.ID, userID, req.Title, req.Content, req.Images)
// 重新查询以获取关联的User和Images
createdPost, err := s.postRepo.GetByID(post.ID)
if err != nil {
return nil, err
}
// 转换为响应DTO
return s.convertToPostResponse(createdPost, userID), nil
}
func (s *VoteService) reviewVotePostAsync(postID, userID, title, content string, images []string) {
if s.postAIService == nil || !s.postAIService.IsEnabled() {
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil {
log.Printf("[WARN] Failed to publish vote post without AI moderation: %v", err)
}
return
}
err := s.postAIService.ModeratePost(context.Background(), title, content, images)
if err != nil {
var rejectedErr *PostModerationRejectedError
if errors.As(err, &rejectedErr) {
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
log.Printf("[WARN] Failed to reject vote post %s: %v", postID, updateErr)
}
s.notifyModerationRejected(userID, rejectedErr.Reason)
return
}
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
log.Printf("[WARN] Failed to publish vote post %s after moderation error: %v", postID, updateErr)
}
return
}
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "ai"); err != nil {
log.Printf("[WARN] Failed to publish vote post %s: %v", postID, err)
}
}
func (s *VoteService) notifyModerationRejected(userID, reason string) {
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
return
}
content := "您发布的投票帖未通过AI审核请修改后重试。"
if strings.TrimSpace(reason) != "" {
content = fmt.Sprintf("您发布的投票帖未通过AI审核原因%s。请修改后重试。", reason)
}
go func() {
_ = s.systemMessageService.SendSystemAnnouncement(
context.Background(),
[]string{userID},
"投票帖审核未通过",
content,
)
}()
}
// GetVoteOptions 获取投票选项
func (s *VoteService) GetVoteOptions(postID string) ([]dto.VoteOptionDTO, error) {
options, err := s.voteRepo.GetOptionsByPostID(postID)
if err != nil {
return nil, err
}
result := make([]dto.VoteOptionDTO, 0, len(options))
for _, option := range options {
result = append(result, dto.VoteOptionDTO{
ID: option.ID,
Content: option.Content,
VotesCount: option.VotesCount,
})
}
return result, nil
}
// GetVoteResult 获取投票结果(包含用户投票状态)
func (s *VoteService) GetVoteResult(postID, userID string) (*dto.VoteResultDTO, error) {
// 获取所有投票选项
options, err := s.voteRepo.GetOptionsByPostID(postID)
if err != nil {
return nil, err
}
// 获取用户的投票记录
userVote, err := s.voteRepo.GetUserVote(postID, userID)
if err != nil {
return nil, err
}
// 构建结果
result := &dto.VoteResultDTO{
Options: make([]dto.VoteOptionDTO, 0, len(options)),
TotalVotes: 0,
HasVoted: userVote != nil,
}
if userVote != nil {
result.VotedOptionID = userVote.OptionID
}
for _, option := range options {
result.Options = append(result.Options, dto.VoteOptionDTO{
ID: option.ID,
Content: option.Content,
VotesCount: option.VotesCount,
})
result.TotalVotes += option.VotesCount
}
return result, nil
}
// Vote 投票
func (s *VoteService) Vote(ctx context.Context, postID, userID, optionID string) error {
// 调用voteRepo.Vote
err := s.voteRepo.Vote(postID, userID, optionID)
if err != nil {
return err
}
// 失效帖子详情缓存
cache.InvalidatePostDetail(s.cache, postID)
return nil
}
// Unvote 取消投票
func (s *VoteService) Unvote(ctx context.Context, postID, userID string) error {
// 调用voteRepo.Unvote
err := s.voteRepo.Unvote(postID, userID)
if err != nil {
return err
}
// 失效帖子详情缓存
cache.InvalidatePostDetail(s.cache, postID)
return nil
}
// UpdateVoteOption 更新投票选项(作者权限)
func (s *VoteService) UpdateVoteOption(ctx context.Context, postID, optionID, userID, content string) error {
// 获取帖子信息
post, err := s.postRepo.GetByID(postID)
if err != nil {
return err
}
// 验证用户是否为帖子作者
if post.UserID != userID {
return errors.New("只有帖子作者可以更新投票选项")
}
// 调用voteRepo.UpdateOption
return s.voteRepo.UpdateOption(optionID, content)
}
// convertToPostResponse 将Post模型转换为PostResponse DTO
func (s *VoteService) convertToPostResponse(post *model.Post, currentUserID string) *dto.PostResponse {
if post == nil {
return nil
}
response := &dto.PostResponse{
ID: post.ID,
UserID: post.UserID,
Title: post.Title,
Content: post.Content,
LikesCount: post.LikesCount,
CommentsCount: post.CommentsCount,
FavoritesCount: post.FavoritesCount,
SharesCount: post.SharesCount,
ViewsCount: post.ViewsCount,
IsPinned: post.IsPinned,
IsLocked: post.IsLocked,
IsVote: post.IsVote,
CreatedAt: dto.FormatTime(post.CreatedAt),
Images: make([]dto.PostImageResponse, 0, len(post.Images)),
}
// 转换图片
for _, img := range post.Images {
response.Images = append(response.Images, dto.PostImageResponse{
ID: img.ID,
URL: img.URL,
ThumbnailURL: img.ThumbnailURL,
Width: img.Width,
Height: img.Height,
})
}
// 转换作者信息
if post.User != nil {
response.Author = &dto.UserResponse{
ID: post.User.ID,
Username: post.User.Username,
Nickname: post.User.Nickname,
Avatar: post.User.Avatar,
}
}
return response
}