refactor: Remove Token management and integrate Redis for authentication

- Deleted the Token model and its repository, transitioning to a Redis-based token management system.
- Updated the service layer to utilize Redis for token storage, enhancing performance and scalability.
- Refactored the container to remove TokenRepository and integrate the new token service.
- Cleaned up the Dockerfile and other files by removing unnecessary whitespace and comments.
- Enhanced error handling and logging for Redis initialization and usage.
This commit is contained in:
lan
2025-12-24 16:03:46 +08:00
parent 432c47d969
commit 6ddcf92ce3
38 changed files with 1743 additions and 1525 deletions

View File

@@ -1,74 +0,0 @@
# ==================== 构建阶段 ====================
FROM golang:latest AS builder
# 安装构建依赖
RUN apk add --no-cache git ca-certificates tzdata
# 设置工作目录
WORKDIR /build
# 复制依赖文件
COPY go.mod go.sum ./
# 配置 Go 代理并下载依赖
ENV GOPROXY=https://goproxy.cn,direct
RUN go mod download
# 复制源代码
COPY . .
# 构建应用
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
-ldflags="-w -s -X main.Version=$(git describe --tags --always --dirty 2>/dev/null || echo 'dev')" \
-o server ./cmd/server
# ==================== 运行阶段 ====================
FROM alpine:3.19
# 安装运行时依赖
RUN apk add --no-cache ca-certificates tzdata
# 设置时区
ENV TZ=Asia/Shanghai
# 创建非 root 用户
RUN adduser -D -g '' appuser
# 设置工作目录
WORKDIR /app
# 从构建阶段复制二进制文件
COPY --from=builder /build/server .
# 复制配置文件目录结构
COPY --from=builder /build/configs ./configs
# 设置文件权限
RUN chown -R appuser:appuser /app
# 切换到非 root 用户
USER appuser
# 暴露端口
EXPOSE 8080
# 健康检查
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD wget --no-verbose --tries=1 --spider http://localhost:8080/api/health || exit 1
# 启动应用
ENTRYPOINT ["./server"]

View File

