Files
backend/internal/repository/push_repo.go
lan 4d8f2ec997 Initial backend repository commit.
Set up project files and add .gitignore to exclude local build/runtime artifacts.

Made-with: Cursor
2026-03-09 21:28:58 +08:00

173 lines
4.7 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"time"
"carrot_bbs/internal/model"
"gorm.io/gorm"
)
// PushRecordRepository 推送记录仓储
type PushRecordRepository struct {
db *gorm.DB
}
// NewPushRecordRepository 创建推送记录仓储
func NewPushRecordRepository(db *gorm.DB) *PushRecordRepository {
return &PushRecordRepository{db: db}
}
// Create 创建推送记录
func (r *PushRecordRepository) Create(record *model.PushRecord) error {
return r.db.Create(record).Error
}
// GetByID 根据ID获取推送记录
func (r *PushRecordRepository) GetByID(id int64) (*model.PushRecord, error) {
var record model.PushRecord
err := r.db.First(&record, "id = ?", id).Error
if err != nil {
return nil, err
}
return &record, nil
}
// Update 更新推送记录
func (r *PushRecordRepository) Update(record *model.PushRecord) error {
return r.db.Save(record).Error
}
// GetPendingPushes 获取待推送记录
func (r *PushRecordRepository) GetPendingPushes(limit int) ([]*model.PushRecord, error) {
var records []*model.PushRecord
err := r.db.Where("push_status = ?", model.PushStatusPending).
Where("expired_at IS NULL OR expired_at > ?", time.Now()).
Order("created_at ASC").
Limit(limit).
Find(&records).Error
return records, err
}
// GetByUserID 根据用户ID获取推送记录
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *PushRecordRepository) GetByUserID(userID string, limit, offset int) ([]*model.PushRecord, error) {
var records []*model.PushRecord
err := r.db.Where("user_id = ?", userID).
Order("created_at DESC").
Offset(offset).
Limit(limit).
Find(&records).Error
return records, err
}
// GetByMessageID 根据消息ID获取推送记录
func (r *PushRecordRepository) GetByMessageID(messageID int64) ([]*model.PushRecord, error) {
var records []*model.PushRecord
err := r.db.Where("message_id = ?", messageID).
Order("created_at DESC").
Find(&records).Error
return records, err
}
// GetFailedPushesForRetry 获取失败待重试的推送
func (r *PushRecordRepository) GetFailedPushesForRetry(limit int) ([]*model.PushRecord, error) {
var records []*model.PushRecord
err := r.db.Where("push_status = ?", model.PushStatusFailed).
Where("retry_count < max_retry").
Where("expired_at IS NULL OR expired_at > ?", time.Now()).
Order("created_at ASC").
Limit(limit).
Find(&records).Error
return records, err
}
// BatchCreate 批量创建推送记录
func (r *PushRecordRepository) BatchCreate(records []*model.PushRecord) error {
if len(records) == 0 {
return nil
}
return r.db.Create(&records).Error
}
// BatchUpdateStatus 批量更新推送状态
func (r *PushRecordRepository) BatchUpdateStatus(ids []int64, status model.PushStatus) error {
if len(ids) == 0 {
return nil
}
updates := map[string]interface{}{
"push_status": status,
}
if status == model.PushStatusPushed {
updates["pushed_at"] = time.Now()
}
return r.db.Model(&model.PushRecord{}).
Where("id IN ?", ids).
Updates(updates).Error
}
// UpdateStatus 更新单条记录状态
func (r *PushRecordRepository) UpdateStatus(id int64, status model.PushStatus) error {
updates := map[string]interface{}{
"push_status": status,
}
if status == model.PushStatusPushed {
updates["pushed_at"] = time.Now()
}
return r.db.Model(&model.PushRecord{}).
Where("id = ?", id).
Updates(updates).Error
}
// MarkAsFailed 标记为失败
func (r *PushRecordRepository) MarkAsFailed(id int64, errMsg string) error {
return r.db.Model(&model.PushRecord{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"push_status": model.PushStatusFailed,
"error_message": errMsg,
"retry_count": gorm.Expr("retry_count + 1"),
}).Error
}
// MarkAsDelivered 标记为已送达
func (r *PushRecordRepository) MarkAsDelivered(id int64) error {
return r.db.Model(&model.PushRecord{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"push_status": model.PushStatusDelivered,
"delivered_at": time.Now(),
}).Error
}
// DeleteExpiredRecords 删除过期的推送记录(软删除)
func (r *PushRecordRepository) DeleteExpiredRecords() error {
return r.db.Where("expired_at IS NOT NULL AND expired_at < ?", time.Now()).
Delete(&model.PushRecord{}).Error
}
// GetStatsByUserID 获取用户推送统计
func (r *PushRecordRepository) GetStatsByUserID(userID int64) (map[model.PushStatus]int64, error) {
type statusCount struct {
Status model.PushStatus
Count int64
}
var results []statusCount
err := r.db.Model(&model.PushRecord{}).
Select("push_status as status, count(*) as count").
Where("user_id = ?", userID).
Group("push_status").
Scan(&results).Error
if err != nil {
return nil, err
}
stats := make(map[model.PushStatus]int64)
for _, r := range results {
stats[r.Status] = r.Count
}
return stats, nil
}