Initial backend repository commit.
Set up project files and add .gitignore to exclude local build/runtime artifacts. Made-with: Cursor
This commit is contained in:
759
internal/service/audit_service.go
Normal file
759
internal/service/audit_service.go
Normal file
@@ -0,0 +1,759 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ==================== 内容审核服务接口和实现 ====================
|
||||
|
||||
// AuditServiceProvider 内容审核服务提供商接口
|
||||
type AuditServiceProvider interface {
|
||||
// AuditText 审核文本
|
||||
AuditText(ctx context.Context, text string, scene string) (*AuditResult, error)
|
||||
// AuditImage 审核图片
|
||||
AuditImage(ctx context.Context, imageURL string) (*AuditResult, error)
|
||||
// GetName 获取提供商名称
|
||||
GetName() string
|
||||
}
|
||||
|
||||
// AuditResult 审核结果
|
||||
type AuditResult struct {
|
||||
Pass bool `json:"pass"` // 是否通过
|
||||
Risk string `json:"risk"` // 风险等级: low, medium, high
|
||||
Labels []string `json:"labels"` // 标签列表
|
||||
Suggest string `json:"suggest"` // 建议: pass, review, block
|
||||
Detail string `json:"detail"` // 详细说明
|
||||
Provider string `json:"provider"` // 服务提供商
|
||||
}
|
||||
|
||||
// AuditService 内容审核服务接口
|
||||
type AuditService interface {
|
||||
// AuditText 审核文本
|
||||
AuditText(ctx context.Context, text string, auditType string) (*AuditResult, error)
|
||||
// AuditImage 审核图片
|
||||
AuditImage(ctx context.Context, imageURL string) (*AuditResult, error)
|
||||
// GetAuditResult 获取审核结果
|
||||
GetAuditResult(ctx context.Context, auditID string) (*AuditResult, error)
|
||||
// SetProvider 设置审核服务提供商
|
||||
SetProvider(provider AuditServiceProvider)
|
||||
// GetProvider 获取当前审核服务提供商
|
||||
GetProvider() AuditServiceProvider
|
||||
}
|
||||
|
||||
// auditServiceImpl 内容审核服务实现
|
||||
type auditServiceImpl struct {
|
||||
db *gorm.DB
|
||||
provider AuditServiceProvider
|
||||
config *AuditConfig
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// AuditConfig 内容审核服务配置
|
||||
type AuditConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" yaml:"enabled"`
|
||||
// 审核服务提供商: local, aliyun, tencent, baidu
|
||||
Provider string `mapstructure:"provider" yaml:"provider"`
|
||||
// 阿里云配置
|
||||
AliyunAccessKey string `mapstructure:"aliyun_access_key" yaml:"aliyun_access_key"`
|
||||
AliyunSecretKey string `mapstructure:"aliyun_secret_key" yaml:"aliyun_secret_key"`
|
||||
AliyunRegion string `mapstructure:"aliyun_region" yaml:"aliyun_region"`
|
||||
// 腾讯云配置
|
||||
TencentSecretID string `mapstructure:"tencent_secret_id" yaml:"tencent_secret_id"`
|
||||
TencentSecretKey string `mapstructure:"tencent_secret_key" yaml:"tencent_secret_key"`
|
||||
// 百度云配置
|
||||
BaiduAPIKey string `mapstructure:"baidu_api_key" yaml:"baidu_api_key"`
|
||||
BaiduSecretKey string `mapstructure:"baidu_secret_key" yaml:"baidu_secret_key"`
|
||||
// 是否自动审核
|
||||
AutoAudit bool `mapstructure:"auto_audit" yaml:"auto_audit"`
|
||||
// 审核超时时间(秒)
|
||||
Timeout int `mapstructure:"timeout" yaml:"timeout"`
|
||||
}
|
||||
|
||||
// NewAuditService 创建内容审核服务
|
||||
func NewAuditService(db *gorm.DB, config *AuditConfig) AuditService {
|
||||
s := &auditServiceImpl{
|
||||
db: db,
|
||||
config: config,
|
||||
}
|
||||
|
||||
// 根据配置初始化提供商
|
||||
if config.Enabled {
|
||||
provider := s.initProvider(config.Provider)
|
||||
s.provider = provider
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// initProvider 根据配置初始化审核服务提供商
|
||||
func (s *auditServiceImpl) initProvider(providerType string) AuditServiceProvider {
|
||||
switch strings.ToLower(providerType) {
|
||||
case "aliyun":
|
||||
return NewAliyunAuditProvider(s.config.AliyunAccessKey, s.config.AliyunSecretKey, s.config.AliyunRegion)
|
||||
case "tencent":
|
||||
return NewTencentAuditProvider(s.config.TencentSecretID, s.config.TencentSecretKey)
|
||||
case "baidu":
|
||||
return NewBaiduAuditProvider(s.config.BaiduAPIKey, s.config.BaiduSecretKey)
|
||||
case "local":
|
||||
fallthrough
|
||||
default:
|
||||
// 默认使用本地审核服务
|
||||
return NewLocalAuditProvider()
|
||||
}
|
||||
}
|
||||
|
||||
// AuditText 审核文本
|
||||
func (s *auditServiceImpl) AuditText(ctx context.Context, text string, auditType string) (*AuditResult, error) {
|
||||
if !s.config.Enabled {
|
||||
// 如果审核服务未启用,直接返回通过
|
||||
return &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Detail: "Audit service disabled",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if text == "" {
|
||||
return &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Detail: "Empty text",
|
||||
}, nil
|
||||
}
|
||||
|
||||
var result *AuditResult
|
||||
var err error
|
||||
|
||||
// 使用提供商审核
|
||||
if s.provider != nil {
|
||||
result, err = s.provider.AuditText(ctx, text, auditType)
|
||||
} else {
|
||||
// 如果没有设置提供商,使用本地审核
|
||||
localProvider := NewLocalAuditProvider()
|
||||
result, err = localProvider.AuditText(ctx, text, auditType)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Audit text error: %v", err)
|
||||
return &AuditResult{
|
||||
Pass: false,
|
||||
Risk: "high",
|
||||
Suggest: "review",
|
||||
Detail: fmt.Sprintf("Audit error: %v", err),
|
||||
}, err
|
||||
}
|
||||
|
||||
// 记录审核日志
|
||||
go s.saveAuditLog(ctx, "text", "", text, auditType, result)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AuditImage 审核图片
|
||||
func (s *auditServiceImpl) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
|
||||
if !s.config.Enabled {
|
||||
return &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Detail: "Audit service disabled",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if imageURL == "" {
|
||||
return &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Detail: "Empty image URL",
|
||||
}, nil
|
||||
}
|
||||
|
||||
var result *AuditResult
|
||||
var err error
|
||||
|
||||
// 使用提供商审核
|
||||
if s.provider != nil {
|
||||
result, err = s.provider.AuditImage(ctx, imageURL)
|
||||
} else {
|
||||
// 如果没有设置提供商,使用本地审核
|
||||
localProvider := NewLocalAuditProvider()
|
||||
result, err = localProvider.AuditImage(ctx, imageURL)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Audit image error: %v", err)
|
||||
return &AuditResult{
|
||||
Pass: false,
|
||||
Risk: "high",
|
||||
Suggest: "review",
|
||||
Detail: fmt.Sprintf("Audit error: %v", err),
|
||||
}, err
|
||||
}
|
||||
|
||||
// 记录审核日志
|
||||
go s.saveAuditLog(ctx, "image", "", "", "image", result)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetAuditResult 获取审核结果
|
||||
func (s *auditServiceImpl) GetAuditResult(ctx context.Context, auditID string) (*AuditResult, error) {
|
||||
if s.db == nil || auditID == "" {
|
||||
return nil, fmt.Errorf("invalid audit ID")
|
||||
}
|
||||
|
||||
var auditLog model.AuditLog
|
||||
if err := s.db.Where("id = ?", auditID).First(&auditLog).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &AuditResult{
|
||||
Pass: auditLog.Result == model.AuditResultPass,
|
||||
Risk: string(auditLog.RiskLevel),
|
||||
Suggest: auditLog.Suggestion,
|
||||
Detail: auditLog.Detail,
|
||||
}
|
||||
|
||||
// 解析标签
|
||||
if auditLog.Labels != "" {
|
||||
json.Unmarshal([]byte(auditLog.Labels), &result.Labels)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SetProvider 设置审核服务提供商
|
||||
func (s *auditServiceImpl) SetProvider(provider AuditServiceProvider) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.provider = provider
|
||||
}
|
||||
|
||||
// GetProvider 获取当前审核服务提供商
|
||||
func (s *auditServiceImpl) GetProvider() AuditServiceProvider {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.provider
|
||||
}
|
||||
|
||||
// saveAuditLog 保存审核日志
|
||||
func (s *auditServiceImpl) saveAuditLog(ctx context.Context, contentType, content, imageURL, auditType string, result *AuditResult) {
|
||||
if s.db == nil {
|
||||
return
|
||||
}
|
||||
|
||||
auditLog := model.AuditLog{
|
||||
ContentType: contentType,
|
||||
Content: content,
|
||||
ContentURL: imageURL,
|
||||
AuditType: auditType,
|
||||
Labels: strings.Join(result.Labels, ","),
|
||||
Suggestion: result.Suggest,
|
||||
Detail: result.Detail,
|
||||
Source: model.AuditSourceAuto,
|
||||
Status: "completed",
|
||||
}
|
||||
|
||||
if result.Pass {
|
||||
auditLog.Result = model.AuditResultPass
|
||||
} else if result.Suggest == "review" {
|
||||
auditLog.Result = model.AuditResultReview
|
||||
} else {
|
||||
auditLog.Result = model.AuditResultBlock
|
||||
}
|
||||
|
||||
switch result.Risk {
|
||||
case "low":
|
||||
auditLog.RiskLevel = model.AuditRiskLevelLow
|
||||
case "medium":
|
||||
auditLog.RiskLevel = model.AuditRiskLevelMedium
|
||||
case "high":
|
||||
auditLog.RiskLevel = model.AuditRiskLevelHigh
|
||||
default:
|
||||
auditLog.RiskLevel = model.AuditRiskLevelLow
|
||||
}
|
||||
|
||||
if err := s.db.Create(&auditLog).Error; err != nil {
|
||||
log.Printf("Failed to save audit log: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 本地审核服务提供商 ====================
|
||||
|
||||
// localAuditProvider 本地审核服务提供商
|
||||
type localAuditProvider struct {
|
||||
// 可以注入敏感词服务进行本地审核
|
||||
sensitiveService SensitiveService
|
||||
}
|
||||
|
||||
// NewLocalAuditProvider 创建本地审核服务提供商
|
||||
func NewLocalAuditProvider() AuditServiceProvider {
|
||||
return &localAuditProvider{
|
||||
sensitiveService: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 获取提供商名称
|
||||
func (p *localAuditProvider) GetName() string {
|
||||
return "local"
|
||||
}
|
||||
|
||||
// AuditText 审核文本
|
||||
func (p *localAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
|
||||
// 本地审核逻辑
|
||||
// 1. 敏感词检查
|
||||
// 2. 规则匹配
|
||||
// 3. 简单的关键词检测
|
||||
|
||||
result := &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Labels: []string{},
|
||||
Provider: "local",
|
||||
}
|
||||
|
||||
// 如果有敏感词服务,使用它进行检测
|
||||
if p.sensitiveService != nil {
|
||||
hasSensitive, words := p.sensitiveService.Check(ctx, text)
|
||||
if hasSensitive {
|
||||
result.Pass = false
|
||||
result.Risk = "high"
|
||||
result.Suggest = "block"
|
||||
result.Detail = fmt.Sprintf("包含敏感词: %s", strings.Join(words, ","))
|
||||
result.Labels = append(result.Labels, "sensitive")
|
||||
}
|
||||
}
|
||||
|
||||
// 简单的关键词检测规则
|
||||
// 实际项目中应该从数据库加载
|
||||
suspiciousPatterns := []string{
|
||||
"诈骗",
|
||||
"钓鱼",
|
||||
"木马",
|
||||
"病毒",
|
||||
}
|
||||
|
||||
for _, pattern := range suspiciousPatterns {
|
||||
if strings.Contains(text, pattern) {
|
||||
result.Pass = false
|
||||
result.Risk = "high"
|
||||
result.Suggest = "block"
|
||||
result.Labels = append(result.Labels, "suspicious")
|
||||
if result.Detail == "" {
|
||||
result.Detail = fmt.Sprintf("包含可疑内容: %s", pattern)
|
||||
} else {
|
||||
result.Detail += fmt.Sprintf(", %s", pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AuditImage 审核图片
|
||||
func (p *localAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
|
||||
// 本地图片审核逻辑
|
||||
// 1. 图片URL合法性检查
|
||||
// 2. 图片格式检查
|
||||
// 3. 可以扩展接入本地图片识别服务
|
||||
|
||||
result := &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Labels: []string{},
|
||||
Provider: "local",
|
||||
}
|
||||
|
||||
// 检查URL是否为空
|
||||
if imageURL == "" {
|
||||
result.Detail = "Empty image URL"
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 检查是否为支持的图片URL格式
|
||||
validPrefixes := []string{"http://", "https://", "s3://", "oss://", "cos://"}
|
||||
isValid := false
|
||||
for _, prefix := range validPrefixes {
|
||||
if strings.HasPrefix(strings.ToLower(imageURL), prefix) {
|
||||
isValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isValid {
|
||||
result.Pass = false
|
||||
result.Risk = "medium"
|
||||
result.Suggest = "review"
|
||||
result.Detail = "Invalid image URL format"
|
||||
result.Labels = append(result.Labels, "invalid_url")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SetSensitiveService 设置敏感词服务
|
||||
func (p *localAuditProvider) SetSensitiveService(ss SensitiveService) {
|
||||
p.sensitiveService = ss
|
||||
}
|
||||
|
||||
// ==================== 阿里云审核服务提供商 ====================
|
||||
|
||||
// aliyunAuditProvider 阿里云审核服务提供商
|
||||
type aliyunAuditProvider struct {
|
||||
accessKey string
|
||||
secretKey string
|
||||
region string
|
||||
}
|
||||
|
||||
// NewAliyunAuditProvider 创建阿里云审核服务提供商
|
||||
func NewAliyunAuditProvider(accessKey, secretKey, region string) AuditServiceProvider {
|
||||
return &aliyunAuditProvider{
|
||||
accessKey: accessKey,
|
||||
secretKey: secretKey,
|
||||
region: region,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 获取提供商名称
|
||||
func (p *aliyunAuditProvider) GetName() string {
|
||||
return "aliyun"
|
||||
}
|
||||
|
||||
// AuditText 审核文本
|
||||
func (p *aliyunAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
|
||||
// 阿里云内容安全API调用
|
||||
// 实际项目中需要实现阿里云SDK调用
|
||||
// 这里预留接口
|
||||
|
||||
result := &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Labels: []string{},
|
||||
Provider: "aliyun",
|
||||
Detail: "Aliyun audit not implemented, using pass",
|
||||
}
|
||||
|
||||
// TODO: 实现阿里云内容安全API调用
|
||||
// 具体参考: https://help.aliyun.com/document_detail/28417.html
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AuditImage 审核图片
|
||||
func (p *aliyunAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
|
||||
result := &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Labels: []string{},
|
||||
Provider: "aliyun",
|
||||
Detail: "Aliyun image audit not implemented, using pass",
|
||||
}
|
||||
|
||||
// TODO: 实现阿里云图片审核API调用
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ==================== 腾讯云审核服务提供商 ====================
|
||||
|
||||
// tencentAuditProvider 腾讯云审核服务提供商
|
||||
type tencentAuditProvider struct {
|
||||
secretID string
|
||||
secretKey string
|
||||
}
|
||||
|
||||
// NewTencentAuditProvider 创建腾讯云审核服务提供商
|
||||
func NewTencentAuditProvider(secretID, secretKey string) AuditServiceProvider {
|
||||
return &tencentAuditProvider{
|
||||
secretID: secretID,
|
||||
secretKey: secretKey,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 获取提供商名称
|
||||
func (p *tencentAuditProvider) GetName() string {
|
||||
return "tencent"
|
||||
}
|
||||
|
||||
// AuditText 审核文本
|
||||
func (p *tencentAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
|
||||
result := &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Labels: []string{},
|
||||
Provider: "tencent",
|
||||
Detail: "Tencent audit not implemented, using pass",
|
||||
}
|
||||
|
||||
// TODO: 实现腾讯云内容审核API调用
|
||||
// 具体参考: https://cloud.tencent.com/document/product/1124/64508
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AuditImage 审核图片
|
||||
func (p *tencentAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
|
||||
result := &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Labels: []string{},
|
||||
Provider: "tencent",
|
||||
Detail: "Tencent image audit not implemented, using pass",
|
||||
}
|
||||
|
||||
// TODO: 实现腾讯云图片审核API调用
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ==================== 百度云审核服务提供商 ====================
|
||||
|
||||
// baiduAuditProvider 百度云审核服务提供商
|
||||
type baiduAuditProvider struct {
|
||||
apiKey string
|
||||
secretKey string
|
||||
}
|
||||
|
||||
// NewBaiduAuditProvider 创建百度云审核服务提供商
|
||||
func NewBaiduAuditProvider(apiKey, secretKey string) AuditServiceProvider {
|
||||
return &baiduAuditProvider{
|
||||
apiKey: apiKey,
|
||||
secretKey: secretKey,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 获取提供商名称
|
||||
func (p *baiduAuditProvider) GetName() string {
|
||||
return "baidu"
|
||||
}
|
||||
|
||||
// AuditText 审核文本
|
||||
func (p *baiduAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
|
||||
result := &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Labels: []string{},
|
||||
Provider: "baidu",
|
||||
Detail: "Baidu audit not implemented, using pass",
|
||||
}
|
||||
|
||||
// TODO: 实现百度云内容审核API调用
|
||||
// 具体参考: https://cloud.baidu.com/doc/ANTISPAM/s/Jjw0r1iF6
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AuditImage 审核图片
|
||||
func (p *baiduAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
|
||||
result := &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Labels: []string{},
|
||||
Provider: "baidu",
|
||||
Detail: "Baidu image audit not implemented, using pass",
|
||||
}
|
||||
|
||||
// TODO: 实现百度云图片审核API调用
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ==================== 审核结果回调处理 ====================
|
||||
|
||||
// AuditCallback 审核回调处理
|
||||
type AuditCallback struct {
|
||||
service AuditService
|
||||
}
|
||||
|
||||
// NewAuditCallback 创建审核回调处理
|
||||
func NewAuditCallback(service AuditService) *AuditCallback {
|
||||
return &AuditCallback{
|
||||
service: service,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleTextCallback 处理文本审核回调
|
||||
func (c *AuditCallback) HandleTextCallback(ctx context.Context, auditID string, result *AuditResult) error {
|
||||
if c.service == nil || auditID == "" || result == nil {
|
||||
return fmt.Errorf("invalid parameters")
|
||||
}
|
||||
|
||||
log.Printf("Processing text audit callback: auditID=%s, result=%+v", auditID, result)
|
||||
|
||||
// 根据审核结果执行相应操作
|
||||
// 例如: 更新帖子状态、发送通知等
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleImageCallback 处理图片审核回调
|
||||
func (c *AuditCallback) HandleImageCallback(ctx context.Context, auditID string, result *AuditResult) error {
|
||||
if c.service == nil || auditID == "" || result == nil {
|
||||
return fmt.Errorf("invalid parameters")
|
||||
}
|
||||
|
||||
log.Printf("Processing image audit callback: auditID=%s, result=%+v", auditID, result)
|
||||
|
||||
// 根据审核结果执行相应操作
|
||||
// 例如: 更新图片状态、删除违规图片等
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== 辅助函数 ====================
|
||||
|
||||
// IsContentSafe 判断内容是否安全
|
||||
func IsContentSafe(result *AuditResult) bool {
|
||||
if result == nil {
|
||||
return true
|
||||
}
|
||||
return result.Pass && result.Suggest != "block"
|
||||
}
|
||||
|
||||
// NeedReview 判断内容是否需要人工复审
|
||||
func NeedReview(result *AuditResult) bool {
|
||||
if result == nil {
|
||||
return false
|
||||
}
|
||||
return result.Suggest == "review"
|
||||
}
|
||||
|
||||
// GetRiskLevel 获取风险等级
|
||||
func GetRiskLevel(result *AuditResult) string {
|
||||
if result == nil {
|
||||
return "low"
|
||||
}
|
||||
return result.Risk
|
||||
}
|
||||
|
||||
// FormatAuditResult 格式化审核结果为字符串
|
||||
func FormatAuditResult(result *AuditResult) string {
|
||||
if result == nil {
|
||||
return "{}"
|
||||
}
|
||||
data, _ := json.Marshal(result)
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// ParseAuditResult 从字符串解析审核结果
|
||||
func ParseAuditResult(data string) (*AuditResult, error) {
|
||||
if data == "" {
|
||||
return nil, fmt.Errorf("empty data")
|
||||
}
|
||||
var result AuditResult
|
||||
if err := json.Unmarshal([]byte(data), &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ==================== 审核日志查询 ====================
|
||||
|
||||
// GetAuditLogs 获取审核日志列表
|
||||
func GetAuditLogs(db *gorm.DB, targetType string, targetID string, result string, page, pageSize int) ([]model.AuditLog, int64, error) {
|
||||
query := db.Model(&model.AuditLog{})
|
||||
|
||||
if targetType != "" {
|
||||
query = query.Where("target_type = ?", targetType)
|
||||
}
|
||||
if targetID != "" {
|
||||
query = query.Where("target_id = ?", targetID)
|
||||
}
|
||||
if result != "" {
|
||||
query = query.Where("result = ?", result)
|
||||
}
|
||||
|
||||
var total int64
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var logs []model.AuditLog
|
||||
offset := (page - 1) * pageSize
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&logs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return logs, total, nil
|
||||
}
|
||||
|
||||
// ==================== 定时任务 ====================
|
||||
|
||||
// AuditScheduler 审核调度器
|
||||
type AuditScheduler struct {
|
||||
db *gorm.DB
|
||||
service AuditService
|
||||
interval time.Duration
|
||||
stopCh chan bool
|
||||
}
|
||||
|
||||
// NewAuditScheduler 创建审核调度器
|
||||
func NewAuditScheduler(db *gorm.DB, service AuditService, interval time.Duration) *AuditScheduler {
|
||||
return &AuditScheduler{
|
||||
db: db,
|
||||
service: service,
|
||||
interval: interval,
|
||||
stopCh: make(chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动调度器
|
||||
func (s *AuditScheduler) Start() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(s.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.processPendingAudits()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Stop 停止调度器
|
||||
func (s *AuditScheduler) Stop() {
|
||||
s.stopCh <- true
|
||||
}
|
||||
|
||||
// processPendingAudits 处理待审核内容
|
||||
func (s *AuditScheduler) processPendingAudits() {
|
||||
// 查询待审核的内容
|
||||
// 1. 查询审核状态为 pending 的记录
|
||||
// 2. 调用审核服务
|
||||
// 3. 更新审核状态
|
||||
|
||||
// 示例逻辑,实际需要根据业务需求实现
|
||||
log.Println("Processing pending audits...")
|
||||
}
|
||||
|
||||
// CleanupOldLogs 清理旧的审核日志
|
||||
func CleanupOldLogs(db *gorm.DB, days int) error {
|
||||
// 清理指定天数之前的审核日志
|
||||
cutoffTime := time.Now().AddDate(0, 0, -days)
|
||||
return db.Where("created_at < ? AND result = ?", cutoffTime, model.AuditResultPass).Delete(&model.AuditLog{}).Error
|
||||
}
|
||||
622
internal/service/chat_service.go
Normal file
622
internal/service/chat_service.go
Normal file
@@ -0,0 +1,622 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/websocket"
|
||||
"carrot_bbs/internal/repository"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 撤回消息的时间限制(2分钟)
|
||||
const RecallMessageTimeout = 2 * time.Minute
|
||||
|
||||
// ChatService 聊天服务接口
|
||||
type ChatService interface {
|
||||
// 会话管理
|
||||
GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error)
|
||||
GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error)
|
||||
GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error)
|
||||
DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error
|
||||
SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error
|
||||
|
||||
// 消息操作
|
||||
SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
|
||||
GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error)
|
||||
GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error)
|
||||
GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error)
|
||||
|
||||
// 已读管理
|
||||
MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error
|
||||
GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error)
|
||||
GetAllUnreadCount(ctx context.Context, userID string) (int64, error)
|
||||
|
||||
// 消息扩展功能
|
||||
RecallMessage(ctx context.Context, messageID string, userID string) error
|
||||
DeleteMessage(ctx context.Context, messageID string, userID string) error
|
||||
|
||||
// WebSocket相关
|
||||
SendTyping(ctx context.Context, senderID string, conversationID string)
|
||||
BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string)
|
||||
|
||||
// 系统消息推送
|
||||
IsUserOnline(userID string) bool
|
||||
PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error
|
||||
PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error
|
||||
PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error
|
||||
|
||||
// 仅保存消息到数据库,不发送 WebSocket 推送(供群聊等自行推送的场景使用)
|
||||
SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
|
||||
}
|
||||
|
||||
// chatServiceImpl 聊天服务实现
|
||||
type chatServiceImpl struct {
|
||||
db *gorm.DB
|
||||
repo *repository.MessageRepository
|
||||
userRepo *repository.UserRepository
|
||||
sensitive SensitiveService
|
||||
wsManager *websocket.WebSocketManager
|
||||
}
|
||||
|
||||
// NewChatService 创建聊天服务
|
||||
func NewChatService(
|
||||
db *gorm.DB,
|
||||
repo *repository.MessageRepository,
|
||||
userRepo *repository.UserRepository,
|
||||
sensitive SensitiveService,
|
||||
wsManager *websocket.WebSocketManager,
|
||||
) ChatService {
|
||||
return &chatServiceImpl{
|
||||
db: db,
|
||||
repo: repo,
|
||||
userRepo: userRepo,
|
||||
sensitive: sensitive,
|
||||
wsManager: wsManager,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrCreateConversation 获取或创建私聊会话
|
||||
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
|
||||
return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
|
||||
}
|
||||
|
||||
// GetConversationList 获取用户的会话列表
|
||||
func (s *chatServiceImpl) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
|
||||
return s.repo.GetConversations(userID, page, pageSize)
|
||||
}
|
||||
|
||||
// GetConversationByID 获取会话详情
|
||||
func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) {
|
||||
// 验证用户是否是会话参与者
|
||||
participant, err := s.repo.GetParticipant(conversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("conversation not found or no permission")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
// 获取会话信息
|
||||
conv, err := s.repo.GetConversation(conversationID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
||||
}
|
||||
|
||||
// 填充用户的已读位置信息
|
||||
_ = participant // 可以用于返回已读位置等信息
|
||||
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
// DeleteConversationForSelf 仅自己删除会话
|
||||
func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error {
|
||||
participant, err := s.repo.GetParticipant(conversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("conversation not found or no permission")
|
||||
}
|
||||
return fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
if participant.ConversationID == "" {
|
||||
return errors.New("conversation not found or no permission")
|
||||
}
|
||||
|
||||
if err := s.repo.HideConversationForUser(conversationID, userID); err != nil {
|
||||
return fmt.Errorf("failed to hide conversation: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetConversationPinned 设置会话置顶(用户维度)
|
||||
func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error {
|
||||
participant, err := s.repo.GetParticipant(conversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("conversation not found or no permission")
|
||||
}
|
||||
return fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
if participant.ConversationID == "" {
|
||||
return errors.New("conversation not found or no permission")
|
||||
}
|
||||
|
||||
if err := s.repo.UpdatePinned(conversationID, userID, isPinned); err != nil {
|
||||
return fmt.Errorf("failed to update pinned status: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendMessage 发送消息(使用 segments)
|
||||
func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
|
||||
// 首先验证会话是否存在
|
||||
conv, err := s.repo.GetConversation(conversationID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("会话不存在,请重新创建会话")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
||||
}
|
||||
|
||||
// 拉黑限制:仅拦截“被拉黑方 -> 拉黑人”方向
|
||||
if conv.Type == model.ConversationTypePrivate && s.userRepo != nil {
|
||||
participants, pErr := s.repo.GetConversationParticipants(conversationID)
|
||||
if pErr != nil {
|
||||
return nil, fmt.Errorf("failed to get participants: %w", pErr)
|
||||
}
|
||||
var sentCount *int64
|
||||
for _, p := range participants {
|
||||
if p.UserID == senderID {
|
||||
continue
|
||||
}
|
||||
blocked, bErr := s.userRepo.IsBlocked(p.UserID, senderID)
|
||||
if bErr != nil {
|
||||
return nil, fmt.Errorf("failed to check block status: %w", bErr)
|
||||
}
|
||||
if blocked {
|
||||
return nil, ErrUserBlocked
|
||||
}
|
||||
|
||||
// 陌生人限制:对方未回关前,只允许发送一条文本消息,且禁止发送图片
|
||||
isFollowedBack, fErr := s.userRepo.IsFollowing(p.UserID, senderID)
|
||||
if fErr != nil {
|
||||
return nil, fmt.Errorf("failed to check follow status: %w", fErr)
|
||||
}
|
||||
if !isFollowedBack {
|
||||
if containsImageSegment(segments) {
|
||||
return nil, errors.New("对方未关注你,暂不支持发送图片")
|
||||
}
|
||||
if sentCount == nil {
|
||||
c, cErr := s.repo.CountMessagesBySenderInConversation(conversationID, senderID)
|
||||
if cErr != nil {
|
||||
return nil, fmt.Errorf("failed to count sender messages: %w", cErr)
|
||||
}
|
||||
sentCount = &c
|
||||
}
|
||||
if *sentCount >= 1 {
|
||||
return nil, errors.New("对方未关注你前,仅允许发送一条消息")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 验证用户是否是会话参与者
|
||||
participant, err := s.repo.GetParticipant(conversationID, senderID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("您不是该会话的参与者")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
// 创建消息
|
||||
message := &model.Message{
|
||||
ConversationID: conversationID,
|
||||
SenderID: senderID, // 直接使用string类型的UUID
|
||||
Segments: segments,
|
||||
ReplyToID: replyToID,
|
||||
Status: model.MessageStatusNormal,
|
||||
}
|
||||
|
||||
// 使用事务创建消息并更新seq
|
||||
if err := s.repo.CreateMessageWithSeq(message); err != nil {
|
||||
return nil, fmt.Errorf("failed to save message: %w", err)
|
||||
}
|
||||
|
||||
// 发送消息给接收者
|
||||
log.Printf("[DEBUG SendMessage] 私聊消息 segments 类型: %T, 值: %+v", message.Segments, message.Segments)
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, websocket.ChatMessage{
|
||||
ID: message.ID,
|
||||
ConversationID: message.ConversationID,
|
||||
SenderID: senderID,
|
||||
Segments: message.Segments,
|
||||
Seq: message.Seq,
|
||||
CreatedAt: message.CreatedAt.UnixMilli(),
|
||||
})
|
||||
|
||||
// 获取会话中的其他参与者
|
||||
participants, err := s.repo.GetConversationParticipants(conversationID)
|
||||
if err == nil {
|
||||
for _, p := range participants {
|
||||
// 不发给自己
|
||||
if p.UserID == senderID {
|
||||
continue
|
||||
}
|
||||
// 如果接收者在线,发送实时消息
|
||||
if s.wsManager != nil {
|
||||
isOnline := s.wsManager.IsUserOnline(p.UserID)
|
||||
log.Printf("[DEBUG SendMessage] 接收者 UserID=%s, 在线状态=%v", p.UserID, isOnline)
|
||||
if isOnline {
|
||||
log.Printf("[DEBUG SendMessage] 发送WebSocket消息给 UserID=%s, 消息类型=%s", p.UserID, wsMsg.Type)
|
||||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Printf("[DEBUG SendMessage] 获取参与者失败: %v", err)
|
||||
}
|
||||
|
||||
_ = participant // 避免未使用变量警告
|
||||
|
||||
return message, nil
|
||||
}
|
||||
|
||||
func containsImageSegment(segments model.MessageSegments) bool {
|
||||
for _, seg := range segments {
|
||||
if seg.Type == string(model.ContentTypeImage) || seg.Type == "image" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetMessages 获取消息历史(分页)
|
||||
func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) {
|
||||
// 验证用户是否是会话参与者
|
||||
_, err := s.repo.GetParticipant(conversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, 0, errors.New("conversation not found or no permission")
|
||||
}
|
||||
return nil, 0, fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
return s.repo.GetMessages(conversationID, page, pageSize)
|
||||
}
|
||||
|
||||
// GetMessagesAfterSeq 获取指定seq之后的消息(用于增量同步)
|
||||
func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) {
|
||||
// 验证用户是否是会话参与者
|
||||
_, err := s.repo.GetParticipant(conversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("conversation not found or no permission")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
return s.repo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
|
||||
}
|
||||
|
||||
// GetMessagesBeforeSeq 获取指定seq之前的历史消息(用于下拉加载更多)
|
||||
func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) {
|
||||
// 验证用户是否是会话参与者
|
||||
_, err := s.repo.GetParticipant(conversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("conversation not found or no permission")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
|
||||
return s.repo.GetMessagesBeforeSeq(conversationID, beforeSeq, limit)
|
||||
}
|
||||
|
||||
// MarkAsRead 标记已读
|
||||
func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error {
|
||||
// 验证用户是否是会话参与者
|
||||
_, err := s.repo.GetParticipant(conversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("conversation not found or no permission")
|
||||
}
|
||||
return fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
// 更新参与者的已读位置
|
||||
err = s.repo.UpdateLastReadSeq(conversationID, userID, seq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update last read seq: %w", err)
|
||||
}
|
||||
|
||||
// 发送已读回执(作为 meta 事件)
|
||||
if s.wsManager != nil {
|
||||
wsMsg := websocket.CreateWSMessage("meta", map[string]interface{}{
|
||||
"detail_type": websocket.MetaDetailTypeRead,
|
||||
"conversation_id": conversationID,
|
||||
"seq": seq,
|
||||
"user_id": userID,
|
||||
})
|
||||
|
||||
// 获取会话中的所有参与者
|
||||
participants, err := s.repo.GetConversationParticipants(conversationID)
|
||||
if err == nil {
|
||||
// 推送给会话中的所有参与者(包括自己)
|
||||
for _, p := range participants {
|
||||
if s.wsManager.IsUserOnline(p.UserID) {
|
||||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUnreadCount 获取指定会话的未读消息数
|
||||
func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
|
||||
// 验证用户是否是会话参与者
|
||||
_, err := s.repo.GetParticipant(conversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return 0, errors.New("conversation not found or no permission")
|
||||
}
|
||||
return 0, fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
return s.repo.GetUnreadCount(conversationID, userID)
|
||||
}
|
||||
|
||||
// GetAllUnreadCount 获取所有会话的未读消息总数
|
||||
func (s *chatServiceImpl) GetAllUnreadCount(ctx context.Context, userID string) (int64, error) {
|
||||
return s.repo.GetAllUnreadCount(userID)
|
||||
}
|
||||
|
||||
// RecallMessage 撤回消息(2分钟内)
|
||||
func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, userID string) error {
|
||||
// 获取消息
|
||||
var message model.Message
|
||||
err := s.db.First(&message, "id = ?", messageID).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("message not found")
|
||||
}
|
||||
return fmt.Errorf("failed to get message: %w", err)
|
||||
}
|
||||
|
||||
// 验证是否是消息发送者
|
||||
if message.SenderIDStr() != userID {
|
||||
return errors.New("can only recall your own messages")
|
||||
}
|
||||
|
||||
// 验证消息是否已被撤回
|
||||
if message.Status == model.MessageStatusRecalled {
|
||||
return errors.New("message already recalled")
|
||||
}
|
||||
|
||||
// 验证是否在2分钟内
|
||||
if time.Since(message.CreatedAt) > RecallMessageTimeout {
|
||||
return errors.New("message recall timeout (2 minutes)")
|
||||
}
|
||||
|
||||
// 更新消息状态为已撤回
|
||||
err = s.db.Model(&message).Update("status", model.MessageStatusRecalled).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to recall message: %w", err)
|
||||
}
|
||||
|
||||
// 发送撤回通知
|
||||
if s.wsManager != nil {
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeRecall, map[string]interface{}{
|
||||
"messageId": messageID,
|
||||
"conversationId": message.ConversationID,
|
||||
"senderId": userID,
|
||||
})
|
||||
|
||||
// 通知会话中的所有参与者
|
||||
participants, err := s.repo.GetConversationParticipants(message.ConversationID)
|
||||
if err == nil {
|
||||
for _, p := range participants {
|
||||
if s.wsManager.IsUserOnline(p.UserID) {
|
||||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteMessage 删除消息(仅对自己可见)
|
||||
func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, userID string) error {
|
||||
// 获取消息
|
||||
var message model.Message
|
||||
err := s.db.First(&message, "id = ?", messageID).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("message not found")
|
||||
}
|
||||
return fmt.Errorf("failed to get message: %w", err)
|
||||
}
|
||||
|
||||
// 验证用户是否是会话参与者
|
||||
_, err = s.repo.GetParticipant(message.ConversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("no permission to delete this message")
|
||||
}
|
||||
return fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
// 对于删除消息,我们使用软删除,但需要确保只对当前用户隐藏
|
||||
// 这里简化处理:只有发送者可以删除自己的消息
|
||||
if message.SenderIDStr() != userID {
|
||||
return errors.New("can only delete your own messages")
|
||||
}
|
||||
|
||||
// 更新消息状态为已删除
|
||||
err = s.db.Model(&message).Update("status", model.MessageStatusDeleted).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendTyping 发送正在输入状态
|
||||
func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) {
|
||||
if s.wsManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 验证用户是否是会话参与者
|
||||
_, err := s.repo.GetParticipant(conversationID, senderID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取会话中的其他参与者
|
||||
participants, err := s.repo.GetConversationParticipants(conversationID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, p := range participants {
|
||||
if p.UserID == senderID {
|
||||
continue
|
||||
}
|
||||
// 发送正在输入状态
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeTyping, map[string]string{
|
||||
"conversationId": conversationID,
|
||||
"senderId": senderID,
|
||||
})
|
||||
|
||||
if s.wsManager.IsUserOnline(p.UserID) {
|
||||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastMessage 广播消息给用户
|
||||
func (s *chatServiceImpl) BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string) {
|
||||
if s.wsManager != nil {
|
||||
s.wsManager.SendToUser(targetUser, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// IsUserOnline 检查用户是否在线
|
||||
func (s *chatServiceImpl) IsUserOnline(userID string) bool {
|
||||
if s.wsManager == nil {
|
||||
return false
|
||||
}
|
||||
return s.wsManager.IsUserOnline(userID)
|
||||
}
|
||||
|
||||
// PushSystemMessage 推送系统消息给指定用户
|
||||
func (s *chatServiceImpl) PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error {
|
||||
if s.wsManager == nil {
|
||||
return errors.New("websocket manager not available")
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return errors.New("user is offline")
|
||||
}
|
||||
|
||||
sysMsg := &websocket.SystemMessage{
|
||||
ID: "", // 由调用方生成
|
||||
Type: msgType,
|
||||
Title: title,
|
||||
Content: content,
|
||||
Data: data,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PushNotificationMessage 推送通知消息给指定用户
|
||||
func (s *chatServiceImpl) PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error {
|
||||
if s.wsManager == nil {
|
||||
return errors.New("websocket manager not available")
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return errors.New("user is offline")
|
||||
}
|
||||
|
||||
// 确保时间戳已设置
|
||||
if notification.CreatedAt == 0 {
|
||||
notification.CreatedAt = time.Now().UnixMilli()
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PushAnnouncementMessage 广播公告消息给所有在线用户
|
||||
func (s *chatServiceImpl) PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error {
|
||||
if s.wsManager == nil {
|
||||
return errors.New("websocket manager not available")
|
||||
}
|
||||
|
||||
// 确保时间戳已设置
|
||||
if announcement.CreatedAt == 0 {
|
||||
announcement.CreatedAt = time.Now().UnixMilli()
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
|
||||
s.wsManager.Broadcast(wsMsg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveMessage 仅保存消息到数据库,不发送 WebSocket 推送
|
||||
// 适用于群聊等由调用方自行负责推送的场景
|
||||
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
|
||||
// 验证会话是否存在
|
||||
_, err := s.repo.GetConversation(conversationID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("会话不存在,请重新创建会话")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
||||
}
|
||||
|
||||
// 验证用户是否是会话参与者
|
||||
_, err = s.repo.GetParticipant(conversationID, senderID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("您不是该会话的参与者")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
message := &model.Message{
|
||||
ConversationID: conversationID,
|
||||
SenderID: senderID,
|
||||
Segments: segments,
|
||||
ReplyToID: replyToID,
|
||||
Status: model.MessageStatusNormal,
|
||||
}
|
||||
|
||||
if err := s.repo.CreateMessageWithSeq(message); err != nil {
|
||||
return nil, fmt.Errorf("failed to save message: %w", err)
|
||||
}
|
||||
|
||||
return message, nil
|
||||
}
|
||||
273
internal/service/comment_service.go
Normal file
273
internal/service/comment_service.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/gorse"
|
||||
"carrot_bbs/internal/repository"
|
||||
)
|
||||
|
||||
// CommentService 评论服务
|
||||
type CommentService struct {
|
||||
commentRepo *repository.CommentRepository
|
||||
postRepo *repository.PostRepository
|
||||
systemMessageService SystemMessageService
|
||||
gorseClient gorse.Client
|
||||
postAIService *PostAIService
|
||||
}
|
||||
|
||||
// NewCommentService 创建评论服务
|
||||
func NewCommentService(commentRepo *repository.CommentRepository, postRepo *repository.PostRepository, systemMessageService SystemMessageService, gorseClient gorse.Client, postAIService *PostAIService) *CommentService {
|
||||
return &CommentService{
|
||||
commentRepo: commentRepo,
|
||||
postRepo: postRepo,
|
||||
systemMessageService: systemMessageService,
|
||||
gorseClient: gorseClient,
|
||||
postAIService: postAIService,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建评论
|
||||
func (s *CommentService) Create(ctx context.Context, postID, userID, content string, parentID *string, images string, imageURLs []string) (*model.Comment, error) {
|
||||
if s.postAIService != nil {
|
||||
// 采用异步审核,前端先立即返回
|
||||
}
|
||||
|
||||
// 获取帖子信息用于发送通知
|
||||
post, err := s.postRepo.GetByID(postID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
comment := &model.Comment{
|
||||
PostID: postID,
|
||||
UserID: userID,
|
||||
Content: content,
|
||||
ParentID: parentID,
|
||||
Images: images,
|
||||
Status: model.CommentStatusPending,
|
||||
}
|
||||
|
||||
// 如果有父评论,设置根评论ID
|
||||
var parentUserID string
|
||||
if parentID != nil {
|
||||
parent, err := s.commentRepo.GetByID(*parentID)
|
||||
if err == nil && parent != nil {
|
||||
if parent.RootID != nil {
|
||||
comment.RootID = parent.RootID
|
||||
} else {
|
||||
comment.RootID = parentID
|
||||
}
|
||||
parentUserID = parent.UserID
|
||||
}
|
||||
}
|
||||
|
||||
err = s.commentRepo.Create(comment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 重新查询以获取关联的 User
|
||||
comment, err = s.commentRepo.GetByID(comment.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go s.reviewCommentAsync(comment.ID, userID, postID, content, imageURLs, parentID, parentUserID, post.UserID)
|
||||
|
||||
return comment, nil
|
||||
}
|
||||
|
||||
func (s *CommentService) reviewCommentAsync(
|
||||
commentID, userID, postID, content string,
|
||||
imageURLs []string,
|
||||
parentID *string,
|
||||
parentUserID string,
|
||||
postOwnerID string,
|
||||
) {
|
||||
// 未启用AI时,直接通过审核并发送后续通知
|
||||
if s.postAIService == nil || !s.postAIService.IsEnabled() {
|
||||
if err := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); err != nil {
|
||||
log.Printf("[WARN] Failed to publish comment without AI moderation: %v", err)
|
||||
return
|
||||
}
|
||||
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
||||
return
|
||||
}
|
||||
|
||||
err := s.postAIService.ModerateComment(context.Background(), content, imageURLs)
|
||||
if err != nil {
|
||||
var rejectedErr *CommentModerationRejectedError
|
||||
if errors.As(err, &rejectedErr) {
|
||||
if delErr := s.commentRepo.Delete(commentID); delErr != nil {
|
||||
log.Printf("[WARN] Failed to delete rejected comment %s: %v", commentID, delErr)
|
||||
}
|
||||
s.notifyCommentModerationRejected(userID, rejectedErr.Reason)
|
||||
return
|
||||
}
|
||||
|
||||
// 审核服务异常时降级放行,避免评论长期pending
|
||||
if updateErr := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); updateErr != nil {
|
||||
log.Printf("[WARN] Failed to publish comment %s after moderation error: %v", commentID, updateErr)
|
||||
return
|
||||
}
|
||||
log.Printf("[WARN] Comment moderation failed, fallback publish comment=%s err=%v", commentID, err)
|
||||
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
||||
return
|
||||
}
|
||||
|
||||
if updateErr := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); updateErr != nil {
|
||||
log.Printf("[WARN] Failed to publish comment %s: %v", commentID, updateErr)
|
||||
return
|
||||
}
|
||||
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
||||
}
|
||||
|
||||
func (s *CommentService) afterCommentPublished(userID, postID, commentID string, parentID *string, parentUserID, postOwnerID string) {
|
||||
// 发送系统消息通知
|
||||
if s.systemMessageService != nil {
|
||||
go func() {
|
||||
if parentID != nil && parentUserID != "" {
|
||||
// 回复评论,通知被回复的人
|
||||
if parentUserID != userID {
|
||||
notifyErr := s.systemMessageService.SendReplyNotification(context.Background(), parentUserID, userID, postID, *parentID, commentID)
|
||||
if notifyErr != nil {
|
||||
fmt.Printf("[DEBUG] Error sending reply notification: %v\n", notifyErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 评论帖子,通知帖子作者
|
||||
if postOwnerID != userID {
|
||||
notifyErr := s.systemMessageService.SendCommentNotification(context.Background(), postOwnerID, userID, postID, commentID)
|
||||
if notifyErr != nil {
|
||||
fmt.Printf("[DEBUG] Error sending comment notification: %v\n", notifyErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 推送评论行为到Gorse(异步)
|
||||
go func() {
|
||||
if s.gorseClient.IsEnabled() {
|
||||
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeComment, userID, postID); err != nil {
|
||||
log.Printf("[WARN] Failed to insert comment feedback to Gorse: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *CommentService) notifyCommentModerationRejected(userID, reason string) {
|
||||
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
content := "您发布的评论未通过AI审核,请修改后重试。"
|
||||
if strings.TrimSpace(reason) != "" {
|
||||
content = fmt.Sprintf("您发布的评论未通过AI审核,原因:%s。请修改后重试。", reason)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := s.systemMessageService.SendSystemAnnouncement(
|
||||
context.Background(),
|
||||
[]string{userID},
|
||||
"评论审核未通过",
|
||||
content,
|
||||
); err != nil {
|
||||
log.Printf("[WARN] Failed to send comment moderation reject notification: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取评论
|
||||
func (s *CommentService) GetByID(ctx context.Context, id string) (*model.Comment, error) {
|
||||
return s.commentRepo.GetByID(id)
|
||||
}
|
||||
|
||||
// GetByPostID 获取帖子评论
|
||||
func (s *CommentService) GetByPostID(ctx context.Context, postID string, page, pageSize int) ([]*model.Comment, int64, error) {
|
||||
// 使用带回复的查询,默认加载前3条回复
|
||||
return s.commentRepo.GetByPostIDWithReplies(postID, page, pageSize, 3)
|
||||
}
|
||||
|
||||
// GetRepliesByRootID 根据根评论ID分页获取回复
|
||||
func (s *CommentService) GetRepliesByRootID(ctx context.Context, rootID string, page, pageSize int) ([]*model.Comment, int64, error) {
|
||||
return s.commentRepo.GetRepliesByRootID(rootID, page, pageSize)
|
||||
}
|
||||
|
||||
// GetReplies 获取回复
|
||||
func (s *CommentService) GetReplies(ctx context.Context, parentID string) ([]*model.Comment, error) {
|
||||
return s.commentRepo.GetReplies(parentID)
|
||||
}
|
||||
|
||||
// Update 更新评论
|
||||
func (s *CommentService) Update(ctx context.Context, comment *model.Comment) error {
|
||||
return s.commentRepo.Update(comment)
|
||||
}
|
||||
|
||||
// Delete 删除评论
|
||||
func (s *CommentService) Delete(ctx context.Context, id string) error {
|
||||
return s.commentRepo.Delete(id)
|
||||
}
|
||||
|
||||
// Like 点赞评论
|
||||
func (s *CommentService) Like(ctx context.Context, commentID, userID string) error {
|
||||
// 获取评论信息用于发送通知
|
||||
comment, err := s.commentRepo.GetByID(commentID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.commentRepo.Like(commentID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 发送评论/回复点赞通知(只有不是给自己点赞时才发送)
|
||||
if s.systemMessageService != nil && comment.UserID != userID {
|
||||
go func() {
|
||||
var notifyErr error
|
||||
if comment.ParentID != nil {
|
||||
notifyErr = s.systemMessageService.SendLikeReplyNotification(
|
||||
context.Background(),
|
||||
comment.UserID,
|
||||
userID,
|
||||
comment.PostID,
|
||||
commentID,
|
||||
comment.Content,
|
||||
)
|
||||
} else {
|
||||
notifyErr = s.systemMessageService.SendLikeCommentNotification(
|
||||
context.Background(),
|
||||
comment.UserID,
|
||||
userID,
|
||||
comment.PostID,
|
||||
commentID,
|
||||
comment.Content,
|
||||
)
|
||||
}
|
||||
if notifyErr != nil {
|
||||
fmt.Printf("[DEBUG] Error sending like notification: %v\n", notifyErr)
|
||||
} else {
|
||||
fmt.Printf("[DEBUG] Like notification sent successfully\n")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlike 取消点赞评论
|
||||
func (s *CommentService) Unlike(ctx context.Context, commentID, userID string) error {
|
||||
return s.commentRepo.Unlike(commentID, userID)
|
||||
}
|
||||
|
||||
// IsLiked 检查是否已点赞
|
||||
func (s *CommentService) IsLiked(ctx context.Context, commentID, userID string) bool {
|
||||
return s.commentRepo.IsLiked(commentID, userID)
|
||||
}
|
||||
234
internal/service/email_code_service.go
Normal file
234
internal/service/email_code_service.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
verifyCodeTTL = 10 * time.Minute
|
||||
verifyCodeRateLimitTTL = 60 * time.Second
|
||||
)
|
||||
|
||||
const (
|
||||
CodePurposeRegister = "register"
|
||||
CodePurposePasswordReset = "password_reset"
|
||||
CodePurposeEmailVerify = "email_verify"
|
||||
CodePurposeChangePassword = "change_password"
|
||||
)
|
||||
|
||||
type verificationCodePayload struct {
|
||||
Code string `json:"code"`
|
||||
Purpose string `json:"purpose"`
|
||||
Email string `json:"email"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
}
|
||||
|
||||
type EmailCodeService interface {
|
||||
SendCode(ctx context.Context, purpose, email string) error
|
||||
VerifyCode(purpose, email, code string) error
|
||||
}
|
||||
|
||||
type emailCodeServiceImpl struct {
|
||||
emailService EmailService
|
||||
cache cache.Cache
|
||||
}
|
||||
|
||||
func NewEmailCodeService(emailService EmailService, cacheBackend cache.Cache) EmailCodeService {
|
||||
if cacheBackend == nil {
|
||||
cacheBackend = cache.GetCache()
|
||||
}
|
||||
return &emailCodeServiceImpl{
|
||||
emailService: emailService,
|
||||
cache: cacheBackend,
|
||||
}
|
||||
}
|
||||
|
||||
func verificationCodeCacheKey(purpose, email string) string {
|
||||
return fmt.Sprintf("auth:verify_code:%s:%s", purpose, strings.ToLower(strings.TrimSpace(email)))
|
||||
}
|
||||
|
||||
func verificationCodeRateLimitKey(purpose, email string) string {
|
||||
return fmt.Sprintf("auth:verify_code_rate_limit:%s:%s", purpose, strings.ToLower(strings.TrimSpace(email)))
|
||||
}
|
||||
|
||||
func generateNumericCode(length int) (string, error) {
|
||||
if length <= 0 {
|
||||
return "", fmt.Errorf("invalid code length")
|
||||
}
|
||||
max := big.NewInt(10)
|
||||
result := make([]byte, length)
|
||||
for i := 0; i < length; i++ {
|
||||
n, err := rand.Int(rand.Reader, max)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
result[i] = byte('0' + n.Int64())
|
||||
}
|
||||
return string(result), nil
|
||||
}
|
||||
|
||||
func (s *emailCodeServiceImpl) SendCode(ctx context.Context, purpose, email string) error {
|
||||
if strings.TrimSpace(email) == "" || !utils.ValidateEmail(email) {
|
||||
return ErrInvalidEmail
|
||||
}
|
||||
if s.emailService == nil || !s.emailService.IsEnabled() {
|
||||
return ErrEmailServiceUnavailable
|
||||
}
|
||||
if s.cache == nil {
|
||||
return ErrVerificationCodeUnavailable
|
||||
}
|
||||
|
||||
rateLimitKey := verificationCodeRateLimitKey(purpose, email)
|
||||
if s.cache.Exists(rateLimitKey) {
|
||||
return ErrVerificationCodeTooFrequent
|
||||
}
|
||||
|
||||
code, err := generateNumericCode(6)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate verification code failed: %w", err)
|
||||
}
|
||||
payload := verificationCodePayload{
|
||||
Code: code,
|
||||
Purpose: purpose,
|
||||
Email: strings.ToLower(strings.TrimSpace(email)),
|
||||
ExpiresAt: time.Now().Add(verifyCodeTTL).Unix(),
|
||||
}
|
||||
cacheKey := verificationCodeCacheKey(purpose, email)
|
||||
s.cache.Set(cacheKey, payload, verifyCodeTTL)
|
||||
s.cache.Set(rateLimitKey, "1", verifyCodeRateLimitTTL)
|
||||
|
||||
subject, sceneText := verificationEmailMeta(purpose)
|
||||
textBody := fmt.Sprintf("【%s】验证码:%s\n有效期:10分钟\n请勿将验证码泄露给他人。", sceneText, code)
|
||||
htmlBody := buildVerificationEmailHTML(sceneText, code)
|
||||
if err := s.emailService.Send(ctx, SendEmailRequest{
|
||||
To: []string{email},
|
||||
Subject: subject,
|
||||
TextBody: textBody,
|
||||
HTMLBody: htmlBody,
|
||||
}); err != nil {
|
||||
s.cache.Delete(cacheKey)
|
||||
return fmt.Errorf("send verification email failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCodeServiceImpl) VerifyCode(purpose, email, code string) error {
|
||||
if strings.TrimSpace(email) == "" || strings.TrimSpace(code) == "" {
|
||||
return ErrVerificationCodeInvalid
|
||||
}
|
||||
if s.cache == nil {
|
||||
return ErrVerificationCodeUnavailable
|
||||
}
|
||||
|
||||
cacheKey := verificationCodeCacheKey(purpose, email)
|
||||
raw, ok := s.cache.Get(cacheKey)
|
||||
if !ok {
|
||||
return ErrVerificationCodeExpired
|
||||
}
|
||||
|
||||
var payload verificationCodePayload
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
if err := json.Unmarshal([]byte(v), &payload); err != nil {
|
||||
return ErrVerificationCodeInvalid
|
||||
}
|
||||
case []byte:
|
||||
if err := json.Unmarshal(v, &payload); err != nil {
|
||||
return ErrVerificationCodeInvalid
|
||||
}
|
||||
case verificationCodePayload:
|
||||
payload = v
|
||||
default:
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return ErrVerificationCodeInvalid
|
||||
}
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
return ErrVerificationCodeInvalid
|
||||
}
|
||||
}
|
||||
|
||||
if payload.Purpose != purpose || payload.Email != strings.ToLower(strings.TrimSpace(email)) {
|
||||
return ErrVerificationCodeInvalid
|
||||
}
|
||||
if payload.ExpiresAt > 0 && time.Now().Unix() > payload.ExpiresAt {
|
||||
s.cache.Delete(cacheKey)
|
||||
return ErrVerificationCodeExpired
|
||||
}
|
||||
if payload.Code != strings.TrimSpace(code) {
|
||||
return ErrVerificationCodeInvalid
|
||||
}
|
||||
|
||||
s.cache.Delete(cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func verificationEmailMeta(purpose string) (subject string, sceneText string) {
|
||||
switch purpose {
|
||||
case CodePurposeRegister:
|
||||
return "Carrot BBS 注册验证码", "注册账号"
|
||||
case CodePurposePasswordReset:
|
||||
return "Carrot BBS 找回密码验证码", "找回密码"
|
||||
case CodePurposeEmailVerify:
|
||||
return "Carrot BBS 邮箱验证验证码", "验证邮箱"
|
||||
case CodePurposeChangePassword:
|
||||
return "Carrot BBS 修改密码验证码", "修改密码"
|
||||
default:
|
||||
return "Carrot BBS 验证码", "身份验证"
|
||||
}
|
||||
}
|
||||
|
||||
func buildVerificationEmailHTML(sceneText, code string) string {
|
||||
return fmt.Sprintf(`<!doctype html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Carrot BBS 验证码</title>
|
||||
</head>
|
||||
<body style="margin:0;padding:0;background:#f4f6fb;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,'PingFang SC','Microsoft YaHei',sans-serif;color:#1f2937;">
|
||||
<table role="presentation" width="100%%" cellspacing="0" cellpadding="0" style="background:#f4f6fb;padding:24px 12px;">
|
||||
<tr>
|
||||
<td align="center">
|
||||
<table role="presentation" width="100%%" cellspacing="0" cellpadding="0" style="max-width:560px;background:#ffffff;border-radius:14px;overflow:hidden;box-shadow:0 8px 30px rgba(15,23,42,0.08);">
|
||||
<tr>
|
||||
<td style="background:linear-gradient(135deg,#ff6b35,#ff8f66);padding:24px 28px;color:#ffffff;">
|
||||
<div style="font-size:22px;font-weight:700;line-height:1.2;">Carrot BBS</div>
|
||||
<div style="margin-top:6px;font-size:14px;opacity:0.95;">%s 验证</div>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="padding:28px;">
|
||||
<p style="margin:0 0 14px;font-size:15px;line-height:1.75;">你好,</p>
|
||||
<p style="margin:0 0 20px;font-size:15px;line-height:1.75;">你正在进行 <strong>%s</strong> 操作,请使用下方验证码完成验证:</p>
|
||||
<div style="margin:0 auto 18px;max-width:320px;border:1px dashed #ff8f66;background:#fff8f4;border-radius:12px;padding:14px 12px;text-align:center;">
|
||||
<div style="font-size:13px;color:#9a3412;letter-spacing:0.5px;">验证码(10分钟内有效)</div>
|
||||
<div style="margin-top:8px;font-size:34px;line-height:1;font-weight:800;letter-spacing:8px;color:#ea580c;">%s</div>
|
||||
</div>
|
||||
<p style="margin:0 0 8px;font-size:13px;color:#6b7280;line-height:1.7;">如果不是你本人操作,请忽略此邮件,并及时检查账号安全。</p>
|
||||
<p style="margin:0;font-size:13px;color:#6b7280;line-height:1.7;">请勿向任何人透露验证码,平台不会以任何理由索取验证码。</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="padding:14px 28px;background:#f8fafc;border-top:1px solid #e5e7eb;color:#94a3b8;font-size:12px;line-height:1.7;">
|
||||
此邮件由系统自动发送,请勿直接回复。<br/>
|
||||
© Carrot BBS
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>`, sceneText, sceneText, code)
|
||||
}
|
||||
82
internal/service/email_service.go
Normal file
82
internal/service/email_service.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
emailpkg "carrot_bbs/internal/pkg/email"
|
||||
)
|
||||
|
||||
// SendEmailRequest 发信请求
|
||||
type SendEmailRequest struct {
|
||||
To []string
|
||||
Cc []string
|
||||
Bcc []string
|
||||
ReplyTo []string
|
||||
Subject string
|
||||
TextBody string
|
||||
HTMLBody string
|
||||
Attachments []string
|
||||
}
|
||||
|
||||
type EmailService interface {
|
||||
IsEnabled() bool
|
||||
Send(ctx context.Context, req SendEmailRequest) error
|
||||
SendText(ctx context.Context, to []string, subject, body string) error
|
||||
SendHTML(ctx context.Context, to []string, subject, html string) error
|
||||
}
|
||||
|
||||
type emailServiceImpl struct {
|
||||
client emailpkg.Client
|
||||
}
|
||||
|
||||
func NewEmailService(client emailpkg.Client) EmailService {
|
||||
return &emailServiceImpl{client: client}
|
||||
}
|
||||
|
||||
func (s *emailServiceImpl) IsEnabled() bool {
|
||||
return s.client != nil && s.client.IsEnabled()
|
||||
}
|
||||
|
||||
func (s *emailServiceImpl) Send(ctx context.Context, req SendEmailRequest) error {
|
||||
if s.client == nil {
|
||||
return fmt.Errorf("email client is nil")
|
||||
}
|
||||
if !s.client.IsEnabled() {
|
||||
return fmt.Errorf("email service is disabled")
|
||||
}
|
||||
if len(req.To) == 0 {
|
||||
return fmt.Errorf("email recipient is empty")
|
||||
}
|
||||
if strings.TrimSpace(req.Subject) == "" {
|
||||
return fmt.Errorf("email subject is empty")
|
||||
}
|
||||
|
||||
return s.client.Send(ctx, emailpkg.Message{
|
||||
To: req.To,
|
||||
Cc: req.Cc,
|
||||
Bcc: req.Bcc,
|
||||
ReplyTo: req.ReplyTo,
|
||||
Subject: req.Subject,
|
||||
TextBody: req.TextBody,
|
||||
HTMLBody: req.HTMLBody,
|
||||
Attachments: req.Attachments,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *emailServiceImpl) SendText(ctx context.Context, to []string, subject, body string) error {
|
||||
return s.Send(ctx, SendEmailRequest{
|
||||
To: to,
|
||||
Subject: subject,
|
||||
TextBody: body,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *emailServiceImpl) SendHTML(ctx context.Context, to []string, subject, html string) error {
|
||||
return s.Send(ctx, SendEmailRequest{
|
||||
To: to,
|
||||
Subject: subject,
|
||||
HTMLBody: html,
|
||||
})
|
||||
}
|
||||
1491
internal/service/group_service.go
Normal file
1491
internal/service/group_service.go
Normal file
File diff suppressed because it is too large
Load Diff
38
internal/service/jwt_service.go
Normal file
38
internal/service/jwt_service.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/pkg/jwt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JWTService JWT服务
|
||||
type JWTService struct {
|
||||
jwt *jwt.JWT
|
||||
}
|
||||
|
||||
// NewJWTService 创建JWT服务
|
||||
func NewJWTService(secret string, accessExpire, refreshExpire int64) *JWTService {
|
||||
return &JWTService{
|
||||
jwt: jwt.New(secret, time.Duration(accessExpire)*time.Second, time.Duration(refreshExpire)*time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateAccessToken 生成访问令牌
|
||||
func (s *JWTService) GenerateAccessToken(userID, username string) (string, error) {
|
||||
return s.jwt.GenerateAccessToken(userID, username)
|
||||
}
|
||||
|
||||
// GenerateRefreshToken 生成刷新令牌
|
||||
func (s *JWTService) GenerateRefreshToken(userID, username string) (string, error) {
|
||||
return s.jwt.GenerateRefreshToken(userID, username)
|
||||
}
|
||||
|
||||
// ParseToken 解析令牌
|
||||
func (s *JWTService) ParseToken(tokenString string) (*jwt.Claims, error) {
|
||||
return s.jwt.ParseToken(tokenString)
|
||||
}
|
||||
|
||||
// ValidateToken 验证令牌
|
||||
func (s *JWTService) ValidateToken(tokenString string) error {
|
||||
return s.jwt.ValidateToken(tokenString)
|
||||
}
|
||||
215
internal/service/message_service.go
Normal file
215
internal/service/message_service.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/repository"
|
||||
)
|
||||
|
||||
// 缓存TTL常量
|
||||
const (
|
||||
ConversationListTTL = 60 * time.Second // 会话列表缓存60秒
|
||||
ConversationDetailTTL = 60 * time.Second // 会话详情缓存60秒
|
||||
UnreadCountTTL = 30 * time.Second // 未读数缓存30秒
|
||||
ConversationNullTTL = 5 * time.Second
|
||||
UnreadNullTTL = 5 * time.Second
|
||||
CacheJitterRatio = 0.1
|
||||
)
|
||||
|
||||
// MessageService 消息服务
|
||||
type MessageService struct {
|
||||
messageRepo *repository.MessageRepository
|
||||
cache cache.Cache
|
||||
}
|
||||
|
||||
// NewMessageService 创建消息服务
|
||||
func NewMessageService(messageRepo *repository.MessageRepository) *MessageService {
|
||||
return &MessageService{
|
||||
messageRepo: messageRepo,
|
||||
cache: cache.GetCache(),
|
||||
}
|
||||
}
|
||||
|
||||
// ConversationListResult 会话列表缓存结果
|
||||
type ConversationListResult struct {
|
||||
Conversations []*model.Conversation
|
||||
Total int64
|
||||
}
|
||||
|
||||
// SendMessage 发送消息(使用 segments)
|
||||
// senderID 和 receiverID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (s *MessageService) SendMessage(ctx context.Context, senderID, receiverID string, segments model.MessageSegments) (*model.Message, error) {
|
||||
// 获取或创建会话
|
||||
conv, err := s.messageRepo.GetOrCreatePrivateConversation(senderID, receiverID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg := &model.Message{
|
||||
ConversationID: conv.ID,
|
||||
SenderID: senderID,
|
||||
Segments: segments,
|
||||
Status: model.MessageStatusNormal,
|
||||
}
|
||||
|
||||
// 使用事务创建消息并更新seq
|
||||
err = s.messageRepo.CreateMessageWithSeq(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 失效会话列表缓存(发送者和接收者)
|
||||
cache.InvalidateConversationList(s.cache, senderID)
|
||||
cache.InvalidateConversationList(s.cache, receiverID)
|
||||
|
||||
// 失效未读数缓存
|
||||
cache.InvalidateUnreadConversation(s.cache, receiverID)
|
||||
cache.InvalidateUnreadDetail(s.cache, receiverID, conv.ID)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// GetConversations 获取会话列表(带缓存)
|
||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (s *MessageService) GetConversations(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
|
||||
cacheSettings := cache.GetSettings()
|
||||
conversationTTL := cacheSettings.ConversationTTL
|
||||
if conversationTTL <= 0 {
|
||||
conversationTTL = ConversationListTTL
|
||||
}
|
||||
nullTTL := cacheSettings.NullTTL
|
||||
if nullTTL <= 0 {
|
||||
nullTTL = ConversationNullTTL
|
||||
}
|
||||
jitter := cacheSettings.JitterRatio
|
||||
if jitter <= 0 {
|
||||
jitter = CacheJitterRatio
|
||||
}
|
||||
|
||||
// 生成缓存键
|
||||
cacheKey := cache.ConversationListKey(userID, page, pageSize)
|
||||
result, err := cache.GetOrLoadTyped[*ConversationListResult](
|
||||
s.cache,
|
||||
cacheKey,
|
||||
conversationTTL,
|
||||
jitter,
|
||||
nullTTL,
|
||||
func() (*ConversationListResult, error) {
|
||||
conversations, total, err := s.messageRepo.GetConversations(userID, page, pageSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ConversationListResult{
|
||||
Conversations: conversations,
|
||||
Total: total,
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if result == nil {
|
||||
return []*model.Conversation{}, 0, nil
|
||||
}
|
||||
return result.Conversations, result.Total, nil
|
||||
}
|
||||
|
||||
// GetMessages 获取消息列表
|
||||
func (s *MessageService) GetMessages(ctx context.Context, conversationID string, page, pageSize int) ([]*model.Message, int64, error) {
|
||||
return s.messageRepo.GetMessages(conversationID, page, pageSize)
|
||||
}
|
||||
|
||||
// GetMessagesAfterSeq 获取指定seq之后的消息(增量同步)
|
||||
func (s *MessageService) GetMessagesAfterSeq(ctx context.Context, conversationID string, afterSeq int64, limit int) ([]*model.Message, error) {
|
||||
return s.messageRepo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
|
||||
}
|
||||
|
||||
// MarkAsRead 标记为已读
|
||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (s *MessageService) MarkAsRead(ctx context.Context, conversationID string, userID string, lastReadSeq int64) error {
|
||||
err := s.messageRepo.UpdateLastReadSeq(conversationID, userID, lastReadSeq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效未读数缓存
|
||||
cache.InvalidateUnreadConversation(s.cache, userID)
|
||||
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
|
||||
|
||||
// 失效会话列表缓存
|
||||
cache.InvalidateConversationList(s.cache, userID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUnreadCount 获取未读消息数(带缓存)
|
||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (s *MessageService) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
|
||||
cacheSettings := cache.GetSettings()
|
||||
unreadTTL := cacheSettings.UnreadCountTTL
|
||||
if unreadTTL <= 0 {
|
||||
unreadTTL = UnreadCountTTL
|
||||
}
|
||||
nullTTL := cacheSettings.NullTTL
|
||||
if nullTTL <= 0 {
|
||||
nullTTL = UnreadNullTTL
|
||||
}
|
||||
jitter := cacheSettings.JitterRatio
|
||||
if jitter <= 0 {
|
||||
jitter = CacheJitterRatio
|
||||
}
|
||||
|
||||
// 生成缓存键
|
||||
cacheKey := cache.UnreadDetailKey(userID, conversationID)
|
||||
|
||||
return cache.GetOrLoadTyped[int64](
|
||||
s.cache,
|
||||
cacheKey,
|
||||
unreadTTL,
|
||||
jitter,
|
||||
nullTTL,
|
||||
func() (int64, error) {
|
||||
return s.messageRepo.GetUnreadCount(conversationID, userID)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// GetOrCreateConversation 获取或创建私聊会话
|
||||
// user1ID 和 user2ID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (s *MessageService) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
|
||||
conv, err := s.messageRepo.GetOrCreatePrivateConversation(user1ID, user2ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 失效会话列表缓存
|
||||
cache.InvalidateConversationList(s.cache, user1ID)
|
||||
cache.InvalidateConversationList(s.cache, user2ID)
|
||||
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
// GetConversationParticipants 获取会话参与者列表
|
||||
func (s *MessageService) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) {
|
||||
return s.messageRepo.GetConversationParticipants(conversationID)
|
||||
}
|
||||
|
||||
// ParseConversationID 辅助函数:直接返回字符串ID(已经是string类型)
|
||||
func ParseConversationID(idStr string) (string, error) {
|
||||
return idStr, nil
|
||||
}
|
||||
|
||||
// InvalidateUserConversationCache 失效用户会话相关缓存(供外部调用)
|
||||
func (s *MessageService) InvalidateUserConversationCache(userID string) {
|
||||
cache.InvalidateConversationList(s.cache, userID)
|
||||
cache.InvalidateUnreadConversation(s.cache, userID)
|
||||
}
|
||||
|
||||
// InvalidateUserUnreadCache 失效用户未读数缓存(供外部调用)
|
||||
func (s *MessageService) InvalidateUserUnreadCache(userID, conversationID string) {
|
||||
cache.InvalidateUnreadConversation(s.cache, userID)
|
||||
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
|
||||
}
|
||||
169
internal/service/notification_service.go
Normal file
169
internal/service/notification_service.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/repository"
|
||||
)
|
||||
|
||||
// 缓存TTL常量
|
||||
const (
|
||||
NotificationUnreadCountTTL = 30 * time.Second // 通知未读数缓存30秒
|
||||
NotificationNullTTL = 5 * time.Second
|
||||
NotificationCacheJitter = 0.1
|
||||
)
|
||||
|
||||
// NotificationService 通知服务
|
||||
type NotificationService struct {
|
||||
notificationRepo *repository.NotificationRepository
|
||||
cache cache.Cache
|
||||
}
|
||||
|
||||
// NewNotificationService 创建通知服务
|
||||
func NewNotificationService(notificationRepo *repository.NotificationRepository) *NotificationService {
|
||||
return &NotificationService{
|
||||
notificationRepo: notificationRepo,
|
||||
cache: cache.GetCache(),
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建通知
|
||||
func (s *NotificationService) Create(ctx context.Context, userID string, notificationType model.NotificationType, title, content string) (*model.Notification, error) {
|
||||
notification := &model.Notification{
|
||||
UserID: userID,
|
||||
Type: notificationType,
|
||||
Title: title,
|
||||
Content: content,
|
||||
IsRead: false,
|
||||
}
|
||||
|
||||
err := s.notificationRepo.Create(notification)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 失效未读数缓存
|
||||
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||
|
||||
return notification, nil
|
||||
}
|
||||
|
||||
// GetByUserID 获取用户通知
|
||||
func (s *NotificationService) GetByUserID(ctx context.Context, userID string, page, pageSize int, unreadOnly bool) ([]*model.Notification, int64, error) {
|
||||
return s.notificationRepo.GetByUserID(userID, page, pageSize, unreadOnly)
|
||||
}
|
||||
|
||||
// MarkAsRead 标记为已读
|
||||
func (s *NotificationService) MarkAsRead(ctx context.Context, id string) error {
|
||||
err := s.notificationRepo.MarkAsRead(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 注意:这里无法获取userID,所以不在缓存中失效
|
||||
// 调用方应该使用MarkAsReadWithUserID方法
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkAsReadWithUserID 标记为已读(带用户ID,用于缓存失效)
|
||||
func (s *NotificationService) MarkAsReadWithUserID(ctx context.Context, id, userID string) error {
|
||||
err := s.notificationRepo.MarkAsRead(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效未读数缓存
|
||||
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkAllAsRead 标记所有为已读
|
||||
func (s *NotificationService) MarkAllAsRead(ctx context.Context, userID string) error {
|
||||
err := s.notificationRepo.MarkAllAsRead(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效未读数缓存
|
||||
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除通知
|
||||
func (s *NotificationService) Delete(ctx context.Context, id string) error {
|
||||
return s.notificationRepo.Delete(id)
|
||||
}
|
||||
|
||||
// GetUnreadCount 获取未读数量(带缓存)
|
||||
func (s *NotificationService) GetUnreadCount(ctx context.Context, userID string) (int64, error) {
|
||||
cacheSettings := cache.GetSettings()
|
||||
unreadTTL := cacheSettings.UnreadCountTTL
|
||||
if unreadTTL <= 0 {
|
||||
unreadTTL = NotificationUnreadCountTTL
|
||||
}
|
||||
nullTTL := cacheSettings.NullTTL
|
||||
if nullTTL <= 0 {
|
||||
nullTTL = NotificationNullTTL
|
||||
}
|
||||
jitter := cacheSettings.JitterRatio
|
||||
if jitter <= 0 {
|
||||
jitter = NotificationCacheJitter
|
||||
}
|
||||
|
||||
// 生成缓存键
|
||||
cacheKey := cache.UnreadSystemKey(userID)
|
||||
return cache.GetOrLoadTyped[int64](
|
||||
s.cache,
|
||||
cacheKey,
|
||||
unreadTTL,
|
||||
jitter,
|
||||
nullTTL,
|
||||
func() (int64, error) {
|
||||
return s.notificationRepo.GetUnreadCount(userID)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// DeleteNotification 删除通知(带用户验证)
|
||||
func (s *NotificationService) DeleteNotification(ctx context.Context, id, userID string) error {
|
||||
// 先检查通知是否属于该用户
|
||||
notification, err := s.notificationRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if notification.UserID != userID {
|
||||
return ErrUnauthorizedNotification
|
||||
}
|
||||
|
||||
err = s.notificationRepo.Delete(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效未读数缓存
|
||||
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearAllNotifications 清空所有通知
|
||||
func (s *NotificationService) ClearAllNotifications(ctx context.Context, userID string) error {
|
||||
err := s.notificationRepo.DeleteAllByUserID(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效未读数缓存
|
||||
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 错误定义
|
||||
var ErrUnauthorizedNotification = &ServiceError{Code: 403, Message: "unauthorized to delete this notification"}
|
||||
103
internal/service/post_ai_service.go
Normal file
103
internal/service/post_ai_service.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"carrot_bbs/internal/pkg/openai"
|
||||
)
|
||||
|
||||
// PostModerationRejectedError 帖子审核拒绝错误
|
||||
type PostModerationRejectedError struct {
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (e *PostModerationRejectedError) Error() string {
|
||||
if strings.TrimSpace(e.Reason) == "" {
|
||||
return "post rejected by moderation"
|
||||
}
|
||||
return "post rejected by moderation: " + e.Reason
|
||||
}
|
||||
|
||||
// UserMessage 返回给前端的用户可读文案
|
||||
func (e *PostModerationRejectedError) UserMessage() string {
|
||||
if strings.TrimSpace(e.Reason) == "" {
|
||||
return "内容未通过审核,请修改后重试"
|
||||
}
|
||||
return strings.TrimSpace(e.Reason)
|
||||
}
|
||||
|
||||
// CommentModerationRejectedError 评论审核拒绝错误
|
||||
type CommentModerationRejectedError struct {
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (e *CommentModerationRejectedError) Error() string {
|
||||
if strings.TrimSpace(e.Reason) == "" {
|
||||
return "comment rejected by moderation"
|
||||
}
|
||||
return "comment rejected by moderation: " + e.Reason
|
||||
}
|
||||
|
||||
// UserMessage 返回给前端的用户可读文案
|
||||
func (e *CommentModerationRejectedError) UserMessage() string {
|
||||
if strings.TrimSpace(e.Reason) == "" {
|
||||
return "评论未通过审核,请修改后重试"
|
||||
}
|
||||
return strings.TrimSpace(e.Reason)
|
||||
}
|
||||
|
||||
type PostAIService struct {
|
||||
openAIClient openai.Client
|
||||
}
|
||||
|
||||
func NewPostAIService(openAIClient openai.Client) *PostAIService {
|
||||
return &PostAIService{
|
||||
openAIClient: openAIClient,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PostAIService) IsEnabled() bool {
|
||||
return s != nil && s.openAIClient != nil && s.openAIClient.IsEnabled()
|
||||
}
|
||||
|
||||
// ModeratePost 审核帖子内容,返回 nil 表示通过
|
||||
func (s *PostAIService) ModeratePost(ctx context.Context, title, content string, images []string) error {
|
||||
if !s.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
approved, reason, err := s.openAIClient.ModeratePost(ctx, title, content, images)
|
||||
if err != nil {
|
||||
if s.openAIClient.Config().StrictModeration {
|
||||
return err
|
||||
}
|
||||
log.Printf("[WARN] AI moderation failed, fallback allow: %v", err)
|
||||
return nil
|
||||
}
|
||||
if !approved {
|
||||
return &PostModerationRejectedError{Reason: reason}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ModerateComment 审核评论内容,返回 nil 表示通过
|
||||
func (s *PostAIService) ModerateComment(ctx context.Context, content string, images []string) error {
|
||||
if !s.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
approved, reason, err := s.openAIClient.ModerateComment(ctx, content, images)
|
||||
if err != nil {
|
||||
if s.openAIClient.Config().StrictModeration {
|
||||
return err
|
||||
}
|
||||
log.Printf("[WARN] AI comment moderation failed, fallback allow: %v", err)
|
||||
return nil
|
||||
}
|
||||
if !approved {
|
||||
return &CommentModerationRejectedError{Reason: reason}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
593
internal/service/post_service.go
Normal file
593
internal/service/post_service.go
Normal file
@@ -0,0 +1,593 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/gorse"
|
||||
"carrot_bbs/internal/repository"
|
||||
)
|
||||
|
||||
// 缓存TTL常量
|
||||
const (
|
||||
PostListTTL = 30 * time.Second // 帖子列表缓存30秒
|
||||
PostListNullTTL = 5 * time.Second
|
||||
PostListJitterRatio = 0.15
|
||||
anonymousViewUserID = "_anon_view"
|
||||
)
|
||||
|
||||
// PostService 帖子服务
|
||||
type PostService struct {
|
||||
postRepo *repository.PostRepository
|
||||
systemMessageService SystemMessageService
|
||||
cache cache.Cache
|
||||
gorseClient gorse.Client
|
||||
postAIService *PostAIService
|
||||
}
|
||||
|
||||
// NewPostService 创建帖子服务
|
||||
func NewPostService(postRepo *repository.PostRepository, systemMessageService SystemMessageService, gorseClient gorse.Client, postAIService *PostAIService) *PostService {
|
||||
return &PostService{
|
||||
postRepo: postRepo,
|
||||
systemMessageService: systemMessageService,
|
||||
cache: cache.GetCache(),
|
||||
gorseClient: gorseClient,
|
||||
postAIService: postAIService,
|
||||
}
|
||||
}
|
||||
|
||||
// PostListResult 帖子列表缓存结果
|
||||
type PostListResult struct {
|
||||
Posts []*model.Post
|
||||
Total int64
|
||||
}
|
||||
|
||||
// Create 创建帖子
|
||||
func (s *PostService) Create(ctx context.Context, userID, title, content string, images []string) (*model.Post, error) {
|
||||
post := &model.Post{
|
||||
UserID: userID,
|
||||
Title: title,
|
||||
Content: content,
|
||||
Status: model.PostStatusPending,
|
||||
}
|
||||
|
||||
err := s.postRepo.Create(post, images)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 失效帖子列表缓存
|
||||
cache.InvalidatePostList(s.cache)
|
||||
|
||||
// 同步到Gorse推荐系统(异步)
|
||||
go s.reviewPostAsync(post.ID, userID, title, content, images)
|
||||
|
||||
// 重新查询以获取关联的 User 和 Images
|
||||
return s.postRepo.GetByID(post.ID)
|
||||
}
|
||||
|
||||
func (s *PostService) reviewPostAsync(postID, userID, title, content string, images []string) {
|
||||
// 未启用AI时,直接发布
|
||||
if s.postAIService == nil || !s.postAIService.IsEnabled() {
|
||||
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil {
|
||||
log.Printf("[WARN] Failed to publish post without AI moderation: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err := s.postAIService.ModeratePost(context.Background(), title, content, images)
|
||||
if err != nil {
|
||||
var rejectedErr *PostModerationRejectedError
|
||||
if errors.As(err, &rejectedErr) {
|
||||
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
|
||||
log.Printf("[WARN] Failed to reject post %s: %v", postID, updateErr)
|
||||
}
|
||||
s.notifyModerationRejected(userID, rejectedErr.Reason)
|
||||
return
|
||||
}
|
||||
|
||||
// 规则审核不可用时,降级为发布,避免长时间pending
|
||||
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
|
||||
log.Printf("[WARN] Failed to publish post %s after moderation error: %v", postID, updateErr)
|
||||
}
|
||||
log.Printf("[WARN] Post moderation failed, fallback publish post=%s err=%v", postID, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "ai"); err != nil {
|
||||
log.Printf("[WARN] Failed to publish post %s: %v", postID, err)
|
||||
return
|
||||
}
|
||||
|
||||
if s.gorseClient.IsEnabled() {
|
||||
post, getErr := s.postRepo.GetByID(postID)
|
||||
if getErr != nil {
|
||||
log.Printf("[WARN] Failed to load published post for gorse sync: %v", getErr)
|
||||
return
|
||||
}
|
||||
categories := s.buildPostCategories(post)
|
||||
comment := post.Title
|
||||
textToEmbed := post.Title + " " + post.Content
|
||||
if upsertErr := s.gorseClient.UpsertItemWithEmbedding(context.Background(), post.ID, categories, comment, textToEmbed); upsertErr != nil {
|
||||
log.Printf("[WARN] Failed to upsert item to Gorse: %v", upsertErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PostService) notifyModerationRejected(userID, reason string) {
|
||||
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
content := "您发布的帖子未通过AI审核,请修改后重试。"
|
||||
if strings.TrimSpace(reason) != "" {
|
||||
content = fmt.Sprintf("您发布的帖子未通过AI审核,原因:%s。请修改后重试。", reason)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := s.systemMessageService.SendSystemAnnouncement(
|
||||
context.Background(),
|
||||
[]string{userID},
|
||||
"帖子审核未通过",
|
||||
content,
|
||||
); err != nil {
|
||||
log.Printf("[WARN] Failed to send moderation reject notification: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取帖子
|
||||
func (s *PostService) GetByID(ctx context.Context, id string) (*model.Post, error) {
|
||||
return s.postRepo.GetByID(id)
|
||||
}
|
||||
|
||||
// Update 更新帖子
|
||||
func (s *PostService) Update(ctx context.Context, post *model.Post) error {
|
||||
err := s.postRepo.Update(post)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效帖子详情缓存和列表缓存
|
||||
cache.InvalidatePostDetail(s.cache, post.ID)
|
||||
cache.InvalidatePostList(s.cache)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除帖子
|
||||
func (s *PostService) Delete(ctx context.Context, id string) error {
|
||||
err := s.postRepo.Delete(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效帖子详情缓存和列表缓存
|
||||
cache.InvalidatePostDetail(s.cache, id)
|
||||
cache.InvalidatePostList(s.cache)
|
||||
|
||||
// 从Gorse中删除帖子(异步)
|
||||
go func() {
|
||||
if s.gorseClient.IsEnabled() {
|
||||
if err := s.gorseClient.DeleteItem(context.Background(), id); err != nil {
|
||||
log.Printf("[WARN] Failed to delete item from Gorse: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 获取帖子列表(带缓存)
|
||||
func (s *PostService) List(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
||||
cacheSettings := cache.GetSettings()
|
||||
postListTTL := cacheSettings.PostListTTL
|
||||
if postListTTL <= 0 {
|
||||
postListTTL = PostListTTL
|
||||
}
|
||||
nullTTL := cacheSettings.NullTTL
|
||||
if nullTTL <= 0 {
|
||||
nullTTL = PostListNullTTL
|
||||
}
|
||||
jitter := cacheSettings.JitterRatio
|
||||
if jitter <= 0 {
|
||||
jitter = PostListJitterRatio
|
||||
}
|
||||
|
||||
// 生成缓存键(包含 userID 维度,避免过滤查询与全量查询互相污染)
|
||||
cacheKey := cache.PostListKey("latest", userID, page, pageSize)
|
||||
|
||||
result, err := cache.GetOrLoadTyped[*PostListResult](
|
||||
s.cache,
|
||||
cacheKey,
|
||||
postListTTL,
|
||||
jitter,
|
||||
nullTTL,
|
||||
func() (*PostListResult, error) {
|
||||
posts, total, err := s.postRepo.List(page, pageSize, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PostListResult{Posts: posts, Total: total}, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if result == nil {
|
||||
return []*model.Post{}, 0, nil
|
||||
}
|
||||
|
||||
// 兼容历史脏缓存:旧缓存序列化会丢失 Post.User,导致前端显示“匿名用户”
|
||||
// 这里检测并回源重建一次缓存,避免在 TTL 内持续返回缺失作者的数据。
|
||||
missingAuthor := false
|
||||
for _, post := range result.Posts {
|
||||
if post != nil && post.UserID != "" && post.User == nil {
|
||||
missingAuthor = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if missingAuthor {
|
||||
posts, total, loadErr := s.postRepo.List(page, pageSize, userID)
|
||||
if loadErr != nil {
|
||||
return nil, 0, loadErr
|
||||
}
|
||||
result = &PostListResult{Posts: posts, Total: total}
|
||||
cache.SetWithJitter(s.cache, cacheKey, result, postListTTL, jitter)
|
||||
}
|
||||
|
||||
return result.Posts, result.Total, nil
|
||||
}
|
||||
|
||||
// GetLatestPosts 获取最新帖子(语义化别名)
|
||||
func (s *PostService) GetLatestPosts(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
||||
return s.List(ctx, page, pageSize, userID)
|
||||
}
|
||||
|
||||
// GetUserPosts 获取用户帖子
|
||||
func (s *PostService) GetUserPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
return s.postRepo.GetUserPosts(userID, page, pageSize)
|
||||
}
|
||||
|
||||
// Like 点赞
|
||||
func (s *PostService) Like(ctx context.Context, postID, userID string) error {
|
||||
// 获取帖子信息用于发送通知
|
||||
post, err := s.postRepo.GetByID(postID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.postRepo.Like(postID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效帖子详情缓存
|
||||
cache.InvalidatePostDetail(s.cache, postID)
|
||||
|
||||
// 发送点赞通知(不给自己发通知)
|
||||
if s.systemMessageService != nil && post.UserID != userID {
|
||||
go func() {
|
||||
notifyErr := s.systemMessageService.SendLikeNotification(context.Background(), post.UserID, userID, postID)
|
||||
if notifyErr != nil {
|
||||
fmt.Printf("[DEBUG] Error sending like notification: %v\n", notifyErr)
|
||||
} else {
|
||||
fmt.Printf("[DEBUG] Like notification sent successfully\n")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 推送点赞行为到Gorse(异步)
|
||||
go func() {
|
||||
if s.gorseClient.IsEnabled() {
|
||||
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeLike, userID, postID); err != nil {
|
||||
log.Printf("[WARN] Failed to insert like feedback to Gorse: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlike 取消点赞
|
||||
func (s *PostService) Unlike(ctx context.Context, postID, userID string) error {
|
||||
err := s.postRepo.Unlike(postID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效帖子详情缓存
|
||||
cache.InvalidatePostDetail(s.cache, postID)
|
||||
|
||||
// 删除Gorse中的点赞反馈(异步)
|
||||
go func() {
|
||||
if s.gorseClient.IsEnabled() {
|
||||
if err := s.gorseClient.DeleteFeedback(context.Background(), gorse.FeedbackTypeLike, userID, postID); err != nil {
|
||||
log.Printf("[WARN] Failed to delete like feedback from Gorse: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsLiked 检查是否点赞
|
||||
func (s *PostService) IsLiked(ctx context.Context, postID, userID string) bool {
|
||||
return s.postRepo.IsLiked(postID, userID)
|
||||
}
|
||||
|
||||
// Favorite 收藏
|
||||
func (s *PostService) Favorite(ctx context.Context, postID, userID string) error {
|
||||
// 获取帖子信息用于发送通知
|
||||
post, err := s.postRepo.GetByID(postID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.postRepo.Favorite(postID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效帖子详情缓存
|
||||
cache.InvalidatePostDetail(s.cache, postID)
|
||||
|
||||
// 发送收藏通知(不给自己发通知)
|
||||
if s.systemMessageService != nil && post.UserID != userID {
|
||||
go func() {
|
||||
notifyErr := s.systemMessageService.SendFavoriteNotification(context.Background(), post.UserID, userID, postID)
|
||||
if notifyErr != nil {
|
||||
fmt.Printf("[DEBUG] Error sending favorite notification: %v\n", notifyErr)
|
||||
} else {
|
||||
fmt.Printf("[DEBUG] Favorite notification sent successfully\n")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 推送收藏行为到Gorse(异步)
|
||||
go func() {
|
||||
if s.gorseClient.IsEnabled() {
|
||||
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeStar, userID, postID); err != nil {
|
||||
log.Printf("[WARN] Failed to insert favorite feedback to Gorse: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unfavorite 取消收藏
|
||||
func (s *PostService) Unfavorite(ctx context.Context, postID, userID string) error {
|
||||
err := s.postRepo.Unfavorite(postID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效帖子详情缓存
|
||||
cache.InvalidatePostDetail(s.cache, postID)
|
||||
|
||||
// 删除Gorse中的收藏反馈(异步)
|
||||
go func() {
|
||||
if s.gorseClient.IsEnabled() {
|
||||
if err := s.gorseClient.DeleteFeedback(context.Background(), gorse.FeedbackTypeStar, userID, postID); err != nil {
|
||||
log.Printf("[WARN] Failed to delete favorite feedback from Gorse: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsFavorited 检查是否收藏
|
||||
func (s *PostService) IsFavorited(ctx context.Context, postID, userID string) bool {
|
||||
return s.postRepo.IsFavorited(postID, userID)
|
||||
}
|
||||
|
||||
// IncrementViews 增加帖子观看量并同步到Gorse
|
||||
func (s *PostService) IncrementViews(ctx context.Context, postID, userID string) error {
|
||||
if err := s.postRepo.IncrementViews(postID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 同步浏览行为到Gorse(异步)
|
||||
go func() {
|
||||
if !s.gorseClient.IsEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
feedbackUserID := userID
|
||||
if feedbackUserID == "" {
|
||||
feedbackUserID = anonymousViewUserID
|
||||
}
|
||||
|
||||
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeRead, feedbackUserID, postID); err != nil {
|
||||
log.Printf("[WARN] Failed to insert read feedback to Gorse: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFavorites 获取收藏列表
|
||||
func (s *PostService) GetFavorites(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
return s.postRepo.GetFavorites(userID, page, pageSize)
|
||||
}
|
||||
|
||||
// Search 搜索帖子
|
||||
func (s *PostService) Search(ctx context.Context, keyword string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
return s.postRepo.Search(keyword, page, pageSize)
|
||||
}
|
||||
|
||||
// GetFollowingPosts 获取关注用户的帖子(带缓存)
|
||||
func (s *PostService) GetFollowingPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
cacheSettings := cache.GetSettings()
|
||||
postListTTL := cacheSettings.PostListTTL
|
||||
if postListTTL <= 0 {
|
||||
postListTTL = PostListTTL
|
||||
}
|
||||
nullTTL := cacheSettings.NullTTL
|
||||
if nullTTL <= 0 {
|
||||
nullTTL = PostListNullTTL
|
||||
}
|
||||
jitter := cacheSettings.JitterRatio
|
||||
if jitter <= 0 {
|
||||
jitter = PostListJitterRatio
|
||||
}
|
||||
|
||||
// 生成缓存键
|
||||
cacheKey := cache.PostListKey("follow", userID, page, pageSize)
|
||||
|
||||
result, err := cache.GetOrLoadTyped[*PostListResult](
|
||||
s.cache,
|
||||
cacheKey,
|
||||
postListTTL,
|
||||
jitter,
|
||||
nullTTL,
|
||||
func() (*PostListResult, error) {
|
||||
posts, total, err := s.postRepo.GetFollowingPosts(userID, page, pageSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PostListResult{Posts: posts, Total: total}, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if result == nil {
|
||||
return []*model.Post{}, 0, nil
|
||||
}
|
||||
return result.Posts, result.Total, nil
|
||||
}
|
||||
|
||||
// GetHotPosts 获取热门帖子(使用Gorse非个性化推荐)
|
||||
func (s *PostService) GetHotPosts(ctx context.Context, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
// 如果Gorse启用,使用自定义的非个性化推荐器
|
||||
if s.gorseClient.IsEnabled() {
|
||||
offset := (page - 1) * pageSize
|
||||
// 使用 most_liked_weekly 推荐器获取周热门
|
||||
// 多取1条用于判断是否还有下一页
|
||||
itemIDs, err := s.gorseClient.GetNonPersonalized(ctx, "most_liked_weekly", pageSize+1, offset, "")
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Gorse GetNonPersonalized failed: %v, fallback to database", err)
|
||||
return s.getHotPostsFromDB(ctx, page, pageSize)
|
||||
}
|
||||
if len(itemIDs) > 0 {
|
||||
hasNext := len(itemIDs) > pageSize
|
||||
if hasNext {
|
||||
itemIDs = itemIDs[:pageSize]
|
||||
}
|
||||
posts, err := s.postRepo.GetByIDs(itemIDs)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
// 近似 total:当 hasNext 为 true 时,按分页窗口估算,避免因脏数据/缺失数据导致总页数被低估
|
||||
estimatedTotal := int64(offset + len(posts))
|
||||
if hasNext {
|
||||
estimatedTotal = int64(offset + pageSize + 1)
|
||||
}
|
||||
return posts, estimatedTotal, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 降级:从数据库获取
|
||||
return s.getHotPostsFromDB(ctx, page, pageSize)
|
||||
}
|
||||
|
||||
// getHotPostsFromDB 从数据库获取热门帖子(降级路径)
|
||||
func (s *PostService) getHotPostsFromDB(ctx context.Context, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
// 直接查询数据库,不再使用本地缓存(Gorse失败降级时使用)
|
||||
posts, total, err := s.postRepo.GetHotPosts(page, pageSize)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return posts, total, nil
|
||||
}
|
||||
|
||||
// GetRecommendedPosts 获取推荐帖子
|
||||
func (s *PostService) GetRecommendedPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
// 如果Gorse未启用或用户未登录,降级为热门帖子
|
||||
if !s.gorseClient.IsEnabled() || userID == "" {
|
||||
return s.GetHotPosts(ctx, page, pageSize)
|
||||
}
|
||||
|
||||
// 计算偏移量
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
// 从Gorse获取推荐列表
|
||||
// 多取1条用于判断是否还有下一页
|
||||
itemIDs, err := s.gorseClient.GetRecommend(ctx, userID, pageSize+1, offset)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Gorse recommendation failed: %v, fallback to hot posts", err)
|
||||
return s.GetHotPosts(ctx, page, pageSize)
|
||||
}
|
||||
|
||||
// 如果没有推荐结果,降级为热门帖子
|
||||
if len(itemIDs) == 0 {
|
||||
return s.GetHotPosts(ctx, page, pageSize)
|
||||
}
|
||||
|
||||
hasNext := len(itemIDs) > pageSize
|
||||
if hasNext {
|
||||
itemIDs = itemIDs[:pageSize]
|
||||
}
|
||||
|
||||
// 根据ID列表查询帖子详情
|
||||
posts, err := s.postRepo.GetByIDs(itemIDs)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 近似 total:当 hasNext 为 true 时,按分页窗口估算,避免因脏数据/缺失数据导致总页数被低估
|
||||
estimatedTotal := int64(offset + len(posts))
|
||||
if hasNext {
|
||||
estimatedTotal = int64(offset + pageSize + 1)
|
||||
}
|
||||
return posts, estimatedTotal, nil
|
||||
}
|
||||
|
||||
// buildPostCategories 构建帖子的类别标签
|
||||
func (s *PostService) buildPostCategories(post *model.Post) []string {
|
||||
var categories []string
|
||||
|
||||
// 热度标签
|
||||
if post.ViewsCount > 1000 {
|
||||
categories = append(categories, "hot_high")
|
||||
} else if post.ViewsCount > 100 {
|
||||
categories = append(categories, "hot_medium")
|
||||
}
|
||||
|
||||
// 点赞标签
|
||||
if post.LikesCount > 100 {
|
||||
categories = append(categories, "likes_100+")
|
||||
} else if post.LikesCount > 50 {
|
||||
categories = append(categories, "likes_50+")
|
||||
} else if post.LikesCount > 10 {
|
||||
categories = append(categories, "likes_10+")
|
||||
}
|
||||
|
||||
// 评论标签
|
||||
if post.CommentsCount > 50 {
|
||||
categories = append(categories, "comments_50+")
|
||||
} else if post.CommentsCount > 10 {
|
||||
categories = append(categories, "comments_10+")
|
||||
}
|
||||
|
||||
// 时间标签
|
||||
age := time.Since(post.CreatedAt)
|
||||
if age < 24*time.Hour {
|
||||
categories = append(categories, "today")
|
||||
} else if age < 7*24*time.Hour {
|
||||
categories = append(categories, "this_week")
|
||||
} else if age < 30*24*time.Hour {
|
||||
categories = append(categories, "this_month")
|
||||
}
|
||||
|
||||
return categories
|
||||
}
|
||||
575
internal/service/push_service.go
Normal file
575
internal/service/push_service.go
Normal file
@@ -0,0 +1,575 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/websocket"
|
||||
"carrot_bbs/internal/repository"
|
||||
)
|
||||
|
||||
// 推送相关常量
|
||||
const (
|
||||
// DefaultPushTimeout 默认推送超时时间
|
||||
DefaultPushTimeout = 30 * time.Second
|
||||
// MaxRetryCount 最大重试次数
|
||||
MaxRetryCount = 3
|
||||
// DefaultExpiredTime 默认消息过期时间(24小时)
|
||||
DefaultExpiredTime = 24 * time.Hour
|
||||
// PushQueueSize 推送队列大小
|
||||
PushQueueSize = 1000
|
||||
)
|
||||
|
||||
// PushPriority 推送优先级
|
||||
type PushPriority int
|
||||
|
||||
const (
|
||||
PriorityLow PushPriority = 1 // 低优先级(营销消息等)
|
||||
PriorityNormal PushPriority = 5 // 普通优先级(系统通知)
|
||||
PriorityHigh PushPriority = 8 // 高优先级(聊天消息)
|
||||
PriorityCritical PushPriority = 10 // 最高优先级(重要系统通知)
|
||||
)
|
||||
|
||||
// PushService 推送服务接口
|
||||
type PushService interface {
|
||||
// 推送核心方法
|
||||
PushMessage(ctx context.Context, userID string, message *model.Message) error
|
||||
PushToUser(ctx context.Context, userID string, message *model.Message, priority int) error
|
||||
|
||||
// 系统消息推送
|
||||
PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error
|
||||
PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error
|
||||
PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error
|
||||
|
||||
// 系统通知推送(新接口,使用独立的 SystemNotification 模型)
|
||||
PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error
|
||||
|
||||
// 设备管理
|
||||
RegisterDevice(ctx context.Context, userID string, deviceID string, deviceType model.DeviceType, pushToken string) error
|
||||
UnregisterDevice(ctx context.Context, deviceID string) error
|
||||
UpdateDeviceToken(ctx context.Context, deviceID string, newPushToken string) error
|
||||
|
||||
// 推送记录管理
|
||||
CreatePushRecord(ctx context.Context, userID string, messageID string, channel model.PushChannel) (*model.PushRecord, error)
|
||||
GetPendingPushes(ctx context.Context, userID string) ([]*model.PushRecord, error)
|
||||
|
||||
// 后台任务
|
||||
StartPushWorker(ctx context.Context)
|
||||
StopPushWorker()
|
||||
}
|
||||
|
||||
// pushServiceImpl 推送服务实现
|
||||
type pushServiceImpl struct {
|
||||
pushRepo *repository.PushRecordRepository
|
||||
deviceRepo *repository.DeviceTokenRepository
|
||||
messageRepo *repository.MessageRepository
|
||||
wsManager *websocket.WebSocketManager
|
||||
|
||||
// 推送队列
|
||||
pushQueue chan *pushTask
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// pushTask 推送任务
|
||||
type pushTask struct {
|
||||
userID string
|
||||
message *model.Message
|
||||
priority int
|
||||
}
|
||||
|
||||
// NewPushService 创建推送服务
|
||||
func NewPushService(
|
||||
pushRepo *repository.PushRecordRepository,
|
||||
deviceRepo *repository.DeviceTokenRepository,
|
||||
messageRepo *repository.MessageRepository,
|
||||
wsManager *websocket.WebSocketManager,
|
||||
) PushService {
|
||||
return &pushServiceImpl{
|
||||
pushRepo: pushRepo,
|
||||
deviceRepo: deviceRepo,
|
||||
messageRepo: messageRepo,
|
||||
wsManager: wsManager,
|
||||
pushQueue: make(chan *pushTask, PushQueueSize),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// PushMessage 推送消息给用户
|
||||
func (s *pushServiceImpl) PushMessage(ctx context.Context, userID string, message *model.Message) error {
|
||||
return s.PushToUser(ctx, userID, message, int(PriorityNormal))
|
||||
}
|
||||
|
||||
// PushToUser 带优先级的推送
|
||||
func (s *pushServiceImpl) PushToUser(ctx context.Context, userID string, message *model.Message, priority int) error {
|
||||
// 首先尝试WebSocket推送(实时推送)
|
||||
if s.pushViaWebSocket(ctx, userID, message) {
|
||||
// WebSocket推送成功,记录推送状态
|
||||
record, err := s.CreatePushRecord(ctx, userID, message.ID, model.PushChannelWebSocket)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create push record: %w", err)
|
||||
}
|
||||
record.MarkPushed()
|
||||
if err := s.pushRepo.Update(record); err != nil {
|
||||
return fmt.Errorf("failed to update push record: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WebSocket推送失败,加入推送队列等待移动端推送
|
||||
select {
|
||||
case s.pushQueue <- &pushTask{
|
||||
userID: userID,
|
||||
message: message,
|
||||
priority: priority,
|
||||
}:
|
||||
return nil
|
||||
default:
|
||||
// 队列已满,直接创建待推送记录
|
||||
_, err := s.CreatePushRecord(ctx, userID, message.ID, model.PushChannelFCM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create pending push record: %w", err)
|
||||
}
|
||||
return errors.New("push queue is full, message queued for later delivery")
|
||||
}
|
||||
}
|
||||
|
||||
// pushViaWebSocket 通过WebSocket推送消息
|
||||
// 返回true表示推送成功,false表示用户不在线
|
||||
func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, message *model.Message) bool {
|
||||
if s.wsManager == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 判断是否为系统消息/通知消息
|
||||
if message.IsSystemMessage() || message.Category == model.CategoryNotification {
|
||||
// 使用 NotificationMessage 格式推送系统通知
|
||||
// 从 segments 中提取文本内容
|
||||
content := dto.ExtractTextContentFromModel(message.Segments)
|
||||
|
||||
notification := &websocket.NotificationMessage{
|
||||
ID: fmt.Sprintf("%s", message.ID),
|
||||
Type: string(message.SystemType),
|
||||
Content: content,
|
||||
Extra: make(map[string]interface{}),
|
||||
CreatedAt: message.CreatedAt.UnixMilli(),
|
||||
}
|
||||
|
||||
// 填充额外数据
|
||||
if message.ExtraData != nil {
|
||||
notification.Extra["actor_id"] = message.ExtraData.ActorID
|
||||
notification.Extra["actor_name"] = message.ExtraData.ActorName
|
||||
notification.Extra["avatar_url"] = message.ExtraData.AvatarURL
|
||||
notification.Extra["target_id"] = message.ExtraData.TargetID
|
||||
notification.Extra["target_type"] = message.ExtraData.TargetType
|
||||
notification.Extra["action_url"] = message.ExtraData.ActionURL
|
||||
notification.Extra["action_time"] = message.ExtraData.ActionTime
|
||||
|
||||
// 设置触发用户信息
|
||||
if message.ExtraData.ActorID > 0 {
|
||||
notification.TriggerUser = &websocket.NotificationUser{
|
||||
ID: fmt.Sprintf("%d", message.ExtraData.ActorID),
|
||||
Username: message.ExtraData.ActorName,
|
||||
Avatar: message.ExtraData.AvatarURL,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return true
|
||||
}
|
||||
|
||||
// 构建普通聊天消息的 WebSocket 消息 - 使用新的 WSEventResponse 格式
|
||||
// 获取会话类型 (private/group)
|
||||
detailType := "private"
|
||||
if message.ConversationID != "" {
|
||||
// 从会话中获取类型,需要查询数据库或从消息中判断
|
||||
// 这里暂时默认为 private,group 类型需要额外逻辑
|
||||
}
|
||||
|
||||
// 直接使用 message.Segments
|
||||
segments := message.Segments
|
||||
|
||||
event := &dto.WSEventResponse{
|
||||
ID: fmt.Sprintf("%s", message.ID),
|
||||
Time: message.CreatedAt.UnixMilli(),
|
||||
Type: "message",
|
||||
DetailType: detailType,
|
||||
Seq: fmt.Sprintf("%d", message.Seq),
|
||||
Segments: segments,
|
||||
SenderID: message.SenderID,
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, event)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return true
|
||||
}
|
||||
|
||||
// pushViaFCM 通过FCM推送(预留接口)
|
||||
func (s *pushServiceImpl) pushViaFCM(ctx context.Context, deviceToken *model.DeviceToken, message *model.Message) error {
|
||||
// TODO: 实现FCM推送
|
||||
// 1. 构建FCM消息
|
||||
// 2. 调用Firebase Admin SDK发送消息
|
||||
// 3. 处理发送结果
|
||||
return errors.New("FCM push not implemented")
|
||||
}
|
||||
|
||||
// pushViaAPNs 通过APNs推送(预留接口)
|
||||
func (s *pushServiceImpl) pushViaAPNs(ctx context.Context, deviceToken *model.DeviceToken, message *model.Message) error {
|
||||
// TODO: 实现APNs推送
|
||||
// 1. 构建APNs消息
|
||||
// 2. 调用APNs SDK发送消息
|
||||
// 3. 处理发送结果
|
||||
return errors.New("APNs push not implemented")
|
||||
}
|
||||
|
||||
// RegisterDevice 注册设备
|
||||
func (s *pushServiceImpl) RegisterDevice(ctx context.Context, userID string, deviceID string, deviceType model.DeviceType, pushToken string) error {
|
||||
deviceToken := &model.DeviceToken{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
DeviceType: deviceType,
|
||||
PushToken: pushToken,
|
||||
IsActive: true,
|
||||
}
|
||||
deviceToken.UpdateLastUsed()
|
||||
|
||||
return s.deviceRepo.Upsert(deviceToken)
|
||||
}
|
||||
|
||||
// UnregisterDevice 注销设备
|
||||
func (s *pushServiceImpl) UnregisterDevice(ctx context.Context, deviceID string) error {
|
||||
return s.deviceRepo.Deactivate(deviceID)
|
||||
}
|
||||
|
||||
// UpdateDeviceToken 更新设备Token
|
||||
func (s *pushServiceImpl) UpdateDeviceToken(ctx context.Context, deviceID string, newPushToken string) error {
|
||||
deviceToken, err := s.deviceRepo.GetByDeviceID(deviceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("device not found: %w", err)
|
||||
}
|
||||
|
||||
deviceToken.PushToken = newPushToken
|
||||
deviceToken.Activate()
|
||||
|
||||
return s.deviceRepo.Update(deviceToken)
|
||||
}
|
||||
|
||||
// CreatePushRecord 创建推送记录
|
||||
func (s *pushServiceImpl) CreatePushRecord(ctx context.Context, userID string, messageID string, channel model.PushChannel) (*model.PushRecord, error) {
|
||||
expiredAt := time.Now().Add(DefaultExpiredTime)
|
||||
record := &model.PushRecord{
|
||||
UserID: userID,
|
||||
MessageID: messageID,
|
||||
PushChannel: channel,
|
||||
PushStatus: model.PushStatusPending,
|
||||
MaxRetry: MaxRetryCount,
|
||||
ExpiredAt: &expiredAt,
|
||||
}
|
||||
|
||||
if err := s.pushRepo.Create(record); err != nil {
|
||||
return nil, fmt.Errorf("failed to create push record: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// GetPendingPushes 获取待推送记录
|
||||
func (s *pushServiceImpl) GetPendingPushes(ctx context.Context, userID string) ([]*model.PushRecord, error) {
|
||||
return s.pushRepo.GetByUserID(userID, 100, 0)
|
||||
}
|
||||
|
||||
// StartPushWorker 启动推送工作协程
|
||||
func (s *pushServiceImpl) StartPushWorker(ctx context.Context) {
|
||||
go s.processPushQueue()
|
||||
go s.retryFailedPushes()
|
||||
}
|
||||
|
||||
// StopPushWorker 停止推送工作协程
|
||||
func (s *pushServiceImpl) StopPushWorker() {
|
||||
close(s.stopChan)
|
||||
}
|
||||
|
||||
// processPushQueue 处理推送队列
|
||||
func (s *pushServiceImpl) processPushQueue() {
|
||||
for {
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
return
|
||||
case task := <-s.pushQueue:
|
||||
s.processPushTask(task)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processPushTask 处理单个推送任务
|
||||
func (s *pushServiceImpl) processPushTask(task *pushTask) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultPushTimeout)
|
||||
defer cancel()
|
||||
|
||||
// 获取用户活跃设备
|
||||
devices, err := s.deviceRepo.GetActiveByUserID(task.userID)
|
||||
if err != nil || len(devices) == 0 {
|
||||
// 没有可用设备,创建待推送记录
|
||||
s.CreatePushRecord(ctx, task.userID, task.message.ID, model.PushChannelFCM)
|
||||
return
|
||||
}
|
||||
|
||||
// 对每个设备创建推送记录并尝试推送
|
||||
for _, device := range devices {
|
||||
record, err := s.CreatePushRecord(ctx, task.userID, task.message.ID, s.getChannelForDevice(device))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var pushErr error
|
||||
switch {
|
||||
case device.IsIOS():
|
||||
pushErr = s.pushViaAPNs(ctx, device, task.message)
|
||||
case device.IsAndroid():
|
||||
pushErr = s.pushViaFCM(ctx, device, task.message)
|
||||
default:
|
||||
// Web设备只支持WebSocket
|
||||
continue
|
||||
}
|
||||
|
||||
if pushErr != nil {
|
||||
record.MarkFailed(pushErr.Error())
|
||||
} else {
|
||||
record.MarkPushed()
|
||||
}
|
||||
s.pushRepo.Update(record)
|
||||
}
|
||||
}
|
||||
|
||||
// getChannelForDevice 根据设备类型获取推送通道
|
||||
func (s *pushServiceImpl) getChannelForDevice(device *model.DeviceToken) model.PushChannel {
|
||||
switch device.DeviceType {
|
||||
case model.DeviceTypeIOS:
|
||||
return model.PushChannelAPNs
|
||||
case model.DeviceTypeAndroid:
|
||||
return model.PushChannelFCM
|
||||
default:
|
||||
return model.PushChannelWebSocket
|
||||
}
|
||||
}
|
||||
|
||||
// retryFailedPushes 重试失败的推送
|
||||
func (s *pushServiceImpl) retryFailedPushes() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.doRetry()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// doRetry 执行重试
|
||||
func (s *pushServiceImpl) doRetry() {
|
||||
ctx := context.Background()
|
||||
|
||||
// 获取失败待重试的推送
|
||||
records, err := s.pushRepo.GetFailedPushesForRetry(100)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
// 检查是否过期
|
||||
if record.IsExpired() {
|
||||
record.MarkExpired()
|
||||
s.pushRepo.Update(record)
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取消息
|
||||
message, err := s.messageRepo.GetMessageByID(record.MessageID)
|
||||
if err != nil {
|
||||
record.MarkFailed("message not found")
|
||||
s.pushRepo.Update(record)
|
||||
continue
|
||||
}
|
||||
|
||||
// 尝试WebSocket推送
|
||||
if s.pushViaWebSocket(ctx, record.UserID, message) {
|
||||
record.MarkDelivered()
|
||||
s.pushRepo.Update(record)
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取设备并尝试移动端推送
|
||||
if record.DeviceToken != "" {
|
||||
device, err := s.deviceRepo.GetByPushToken(record.DeviceToken)
|
||||
if err != nil {
|
||||
record.MarkFailed("device not found")
|
||||
s.pushRepo.Update(record)
|
||||
continue
|
||||
}
|
||||
|
||||
var pushErr error
|
||||
switch {
|
||||
case device.IsIOS():
|
||||
pushErr = s.pushViaAPNs(ctx, device, message)
|
||||
case device.IsAndroid():
|
||||
pushErr = s.pushViaFCM(ctx, device, message)
|
||||
}
|
||||
|
||||
if pushErr != nil {
|
||||
record.MarkFailed(pushErr.Error())
|
||||
} else {
|
||||
record.MarkPushed()
|
||||
}
|
||||
s.pushRepo.Update(record)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PushSystemMessage 推送系统消息
|
||||
func (s *pushServiceImpl) PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error {
|
||||
// 首先尝试WebSocket推送
|
||||
if s.pushSystemViaWebSocket(ctx, userID, msgType, title, content, data) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 用户不在线,创建待推送记录(移动端上线后可通过其他方式获取)
|
||||
// 系统消息通常不需要离线推送,客户端上线后会主动拉取
|
||||
return errors.New("user is offline, system message will be available on next sync")
|
||||
}
|
||||
|
||||
// pushSystemViaWebSocket 通过WebSocket推送系统消息
|
||||
func (s *pushServiceImpl) pushSystemViaWebSocket(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) bool {
|
||||
if s.wsManager == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
sysMsg := &websocket.SystemMessage{
|
||||
Type: msgType,
|
||||
Title: title,
|
||||
Content: content,
|
||||
Data: data,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return true
|
||||
}
|
||||
|
||||
// PushNotification 推送通知消息
|
||||
func (s *pushServiceImpl) PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error {
|
||||
// 首先尝试WebSocket推送
|
||||
if s.pushNotificationViaWebSocket(ctx, userID, notification) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 用户不在线,创建待推送记录
|
||||
// 通知消息可以等用户上线后拉取
|
||||
return errors.New("user is offline, notification will be available on next sync")
|
||||
}
|
||||
|
||||
// pushNotificationViaWebSocket 通过WebSocket推送通知消息
|
||||
func (s *pushServiceImpl) pushNotificationViaWebSocket(ctx context.Context, userID string, notification *websocket.NotificationMessage) bool {
|
||||
if s.wsManager == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
if notification.CreatedAt == 0 {
|
||||
notification.CreatedAt = time.Now().UnixMilli()
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return true
|
||||
}
|
||||
|
||||
// PushAnnouncement 广播公告消息
|
||||
func (s *pushServiceImpl) PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error {
|
||||
if s.wsManager == nil {
|
||||
return errors.New("websocket manager not available")
|
||||
}
|
||||
|
||||
if announcement.CreatedAt == 0 {
|
||||
announcement.CreatedAt = time.Now().UnixMilli()
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
|
||||
s.wsManager.Broadcast(wsMsg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PushSystemNotification 推送系统通知(使用独立的 SystemNotification 模型)
|
||||
func (s *pushServiceImpl) PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error {
|
||||
// 首先尝试WebSocket推送
|
||||
if s.pushSystemNotificationViaWebSocket(ctx, userID, notification) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 用户不在线,系统通知已存储在数据库中,用户上线后会主动拉取
|
||||
return nil
|
||||
}
|
||||
|
||||
// pushSystemNotificationViaWebSocket 通过WebSocket推送系统通知
|
||||
func (s *pushServiceImpl) pushSystemNotificationViaWebSocket(ctx context.Context, userID string, notification *model.SystemNotification) bool {
|
||||
if s.wsManager == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 构建 WebSocket 通知消息
|
||||
wsNotification := &websocket.NotificationMessage{
|
||||
ID: fmt.Sprintf("%d", notification.ID),
|
||||
Type: string(notification.Type),
|
||||
Title: notification.Title,
|
||||
Content: notification.Content,
|
||||
Extra: make(map[string]interface{}),
|
||||
CreatedAt: notification.CreatedAt.UnixMilli(),
|
||||
}
|
||||
|
||||
// 填充额外数据
|
||||
if notification.ExtraData != nil {
|
||||
wsNotification.Extra["actor_id_str"] = notification.ExtraData.ActorIDStr
|
||||
wsNotification.Extra["actor_name"] = notification.ExtraData.ActorName
|
||||
wsNotification.Extra["avatar_url"] = notification.ExtraData.AvatarURL
|
||||
wsNotification.Extra["target_id"] = notification.ExtraData.TargetID
|
||||
wsNotification.Extra["target_type"] = notification.ExtraData.TargetType
|
||||
wsNotification.Extra["action_url"] = notification.ExtraData.ActionURL
|
||||
wsNotification.Extra["action_time"] = notification.ExtraData.ActionTime
|
||||
|
||||
// 设置触发用户信息
|
||||
if notification.ExtraData.ActorIDStr != "" {
|
||||
wsNotification.TriggerUser = &websocket.NotificationUser{
|
||||
ID: notification.ExtraData.ActorIDStr,
|
||||
Username: notification.ExtraData.ActorName,
|
||||
Avatar: notification.ExtraData.AvatarURL,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, wsNotification)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return true
|
||||
}
|
||||
559
internal/service/sensitive_service.go
Normal file
559
internal/service/sensitive_service.go
Normal file
@@ -0,0 +1,559 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"carrot_bbs/internal/model"
|
||||
redisclient "carrot_bbs/internal/pkg/redis"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ==================== DFA 敏感词过滤实现 ====================
|
||||
|
||||
// SensitiveNode 敏感词树节点
|
||||
type SensitiveNode struct {
|
||||
// 子节点映射
|
||||
Children map[rune]*SensitiveNode
|
||||
// 是否为敏感词结尾
|
||||
IsEnd bool
|
||||
// 敏感词信息(仅在 IsEnd 为 true 时有效)
|
||||
Word string
|
||||
Level model.SensitiveWordLevel
|
||||
Category model.SensitiveWordCategory
|
||||
}
|
||||
|
||||
// NewSensitiveNode 创建新的敏感词节点
|
||||
func NewSensitiveNode() *SensitiveNode {
|
||||
return &SensitiveNode{
|
||||
Children: make(map[rune]*SensitiveNode),
|
||||
IsEnd: false,
|
||||
}
|
||||
}
|
||||
|
||||
// SensitiveWordTree 敏感词树
|
||||
type SensitiveWordTree struct {
|
||||
root *SensitiveNode
|
||||
wordCount int
|
||||
mu sync.RWMutex
|
||||
lastReload time.Time
|
||||
}
|
||||
|
||||
// NewSensitiveWordTree 创建新的敏感词树
|
||||
func NewSensitiveWordTree() *SensitiveWordTree {
|
||||
return &SensitiveWordTree{
|
||||
root: NewSensitiveNode(),
|
||||
wordCount: 0,
|
||||
lastReload: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// AddWord 添加敏感词到树中
|
||||
func (t *SensitiveWordTree) AddWord(word string, level model.SensitiveWordLevel, category model.SensitiveWordCategory) {
|
||||
if word == "" {
|
||||
return
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
node := t.root
|
||||
// 转换为小写进行匹配(不区分大小写)
|
||||
lowerWord := strings.ToLower(word)
|
||||
runes := []rune(lowerWord)
|
||||
|
||||
for _, r := range runes {
|
||||
child, exists := node.Children[r]
|
||||
if !exists {
|
||||
child = NewSensitiveNode()
|
||||
node.Children[r] = child
|
||||
}
|
||||
node = child
|
||||
}
|
||||
|
||||
// 如果不是已存在的敏感词,则计数+1
|
||||
if !node.IsEnd {
|
||||
t.wordCount++
|
||||
}
|
||||
|
||||
node.IsEnd = true
|
||||
node.Word = word
|
||||
node.Level = level
|
||||
node.Category = category
|
||||
}
|
||||
|
||||
// RemoveWord 从树中移除敏感词
|
||||
func (t *SensitiveWordTree) RemoveWord(word string) {
|
||||
if word == "" {
|
||||
return
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
lowerWord := strings.ToLower(word)
|
||||
runes := []rune(lowerWord)
|
||||
|
||||
// 查找节点
|
||||
node := t.root
|
||||
for _, r := range runes {
|
||||
child, exists := node.Children[r]
|
||||
if !exists {
|
||||
return // 敏感词不存在
|
||||
}
|
||||
node = child
|
||||
}
|
||||
|
||||
if node.IsEnd {
|
||||
node.IsEnd = false
|
||||
node.Word = ""
|
||||
t.wordCount--
|
||||
}
|
||||
}
|
||||
|
||||
// Check 检查文本是否包含敏感词,返回是否包含及敏感词列表
|
||||
func (t *SensitiveWordTree) Check(text string) (bool, []string) {
|
||||
if text == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
var foundWords []string
|
||||
runes := []rune(strings.ToLower(text))
|
||||
length := len(runes)
|
||||
|
||||
// 用于标记已找到的敏感词位置,避免重复计算
|
||||
marked := make([]bool, length)
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
// 从当前位置开始搜索
|
||||
node := t.root
|
||||
matchEnd := -1
|
||||
matchWord := ""
|
||||
|
||||
for j := i; j < length; j++ {
|
||||
child, exists := node.Children[runes[j]]
|
||||
if !exists {
|
||||
break
|
||||
}
|
||||
node = child
|
||||
|
||||
if node.IsEnd {
|
||||
matchEnd = j
|
||||
matchWord = node.Word
|
||||
}
|
||||
}
|
||||
|
||||
// 标记找到的敏感词位置
|
||||
if matchEnd >= 0 && !marked[i] {
|
||||
for k := i; k <= matchEnd; k++ {
|
||||
marked[k] = true
|
||||
}
|
||||
foundWords = append(foundWords, matchWord)
|
||||
}
|
||||
}
|
||||
|
||||
return len(foundWords) > 0, foundWords
|
||||
}
|
||||
|
||||
// Replace 替换文本中的敏感词
|
||||
func (t *SensitiveWordTree) Replace(text string, repl string) string {
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
runes := []rune(text)
|
||||
length := len(runes)
|
||||
result := make([]rune, 0, length)
|
||||
|
||||
// 用于标记已替换的位置
|
||||
marked := make([]bool, length)
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
if marked[i] {
|
||||
continue
|
||||
}
|
||||
|
||||
// 从当前位置开始搜索
|
||||
node := t.root
|
||||
matchEnd := -1
|
||||
|
||||
for j := i; j < length; j++ {
|
||||
child, exists := node.Children[runes[j]]
|
||||
if !exists {
|
||||
break
|
||||
}
|
||||
node = child
|
||||
|
||||
if node.IsEnd {
|
||||
matchEnd = j
|
||||
}
|
||||
}
|
||||
|
||||
if matchEnd >= 0 {
|
||||
// 标记已替换的位置
|
||||
for k := i; k <= matchEnd; k++ {
|
||||
marked[k] = true
|
||||
}
|
||||
// 追加替换符
|
||||
replRunes := []rune(repl)
|
||||
result = append(result, replRunes...)
|
||||
// 跳过已匹配的字符
|
||||
i = matchEnd
|
||||
} else {
|
||||
// 追加原字符
|
||||
result = append(result, runes[i])
|
||||
}
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// WordCount 获取敏感词数量
|
||||
func (t *SensitiveWordTree) WordCount() int {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.wordCount
|
||||
}
|
||||
|
||||
// ==================== 敏感词服务实现 ====================
|
||||
|
||||
// SensitiveService 敏感词服务接口
|
||||
type SensitiveService interface {
|
||||
// Check 检查文本是否包含敏感词
|
||||
Check(ctx context.Context, text string) (bool, []string)
|
||||
// Replace 替换敏感词
|
||||
Replace(ctx context.Context, text string, repl string) string
|
||||
// AddWord 添加敏感词
|
||||
AddWord(ctx context.Context, word string, category string, level int) error
|
||||
// RemoveWord 移除敏感词
|
||||
RemoveWord(ctx context.Context, word string) error
|
||||
// Reload 重新加载敏感词库
|
||||
Reload(ctx context.Context) error
|
||||
// GetWordCount 获取敏感词数量
|
||||
GetWordCount(ctx context.Context) int
|
||||
}
|
||||
|
||||
// sensitiveServiceImpl 敏感词服务实现
|
||||
type sensitiveServiceImpl struct {
|
||||
tree *SensitiveWordTree
|
||||
db *gorm.DB
|
||||
redis *redisclient.Client
|
||||
config *SensitiveConfig
|
||||
mu sync.RWMutex
|
||||
replaceStr string
|
||||
}
|
||||
|
||||
// SensitiveConfig 敏感词服务配置
|
||||
type SensitiveConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" yaml:"enabled"`
|
||||
ReplaceStr string `mapstructure:"replace_str" yaml:"replace_str"`
|
||||
// 最小匹配长度
|
||||
MinMatchLen int `mapstructure:"min_match_len" yaml:"min_match_len"`
|
||||
// 是否从数据库加载
|
||||
LoadFromDB bool `mapstructure:"load_from_db" yaml:"load_from_db"`
|
||||
// 是否从Redis加载
|
||||
LoadFromRedis bool `mapstructure:"load_from_redis" yaml:"load_from_redis"`
|
||||
// Redis键前缀
|
||||
RedisKeyPrefix string `mapstructure:"redis_key_prefix" yaml:"redis_key_prefix"`
|
||||
}
|
||||
|
||||
// NewSensitiveService 创建敏感词服务
|
||||
func NewSensitiveService(db *gorm.DB, redisClient *redisclient.Client, config *SensitiveConfig) SensitiveService {
|
||||
s := &sensitiveServiceImpl{
|
||||
tree: NewSensitiveWordTree(),
|
||||
db: db,
|
||||
redis: redisClient,
|
||||
config: config,
|
||||
replaceStr: config.ReplaceStr,
|
||||
}
|
||||
|
||||
// 如果未设置替换符,默认使用 ***
|
||||
if s.replaceStr == "" {
|
||||
s.replaceStr = "***"
|
||||
}
|
||||
|
||||
// 初始化加载敏感词
|
||||
if config.LoadFromDB {
|
||||
if err := s.loadFromDB(context.Background()); err != nil {
|
||||
log.Printf("Failed to load sensitive words from database: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if config.LoadFromRedis && redisClient != nil {
|
||||
if err := s.loadFromRedis(context.Background()); err != nil {
|
||||
log.Printf("Failed to load sensitive words from redis: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Check 检查文本是否包含敏感词
|
||||
func (s *sensitiveServiceImpl) Check(ctx context.Context, text string) (bool, []string) {
|
||||
if !s.config.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
if text == "" {
|
||||
return false, nil
|
||||
}
|
||||
return s.tree.Check(text)
|
||||
}
|
||||
|
||||
// Replace 替换敏感词
|
||||
func (s *sensitiveServiceImpl) Replace(ctx context.Context, text string, repl string) string {
|
||||
if !s.config.Enabled {
|
||||
return text
|
||||
}
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
|
||||
// 如果未指定替换符,使用默认替换符
|
||||
if repl == "" {
|
||||
repl = s.replaceStr
|
||||
}
|
||||
|
||||
return s.tree.Replace(text, repl)
|
||||
}
|
||||
|
||||
// AddWord 添加敏感词
|
||||
func (s *sensitiveServiceImpl) AddWord(ctx context.Context, word string, category string, level int) error {
|
||||
if word == "" {
|
||||
return fmt.Errorf("word cannot be empty")
|
||||
}
|
||||
|
||||
// 转换为敏感词级别
|
||||
wordLevel := model.SensitiveWordLevel(level)
|
||||
if wordLevel < 1 || wordLevel > 3 {
|
||||
wordLevel = model.SensitiveWordLevelLow
|
||||
}
|
||||
|
||||
// 转换为敏感词分类
|
||||
wordCategory := model.SensitiveWordCategory(category)
|
||||
if wordCategory == "" {
|
||||
wordCategory = model.SensitiveWordCategoryOther
|
||||
}
|
||||
|
||||
// 添加到树
|
||||
s.tree.AddWord(word, wordLevel, wordCategory)
|
||||
|
||||
// 持久化到数据库
|
||||
if s.db != nil {
|
||||
sensitiveWord := model.SensitiveWord{
|
||||
Word: word,
|
||||
Category: wordCategory,
|
||||
Level: wordLevel,
|
||||
IsActive: true,
|
||||
}
|
||||
|
||||
// 使用 upsert 逻辑
|
||||
var existing model.SensitiveWord
|
||||
result := s.db.Where("word = ?", word).First(&existing)
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
if err := s.db.Create(&sensitiveWord).Error; err != nil {
|
||||
log.Printf("Failed to save sensitive word to database: %v", err)
|
||||
}
|
||||
} else if result.Error == nil {
|
||||
// 更新已存在的记录
|
||||
existing.Category = wordCategory
|
||||
existing.Level = wordLevel
|
||||
existing.IsActive = true
|
||||
if err := s.db.Save(&existing).Error; err != nil {
|
||||
log.Printf("Failed to update sensitive word in database: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 同步到 Redis
|
||||
if s.redis != nil && s.config.RedisKeyPrefix != "" {
|
||||
key := fmt.Sprintf("%s:%s", s.config.RedisKeyPrefix, word)
|
||||
data := map[string]interface{}{
|
||||
"word": word,
|
||||
"category": category,
|
||||
"level": level,
|
||||
}
|
||||
jsonData, _ := json.Marshal(data)
|
||||
s.redis.Set(ctx, key, jsonData, 0)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveWord 移除敏感词
|
||||
func (s *sensitiveServiceImpl) RemoveWord(ctx context.Context, word string) error {
|
||||
if word == "" {
|
||||
return fmt.Errorf("word cannot be empty")
|
||||
}
|
||||
|
||||
// 从树中移除
|
||||
s.tree.RemoveWord(word)
|
||||
|
||||
// 从数据库中标记为不活跃
|
||||
if s.db != nil {
|
||||
result := s.db.Model(&model.SensitiveWord{}).Where("word = ?", word).Update("is_active", false)
|
||||
if result.Error != nil {
|
||||
log.Printf("Failed to deactivate sensitive word in database: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
// 从 Redis 中删除
|
||||
if s.redis != nil && s.config.RedisKeyPrefix != "" {
|
||||
key := fmt.Sprintf("%s:%s", s.config.RedisKeyPrefix, word)
|
||||
s.redis.Del(ctx, key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reload 重新加载敏感词库
|
||||
func (s *sensitiveServiceImpl) Reload(ctx context.Context) error {
|
||||
// 清空现有树
|
||||
s.tree = NewSensitiveWordTree()
|
||||
|
||||
// 从数据库加载
|
||||
if s.config.LoadFromDB {
|
||||
if err := s.loadFromDB(ctx); err != nil {
|
||||
return fmt.Errorf("failed to load from database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 从 Redis 加载
|
||||
if s.config.LoadFromRedis && s.redis != nil {
|
||||
if err := s.loadFromRedis(ctx); err != nil {
|
||||
return fmt.Errorf("failed to load from redis: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetWordCount 获取敏感词数量
|
||||
func (s *sensitiveServiceImpl) GetWordCount(ctx context.Context) int {
|
||||
return s.tree.WordCount()
|
||||
}
|
||||
|
||||
// loadFromDB 从数据库加载敏感词
|
||||
func (s *sensitiveServiceImpl) loadFromDB(ctx context.Context) error {
|
||||
if s.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var words []model.SensitiveWord
|
||||
if err := s.db.Where("is_active = ?", true).Find(&words).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, word := range words {
|
||||
s.tree.AddWord(word.Word, word.Level, word.Category)
|
||||
}
|
||||
|
||||
log.Printf("Loaded %d sensitive words from database", len(words))
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadFromRedis 从 Redis 加载敏感词
|
||||
func (s *sensitiveServiceImpl) loadFromRedis(ctx context.Context) error {
|
||||
if s.redis == nil || s.config.RedisKeyPrefix == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用 SCAN 命令代替 KEYS,避免阻塞
|
||||
pattern := fmt.Sprintf("%s:*", s.config.RedisKeyPrefix)
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := s.redis.GetClient().Scan(ctx, cursor, pattern, 100).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
data, err := s.redis.Get(ctx, key)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var wordData map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(data), &wordData); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
word, _ := wordData["word"].(string)
|
||||
category, _ := wordData["category"].(string)
|
||||
level, _ := wordData["level"].(float64)
|
||||
|
||||
if word != "" {
|
||||
s.tree.AddWord(word, model.SensitiveWordLevel(int(level)), model.SensitiveWordCategory(category))
|
||||
}
|
||||
}
|
||||
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== 辅助函数 ====================
|
||||
|
||||
// ContainsSensitiveWord 快速检查文本是否包含敏感词
|
||||
func ContainsSensitiveWord(text string, tree *SensitiveWordTree) bool {
|
||||
if tree == nil || text == "" {
|
||||
return false
|
||||
}
|
||||
hasSensitive, _ := tree.Check(text)
|
||||
return hasSensitive
|
||||
}
|
||||
|
||||
// FilterSensitiveWords 过滤敏感词并返回替换后的文本
|
||||
func FilterSensitiveWords(text string, tree *SensitiveWordTree, repl string) string {
|
||||
if tree == nil || text == "" {
|
||||
return text
|
||||
}
|
||||
if repl == "" {
|
||||
repl = "***"
|
||||
}
|
||||
return tree.Replace(text, repl)
|
||||
}
|
||||
|
||||
// ValidateTextLength 验证文本长度是否合法
|
||||
func ValidateTextLength(text string, minLen, maxLen int) bool {
|
||||
length := utf8.RuneCountInString(text)
|
||||
return length >= minLen && length <= maxLen
|
||||
}
|
||||
|
||||
// SanitizeText 清理文本,移除多余空白字符
|
||||
func SanitizeText(text string) string {
|
||||
// 替换多个连续空白字符为单个空格
|
||||
spaceReg := regexp.MustCompile(`\s+`)
|
||||
text = spaceReg.ReplaceAllString(text, " ")
|
||||
// 去除首尾空白
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
// ==================== 默认敏感词列表 ====================
|
||||
|
||||
// DefaultSensitiveWords 返回默认敏感词列表(示例)
|
||||
func DefaultSensitiveWords() map[string]struct{} {
|
||||
return map[string]struct{}{
|
||||
// 示例敏感词,实际需要从数据库或配置加载
|
||||
"测试敏感词1": {},
|
||||
"测试敏感词2": {},
|
||||
"测试敏感词3": {},
|
||||
}
|
||||
}
|
||||
139
internal/service/sticker_service.go
Normal file
139
internal/service/sticker_service.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/repository"
|
||||
"errors"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrStickerAlreadyExists = errors.New("sticker already exists")
|
||||
ErrInvalidStickerURL = errors.New("invalid sticker url")
|
||||
)
|
||||
|
||||
// StickerService 自定义表情服务接口
|
||||
type StickerService interface {
|
||||
// 获取用户的所有表情
|
||||
GetUserStickers(userID string) ([]model.UserSticker, error)
|
||||
// 添加表情
|
||||
AddSticker(userID string, url string, width, height int) (*model.UserSticker, error)
|
||||
// 删除表情
|
||||
DeleteSticker(userID string, stickerID string) error
|
||||
// 检查表情是否已存在
|
||||
CheckExists(userID string, url string) (bool, error)
|
||||
// 重新排序
|
||||
ReorderStickers(userID string, orders map[string]int) error
|
||||
// 获取用户表情数量
|
||||
GetStickerCount(userID string) (int64, error)
|
||||
}
|
||||
|
||||
// stickerService 自定义表情服务实现
|
||||
type stickerService struct {
|
||||
stickerRepo repository.StickerRepository
|
||||
}
|
||||
|
||||
// NewStickerService 创建自定义表情服务
|
||||
func NewStickerService(stickerRepo repository.StickerRepository) StickerService {
|
||||
return &stickerService{
|
||||
stickerRepo: stickerRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserStickers 获取用户的所有表情
|
||||
func (s *stickerService) GetUserStickers(userID string) ([]model.UserSticker, error) {
|
||||
stickers, err := s.stickerRepo.GetByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 兼容历史脏数据:过滤本地文件 URI,避免客户端加载 file:// 报错
|
||||
filtered := make([]model.UserSticker, 0, len(stickers))
|
||||
for _, sticker := range stickers {
|
||||
if isValidStickerURL(sticker.URL) {
|
||||
filtered = append(filtered, sticker)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// AddSticker 添加表情
|
||||
func (s *stickerService) AddSticker(userID string, url string, width, height int) (*model.UserSticker, error) {
|
||||
if !isValidStickerURL(url) {
|
||||
return nil, ErrInvalidStickerURL
|
||||
}
|
||||
|
||||
// 检查是否已存在
|
||||
exists, err := s.stickerRepo.Exists(userID, url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrStickerAlreadyExists
|
||||
}
|
||||
|
||||
// 获取当前数量用于设置排序
|
||||
count, err := s.stickerRepo.CountByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sticker := &model.UserSticker{
|
||||
UserID: userID,
|
||||
URL: url,
|
||||
Width: width,
|
||||
Height: height,
|
||||
SortOrder: int(count), // 新表情添加到末尾
|
||||
}
|
||||
|
||||
if err := s.stickerRepo.Create(sticker); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sticker, nil
|
||||
}
|
||||
|
||||
func isValidStickerURL(raw string) bool {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
scheme := strings.ToLower(parsed.Scheme)
|
||||
return scheme == "http" || scheme == "https"
|
||||
}
|
||||
|
||||
// DeleteSticker 删除表情
|
||||
func (s *stickerService) DeleteSticker(userID string, stickerID string) error {
|
||||
// 先检查表情是否属于该用户
|
||||
sticker, err := s.stickerRepo.GetByID(stickerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sticker.UserID != userID {
|
||||
return errors.New("sticker not found")
|
||||
}
|
||||
|
||||
return s.stickerRepo.Delete(stickerID)
|
||||
}
|
||||
|
||||
// CheckExists 检查表情是否已存在
|
||||
func (s *stickerService) CheckExists(userID string, url string) (bool, error) {
|
||||
return s.stickerRepo.Exists(userID, url)
|
||||
}
|
||||
|
||||
// ReorderStickers 重新排序
|
||||
func (s *stickerService) ReorderStickers(userID string, orders map[string]int) error {
|
||||
return s.stickerRepo.BatchUpdateSortOrder(userID, orders)
|
||||
}
|
||||
|
||||
// GetStickerCount 获取用户表情数量
|
||||
func (s *stickerService) GetStickerCount(userID string) (int64, error) {
|
||||
return s.stickerRepo.CountByUserID(userID)
|
||||
}
|
||||
462
internal/service/system_message_service.go
Normal file
462
internal/service/system_message_service.go
Normal file
@@ -0,0 +1,462 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
"carrot_bbs/internal/repository"
|
||||
)
|
||||
|
||||
// SystemMessageService 系统消息服务接口
|
||||
type SystemMessageService interface {
|
||||
// 发送互动通知
|
||||
SendLikeNotification(ctx context.Context, userID string, operatorID string, postID string) error
|
||||
SendCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string) error
|
||||
SendReplyNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, replyID string) error
|
||||
SendFollowNotification(ctx context.Context, userID string, operatorID string) error
|
||||
SendMentionNotification(ctx context.Context, userID string, operatorID string, postID string) error
|
||||
SendFavoriteNotification(ctx context.Context, userID string, operatorID string, postID string) error
|
||||
SendLikeCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, commentContent string) error
|
||||
SendLikeReplyNotification(ctx context.Context, userID string, operatorID string, postID string, replyID string, replyContent string) error
|
||||
|
||||
// 发送系统公告
|
||||
SendSystemAnnouncement(ctx context.Context, userIDs []string, title string, content string) error
|
||||
SendBroadcastAnnouncement(ctx context.Context, title string, content string) error
|
||||
}
|
||||
|
||||
type systemMessageServiceImpl struct {
|
||||
notifyRepo *repository.SystemNotificationRepository
|
||||
pushService PushService
|
||||
userRepo *repository.UserRepository
|
||||
postRepo *repository.PostRepository
|
||||
cache cache.Cache
|
||||
}
|
||||
|
||||
// NewSystemMessageService 创建系统消息服务
|
||||
func NewSystemMessageService(
|
||||
notifyRepo *repository.SystemNotificationRepository,
|
||||
pushService PushService,
|
||||
userRepo *repository.UserRepository,
|
||||
postRepo *repository.PostRepository,
|
||||
) SystemMessageService {
|
||||
return &systemMessageServiceImpl{
|
||||
notifyRepo: notifyRepo,
|
||||
pushService: pushService,
|
||||
userRepo: userRepo,
|
||||
postRepo: postRepo,
|
||||
cache: cache.GetCache(),
|
||||
}
|
||||
}
|
||||
|
||||
// SendLikeNotification 发送点赞通知
|
||||
func (s *systemMessageServiceImpl) SendLikeNotification(ctx context.Context, userID string, operatorID string, postID string) error {
|
||||
// 获取操作者信息
|
||||
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取帖子标题
|
||||
postTitle, err := s.getPostTitle(postID)
|
||||
if err != nil {
|
||||
postTitle = "您的帖子"
|
||||
}
|
||||
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
ActorIDStr: operatorID,
|
||||
ActorName: actorName,
|
||||
AvatarURL: avatarURL,
|
||||
TargetID: postID,
|
||||
TargetTitle: postTitle,
|
||||
TargetType: "post",
|
||||
ActionURL: fmt.Sprintf("/posts/%s", postID),
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("%s 赞了「%s」", actorName, postTitle)
|
||||
|
||||
// 创建通知
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyLikePost, content, extraData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create like notification: %w", err)
|
||||
}
|
||||
|
||||
// 推送通知
|
||||
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||
}
|
||||
|
||||
// SendCommentNotification 发送评论通知
|
||||
func (s *systemMessageServiceImpl) SendCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string) error {
|
||||
// 获取操作者信息
|
||||
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取帖子标题
|
||||
postTitle, err := s.getPostTitle(postID)
|
||||
if err != nil {
|
||||
postTitle = "您的帖子"
|
||||
}
|
||||
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
ActorIDStr: operatorID,
|
||||
ActorName: actorName,
|
||||
AvatarURL: avatarURL,
|
||||
TargetID: postID,
|
||||
TargetTitle: postTitle,
|
||||
TargetType: "comment",
|
||||
ActionURL: fmt.Sprintf("/posts/%s?comment=%s", postID, commentID),
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("%s 评论了「%s」", actorName, postTitle)
|
||||
|
||||
// 创建通知
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyComment, content, extraData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create comment notification: %w", err)
|
||||
}
|
||||
|
||||
// 推送通知
|
||||
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||
}
|
||||
|
||||
// SendReplyNotification 发送回复通知
|
||||
func (s *systemMessageServiceImpl) SendReplyNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, replyID string) error {
|
||||
// 获取操作者信息
|
||||
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取帖子标题
|
||||
postTitle, err := s.getPostTitle(postID)
|
||||
if err != nil {
|
||||
postTitle = "您的帖子"
|
||||
}
|
||||
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
ActorIDStr: operatorID,
|
||||
ActorName: actorName,
|
||||
AvatarURL: avatarURL,
|
||||
TargetID: replyID,
|
||||
TargetTitle: postTitle,
|
||||
TargetType: "reply",
|
||||
ActionURL: fmt.Sprintf("/posts/%s?comment=%s&reply=%s", postID, commentID, replyID),
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("%s 回复了您在「%s」的评论", actorName, postTitle)
|
||||
|
||||
// 创建通知
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyReply, content, extraData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create reply notification: %w", err)
|
||||
}
|
||||
|
||||
// 推送通知
|
||||
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||
}
|
||||
|
||||
// SendFollowNotification 发送关注通知
|
||||
func (s *systemMessageServiceImpl) SendFollowNotification(ctx context.Context, userID string, operatorID string) error {
|
||||
fmt.Printf("[DEBUG] SendFollowNotification: userID=%s, operatorID=%s\n", userID, operatorID)
|
||||
|
||||
// 获取操作者信息
|
||||
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] SendFollowNotification: getActorInfo error: %v\n", err)
|
||||
return err
|
||||
}
|
||||
fmt.Printf("[DEBUG] SendFollowNotification: actorName=%s, avatarURL=%s\n", actorName, avatarURL)
|
||||
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
ActorIDStr: operatorID,
|
||||
ActorName: actorName,
|
||||
AvatarURL: avatarURL,
|
||||
TargetID: "",
|
||||
TargetType: "user",
|
||||
ActionURL: fmt.Sprintf("/users/%s", operatorID),
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("%s 关注了你", actorName)
|
||||
|
||||
// 创建通知
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyFollow, content, extraData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create follow notification: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] SendFollowNotification: notification created, ID=%d, Content=%s\n", notification.ID, notification.Content)
|
||||
|
||||
// 推送通知
|
||||
pushErr := s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||
if pushErr != nil {
|
||||
fmt.Printf("[DEBUG] SendFollowNotification: PushSystemNotification error: %v\n", pushErr)
|
||||
} else {
|
||||
fmt.Printf("[DEBUG] SendFollowNotification: PushSystemNotification success\n")
|
||||
}
|
||||
return pushErr
|
||||
}
|
||||
|
||||
// SendFavoriteNotification 发送收藏通知
|
||||
func (s *systemMessageServiceImpl) SendFavoriteNotification(ctx context.Context, userID string, operatorID string, postID string) error {
|
||||
// 获取操作者信息
|
||||
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取帖子标题
|
||||
postTitle, err := s.getPostTitle(postID)
|
||||
if err != nil {
|
||||
postTitle = "您的帖子"
|
||||
}
|
||||
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
ActorIDStr: operatorID,
|
||||
ActorName: actorName,
|
||||
AvatarURL: avatarURL,
|
||||
TargetID: postID,
|
||||
TargetTitle: postTitle,
|
||||
TargetType: "post",
|
||||
ActionURL: fmt.Sprintf("/posts/%s", postID),
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("%s 收藏了「%s」", actorName, postTitle)
|
||||
|
||||
// 创建通知
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyFavoritePost, content, extraData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create favorite notification: %w", err)
|
||||
}
|
||||
|
||||
// 推送通知
|
||||
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||
}
|
||||
|
||||
// SendLikeCommentNotification 发送评论点赞通知
|
||||
func (s *systemMessageServiceImpl) SendLikeCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, commentContent string) error {
|
||||
// 获取操作者信息
|
||||
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 截取评论内容预览(最多50字)
|
||||
preview := commentContent
|
||||
runes := []rune(preview)
|
||||
if len(runes) > 50 {
|
||||
preview = string(runes[:50]) + "..."
|
||||
}
|
||||
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
ActorIDStr: operatorID,
|
||||
ActorName: actorName,
|
||||
AvatarURL: avatarURL,
|
||||
TargetID: postID,
|
||||
TargetTitle: preview,
|
||||
TargetType: "comment",
|
||||
ActionURL: fmt.Sprintf("/posts/%s?comment=%s", postID, commentID),
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("%s 赞了您的评论", actorName)
|
||||
|
||||
// 创建通知
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyLikeComment, content, extraData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create like comment notification: %w", err)
|
||||
}
|
||||
|
||||
// 推送通知
|
||||
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||
}
|
||||
|
||||
// SendLikeReplyNotification 发送回复点赞通知
|
||||
func (s *systemMessageServiceImpl) SendLikeReplyNotification(ctx context.Context, userID string, operatorID string, postID string, replyID string, replyContent string) error {
|
||||
// 获取操作者信息
|
||||
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 截取回复内容预览(最多50字)
|
||||
preview := replyContent
|
||||
runes := []rune(preview)
|
||||
if len(runes) > 50 {
|
||||
preview = string(runes[:50]) + "..."
|
||||
}
|
||||
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
ActorIDStr: operatorID,
|
||||
ActorName: actorName,
|
||||
AvatarURL: avatarURL,
|
||||
TargetID: postID,
|
||||
TargetTitle: preview,
|
||||
TargetType: "reply",
|
||||
ActionURL: fmt.Sprintf("/posts/%s?reply=%s", postID, replyID),
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("%s 赞了您的回复", actorName)
|
||||
|
||||
// 创建通知
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyLikeReply, content, extraData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create like reply notification: %w", err)
|
||||
}
|
||||
|
||||
// 推送通知
|
||||
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||
}
|
||||
|
||||
// SendMentionNotification 发送@提及通知
|
||||
func (s *systemMessageServiceImpl) SendMentionNotification(ctx context.Context, userID string, operatorID string, postID string) error {
|
||||
// 获取操作者信息
|
||||
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取帖子标题
|
||||
postTitle, err := s.getPostTitle(postID)
|
||||
if err != nil {
|
||||
postTitle = "您的帖子"
|
||||
}
|
||||
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
ActorIDStr: operatorID,
|
||||
ActorName: actorName,
|
||||
AvatarURL: avatarURL,
|
||||
TargetID: postID,
|
||||
TargetTitle: postTitle,
|
||||
TargetType: "post",
|
||||
ActionURL: fmt.Sprintf("/posts/%s", postID),
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("%s 在「%s」中提到了你", actorName, postTitle)
|
||||
|
||||
// 创建通知
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyMention, content, extraData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create mention notification: %w", err)
|
||||
}
|
||||
|
||||
// 推送通知
|
||||
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||
}
|
||||
|
||||
// SendSystemAnnouncement 发送系统公告给指定用户
|
||||
func (s *systemMessageServiceImpl) SendSystemAnnouncement(ctx context.Context, userIDs []string, title string, content string) error {
|
||||
for _, userID := range userIDs {
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
TargetType: "announcement",
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyAnnounce, fmt.Sprintf("【%s】%s", title, content), extraData)
|
||||
if err != nil {
|
||||
continue // 单个失败不影响其他用户
|
||||
}
|
||||
|
||||
// 推送通知(使用高优先级)
|
||||
if err := s.pushService.PushSystemNotification(ctx, userID, notification); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendBroadcastAnnouncement 发送广播公告给所有在线用户
|
||||
func (s *systemMessageServiceImpl) SendBroadcastAnnouncement(ctx context.Context, title string, content string) error {
|
||||
// TODO: 实现广播公告
|
||||
// 1. 获取所有在线用户
|
||||
// 2. 批量发送公告
|
||||
// 3. 对于离线用户,存储为待推送记录
|
||||
return fmt.Errorf("broadcast announcement not implemented")
|
||||
}
|
||||
|
||||
// createNotification 创建系统通知(存储到独立表)
|
||||
func (s *systemMessageServiceImpl) createNotification(ctx context.Context, userID string, notifyType model.SystemNotificationType, content string, extraData *model.SystemNotificationExtra) (*model.SystemNotification, error) {
|
||||
fmt.Printf("[DEBUG] createNotification: userID=%s, notifyType=%s\n", userID, notifyType)
|
||||
|
||||
// 生成雪花算法ID
|
||||
id, err := utils.GetSnowflake().GenerateID()
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] createNotification: failed to generate ID: %v\n", err)
|
||||
return nil, fmt.Errorf("failed to generate notification ID: %w", err)
|
||||
}
|
||||
|
||||
notification := &model.SystemNotification{
|
||||
ID: id,
|
||||
ReceiverID: userID,
|
||||
Type: notifyType,
|
||||
Content: content,
|
||||
ExtraData: extraData,
|
||||
IsRead: false,
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] createNotification: notification created with ID=%d, ReceiverID=%s\n", id, userID)
|
||||
|
||||
// 保存通知到数据库
|
||||
if err := s.notifyRepo.Create(notification); err != nil {
|
||||
fmt.Printf("[DEBUG] createNotification: failed to save notification: %v\n", err)
|
||||
return nil, fmt.Errorf("failed to save notification: %w", err)
|
||||
}
|
||||
|
||||
// 失效系统消息未读数缓存
|
||||
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||
|
||||
fmt.Printf("[DEBUG] createNotification: notification saved successfully, ID=%d\n", notification.ID)
|
||||
return notification, nil
|
||||
}
|
||||
|
||||
// getActorInfo 获取操作者信息
|
||||
func (s *systemMessageServiceImpl) getActorInfo(ctx context.Context, operatorID string) (string, string, error) {
|
||||
// 从用户仓储获取用户信息
|
||||
if s.userRepo != nil {
|
||||
user, err := s.userRepo.GetByID(operatorID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] getActorInfo: failed to get user %s: %v\n", operatorID, err)
|
||||
return "用户", utils.GenerateDefaultAvatarURL("用户"), nil // 返回默认值,不阻断流程
|
||||
}
|
||||
avatar := utils.GetAvatarOrDefault(user.Username, user.Nickname, user.Avatar)
|
||||
return user.Nickname, avatar, nil
|
||||
}
|
||||
// 如果没有用户仓储,返回默认值
|
||||
return "用户", utils.GenerateDefaultAvatarURL("用户"), nil
|
||||
}
|
||||
|
||||
// getPostTitle 获取帖子标题
|
||||
func (s *systemMessageServiceImpl) getPostTitle(postID string) (string, error) {
|
||||
if s.postRepo == nil {
|
||||
if len(postID) >= 8 {
|
||||
return fmt.Sprintf("帖子#%s", postID[:8]), nil
|
||||
}
|
||||
return fmt.Sprintf("帖子#%s", postID), nil
|
||||
}
|
||||
post, err := s.postRepo.GetByID(postID)
|
||||
if err != nil {
|
||||
if len(postID) >= 8 {
|
||||
return fmt.Sprintf("帖子#%s", postID[:8]), nil
|
||||
}
|
||||
return fmt.Sprintf("帖子#%s", postID), nil
|
||||
}
|
||||
if post.Title != "" {
|
||||
return post.Title, nil
|
||||
}
|
||||
// 如果没有标题,返回内容前20个字符
|
||||
if len(post.Content) > 20 {
|
||||
return post.Content[:20] + "...", nil
|
||||
}
|
||||
return post.Content, nil
|
||||
}
|
||||
273
internal/service/upload_service.go
Normal file
273
internal/service/upload_service.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/jpeg"
|
||||
"image/png"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"carrot_bbs/internal/pkg/s3"
|
||||
_ "golang.org/x/image/bmp"
|
||||
_ "golang.org/x/image/tiff"
|
||||
)
|
||||
|
||||
// UploadService 上传服务
|
||||
type UploadService struct {
|
||||
s3Client *s3.Client
|
||||
userService *UserService
|
||||
}
|
||||
|
||||
// NewUploadService 创建上传服务
|
||||
func NewUploadService(s3Client *s3.Client, userService *UserService) *UploadService {
|
||||
return &UploadService{
|
||||
s3Client: s3Client,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// UploadImage 上传图片
|
||||
func (s *UploadService) UploadImage(ctx context.Context, file *multipart.FileHeader) (string, error) {
|
||||
processedData, contentType, ext, err := prepareImageForUpload(file)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 压缩后再计算哈希,确保同一压缩结果映射同一对象名
|
||||
hash := sha256.Sum256(processedData)
|
||||
hashStr := fmt.Sprintf("%x", hash)
|
||||
|
||||
objectName := fmt.Sprintf("images/%s%s", hashStr, ext)
|
||||
|
||||
url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to upload to S3: %w", err)
|
||||
}
|
||||
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// getExtFromContentType 根据Content-Type获取文件扩展名
|
||||
func getExtFromContentType(contentType string) string {
|
||||
baseType, _, err := mime.ParseMediaType(contentType)
|
||||
if err == nil && baseType != "" {
|
||||
contentType = baseType
|
||||
}
|
||||
|
||||
switch contentType {
|
||||
case "image/jpg", "image/jpeg":
|
||||
return ".jpg"
|
||||
case "image/png":
|
||||
return ".png"
|
||||
case "image/gif":
|
||||
return ".gif"
|
||||
case "image/webp":
|
||||
return ".webp"
|
||||
case "image/bmp", "image/x-ms-bmp":
|
||||
return ".bmp"
|
||||
case "image/tiff":
|
||||
return ".tiff"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// UploadAvatar 上传头像
|
||||
func (s *UploadService) UploadAvatar(ctx context.Context, userID string, file *multipart.FileHeader) (string, error) {
|
||||
processedData, contentType, ext, err := prepareImageForUpload(file)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 压缩后再计算哈希
|
||||
hash := sha256.Sum256(processedData)
|
||||
hashStr := fmt.Sprintf("%x", hash)
|
||||
|
||||
objectName := fmt.Sprintf("avatars/%s%s", hashStr, ext)
|
||||
|
||||
url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to upload to S3: %w", err)
|
||||
}
|
||||
|
||||
// 更新用户头像
|
||||
if s.userService != nil {
|
||||
user, err := s.userService.GetUserByID(ctx, userID)
|
||||
if err == nil && user != nil {
|
||||
user.Avatar = url
|
||||
err = s.userService.UpdateUser(ctx, user)
|
||||
if err != nil {
|
||||
// 更新失败不影响上传结果,只记录日志
|
||||
fmt.Printf("[UploadAvatar] failed to update user avatar: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// UploadCover 上传头图(个人主页封面)
|
||||
func (s *UploadService) UploadCover(ctx context.Context, userID string, file *multipart.FileHeader) (string, error) {
|
||||
processedData, contentType, ext, err := prepareImageForUpload(file)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 压缩后再计算哈希
|
||||
hash := sha256.Sum256(processedData)
|
||||
hashStr := fmt.Sprintf("%x", hash)
|
||||
|
||||
objectName := fmt.Sprintf("covers/%s%s", hashStr, ext)
|
||||
|
||||
url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to upload to S3: %w", err)
|
||||
}
|
||||
|
||||
// 更新用户头图
|
||||
if s.userService != nil {
|
||||
user, err := s.userService.GetUserByID(ctx, userID)
|
||||
if err == nil && user != nil {
|
||||
user.CoverURL = url
|
||||
err = s.userService.UpdateUser(ctx, user)
|
||||
if err != nil {
|
||||
// 更新失败不影响上传结果,只记录日志
|
||||
fmt.Printf("[UploadCover] failed to update user cover: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// GetURL 获取文件URL
|
||||
func (s *UploadService) GetURL(ctx context.Context, objectName string) (string, error) {
|
||||
return s.s3Client.GetURL(ctx, objectName)
|
||||
}
|
||||
|
||||
// Delete 删除文件
|
||||
func (s *UploadService) Delete(ctx context.Context, objectName string) error {
|
||||
return s.s3Client.Delete(ctx, objectName)
|
||||
}
|
||||
|
||||
func prepareImageForUpload(file *multipart.FileHeader) ([]byte, string, string, error) {
|
||||
f, err := file.Open()
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
originalData, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
// 优先从文件字节探测真实类型,避免前端压缩/转码后 header 与实际格式不一致
|
||||
detectedType := normalizeImageContentType(http.DetectContentType(originalData))
|
||||
headerType := normalizeImageContentType(file.Header.Get("Content-Type"))
|
||||
contentType := detectedType
|
||||
if contentType == "" || contentType == "application/octet-stream" {
|
||||
contentType = headerType
|
||||
}
|
||||
|
||||
compressedData, compressedType, err := compressImageData(originalData, contentType)
|
||||
if err != nil {
|
||||
// 压缩失败时回退到原图,保证上传可用性
|
||||
compressedData = originalData
|
||||
compressedType = contentType
|
||||
}
|
||||
|
||||
if compressedType == "" {
|
||||
compressedType = contentType
|
||||
}
|
||||
if compressedType == "" {
|
||||
compressedType = http.DetectContentType(compressedData)
|
||||
}
|
||||
|
||||
ext := getExtFromContentType(compressedType)
|
||||
if ext == "" {
|
||||
ext = strings.ToLower(filepath.Ext(file.Filename))
|
||||
}
|
||||
if ext == "" {
|
||||
// 最终兜底,避免对象名无扩展名导致 URL 语义不明确
|
||||
ext = ".jpg"
|
||||
}
|
||||
|
||||
return compressedData, compressedType, ext, nil
|
||||
}
|
||||
|
||||
func compressImageData(data []byte, contentType string) ([]byte, string, error) {
|
||||
contentType = normalizeImageContentType(contentType)
|
||||
|
||||
// GIF/WebP 等格式先保留原图,避免动画和透明通道丢失
|
||||
if contentType == "image/gif" || contentType == "image/webp" {
|
||||
return data, contentType, nil
|
||||
}
|
||||
|
||||
if contentType != "image/jpeg" &&
|
||||
contentType != "image/png" &&
|
||||
contentType != "image/bmp" &&
|
||||
contentType != "image/x-ms-bmp" &&
|
||||
contentType != "image/tiff" {
|
||||
return data, contentType, nil
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to decode image: %w", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
switch contentType {
|
||||
case "image/png":
|
||||
encoder := png.Encoder{CompressionLevel: png.BestCompression}
|
||||
if err := encoder.Encode(&buf, img); err != nil {
|
||||
return nil, "", fmt.Errorf("failed to encode png: %w", err)
|
||||
}
|
||||
return buf.Bytes(), "image/png", nil
|
||||
default:
|
||||
// BMP/TIFF 等无损大图统一压缩为 JPEG,控制体积
|
||||
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 82}); err != nil {
|
||||
return nil, "", fmt.Errorf("failed to encode jpeg: %w", err)
|
||||
}
|
||||
return buf.Bytes(), "image/jpeg", nil
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeImageContentType(contentType string) string {
|
||||
if contentType == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
baseType, _, err := mime.ParseMediaType(contentType)
|
||||
if err == nil && baseType != "" {
|
||||
contentType = baseType
|
||||
}
|
||||
|
||||
switch strings.ToLower(contentType) {
|
||||
case "image/jpg":
|
||||
return "image/jpeg"
|
||||
case "image/jpeg":
|
||||
return "image/jpeg"
|
||||
case "image/png":
|
||||
return "image/png"
|
||||
case "image/gif":
|
||||
return "image/gif"
|
||||
case "image/webp":
|
||||
return "image/webp"
|
||||
case "image/bmp", "image/x-ms-bmp":
|
||||
return "image/bmp"
|
||||
case "image/tiff":
|
||||
return "image/tiff"
|
||||
default:
|
||||
return contentType
|
||||
}
|
||||
}
|
||||
592
internal/service/user_service.go
Normal file
592
internal/service/user_service.go
Normal file
@@ -0,0 +1,592 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
"carrot_bbs/internal/repository"
|
||||
)
|
||||
|
||||
// UserService 用户服务
|
||||
type UserService struct {
|
||||
userRepo *repository.UserRepository
|
||||
systemMessageService SystemMessageService
|
||||
emailCodeService EmailCodeService
|
||||
}
|
||||
|
||||
// NewUserService 创建用户服务
|
||||
func NewUserService(
|
||||
userRepo *repository.UserRepository,
|
||||
systemMessageService SystemMessageService,
|
||||
emailService EmailService,
|
||||
cacheBackend cache.Cache,
|
||||
) *UserService {
|
||||
return &UserService{
|
||||
userRepo: userRepo,
|
||||
systemMessageService: systemMessageService,
|
||||
emailCodeService: NewEmailCodeService(emailService, cacheBackend),
|
||||
}
|
||||
}
|
||||
|
||||
// SendRegisterCode 发送注册验证码
|
||||
func (s *UserService) SendRegisterCode(ctx context.Context, email string) error {
|
||||
user, err := s.userRepo.GetByEmail(email)
|
||||
if err == nil && user != nil {
|
||||
return ErrEmailExists
|
||||
}
|
||||
return s.emailCodeService.SendCode(ctx, CodePurposeRegister, email)
|
||||
}
|
||||
|
||||
// SendPasswordResetCode 发送找回密码验证码
|
||||
func (s *UserService) SendPasswordResetCode(ctx context.Context, email string) error {
|
||||
user, err := s.userRepo.GetByEmail(email)
|
||||
if err != nil || user == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
return s.emailCodeService.SendCode(ctx, CodePurposePasswordReset, email)
|
||||
}
|
||||
|
||||
// SendCurrentUserEmailVerifyCode 发送当前用户邮箱验证验证码
|
||||
func (s *UserService) SendCurrentUserEmailVerifyCode(ctx context.Context, userID, email string) error {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil || user == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
targetEmail := strings.TrimSpace(email)
|
||||
if targetEmail == "" && user.Email != nil {
|
||||
targetEmail = strings.TrimSpace(*user.Email)
|
||||
}
|
||||
if targetEmail == "" || !utils.ValidateEmail(targetEmail) {
|
||||
return ErrInvalidEmail
|
||||
}
|
||||
|
||||
if user.EmailVerified && user.Email != nil && strings.EqualFold(strings.TrimSpace(*user.Email), targetEmail) {
|
||||
return ErrEmailAlreadyVerified
|
||||
}
|
||||
|
||||
existingUser, queryErr := s.userRepo.GetByEmail(targetEmail)
|
||||
if queryErr == nil && existingUser != nil && existingUser.ID != userID {
|
||||
return ErrEmailExists
|
||||
}
|
||||
|
||||
return s.emailCodeService.SendCode(ctx, CodePurposeEmailVerify, targetEmail)
|
||||
}
|
||||
|
||||
// VerifyCurrentUserEmail 验证当前用户邮箱
|
||||
func (s *UserService) VerifyCurrentUserEmail(ctx context.Context, userID, email, verificationCode string) error {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil || user == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
targetEmail := strings.TrimSpace(email)
|
||||
if targetEmail == "" && user.Email != nil {
|
||||
targetEmail = strings.TrimSpace(*user.Email)
|
||||
}
|
||||
if targetEmail == "" || !utils.ValidateEmail(targetEmail) {
|
||||
return ErrInvalidEmail
|
||||
}
|
||||
|
||||
if err := s.emailCodeService.VerifyCode(CodePurposeEmailVerify, targetEmail, verificationCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
existingUser, queryErr := s.userRepo.GetByEmail(targetEmail)
|
||||
if queryErr == nil && existingUser != nil && existingUser.ID != userID {
|
||||
return ErrEmailExists
|
||||
}
|
||||
|
||||
user.Email = &targetEmail
|
||||
user.EmailVerified = true
|
||||
return s.userRepo.Update(user)
|
||||
}
|
||||
|
||||
// SendChangePasswordCode 发送修改密码验证码
|
||||
func (s *UserService) SendChangePasswordCode(ctx context.Context, userID string) error {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil || user == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
if user.Email == nil || strings.TrimSpace(*user.Email) == "" {
|
||||
return ErrEmailNotBound
|
||||
}
|
||||
return s.emailCodeService.SendCode(ctx, CodePurposeChangePassword, *user.Email)
|
||||
}
|
||||
|
||||
// Register 用户注册
|
||||
func (s *UserService) Register(ctx context.Context, username, email, password, nickname, phone, verificationCode string) (*model.User, error) {
|
||||
// 验证用户名
|
||||
if !utils.ValidateUsername(username) {
|
||||
return nil, ErrInvalidUsername
|
||||
}
|
||||
|
||||
// 注册必须提供邮箱并完成验证码校验
|
||||
if email == "" || !utils.ValidateEmail(email) {
|
||||
return nil, ErrInvalidEmail
|
||||
}
|
||||
if err := s.emailCodeService.VerifyCode(CodePurposeRegister, email, verificationCode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if !utils.ValidatePassword(password) {
|
||||
return nil, ErrWeakPassword
|
||||
}
|
||||
|
||||
// 验证手机号(如果提供)
|
||||
if phone != "" && !utils.ValidatePhone(phone) {
|
||||
return nil, ErrInvalidPhone
|
||||
}
|
||||
|
||||
// 检查用户名是否已存在
|
||||
existingUser, err := s.userRepo.GetByUsername(username)
|
||||
if err == nil && existingUser != nil {
|
||||
return nil, ErrUsernameExists
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在(如果提供)
|
||||
if email != "" {
|
||||
existingUser, err = s.userRepo.GetByEmail(email)
|
||||
if err == nil && existingUser != nil {
|
||||
return nil, ErrEmailExists
|
||||
}
|
||||
}
|
||||
|
||||
// 检查手机号是否已存在(如果提供)
|
||||
if phone != "" {
|
||||
existingUser, err = s.userRepo.GetByPhone(phone)
|
||||
if err == nil && existingUser != nil {
|
||||
return nil, ErrPhoneExists
|
||||
}
|
||||
}
|
||||
|
||||
// 密码哈希
|
||||
hashedPassword, err := utils.HashPassword(password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
user := &model.User{
|
||||
Username: username,
|
||||
Nickname: nickname,
|
||||
EmailVerified: true,
|
||||
PasswordHash: hashedPassword,
|
||||
Status: model.UserStatusActive,
|
||||
}
|
||||
|
||||
// 如果提供了邮箱,设置指针值
|
||||
if email != "" {
|
||||
user.Email = &email
|
||||
}
|
||||
|
||||
// 如果提供了手机号,设置指针值
|
||||
if phone != "" {
|
||||
user.Phone = &phone
|
||||
}
|
||||
|
||||
err = s.userRepo.Create(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Login 用户登录
|
||||
func (s *UserService) Login(ctx context.Context, account, password string) (*model.User, error) {
|
||||
account = strings.TrimSpace(account)
|
||||
var (
|
||||
user *model.User
|
||||
err error
|
||||
)
|
||||
if utils.ValidateEmail(account) {
|
||||
user, err = s.userRepo.GetByEmail(account)
|
||||
} else if utils.ValidatePhone(account) {
|
||||
user, err = s.userRepo.GetByPhone(account)
|
||||
} else {
|
||||
user, err = s.userRepo.GetByUsername(account)
|
||||
}
|
||||
if err != nil || user == nil {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
if !utils.CheckPasswordHash(password, user.PasswordHash) {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
if user.Status != model.UserStatusActive {
|
||||
return nil, ErrUserBanned
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUserByID 根据ID获取用户
|
||||
func (s *UserService) GetUserByID(ctx context.Context, id string) (*model.User, error) {
|
||||
return s.userRepo.GetByID(id)
|
||||
}
|
||||
|
||||
// GetUserPostCount 获取用户帖子数(实时计算)
|
||||
func (s *UserService) GetUserPostCount(ctx context.Context, userID string) (int64, error) {
|
||||
return s.userRepo.GetPostsCount(userID)
|
||||
}
|
||||
|
||||
// GetUserPostCountBatch 批量获取用户帖子数(实时计算)
|
||||
func (s *UserService) GetUserPostCountBatch(ctx context.Context, userIDs []string) (map[string]int64, error) {
|
||||
return s.userRepo.GetPostsCountBatch(userIDs)
|
||||
}
|
||||
|
||||
// GetUserByIDWithFollowingStatus 根据ID获取用户(包含当前用户是否关注的状态)
|
||||
func (s *UserService) GetUserByIDWithFollowingStatus(ctx context.Context, userID, currentUserID string) (*model.User, bool, error) {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
// 如果查询的是当前用户自己,不需要检查关注状态
|
||||
if userID == currentUserID {
|
||||
return user, false, nil
|
||||
}
|
||||
|
||||
isFollowing, err := s.userRepo.IsFollowing(currentUserID, userID)
|
||||
if err != nil {
|
||||
return user, false, err
|
||||
}
|
||||
|
||||
return user, isFollowing, nil
|
||||
}
|
||||
|
||||
// GetUserByIDWithMutualFollowStatus 根据ID获取用户(包含双向关注状态)
|
||||
func (s *UserService) GetUserByIDWithMutualFollowStatus(ctx context.Context, userID, currentUserID string) (*model.User, bool, bool, error) {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return nil, false, false, err
|
||||
}
|
||||
|
||||
// 如果查询的是当前用户自己,不需要检查关注状态
|
||||
if userID == currentUserID {
|
||||
return user, false, false, nil
|
||||
}
|
||||
|
||||
// 当前用户是否关注了该用户
|
||||
isFollowing, err := s.userRepo.IsFollowing(currentUserID, userID)
|
||||
if err != nil {
|
||||
return user, false, false, err
|
||||
}
|
||||
|
||||
// 该用户是否关注了当前用户
|
||||
isFollowingMe, err := s.userRepo.IsFollowing(userID, currentUserID)
|
||||
if err != nil {
|
||||
return user, isFollowing, false, err
|
||||
}
|
||||
|
||||
return user, isFollowing, isFollowingMe, nil
|
||||
}
|
||||
|
||||
// UpdateUser 更新用户
|
||||
func (s *UserService) UpdateUser(ctx context.Context, user *model.User) error {
|
||||
return s.userRepo.Update(user)
|
||||
}
|
||||
|
||||
// GetFollowers 获取粉丝
|
||||
func (s *UserService) GetFollowers(ctx context.Context, userID string, page, pageSize int) ([]*model.User, int64, error) {
|
||||
return s.userRepo.GetFollowers(userID, page, pageSize)
|
||||
}
|
||||
|
||||
// GetFollowing 获取关注
|
||||
func (s *UserService) GetFollowing(ctx context.Context, userID string, page, pageSize int) ([]*model.User, int64, error) {
|
||||
return s.userRepo.GetFollowing(userID, page, pageSize)
|
||||
}
|
||||
|
||||
// FollowUser 关注用户
|
||||
func (s *UserService) FollowUser(ctx context.Context, followerID, followeeID string) error {
|
||||
fmt.Printf("[DEBUG] FollowUser called: followerID=%s, followeeID=%s\n", followerID, followeeID)
|
||||
|
||||
blocked, err := s.userRepo.IsBlockedEitherDirection(followerID, followeeID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if blocked {
|
||||
return ErrUserBlocked
|
||||
}
|
||||
|
||||
// 检查是否已经关注
|
||||
isFollowing, err := s.userRepo.IsFollowing(followerID, followeeID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] Error checking existing follow: %v\n", err)
|
||||
return err
|
||||
}
|
||||
if isFollowing {
|
||||
fmt.Printf("[DEBUG] Already following, skip creation\n")
|
||||
return nil // 已经关注,直接返回成功
|
||||
}
|
||||
|
||||
// 创建关注关系
|
||||
follow := &model.Follow{
|
||||
FollowerID: followerID,
|
||||
FollowingID: followeeID,
|
||||
}
|
||||
|
||||
err = s.userRepo.CreateFollow(follow)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] CreateFollow error: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] Follow record created successfully\n")
|
||||
|
||||
// 刷新关注者的关注数(通过实际计数,更可靠)
|
||||
err = s.userRepo.RefreshFollowingCount(followerID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] Error refreshing following count: %v\n", err)
|
||||
// 不回滚,计数可以通过其他方式修复
|
||||
}
|
||||
|
||||
// 刷新被关注者的粉丝数(通过实际计数,更可靠)
|
||||
err = s.userRepo.RefreshFollowersCount(followeeID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] Error refreshing followers count: %v\n", err)
|
||||
// 不回滚,计数可以通过其他方式修复
|
||||
}
|
||||
|
||||
// 发送关注通知给被关注者
|
||||
if s.systemMessageService != nil {
|
||||
// 异步发送通知,不阻塞主流程
|
||||
go func() {
|
||||
notifyErr := s.systemMessageService.SendFollowNotification(context.Background(), followeeID, followerID)
|
||||
if notifyErr != nil {
|
||||
fmt.Printf("[DEBUG] Error sending follow notification: %v\n", notifyErr)
|
||||
} else {
|
||||
fmt.Printf("[DEBUG] Follow notification sent successfully to %s\n", followeeID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] FollowUser completed: followerID=%s, followeeID=%s\n", followerID, followeeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnfollowUser 取消关注用户
|
||||
func (s *UserService) UnfollowUser(ctx context.Context, followerID, followeeID string) error {
|
||||
fmt.Printf("[DEBUG] UnfollowUser called: followerID=%s, followeeID=%s\n", followerID, followeeID)
|
||||
|
||||
// 检查是否已经关注
|
||||
isFollowing, err := s.userRepo.IsFollowing(followerID, followeeID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] Error checking existing follow: %v\n", err)
|
||||
return err
|
||||
}
|
||||
if !isFollowing {
|
||||
fmt.Printf("[DEBUG] Not following, skip deletion\n")
|
||||
return nil // 没有关注,直接返回成功
|
||||
}
|
||||
|
||||
// 删除关注关系
|
||||
err = s.userRepo.DeleteFollow(followerID, followeeID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] DeleteFollow error: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] Follow record deleted successfully\n")
|
||||
|
||||
// 刷新关注者的关注数(通过实际计数,更可靠)
|
||||
err = s.userRepo.RefreshFollowingCount(followerID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] Error refreshing following count: %v\n", err)
|
||||
}
|
||||
|
||||
// 刷新被关注者的粉丝数(通过实际计数,更可靠)
|
||||
err = s.userRepo.RefreshFollowersCount(followeeID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] Error refreshing followers count: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] UnfollowUser completed: followerID=%s, followeeID=%s\n", followerID, followeeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// BlockUser 拉黑用户,并自动清理双向关注/粉丝关系
|
||||
func (s *UserService) BlockUser(ctx context.Context, blockerID, blockedID string) error {
|
||||
if blockerID == blockedID {
|
||||
return ErrInvalidOperation
|
||||
}
|
||||
return s.userRepo.BlockUserAndCleanupRelations(blockerID, blockedID)
|
||||
}
|
||||
|
||||
// UnblockUser 取消拉黑
|
||||
func (s *UserService) UnblockUser(ctx context.Context, blockerID, blockedID string) error {
|
||||
if blockerID == blockedID {
|
||||
return ErrInvalidOperation
|
||||
}
|
||||
return s.userRepo.UnblockUser(blockerID, blockedID)
|
||||
}
|
||||
|
||||
// GetBlockedUsers 获取黑名单列表
|
||||
func (s *UserService) GetBlockedUsers(ctx context.Context, blockerID string, page, pageSize int) ([]*model.User, int64, error) {
|
||||
return s.userRepo.GetBlockedUsers(blockerID, page, pageSize)
|
||||
}
|
||||
|
||||
// IsBlocked 检查当前用户是否已拉黑目标用户
|
||||
func (s *UserService) IsBlocked(ctx context.Context, blockerID, blockedID string) (bool, error) {
|
||||
return s.userRepo.IsBlocked(blockerID, blockedID)
|
||||
}
|
||||
|
||||
// GetFollowingList 获取关注列表(字符串参数版本)
|
||||
func (s *UserService) GetFollowingList(ctx context.Context, userID, page, pageSize string) ([]*model.User, error) {
|
||||
// 转换字符串参数为整数
|
||||
pageInt := 1
|
||||
pageSizeInt := 20
|
||||
if page != "" {
|
||||
_, err := fmt.Sscanf(page, "%d", &pageInt)
|
||||
if err != nil {
|
||||
pageInt = 1
|
||||
}
|
||||
}
|
||||
if pageSize != "" {
|
||||
_, err := fmt.Sscanf(pageSize, "%d", &pageSizeInt)
|
||||
if err != nil {
|
||||
pageSizeInt = 20
|
||||
}
|
||||
}
|
||||
|
||||
users, _, err := s.userRepo.GetFollowing(userID, pageInt, pageSizeInt)
|
||||
return users, err
|
||||
}
|
||||
|
||||
// GetFollowersList 获取粉丝列表(字符串参数版本)
|
||||
func (s *UserService) GetFollowersList(ctx context.Context, userID, page, pageSize string) ([]*model.User, error) {
|
||||
// 转换字符串参数为整数
|
||||
pageInt := 1
|
||||
pageSizeInt := 20
|
||||
if page != "" {
|
||||
_, err := fmt.Sscanf(page, "%d", &pageInt)
|
||||
if err != nil {
|
||||
pageInt = 1
|
||||
}
|
||||
}
|
||||
if pageSize != "" {
|
||||
_, err := fmt.Sscanf(pageSize, "%d", &pageSizeInt)
|
||||
if err != nil {
|
||||
pageSizeInt = 20
|
||||
}
|
||||
}
|
||||
|
||||
users, _, err := s.userRepo.GetFollowers(userID, pageInt, pageSizeInt)
|
||||
return users, err
|
||||
}
|
||||
|
||||
// GetMutualFollowStatus 批量获取双向关注状态
|
||||
func (s *UserService) GetMutualFollowStatus(ctx context.Context, currentUserID string, targetUserIDs []string) (map[string][2]bool, error) {
|
||||
return s.userRepo.GetMutualFollowStatus(currentUserID, targetUserIDs)
|
||||
}
|
||||
|
||||
// CheckUsernameAvailable 检查用户名是否可用
|
||||
func (s *UserService) CheckUsernameAvailable(ctx context.Context, username string) (bool, error) {
|
||||
user, err := s.userRepo.GetByUsername(username)
|
||||
if err != nil {
|
||||
return true, nil // 用户不存在,可用
|
||||
}
|
||||
return user == nil, nil
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
func (s *UserService) ChangePassword(ctx context.Context, userID, oldPassword, newPassword, verificationCode string) error {
|
||||
// 获取用户
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
if user.Email == nil || strings.TrimSpace(*user.Email) == "" {
|
||||
return ErrEmailNotBound
|
||||
}
|
||||
if err := s.emailCodeService.VerifyCode(CodePurposeChangePassword, *user.Email, verificationCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证旧密码
|
||||
if !utils.CheckPasswordHash(oldPassword, user.PasswordHash) {
|
||||
return ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// 哈希新密码
|
||||
hashedPassword, err := utils.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
user.PasswordHash = hashedPassword
|
||||
return s.userRepo.Update(user)
|
||||
}
|
||||
|
||||
// ResetPasswordByEmail 通过邮箱重置密码
|
||||
func (s *UserService) ResetPasswordByEmail(ctx context.Context, email, verificationCode, newPassword string) error {
|
||||
email = strings.TrimSpace(email)
|
||||
if !utils.ValidateEmail(email) {
|
||||
return ErrInvalidEmail
|
||||
}
|
||||
if !utils.ValidatePassword(newPassword) {
|
||||
return ErrWeakPassword
|
||||
}
|
||||
if err := s.emailCodeService.VerifyCode(CodePurposePasswordReset, email, verificationCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByEmail(email)
|
||||
if err != nil || user == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
hashedPassword, err := utils.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user.PasswordHash = hashedPassword
|
||||
return s.userRepo.Update(user)
|
||||
}
|
||||
|
||||
// Search 搜索用户
|
||||
func (s *UserService) Search(ctx context.Context, keyword string, page, pageSize int) ([]*model.User, int64, error) {
|
||||
return s.userRepo.Search(keyword, page, pageSize)
|
||||
}
|
||||
|
||||
// 错误定义
|
||||
var (
|
||||
ErrInvalidUsername = &ServiceError{Code: 400, Message: "invalid username"}
|
||||
ErrInvalidEmail = &ServiceError{Code: 400, Message: "invalid email"}
|
||||
ErrInvalidPhone = &ServiceError{Code: 400, Message: "invalid phone number"}
|
||||
ErrWeakPassword = &ServiceError{Code: 400, Message: "password too weak"}
|
||||
ErrUsernameExists = &ServiceError{Code: 400, Message: "username already exists"}
|
||||
ErrEmailExists = &ServiceError{Code: 400, Message: "email already exists"}
|
||||
ErrPhoneExists = &ServiceError{Code: 400, Message: "phone number already exists"}
|
||||
ErrUserNotFound = &ServiceError{Code: 404, Message: "user not found"}
|
||||
ErrUserBanned = &ServiceError{Code: 403, Message: "user is banned"}
|
||||
ErrUserBlocked = &ServiceError{Code: 403, Message: "blocked relationship exists"}
|
||||
ErrInvalidOperation = &ServiceError{Code: 400, Message: "invalid operation"}
|
||||
ErrEmailServiceUnavailable = &ServiceError{Code: 503, Message: "email service unavailable"}
|
||||
ErrVerificationCodeTooFrequent = &ServiceError{Code: 429, Message: "verification code sent too frequently"}
|
||||
ErrVerificationCodeInvalid = &ServiceError{Code: 400, Message: "invalid verification code"}
|
||||
ErrVerificationCodeExpired = &ServiceError{Code: 400, Message: "verification code expired"}
|
||||
ErrVerificationCodeUnavailable = &ServiceError{Code: 500, Message: "verification code storage unavailable"}
|
||||
ErrEmailAlreadyVerified = &ServiceError{Code: 400, Message: "email already verified"}
|
||||
ErrEmailNotBound = &ServiceError{Code: 400, Message: "email not bound"}
|
||||
)
|
||||
|
||||
// ServiceError 服务错误
|
||||
type ServiceError struct {
|
||||
Code int
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *ServiceError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
var ErrInvalidCredentials = &ServiceError{Code: 401, Message: "invalid username or password"}
|
||||
282
internal/service/vote_service.go
Normal file
282
internal/service/vote_service.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/repository"
|
||||
)
|
||||
|
||||
// VoteService 投票服务
|
||||
type VoteService struct {
|
||||
voteRepo *repository.VoteRepository
|
||||
postRepo *repository.PostRepository
|
||||
cache cache.Cache
|
||||
postAIService *PostAIService
|
||||
systemMessageService SystemMessageService
|
||||
}
|
||||
|
||||
// NewVoteService 创建投票服务
|
||||
func NewVoteService(
|
||||
voteRepo *repository.VoteRepository,
|
||||
postRepo *repository.PostRepository,
|
||||
cache cache.Cache,
|
||||
postAIService *PostAIService,
|
||||
systemMessageService SystemMessageService,
|
||||
) *VoteService {
|
||||
return &VoteService{
|
||||
voteRepo: voteRepo,
|
||||
postRepo: postRepo,
|
||||
cache: cache,
|
||||
postAIService: postAIService,
|
||||
systemMessageService: systemMessageService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateVotePost 创建投票帖子
|
||||
func (s *VoteService) CreateVotePost(ctx context.Context, userID string, req *dto.CreateVotePostRequest) (*dto.PostResponse, error) {
|
||||
// 验证投票选项数量
|
||||
if len(req.VoteOptions) < 2 {
|
||||
return nil, errors.New("投票选项至少需要2个")
|
||||
}
|
||||
if len(req.VoteOptions) > 10 {
|
||||
return nil, errors.New("投票选项最多10个")
|
||||
}
|
||||
|
||||
// 创建普通帖子(设置IsVote=true)
|
||||
post := &model.Post{
|
||||
UserID: userID,
|
||||
CommunityID: req.CommunityID,
|
||||
Title: req.Title,
|
||||
Content: req.Content,
|
||||
Status: model.PostStatusPending,
|
||||
IsVote: true,
|
||||
}
|
||||
|
||||
err := s.postRepo.Create(post, req.Images)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建投票选项
|
||||
err = s.voteRepo.CreateOptions(post.ID, req.VoteOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 异步审核
|
||||
go s.reviewVotePostAsync(post.ID, userID, req.Title, req.Content, req.Images)
|
||||
|
||||
// 重新查询以获取关联的User和Images
|
||||
createdPost, err := s.postRepo.GetByID(post.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为响应DTO
|
||||
return s.convertToPostResponse(createdPost, userID), nil
|
||||
}
|
||||
|
||||
func (s *VoteService) reviewVotePostAsync(postID, userID, title, content string, images []string) {
|
||||
if s.postAIService == nil || !s.postAIService.IsEnabled() {
|
||||
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil {
|
||||
log.Printf("[WARN] Failed to publish vote post without AI moderation: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err := s.postAIService.ModeratePost(context.Background(), title, content, images)
|
||||
if err != nil {
|
||||
var rejectedErr *PostModerationRejectedError
|
||||
if errors.As(err, &rejectedErr) {
|
||||
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
|
||||
log.Printf("[WARN] Failed to reject vote post %s: %v", postID, updateErr)
|
||||
}
|
||||
s.notifyModerationRejected(userID, rejectedErr.Reason)
|
||||
return
|
||||
}
|
||||
|
||||
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
|
||||
log.Printf("[WARN] Failed to publish vote post %s after moderation error: %v", postID, updateErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "ai"); err != nil {
|
||||
log.Printf("[WARN] Failed to publish vote post %s: %v", postID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *VoteService) notifyModerationRejected(userID, reason string) {
|
||||
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
content := "您发布的投票帖未通过AI审核,请修改后重试。"
|
||||
if strings.TrimSpace(reason) != "" {
|
||||
content = fmt.Sprintf("您发布的投票帖未通过AI审核,原因:%s。请修改后重试。", reason)
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = s.systemMessageService.SendSystemAnnouncement(
|
||||
context.Background(),
|
||||
[]string{userID},
|
||||
"投票帖审核未通过",
|
||||
content,
|
||||
)
|
||||
}()
|
||||
}
|
||||
|
||||
// GetVoteOptions 获取投票选项
|
||||
func (s *VoteService) GetVoteOptions(postID string) ([]dto.VoteOptionDTO, error) {
|
||||
options, err := s.voteRepo.GetOptionsByPostID(postID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]dto.VoteOptionDTO, 0, len(options))
|
||||
for _, option := range options {
|
||||
result = append(result, dto.VoteOptionDTO{
|
||||
ID: option.ID,
|
||||
Content: option.Content,
|
||||
VotesCount: option.VotesCount,
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetVoteResult 获取投票结果(包含用户投票状态)
|
||||
func (s *VoteService) GetVoteResult(postID, userID string) (*dto.VoteResultDTO, error) {
|
||||
// 获取所有投票选项
|
||||
options, err := s.voteRepo.GetOptionsByPostID(postID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取用户的投票记录
|
||||
userVote, err := s.voteRepo.GetUserVote(postID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建结果
|
||||
result := &dto.VoteResultDTO{
|
||||
Options: make([]dto.VoteOptionDTO, 0, len(options)),
|
||||
TotalVotes: 0,
|
||||
HasVoted: userVote != nil,
|
||||
}
|
||||
|
||||
if userVote != nil {
|
||||
result.VotedOptionID = userVote.OptionID
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
result.Options = append(result.Options, dto.VoteOptionDTO{
|
||||
ID: option.ID,
|
||||
Content: option.Content,
|
||||
VotesCount: option.VotesCount,
|
||||
})
|
||||
result.TotalVotes += option.VotesCount
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Vote 投票
|
||||
func (s *VoteService) Vote(ctx context.Context, postID, userID, optionID string) error {
|
||||
// 调用voteRepo.Vote
|
||||
err := s.voteRepo.Vote(postID, userID, optionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效帖子详情缓存
|
||||
cache.InvalidatePostDetail(s.cache, postID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unvote 取消投票
|
||||
func (s *VoteService) Unvote(ctx context.Context, postID, userID string) error {
|
||||
// 调用voteRepo.Unvote
|
||||
err := s.voteRepo.Unvote(postID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效帖子详情缓存
|
||||
cache.InvalidatePostDetail(s.cache, postID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateVoteOption 更新投票选项(作者权限)
|
||||
func (s *VoteService) UpdateVoteOption(ctx context.Context, postID, optionID, userID, content string) error {
|
||||
// 获取帖子信息
|
||||
post, err := s.postRepo.GetByID(postID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证用户是否为帖子作者
|
||||
if post.UserID != userID {
|
||||
return errors.New("只有帖子作者可以更新投票选项")
|
||||
}
|
||||
|
||||
// 调用voteRepo.UpdateOption
|
||||
return s.voteRepo.UpdateOption(optionID, content)
|
||||
}
|
||||
|
||||
// convertToPostResponse 将Post模型转换为PostResponse DTO
|
||||
func (s *VoteService) convertToPostResponse(post *model.Post, currentUserID string) *dto.PostResponse {
|
||||
if post == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
response := &dto.PostResponse{
|
||||
ID: post.ID,
|
||||
UserID: post.UserID,
|
||||
Title: post.Title,
|
||||
Content: post.Content,
|
||||
LikesCount: post.LikesCount,
|
||||
CommentsCount: post.CommentsCount,
|
||||
FavoritesCount: post.FavoritesCount,
|
||||
SharesCount: post.SharesCount,
|
||||
ViewsCount: post.ViewsCount,
|
||||
IsPinned: post.IsPinned,
|
||||
IsLocked: post.IsLocked,
|
||||
IsVote: post.IsVote,
|
||||
CreatedAt: dto.FormatTime(post.CreatedAt),
|
||||
Images: make([]dto.PostImageResponse, 0, len(post.Images)),
|
||||
}
|
||||
|
||||
// 转换图片
|
||||
for _, img := range post.Images {
|
||||
response.Images = append(response.Images, dto.PostImageResponse{
|
||||
ID: img.ID,
|
||||
URL: img.URL,
|
||||
ThumbnailURL: img.ThumbnailURL,
|
||||
Width: img.Width,
|
||||
Height: img.Height,
|
||||
})
|
||||
}
|
||||
|
||||
// 转换作者信息
|
||||
if post.User != nil {
|
||||
response.Author = &dto.UserResponse{
|
||||
ID: post.User.ID,
|
||||
Username: post.User.Username,
|
||||
Nickname: post.User.Nickname,
|
||||
Avatar: post.User.Avatar,
|
||||
}
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
Reference in New Issue
Block a user