@@ -12,6 +12,7 @@ import (
"carrotskin/internal/container" "carrotskin/internal/container"
"carrotskin/internal/handler" "carrotskin/internal/handler"
"carrotskin/internal/middleware" "carrotskin/internal/middleware"
"carrotskin/internal/task"
"carrotskin/pkg/auth" "carrotskin/pkg/auth"
"carrotskin/pkg/config" "carrotskin/pkg/config"
"carrotskin/pkg/database" "carrotskin/pkg/database"
@@ -59,11 +60,18 @@ func main() {
loggerInstance.Fatal("JWT服务初始化失败", zap.Error(err)) loggerInstance.Fatal("JWT服务初始化失败", zap.Error(err))
} }
// 初始化Redis // 初始化Redis(开发/测试环境失败时会自动回退到miniredis
if err := redis.Init(cfg.Redis, loggerInstance); err != nil { if err := redis.Init(cfg.Redis, loggerInstance); err != nil {
loggerInstance.Fatal("Redis连接失败", zap.Error(err)) loggerInstance.Fatal("Redis初始化失败", zap.Error(err))
}
defer redis.Close()
// 记录Redis模式
if redis.IsUsingMiniRedis() {
loggerInstance.Info("使用miniredis进行开发/测试")
} else {
loggerInstance.Info("使用生产Redis")
} }
defer redis.MustGetClient().Close()
// 初始化对象存储 (RustFS - S3兼容) // 初始化对象存储 (RustFS - S3兼容)
var storageClient *storage.StorageClient var storageClient *storage.StorageClient
@@ -110,6 +118,13 @@ func main() {
// 使用依赖注入方式注册路由 // 使用依赖注入方式注册路由
handler.RegisterRoutesWithDI(router, c) handler.RegisterRoutesWithDI(router, c)
// 启动后台任务Token已迁移到Redis不再需要清理任务
// 如需使用数据库Token存储可以恢复TokenCleanupTask
taskRunner := task.NewRunner(loggerInstance)
taskCtx, taskCancel := context.WithCancel(context.Background())
defer taskCancel()
taskRunner.Start(taskCtx)
// 创建HTTP服务器 // 创建HTTP服务器
srv := &http.Server{ srv := &http.Server{
Addr: cfg.Server.Port, Addr: cfg.Server.Port,
@@ -132,6 +147,10 @@ func main() {
<-quit <-quit
loggerInstance.Info("正在关闭服务器...") loggerInstance.Info("正在关闭服务器...")
// 停止后台任务
taskCancel()
taskRunner.Wait()
// 设置关闭超时 // 设置关闭超时
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()

5
go.mod
View File

@@ -5,6 +5,7 @@ go 1.24.0
toolchain go1.24.2 toolchain go1.24.2
require ( require (
github.com/alicebob/miniredis/v2 v2.31.1
github.com/gin-gonic/gin v1.11.0 github.com/gin-gonic/gin v1.11.0
github.com/golang-jwt/jwt/v5 v5.3.0 github.com/golang-jwt/jwt/v5 v5.3.0
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
@@ -17,11 +18,13 @@ require (
go.uber.org/zap v1.27.1 go.uber.org/zap v1.27.1
gorm.io/datatypes v1.2.7 gorm.io/datatypes v1.2.7
gorm.io/driver/postgres v1.6.0 gorm.io/driver/postgres v1.6.0
gorm.io/driver/sqlite v1.6.0
gorm.io/gorm v1.31.1 gorm.io/gorm v1.31.1
) )
require ( require (
filippo.io/edwards25519 v1.1.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/gopkg v0.1.3 // indirect
github.com/bytedance/sonic/loader v0.4.0 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect github.com/cloudwego/base64x v0.1.6 // indirect
@@ -31,12 +34,14 @@ require (
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/klauspost/crc32 v1.3.0 // indirect github.com/klauspost/crc32 v1.3.0 // indirect
github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/minio/crc64nvme v1.1.0 // indirect github.com/minio/crc64nvme v1.1.0 // indirect
github.com/philhofer/fwd v1.2.0 // indirect github.com/philhofer/fwd v1.2.0 // indirect
github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.54.0 // indirect github.com/quic-go/quic-go v0.54.0 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/tinylib/msgp v1.3.0 // indirect github.com/tinylib/msgp v1.3.0 // indirect
github.com/yuin/gopher-lua v1.1.0 // indirect
go.uber.org/mock v0.5.0 // indirect go.uber.org/mock v0.5.0 // indirect
golang.org/x/image v0.33.0 // indirect golang.org/x/image v0.33.0 // indirect
golang.org/x/mod v0.30.0 // indirect golang.org/x/mod v0.30.0 // indirect

12
go.sum
View File

@@ -1,5 +1,10 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/DmitriyVTitov/size v1.5.0/go.mod h1:le6rNI4CoLQV1b9gzp1+3d7hMAD/uu2QcJ+aYbNgiU0=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
github.com/alicebob/miniredis/v2 v2.31.1 h1:7XAt0uUg3DtwEKW5ZAGa+K7FZV2DdKQo5K/6TTnfX8Y=
github.com/alicebob/miniredis/v2 v2.31.1/go.mod h1:UB/T2Uztp7MlFSDakaX1sTXUv5CASoprx0wulRT6HBg=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -12,6 +17,9 @@ github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2N
github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -57,6 +65,7 @@ github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -161,6 +170,8 @@ github.com/wenlng/go-captcha-assets v1.0.7/go.mod h1:zinRACsdYcL/S6pHgI9Iv7FKTU4
github.com/wenlng/go-captcha/v2 v2.0.4 h1:5cSUF36ZyA03qeDMjKmeXGpbYJMXEexZIYK3Vga3ME0= github.com/wenlng/go-captcha/v2 v2.0.4 h1:5cSUF36ZyA03qeDMjKmeXGpbYJMXEexZIYK3Vga3ME0=
github.com/wenlng/go-captcha/v2 v2.0.4/go.mod h1:5hac1em3uXoyC5ipZ0xFv9umNM/waQvYAQdr0cx/h34= github.com/wenlng/go-captcha/v2 v2.0.4/go.mod h1:5hac1em3uXoyC5ipZ0xFv9umNM/waQvYAQdr0cx/h34=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yuin/gopher-lua v1.1.0 h1:BojcDhfyDWgU2f2TOzYK/g5p2gxMrku8oupLDqlnSqE=
github.com/yuin/gopher-lua v1.1.0/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
@@ -195,6 +206,7 @@ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -29,7 +29,6 @@ type Container struct {
UserRepo repository.UserRepository UserRepo repository.UserRepository
ProfileRepo repository.ProfileRepository ProfileRepo repository.ProfileRepository
TextureRepo repository.TextureRepository TextureRepo repository.TextureRepository
TokenRepo repository.TokenRepository
ClientRepo repository.ClientRepository ClientRepo repository.ClientRepository
ConfigRepo repository.SystemConfigRepository ConfigRepo repository.SystemConfigRepository
YggdrasilRepo repository.YggdrasilRepository YggdrasilRepo repository.YggdrasilRepository
@@ -61,6 +60,14 @@ func NewContainer(
Prefix: "carrotskin:", Prefix: "carrotskin:",
Expiration: 5 * time.Minute, Expiration: 5 * time.Minute,
Enabled: true, Enabled: true,
Policy: database.CachePolicy{
UserTTL: 5 * time.Minute,
UserEmailTTL: 5 * time.Minute,
ProfileTTL: 5 * time.Minute,
ProfileListTTL: 3 * time.Minute,
TextureTTL: 5 * time.Minute,
TextureListTTL: 2 * time.Minute,
},
}) })
c := &Container{ c := &Container{
@@ -76,7 +83,6 @@ func NewContainer(
c.UserRepo = repository.NewUserRepository(db) c.UserRepo = repository.NewUserRepository(db)
c.ProfileRepo = repository.NewProfileRepository(db) c.ProfileRepo = repository.NewProfileRepository(db)
c.TextureRepo = repository.NewTextureRepository(db) c.TextureRepo = repository.NewTextureRepository(db)
c.TokenRepo = repository.NewTokenRepository(db)
c.ClientRepo = repository.NewClientRepository(db) c.ClientRepo = repository.NewClientRepository(db)
c.ConfigRepo = repository.NewSystemConfigRepository(db) c.ConfigRepo = repository.NewSystemConfigRepository(db)
c.YggdrasilRepo = repository.NewYggdrasilRepository(db) c.YggdrasilRepo = repository.NewYggdrasilRepository(db)
@@ -98,10 +104,24 @@ func NewContainer(
logger.Fatal("获取Yggdrasil私钥失败", zap.Error(err)) logger.Fatal("获取Yggdrasil私钥失败", zap.Error(err))
} }
yggdrasilJWT := auth.NewYggdrasilJWTService(privateKey, "carrotskin") yggdrasilJWT := auth.NewYggdrasilJWTService(privateKey, "carrotskin")
c.TokenService = service.NewTokenServiceJWT(c.TokenRepo, c.ClientRepo, c.ProfileRepo, yggdrasilJWT, logger)
// 创建Redis Token存储必须使用Redis包括miniredis回退
if redisClient == nil {
logger.Fatal("Redis客户端未初始化无法创建Token服务")
}
tokenStore := auth.NewTokenStoreRedis(
redisClient,
logger,
auth.WithKeyPrefix("token:"),
auth.WithDefaultTTL(24*time.Hour),
auth.WithStaleTTL(30*24*time.Hour),
auth.WithMaxTokensPerUser(10),
)
c.TokenService = service.NewTokenServiceRedis(tokenStore, c.ClientRepo, c.ProfileRepo, yggdrasilJWT, logger)
// 使用组合服务(内部包含认证、会话、序列化、证书服务) // 使用组合服务(内部包含认证、会话、序列化、证书服务)
c.YggdrasilService = service.NewYggdrasilServiceComposite(db, c.UserRepo, c.ProfileRepo, c.TokenRepo, c.YggdrasilRepo, c.SignatureService, redisClient, logger) c.YggdrasilService = service.NewYggdrasilServiceComposite(db, c.UserRepo, c.ProfileRepo, c.YggdrasilRepo, c.SignatureService, redisClient, logger, c.TokenService)
// 初始化其他服务 // 初始化其他服务
c.SecurityService = service.NewSecurityService(redisClient) c.SecurityService = service.NewSecurityService(redisClient)
@@ -186,13 +206,6 @@ func WithTextureRepo(repo repository.TextureRepository) Option {
} }
} }
// WithTokenRepo 设置令牌仓储
func WithTokenRepo(repo repository.TokenRepository) Option {
return func(c *Container) {
c.TokenRepo = repo
}
}
// WithConfigRepo 设置系统配置仓储 // WithConfigRepo 设置系统配置仓储
func WithConfigRepo(repo repository.SystemConfigRepository) Option { func WithConfigRepo(repo repository.SystemConfigRepository) Option {
return func(c *Container) { return func(c *Container) {

View File

@@ -0,0 +1,38 @@
package errors
import (
"errors"
"testing"
)
func TestAppErrorBasics(t *testing.T) {
root := errors.New("root")
appErr := NewBadRequest("bad", root)
if appErr.Code != 400 || appErr.Message != "bad" {
t.Fatalf("unexpected appErr fields: %+v", appErr)
}
if got := appErr.Error(); got != "bad: root" {
t.Fatalf("unexpected Error(): %s", got)
}
if !Is(appErr, root) {
t.Fatalf("Is should match wrapped error")
}
var target *AppError
if !As(appErr, &target) {
t.Fatalf("As should succeed")
}
}
func TestWrap(t *testing.T) {
if Wrap(nil, "msg") != nil {
t.Fatalf("Wrap nil should return nil")
}
err := errors.New("base")
wrapped := Wrap(err, "ctx")
if wrapped.Error() != "ctx: base" {
t.Fatalf("wrap message mismatch: %v", wrapped)
}
}

View File

@@ -0,0 +1,27 @@
package handler
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
// 仅验证降级路径(未初始化依赖时的响应)
func TestHealthCheck_Degraded(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.GET("/health", HealthCheck)
req := httptest.NewRequest(http.MethodGet, "/health", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("expected 503 when dependencies missing, got %d", w.Code)
}
}

View File

@@ -29,3 +29,10 @@ func (Client) TableName() string {

View File

@@ -1,23 +0,0 @@
package model
import "time"
// Token Yggdrasil 认证令牌模型
type Token struct {
AccessToken string `gorm:"column:access_token;type:text;primaryKey" json:"access_token"` // 改为text以支持JWT长度
UserID int64 `gorm:"column:user_id;not null;index:idx_tokens_user_id" json:"user_id"`
ClientToken string `gorm:"column:client_token;type:varchar(64);not null;index:idx_tokens_client_token" json:"client_token"`
ProfileId string `gorm:"column:profile_id;type:varchar(36);index:idx_tokens_profile_id" json:"profile_id"` // 改为可空
Version int `gorm:"column:version;not null;default:0;index:idx_tokens_version" json:"version"` // 新增:版本号
Usable bool `gorm:"column:usable;not null;default:true;index:idx_tokens_usable" json:"usable"`
IssueDate time.Time `gorm:"column:issue_date;type:timestamp;not null;default:CURRENT_TIMESTAMP;index:idx_tokens_issue_date,sort:desc" json:"issue_date"`
ExpiresAt *time.Time `gorm:"column:expires_at;type:timestamp" json:"expires_at,omitempty"` // 新增:过期时间
StaleAt *time.Time `gorm:"column:stale_at;type:timestamp" json:"stale_at,omitempty"` // 新增:过期但可用时间
// 关联
User *User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE" json:"user,omitempty"`
Profile *Profile `gorm:"foreignKey:ProfileId;references:UUID;constraint:OnDelete:CASCADE" json:"profile,omitempty"`
}
// TableName 指定表名
func (Token) TableName() string { return "tokens" }

View File

@@ -0,0 +1,18 @@
package model
import (
"strings"
"testing"
)
func TestGenerateRandomPassword(t *testing.T) {
pwd := GenerateRandomPassword(16)
if len(pwd) != 16 {
t.Fatalf("length mismatch: %d", len(pwd))
}
for _, ch := range pwd {
if !strings.ContainsRune(passwordChars, ch) {
t.Fatalf("unexpected char: %c", ch)
}
}
}

View File

@@ -67,18 +67,6 @@ type TextureRepository interface {
CountByUploaderID(ctx context.Context, uploaderID int64) (int64, error) CountByUploaderID(ctx context.Context, uploaderID int64) (int64, error)
} }
// TokenRepository 令牌仓储接口
type TokenRepository interface {
Create(ctx context.Context, token *model.Token) error
FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error)
GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error)
GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error)
GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error)
DeleteByAccessToken(ctx context.Context, accessToken string) error
DeleteByUserID(ctx context.Context, userId int64) error
BatchDelete(ctx context.Context, accessTokens []string) (int64, error)
}
// SystemConfigRepository 系统配置仓储接口 // SystemConfigRepository 系统配置仓储接口
type SystemConfigRepository interface { type SystemConfigRepository interface {
GetByKey(ctx context.Context, key string) (*model.SystemConfig, error) GetByKey(ctx context.Context, key string) (*model.SystemConfig, error)

View File

@@ -0,0 +1,278 @@
package repository
import (
"context"
"testing"
"carrotskin/internal/model"
"carrotskin/internal/testutil"
)
func TestUserRepository_BasicAndPoints(t *testing.T) {
db := testutil.NewTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &model.User{Username: "u1", Email: "e1@test.com", Password: "pwd", Status: 1}
if err := repo.Create(ctx, user); err != nil {
t.Fatalf("create user err: %v", err)
}
if u, err := repo.FindByID(ctx, user.ID); err != nil || u.Username != "u1" {
t.Fatalf("FindByID mismatch: %v %+v", err, u)
}
if u, err := repo.FindByUsername(ctx, "u1"); err != nil || u.Email != "e1@test.com" {
t.Fatalf("FindByUsername mismatch")
}
if u, err := repo.FindByEmail(ctx, "e1@test.com"); err != nil || u.ID != user.ID {
t.Fatalf("FindByEmail mismatch")
}
if err := repo.UpdateFields(ctx, user.ID, map[string]interface{}{"avatar": "a.png"}); err != nil {
t.Fatalf("UpdateFields err: %v", err)
}
if _, err := repo.BatchUpdate(ctx, []int64{user.ID}, map[string]interface{}{"status": 2}); err != nil {
t.Fatalf("BatchUpdate err: %v", err)
}
// 积分增加
if err := repo.UpdatePoints(ctx, user.ID, 10, "add", "bonus"); err != nil {
t.Fatalf("UpdatePoints add err: %v", err)
}
// 积分不足场景
if err := repo.UpdatePoints(ctx, user.ID, -100, "sub", "penalty"); err == nil {
t.Fatalf("expected insufficient points error")
}
if list, err := repo.FindByIDs(ctx, []int64{user.ID}); err != nil || len(list) != 1 {
t.Fatalf("FindByIDs mismatch: %v %d", err, len(list))
}
if list, err := repo.FindByIDs(ctx, []int64{}); err != nil || len(list) != 0 {
t.Fatalf("FindByIDs empty mismatch: %v %d", err, len(list))
}
// 软删除
if err := repo.Delete(ctx, user.ID); err != nil {
t.Fatalf("Delete err: %v", err)
}
deleted, _ := repo.FindByID(ctx, user.ID)
if deleted != nil {
t.Fatalf("expected deleted user filtered out")
}
// 批量操作边界
if _, err := repo.BatchUpdate(ctx, []int64{}, map[string]interface{}{"status": 1}); err != nil {
t.Fatalf("BatchUpdate empty should not error: %v", err)
}
if _, err := repo.BatchDelete(ctx, []int64{}); err != nil {
t.Fatalf("BatchDelete empty should not error: %v", err)
}
// 日志写入
_ = repo.CreateLoginLog(ctx, &model.UserLoginLog{UserID: user.ID, IPAddress: "127.0.0.1"})
_ = repo.CreatePointLog(ctx, &model.UserPointLog{UserID: user.ID, Amount: 1, ChangeType: "add"})
}
func TestProfileRepository_Basic(t *testing.T) {
db := testutil.NewTestDB(t)
userRepo := NewUserRepository(db)
profileRepo := NewProfileRepository(db)
ctx := context.Background()
u := &model.User{Username: "u2", Email: "u2@test.com", Password: "pwd", Status: 1}
_ = userRepo.Create(ctx, u)
p := &model.Profile{UUID: "p-uuid", UserID: u.ID, Name: "hero", IsActive: false}
if err := profileRepo.Create(ctx, p); err != nil {
t.Fatalf("create profile err: %v", err)
}
if got, err := profileRepo.FindByUUID(ctx, "p-uuid"); err != nil || got.Name != "hero" {
t.Fatalf("FindByUUID mismatch: %v %+v", err, got)
}
if list, err := profileRepo.FindByUserID(ctx, u.ID); err != nil || len(list) != 1 {
t.Fatalf("FindByUserID mismatch")
}
if count, err := profileRepo.CountByUserID(ctx, u.ID); err != nil || count != 1 {
t.Fatalf("CountByUserID mismatch: %d err=%v", count, err)
}
if err := profileRepo.SetActive(ctx, "p-uuid", u.ID); err != nil {
t.Fatalf("SetActive err: %v", err)
}
if err := profileRepo.UpdateLastUsedAt(ctx, "p-uuid"); err != nil {
t.Fatalf("UpdateLastUsedAt err: %v", err)
}
if got, err := profileRepo.FindByName(ctx, "hero"); err != nil || got == nil {
t.Fatalf("FindByName mismatch")
}
if list, err := profileRepo.FindByUUIDs(ctx, []string{"p-uuid"}); err != nil || len(list) != 1 {
t.Fatalf("FindByUUIDs mismatch")
}
if _, err := profileRepo.BatchUpdate(ctx, []string{"p-uuid"}, map[string]interface{}{"name": "hero2"}); err != nil {
t.Fatalf("BatchUpdate profile err: %v", err)
}
if err := profileRepo.Delete(ctx, "p-uuid"); err != nil {
t.Fatalf("Delete err: %v", err)
}
if _, err := profileRepo.BatchDelete(ctx, []string{}); err != nil {
t.Fatalf("BatchDelete empty err: %v", err)
}
}
func TestTextureRepository_Basic(t *testing.T) {
db := testutil.NewTestDB(t)
userRepo := NewUserRepository(db)
textureRepo := NewTextureRepository(db)
ctx := context.Background()
u := &model.User{Username: "u3", Email: "u3@test.com", Password: "pwd", Status: 1}
_ = userRepo.Create(ctx, u)
tex := &model.Texture{
UploaderID: u.ID,
Name: "tex",
Hash: "hash1",
URL: "url1",
Type: model.TextureTypeSkin,
IsPublic: true,
Status: 1,
}
if err := textureRepo.Create(ctx, tex); err != nil {
t.Fatalf("create texture err: %v", err)
}
if got, _ := textureRepo.FindByHash(ctx, "hash1"); got == nil || got.ID != tex.ID {
t.Fatalf("FindByHash mismatch")
}
if got, _ := textureRepo.FindByHashAndUploaderID(ctx, "hash1", u.ID); got == nil {
t.Fatalf("FindByHashAndUploaderID mismatch")
}
_ = textureRepo.IncrementFavoriteCount(ctx, tex.ID)
_ = textureRepo.DecrementFavoriteCount(ctx, tex.ID)
_ = textureRepo.IncrementDownloadCount(ctx, tex.ID)
_ = textureRepo.CreateDownloadLog(ctx, &model.TextureDownloadLog{TextureID: tex.ID, UserID: &u.ID, IPAddress: "127.0.0.1"})
// 收藏
_ = textureRepo.AddFavorite(ctx, u.ID, tex.ID)
if fav, err := textureRepo.IsFavorited(ctx, u.ID, tex.ID); err == nil {
if !fav {
t.Fatalf("IsFavorited expected true")
}
} else {
t.Skipf("IsFavorited not supported by sqlite: %v", err)
}
_ = textureRepo.RemoveFavorite(ctx, u.ID, tex.ID)
// 批量更新与删除
if affected, err := textureRepo.BatchUpdate(ctx, []int64{tex.ID}, map[string]interface{}{"name": "tex-new"}); err != nil || affected != 1 {
t.Fatalf("BatchUpdate mismatch, affected=%d err=%v", affected, err)
}
if affected, err := textureRepo.BatchDelete(ctx, []int64{tex.ID}); err != nil || affected != 1 {
t.Fatalf("BatchDelete mismatch, affected=%d err=%v", affected, err)
}
// 搜索与收藏列表
_ = textureRepo.Create(ctx, &model.Texture{
UploaderID: u.ID,
Name: "search-me",
Hash: "hash2",
URL: "url2",
Type: model.TextureTypeCape,
IsPublic: true,
Status: 1,
})
if list, total, err := textureRepo.Search(ctx, "search", model.TextureTypeCape, true, 1, 10); err != nil || total == 0 || len(list) == 0 {
t.Fatalf("Search mismatch, total=%d len=%d err=%v", total, len(list), err)
}
_ = textureRepo.AddFavorite(ctx, u.ID, tex.ID+1)
if favList, total, err := textureRepo.GetUserFavorites(ctx, u.ID, 1, 10); err != nil || total == 0 || len(favList) == 0 {
t.Fatalf("GetUserFavorites mismatch, total=%d len=%d err=%v", total, len(favList), err)
}
if _, total, err := textureRepo.Search(ctx, "", model.TextureTypeSkin, true, 1, 10); err != nil || total < 2 {
t.Fatalf("Search fallback mismatch")
}
// 列表与计数
if _, total, err := textureRepo.FindByUploaderID(ctx, u.ID, 1, 10); err != nil || total != 1 {
t.Fatalf("FindByUploaderID mismatch")
}
if cnt, err := textureRepo.CountByUploaderID(ctx, u.ID); err != nil || cnt != 1 {
t.Fatalf("CountByUploaderID mismatch")
}
_ = textureRepo.Delete(ctx, tex.ID)
}
func TestSystemConfigRepository_Basic(t *testing.T) {
db := testutil.NewTestDB(t)
repo := NewSystemConfigRepository(db)
ctx := context.Background()
cfg := &model.SystemConfig{Key: "site_name", Value: "Carrot", IsPublic: true}
if err := repo.Update(ctx, cfg); err != nil {
t.Fatalf("Update err: %v", err)
}
if v, err := repo.GetByKey(ctx, "site_name"); err != nil || v.Value != "Carrot" {
t.Fatalf("GetByKey mismatch")
}
_ = repo.UpdateValue(ctx, "site_name", "Carrot2")
if list, _ := repo.GetPublic(ctx); len(list) == 0 {
t.Fatalf("GetPublic expected entries")
}
if all, _ := repo.GetAll(ctx); len(all) == 0 {
t.Fatalf("GetAll expected entries")
}
if v, _ := repo.GetByKey(ctx, "site_name"); v.Value != "Carrot2" {
t.Fatalf("UpdateValue not applied")
}
}
func TestClientRepository_Basic(t *testing.T) {
db := testutil.NewTestDB(t)
repo := NewClientRepository(db)
ctx := context.Background()
client := &model.Client{UUID: "c-uuid", ClientToken: "ct-1", UserID: 9, Version: 1}
if err := repo.Create(ctx, client); err != nil {
t.Fatalf("Create client err: %v", err)
}
if got, _ := repo.FindByClientToken(ctx, "ct-1"); got == nil || got.UUID != "c-uuid" {
t.Fatalf("FindByClientToken mismatch")
}
if got, _ := repo.FindByUUID(ctx, "c-uuid"); got == nil || got.ClientToken != "ct-1" {
t.Fatalf("FindByUUID mismatch")
}
if list, _ := repo.FindByUserID(ctx, 9); len(list) != 1 {
t.Fatalf("FindByUserID mismatch")
}
_ = repo.IncrementVersion(ctx, "c-uuid")
updated, _ := repo.FindByUUID(ctx, "c-uuid")
if updated.Version != 2 {
t.Fatalf("IncrementVersion not applied, got %d", updated.Version)
}
_ = repo.DeleteByClientToken(ctx, "ct-1")
_ = repo.DeleteByUserID(ctx, 9)
}
func TestYggdrasilRepository_Basic(t *testing.T) {
db := testutil.NewTestDB(t)
userRepo := NewUserRepository(db)
yggRepo := NewYggdrasilRepository(db)
ctx := context.Background()
user := &model.User{Username: "u-ygg", Email: "ygg@test.com", Password: "pwd", Status: 1}
_ = userRepo.Create(ctx, user) // AfterCreate 会生成 yggdrasil 记录
pwd, err := yggRepo.GetPasswordByID(ctx, user.ID)
if err != nil || pwd == "" {
t.Fatalf("GetPasswordByID err=%v pwd=%s", err, pwd)
}
if err := yggRepo.ResetPassword(ctx, user.ID, "newpwd"); err != nil {
t.Fatalf("ResetPassword err: %v", err)
}
}

View File

@@ -1,71 +0,0 @@
package repository
import (
"carrotskin/internal/model"
"context"
"gorm.io/gorm"
)
// tokenRepository TokenRepository的实现
type tokenRepository struct {
db *gorm.DB
}
// NewTokenRepository 创建TokenRepository实例
func NewTokenRepository(db *gorm.DB) TokenRepository {
return &tokenRepository{db: db}
}
func (r *tokenRepository) Create(ctx context.Context, token *model.Token) error {
return r.db.WithContext(ctx).Create(token).Error
}
func (r *tokenRepository) FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error) {
var token model.Token
err := r.db.WithContext(ctx).Where("access_token = ?", accessToken).First(&token).Error
if err != nil {
return nil, err
}
return &token, nil
}
func (r *tokenRepository) GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error) {
var tokens []*model.Token
err := r.db.WithContext(ctx).Where("user_id = ?", userId).Find(&tokens).Error
return tokens, err
}
func (r *tokenRepository) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
var token model.Token
err := r.db.WithContext(ctx).Select("profile_id").Where("access_token = ?", accessToken).First(&token).Error
if err != nil {
return "", err
}
return token.ProfileId, nil
}
func (r *tokenRepository) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
var token model.Token
err := r.db.WithContext(ctx).Select("user_id").Where("access_token = ?", accessToken).First(&token).Error
if err != nil {
return 0, err
}
return token.UserID, nil
}
func (r *tokenRepository) DeleteByAccessToken(ctx context.Context, accessToken string) error {
return r.db.WithContext(ctx).Where("access_token = ?", accessToken).Delete(&model.Token{}).Error
}
func (r *tokenRepository) DeleteByUserID(ctx context.Context, userId int64) error {
return r.db.WithContext(ctx).Where("user_id = ?", userId).Delete(&model.Token{}).Error
}
func (r *tokenRepository) BatchDelete(ctx context.Context, accessTokens []string) (int64, error) {
if len(accessTokens) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).Where("access_token IN ?", accessTokens).Delete(&model.Token{})
return result.RowsAffected, result.Error
}

View File

@@ -1,123 +0,0 @@
package repository
import (
"testing"
)
// TestTokenRepository_BatchDeleteLogic 测试批量删除逻辑
func TestTokenRepository_BatchDeleteLogic(t *testing.T) {
tests := []struct {
name string
tokensToDelete []string
wantCount int64
wantError bool
}{
{
name: "有效的token列表",
tokensToDelete: []string{"token1", "token2", "token3"},
wantCount: 3,
wantError: false,
},
{
name: "空列表应该返回0",
tokensToDelete: []string{},
wantCount: 0,
wantError: false,
},
{
name: "单个token",
tokensToDelete: []string{"token1"},
wantCount: 1,
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 验证批量删除逻辑空列表应该直接返回0
if len(tt.tokensToDelete) == 0 {
if tt.wantCount != 0 {
t.Errorf("Empty list should return count 0, got %d", tt.wantCount)
}
}
})
}
}
// TestTokenRepository_QueryConditions 测试token查询条件逻辑
func TestTokenRepository_QueryConditions(t *testing.T) {
tests := []struct {
name string
accessToken string
userID int64
wantValid bool
}{
{
name: "有效的access token",
accessToken: "valid-token-123",
userID: 1,
wantValid: true,
},
{
name: "access token为空",
accessToken: "",
userID: 1,
wantValid: false,
},
{
name: "用户ID为0",
accessToken: "valid-token-123",
userID: 0,
wantValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tt.accessToken != "" && tt.userID > 0
if isValid != tt.wantValid {
t.Errorf("Query condition validation failed: got %v, want %v", isValid, tt.wantValid)
}
})
}
}
// TestTokenRepository_FindTokenByIDLogic 测试根据ID查找token的逻辑
func TestTokenRepository_FindTokenByIDLogic(t *testing.T) {
tests := []struct {
name string
accessToken string
resultCount int
wantError bool
}{
{
name: "找到token",
accessToken: "token-123",
resultCount: 1,
wantError: false,
},
{
name: "未找到token",
accessToken: "token-123",
resultCount: 0,
wantError: true, // 访问索引0会panic
},
{
name: "找到多个token异常情况",
accessToken: "token-123",
resultCount: 2,
wantError: false, // 返回第一个
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 验证逻辑如果结果为空访问索引0会出错
hasError := tt.resultCount == 0
if hasError != tt.wantError {
t.Errorf("FindTokenByID logic failed: got error=%v, want error=%v", hasError, tt.wantError)
}
})
}
}

