package repository import ( "time" "carrot_bbs/internal/model" "gorm.io/gorm" ) // PushRecordRepository 推送记录仓储 type PushRecordRepository struct { db *gorm.DB } // NewPushRecordRepository 创建推送记录仓储 func NewPushRecordRepository(db *gorm.DB) *PushRecordRepository { return &PushRecordRepository{db: db} } // Create 创建推送记录 func (r *PushRecordRepository) Create(record *model.PushRecord) error { return r.db.Create(record).Error } // GetByID 根据ID获取推送记录 func (r *PushRecordRepository) GetByID(id int64) (*model.PushRecord, error) { var record model.PushRecord err := r.db.First(&record, "id = ?", id).Error if err != nil { return nil, err } return &record, nil } // Update 更新推送记录 func (r *PushRecordRepository) Update(record *model.PushRecord) error { return r.db.Save(record).Error } // GetPendingPushes 获取待推送记录 func (r *PushRecordRepository) GetPendingPushes(limit int) ([]*model.PushRecord, error) { var records []*model.PushRecord err := r.db.Where("push_status = ?", model.PushStatusPending). Where("expired_at IS NULL OR expired_at > ?", time.Now()). Order("created_at ASC"). Limit(limit). Find(&records).Error return records, err } // GetByUserID 根据用户ID获取推送记录 // userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 func (r *PushRecordRepository) GetByUserID(userID string, limit, offset int) ([]*model.PushRecord, error) { var records []*model.PushRecord err := r.db.Where("user_id = ?", userID). Order("created_at DESC"). Offset(offset). Limit(limit). Find(&records).Error return records, err } // GetByMessageID 根据消息ID获取推送记录 func (r *PushRecordRepository) GetByMessageID(messageID int64) ([]*model.PushRecord, error) { var records []*model.PushRecord err := r.db.Where("message_id = ?", messageID). Order("created_at DESC"). Find(&records).Error return records, err } // GetFailedPushesForRetry 获取失败待重试的推送 func (r *PushRecordRepository) GetFailedPushesForRetry(limit int) ([]*model.PushRecord, error) { var records []*model.PushRecord err := r.db.Where("push_status = ?", model.PushStatusFailed). Where("retry_count < max_retry"). Where("expired_at IS NULL OR expired_at > ?", time.Now()). Order("created_at ASC"). Limit(limit). Find(&records).Error return records, err } // BatchCreate 批量创建推送记录 func (r *PushRecordRepository) BatchCreate(records []*model.PushRecord) error { if len(records) == 0 { return nil } return r.db.Create(&records).Error } // BatchUpdateStatus 批量更新推送状态 func (r *PushRecordRepository) BatchUpdateStatus(ids []int64, status model.PushStatus) error { if len(ids) == 0 { return nil } updates := map[string]interface{}{ "push_status": status, } if status == model.PushStatusPushed { updates["pushed_at"] = time.Now() } return r.db.Model(&model.PushRecord{}). Where("id IN ?", ids). Updates(updates).Error } // UpdateStatus 更新单条记录状态 func (r *PushRecordRepository) UpdateStatus(id int64, status model.PushStatus) error { updates := map[string]interface{}{ "push_status": status, } if status == model.PushStatusPushed { updates["pushed_at"] = time.Now() } return r.db.Model(&model.PushRecord{}). Where("id = ?", id). Updates(updates).Error } // MarkAsFailed 标记为失败 func (r *PushRecordRepository) MarkAsFailed(id int64, errMsg string) error { return r.db.Model(&model.PushRecord{}). Where("id = ?", id). Updates(map[string]interface{}{ "push_status": model.PushStatusFailed, "error_message": errMsg, "retry_count": gorm.Expr("retry_count + 1"), }).Error } // MarkAsDelivered 标记为已送达 func (r *PushRecordRepository) MarkAsDelivered(id int64) error { return r.db.Model(&model.PushRecord{}). Where("id = ?", id). Updates(map[string]interface{}{ "push_status": model.PushStatusDelivered, "delivered_at": time.Now(), }).Error } // DeleteExpiredRecords 删除过期的推送记录(软删除) func (r *PushRecordRepository) DeleteExpiredRecords() error { return r.db.Where("expired_at IS NOT NULL AND expired_at < ?", time.Now()). Delete(&model.PushRecord{}).Error } // GetStatsByUserID 获取用户推送统计 func (r *PushRecordRepository) GetStatsByUserID(userID int64) (map[model.PushStatus]int64, error) { type statusCount struct { Status model.PushStatus Count int64 } var results []statusCount err := r.db.Model(&model.PushRecord{}). Select("push_status as status, count(*) as count"). Where("user_id = ?", userID). Group("push_status"). Scan(&results).Error if err != nil { return nil, err } stats := make(map[model.PushStatus]int64) for _, r := range results { stats[r.Status] = r.Count } return stats, nil }