173 lines
4.7 KiB
Go
173 lines
4.7 KiB
Go
|
|
package repository
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"carrot_bbs/internal/model"
|
|||
|
|
|
|||
|
|
"gorm.io/gorm"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// PushRecordRepository 推送记录仓储
|
|||
|
|
type PushRecordRepository struct {
|
|||
|
|
db *gorm.DB
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewPushRecordRepository 创建推送记录仓储
|
|||
|
|
func NewPushRecordRepository(db *gorm.DB) *PushRecordRepository {
|
|||
|
|
return &PushRecordRepository{db: db}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Create 创建推送记录
|
|||
|
|
func (r *PushRecordRepository) Create(record *model.PushRecord) error {
|
|||
|
|
return r.db.Create(record).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetByID 根据ID获取推送记录
|
|||
|
|
func (r *PushRecordRepository) GetByID(id int64) (*model.PushRecord, error) {
|
|||
|
|
var record model.PushRecord
|
|||
|
|
err := r.db.First(&record, "id = ?", id).Error
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
return &record, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Update 更新推送记录
|
|||
|
|
func (r *PushRecordRepository) Update(record *model.PushRecord) error {
|
|||
|
|
return r.db.Save(record).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetPendingPushes 获取待推送记录
|
|||
|
|
func (r *PushRecordRepository) GetPendingPushes(limit int) ([]*model.PushRecord, error) {
|
|||
|
|
var records []*model.PushRecord
|
|||
|
|
err := r.db.Where("push_status = ?", model.PushStatusPending).
|
|||
|
|
Where("expired_at IS NULL OR expired_at > ?", time.Now()).
|
|||
|
|
Order("created_at ASC").
|
|||
|
|
Limit(limit).
|
|||
|
|
Find(&records).Error
|
|||
|
|
return records, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetByUserID 根据用户ID获取推送记录
|
|||
|
|
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
|||
|
|
func (r *PushRecordRepository) GetByUserID(userID string, limit, offset int) ([]*model.PushRecord, error) {
|
|||
|
|
var records []*model.PushRecord
|
|||
|
|
err := r.db.Where("user_id = ?", userID).
|
|||
|
|
Order("created_at DESC").
|
|||
|
|
Offset(offset).
|
|||
|
|
Limit(limit).
|
|||
|
|
Find(&records).Error
|
|||
|
|
return records, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetByMessageID 根据消息ID获取推送记录
|
|||
|
|
func (r *PushRecordRepository) GetByMessageID(messageID int64) ([]*model.PushRecord, error) {
|
|||
|
|
var records []*model.PushRecord
|
|||
|
|
err := r.db.Where("message_id = ?", messageID).
|
|||
|
|
Order("created_at DESC").
|
|||
|
|
Find(&records).Error
|
|||
|
|
return records, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetFailedPushesForRetry 获取失败待重试的推送
|
|||
|
|
func (r *PushRecordRepository) GetFailedPushesForRetry(limit int) ([]*model.PushRecord, error) {
|
|||
|
|
var records []*model.PushRecord
|
|||
|
|
err := r.db.Where("push_status = ?", model.PushStatusFailed).
|
|||
|
|
Where("retry_count < max_retry").
|
|||
|
|
Where("expired_at IS NULL OR expired_at > ?", time.Now()).
|
|||
|
|
Order("created_at ASC").
|
|||
|
|
Limit(limit).
|
|||
|
|
Find(&records).Error
|
|||
|
|
return records, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// BatchCreate 批量创建推送记录
|
|||
|
|
func (r *PushRecordRepository) BatchCreate(records []*model.PushRecord) error {
|
|||
|
|
if len(records) == 0 {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
return r.db.Create(&records).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// BatchUpdateStatus 批量更新推送状态
|
|||
|
|
func (r *PushRecordRepository) BatchUpdateStatus(ids []int64, status model.PushStatus) error {
|
|||
|
|
if len(ids) == 0 {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
updates := map[string]interface{}{
|
|||
|
|
"push_status": status,
|
|||
|
|
}
|
|||
|
|
if status == model.PushStatusPushed {
|
|||
|
|
updates["pushed_at"] = time.Now()
|
|||
|
|
}
|
|||
|
|
return r.db.Model(&model.PushRecord{}).
|
|||
|
|
Where("id IN ?", ids).
|
|||
|
|
Updates(updates).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// UpdateStatus 更新单条记录状态
|
|||
|
|
func (r *PushRecordRepository) UpdateStatus(id int64, status model.PushStatus) error {
|
|||
|
|
updates := map[string]interface{}{
|
|||
|
|
"push_status": status,
|
|||
|
|
}
|
|||
|
|
if status == model.PushStatusPushed {
|
|||
|
|
updates["pushed_at"] = time.Now()
|
|||
|
|
}
|
|||
|
|
return r.db.Model(&model.PushRecord{}).
|
|||
|
|
Where("id = ?", id).
|
|||
|
|
Updates(updates).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// MarkAsFailed 标记为失败
|
|||
|
|
func (r *PushRecordRepository) MarkAsFailed(id int64, errMsg string) error {
|
|||
|
|
return r.db.Model(&model.PushRecord{}).
|
|||
|
|
Where("id = ?", id).
|
|||
|
|
Updates(map[string]interface{}{
|
|||
|
|
"push_status": model.PushStatusFailed,
|
|||
|
|
"error_message": errMsg,
|
|||
|
|
"retry_count": gorm.Expr("retry_count + 1"),
|
|||
|
|
}).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// MarkAsDelivered 标记为已送达
|
|||
|
|
func (r *PushRecordRepository) MarkAsDelivered(id int64) error {
|
|||
|
|
return r.db.Model(&model.PushRecord{}).
|
|||
|
|
Where("id = ?", id).
|
|||
|
|
Updates(map[string]interface{}{
|
|||
|
|
"push_status": model.PushStatusDelivered,
|
|||
|
|
"delivered_at": time.Now(),
|
|||
|
|
}).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// DeleteExpiredRecords 删除过期的推送记录(软删除)
|
|||
|
|
func (r *PushRecordRepository) DeleteExpiredRecords() error {
|
|||
|
|
return r.db.Where("expired_at IS NOT NULL AND expired_at < ?", time.Now()).
|
|||
|
|
Delete(&model.PushRecord{}).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetStatsByUserID 获取用户推送统计
|
|||
|
|
func (r *PushRecordRepository) GetStatsByUserID(userID int64) (map[model.PushStatus]int64, error) {
|
|||
|
|
type statusCount struct {
|
|||
|
|
Status model.PushStatus
|
|||
|
|
Count int64
|
|||
|
|
}
|
|||
|
|
var results []statusCount
|
|||
|
|
|
|||
|
|
err := r.db.Model(&model.PushRecord{}).
|
|||
|
|
Select("push_status as status, count(*) as count").
|
|||
|
|
Where("user_id = ?", userID).
|
|||
|
|
Group("push_status").
|
|||
|
|
Scan(&results).Error
|
|||
|
|
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
stats := make(map[model.PushStatus]int64)
|
|||
|
|
for _, r := range results {
|
|||
|
|
stats[r.Status] = r.Count
|
|||
|
|
}
|
|||
|
|
return stats, nil
|
|||
|
|
}
|