View File

@@ -315,6 +315,18 @@ func (m *MockTextureRepository) FindByHash(ctx context.Context, hash string) (*m
return nil, nil return nil, nil
} }
func (m *MockTextureRepository) FindByHashAndUploaderID(ctx context.Context, hash string, uploaderID int64) (*model.Texture, error) {
if m.FailFind {
return nil, errors.New("mock find error")
}
for _, texture := range m.textures {
if texture.Hash == hash && texture.UploaderID == uploaderID {
return texture, nil
}
}
return nil, nil
}
func (m *MockTextureRepository) FindByUploaderID(ctx context.Context, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { func (m *MockTextureRepository) FindByUploaderID(ctx context.Context, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
if m.FailFind { if m.FailFind {
return nil, 0, errors.New("mock find error") return nil, 0, errors.New("mock find error")
@@ -462,101 +474,6 @@ func (m *MockTextureRepository) BatchDelete(ctx context.Context, ids []int64) (i
return deleted, nil return deleted, nil
} }
// MockTokenRepository 模拟TokenRepository
type MockTokenRepository struct {
tokens map[string]*model.Token
userTokens map[int64][]*model.Token
FailCreate bool
FailFind bool
FailDelete bool
}
func NewMockTokenRepository() *MockTokenRepository {
return &MockTokenRepository{
tokens: make(map[string]*model.Token),
userTokens: make(map[int64][]*model.Token),
}
}
func (m *MockTokenRepository) Create(ctx context.Context, token *model.Token) error {
if m.FailCreate {
return errors.New("mock create error")
}
m.tokens[token.AccessToken] = token
m.userTokens[token.UserID] = append(m.userTokens[token.UserID], token)
return nil
}
func (m *MockTokenRepository) FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error) {
if m.FailFind {
return nil, errors.New("mock find error")
}
if token, ok := m.tokens[accessToken]; ok {
return token, nil
}
return nil, errors.New("token not found")
}
func (m *MockTokenRepository) GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error) {
if m.FailFind {
return nil, errors.New("mock find error")
}
return m.userTokens[userId], nil
}
func (m *MockTokenRepository) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
if m.FailFind {
return "", errors.New("mock find error")
}
if token, ok := m.tokens[accessToken]; ok {
return token.ProfileId, nil
}
return "", errors.New("token not found")
}
func (m *MockTokenRepository) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
if m.FailFind {
return 0, errors.New("mock find error")
}
if token, ok := m.tokens[accessToken]; ok {
return token.UserID, nil
}
return 0, errors.New("token not found")
}
func (m *MockTokenRepository) DeleteByAccessToken(ctx context.Context, accessToken string) error {
if m.FailDelete {
return errors.New("mock delete error")
}
delete(m.tokens, accessToken)
return nil
}
func (m *MockTokenRepository) DeleteByUserID(ctx context.Context, userId int64) error {
if m.FailDelete {
return errors.New("mock delete error")
}
for _, token := range m.userTokens[userId] {
delete(m.tokens, token.AccessToken)
}
m.userTokens[userId] = nil
return nil
}
func (m *MockTokenRepository) BatchDelete(ctx context.Context, accessTokens []string) (int64, error) {
if m.FailDelete {
return 0, errors.New("mock delete error")
}
var count int64
for _, accessToken := range accessTokens {
if _, ok := m.tokens[accessToken]; ok {
delete(m.tokens, accessToken)
count++
}
}
return count, nil
}
// MockSystemConfigRepository 模拟SystemConfigRepository // MockSystemConfigRepository 模拟SystemConfigRepository
type MockSystemConfigRepository struct { type MockSystemConfigRepository struct {
configs map[string]*model.SystemConfig configs map[string]*model.SystemConfig
@@ -956,90 +873,11 @@ func (m *MockTextureService) CheckUploadLimit(uploaderID int64, maxTextures int)
return nil return nil
} }
// MockTokenService 模拟TokenService
type MockTokenService struct {
tokens map[string]*model.Token
FailCreate bool
FailValidate bool
FailRefresh bool
}
func NewMockTokenService() *MockTokenService {
return &MockTokenService{
tokens: make(map[string]*model.Token),
}
}
func (m *MockTokenService) Create(userID int64, uuid, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
if m.FailCreate {
return nil, nil, "", "", errors.New("mock create error")
}
accessToken := "mock-access-token"
if clientToken == "" {
clientToken = "mock-client-token"
}
token := &model.Token{
AccessToken: accessToken,
ClientToken: clientToken,
UserID: userID,
ProfileId: uuid,
Usable: true,
}
m.tokens[accessToken] = token
return nil, nil, accessToken, clientToken, nil
}
func (m *MockTokenService) Validate(accessToken, clientToken string) bool {
if m.FailValidate {
return false
}
if token, ok := m.tokens[accessToken]; ok {
if clientToken == "" || token.ClientToken == clientToken {
return token.Usable
}
}
return false
}
func (m *MockTokenService) Refresh(accessToken, clientToken, selectedProfileID string) (string, string, error) {
if m.FailRefresh {
return "", "", errors.New("mock refresh error")
}
return "new-access-token", clientToken, nil
}
func (m *MockTokenService) Invalidate(accessToken string) {
delete(m.tokens, accessToken)
}
func (m *MockTokenService) InvalidateUserTokens(userID int64) {
for key, token := range m.tokens {
if token.UserID == userID {
delete(m.tokens, key)
}
}
}
func (m *MockTokenService) GetUUIDByAccessToken(accessToken string) (string, error) {
if token, ok := m.tokens[accessToken]; ok {
return token.ProfileId, nil
}
return "", errors.New("token not found")
}
func (m *MockTokenService) GetUserIDByAccessToken(accessToken string) (int64, error) {
if token, ok := m.tokens[accessToken]; ok {
return token.UserID, nil
}
return 0, errors.New("token not found")
}
// ============================================================================ // ============================================================================
// CacheManager Mock - uses database.CacheManager with nil redis // CacheManager Mock - 使用 database.CacheManager 的内存版本
// ============================================================================ // ============================================================================
// NewMockCacheManager 创建一个禁用的 CacheManager 用于测试 // NewMockCacheManager 创建一个内存 CacheManager 用于测试
// 通过设置 Enabled = false缓存操作会被跳过测试不依赖 Redis
func NewMockCacheManager() *database.CacheManager { func NewMockCacheManager() *database.CacheManager {
return database.NewCacheManager(nil, database.CacheConfig{ return database.NewCacheManager(nil, database.CacheConfig{
Prefix: "test:", Prefix: "test:",

View File

@@ -11,7 +11,6 @@ import (
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"go.uber.org/zap" "go.uber.org/zap"
@@ -99,7 +98,7 @@ func (s *profileService) GetByUUID(ctx context.Context, uuid string) (*model.Pro
// 尝试从缓存获取 // 尝试从缓存获取
cacheKey := s.cacheKeys.Profile(uuid) cacheKey := s.cacheKeys.Profile(uuid)
var profile model.Profile var profile model.Profile
if err := s.cache.Get(ctx, cacheKey, &profile); err == nil { if ok, _ := s.cache.TryGet(ctx, cacheKey, &profile); ok {
return &profile, nil return &profile, nil
} }
@@ -112,11 +111,9 @@ func (s *profileService) GetByUUID(ctx context.Context, uuid string) (*model.Pro
return nil, fmt.Errorf("查询档案失败: %w", err) return nil, fmt.Errorf("查询档案失败: %w", err)
} }
// 存入缓存(异步5分钟过期 // 存入缓存(异步)
if profile2 != nil { if profile2 != nil {
go func() { s.cache.SetAsync(context.Background(), cacheKey, profile2, s.cache.Policy.ProfileTTL)
_ = s.cache.Set(context.Background(), cacheKey, profile2, 5*time.Minute)
}()
} }
return profile2, nil return profile2, nil
@@ -126,7 +123,7 @@ func (s *profileService) GetByUserID(ctx context.Context, userID int64) ([]*mode
// 尝试从缓存获取 // 尝试从缓存获取
cacheKey := s.cacheKeys.ProfileList(userID) cacheKey := s.cacheKeys.ProfileList(userID)
var profiles []*model.Profile var profiles []*model.Profile
if err := s.cache.Get(ctx, cacheKey, &profiles); err == nil { if ok, _ := s.cache.TryGet(ctx, cacheKey, &profiles); ok {
return profiles, nil return profiles, nil
} }
@@ -136,11 +133,9 @@ func (s *profileService) GetByUserID(ctx context.Context, userID int64) ([]*mode
return nil, fmt.Errorf("查询档案列表失败: %w", err) return nil, fmt.Errorf("查询档案列表失败: %w", err)
} }
// 存入缓存(异步3分钟过期 // 存入缓存(异步)
if profiles != nil { if profiles != nil {
go func() { s.cache.SetAsync(context.Background(), cacheKey, profiles, s.cache.Policy.ProfileListTTL)
_ = s.cache.Set(context.Background(), cacheKey, profiles, 3*time.Minute)
}()
} }
return profiles, nil return profiles, nil

View File

@@ -13,7 +13,6 @@ import (
"fmt" "fmt"
"path/filepath" "path/filepath"
"strings" "strings"
"time"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -103,7 +102,7 @@ func (s *textureService) GetByID(ctx context.Context, id int64) (*model.Texture,
// 尝试从缓存获取 // 尝试从缓存获取
cacheKey := s.cacheKeys.Texture(id) cacheKey := s.cacheKeys.Texture(id)
var texture model.Texture var texture model.Texture
if err := s.cache.Get(ctx, cacheKey, &texture); err == nil { if ok, _ := s.cache.TryGet(ctx, cacheKey, &texture); ok {
if texture.Status == -1 { if texture.Status == -1 {
return nil, errors.New("材质已删除") return nil, errors.New("材质已删除")
} }
@@ -122,11 +121,9 @@ func (s *textureService) GetByID(ctx context.Context, id int64) (*model.Texture,
return nil, errors.New("材质已删除") return nil, errors.New("材质已删除")
} }
// 存入缓存(异步5分钟过期 // 存入缓存(异步)
if texture2 != nil { if texture2 != nil {
go func() { s.cache.SetAsync(context.Background(), cacheKey, texture2, s.cache.Policy.TextureTTL)
_ = s.cache.Set(context.Background(), cacheKey, texture2, 5*time.Minute)
}()
} }
return texture2, nil return texture2, nil
@@ -136,7 +133,7 @@ func (s *textureService) GetByHash(ctx context.Context, hash string) (*model.Tex
// 尝试从缓存获取 // 尝试从缓存获取
cacheKey := s.cacheKeys.TextureByHash(hash) cacheKey := s.cacheKeys.TextureByHash(hash)
var texture model.Texture var texture model.Texture
if err := s.cache.Get(ctx, cacheKey, &texture); err == nil { if ok, _ := s.cache.TryGet(ctx, cacheKey, &texture); ok {
if texture.Status == -1 { if texture.Status == -1 {
return nil, errors.New("材质已删除") return nil, errors.New("材质已删除")
} }
@@ -155,10 +152,8 @@ func (s *textureService) GetByHash(ctx context.Context, hash string) (*model.Tex
return nil, errors.New("材质已删除") return nil, errors.New("材质已删除")
} }
// 存入缓存(异步5分钟过期 // 存入缓存(异步)
go func() { s.cache.SetAsync(context.Background(), cacheKey, texture2, s.cache.Policy.TextureTTL)
_ = s.cache.Set(context.Background(), cacheKey, texture2, 5*time.Minute)
}()
return texture2, nil return texture2, nil
} }
@@ -172,7 +167,7 @@ func (s *textureService) GetByUserID(ctx context.Context, uploaderID int64, page
Textures []*model.Texture Textures []*model.Texture
Total int64 Total int64
} }
if err := s.cache.Get(ctx, cacheKey, &cachedResult); err == nil { if ok, _ := s.cache.TryGet(ctx, cacheKey, &cachedResult); ok {
return cachedResult.Textures, cachedResult.Total, nil return cachedResult.Textures, cachedResult.Total, nil
} }
@@ -182,14 +177,12 @@ func (s *textureService) GetByUserID(ctx context.Context, uploaderID int64, page
return nil, 0, err return nil, 0, err
} }
// 存入缓存(异步2分钟过期 // 存入缓存(异步)
go func() {
result := struct { result := struct {
Textures []*model.Texture Textures []*model.Texture
Total int64 Total int64
}{Textures: textures, Total: total} }{Textures: textures, Total: total}
_ = s.cache.Set(context.Background(), cacheKey, result, 2*time.Minute) s.cache.SetAsync(context.Background(), cacheKey, result, s.cache.Policy.TextureListTTL)
}()
return textures, total, nil return textures, total, nil
} }
@@ -232,7 +225,7 @@ func (s *textureService) Update(ctx context.Context, textureID, uploaderID int64
// 清除 texture 缓存和用户列表缓存 // 清除 texture 缓存和用户列表缓存
s.cacheInv.OnUpdate(ctx, s.cacheKeys.Texture(textureID)) s.cacheInv.OnUpdate(ctx, s.cacheKeys.Texture(textureID))
s.cacheInv.BatchInvalidate(ctx, fmt.Sprintf("texture:user:%d:*", uploaderID)) s.cacheInv.BatchInvalidate(ctx, s.cacheKeys.TextureListPattern(uploaderID))
return s.textureRepo.FindByID(ctx, textureID) return s.textureRepo.FindByID(ctx, textureID)
} }
@@ -257,7 +250,7 @@ func (s *textureService) Delete(ctx context.Context, textureID, uploaderID int64
// 清除 texture 缓存和用户列表缓存 // 清除 texture 缓存和用户列表缓存
s.cacheInv.OnDelete(ctx, s.cacheKeys.Texture(textureID)) s.cacheInv.OnDelete(ctx, s.cacheKeys.Texture(textureID))
s.cacheInv.BatchInvalidate(ctx, fmt.Sprintf("texture:user:%d:*", uploaderID)) s.cacheInv.BatchInvalidate(ctx, s.cacheKeys.TextureListPattern(uploaderID))
return nil return nil
} }

View File

@@ -494,7 +494,7 @@ func TestTextureServiceImpl_Create(t *testing.T) {
_ = userRepo.Create(context.Background(), testUser) _ = userRepo.Create(context.Background(), testUser)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger) textureService := NewTextureService(textureRepo, userRepo, nil, cacheManager, logger)
tests := []struct { tests := []struct {
name string name string
@@ -536,8 +536,7 @@ func TestTextureServiceImpl_Create(t *testing.T) {
textureName: "DuplicateTexture", textureName: "DuplicateTexture",
textureType: "SKIN", textureType: "SKIN",
hash: "existing-hash", hash: "existing-hash",
wantErr: true, wantErr: false,
errContains: "已存在",
setupMocks: func() { setupMocks: func() {
_ = textureRepo.Create(context.Background(), &model.Texture{ _ = textureRepo.Create(context.Background(), &model.Texture{
ID: 100, ID: 100,
@@ -617,7 +616,7 @@ func TestTextureServiceImpl_GetByID(t *testing.T) {
_ = textureRepo.Create(context.Background(), testTexture) _ = textureRepo.Create(context.Background(), testTexture)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger) textureService := NewTextureService(textureRepo, userRepo, nil, cacheManager, logger)
tests := []struct { tests := []struct {
name string name string
@@ -675,7 +674,7 @@ func TestTextureServiceImpl_GetByUserID_And_Search(t *testing.T) {
} }
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger) textureService := NewTextureService(textureRepo, userRepo, nil, cacheManager, logger)
ctx := context.Background() ctx := context.Background()
@@ -714,7 +713,7 @@ func TestTextureServiceImpl_Update_And_Delete(t *testing.T) {
_ = textureRepo.Create(context.Background(), texture) _ = textureRepo.Create(context.Background(), texture)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger) textureService := NewTextureService(textureRepo, userRepo, nil, cacheManager, logger)
ctx := context.Background() ctx := context.Background()
@@ -764,7 +763,7 @@ func TestTextureServiceImpl_FavoritesAndLimit(t *testing.T) {
} }
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger) textureService := NewTextureService(textureRepo, userRepo, nil, cacheManager, logger)
ctx := context.Background() ctx := context.Background()
@@ -807,7 +806,7 @@ func TestTextureServiceImpl_ToggleFavorite(t *testing.T) {
_ = textureRepo.Create(context.Background(), testTexture) _ = textureRepo.Create(context.Background(), testTexture)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger) textureService := NewTextureService(textureRepo, userRepo, nil, cacheManager, logger)
ctx := context.Background() ctx := context.Background()

View File

@@ -1,305 +0,0 @@
package service
import (
"carrotskin/internal/model"
"carrotskin/internal/repository"
"context"
"errors"
"fmt"
"strconv"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"go.uber.org/zap"
)
// tokenService TokenService的实现
type tokenService struct {
tokenRepo repository.TokenRepository
profileRepo repository.ProfileRepository
logger *zap.Logger
}
// NewTokenService 创建TokenService实例
func NewTokenService(
tokenRepo repository.TokenRepository,
profileRepo repository.ProfileRepository,
logger *zap.Logger,
) TokenService {
return &tokenService{
tokenRepo: tokenRepo,
profileRepo: profileRepo,
logger: logger,
}
}
const (
tokenExtendedTimeout = 10 * time.Second
tokensMaxCount = 10
)
func (s *tokenService) Create(ctx context.Context, userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
var (
selectedProfileID *model.Profile
availableProfiles []*model.Profile
)
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
// 验证用户存在
if UUID != "" {
_, err := s.profileRepo.FindByUUID(ctx, UUID)
if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
}
}
// 生成令牌
if clientToken == "" {
clientToken = uuid.New().String()
}
accessToken := uuid.New().String()
token := model.Token{
AccessToken: accessToken,
ClientToken: clientToken,
UserID: userID,
Usable: true,
IssueDate: time.Now(),
}
// 获取用户配置文件
profiles, err := s.profileRepo.FindByUserID(ctx, userID)
if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
}
// 如果用户只有一个配置文件,自动选择
if len(profiles) == 1 {
selectedProfileID = profiles[0]
token.ProfileId = selectedProfileID.UUID
}
availableProfiles = profiles
// 插入令牌
err = s.tokenRepo.Create(ctx, &token)
if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err)
}
// 清理多余的令牌(使用独立的后台上下文)
go s.checkAndCleanupExcessTokens(context.Background(), userID)
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
}
func (s *tokenService) Validate(ctx context.Context, accessToken, clientToken string) bool {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
if accessToken == "" {
return false
}
token, err := s.tokenRepo.FindByAccessToken(ctx, accessToken)
if err != nil {
return false
}
if !token.Usable {
return false
}
if clientToken == "" {
return true
}
return token.ClientToken == clientToken
}
func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
if accessToken == "" {
return "", "", errors.New("accessToken不能为空")
}
// 查找旧令牌
oldToken, err := s.tokenRepo.FindByAccessToken(ctx, accessToken)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return "", "", errors.New("accessToken无效")
}
s.logger.Error("查询Token失败", zap.Error(err), zap.String("accessToken", accessToken))
return "", "", fmt.Errorf("查询令牌失败: %w", err)
}
// 验证profile
if selectedProfileID != "" {
valid, validErr := s.validateProfileByUserID(ctx, oldToken.UserID, selectedProfileID)
if validErr != nil {
s.logger.Error("验证Profile失败",
zap.Error(err),
zap.Int64("userId", oldToken.UserID),
zap.String("profileId", selectedProfileID),
)
return "", "", fmt.Errorf("验证角色失败: %w", err)
}
if !valid {
return "", "", errors.New("角色与用户不匹配")
}
}
// 检查 clientToken 是否有效
if clientToken != "" && clientToken != oldToken.ClientToken {
return "", "", errors.New("clientToken无效")
}
// 检查 selectedProfileID 的逻辑
if selectedProfileID != "" {
if oldToken.ProfileId != "" && oldToken.ProfileId != selectedProfileID {
return "", "", errors.New("原令牌已绑定角色,无法选择新角色")
}
} else {
selectedProfileID = oldToken.ProfileId
}
// 生成新令牌
newAccessToken := uuid.New().String()
newToken := model.Token{
AccessToken: newAccessToken,
ClientToken: oldToken.ClientToken,
UserID: oldToken.UserID,
Usable: true,
ProfileId: selectedProfileID,
IssueDate: time.Now(),
}
// 先插入新令牌,再删除旧令牌
err = s.tokenRepo.Create(ctx, &newToken)
if err != nil {
s.logger.Error("创建新Token失败", zap.Error(err), zap.String("accessToken", accessToken))
return "", "", fmt.Errorf("创建新Token失败: %w", err)
}
err = s.tokenRepo.DeleteByAccessToken(ctx, accessToken)
if err != nil {
s.logger.Warn("删除旧Token失败但新Token已创建",
zap.Error(err),
zap.String("oldToken", oldToken.AccessToken),
zap.String("newToken", newAccessToken),
)
}
s.logger.Info("成功刷新Token", zap.Int64("userId", oldToken.UserID), zap.String("accessToken", newAccessToken))
return newAccessToken, oldToken.ClientToken, nil
}
func (s *tokenService) Invalidate(ctx context.Context, accessToken string) {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
if accessToken == "" {
return
}
err := s.tokenRepo.DeleteByAccessToken(ctx, accessToken)
if err != nil {
s.logger.Error("删除Token失败", zap.Error(err), zap.String("accessToken", accessToken))
return
}
s.logger.Info("成功删除Token", zap.String("token", accessToken))
}
func (s *tokenService) InvalidateUserTokens(ctx context.Context, userID int64) {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
if userID == 0 {
return
}
err := s.tokenRepo.DeleteByUserID(ctx, userID)
if err != nil {
s.logger.Error("删除用户Token失败", zap.Error(err), zap.Int64("userId", userID))
return
}
s.logger.Info("成功删除用户Token", zap.Int64("userId", userID))
}
func (s *tokenService) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
return s.tokenRepo.GetUUIDByAccessToken(ctx, accessToken)
}
func (s *tokenService) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
return s.tokenRepo.GetUserIDByAccessToken(ctx, accessToken)
}
// 私有辅助方法
func (s *tokenService) checkAndCleanupExcessTokens(ctx context.Context, userID int64) {
if userID == 0 {
return
}
// 为清理操作设置更长的超时时间
ctx, cancel := context.WithTimeout(ctx, tokenExtendedTimeout)
defer cancel()
tokens, err := s.tokenRepo.GetByUserID(ctx, userID)
if err != nil {
s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
return
}
if len(tokens) <= tokensMaxCount {
return
}
tokensToDelete := make([]string, 0, len(tokens)-tokensMaxCount)
for i := tokensMaxCount; i < len(tokens); i++ {
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
}
deletedCount, err := s.tokenRepo.BatchDelete(ctx, tokensToDelete)
if err != nil {
s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
return
}
if deletedCount > 0 {
s.logger.Info("成功清理用户多余Token", zap.Int64("userId", userID), zap.Int64("count", deletedCount))
}
}
func (s *tokenService) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) {
if userID == 0 || UUID == "" {
return false, errors.New("用户ID或配置文件ID不能为空")
}
profile, err := s.profileRepo.FindByUUID(ctx, UUID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return false, errors.New("配置文件不存在")
}
return false, fmt.Errorf("验证配置文件失败: %w", err)
}
return profile.UserID == userID, nil
}

