Add complete schedule functionality including: - Schedule screen with weekly course table view - Course detail screen with transparent modal presentation - New ScheduleStack navigator integrated into main tab bar - Schedule service for API interactions - Type definitions for course entities Also includes bug fixes for group invite/request handlers to include required groupId parameter.
1097 lines
26 KiB
Go
1097 lines
26 KiB
Go
package cache
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"log"
|
||
"math"
|
||
"math/rand"
|
||
"sort"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"github.com/redis/go-redis/v9"
|
||
|
||
redisPkg "carrot_bbs/internal/pkg/redis"
|
||
)
|
||
|
||
// Cache 缓存接口
|
||
type Cache interface {
|
||
// Set 设置缓存值,支持TTL
|
||
Set(key string, value interface{}, ttl time.Duration)
|
||
// Get 获取缓存值
|
||
Get(key string) (interface{}, bool)
|
||
// Delete 删除缓存
|
||
Delete(key string)
|
||
// DeleteByPrefix 根据前缀删除缓存
|
||
DeleteByPrefix(prefix string)
|
||
// Clear 清空所有缓存
|
||
Clear()
|
||
// Exists 检查键是否存在
|
||
Exists(key string) bool
|
||
// Increment 增加计数器的值
|
||
Increment(key string) int64
|
||
// IncrementBy 增加指定值
|
||
IncrementBy(key string, value int64) int64
|
||
|
||
// ==================== Hash 操作 ====================
|
||
// HSet 设置 Hash 字段
|
||
HSet(ctx context.Context, key string, field string, value interface{}) error
|
||
// HMSet 批量设置 Hash 字段
|
||
HMSet(ctx context.Context, key string, values map[string]interface{}) error
|
||
// HGet 获取 Hash 字段值
|
||
HGet(ctx context.Context, key string, field string) (string, error)
|
||
// HMGet 批量获取 Hash 字段值
|
||
HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error)
|
||
// HGetAll 获取 Hash 所有字段
|
||
HGetAll(ctx context.Context, key string) (map[string]string, error)
|
||
// HDel 删除 Hash 字段
|
||
HDel(ctx context.Context, key string, fields ...string) error
|
||
|
||
// ==================== Sorted Set 操作 ====================
|
||
// ZAdd 添加 Sorted Set 成员
|
||
ZAdd(ctx context.Context, key string, score float64, member string) error
|
||
// ZRangeByScore 按分数范围获取成员(升序)
|
||
ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error)
|
||
// ZRevRangeByScore 按分数范围获取成员(降序)
|
||
ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error)
|
||
// ZRem 删除 Sorted Set 成员
|
||
ZRem(ctx context.Context, key string, members ...interface{}) error
|
||
// ZCard 获取 Sorted Set 成员数量
|
||
ZCard(ctx context.Context, key string) (int64, error)
|
||
|
||
// ==================== 计数器操作 ====================
|
||
// Incr 原子递增(返回新值)
|
||
Incr(ctx context.Context, key string) (int64, error)
|
||
// Expire 设置过期时间
|
||
Expire(ctx context.Context, key string, ttl time.Duration) error
|
||
}
|
||
|
||
// cacheItem 缓存项(用于内存缓存降级)
|
||
type cacheItem struct {
|
||
value interface{}
|
||
expiration int64 // 过期时间戳(纳秒)
|
||
}
|
||
|
||
const nullMarkerValue = "__carrot_cache_null__"
|
||
|
||
type cacheMetrics struct {
|
||
hit atomic.Int64
|
||
miss atomic.Int64
|
||
decodeError atomic.Int64
|
||
setError atomic.Int64
|
||
invalidate atomic.Int64
|
||
}
|
||
|
||
var metrics cacheMetrics
|
||
var loadLocks sync.Map
|
||
|
||
type MetricsSnapshot struct {
|
||
Hit int64
|
||
Miss int64
|
||
DecodeError int64
|
||
SetError int64
|
||
Invalidate int64
|
||
}
|
||
|
||
type Settings struct {
|
||
Enabled bool
|
||
KeyPrefix string
|
||
DefaultTTL time.Duration
|
||
NullTTL time.Duration
|
||
JitterRatio float64
|
||
PostListTTL time.Duration
|
||
ConversationTTL time.Duration
|
||
UnreadCountTTL time.Duration
|
||
GroupMembersTTL time.Duration
|
||
DisableFlushDB bool
|
||
}
|
||
|
||
var settings = Settings{
|
||
Enabled: true,
|
||
DefaultTTL: 30 * time.Second,
|
||
NullTTL: 5 * time.Second,
|
||
JitterRatio: 0.1,
|
||
PostListTTL: 30 * time.Second,
|
||
ConversationTTL: 60 * time.Second,
|
||
UnreadCountTTL: 30 * time.Second,
|
||
GroupMembersTTL: 120 * time.Second,
|
||
DisableFlushDB: true,
|
||
}
|
||
|
||
func Configure(s Settings) {
|
||
settings.Enabled = s.Enabled
|
||
if s.KeyPrefix != "" {
|
||
settings.KeyPrefix = s.KeyPrefix
|
||
}
|
||
if s.DefaultTTL > 0 {
|
||
settings.DefaultTTL = s.DefaultTTL
|
||
}
|
||
if s.NullTTL > 0 {
|
||
settings.NullTTL = s.NullTTL
|
||
}
|
||
if s.JitterRatio > 0 {
|
||
settings.JitterRatio = s.JitterRatio
|
||
}
|
||
if s.PostListTTL > 0 {
|
||
settings.PostListTTL = s.PostListTTL
|
||
}
|
||
if s.ConversationTTL > 0 {
|
||
settings.ConversationTTL = s.ConversationTTL
|
||
}
|
||
if s.UnreadCountTTL > 0 {
|
||
settings.UnreadCountTTL = s.UnreadCountTTL
|
||
}
|
||
if s.GroupMembersTTL > 0 {
|
||
settings.GroupMembersTTL = s.GroupMembersTTL
|
||
}
|
||
settings.DisableFlushDB = s.DisableFlushDB
|
||
}
|
||
|
||
func GetSettings() Settings {
|
||
return settings
|
||
}
|
||
|
||
func normalizeKey(key string) string {
|
||
if settings.KeyPrefix == "" {
|
||
return key
|
||
}
|
||
return settings.KeyPrefix + ":" + key
|
||
}
|
||
|
||
func normalizePrefix(prefix string) string {
|
||
if settings.KeyPrefix == "" {
|
||
return prefix
|
||
}
|
||
return settings.KeyPrefix + ":" + prefix
|
||
}
|
||
|
||
func GetMetricsSnapshot() MetricsSnapshot {
|
||
return MetricsSnapshot{
|
||
Hit: metrics.hit.Load(),
|
||
Miss: metrics.miss.Load(),
|
||
DecodeError: metrics.decodeError.Load(),
|
||
SetError: metrics.setError.Load(),
|
||
Invalidate: metrics.invalidate.Load(),
|
||
}
|
||
}
|
||
|
||
// isExpired 检查是否过期
|
||
func (item *cacheItem) isExpired() bool {
|
||
if item.expiration == 0 {
|
||
return false
|
||
}
|
||
return time.Now().UnixNano() > item.expiration
|
||
}
|
||
|
||
// MemoryCache 内存缓存实现(降级使用)
|
||
type MemoryCache struct {
|
||
items sync.Map
|
||
// cleanupInterval 清理过期缓存的间隔
|
||
cleanupInterval time.Duration
|
||
// stopCleanup 停止清理协程的通道
|
||
stopCleanup chan struct{}
|
||
}
|
||
|
||
// NewMemoryCache 创建内存缓存
|
||
func NewMemoryCache() *MemoryCache {
|
||
c := &MemoryCache{
|
||
cleanupInterval: 1 * time.Minute,
|
||
stopCleanup: make(chan struct{}),
|
||
}
|
||
// 启动后台清理协程
|
||
go c.cleanup()
|
||
return c
|
||
}
|
||
|
||
// Set 设置缓存值
|
||
func (c *MemoryCache) Set(key string, value interface{}, ttl time.Duration) {
|
||
key = normalizeKey(key)
|
||
var expiration int64
|
||
if ttl > 0 {
|
||
expiration = time.Now().Add(ttl).UnixNano()
|
||
}
|
||
c.items.Store(key, &cacheItem{
|
||
value: value,
|
||
expiration: expiration,
|
||
})
|
||
}
|
||
|
||
// Get 获取缓存值
|
||
func (c *MemoryCache) Get(key string) (interface{}, bool) {
|
||
key = normalizeKey(key)
|
||
val, ok := c.items.Load(key)
|
||
if !ok {
|
||
return nil, false
|
||
}
|
||
|
||
item := val.(*cacheItem)
|
||
if item.isExpired() {
|
||
c.items.Delete(key)
|
||
return nil, false
|
||
}
|
||
|
||
return item.value, true
|
||
}
|
||
|
||
// Delete 删除缓存
|
||
func (c *MemoryCache) Delete(key string) {
|
||
key = normalizeKey(key)
|
||
metrics.invalidate.Add(1)
|
||
c.items.Delete(key)
|
||
}
|
||
|
||
// DeleteByPrefix 根据前缀删除缓存
|
||
func (c *MemoryCache) DeleteByPrefix(prefix string) {
|
||
prefix = normalizePrefix(prefix)
|
||
c.items.Range(func(key, value interface{}) bool {
|
||
if keyStr, ok := key.(string); ok {
|
||
if strings.HasPrefix(keyStr, prefix) {
|
||
metrics.invalidate.Add(1)
|
||
c.items.Delete(key)
|
||
}
|
||
}
|
||
return true
|
||
})
|
||
}
|
||
|
||
// Clear 清空所有缓存
|
||
func (c *MemoryCache) Clear() {
|
||
c.items.Range(func(key, value interface{}) bool {
|
||
metrics.invalidate.Add(1)
|
||
c.items.Delete(key)
|
||
return true
|
||
})
|
||
}
|
||
|
||
// Exists 检查键是否存在
|
||
func (c *MemoryCache) Exists(key string) bool {
|
||
_, ok := c.Get(key)
|
||
return ok
|
||
}
|
||
|
||
// Increment 增加计数器的值
|
||
func (c *MemoryCache) Increment(key string) int64 {
|
||
return c.IncrementBy(key, 1)
|
||
}
|
||
|
||
// IncrementBy 增加指定值
|
||
func (c *MemoryCache) IncrementBy(key string, value int64) int64 {
|
||
key = normalizeKey(key)
|
||
for {
|
||
val, ok := c.items.Load(key)
|
||
if !ok {
|
||
// 键不存在,创建新值
|
||
c.items.Store(key, &cacheItem{
|
||
value: value,
|
||
expiration: 0,
|
||
})
|
||
return value
|
||
}
|
||
|
||
item := val.(*cacheItem)
|
||
if item.isExpired() {
|
||
// 已过期,创建新值
|
||
c.items.Store(key, &cacheItem{
|
||
value: value,
|
||
expiration: 0,
|
||
})
|
||
return value
|
||
}
|
||
|
||
// 尝试更新
|
||
currentValue, ok := item.value.(int64)
|
||
if !ok {
|
||
// 类型不匹配,覆盖为新值
|
||
c.items.Store(key, &cacheItem{
|
||
value: value,
|
||
expiration: item.expiration,
|
||
})
|
||
return value
|
||
}
|
||
|
||
newValue := currentValue + value
|
||
// 使用 CAS 操作确保并发安全
|
||
if c.items.CompareAndSwap(key, val, &cacheItem{
|
||
value: newValue,
|
||
expiration: item.expiration,
|
||
}) {
|
||
return newValue
|
||
}
|
||
// CAS 失败,重试
|
||
}
|
||
}
|
||
|
||
// cleanup 定期清理过期缓存
|
||
func (c *MemoryCache) cleanup() {
|
||
ticker := time.NewTicker(c.cleanupInterval)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
c.cleanExpired()
|
||
case <-c.stopCleanup:
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// cleanExpired 清理过期缓存
|
||
func (c *MemoryCache) cleanExpired() {
|
||
count := 0
|
||
c.items.Range(func(key, value interface{}) bool {
|
||
item := value.(*cacheItem)
|
||
if item.isExpired() {
|
||
c.items.Delete(key)
|
||
count++
|
||
}
|
||
return true
|
||
})
|
||
if count > 0 {
|
||
log.Printf("[Cache] Cleaned %d expired items", count)
|
||
}
|
||
}
|
||
|
||
// Stop 停止缓存清理协程
|
||
func (c *MemoryCache) Stop() {
|
||
close(c.stopCleanup)
|
||
}
|
||
|
||
// ==================== MemoryCache Hash 操作 ====================
|
||
|
||
// hashItem Hash 存储项
|
||
type hashItem struct {
|
||
fields sync.Map
|
||
}
|
||
|
||
// HSet 设置 Hash 字段
|
||
func (c *MemoryCache) HSet(ctx context.Context, key string, field string, value interface{}) error {
|
||
key = normalizeKey(key)
|
||
item, _ := c.items.Load(key)
|
||
var h *hashItem
|
||
if item == nil {
|
||
h = &hashItem{}
|
||
c.items.Store(key, &cacheItem{value: h, expiration: 0})
|
||
} else {
|
||
ci := item.(*cacheItem)
|
||
if ci.isExpired() {
|
||
h = &hashItem{}
|
||
c.items.Store(key, &cacheItem{value: h, expiration: 0})
|
||
} else {
|
||
h = ci.value.(*hashItem)
|
||
}
|
||
}
|
||
h.fields.Store(field, value)
|
||
return nil
|
||
}
|
||
|
||
// HMSet 批量设置 Hash 字段
|
||
func (c *MemoryCache) HMSet(ctx context.Context, key string, values map[string]interface{}) error {
|
||
for field, value := range values {
|
||
if err := c.HSet(ctx, key, field, value); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// HGet 获取 Hash 字段值
|
||
func (c *MemoryCache) HGet(ctx context.Context, key string, field string) (string, error) {
|
||
key = normalizeKey(key)
|
||
item, ok := c.items.Load(key)
|
||
if !ok {
|
||
return "", fmt.Errorf("key not found")
|
||
}
|
||
ci := item.(*cacheItem)
|
||
if ci.isExpired() {
|
||
c.items.Delete(key)
|
||
return "", fmt.Errorf("key not found")
|
||
}
|
||
h, ok := ci.value.(*hashItem)
|
||
if !ok {
|
||
return "", fmt.Errorf("key is not a hash")
|
||
}
|
||
val, ok := h.fields.Load(field)
|
||
if !ok {
|
||
return "", fmt.Errorf("field not found")
|
||
}
|
||
switch v := val.(type) {
|
||
case string:
|
||
return v, nil
|
||
case []byte:
|
||
return string(v), nil
|
||
default:
|
||
data, _ := json.Marshal(v)
|
||
return string(data), nil
|
||
}
|
||
}
|
||
|
||
// HMGet 批量获取 Hash 字段值
|
||
func (c *MemoryCache) HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) {
|
||
result := make([]interface{}, len(fields))
|
||
for i, field := range fields {
|
||
val, err := c.HGet(ctx, key, field)
|
||
if err != nil {
|
||
result[i] = nil
|
||
} else {
|
||
result[i] = val
|
||
}
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// HGetAll 获取 Hash 所有字段
|
||
func (c *MemoryCache) HGetAll(ctx context.Context, key string) (map[string]string, error) {
|
||
key = normalizeKey(key)
|
||
item, ok := c.items.Load(key)
|
||
if !ok {
|
||
return nil, fmt.Errorf("key not found")
|
||
}
|
||
ci := item.(*cacheItem)
|
||
if ci.isExpired() {
|
||
c.items.Delete(key)
|
||
return nil, fmt.Errorf("key not found")
|
||
}
|
||
h, ok := ci.value.(*hashItem)
|
||
if !ok {
|
||
return nil, fmt.Errorf("key is not a hash")
|
||
}
|
||
result := make(map[string]string)
|
||
h.fields.Range(func(k, v interface{}) bool {
|
||
keyStr := k.(string)
|
||
switch val := v.(type) {
|
||
case string:
|
||
result[keyStr] = val
|
||
case []byte:
|
||
result[keyStr] = string(val)
|
||
default:
|
||
data, _ := json.Marshal(val)
|
||
result[keyStr] = string(data)
|
||
}
|
||
return true
|
||
})
|
||
return result, nil
|
||
}
|
||
|
||
// HDel 删除 Hash 字段
|
||
func (c *MemoryCache) HDel(ctx context.Context, key string, fields ...string) error {
|
||
key = normalizeKey(key)
|
||
item, ok := c.items.Load(key)
|
||
if !ok {
|
||
return nil
|
||
}
|
||
ci := item.(*cacheItem)
|
||
if ci.isExpired() {
|
||
c.items.Delete(key)
|
||
return nil
|
||
}
|
||
h, ok := ci.value.(*hashItem)
|
||
if !ok {
|
||
return nil
|
||
}
|
||
for _, field := range fields {
|
||
h.fields.Delete(field)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ==================== MemoryCache Sorted Set 操作 ====================
|
||
|
||
// zItem Sorted Set 成员
|
||
type zItem struct {
|
||
score float64
|
||
member string
|
||
}
|
||
|
||
// zsetItem Sorted Set 存储项
|
||
type zsetItem struct {
|
||
members sync.Map // member -> *zItem
|
||
byScore *sortedSlice // 按分数排序的切片
|
||
}
|
||
|
||
// sortedSlice 简单的排序切片实现
|
||
type sortedSlice struct {
|
||
items []*zItem
|
||
mu sync.RWMutex
|
||
}
|
||
|
||
// ZAdd 添加 Sorted Set 成员
|
||
func (c *MemoryCache) ZAdd(ctx context.Context, key string, score float64, member string) error {
|
||
key = normalizeKey(key)
|
||
item, _ := c.items.Load(key)
|
||
var z *zsetItem
|
||
if item == nil {
|
||
z = &zsetItem{byScore: &sortedSlice{}}
|
||
c.items.Store(key, &cacheItem{value: z, expiration: 0})
|
||
} else {
|
||
ci := item.(*cacheItem)
|
||
if ci.isExpired() {
|
||
z = &zsetItem{byScore: &sortedSlice{}}
|
||
c.items.Store(key, &cacheItem{value: z, expiration: 0})
|
||
} else {
|
||
z = ci.value.(*zsetItem)
|
||
}
|
||
}
|
||
z.members.Store(member, &zItem{score: score, member: member})
|
||
z.byScore.mu.Lock()
|
||
// 简单实现:重新构建排序切片
|
||
z.byScore.items = nil
|
||
z.members.Range(func(k, v interface{}) bool {
|
||
z.byScore.items = append(z.byScore.items, v.(*zItem))
|
||
return true
|
||
})
|
||
// 按分数排序
|
||
sort.Slice(z.byScore.items, func(i, j int) bool {
|
||
return z.byScore.items[i].score < z.byScore.items[j].score
|
||
})
|
||
z.byScore.mu.Unlock()
|
||
return nil
|
||
}
|
||
|
||
// ZRangeByScore 按分数范围获取成员(升序)
|
||
func (c *MemoryCache) ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) {
|
||
key = normalizeKey(key)
|
||
item, ok := c.items.Load(key)
|
||
if !ok {
|
||
return nil, nil
|
||
}
|
||
ci := item.(*cacheItem)
|
||
if ci.isExpired() {
|
||
c.items.Delete(key)
|
||
return nil, nil
|
||
}
|
||
z, ok := ci.value.(*zsetItem)
|
||
if !ok {
|
||
return nil, fmt.Errorf("key is not a sorted set")
|
||
}
|
||
|
||
minScore, _ := strconv.ParseFloat(min, 64)
|
||
maxScore, _ := strconv.ParseFloat(max, 64)
|
||
if min == "-inf" {
|
||
minScore = math.Inf(-1)
|
||
}
|
||
if max == "+inf" {
|
||
maxScore = math.Inf(1)
|
||
}
|
||
|
||
z.byScore.mu.RLock()
|
||
defer z.byScore.mu.RUnlock()
|
||
|
||
var result []string
|
||
var skipped int64 = 0
|
||
for _, item := range z.byScore.items {
|
||
if item.score < minScore || item.score > maxScore {
|
||
continue
|
||
}
|
||
if skipped < offset {
|
||
skipped++
|
||
continue
|
||
}
|
||
if count > 0 && int64(len(result)) >= count {
|
||
break
|
||
}
|
||
result = append(result, item.member)
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// ZRevRangeByScore 按分数范围获取成员(降序)
|
||
func (c *MemoryCache) ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) {
|
||
key = normalizeKey(key)
|
||
item, ok := c.items.Load(key)
|
||
if !ok {
|
||
return nil, nil
|
||
}
|
||
ci := item.(*cacheItem)
|
||
if ci.isExpired() {
|
||
c.items.Delete(key)
|
||
return nil, nil
|
||
}
|
||
z, ok := ci.value.(*zsetItem)
|
||
if !ok {
|
||
return nil, fmt.Errorf("key is not a sorted set")
|
||
}
|
||
|
||
minScore, _ := strconv.ParseFloat(min, 64)
|
||
maxScore, _ := strconv.ParseFloat(max, 64)
|
||
if min == "-inf" {
|
||
minScore = math.Inf(-1)
|
||
}
|
||
if max == "+inf" {
|
||
maxScore = math.Inf(1)
|
||
}
|
||
|
||
z.byScore.mu.RLock()
|
||
defer z.byScore.mu.RUnlock()
|
||
|
||
var result []string
|
||
var skipped int64 = 0
|
||
// 从后往前遍历
|
||
for i := len(z.byScore.items) - 1; i >= 0; i-- {
|
||
item := z.byScore.items[i]
|
||
if item.score < minScore || item.score > maxScore {
|
||
continue
|
||
}
|
||
if skipped < offset {
|
||
skipped++
|
||
continue
|
||
}
|
||
if count > 0 && int64(len(result)) >= count {
|
||
break
|
||
}
|
||
result = append(result, item.member)
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// ZRem 删除 Sorted Set 成员
|
||
func (c *MemoryCache) ZRem(ctx context.Context, key string, members ...interface{}) error {
|
||
key = normalizeKey(key)
|
||
item, ok := c.items.Load(key)
|
||
if !ok {
|
||
return nil
|
||
}
|
||
ci := item.(*cacheItem)
|
||
if ci.isExpired() {
|
||
c.items.Delete(key)
|
||
return nil
|
||
}
|
||
z, ok := ci.value.(*zsetItem)
|
||
if !ok {
|
||
return nil
|
||
}
|
||
for _, m := range members {
|
||
if member, ok := m.(string); ok {
|
||
z.members.Delete(member)
|
||
}
|
||
}
|
||
// 重建排序切片
|
||
z.byScore.mu.Lock()
|
||
z.byScore.items = nil
|
||
z.members.Range(func(k, v interface{}) bool {
|
||
z.byScore.items = append(z.byScore.items, v.(*zItem))
|
||
return true
|
||
})
|
||
sort.Slice(z.byScore.items, func(i, j int) bool {
|
||
return z.byScore.items[i].score < z.byScore.items[j].score
|
||
})
|
||
z.byScore.mu.Unlock()
|
||
return nil
|
||
}
|
||
|
||
// ZCard 获取 Sorted Set 成员数量
|
||
func (c *MemoryCache) ZCard(ctx context.Context, key string) (int64, error) {
|
||
key = normalizeKey(key)
|
||
item, ok := c.items.Load(key)
|
||
if !ok {
|
||
return 0, nil
|
||
}
|
||
ci := item.(*cacheItem)
|
||
if ci.isExpired() {
|
||
c.items.Delete(key)
|
||
return 0, nil
|
||
}
|
||
z, ok := ci.value.(*zsetItem)
|
||
if !ok {
|
||
return 0, fmt.Errorf("key is not a sorted set")
|
||
}
|
||
var count int64 = 0
|
||
z.members.Range(func(k, v interface{}) bool {
|
||
count++
|
||
return true
|
||
})
|
||
return count, nil
|
||
}
|
||
|
||
// ==================== MemoryCache 计数器操作 ====================
|
||
|
||
// Incr 原子递增(返回新值)
|
||
func (c *MemoryCache) Incr(ctx context.Context, key string) (int64, error) {
|
||
return c.IncrementBy(key, 1), nil
|
||
}
|
||
|
||
// Expire 设置过期时间
|
||
func (c *MemoryCache) Expire(ctx context.Context, key string, ttl time.Duration) error {
|
||
key = normalizeKey(key)
|
||
item, ok := c.items.Load(key)
|
||
if !ok {
|
||
return fmt.Errorf("key not found")
|
||
}
|
||
ci := item.(*cacheItem)
|
||
var expiration int64
|
||
if ttl > 0 {
|
||
expiration = time.Now().Add(ttl).UnixNano()
|
||
}
|
||
c.items.Store(key, &cacheItem{
|
||
value: ci.value,
|
||
expiration: expiration,
|
||
})
|
||
return nil
|
||
}
|
||
|
||
// RedisCache Redis缓存实现
|
||
type RedisCache struct {
|
||
client *redisPkg.Client
|
||
ctx context.Context
|
||
}
|
||
|
||
// NewRedisCache 创建Redis缓存
|
||
func NewRedisCache(client *redisPkg.Client) *RedisCache {
|
||
return &RedisCache{
|
||
client: client,
|
||
ctx: context.Background(),
|
||
}
|
||
}
|
||
|
||
// Set 设置缓存值
|
||
func (c *RedisCache) Set(key string, value interface{}, ttl time.Duration) {
|
||
key = normalizeKey(key)
|
||
// 将值序列化为JSON
|
||
data, err := json.Marshal(value)
|
||
if err != nil {
|
||
metrics.setError.Add(1)
|
||
log.Printf("[RedisCache] Failed to marshal value for key %s: %v", key, err)
|
||
return
|
||
}
|
||
|
||
if err := c.client.Set(c.ctx, key, data, ttl); err != nil {
|
||
metrics.setError.Add(1)
|
||
log.Printf("[RedisCache] Failed to set key %s: %v", key, err)
|
||
}
|
||
}
|
||
|
||
// Get 获取缓存值
|
||
func (c *RedisCache) Get(key string) (interface{}, bool) {
|
||
key = normalizeKey(key)
|
||
data, err := c.client.Get(c.ctx, key)
|
||
if err != nil {
|
||
if err == redis.Nil {
|
||
return nil, false
|
||
}
|
||
log.Printf("[RedisCache] Failed to get key %s: %v", key, err)
|
||
return nil, false
|
||
}
|
||
|
||
// 返回原始字符串,由调用侧决定如何解码为目标类型
|
||
return data, true
|
||
}
|
||
|
||
// Delete 删除缓存
|
||
func (c *RedisCache) Delete(key string) {
|
||
key = normalizeKey(key)
|
||
metrics.invalidate.Add(1)
|
||
if err := c.client.Del(c.ctx, key); err != nil {
|
||
log.Printf("[RedisCache] Failed to delete key %s: %v", key, err)
|
||
}
|
||
}
|
||
|
||
// DeleteByPrefix 根据前缀删除缓存
|
||
func (c *RedisCache) DeleteByPrefix(prefix string) {
|
||
prefix = normalizePrefix(prefix)
|
||
// 使用原生客户端执行SCAN命令
|
||
rdb := c.client.GetClient()
|
||
var cursor uint64
|
||
for {
|
||
keys, nextCursor, err := rdb.Scan(c.ctx, cursor, prefix+"*", 100).Result()
|
||
if err != nil {
|
||
log.Printf("[RedisCache] Failed to scan keys with prefix %s: %v", prefix, err)
|
||
return
|
||
}
|
||
|
||
if len(keys) > 0 {
|
||
metrics.invalidate.Add(int64(len(keys)))
|
||
if err := c.client.Del(c.ctx, keys...); err != nil {
|
||
log.Printf("[RedisCache] Failed to delete keys with prefix %s: %v", prefix, err)
|
||
}
|
||
}
|
||
|
||
cursor = nextCursor
|
||
if cursor == 0 {
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// Clear 清空所有缓存
|
||
func (c *RedisCache) Clear() {
|
||
if settings.DisableFlushDB {
|
||
log.Printf("[RedisCache] Skip FlushDB because cache.disable_flushdb=true")
|
||
return
|
||
}
|
||
metrics.invalidate.Add(1)
|
||
rdb := c.client.GetClient()
|
||
if err := rdb.FlushDB(c.ctx).Err(); err != nil {
|
||
log.Printf("[RedisCache] Failed to clear cache: %v", err)
|
||
}
|
||
}
|
||
|
||
// Exists 检查键是否存在
|
||
func (c *RedisCache) Exists(key string) bool {
|
||
key = normalizeKey(key)
|
||
n, err := c.client.Exists(c.ctx, key)
|
||
if err != nil {
|
||
log.Printf("[RedisCache] Failed to check existence of key %s: %v", key, err)
|
||
return false
|
||
}
|
||
return n > 0
|
||
}
|
||
|
||
// Increment 增加计数器的值
|
||
func (c *RedisCache) Increment(key string) int64 {
|
||
return c.IncrementBy(key, 1)
|
||
}
|
||
|
||
// IncrementBy 增加指定值
|
||
func (c *RedisCache) IncrementBy(key string, value int64) int64 {
|
||
key = normalizeKey(key)
|
||
rdb := c.client.GetClient()
|
||
result, err := rdb.IncrBy(c.ctx, key, value).Result()
|
||
if err != nil {
|
||
log.Printf("[RedisCache] Failed to increment key %s: %v", key, err)
|
||
return 0
|
||
}
|
||
return result
|
||
}
|
||
|
||
// ==================== RedisCache Hash 操作 ====================
|
||
|
||
// HSet 设置 Hash 字段
|
||
func (c *RedisCache) HSet(ctx context.Context, key string, field string, value interface{}) error {
|
||
key = normalizeKey(key)
|
||
return c.client.HSet(ctx, key, field, value)
|
||
}
|
||
|
||
// HMSet 批量设置 Hash 字段
|
||
func (c *RedisCache) HMSet(ctx context.Context, key string, values map[string]interface{}) error {
|
||
key = normalizeKey(key)
|
||
return c.client.HMSet(ctx, key, values)
|
||
}
|
||
|
||
// HGet 获取 Hash 字段值
|
||
func (c *RedisCache) HGet(ctx context.Context, key string, field string) (string, error) {
|
||
key = normalizeKey(key)
|
||
return c.client.HGet(ctx, key, field)
|
||
}
|
||
|
||
// HMGet 批量获取 Hash 字段值
|
||
func (c *RedisCache) HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) {
|
||
key = normalizeKey(key)
|
||
return c.client.HMGet(ctx, key, fields...)
|
||
}
|
||
|
||
// HGetAll 获取 Hash 所有字段
|
||
func (c *RedisCache) HGetAll(ctx context.Context, key string) (map[string]string, error) {
|
||
key = normalizeKey(key)
|
||
return c.client.HGetAll(ctx, key)
|
||
}
|
||
|
||
// HDel 删除 Hash 字段
|
||
func (c *RedisCache) HDel(ctx context.Context, key string, fields ...string) error {
|
||
key = normalizeKey(key)
|
||
return c.client.HDel(ctx, key, fields...)
|
||
}
|
||
|
||
// ==================== RedisCache Sorted Set 操作 ====================
|
||
|
||
// ZAdd 添加 Sorted Set 成员
|
||
func (c *RedisCache) ZAdd(ctx context.Context, key string, score float64, member string) error {
|
||
key = normalizeKey(key)
|
||
return c.client.ZAdd(ctx, key, score, member)
|
||
}
|
||
|
||
// ZRangeByScore 按分数范围获取成员(升序)
|
||
func (c *RedisCache) ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) {
|
||
key = normalizeKey(key)
|
||
return c.client.ZRangeByScore(ctx, key, min, max, offset, count)
|
||
}
|
||
|
||
// ZRevRangeByScore 按分数范围获取成员(降序)
|
||
func (c *RedisCache) ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) {
|
||
key = normalizeKey(key)
|
||
return c.client.ZRevRangeByScore(ctx, key, max, min, offset, count)
|
||
}
|
||
|
||
// ZRem 删除 Sorted Set 成员
|
||
func (c *RedisCache) ZRem(ctx context.Context, key string, members ...interface{}) error {
|
||
key = normalizeKey(key)
|
||
return c.client.ZRem(ctx, key, members...)
|
||
}
|
||
|
||
// ZCard 获取 Sorted Set 成员数量
|
||
func (c *RedisCache) ZCard(ctx context.Context, key string) (int64, error) {
|
||
key = normalizeKey(key)
|
||
return c.client.ZCard(ctx, key)
|
||
}
|
||
|
||
// ==================== RedisCache 计数器操作 ====================
|
||
|
||
// Incr 原子递增(返回新值)
|
||
func (c *RedisCache) Incr(ctx context.Context, key string) (int64, error) {
|
||
key = normalizeKey(key)
|
||
return c.client.Incr(ctx, key)
|
||
}
|
||
|
||
// Expire 设置过期时间
|
||
func (c *RedisCache) Expire(ctx context.Context, key string, ttl time.Duration) error {
|
||
key = normalizeKey(key)
|
||
_, err := c.client.Expire(ctx, key, ttl)
|
||
return err
|
||
}
|
||
|
||
// 全局缓存实例
|
||
var globalCache Cache
|
||
var once sync.Once
|
||
|
||
// InitCache 初始化全局缓存实例(使用Redis)
|
||
func InitCache(redisClient *redisPkg.Client) {
|
||
once.Do(func() {
|
||
if redisClient != nil {
|
||
globalCache = NewRedisCache(redisClient)
|
||
log.Println("[Cache] Initialized Redis cache")
|
||
} else {
|
||
globalCache = NewMemoryCache()
|
||
log.Println("[Cache] Initialized Memory cache (Redis not available)")
|
||
}
|
||
})
|
||
}
|
||
|
||
// GetCache 获取全局缓存实例
|
||
func GetCache() Cache {
|
||
if globalCache == nil {
|
||
// 如果未初始化,返回内存缓存作为降级
|
||
log.Println("[Cache] Warning: Cache not initialized, using Memory cache")
|
||
return NewMemoryCache()
|
||
}
|
||
return globalCache
|
||
}
|
||
|
||
// GetRedisClient 从缓存中获取Redis客户端(仅在Redis模式下有效)
|
||
func GetRedisClient() (*redisPkg.Client, error) {
|
||
if redisCache, ok := globalCache.(*RedisCache); ok {
|
||
return redisCache.client, nil
|
||
}
|
||
return nil, fmt.Errorf("cache is not using Redis backend")
|
||
}
|
||
|
||
func SetWithJitter(c Cache, key string, value interface{}, ttl time.Duration, jitterRatio float64) {
|
||
if !settings.Enabled {
|
||
return
|
||
}
|
||
c.Set(key, value, ApplyTTLJitter(ttl, jitterRatio))
|
||
}
|
||
|
||
func SetNull(c Cache, key string, ttl time.Duration) {
|
||
if !settings.Enabled {
|
||
return
|
||
}
|
||
c.Set(key, nullMarkerValue, ttl)
|
||
}
|
||
|
||
func ApplyTTLJitter(ttl time.Duration, jitterRatio float64) time.Duration {
|
||
if ttl <= 0 || jitterRatio <= 0 {
|
||
return ttl
|
||
}
|
||
if jitterRatio > 1 {
|
||
jitterRatio = 1
|
||
}
|
||
maxJitter := int64(float64(ttl) * jitterRatio)
|
||
if maxJitter <= 0 {
|
||
return ttl
|
||
}
|
||
delta := rand.Int63n(maxJitter + 1)
|
||
return ttl + time.Duration(delta)
|
||
}
|
||
|
||
func GetTyped[T any](c Cache, key string) (T, bool) {
|
||
var zero T
|
||
if !settings.Enabled {
|
||
return zero, false
|
||
}
|
||
raw, ok := c.Get(key)
|
||
if !ok {
|
||
metrics.miss.Add(1)
|
||
return zero, false
|
||
}
|
||
if str, ok := raw.(string); ok && str == nullMarkerValue {
|
||
metrics.hit.Add(1)
|
||
return zero, false
|
||
}
|
||
|
||
if typed, ok := raw.(T); ok {
|
||
metrics.hit.Add(1)
|
||
return typed, true
|
||
}
|
||
|
||
var out T
|
||
switch v := raw.(type) {
|
||
case string:
|
||
if err := json.Unmarshal([]byte(v), &out); err != nil {
|
||
metrics.decodeError.Add(1)
|
||
return zero, false
|
||
}
|
||
metrics.hit.Add(1)
|
||
return out, true
|
||
case []byte:
|
||
if err := json.Unmarshal(v, &out); err != nil {
|
||
metrics.decodeError.Add(1)
|
||
return zero, false
|
||
}
|
||
metrics.hit.Add(1)
|
||
return out, true
|
||
default:
|
||
data, err := json.Marshal(v)
|
||
if err != nil {
|
||
metrics.decodeError.Add(1)
|
||
return zero, false
|
||
}
|
||
if err := json.Unmarshal(data, &out); err != nil {
|
||
metrics.decodeError.Add(1)
|
||
return zero, false
|
||
}
|
||
metrics.hit.Add(1)
|
||
return out, true
|
||
}
|
||
}
|
||
|
||
func GetOrLoadTyped[T any](
|
||
c Cache,
|
||
key string,
|
||
ttl time.Duration,
|
||
jitterRatio float64,
|
||
nullTTL time.Duration,
|
||
loader func() (T, error),
|
||
) (T, error) {
|
||
if cached, ok := GetTyped[T](c, key); ok {
|
||
return cached, nil
|
||
}
|
||
|
||
lockValue, _ := loadLocks.LoadOrStore(key, &sync.Mutex{})
|
||
lock := lockValue.(*sync.Mutex)
|
||
lock.Lock()
|
||
defer lock.Unlock()
|
||
|
||
if cached, ok := GetTyped[T](c, key); ok {
|
||
return cached, nil
|
||
}
|
||
|
||
loaded, err := loader()
|
||
if err != nil {
|
||
var zero T
|
||
return zero, err
|
||
}
|
||
|
||
encoded, marshalErr := json.Marshal(loaded)
|
||
if marshalErr == nil && string(encoded) == "null" && nullTTL > 0 {
|
||
SetNull(c, key, nullTTL)
|
||
return loaded, nil
|
||
}
|
||
|
||
SetWithJitter(c, key, loaded, ttl, jitterRatio)
|
||
return loaded, nil
|
||
}
|