79 lines
2.3 KiB
Go
79 lines
2.3 KiB
Go
|
|
package repository
|
||
|
|
|
||
|
|
import (
|
||
|
|
"carrot_bbs/internal/model"
|
||
|
|
|
||
|
|
"gorm.io/gorm"
|
||
|
|
)
|
||
|
|
|
||
|
|
// NotificationRepository 通知仓储
|
||
|
|
type NotificationRepository struct {
|
||
|
|
db *gorm.DB
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewNotificationRepository 创建通知仓储
|
||
|
|
func NewNotificationRepository(db *gorm.DB) *NotificationRepository {
|
||
|
|
return &NotificationRepository{db: db}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Create 创建通知
|
||
|
|
func (r *NotificationRepository) Create(notification *model.Notification) error {
|
||
|
|
return r.db.Create(notification).Error
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetByID 根据ID获取通知
|
||
|
|
func (r *NotificationRepository) GetByID(id string) (*model.Notification, error) {
|
||
|
|
var notification model.Notification
|
||
|
|
err := r.db.First(¬ification, "id = ?", id).Error
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return ¬ification, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetByUserID 获取用户通知
|
||
|
|
func (r *NotificationRepository) GetByUserID(userID string, page, pageSize int, unreadOnly bool) ([]*model.Notification, int64, error) {
|
||
|
|
var notifications []*model.Notification
|
||
|
|
var total int64
|
||
|
|
|
||
|
|
query := r.db.Model(&model.Notification{}).Where("user_id = ?", userID)
|
||
|
|
|
||
|
|
if unreadOnly {
|
||
|
|
query = query.Where("is_read = ?", false)
|
||
|
|
}
|
||
|
|
|
||
|
|
query.Count(&total)
|
||
|
|
|
||
|
|
offset := (page - 1) * pageSize
|
||
|
|
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(¬ifications).Error
|
||
|
|
|
||
|
|
return notifications, total, err
|
||
|
|
}
|
||
|
|
|
||
|
|
// MarkAsRead 标记为已读
|
||
|
|
func (r *NotificationRepository) MarkAsRead(id string) error {
|
||
|
|
return r.db.Model(&model.Notification{}).Where("id = ?", id).Update("is_read", true).Error
|
||
|
|
}
|
||
|
|
|
||
|
|
// MarkAllAsRead 标记所有为已读
|
||
|
|
func (r *NotificationRepository) MarkAllAsRead(userID string) error {
|
||
|
|
return r.db.Model(&model.Notification{}).Where("user_id = ?", userID).Update("is_read", true).Error
|
||
|
|
}
|
||
|
|
|
||
|
|
// Delete 删除通知
|
||
|
|
func (r *NotificationRepository) Delete(id string) error {
|
||
|
|
return r.db.Delete(&model.Notification{}, "id = ?", id).Error
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetUnreadCount 获取未读数量
|
||
|
|
func (r *NotificationRepository) GetUnreadCount(userID string) (int64, error) {
|
||
|
|
var count int64
|
||
|
|
err := r.db.Model(&model.Notification{}).Where("user_id = ? AND is_read = ?", userID, false).Count(&count).Error
|
||
|
|
return count, err
|
||
|
|
}
|
||
|
|
|
||
|
|
// DeleteAllByUserID 删除用户所有通知
|
||
|
|
func (r *NotificationRepository) DeleteAllByUserID(userID string) error {
|
||
|
|
return r.db.Where("user_id = ?", userID).Delete(&model.Notification{}).Error
|
||
|
|
}
|