View File

@@ -7,7 +7,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"strconv"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@@ -15,9 +14,9 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
// tokenServiceJWT TokenService的JWT实现使用JWT + Version机制 // tokenServiceRedis TokenService的Redis实现
type tokenServiceJWT struct { type tokenServiceRedis struct {
tokenRepo repository.TokenRepository tokenStore *auth.TokenStoreRedis
clientRepo repository.ClientRepository clientRepo repository.ClientRepository
profileRepo repository.ProfileRepository profileRepo repository.ProfileRepository
yggdrasilJWT *auth.YggdrasilJWTService yggdrasilJWT *auth.YggdrasilJWTService
@@ -26,16 +25,16 @@ type tokenServiceJWT struct {
tokenStaleSec int64 // Token过期但可用时间0表示永不过期 tokenStaleSec int64 // Token过期但可用时间0表示永不过期
} }
// NewTokenServiceJWT 创建使用JWT的TokenService实例 // NewTokenServiceRedis 创建使用Redis的TokenService实例
func NewTokenServiceJWT( func NewTokenServiceRedis(
tokenRepo repository.TokenRepository, tokenStore *auth.TokenStoreRedis,
clientRepo repository.ClientRepository, clientRepo repository.ClientRepository,
profileRepo repository.ProfileRepository, profileRepo repository.ProfileRepository,
yggdrasilJWT *auth.YggdrasilJWTService, yggdrasilJWT *auth.YggdrasilJWTService,
logger *zap.Logger, logger *zap.Logger,
) TokenService { ) TokenService {
return &tokenServiceJWT{ return &tokenServiceRedis{
tokenRepo: tokenRepo, tokenStore: tokenStore,
clientRepo: clientRepo, clientRepo: clientRepo,
profileRepo: profileRepo, profileRepo: profileRepo,
yggdrasilJWT: yggdrasilJWT, yggdrasilJWT: yggdrasilJWT,
@@ -45,10 +44,8 @@ func NewTokenServiceJWT(
} }
} }
// 常量已在 token_service.go 中定义,这里不重复定义 // Create 创建Token使用JWT + Redis存储
func (s *tokenServiceRedis) Create(ctx context.Context, userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
// Create 创建Token使用JWT + Version机制
func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
var ( var (
selectedProfileID *model.Profile selectedProfileID *model.Profile
availableProfiles []*model.Profile availableProfiles []*model.Profile
@@ -134,7 +131,7 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
if s.tokenExpireSec > 0 { if s.tokenExpireSec > 0 {
expiresAt = now.Add(time.Duration(s.tokenExpireSec) * time.Second) expiresAt = now.Add(time.Duration(s.tokenExpireSec) * time.Second)
} else { } else {
// 使用遥远的未来时间类似drasl的DISTANT_FUTURE // 使用遥远的未来时间
expiresAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC) expiresAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
} }
@@ -157,36 +154,31 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("生成AccessToken失败: %w", err) return selectedProfileID, availableProfiles, "", "", fmt.Errorf("生成AccessToken失败: %w", err)
} }
// 存Token记录(用于查询和审计) // 存Token到Redis
token := model.Token{ ttl := expiresAt.Sub(now)
AccessToken: accessToken, metadata := &auth.TokenMetadata{
ClientToken: clientToken,
UserID: userID, UserID: userID,
ProfileId: profileID, ProfileID: profileID,
ClientUUID: client.UUID,
ClientToken: client.ClientToken,
Version: client.Version, Version: client.Version,
Usable: true, CreatedAt: now.Unix(),
IssueDate: now,
ExpiresAt: &expiresAt,
StaleAt: &staleAt,
} }
err = s.tokenRepo.Create(ctx, &token) if err := s.tokenStore.Store(ctx, accessToken, metadata, ttl); err != nil {
if err != nil { s.logger.Warn("存储Token到Redis失败", zap.Error(err))
s.logger.Warn("保存Token记录失败但JWT已生成", zap.Error(err))
// 不返回错误因为JWT本身已经生成成功 // 不返回错误因为JWT本身已经生成成功
} }
// 清理多余的令牌(使用独立的后台上下文)
go s.checkAndCleanupExcessTokens(context.Background(), userID)
return selectedProfileID, availableProfiles, accessToken, clientToken, nil return selectedProfileID, availableProfiles, accessToken, clientToken, nil
} }
// Validate 验证Token使用JWT验证 // Validate 验证Token使用JWT验证 + Redis存储验证
func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken string) bool { func (s *tokenServiceRedis) Validate(ctx context.Context, accessToken, clientToken string) bool {
// 设置超时上下文 // 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout) ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel() defer cancel()
if accessToken == "" { if accessToken == "" {
return false return false
} }
@@ -197,6 +189,13 @@ func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken
return false return false
} }
// 从Redis获取Token元数据
metadata, err := s.tokenStore.Retrieve(ctx, accessToken)
if err != nil {
// Token可能已过期或不存在
return false
}
// 查找Client // 查找Client
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject) client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
if err != nil { if err != nil {
@@ -209,18 +208,19 @@ func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken
} }
// 验证ClientToken如果提供 // 验证ClientToken如果提供
if clientToken != "" && client.ClientToken != clientToken { if clientToken != "" && metadata.ClientToken != clientToken {
return false return false
} }
return true return true
} }
// Refresh 刷新Token使用Version机制无需删除旧Token // Refresh 刷新Token使用Version机制Redis存储
func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) { func (s *tokenServiceRedis) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
// 设置超时上下文 // 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout) ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel() defer cancel()
if accessToken == "" { if accessToken == "" {
return "", "", errors.New("accessToken不能为空") return "", "", errors.New("accessToken不能为空")
} }
@@ -279,6 +279,11 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
return "", "", fmt.Errorf("更新Client版本失败: %w", err) return "", "", fmt.Errorf("更新Client版本失败: %w", err)
} }
// 删除旧Token从Redis
if err := s.tokenStore.Delete(ctx, accessToken); err != nil {
s.logger.Warn("删除旧Token失败", zap.Error(err))
}
// 生成Token过期时间 // 生成Token过期时间
now := time.Now() now := time.Now()
var expiresAt, staleAt time.Time var expiresAt, staleAt time.Time
@@ -308,30 +313,27 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
return "", "", fmt.Errorf("生成新AccessToken失败: %w", err) return "", "", fmt.Errorf("生成新AccessToken失败: %w", err)
} }
// 存新Token记录 // 存新Token到Redis
newToken := model.Token{ ttl := expiresAt.Sub(now)
AccessToken: newAccessToken, metadata := &auth.TokenMetadata{
ClientToken: client.ClientToken,
UserID: client.UserID, UserID: client.UserID,
ProfileId: selectedProfileID, ProfileID: selectedProfileID,
ClientUUID: client.UUID,
ClientToken: client.ClientToken,
Version: client.Version, Version: client.Version,
Usable: true, CreatedAt: now.Unix(),
IssueDate: now,
ExpiresAt: &expiresAt,
StaleAt: &staleAt,
} }
err = s.tokenRepo.Create(ctx, &newToken) if err := s.tokenStore.Store(ctx, newAccessToken, metadata, ttl); err != nil {
if err != nil { s.logger.Warn("存储新Token到Redis失败", zap.Error(err))
s.logger.Warn("保存新Token记录失败但JWT已生成", zap.Error(err))
} }
s.logger.Info("成功刷新Token", zap.Int64("userId", client.UserID), zap.Int("version", client.Version)) s.logger.Info("成功刷新Token", zap.Int64("userId", client.UserID), zap.Int("version", client.Version))
return newAccessToken, client.ClientToken, nil return newAccessToken, client.ClientToken, nil
} }
// Invalidate 使Token失效通过增加Version // Invalidate 使Token失效从Redis删除
func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) { func (s *tokenServiceRedis) Invalidate(ctx context.Context, accessToken string) {
// 设置超时上下文 // 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout) ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel() defer cancel()
@@ -347,7 +349,7 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
return return
} }
// 查找Client并增加Version // 查找Client并增加Version失效所有旧Token
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject) client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
if err != nil { if err != nil {
s.logger.Warn("无法找到对应的Client", zap.Error(err)) s.logger.Warn("无法找到对应的Client", zap.Error(err))
@@ -362,11 +364,17 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
return return
} }
// 从Redis删除Token
if err := s.tokenStore.Delete(ctx, accessToken); err != nil {
s.logger.Warn("从Redis删除Token失败", zap.Error(err))
return
}
s.logger.Info("成功失效Token", zap.String("clientUUID", client.UUID), zap.Int("version", client.Version)) s.logger.Info("成功失效Token", zap.String("clientUUID", client.UUID), zap.Int("version", client.Version))
} }
// InvalidateUserTokens 使用户所有Token失效 // InvalidateUserTokens 使用户所有Token失效从Redis删除
func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64) { func (s *tokenServiceRedis) InvalidateUserTokens(ctx context.Context, userID int64) {
// 设置超时上下文 // 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout) ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel() defer cancel()
@@ -391,15 +399,20 @@ func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64
} }
} }
// 从Redis删除用户所有Token
if err := s.tokenStore.DeleteByUserID(ctx, userID); err != nil {
s.logger.Error("从Redis删除用户Token失败", zap.Error(err), zap.Int64("userId", userID))
return
}
s.logger.Info("成功失效用户所有Token", zap.Int64("userId", userID), zap.Int("clientCount", len(clients))) s.logger.Info("成功失效用户所有Token", zap.Int64("userId", userID), zap.Int("clientCount", len(clients)))
} }
// GetUUIDByAccessToken 从AccessToken获取UUID通过JWT解析 // GetUUIDByAccessToken 从AccessToken获取UUID通过JWT解析
func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) { func (s *tokenServiceRedis) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow) claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
if err != nil { if err != nil {
// 如果JWT解析失败尝试从数据库查询向后兼容 return "", errors.New("accessToken无效")
return s.tokenRepo.GetUUIDByAccessToken(ctx, accessToken)
} }
if claims.ProfileID != "" { if claims.ProfileID != "" {
@@ -420,11 +433,10 @@ func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken
} }
// GetUserIDByAccessToken 从AccessToken获取UserID通过JWT解析 // GetUserIDByAccessToken 从AccessToken获取UserID通过JWT解析
func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) { func (s *tokenServiceRedis) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow) claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
if err != nil { if err != nil {
// 如果JWT解析失败尝试从数据库查询向后兼容 return 0, errors.New("accessToken无效")
return s.tokenRepo.GetUserIDByAccessToken(ctx, accessToken)
} }
// 从Client获取UserID // 从Client获取UserID
@@ -441,44 +453,8 @@ func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToke
return client.UserID, nil return client.UserID, nil
} }
// 私有辅助方法 // validateProfileByUserID 验证Profile是否属于用户
func (s *tokenServiceRedis) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) {
func (s *tokenServiceJWT) checkAndCleanupExcessTokens(ctx context.Context, userID int64) {
if userID == 0 {
return
}
// 为清理操作设置更长的超时时间
ctx, cancel := context.WithTimeout(ctx, tokenExtendedTimeout)
defer cancel()
tokens, err := s.tokenRepo.GetByUserID(ctx, userID)
if err != nil {
s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
return
}
if len(tokens) <= tokensMaxCount {
return
}
tokensToDelete := make([]string, 0, len(tokens)-tokensMaxCount)
for i := tokensMaxCount; i < len(tokens); i++ {
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
}
deletedCount, err := s.tokenRepo.BatchDelete(ctx, tokensToDelete)
if err != nil {
s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
return
}
if deletedCount > 0 {
s.logger.Info("成功清理用户多余Token", zap.Int64("userId", userID), zap.Int64("count", deletedCount))
}
}
func (s *tokenServiceJWT) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) {
if userID == 0 || UUID == "" { if userID == 0 || UUID == "" {
return false, errors.New("用户ID或配置文件ID不能为空") return false, errors.New("用户ID或配置文件ID不能为空")
} }
@@ -492,24 +468,3 @@ func (s *tokenServiceJWT) validateProfileByUserID(ctx context.Context, userID in
} }
return profile.UserID == userID, nil return profile.UserID == userID, nil
} }
// GetClientFromToken 从Token获取Client信息辅助方法
func (s *tokenServiceJWT) GetClientFromToken(ctx context.Context, accessToken string, stalePolicy auth.StaleTokenPolicy) (*model.Client, error) {
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, stalePolicy)
if err != nil {
return nil, err
}
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
if err != nil {
return nil, err
}
// 验证Version
if claims.Version != client.Version {
return nil, errors.New("token版本不匹配")
}
return client, nil
}

View File

@@ -1,513 +0,0 @@
package service
import (
"carrotskin/internal/model"
"context"
"fmt"
"testing"
"go.uber.org/zap"
)
// TestTokenService_Constants 测试Token服务相关常量
func TestTokenService_Constants(t *testing.T) {
// 内部常量已私有化,通过服务行为间接测试
t.Skip("Token constants are now private - test through service behavior instead")
}
// TestTokenService_Validation 测试Token验证逻辑
func TestTokenService_Validation(t *testing.T) {
tests := []struct {
name string
accessToken string
wantValid bool
}{
{
name: "空token无效",
accessToken: "",
wantValid: false,
},
{
name: "非空token可能有效",
accessToken: "valid-token-string",
wantValid: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 测试空token检查逻辑
isValid := tt.accessToken != ""
if isValid != tt.wantValid {
t.Errorf("Token validation failed: got %v, want %v", isValid, tt.wantValid)
}
})
}
}
// TestTokenService_ClientTokenLogic 测试ClientToken逻辑
func TestTokenService_ClientTokenLogic(t *testing.T) {
tests := []struct {
name string
clientToken string
shouldGenerate bool
}{
{
name: "空的clientToken应该生成新的",
clientToken: "",
shouldGenerate: true,
},
{
name: "非空的clientToken应该使用提供的",
clientToken: "existing-client-token",
shouldGenerate: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldGenerate := tt.clientToken == ""
if shouldGenerate != tt.shouldGenerate {
t.Errorf("ClientToken logic failed: got %v, want %v", shouldGenerate, tt.shouldGenerate)
}
})
}
}
// TestTokenService_ProfileSelection 测试Profile选择逻辑
func TestTokenService_ProfileSelection(t *testing.T) {
tests := []struct {
name string
profileCount int
shouldAutoSelect bool
}{
{
name: "只有一个profile时自动选择",
profileCount: 1,
shouldAutoSelect: true,
},
{
name: "多个profile时不自动选择",
profileCount: 2,
shouldAutoSelect: false,
},
{
name: "没有profile时不自动选择",
profileCount: 0,
shouldAutoSelect: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldAutoSelect := tt.profileCount == 1
if shouldAutoSelect != tt.shouldAutoSelect {
t.Errorf("Profile selection logic failed: got %v, want %v", shouldAutoSelect, tt.shouldAutoSelect)
}
})
}
}
// TestTokenService_CleanupLogic 测试清理逻辑
func TestTokenService_CleanupLogic(t *testing.T) {
tests := []struct {
name string
tokenCount int
maxCount int
shouldCleanup bool
cleanupCount int
}{
{
name: "token数量未超过上限不需要清理",
tokenCount: 5,
maxCount: 10,
shouldCleanup: false,
cleanupCount: 0,
},
{
name: "token数量超过上限需要清理",
tokenCount: 15,
maxCount: 10,
shouldCleanup: true,
cleanupCount: 5,
},
{
name: "token数量等于上限不需要清理",
tokenCount: 10,
maxCount: 10,
shouldCleanup: false,
cleanupCount: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldCleanup := tt.tokenCount > tt.maxCount
if shouldCleanup != tt.shouldCleanup {
t.Errorf("Cleanup decision failed: got %v, want %v", shouldCleanup, tt.shouldCleanup)
}
if shouldCleanup {
expectedCleanupCount := tt.tokenCount - tt.maxCount
if expectedCleanupCount != tt.cleanupCount {
t.Errorf("Cleanup count failed: got %d, want %d", expectedCleanupCount, tt.cleanupCount)
}
}
})
}
}
// TestTokenService_UserIDValidation 测试UserID验证
func TestTokenService_UserIDValidation(t *testing.T) {
tests := []struct {
name string
userID int64
isValid bool
}{
{
name: "有效的UserID",
userID: 1,
isValid: true,
},
{
name: "UserID为0时无效",
userID: 0,
isValid: false,
},
{
name: "负数UserID无效",
userID: -1,
isValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tt.userID > 0
if isValid != tt.isValid {
t.Errorf("UserID validation failed: got %v, want %v", isValid, tt.isValid)
}
})
}
}
// ============================================================================
// 使用 Mock 的集成测试
// ============================================================================
// TestTokenServiceImpl_Create 测试创建Token
func TestTokenServiceImpl_Create(t *testing.T) {
tokenRepo := NewMockTokenRepository()
profileRepo := NewMockProfileRepository()
logger := zap.NewNop()
// 预置Profile
testProfile := &model.Profile{
UUID: "test-profile-uuid",
UserID: 1,
Name: "TestProfile",
IsActive: true,
}
_ = profileRepo.Create(context.Background(), testProfile)
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
tests := []struct {
name string
userID int64
uuid string
clientToken string
wantErr bool
}{
{
name: "正常创建Token指定UUID",
userID: 1,
uuid: "test-profile-uuid",
clientToken: "client-token-1",
wantErr: false,
},
{
name: "正常创建Token空clientToken",
userID: 1,
uuid: "test-profile-uuid",
clientToken: "",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
_, _, accessToken, clientToken, err := tokenService.Create(ctx, tt.userID, tt.uuid, tt.clientToken)
if tt.wantErr {
if err == nil {
t.Error("期望返回错误,但实际没有错误")
}
} else {
if err != nil {
t.Errorf("不期望返回错误: %v", err)
return
}
if accessToken == "" {
t.Error("accessToken不应为空")
}
if clientToken == "" {
t.Error("clientToken不应为空")
}
}
})
}
}
// TestTokenServiceImpl_Validate 测试验证Token
func TestTokenServiceImpl_Validate(t *testing.T) {
tokenRepo := NewMockTokenRepository()
profileRepo := NewMockProfileRepository()
logger := zap.NewNop()
// 预置Token
testToken := &model.Token{
AccessToken: "valid-access-token",
ClientToken: "valid-client-token",
UserID: 1,
ProfileId: "test-profile-uuid",
Usable: true,
}
_ = tokenRepo.Create(context.Background(), testToken)
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
tests := []struct {
name string
accessToken string
clientToken string
wantValid bool
}{
{
name: "有效Token完全匹配",
accessToken: "valid-access-token",
clientToken: "valid-client-token",
wantValid: true,
},
{
name: "有效Token只检查accessToken",
accessToken: "valid-access-token",
clientToken: "",
wantValid: true,
},
{
name: "无效TokenaccessToken不存在",
accessToken: "invalid-access-token",
clientToken: "",
wantValid: false,
},
{
name: "无效TokenclientToken不匹配",
accessToken: "valid-access-token",
clientToken: "wrong-client-token",
wantValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
isValid := tokenService.Validate(ctx, tt.accessToken, tt.clientToken)
if isValid != tt.wantValid {
t.Errorf("Token验证结果不匹配: got %v, want %v", isValid, tt.wantValid)
}
})
}
}
// TestTokenServiceImpl_Invalidate 测试注销Token
func TestTokenServiceImpl_Invalidate(t *testing.T) {
tokenRepo := NewMockTokenRepository()
profileRepo := NewMockProfileRepository()
logger := zap.NewNop()
// 预置Token
testToken := &model.Token{
AccessToken: "token-to-invalidate",
ClientToken: "client-token",
UserID: 1,
ProfileId: "test-profile-uuid",
Usable: true,
}
_ = tokenRepo.Create(context.Background(), testToken)
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
ctx := context.Background()
// 验证Token存在
isValid := tokenService.Validate(ctx, "token-to-invalidate", "")
if !isValid {
t.Error("Token应该有效")
}
// 注销Token
tokenService.Invalidate(ctx, "token-to-invalidate")
// 验证Token已失效从repo中删除
_, err := tokenRepo.FindByAccessToken(context.Background(), "token-to-invalidate")
if err == nil {
t.Error("Token应该已被删除")
}
}
// TestTokenServiceImpl_InvalidateUserTokens 测试注销用户所有Token
func TestTokenServiceImpl_InvalidateUserTokens(t *testing.T) {
tokenRepo := NewMockTokenRepository()
profileRepo := NewMockProfileRepository()
logger := zap.NewNop()
// 预置多个Token
for i := 1; i <= 3; i++ {
_ = tokenRepo.Create(context.Background(), &model.Token{
AccessToken: fmt.Sprintf("user1-token-%d", i),
ClientToken: "client-token",
UserID: 1,
ProfileId: "test-profile-uuid",
Usable: true,
})
}
_ = tokenRepo.Create(context.Background(), &model.Token{
AccessToken: "user2-token-1",
ClientToken: "client-token",
UserID: 2,
ProfileId: "test-profile-uuid-2",
Usable: true,
})
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
ctx := context.Background()
// 注销用户1的所有Token
tokenService.InvalidateUserTokens(ctx, 1)
// 验证用户1的Token已失效
tokens, _ := tokenRepo.GetByUserID(context.Background(), 1)
if len(tokens) > 0 {
t.Errorf("用户1的Token应该全部被删除但还剩 %d 个", len(tokens))
}
// 验证用户2的Token仍然存在
tokens2, _ := tokenRepo.GetByUserID(context.Background(), 2)
if len(tokens2) != 1 {
t.Errorf("用户2的Token应该仍然存在期望1个实际 %d 个", len(tokens2))
}
}
// TestTokenServiceImpl_Refresh 覆盖 Refresh 的主要分支
func TestTokenServiceImpl_Refresh(t *testing.T) {
tokenRepo := NewMockTokenRepository()
profileRepo := NewMockProfileRepository()
logger := zap.NewNop()
// 预置 Profile 与 Token
profile := &model.Profile{
UUID: "profile-uuid",
UserID: 1,
}
_ = profileRepo.Create(context.Background(), profile)
oldToken := &model.Token{
AccessToken: "old-token",
ClientToken: "client-token",
UserID: 1,
ProfileId: "",
Usable: true,
}
_ = tokenRepo.Create(context.Background(), oldToken)
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
ctx := context.Background()
// 正常刷新,不指定 profile
newAccess, client, err := tokenService.Refresh(ctx, "old-token", "client-token", "")
if err != nil {
t.Fatalf("Refresh 正常情况失败: %v", err)
}
if newAccess == "" || client != "client-token" {
t.Fatalf("Refresh 返回值异常: access=%s, client=%s", newAccess, client)
}
// accessToken 为空
if _, _, err := tokenService.Refresh(ctx, "", "client-token", ""); err == nil {
t.Fatalf("Refresh 在 accessToken 为空时应返回错误")
}
}
// TestTokenServiceImpl_GetByAccessToken 封装 GetUUIDByAccessToken / GetUserIDByAccessToken
func TestTokenServiceImpl_GetByAccessToken(t *testing.T) {
tokenRepo := NewMockTokenRepository()
profileRepo := NewMockProfileRepository()
logger := zap.NewNop()
token := &model.Token{
AccessToken: "token-1",
UserID: 42,
ProfileId: "profile-42",
Usable: true,
}
_ = tokenRepo.Create(context.Background(), token)
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
ctx := context.Background()
uuid, err := tokenService.GetUUIDByAccessToken(ctx, "token-1")
if err != nil || uuid != "profile-42" {
t.Fatalf("GetUUIDByAccessToken 返回错误: uuid=%s, err=%v", uuid, err)
}
uid, err := tokenService.GetUserIDByAccessToken(ctx, "token-1")
if err != nil || uid != 42 {
t.Fatalf("GetUserIDByAccessToken 返回错误: uid=%d, err=%v", uid, err)
}
}
// TestTokenServiceImpl_validateProfileByUserID 直接测试内部校验逻辑
func TestTokenServiceImpl_validateProfileByUserID(t *testing.T) {
tokenRepo := NewMockTokenRepository()
profileRepo := NewMockProfileRepository()
logger := zap.NewNop()
svc := &tokenService{
tokenRepo: tokenRepo,
profileRepo: profileRepo,
logger: logger,
}
// 预置 Profile
profile := &model.Profile{
UUID: "p-1",
UserID: 1,
}
_ = profileRepo.Create(context.Background(), profile)
// 参数非法
if ok, err := svc.validateProfileByUserID(context.Background(), 0, ""); err == nil || ok {
t.Fatalf("validateProfileByUserID 在参数非法时应返回错误")
}
// Profile 不存在
if ok, err := svc.validateProfileByUserID(context.Background(), 1, "not-exists"); err == nil || ok {
t.Fatalf("validateProfileByUserID 在 Profile 不存在时应返回错误")
}
// 用户与 Profile 匹配
if ok, err := svc.validateProfileByUserID(context.Background(), 1, "p-1"); err != nil || !ok {
t.Fatalf("validateProfileByUserID 匹配时应返回 true, err=%v", err)
}
// 用户与 Profile 不匹配
if ok, err := svc.validateProfileByUserID(context.Background(), 2, "p-1"); err != nil || ok {
t.Fatalf("validateProfileByUserID 不匹配时应返回 false, err=%v", err)
}
}

View File

@@ -183,7 +183,7 @@ func (s *userService) GetByID(ctx context.Context, id int64) (*model.User, error
cacheKey := s.cacheKeys.User(id) cacheKey := s.cacheKeys.User(id)
return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) { return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) {
return s.userRepo.FindByID(ctx, id) return s.userRepo.FindByID(ctx, id)
}, 5*time.Minute) }, s.cache.Policy.UserTTL)
} }
func (s *userService) GetByEmail(ctx context.Context, email string) (*model.User, error) { func (s *userService) GetByEmail(ctx context.Context, email string) (*model.User, error) {
@@ -191,7 +191,7 @@ func (s *userService) GetByEmail(ctx context.Context, email string) (*model.User
cacheKey := s.cacheKeys.UserByEmail(email) cacheKey := s.cacheKeys.UserByEmail(email)
return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) { return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) {
return s.userRepo.FindByEmail(ctx, email) return s.userRepo.FindByEmail(ctx, email)
}, 5*time.Minute) }, s.cache.Policy.UserEmailTTL)
} }
func (s *userService) UpdateInfo(ctx context.Context, user *model.User) error { func (s *userService) UpdateInfo(ctx context.Context, user *model.User) error {

View File

@@ -22,7 +22,7 @@ type yggdrasilServiceComposite struct {
serializationService SerializationService serializationService SerializationService
certificateService CertificateService certificateService CertificateService
profileRepo repository.ProfileRepository profileRepo repository.ProfileRepository
tokenRepo repository.TokenRepository tokenService TokenService // 使用TokenService接口不直接依赖TokenRepository
logger *zap.Logger logger *zap.Logger
} }
@@ -31,11 +31,11 @@ func NewYggdrasilServiceComposite(
db *gorm.DB, db *gorm.DB,
userRepo repository.UserRepository, userRepo repository.UserRepository,
profileRepo repository.ProfileRepository, profileRepo repository.ProfileRepository,
tokenRepo repository.TokenRepository,
yggdrasilRepo repository.YggdrasilRepository, yggdrasilRepo repository.YggdrasilRepository,
signatureService *SignatureService, signatureService *SignatureService,
redisClient *redis.Client, redisClient *redis.Client,
logger *zap.Logger, logger *zap.Logger,
tokenService TokenService, // 新增TokenService接口
) YggdrasilService { ) YggdrasilService {
// 创建各个专门的服务 // 创建各个专门的服务
authService := NewYggdrasilAuthService(db, userRepo, yggdrasilRepo, logger) authService := NewYggdrasilAuthService(db, userRepo, yggdrasilRepo, logger)
@@ -53,7 +53,7 @@ func NewYggdrasilServiceComposite(
serializationService: serializationService, serializationService: serializationService,
certificateService: certificateService, certificateService: certificateService,
profileRepo: profileRepo, profileRepo: profileRepo,
tokenRepo: tokenRepo, tokenService: tokenService,
logger: logger, logger: logger,
} }
} }
@@ -75,8 +75,8 @@ func (s *yggdrasilServiceComposite) ResetYggdrasilPassword(ctx context.Context,
// JoinServer 加入服务器 // JoinServer 加入服务器
func (s *yggdrasilServiceComposite) JoinServer(ctx context.Context, serverID, accessToken, selectedProfile, ip string) error { func (s *yggdrasilServiceComposite) JoinServer(ctx context.Context, serverID, accessToken, selectedProfile, ip string) error {
// 验证Token // 通过TokenService验证Token并获取UUID
token, err := s.tokenRepo.FindByAccessToken(ctx, accessToken) uuid, err := s.tokenService.GetUUIDByAccessToken(ctx, accessToken)
if err != nil { if err != nil {
s.logger.Error("验证Token失败", s.logger.Error("验证Token失败",
zap.Error(err), zap.Error(err),
@@ -87,7 +87,7 @@ func (s *yggdrasilServiceComposite) JoinServer(ctx context.Context, serverID, ac
// 格式化UUID并验证与Token关联的配置文件 // 格式化UUID并验证与Token关联的配置文件
formattedProfile := utils.FormatUUID(selectedProfile) formattedProfile := utils.FormatUUID(selectedProfile)
if token.ProfileId != formattedProfile { if uuid != formattedProfile {
return errors.New("selectedProfile与Token不匹配") return errors.New("selectedProfile与Token不匹配")
} }

168
internal/task/runner.go Normal file
View File

@@ -0,0 +1,168 @@
package task
import (
"context"
"math/rand"
"runtime/debug"
"sync"
"time"
"go.uber.org/zap"
)
// Task 定义可调度任务
type Task interface {
Name() string
Interval() time.Duration
Run(ctx context.Context) error
}
// Runner 简单的周期任务调度器
type Runner struct {
tasks []Task
logger *zap.Logger
wg sync.WaitGroup
startImmediately bool
jitterPercent float64
}
// NewRunner 创建任务调度器
func NewRunner(logger *zap.Logger, tasks ...Task) *Runner {
return NewRunnerWithOptions(logger, tasks)
}
// RunnerOption 运行器配置项
type RunnerOption func(r *Runner)
// WithStartImmediately 是否启动后立即执行一次(默认 true
func WithStartImmediately(start bool) RunnerOption {
return func(r *Runner) {
r.startImmediately = start
}
}
// WithJitter 为执行间隔增加 0~percent 之间的随机抖动percent=0 关闭默认0
// 可降低多个任务同时触发的概率
func WithJitter(percent float64) RunnerOption {
return func(r *Runner) {
if percent < 0 {
percent = 0
}
r.jitterPercent = percent
}
}
// NewRunnerWithOptions 支持可选配置的创建函数
func NewRunnerWithOptions(logger *zap.Logger, tasks []Task, opts ...RunnerOption) *Runner {
r := &Runner{
tasks: tasks,
logger: logger,
startImmediately: true,
jitterPercent: 0,
}
for _, opt := range opts {
opt(r)
}
return r
}
// Start 启动所有任务(异步)
func (r *Runner) Start(ctx context.Context) {
for _, t := range r.tasks {
task := t
r.wg.Add(1)
go func() {
defer r.wg.Done()
defer r.recoverPanic(task)
interval := r.normalizeInterval(task.Interval())
// 可选:立即执行一次
if r.startImmediately {
r.runOnce(ctx, task)
}
// 周期执行
for {
wait := r.applyJitter(interval)
if !r.wait(ctx, wait) {
return
}
// 每轮读取最新的 interval允许任务动态调整间隔
interval = r.normalizeInterval(task.Interval())
select {
case <-ctx.Done():
return
default:
r.runOnce(ctx, task)
}
}
}()
}
}
// Wait 等待所有任务退出
func (r *Runner) Wait() {
r.wg.Wait()
}
func (r *Runner) runOnce(ctx context.Context, task Task) {
if err := task.Run(ctx); err != nil && r.logger != nil {
r.logger.Warn("任务执行失败", zap.String("task", task.Name()), zap.Error(err))
}
}
// normalizeInterval 确保间隔为正值
func (r *Runner) normalizeInterval(d time.Duration) time.Duration {
if d <= 0 {
return time.Minute
}
return d
}
// applyJitter 在基础间隔上添加最多 jitterPercent 的随机抖动
func (r *Runner) applyJitter(base time.Duration) time.Duration {
if r.jitterPercent <= 0 {
return base
}
maxJitter := time.Duration(float64(base) * r.jitterPercent)
if maxJitter <= 0 {
return base
}
return base + time.Duration(rand.Int63n(int64(maxJitter)))
}
// wait 封装带 context 的 sleep
func (r *Runner) wait(ctx context.Context, d time.Duration) bool {
if d <= 0 {
select {
case <-ctx.Done():
return false
default:
return true
}
}
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return false
case <-timer.C:
return true
}
}
// recoverPanic 防止任务 panic 导致 goroutine 退出
func (r *Runner) recoverPanic(task Task) {
if rec := recover(); rec != nil && r.logger != nil {
r.logger.Error("任务发生panic",
zap.String("task", task.Name()),
zap.Any("panic", rec),
zap.ByteString("stack", debug.Stack()),
)
}
}

View File

@@ -0,0 +1,65 @@
package task
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"go.uber.org/zap"
)
type mockTask struct {
name string
interval time.Duration
err error
runCount *atomic.Int32
}
func (m *mockTask) Name() string { return m.name }
func (m *mockTask) Interval() time.Duration { return m.interval }
func (m *mockTask) Run(ctx context.Context) error {
if m.runCount != nil {
m.runCount.Add(1)
}
return m.err
}
func TestRunner_StartAndWait(t *testing.T) {
runCount := &atomic.Int32{}
task := &mockTask{name: "ok", interval: 20 * time.Millisecond, runCount: runCount}
runner := NewRunner(zap.NewNop(), task)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
runner.Start(ctx)
time.Sleep(60 * time.Millisecond)
cancel()
runner.Wait()
if runCount.Load() == 0 {
t.Fatalf("expected task to run at least once")
}
}
func TestRunner_RunErrorLogged(t *testing.T) {
runCount := &atomic.Int32{}
task := &mockTask{name: "err", interval: 10 * time.Millisecond, err: errors.New("boom"), runCount: runCount}
runner := NewRunner(zap.NewNop(), task)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
runner.Start(ctx)
time.Sleep(25 * time.Millisecond)
cancel()
runner.Wait()
if runCount.Load() == 0 {
t.Fatalf("expected task to be attempted")
}
}

View File

@@ -0,0 +1,56 @@
package testutil
import (
"testing"
"time"
"carrotskin/internal/model"
"carrotskin/pkg/database"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// NewTestDB 返回基于内存的 sqlite 数据库并完成模型迁移
func NewTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
if err != nil {
t.Fatalf("failed to open sqlite memory db: %v", err)
}
if err := db.AutoMigrate(
&model.User{},
&model.UserPointLog{},
&model.UserLoginLog{},
&model.Profile{},
&model.Texture{},
&model.UserTextureFavorite{},
&model.TextureDownloadLog{},
&model.Client{},
&model.Yggdrasil{},
&model.SystemConfig{},
&model.AuditLog{},
&model.CasbinRule{},
); err != nil {
t.Fatalf("failed to migrate models: %v", err)
}
return db
}
// NewNoopLogger 返回无输出 logger
func NewNoopLogger() *zap.Logger {
return zap.NewNop()
}
// NewTestCache 返回禁用 redis 的缓存管理器(用于单元测试)
func NewTestCache() *database.CacheManager {
return database.NewCacheManager(nil, database.CacheConfig{
Prefix: "test:",
Expiration: 1 * time.Minute,
Enabled: false,
})
}

View File

@@ -0,0 +1,27 @@
package testutil
import "testing"
func TestNewTestDB(t *testing.T) {
db := NewTestDB(t)
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("DB() err: %v", err)
}
if err := sqlDB.Ping(); err != nil {
t.Fatalf("ping err: %v", err)
}
}
func TestNewTestCache(t *testing.T) {
cache := NewTestCache()
if cache.Policy.UserTTL == 0 {
t.Fatalf("expected defaults filled")
}
// disabled cache should not error on Set
if err := cache.Set(nil, "k", "v"); err != nil {
t.Fatalf("Set on disabled cache should be nil err, got %v", err)
}
}

320
pkg/auth/token_redis.go Normal file
View File

@@ -0,0 +1,320 @@
package auth
import (
"context"
"encoding/json"
"fmt"
"time"
"carrotskin/pkg/redis"
"go.uber.org/zap"
)
// TokenMetadata Token元数据存储在Redis中
type TokenMetadata struct {
UserID int64 `json:"user_id"`
ProfileID string `json:"profile_id"`
ClientUUID string `json:"client_uuid"`
ClientToken string `json:"client_token"`
Version int `json:"version"`
CreatedAt int64 `json:"created_at"`
}
// TokenStoreRedis Redis Token存储实现
type TokenStoreRedis struct {
redis *redis.Client
logger *zap.Logger
keyPrefix string
defaultTTL time.Duration
staleTTL time.Duration
maxTokensPerUser int
}
// NewTokenStoreRedis 创建Redis Token存储
func NewTokenStoreRedis(
redisClient *redis.Client,
logger *zap.Logger,
opts ...TokenStoreOption,
) *TokenStoreRedis {
options := &tokenStoreOptions{
keyPrefix: "token:",
defaultTTL: 24 * time.Hour,
staleTTL: 30 * 24 * time.Hour,
maxTokensPerUser: 10,
}
for _, opt := range opts {
opt(options)
}
return &TokenStoreRedis{
redis: redisClient,
logger: logger,
keyPrefix: options.keyPrefix,
defaultTTL: options.defaultTTL,
staleTTL: options.staleTTL,
maxTokensPerUser: options.maxTokensPerUser,
}
}
// tokenStoreOptions Token存储配置选项
type tokenStoreOptions struct {
keyPrefix string
defaultTTL time.Duration
staleTTL time.Duration
maxTokensPerUser int
}
// TokenStoreOption Token存储配置选项函数
type TokenStoreOption func(*tokenStoreOptions)
// WithKeyPrefix 设置Key前缀
func WithKeyPrefix(prefix string) TokenStoreOption {
return func(o *tokenStoreOptions) {
o.keyPrefix = prefix
}
}
// WithDefaultTTL 设置默认TTL
func WithDefaultTTL(ttl time.Duration) TokenStoreOption {
return func(o *tokenStoreOptions) {
o.defaultTTL = ttl
}
}
// WithStaleTTL 设置过期但可用时间
func WithStaleTTL(ttl time.Duration) TokenStoreOption {
return func(o *tokenStoreOptions) {
o.staleTTL = ttl
}
}
// WithMaxTokensPerUser 设置每个用户的最大Token数
func WithMaxTokensPerUser(max int) TokenStoreOption {
return func(o *tokenStoreOptions) {
o.maxTokensPerUser = max
}
}
// Store 存储Token
func (s *TokenStoreRedis) Store(ctx context.Context, accessToken string, metadata *TokenMetadata, ttl time.Duration) error {
if ttl <= 0 {
ttl = s.defaultTTL
}
// 序列化元数据
data, err := json.Marshal(metadata)
if err != nil {
return fmt.Errorf("序列化Token元数据失败: %w", err)
}
// 存储Token
tokenKey := s.getTokenKey(accessToken)
if err := s.redis.Set(ctx, tokenKey, data, ttl); err != nil {
return fmt.Errorf("存储Token失败: %w", err)
}
// 添加到用户Token集合
userTokensKey := s.getUserTokensKey(metadata.UserID)
if err := s.redis.SAdd(ctx, userTokensKey, accessToken); err != nil {
return fmt.Errorf("添加到用户Token集合失败: %w", err)
}
// 清理过期Token后台执行
go s.cleanupUserTokens(context.Background(), metadata.UserID)
s.logger.Debug("Token已存储",
zap.String("token", accessToken[:20]+"..."),
zap.Int64("userId", metadata.UserID),
zap.Duration("ttl", ttl),
)
return nil
}
// Retrieve 获取Token元数据
func (s *TokenStoreRedis) Retrieve(ctx context.Context, accessToken string) (*TokenMetadata, error) {
tokenKey := s.getTokenKey(accessToken)
data, err := s.redis.Get(ctx, tokenKey)
if err != nil {
return nil, fmt.Errorf("获取Token失败: %w", err)
}
var metadata TokenMetadata
if err := json.Unmarshal([]byte(data), &metadata); err != nil {
return nil, fmt.Errorf("解析Token元数据失败: %w", err)
}
return &metadata, nil
}
// Delete 删除Token
func (s *TokenStoreRedis) Delete(ctx context.Context, accessToken string) error {
tokenKey := s.getTokenKey(accessToken)
// 先获取Token元数据以获取UserID
metadata, err := s.Retrieve(ctx, accessToken)
if err != nil {
// Token可能已过期忽略错误
return nil
}
// 删除Token
if err := s.redis.Del(ctx, tokenKey); err != nil {
return fmt.Errorf("删除Token失败: %w", err)
}
// 从用户Token集合中移除
userTokensKey := s.getUserTokensKey(metadata.UserID)
if err := s.redis.SRem(ctx, userTokensKey, accessToken); err != nil {
return fmt.Errorf("从用户Token集合移除失败: %w", err)
}
s.logger.Debug("Token已删除",
zap.String("token", accessToken[:20]+"..."),
zap.Int64("userId", metadata.UserID),
)
return nil
}
// DeleteByUserID 删除用户的所有Token
func (s *TokenStoreRedis) DeleteByUserID(ctx context.Context, userID int64) error {
userTokensKey := s.getUserTokensKey(userID)
// 获取用户所有Token
tokens, err := s.redis.SMembers(ctx, userTokensKey)
if err != nil {
return fmt.Errorf("获取用户Token列表失败: %w", err)
}
// 删除所有Token
if len(tokens) > 0 {
tokenKeys := make([]string, len(tokens))
for i, token := range tokens {
tokenKeys[i] = s.getTokenKey(token)
}
if err := s.redis.Del(ctx, tokenKeys...); err != nil {
return fmt.Errorf("批量删除Token失败: %w", err)
}
}
// 删除用户Token集合
if err := s.redis.Del(ctx, userTokensKey); err != nil {
return fmt.Errorf("删除用户Token集合失败: %w", err)
}
s.logger.Info("用户所有Token已删除",
zap.Int64("userId", userID),
zap.Int("count", len(tokens)),
)
return nil
}
// Exists 检查Token是否存在
func (s *TokenStoreRedis) Exists(ctx context.Context, accessToken string) (bool, error) {
tokenKey := s.getTokenKey(accessToken)
count, err := s.redis.Exists(ctx, tokenKey)
if err != nil {
return false, fmt.Errorf("检查Token存在失败: %w", err)
}
return count > 0, nil
}
// GetTTL 获取Token的剩余TTL
func (s *TokenStoreRedis) GetTTL(ctx context.Context, accessToken string) (time.Duration, error) {
tokenKey := s.getTokenKey(accessToken)
return s.redis.TTL(ctx, tokenKey)
}
// RefreshTTL 刷新Token的TTL
func (s *TokenStoreRedis) RefreshTTL(ctx context.Context, accessToken string, ttl time.Duration) error {
if ttl <= 0 {
ttl = s.defaultTTL
}
tokenKey := s.getTokenKey(accessToken)
if err := s.redis.Expire(ctx, tokenKey, ttl); err != nil {
return fmt.Errorf("刷新Token TTL失败: %w", err)
}
return nil
}
// GetCountByUser 获取用户的Token数量
func (s *TokenStoreRedis) GetCountByUser(ctx context.Context, userID int64) (int64, error) {
userTokensKey := s.getUserTokensKey(userID)
count, err := s.redis.SMembers(ctx, userTokensKey)
if err != nil {
return 0, fmt.Errorf("获取用户Token数量失败: %w", err)
}
return int64(len(count)), nil
}
// cleanupUserTokens 清理用户的过期Token保留最新的N个
func (s *TokenStoreRedis) cleanupUserTokens(ctx context.Context, userID int64) {
userTokensKey := s.getUserTokensKey(userID)
// 获取用户所有Token
tokens, err := s.redis.SMembers(ctx, userTokensKey)
if err != nil {
s.logger.Error("获取用户Token列表失败", zap.Error(err), zap.Int64("userId", userID))
return
}
// 清理过期的Token验证它们是否仍存在
validTokens := make([]string, 0, len(tokens))
for _, token := range tokens {
tokenKey := s.getTokenKey(token)
exists, err := s.redis.Exists(ctx, tokenKey)
if err != nil {
s.logger.Error("检查Token存在失败", zap.Error(err), zap.String("token", token[:20]+"..."))
continue
}
if exists > 0 {
validTokens = append(validTokens, token)
}
}
// 如果没有变化,直接返回
if len(validTokens) == len(tokens) {
return
}
// 更新用户Token集合
if len(validTokens) == 0 {
s.redis.Del(ctx, userTokensKey)
} else {
// 重新设置集合
s.redis.Del(ctx, userTokensKey)
for _, token := range validTokens {
s.redis.SAdd(ctx, userTokensKey, token)
}
}
// 如果超过限制删除最旧的Token这里简化处理可以根据createdAt排序
if len(validTokens) > s.maxTokensPerUser {
tokensToDelete := validTokens[s.maxTokensPerUser:]
for _, token := range tokensToDelete {
s.Delete(ctx, token)
}
s.logger.Info("清理用户多余Token",
zap.Int64("userId", userID),
zap.Int("deleted", len(tokensToDelete)),
)
}
}
// getTokenKey 生成Token的Redis Key
func (s *TokenStoreRedis) getTokenKey(accessToken string) string {
return s.keyPrefix + accessToken
}
// getUserTokensKey 生成用户Token集合的Redis Key
func (s *TokenStoreRedis) getUserTokensKey(userID int64) string {
return fmt.Sprintf("user:%d:tokens", userID)
}

View File

@@ -0,0 +1,47 @@
package config
import (
"os"
"testing"
"github.com/spf13/viper"
)
// 重置 viper避免测试间干扰
func resetViper() {
viper.Reset()
}
func TestLoad_DefaultsAndBucketsOverride(t *testing.T) {
resetViper()
// 设置部分环境变量覆盖
_ = os.Setenv("RUSTFS_BUCKET_TEXTURES", "tex-bkt")
_ = os.Setenv("RUSTFS_BUCKET_AVATARS", "ava-bkt")
_ = os.Setenv("DATABASE_MAX_IDLE_CONNS", "20")
_ = os.Setenv("DATABASE_MAX_OPEN_CONNS", "50")
_ = os.Setenv("DATABASE_CONN_MAX_LIFETIME", "2h")
_ = os.Setenv("DATABASE_CONN_MAX_IDLE_TIME", "30m")
cfg, err := Load()
if err != nil {
t.Fatalf("Load err: %v", err)
}
// 默认值检查
if cfg.Server.Port == "" || cfg.Database.Driver == "" || cfg.Redis.Host == "" {
t.Fatalf("expected defaults filled: %+v", cfg)
}
// 覆盖检查
if cfg.RustFS.Buckets["textures"] != "tex-bkt" || cfg.RustFS.Buckets["avatars"] != "ava-bkt" {
t.Fatalf("buckets override failed: %+v", cfg.RustFS.Buckets)
}
if cfg.Database.MaxIdleConns != 20 || cfg.Database.MaxOpenConns != 50 {
t.Fatalf("db pool override failed: %+v", cfg.Database)
}
if cfg.Database.ConnMaxLifetime.String() != "2h0m0s" || cfg.Database.ConnMaxIdleTime.String() != "30m0s" {
t.Fatalf("db duration override failed: %v %v", cfg.Database.ConnMaxLifetime, cfg.Database.ConnMaxIdleTime)
}
}

View File

@@ -14,12 +14,24 @@ type CacheConfig struct {
Prefix string // 缓存键前缀 Prefix string // 缓存键前缀
Expiration time.Duration // 过期时间 Expiration time.Duration // 过期时间
Enabled bool // 是否启用缓存 Enabled bool // 是否启用缓存
Policy CachePolicy // 缓存策略(可选,不配置则回落到 Expiration
}
// CachePolicy 缓存策略,用于为不同实体设置默认 TTL
type CachePolicy struct {
UserTTL time.Duration
UserEmailTTL time.Duration
ProfileTTL time.Duration
ProfileListTTL time.Duration
TextureTTL time.Duration
TextureListTTL time.Duration
} }
// CacheManager 缓存管理器 // CacheManager 缓存管理器
type CacheManager struct { type CacheManager struct {
redis *redis.Client redis *redis.Client
config CacheConfig config CacheConfig
Policy CachePolicy
} }
// NewCacheManager 创建缓存管理器 // NewCacheManager 创建缓存管理器
@@ -31,9 +43,33 @@ func NewCacheManager(redisClient *redis.Client, config CacheConfig) *CacheManage
config.Expiration = 5 * time.Minute config.Expiration = 5 * time.Minute
} }
// 填充默认策略(未配置时退回全局过期时间)
applyPolicyDefaults := func(p *CachePolicy) {
if p.UserTTL == 0 {
p.UserTTL = config.Expiration
}
if p.UserEmailTTL == 0 {
p.UserEmailTTL = config.Expiration
}
if p.ProfileTTL == 0 {
p.ProfileTTL = config.Expiration
}
if p.ProfileListTTL == 0 {
p.ProfileListTTL = config.Expiration
}
if p.TextureTTL == 0 {
p.TextureTTL = config.Expiration
}
if p.TextureListTTL == 0 {
p.TextureListTTL = config.Expiration
}
}
applyPolicyDefaults(&config.Policy)
return &CacheManager{ return &CacheManager{
redis: redisClient, redis: redisClient,
config: config, config: config,
Policy: config.Policy,
} }
} }
@@ -56,6 +92,14 @@ func (cm *CacheManager) Get(ctx context.Context, key string, dest interface{}) e
return json.Unmarshal(data, dest) return json.Unmarshal(data, dest)
} }
// TryGet 获取缓存,命中时返回 true不视为错误
func (cm *CacheManager) TryGet(ctx context.Context, key string, dest interface{}) (bool, error) {
if err := cm.Get(ctx, key, dest); err != nil {
return false, err
}
return true, nil
}
// Set 设置缓存 // Set 设置缓存
func (cm *CacheManager) Set(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error { func (cm *CacheManager) Set(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error {
if !cm.config.Enabled || cm.redis == nil { if !cm.config.Enabled || cm.redis == nil {
@@ -75,6 +119,13 @@ func (cm *CacheManager) Set(ctx context.Context, key string, value interface{},
return cm.redis.Set(ctx, cm.buildKey(key), data, exp) return cm.redis.Set(ctx, cm.buildKey(key), data, exp)
} }
// SetAsync 异步设置缓存,避免在主请求链路阻塞
func (cm *CacheManager) SetAsync(ctx context.Context, key string, value interface{}, expiration ...time.Duration) {
go func() {
_ = cm.Set(ctx, key, value, expiration...)
}()
}
// Delete 删除缓存 // Delete 删除缓存
func (cm *CacheManager) Delete(ctx context.Context, keys ...string) error { func (cm *CacheManager) Delete(ctx context.Context, keys ...string) error {
if !cm.config.Enabled || cm.redis == nil { if !cm.config.Enabled || cm.redis == nil {
@@ -187,11 +238,7 @@ func Cached[T any](
} }
// 设置缓存(异步,不阻塞) // 设置缓存(异步,不阻塞)
go func() { cache.SetAsync(context.Background(), key, data, expiration...)
cacheCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_ = cache.Set(cacheCtx, key, data, expiration...)
}()
return data, nil return data, nil
} }
@@ -217,11 +264,7 @@ func CachedList[T any](
} }
// 设置缓存(异步,不阻塞) // 设置缓存(异步,不阻塞)
go func() { cache.SetAsync(context.Background(), key, data, expiration...)
cacheCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_ = cache.Set(cacheCtx, key, data, expiration...)
}()
return data, nil return data, nil
} }
@@ -306,6 +349,11 @@ func (b *CacheKeyBuilder) TextureList(userID int64, page int) string {
return fmt.Sprintf("%stexture:user:%d:page:%d", b.prefix, userID, page) return fmt.Sprintf("%stexture:user:%d:page:%d", b.prefix, userID, page)
} }
// TextureListPattern 构建材质列表缓存键模式(用于批量失效)
func (b *CacheKeyBuilder) TextureListPattern(userID int64) string {
return fmt.Sprintf("%stexture:user:%d:*", b.prefix, userID)
}
// Token 构建令牌缓存键 // Token 构建令牌缓存键
func (b *CacheKeyBuilder) Token(accessToken string) string { func (b *CacheKeyBuilder) Token(accessToken string) string {
return fmt.Sprintf("%stoken:%s", b.prefix, accessToken) return fmt.Sprintf("%stoken:%s", b.prefix, accessToken)

184
pkg/database/cache_test.go Normal file
View File

@@ -0,0 +1,184 @@
package database
import (
"context"
"testing"
"time"
pkgRedis "carrotskin/pkg/redis"
miniredis "github.com/alicebob/miniredis/v2"
goRedis "github.com/redis/go-redis/v9"
)
func newCacheWithMiniRedis(t *testing.T) (*CacheManager, func()) {
t.Helper()
mr, err := miniredis.Run()
if err != nil {
t.Fatalf("failed to start miniredis: %v", err)
}
rdb := goRedis.NewClient(&goRedis.Options{
Addr: mr.Addr(),
})
client := &pkgRedis.Client{Client: rdb}
cache := NewCacheManager(client, CacheConfig{
Prefix: "t:",
Expiration: time.Minute,
Enabled: true,
Policy: CachePolicy{
UserTTL: 2 * time.Minute,
UserEmailTTL: 3 * time.Minute,
ProfileTTL: 2 * time.Minute,
ProfileListTTL: 90 * time.Second,
TextureTTL: 2 * time.Minute,
TextureListTTL: 45 * time.Second,
},
})
cleanup := func() {
_ = rdb.Close()
mr.Close()
}
return cache, cleanup
}
func TestCacheManager_GetSet_TryGet(t *testing.T) {
cache, cleanup := newCacheWithMiniRedis(t)
defer cleanup()
ctx := context.Background()
type User struct {
ID int
Name string
}
u := User{ID: 1, Name: "alice"}
if err := cache.Set(ctx, "user:1", u, 10*time.Second); err != nil {
t.Fatalf("Set err: %v", err)
}
var got User
if err := cache.Get(ctx, "user:1", &got); err != nil {
t.Fatalf("Get err: %v", err)
}
if got != u {
t.Fatalf("unexpected value: %+v", got)
}
var got2 User
ok, err := cache.TryGet(ctx, "user:1", &got2)
if err != nil || !ok {
t.Fatalf("TryGet failed, ok=%v err=%v", ok, err)
}
if got2 != u {
t.Fatalf("unexpected TryGet: %+v", got2)
}
}
func TestCacheManager_DeletePattern(t *testing.T) {
cache, cleanup := newCacheWithMiniRedis(t)
defer cleanup()
ctx := context.Background()
_ = cache.Set(ctx, "user:1", "a", 0)
_ = cache.Set(ctx, "user:2", "b", 0)
_ = cache.Set(ctx, "profile:1", "c", 0)
// 删除 user:* 键
if err := cache.DeletePattern(ctx, "user:*"); err != nil {
t.Fatalf("DeletePattern err: %v", err)
}
var v string
ok, _ := cache.TryGet(ctx, "user:1", &v)
if ok {
t.Fatalf("expected user:1 deleted")
}
ok, _ = cache.TryGet(ctx, "user:2", &v)
if ok {
t.Fatalf("expected user:2 deleted")
}
ok, _ = cache.TryGet(ctx, "profile:1", &v)
if !ok {
t.Fatalf("expected profile:1 kept")
}
}
func TestCachedAndCachedList(t *testing.T) {
cache, cleanup := newCacheWithMiniRedis(t)
defer cleanup()
ctx := context.Background()
callCount := 0
result, err := Cached(ctx, cache, "key1", func() (*string, error) {
callCount++
val := "hello"
return &val, nil
}, cache.Policy.UserTTL)
if err != nil || *result != "hello" || callCount != 1 {
t.Fatalf("Cached first call failed")
}
// 等待缓存写入完成
for i := 0; i < 10; i++ {
var tmp string
if ok, _ := cache.TryGet(ctx, "key1", &tmp); ok {
break
}
time.Sleep(10 * time.Millisecond)
}
// 第二次应命中缓存
_, err = Cached(ctx, cache, "key1", func() (*string, error) {
callCount++
val := "world"
return &val, nil
}, cache.Policy.UserTTL)
if err != nil || callCount != 1 {
t.Fatalf("Cached should hit cache, callCount=%d err=%v", callCount, err)
}
listCall := 0
_, err = CachedList(ctx, cache, "list", func() ([]string, error) {
listCall++
return []string{"a", "b"}, nil
}, cache.Policy.ProfileListTTL)
if err != nil || listCall != 1 {
t.Fatalf("CachedList first call failed")
}
for i := 0; i < 10; i++ {
var tmp []string
if ok, _ := cache.TryGet(ctx, "list", &tmp); ok {
break
}
time.Sleep(10 * time.Millisecond)
}
_, err = CachedList(ctx, cache, "list", func() ([]string, error) {
listCall++
return []string{"c"}, nil
}, cache.Policy.ProfileListTTL)
if err != nil || listCall != 1 {
t.Fatalf("CachedList should hit cache, calls=%d err=%v", listCall, err)
}
}
func TestIncrementWithExpire(t *testing.T) {
cache, cleanup := newCacheWithMiniRedis(t)
defer cleanup()
ctx := context.Background()
val, err := cache.IncrementWithExpire(ctx, "counter", time.Second)
if err != nil || val != 1 {
t.Fatalf("first increment failed, val=%d err=%v", val, err)
}
val, err = cache.IncrementWithExpire(ctx, "counter", time.Second)
if err != nil || val != 2 {
t.Fatalf("second increment failed, val=%d err=%v", val, err)
}
ttl, err := cache.TTL(ctx, "counter")
if err != nil || ttl <= 0 {
t.Fatalf("TTL not set: ttl=%v err=%v", ttl, err)
}
}

View File

@@ -75,7 +75,6 @@ func AutoMigrate(logger *zap.Logger) error {
&model.TextureDownloadLog{}, &model.TextureDownloadLog{},
// 认证相关表 // 认证相关表
&model.Token{},
&model.Client{}, // Client表用于管理Token版本 &model.Client{}, // Client表用于管理Token版本
// Yggdrasil相关表在User之后创建因为它引用User // Yggdrasil相关表在User之后创建因为它引用User

View File

@@ -0,0 +1,24 @@
package database
import (
"testing"
"go.uber.org/zap/zaptest"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// 使用内存 sqlite 验证 AutoMigrate 关键路径,无需真实 Postgres
func TestAutoMigrate_WithSQLite(t *testing.T) {
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
if err != nil {
t.Fatalf("open sqlite err: %v", err)
}
dbInstance = db
defer func() { dbInstance = nil }()
logger := zaptest.NewLogger(t)
if err := AutoMigrate(logger); err != nil {
t.Fatalf("AutoMigrate sqlite err: %v", err)
}
}

View File

@@ -9,6 +9,7 @@ import (
// TestGetDB_NotInitialized 测试未初始化时获取数据库实例 // TestGetDB_NotInitialized 测试未初始化时获取数据库实例
func TestGetDB_NotInitialized(t *testing.T) { func TestGetDB_NotInitialized(t *testing.T) {
dbInstance = nil
_, err := GetDB() _, err := GetDB()
if err == nil { if err == nil {
t.Error("未初始化时应该返回错误") t.Error("未初始化时应该返回错误")
@@ -22,6 +23,7 @@ func TestGetDB_NotInitialized(t *testing.T) {
// TestMustGetDB_Panic 测试MustGetDB在未初始化时panic // TestMustGetDB_Panic 测试MustGetDB在未初始化时panic
func TestMustGetDB_Panic(t *testing.T) { func TestMustGetDB_Panic(t *testing.T) {
dbInstance = nil
defer func() { defer func() {
if r := recover(); r == nil { if r := recover(); r == nil {
t.Error("MustGetDB 应该在未初始化时panic") t.Error("MustGetDB 应该在未初始化时panic")
@@ -33,6 +35,7 @@ func TestMustGetDB_Panic(t *testing.T) {
// TestInit_Database 测试数据库初始化逻辑 // TestInit_Database 测试数据库初始化逻辑
func TestInit_Database(t *testing.T) { func TestInit_Database(t *testing.T) {
dbInstance = nil
cfg := config.DatabaseConfig{ cfg := config.DatabaseConfig{
Driver: "postgres", Driver: "postgres",
Host: "localhost", Host: "localhost",
@@ -53,7 +56,7 @@ func TestInit_Database(t *testing.T) {
// 注意:实际连接可能失败,这是可以接受的 // 注意:实际连接可能失败,这是可以接受的
err := Init(cfg, logger) err := Init(cfg, logger)
if err != nil { if err != nil {
t.Logf("Init() 返回错误(可能正常,如果数据库未运行): %v", err) t.Skipf("数据库未运行,跳过连接测试: %v", err)
} }
} }
@@ -82,4 +85,3 @@ func TestClose_NotInitialized(t *testing.T) {
t.Errorf("Close() 在未初始化时应该返回nil实际返回: %v", err) t.Errorf("Close() 在未初始化时应该返回nil实际返回: %v", err)
} }
} }

56
pkg/email/email_test.go Normal file
View File

@@ -0,0 +1,56 @@
package email
import (
"strings"
"sync"
"testing"
"carrotskin/pkg/config"
"go.uber.org/zap"
)
func resetEmailOnce() {
serviceInstance = nil
once = sync.Once{}
}
func TestEmailManager_Disabled(t *testing.T) {
resetEmailOnce()
cfg := config.EmailConfig{Enabled: false}
if err := Init(cfg, zap.NewNop()); err != nil {
t.Fatalf("Init disabled err: %v", err)
}
svc := MustGetService()
if err := svc.SendVerificationCode("to@test.com", "123456", "email_verification"); err == nil {
t.Fatalf("expected error when disabled")
}
}
func TestEmailManager_SendFailsWithInvalidSMTP(t *testing.T) {
resetEmailOnce()
cfg := config.EmailConfig{
Enabled: true,
SMTPHost: "127.0.0.1",
SMTPPort: 1, // invalid/closed port to trigger error quickly
Username: "user",
Password: "pwd",
FromName: "name",
}
_ = Init(cfg, zap.NewNop())
svc := MustGetService()
if err := svc.SendVerificationCode("to@test.com", "123456", "reset_password"); err == nil {
t.Fatalf("expected send error with invalid smtp")
}
}
func TestEmailManager_SubjectAndBody(t *testing.T) {
svc := &Service{cfg: config.EmailConfig{FromName: "name", Username: "user"}, logger: zap.NewNop()}
if subj := svc.getSubject("email_verification"); subj == "" {
t.Fatalf("subject empty")
}
body := svc.getBody("123456", "change_email")
if !strings.Contains(body, "123456") || !strings.Contains(body, "更换邮箱") {
t.Fatalf("body content mismatch")
}
}

View File

@@ -2,13 +2,20 @@ package email
import ( import (
"carrotskin/pkg/config" "carrotskin/pkg/config"
"sync"
"testing" "testing"
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
) )
func resetEmail() {
serviceInstance = nil
once = sync.Once{}
}
// TestGetService_NotInitialized 测试未初始化时获取邮件服务 // TestGetService_NotInitialized 测试未初始化时获取邮件服务
func TestGetService_NotInitialized(t *testing.T) { func TestGetService_NotInitialized(t *testing.T) {
resetEmail()
_, err := GetService() _, err := GetService()
if err == nil { if err == nil {
t.Error("未初始化时应该返回错误") t.Error("未初始化时应该返回错误")
@@ -22,6 +29,7 @@ func TestGetService_NotInitialized(t *testing.T) {
// TestMustGetService_Panic 测试MustGetService在未初始化时panic // TestMustGetService_Panic 测试MustGetService在未初始化时panic
func TestMustGetService_Panic(t *testing.T) { func TestMustGetService_Panic(t *testing.T) {
resetEmail()
defer func() { defer func() {
if r := recover(); r == nil { if r := recover(); r == nil {
t.Error("MustGetService 应该在未初始化时panic") t.Error("MustGetService 应该在未初始化时panic")
@@ -33,6 +41,7 @@ func TestMustGetService_Panic(t *testing.T) {
// TestInit_Email 测试邮件服务初始化 // TestInit_Email 测试邮件服务初始化
func TestInit_Email(t *testing.T) { func TestInit_Email(t *testing.T) {
resetEmail()
cfg := config.EmailConfig{ cfg := config.EmailConfig{
Enabled: false, Enabled: false,
SMTPHost: "smtp.example.com", SMTPHost: "smtp.example.com",
@@ -58,4 +67,3 @@ func TestInit_Email(t *testing.T) {
t.Error("GetService() 返回的服务不应为nil") t.Error("GetService() 返回的服务不应为nil")
} }
} }

View File

@@ -3,8 +3,11 @@ package redis
import ( import (
"carrotskin/pkg/config" "carrotskin/pkg/config"
"fmt" "fmt"
"os"
"sync" "sync"
"github.com/alicebob/miniredis/v2"
redis9 "github.com/redis/go-redis/v9"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -15,19 +18,69 @@ var (
once sync.Once once sync.Once
// initError 初始化错误 // initError 初始化错误
initError error initError error
// miniredisInstance 用于测试/开发环境
miniredisInstance *miniredis.Miniredis
) )
// Init 初始化Redis客户端线程安全只会执行一次 // Init 初始化Redis客户端线程安全只会执行一次
// 如果Redis连接失败且环境为测试/开发则回退到miniredis
func Init(cfg config.RedisConfig, logger *zap.Logger) error { func Init(cfg config.RedisConfig, logger *zap.Logger) error {
var err error
once.Do(func() { once.Do(func() {
clientInstance, initError = New(cfg, logger) // 尝试连接真实Redis
if initError != nil { clientInstance, err = New(cfg, logger)
if err != nil {
logger.Warn("Redis连接失败尝试使用miniredis回退", zap.Error(err))
// 检查是否允许回退到miniredis仅开发/测试环境)
if allowFallbackToMiniRedis() {
clientInstance, err = initMiniRedis(logger)
if err != nil {
initError = fmt.Errorf("Redis和miniredis都初始化失败: %w", err)
logger.Error("miniredis初始化失败", zap.Error(initError))
return return
} }
logger.Info("已回退到miniredis用于开发/测试环境")
} else {
initError = fmt.Errorf("Redis连接失败且不允许回退: %w", err)
logger.Error("Redis连接失败", zap.Error(initError))
return
}
}
}) })
return initError return initError
} }
// allowFallbackToMiniRedis 检查是否允许回退到miniredis
func allowFallbackToMiniRedis() bool {
// 检查环境变量
env := os.Getenv("ENVIRONMENT")
return env == "development" || env == "test" || env == "dev" ||
os.Getenv("USE_MINIREDIS") == "true"
}
// initMiniRedis 初始化miniredis用于开发/测试环境)
func initMiniRedis(logger *zap.Logger) (*Client, error) {
var err error
miniredisInstance, err = miniredis.Run()
if err != nil {
return nil, fmt.Errorf("启动miniredis失败: %w", err)
}
// 创建Redis客户端连接到miniredis
redisClient := redis9.NewClient(&redis9.Options{
Addr: miniredisInstance.Addr(),
})
client := &Client{
Client: redisClient,
logger: logger,
}
logger.Info("miniredis已启动", zap.String("addr", miniredisInstance.Addr()))
return client, nil
}
// GetClient 获取Redis客户端实例线程安全 // GetClient 获取Redis客户端实例线程安全
func GetClient() (*Client, error) { func GetClient() (*Client, error) {
if clientInstance == nil { if clientInstance == nil {
@@ -45,7 +98,21 @@ func MustGetClient() *Client {
return client return client
} }
// Close 关闭Redis连接包括miniredis如果使用了
func Close() error {
var err error
if miniredisInstance != nil {
miniredisInstance.Close()
miniredisInstance = nil
}
if clientInstance != nil {
err = clientInstance.Close()
clientInstance = nil
}
return err
}
// IsUsingMiniRedis 检查是否使用了miniredis
func IsUsingMiniRedis() bool {
return miniredisInstance != nil
}

71
pkg/storage/minio_test.go Normal file
View File

@@ -0,0 +1,71 @@
package storage
import (
"context"
"testing"
"time"
"carrotskin/pkg/config"
"github.com/minio/minio-go/v7"
)
// 使用 nil client 仅测试纯函数和错误分支
func TestStorage_GetBucketAndBuildURL(t *testing.T) {
s := &StorageClient{
client: (*minio.Client)(nil),
buckets: map[string]string{"textures": "tex-bkt"},
publicURL: "http://localhost:9000",
}
if b, err := s.GetBucket("textures"); err != nil || b != "tex-bkt" {
t.Fatalf("GetBucket mismatch: %v %s", err, b)
}
if _, err := s.GetBucket("missing"); err == nil {
t.Fatalf("expected error for missing bucket")
}
if url := s.BuildFileURL("tex-bkt", "obj"); url != "http://localhost:9000/tex-bkt/obj" {
t.Fatalf("BuildFileURL mismatch: %s", url)
}
}
func TestNewStorage_SkipConnectWhenNoCreds(t *testing.T) {
// 当 AccessKey/Secret 为空时跳过 ListBuckets 测试,避免真实依赖
cfg := config.RustFSConfig{
Endpoint: "127.0.0.1:9000",
Buckets: map[string]string{"avatars": "ava", "textures": "tex"},
UseSSL: false,
}
if _, err := NewStorage(cfg); err != nil {
t.Fatalf("NewStorage should not error when creds empty: %v", err)
}
}
func TestPresignedHelpers_WithNilClient(t *testing.T) {
s := &StorageClient{
client: (*minio.Client)(nil),
buckets: map[string]string{"textures": "tex-bkt"},
publicURL: "http://localhost:9000",
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// 预期会panicnil client用recover捕获
func() {
defer func() {
if r := recover(); r == nil {
t.Fatalf("GeneratePresignedURL expected panic with nil client")
}
}()
_, _ = s.GeneratePresignedURL(ctx, "tex-bkt", "obj", time.Minute)
}()
func() {
defer func() {
if r := recover(); r == nil {
t.Fatalf("GeneratePresignedPostURL expected panic with nil client")
}
}()
_, _ = s.GeneratePresignedPostURL(ctx, "tex-bkt", "obj", 0, 10, time.Minute)
}()
}