refactor: Remove Token management and integrate Redis for authentication
- Deleted the Token model and its repository, transitioning to a Redis-based token management system. - Updated the service layer to utilize Redis for token storage, enhancing performance and scalability. - Refactored the container to remove TokenRepository and integrate the new token service. - Cleaned up the Dockerfile and other files by removing unnecessary whitespace and comments. - Enhanced error handling and logging for Redis initialization and usage.
This commit is contained in:
@@ -29,7 +29,6 @@ type Container struct {
|
||||
UserRepo repository.UserRepository
|
||||
ProfileRepo repository.ProfileRepository
|
||||
TextureRepo repository.TextureRepository
|
||||
TokenRepo repository.TokenRepository
|
||||
ClientRepo repository.ClientRepository
|
||||
ConfigRepo repository.SystemConfigRepository
|
||||
YggdrasilRepo repository.YggdrasilRepository
|
||||
@@ -61,6 +60,14 @@ func NewContainer(
|
||||
Prefix: "carrotskin:",
|
||||
Expiration: 5 * time.Minute,
|
||||
Enabled: true,
|
||||
Policy: database.CachePolicy{
|
||||
UserTTL: 5 * time.Minute,
|
||||
UserEmailTTL: 5 * time.Minute,
|
||||
ProfileTTL: 5 * time.Minute,
|
||||
ProfileListTTL: 3 * time.Minute,
|
||||
TextureTTL: 5 * time.Minute,
|
||||
TextureListTTL: 2 * time.Minute,
|
||||
},
|
||||
})
|
||||
|
||||
c := &Container{
|
||||
@@ -76,7 +83,6 @@ func NewContainer(
|
||||
c.UserRepo = repository.NewUserRepository(db)
|
||||
c.ProfileRepo = repository.NewProfileRepository(db)
|
||||
c.TextureRepo = repository.NewTextureRepository(db)
|
||||
c.TokenRepo = repository.NewTokenRepository(db)
|
||||
c.ClientRepo = repository.NewClientRepository(db)
|
||||
c.ConfigRepo = repository.NewSystemConfigRepository(db)
|
||||
c.YggdrasilRepo = repository.NewYggdrasilRepository(db)
|
||||
@@ -98,10 +104,24 @@ func NewContainer(
|
||||
logger.Fatal("获取Yggdrasil私钥失败", zap.Error(err))
|
||||
}
|
||||
yggdrasilJWT := auth.NewYggdrasilJWTService(privateKey, "carrotskin")
|
||||
c.TokenService = service.NewTokenServiceJWT(c.TokenRepo, c.ClientRepo, c.ProfileRepo, yggdrasilJWT, logger)
|
||||
|
||||
// 创建Redis Token存储(必须使用Redis,包括miniredis回退)
|
||||
if redisClient == nil {
|
||||
logger.Fatal("Redis客户端未初始化,无法创建Token服务")
|
||||
}
|
||||
|
||||
tokenStore := auth.NewTokenStoreRedis(
|
||||
redisClient,
|
||||
logger,
|
||||
auth.WithKeyPrefix("token:"),
|
||||
auth.WithDefaultTTL(24*time.Hour),
|
||||
auth.WithStaleTTL(30*24*time.Hour),
|
||||
auth.WithMaxTokensPerUser(10),
|
||||
)
|
||||
c.TokenService = service.NewTokenServiceRedis(tokenStore, c.ClientRepo, c.ProfileRepo, yggdrasilJWT, logger)
|
||||
|
||||
// 使用组合服务(内部包含认证、会话、序列化、证书服务)
|
||||
c.YggdrasilService = service.NewYggdrasilServiceComposite(db, c.UserRepo, c.ProfileRepo, c.TokenRepo, c.YggdrasilRepo, c.SignatureService, redisClient, logger)
|
||||
c.YggdrasilService = service.NewYggdrasilServiceComposite(db, c.UserRepo, c.ProfileRepo, c.YggdrasilRepo, c.SignatureService, redisClient, logger, c.TokenService)
|
||||
|
||||
// 初始化其他服务
|
||||
c.SecurityService = service.NewSecurityService(redisClient)
|
||||
@@ -186,13 +206,6 @@ func WithTextureRepo(repo repository.TextureRepository) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithTokenRepo 设置令牌仓储
|
||||
func WithTokenRepo(repo repository.TokenRepository) Option {
|
||||
return func(c *Container) {
|
||||
c.TokenRepo = repo
|
||||
}
|
||||
}
|
||||
|
||||
// WithConfigRepo 设置系统配置仓储
|
||||
func WithConfigRepo(repo repository.SystemConfigRepository) Option {
|
||||
return func(c *Container) {
|
||||
|
||||
38
internal/errors/errors_test.go
Normal file
38
internal/errors/errors_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAppErrorBasics(t *testing.T) {
|
||||
root := errors.New("root")
|
||||
appErr := NewBadRequest("bad", root)
|
||||
|
||||
if appErr.Code != 400 || appErr.Message != "bad" {
|
||||
t.Fatalf("unexpected appErr fields: %+v", appErr)
|
||||
}
|
||||
if got := appErr.Error(); got != "bad: root" {
|
||||
t.Fatalf("unexpected Error(): %s", got)
|
||||
}
|
||||
if !Is(appErr, root) {
|
||||
t.Fatalf("Is should match wrapped error")
|
||||
}
|
||||
var target *AppError
|
||||
if !As(appErr, &target) {
|
||||
t.Fatalf("As should succeed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrap(t *testing.T) {
|
||||
if Wrap(nil, "msg") != nil {
|
||||
t.Fatalf("Wrap nil should return nil")
|
||||
}
|
||||
err := errors.New("base")
|
||||
wrapped := Wrap(err, "ctx")
|
||||
if wrapped.Error() != "ctx: base" {
|
||||
t.Fatalf("wrap message mismatch: %v", wrapped)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
27
internal/handler/swagger_test.go
Normal file
27
internal/handler/swagger_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 仅验证降级路径(未初始化依赖时的响应)
|
||||
func TestHealthCheck_Degraded(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.GET("/health", HealthCheck)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("expected 503 when dependencies missing, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -29,3 +29,10 @@ func (Client) TableName() string {
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// Token Yggdrasil 认证令牌模型
|
||||
type Token struct {
|
||||
AccessToken string `gorm:"column:access_token;type:text;primaryKey" json:"access_token"` // 改为text以支持JWT长度
|
||||
UserID int64 `gorm:"column:user_id;not null;index:idx_tokens_user_id" json:"user_id"`
|
||||
ClientToken string `gorm:"column:client_token;type:varchar(64);not null;index:idx_tokens_client_token" json:"client_token"`
|
||||
ProfileId string `gorm:"column:profile_id;type:varchar(36);index:idx_tokens_profile_id" json:"profile_id"` // 改为可空
|
||||
Version int `gorm:"column:version;not null;default:0;index:idx_tokens_version" json:"version"` // 新增:版本号
|
||||
Usable bool `gorm:"column:usable;not null;default:true;index:idx_tokens_usable" json:"usable"`
|
||||
IssueDate time.Time `gorm:"column:issue_date;type:timestamp;not null;default:CURRENT_TIMESTAMP;index:idx_tokens_issue_date,sort:desc" json:"issue_date"`
|
||||
ExpiresAt *time.Time `gorm:"column:expires_at;type:timestamp" json:"expires_at,omitempty"` // 新增:过期时间
|
||||
StaleAt *time.Time `gorm:"column:stale_at;type:timestamp" json:"stale_at,omitempty"` // 新增:过期但可用时间
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE" json:"user,omitempty"`
|
||||
Profile *Profile `gorm:"foreignKey:ProfileId;references:UUID;constraint:OnDelete:CASCADE" json:"profile,omitempty"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (Token) TableName() string { return "tokens" }
|
||||
18
internal/model/yggdrasil_test.go
Normal file
18
internal/model/yggdrasil_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateRandomPassword(t *testing.T) {
|
||||
pwd := GenerateRandomPassword(16)
|
||||
if len(pwd) != 16 {
|
||||
t.Fatalf("length mismatch: %d", len(pwd))
|
||||
}
|
||||
for _, ch := range pwd {
|
||||
if !strings.ContainsRune(passwordChars, ch) {
|
||||
t.Fatalf("unexpected char: %c", ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -67,18 +67,6 @@ type TextureRepository interface {
|
||||
CountByUploaderID(ctx context.Context, uploaderID int64) (int64, error)
|
||||
}
|
||||
|
||||
// TokenRepository 令牌仓储接口
|
||||
type TokenRepository interface {
|
||||
Create(ctx context.Context, token *model.Token) error
|
||||
FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error)
|
||||
GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error)
|
||||
GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error)
|
||||
GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error)
|
||||
DeleteByAccessToken(ctx context.Context, accessToken string) error
|
||||
DeleteByUserID(ctx context.Context, userId int64) error
|
||||
BatchDelete(ctx context.Context, accessTokens []string) (int64, error)
|
||||
}
|
||||
|
||||
// SystemConfigRepository 系统配置仓储接口
|
||||
type SystemConfigRepository interface {
|
||||
GetByKey(ctx context.Context, key string) (*model.SystemConfig, error)
|
||||
|
||||
278
internal/repository/repository_sqlite_test.go
Normal file
278
internal/repository/repository_sqlite_test.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/testutil"
|
||||
)
|
||||
|
||||
func TestUserRepository_BasicAndPoints(t *testing.T) {
|
||||
db := testutil.NewTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &model.User{Username: "u1", Email: "e1@test.com", Password: "pwd", Status: 1}
|
||||
if err := repo.Create(ctx, user); err != nil {
|
||||
t.Fatalf("create user err: %v", err)
|
||||
}
|
||||
|
||||
if u, err := repo.FindByID(ctx, user.ID); err != nil || u.Username != "u1" {
|
||||
t.Fatalf("FindByID mismatch: %v %+v", err, u)
|
||||
}
|
||||
if u, err := repo.FindByUsername(ctx, "u1"); err != nil || u.Email != "e1@test.com" {
|
||||
t.Fatalf("FindByUsername mismatch")
|
||||
}
|
||||
if u, err := repo.FindByEmail(ctx, "e1@test.com"); err != nil || u.ID != user.ID {
|
||||
t.Fatalf("FindByEmail mismatch")
|
||||
}
|
||||
|
||||
if err := repo.UpdateFields(ctx, user.ID, map[string]interface{}{"avatar": "a.png"}); err != nil {
|
||||
t.Fatalf("UpdateFields err: %v", err)
|
||||
}
|
||||
|
||||
if _, err := repo.BatchUpdate(ctx, []int64{user.ID}, map[string]interface{}{"status": 2}); err != nil {
|
||||
t.Fatalf("BatchUpdate err: %v", err)
|
||||
}
|
||||
|
||||
// 积分增加
|
||||
if err := repo.UpdatePoints(ctx, user.ID, 10, "add", "bonus"); err != nil {
|
||||
t.Fatalf("UpdatePoints add err: %v", err)
|
||||
}
|
||||
// 积分不足场景
|
||||
if err := repo.UpdatePoints(ctx, user.ID, -100, "sub", "penalty"); err == nil {
|
||||
t.Fatalf("expected insufficient points error")
|
||||
}
|
||||
|
||||
if list, err := repo.FindByIDs(ctx, []int64{user.ID}); err != nil || len(list) != 1 {
|
||||
t.Fatalf("FindByIDs mismatch: %v %d", err, len(list))
|
||||
}
|
||||
if list, err := repo.FindByIDs(ctx, []int64{}); err != nil || len(list) != 0 {
|
||||
t.Fatalf("FindByIDs empty mismatch: %v %d", err, len(list))
|
||||
}
|
||||
|
||||
// 软删除
|
||||
if err := repo.Delete(ctx, user.ID); err != nil {
|
||||
t.Fatalf("Delete err: %v", err)
|
||||
}
|
||||
deleted, _ := repo.FindByID(ctx, user.ID)
|
||||
if deleted != nil {
|
||||
t.Fatalf("expected deleted user filtered out")
|
||||
}
|
||||
|
||||
// 批量操作边界
|
||||
if _, err := repo.BatchUpdate(ctx, []int64{}, map[string]interface{}{"status": 1}); err != nil {
|
||||
t.Fatalf("BatchUpdate empty should not error: %v", err)
|
||||
}
|
||||
if _, err := repo.BatchDelete(ctx, []int64{}); err != nil {
|
||||
t.Fatalf("BatchDelete empty should not error: %v", err)
|
||||
}
|
||||
|
||||
// 日志写入
|
||||
_ = repo.CreateLoginLog(ctx, &model.UserLoginLog{UserID: user.ID, IPAddress: "127.0.0.1"})
|
||||
_ = repo.CreatePointLog(ctx, &model.UserPointLog{UserID: user.ID, Amount: 1, ChangeType: "add"})
|
||||
}
|
||||
|
||||
func TestProfileRepository_Basic(t *testing.T) {
|
||||
db := testutil.NewTestDB(t)
|
||||
userRepo := NewUserRepository(db)
|
||||
profileRepo := NewProfileRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
u := &model.User{Username: "u2", Email: "u2@test.com", Password: "pwd", Status: 1}
|
||||
_ = userRepo.Create(ctx, u)
|
||||
|
||||
p := &model.Profile{UUID: "p-uuid", UserID: u.ID, Name: "hero", IsActive: false}
|
||||
if err := profileRepo.Create(ctx, p); err != nil {
|
||||
t.Fatalf("create profile err: %v", err)
|
||||
}
|
||||
|
||||
if got, err := profileRepo.FindByUUID(ctx, "p-uuid"); err != nil || got.Name != "hero" {
|
||||
t.Fatalf("FindByUUID mismatch: %v %+v", err, got)
|
||||
}
|
||||
if list, err := profileRepo.FindByUserID(ctx, u.ID); err != nil || len(list) != 1 {
|
||||
t.Fatalf("FindByUserID mismatch")
|
||||
}
|
||||
if count, err := profileRepo.CountByUserID(ctx, u.ID); err != nil || count != 1 {
|
||||
t.Fatalf("CountByUserID mismatch: %d err=%v", count, err)
|
||||
}
|
||||
|
||||
if err := profileRepo.SetActive(ctx, "p-uuid", u.ID); err != nil {
|
||||
t.Fatalf("SetActive err: %v", err)
|
||||
}
|
||||
if err := profileRepo.UpdateLastUsedAt(ctx, "p-uuid"); err != nil {
|
||||
t.Fatalf("UpdateLastUsedAt err: %v", err)
|
||||
}
|
||||
|
||||
if got, err := profileRepo.FindByName(ctx, "hero"); err != nil || got == nil {
|
||||
t.Fatalf("FindByName mismatch")
|
||||
}
|
||||
if list, err := profileRepo.FindByUUIDs(ctx, []string{"p-uuid"}); err != nil || len(list) != 1 {
|
||||
t.Fatalf("FindByUUIDs mismatch")
|
||||
}
|
||||
if _, err := profileRepo.BatchUpdate(ctx, []string{"p-uuid"}, map[string]interface{}{"name": "hero2"}); err != nil {
|
||||
t.Fatalf("BatchUpdate profile err: %v", err)
|
||||
}
|
||||
|
||||
if err := profileRepo.Delete(ctx, "p-uuid"); err != nil {
|
||||
t.Fatalf("Delete err: %v", err)
|
||||
}
|
||||
if _, err := profileRepo.BatchDelete(ctx, []string{}); err != nil {
|
||||
t.Fatalf("BatchDelete empty err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextureRepository_Basic(t *testing.T) {
|
||||
db := testutil.NewTestDB(t)
|
||||
userRepo := NewUserRepository(db)
|
||||
textureRepo := NewTextureRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
u := &model.User{Username: "u3", Email: "u3@test.com", Password: "pwd", Status: 1}
|
||||
_ = userRepo.Create(ctx, u)
|
||||
|
||||
tex := &model.Texture{
|
||||
UploaderID: u.ID,
|
||||
Name: "tex",
|
||||
Hash: "hash1",
|
||||
URL: "url1",
|
||||
Type: model.TextureTypeSkin,
|
||||
IsPublic: true,
|
||||
Status: 1,
|
||||
}
|
||||
if err := textureRepo.Create(ctx, tex); err != nil {
|
||||
t.Fatalf("create texture err: %v", err)
|
||||
}
|
||||
|
||||
if got, _ := textureRepo.FindByHash(ctx, "hash1"); got == nil || got.ID != tex.ID {
|
||||
t.Fatalf("FindByHash mismatch")
|
||||
}
|
||||
if got, _ := textureRepo.FindByHashAndUploaderID(ctx, "hash1", u.ID); got == nil {
|
||||
t.Fatalf("FindByHashAndUploaderID mismatch")
|
||||
}
|
||||
|
||||
_ = textureRepo.IncrementFavoriteCount(ctx, tex.ID)
|
||||
_ = textureRepo.DecrementFavoriteCount(ctx, tex.ID)
|
||||
_ = textureRepo.IncrementDownloadCount(ctx, tex.ID)
|
||||
_ = textureRepo.CreateDownloadLog(ctx, &model.TextureDownloadLog{TextureID: tex.ID, UserID: &u.ID, IPAddress: "127.0.0.1"})
|
||||
|
||||
// 收藏
|
||||
_ = textureRepo.AddFavorite(ctx, u.ID, tex.ID)
|
||||
if fav, err := textureRepo.IsFavorited(ctx, u.ID, tex.ID); err == nil {
|
||||
if !fav {
|
||||
t.Fatalf("IsFavorited expected true")
|
||||
}
|
||||
} else {
|
||||
t.Skipf("IsFavorited not supported by sqlite: %v", err)
|
||||
}
|
||||
_ = textureRepo.RemoveFavorite(ctx, u.ID, tex.ID)
|
||||
|
||||
// 批量更新与删除
|
||||
if affected, err := textureRepo.BatchUpdate(ctx, []int64{tex.ID}, map[string]interface{}{"name": "tex-new"}); err != nil || affected != 1 {
|
||||
t.Fatalf("BatchUpdate mismatch, affected=%d err=%v", affected, err)
|
||||
}
|
||||
if affected, err := textureRepo.BatchDelete(ctx, []int64{tex.ID}); err != nil || affected != 1 {
|
||||
t.Fatalf("BatchDelete mismatch, affected=%d err=%v", affected, err)
|
||||
}
|
||||
|
||||
// 搜索与收藏列表
|
||||
_ = textureRepo.Create(ctx, &model.Texture{
|
||||
UploaderID: u.ID,
|
||||
Name: "search-me",
|
||||
Hash: "hash2",
|
||||
URL: "url2",
|
||||
Type: model.TextureTypeCape,
|
||||
IsPublic: true,
|
||||
Status: 1,
|
||||
})
|
||||
if list, total, err := textureRepo.Search(ctx, "search", model.TextureTypeCape, true, 1, 10); err != nil || total == 0 || len(list) == 0 {
|
||||
t.Fatalf("Search mismatch, total=%d len=%d err=%v", total, len(list), err)
|
||||
}
|
||||
_ = textureRepo.AddFavorite(ctx, u.ID, tex.ID+1)
|
||||
if favList, total, err := textureRepo.GetUserFavorites(ctx, u.ID, 1, 10); err != nil || total == 0 || len(favList) == 0 {
|
||||
t.Fatalf("GetUserFavorites mismatch, total=%d len=%d err=%v", total, len(favList), err)
|
||||
}
|
||||
if _, total, err := textureRepo.Search(ctx, "", model.TextureTypeSkin, true, 1, 10); err != nil || total < 2 {
|
||||
t.Fatalf("Search fallback mismatch")
|
||||
}
|
||||
|
||||
// 列表与计数
|
||||
if _, total, err := textureRepo.FindByUploaderID(ctx, u.ID, 1, 10); err != nil || total != 1 {
|
||||
t.Fatalf("FindByUploaderID mismatch")
|
||||
}
|
||||
if cnt, err := textureRepo.CountByUploaderID(ctx, u.ID); err != nil || cnt != 1 {
|
||||
t.Fatalf("CountByUploaderID mismatch")
|
||||
}
|
||||
|
||||
_ = textureRepo.Delete(ctx, tex.ID)
|
||||
}
|
||||
|
||||
func TestSystemConfigRepository_Basic(t *testing.T) {
|
||||
db := testutil.NewTestDB(t)
|
||||
repo := NewSystemConfigRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
cfg := &model.SystemConfig{Key: "site_name", Value: "Carrot", IsPublic: true}
|
||||
if err := repo.Update(ctx, cfg); err != nil {
|
||||
t.Fatalf("Update err: %v", err)
|
||||
}
|
||||
if v, err := repo.GetByKey(ctx, "site_name"); err != nil || v.Value != "Carrot" {
|
||||
t.Fatalf("GetByKey mismatch")
|
||||
}
|
||||
_ = repo.UpdateValue(ctx, "site_name", "Carrot2")
|
||||
if list, _ := repo.GetPublic(ctx); len(list) == 0 {
|
||||
t.Fatalf("GetPublic expected entries")
|
||||
}
|
||||
if all, _ := repo.GetAll(ctx); len(all) == 0 {
|
||||
t.Fatalf("GetAll expected entries")
|
||||
}
|
||||
if v, _ := repo.GetByKey(ctx, "site_name"); v.Value != "Carrot2" {
|
||||
t.Fatalf("UpdateValue not applied")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientRepository_Basic(t *testing.T) {
|
||||
db := testutil.NewTestDB(t)
|
||||
repo := NewClientRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
client := &model.Client{UUID: "c-uuid", ClientToken: "ct-1", UserID: 9, Version: 1}
|
||||
if err := repo.Create(ctx, client); err != nil {
|
||||
t.Fatalf("Create client err: %v", err)
|
||||
}
|
||||
if got, _ := repo.FindByClientToken(ctx, "ct-1"); got == nil || got.UUID != "c-uuid" {
|
||||
t.Fatalf("FindByClientToken mismatch")
|
||||
}
|
||||
if got, _ := repo.FindByUUID(ctx, "c-uuid"); got == nil || got.ClientToken != "ct-1" {
|
||||
t.Fatalf("FindByUUID mismatch")
|
||||
}
|
||||
if list, _ := repo.FindByUserID(ctx, 9); len(list) != 1 {
|
||||
t.Fatalf("FindByUserID mismatch")
|
||||
}
|
||||
_ = repo.IncrementVersion(ctx, "c-uuid")
|
||||
updated, _ := repo.FindByUUID(ctx, "c-uuid")
|
||||
if updated.Version != 2 {
|
||||
t.Fatalf("IncrementVersion not applied, got %d", updated.Version)
|
||||
}
|
||||
_ = repo.DeleteByClientToken(ctx, "ct-1")
|
||||
_ = repo.DeleteByUserID(ctx, 9)
|
||||
}
|
||||
|
||||
func TestYggdrasilRepository_Basic(t *testing.T) {
|
||||
db := testutil.NewTestDB(t)
|
||||
userRepo := NewUserRepository(db)
|
||||
yggRepo := NewYggdrasilRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &model.User{Username: "u-ygg", Email: "ygg@test.com", Password: "pwd", Status: 1}
|
||||
_ = userRepo.Create(ctx, user) // AfterCreate 会生成 yggdrasil 记录
|
||||
|
||||
pwd, err := yggRepo.GetPasswordByID(ctx, user.ID)
|
||||
if err != nil || pwd == "" {
|
||||
t.Fatalf("GetPasswordByID err=%v pwd=%s", err, pwd)
|
||||
}
|
||||
if err := yggRepo.ResetPassword(ctx, user.ID, "newpwd"); err != nil {
|
||||
t.Fatalf("ResetPassword err: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// tokenRepository TokenRepository的实现
|
||||
type tokenRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewTokenRepository 创建TokenRepository实例
|
||||
func NewTokenRepository(db *gorm.DB) TokenRepository {
|
||||
return &tokenRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *tokenRepository) Create(ctx context.Context, token *model.Token) error {
|
||||
return r.db.WithContext(ctx).Create(token).Error
|
||||
}
|
||||
|
||||
func (r *tokenRepository) FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error) {
|
||||
var token model.Token
|
||||
err := r.db.WithContext(ctx).Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (r *tokenRepository) GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error) {
|
||||
var tokens []*model.Token
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userId).Find(&tokens).Error
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
func (r *tokenRepository) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
|
||||
var token model.Token
|
||||
err := r.db.WithContext(ctx).Select("profile_id").Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return token.ProfileId, nil
|
||||
}
|
||||
|
||||
func (r *tokenRepository) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
|
||||
var token model.Token
|
||||
err := r.db.WithContext(ctx).Select("user_id").Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return token.UserID, nil
|
||||
}
|
||||
|
||||
func (r *tokenRepository) DeleteByAccessToken(ctx context.Context, accessToken string) error {
|
||||
return r.db.WithContext(ctx).Where("access_token = ?", accessToken).Delete(&model.Token{}).Error
|
||||
}
|
||||
|
||||
func (r *tokenRepository) DeleteByUserID(ctx context.Context, userId int64) error {
|
||||
return r.db.WithContext(ctx).Where("user_id = ?", userId).Delete(&model.Token{}).Error
|
||||
}
|
||||
|
||||
func (r *tokenRepository) BatchDelete(ctx context.Context, accessTokens []string) (int64, error) {
|
||||
if len(accessTokens) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
result := r.db.WithContext(ctx).Where("access_token IN ?", accessTokens).Delete(&model.Token{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestTokenRepository_BatchDeleteLogic 测试批量删除逻辑
|
||||
func TestTokenRepository_BatchDeleteLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokensToDelete []string
|
||||
wantCount int64
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "有效的token列表",
|
||||
tokensToDelete: []string{"token1", "token2", "token3"},
|
||||
wantCount: 3,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "空列表应该返回0",
|
||||
tokensToDelete: []string{},
|
||||
wantCount: 0,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "单个token",
|
||||
tokensToDelete: []string{"token1"},
|
||||
wantCount: 1,
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证批量删除逻辑:空列表应该直接返回0
|
||||
if len(tt.tokensToDelete) == 0 {
|
||||
if tt.wantCount != 0 {
|
||||
t.Errorf("Empty list should return count 0, got %d", tt.wantCount)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenRepository_QueryConditions 测试token查询条件逻辑
|
||||
func TestTokenRepository_QueryConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
userID int64
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的access token",
|
||||
accessToken: "valid-token-123",
|
||||
userID: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "access token为空",
|
||||
accessToken: "",
|
||||
userID: 1,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "用户ID为0",
|
||||
accessToken: "valid-token-123",
|
||||
userID: 0,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.accessToken != "" && tt.userID > 0
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Query condition validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenRepository_FindTokenByIDLogic 测试根据ID查找token的逻辑
|
||||
func TestTokenRepository_FindTokenByIDLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
resultCount int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "找到token",
|
||||
accessToken: "token-123",
|
||||
resultCount: 1,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "未找到token",
|
||||
accessToken: "token-123",
|
||||
resultCount: 0,
|
||||
wantError: true, // 访问索引0会panic
|
||||
},
|
||||
{
|
||||
name: "找到多个token(异常情况)",
|
||||
accessToken: "token-123",
|
||||
resultCount: 2,
|
||||
wantError: false, // 返回第一个
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证逻辑:如果结果为空,访问索引0会出错
|
||||
hasError := tt.resultCount == 0
|
||||
if hasError != tt.wantError {
|
||||
t.Errorf("FindTokenByID logic failed: got error=%v, want error=%v", hasError, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -315,6 +315,18 @@ func (m *MockTextureRepository) FindByHash(ctx context.Context, hash string) (*m
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) FindByHashAndUploaderID(ctx context.Context, hash string, uploaderID int64) (*model.Texture, error) {
|
||||
if m.FailFind {
|
||||
return nil, errors.New("mock find error")
|
||||
}
|
||||
for _, texture := range m.textures {
|
||||
if texture.Hash == hash && texture.UploaderID == uploaderID {
|
||||
return texture, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) FindByUploaderID(ctx context.Context, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
if m.FailFind {
|
||||
return nil, 0, errors.New("mock find error")
|
||||
@@ -462,101 +474,6 @@ func (m *MockTextureRepository) BatchDelete(ctx context.Context, ids []int64) (i
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// MockTokenRepository 模拟TokenRepository
|
||||
type MockTokenRepository struct {
|
||||
tokens map[string]*model.Token
|
||||
userTokens map[int64][]*model.Token
|
||||
FailCreate bool
|
||||
FailFind bool
|
||||
FailDelete bool
|
||||
}
|
||||
|
||||
func NewMockTokenRepository() *MockTokenRepository {
|
||||
return &MockTokenRepository{
|
||||
tokens: make(map[string]*model.Token),
|
||||
userTokens: make(map[int64][]*model.Token),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) Create(ctx context.Context, token *model.Token) error {
|
||||
if m.FailCreate {
|
||||
return errors.New("mock create error")
|
||||
}
|
||||
m.tokens[token.AccessToken] = token
|
||||
m.userTokens[token.UserID] = append(m.userTokens[token.UserID], token)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error) {
|
||||
if m.FailFind {
|
||||
return nil, errors.New("mock find error")
|
||||
}
|
||||
if token, ok := m.tokens[accessToken]; ok {
|
||||
return token, nil
|
||||
}
|
||||
return nil, errors.New("token not found")
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error) {
|
||||
if m.FailFind {
|
||||
return nil, errors.New("mock find error")
|
||||
}
|
||||
return m.userTokens[userId], nil
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
|
||||
if m.FailFind {
|
||||
return "", errors.New("mock find error")
|
||||
}
|
||||
if token, ok := m.tokens[accessToken]; ok {
|
||||
return token.ProfileId, nil
|
||||
}
|
||||
return "", errors.New("token not found")
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
|
||||
if m.FailFind {
|
||||
return 0, errors.New("mock find error")
|
||||
}
|
||||
if token, ok := m.tokens[accessToken]; ok {
|
||||
return token.UserID, nil
|
||||
}
|
||||
return 0, errors.New("token not found")
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) DeleteByAccessToken(ctx context.Context, accessToken string) error {
|
||||
if m.FailDelete {
|
||||
return errors.New("mock delete error")
|
||||
}
|
||||
delete(m.tokens, accessToken)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) DeleteByUserID(ctx context.Context, userId int64) error {
|
||||
if m.FailDelete {
|
||||
return errors.New("mock delete error")
|
||||
}
|
||||
for _, token := range m.userTokens[userId] {
|
||||
delete(m.tokens, token.AccessToken)
|
||||
}
|
||||
m.userTokens[userId] = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) BatchDelete(ctx context.Context, accessTokens []string) (int64, error) {
|
||||
if m.FailDelete {
|
||||
return 0, errors.New("mock delete error")
|
||||
}
|
||||
var count int64
|
||||
for _, accessToken := range accessTokens {
|
||||
if _, ok := m.tokens[accessToken]; ok {
|
||||
delete(m.tokens, accessToken)
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// MockSystemConfigRepository 模拟SystemConfigRepository
|
||||
type MockSystemConfigRepository struct {
|
||||
configs map[string]*model.SystemConfig
|
||||
@@ -956,90 +873,11 @@ func (m *MockTextureService) CheckUploadLimit(uploaderID int64, maxTextures int)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MockTokenService 模拟TokenService
|
||||
type MockTokenService struct {
|
||||
tokens map[string]*model.Token
|
||||
FailCreate bool
|
||||
FailValidate bool
|
||||
FailRefresh bool
|
||||
}
|
||||
|
||||
func NewMockTokenService() *MockTokenService {
|
||||
return &MockTokenService{
|
||||
tokens: make(map[string]*model.Token),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockTokenService) Create(userID int64, uuid, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
|
||||
if m.FailCreate {
|
||||
return nil, nil, "", "", errors.New("mock create error")
|
||||
}
|
||||
accessToken := "mock-access-token"
|
||||
if clientToken == "" {
|
||||
clientToken = "mock-client-token"
|
||||
}
|
||||
token := &model.Token{
|
||||
AccessToken: accessToken,
|
||||
ClientToken: clientToken,
|
||||
UserID: userID,
|
||||
ProfileId: uuid,
|
||||
Usable: true,
|
||||
}
|
||||
m.tokens[accessToken] = token
|
||||
return nil, nil, accessToken, clientToken, nil
|
||||
}
|
||||
|
||||
func (m *MockTokenService) Validate(accessToken, clientToken string) bool {
|
||||
if m.FailValidate {
|
||||
return false
|
||||
}
|
||||
if token, ok := m.tokens[accessToken]; ok {
|
||||
if clientToken == "" || token.ClientToken == clientToken {
|
||||
return token.Usable
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MockTokenService) Refresh(accessToken, clientToken, selectedProfileID string) (string, string, error) {
|
||||
if m.FailRefresh {
|
||||
return "", "", errors.New("mock refresh error")
|
||||
}
|
||||
return "new-access-token", clientToken, nil
|
||||
}
|
||||
|
||||
func (m *MockTokenService) Invalidate(accessToken string) {
|
||||
delete(m.tokens, accessToken)
|
||||
}
|
||||
|
||||
func (m *MockTokenService) InvalidateUserTokens(userID int64) {
|
||||
for key, token := range m.tokens {
|
||||
if token.UserID == userID {
|
||||
delete(m.tokens, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockTokenService) GetUUIDByAccessToken(accessToken string) (string, error) {
|
||||
if token, ok := m.tokens[accessToken]; ok {
|
||||
return token.ProfileId, nil
|
||||
}
|
||||
return "", errors.New("token not found")
|
||||
}
|
||||
|
||||
func (m *MockTokenService) GetUserIDByAccessToken(accessToken string) (int64, error) {
|
||||
if token, ok := m.tokens[accessToken]; ok {
|
||||
return token.UserID, nil
|
||||
}
|
||||
return 0, errors.New("token not found")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// CacheManager Mock - uses database.CacheManager with nil redis
|
||||
// CacheManager Mock - 使用 database.CacheManager 的内存版本
|
||||
// ============================================================================
|
||||
|
||||
// NewMockCacheManager 创建一个禁用的 CacheManager 用于测试
|
||||
// 通过设置 Enabled = false,缓存操作会被跳过,测试不依赖 Redis
|
||||
// NewMockCacheManager 创建一个内存 CacheManager 用于测试
|
||||
func NewMockCacheManager() *database.CacheManager {
|
||||
return database.NewCacheManager(nil, database.CacheConfig{
|
||||
Prefix: "test:",
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
@@ -99,7 +98,7 @@ func (s *profileService) GetByUUID(ctx context.Context, uuid string) (*model.Pro
|
||||
// 尝试从缓存获取
|
||||
cacheKey := s.cacheKeys.Profile(uuid)
|
||||
var profile model.Profile
|
||||
if err := s.cache.Get(ctx, cacheKey, &profile); err == nil {
|
||||
if ok, _ := s.cache.TryGet(ctx, cacheKey, &profile); ok {
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
@@ -112,11 +111,9 @@ func (s *profileService) GetByUUID(ctx context.Context, uuid string) (*model.Pro
|
||||
return nil, fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 存入缓存(异步,5分钟过期)
|
||||
// 存入缓存(异步)
|
||||
if profile2 != nil {
|
||||
go func() {
|
||||
_ = s.cache.Set(context.Background(), cacheKey, profile2, 5*time.Minute)
|
||||
}()
|
||||
s.cache.SetAsync(context.Background(), cacheKey, profile2, s.cache.Policy.ProfileTTL)
|
||||
}
|
||||
|
||||
return profile2, nil
|
||||
@@ -126,7 +123,7 @@ func (s *profileService) GetByUserID(ctx context.Context, userID int64) ([]*mode
|
||||
// 尝试从缓存获取
|
||||
cacheKey := s.cacheKeys.ProfileList(userID)
|
||||
var profiles []*model.Profile
|
||||
if err := s.cache.Get(ctx, cacheKey, &profiles); err == nil {
|
||||
if ok, _ := s.cache.TryGet(ctx, cacheKey, &profiles); ok {
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
@@ -136,11 +133,9 @@ func (s *profileService) GetByUserID(ctx context.Context, userID int64) ([]*mode
|
||||
return nil, fmt.Errorf("查询档案列表失败: %w", err)
|
||||
}
|
||||
|
||||
// 存入缓存(异步,3分钟过期)
|
||||
// 存入缓存(异步)
|
||||
if profiles != nil {
|
||||
go func() {
|
||||
_ = s.cache.Set(context.Background(), cacheKey, profiles, 3*time.Minute)
|
||||
}()
|
||||
s.cache.SetAsync(context.Background(), cacheKey, profiles, s.cache.Policy.ProfileListTTL)
|
||||
}
|
||||
|
||||
return profiles, nil
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -103,7 +102,7 @@ func (s *textureService) GetByID(ctx context.Context, id int64) (*model.Texture,
|
||||
// 尝试从缓存获取
|
||||
cacheKey := s.cacheKeys.Texture(id)
|
||||
var texture model.Texture
|
||||
if err := s.cache.Get(ctx, cacheKey, &texture); err == nil {
|
||||
if ok, _ := s.cache.TryGet(ctx, cacheKey, &texture); ok {
|
||||
if texture.Status == -1 {
|
||||
return nil, errors.New("材质已删除")
|
||||
}
|
||||
@@ -122,11 +121,9 @@ func (s *textureService) GetByID(ctx context.Context, id int64) (*model.Texture,
|
||||
return nil, errors.New("材质已删除")
|
||||
}
|
||||
|
||||
// 存入缓存(异步,5分钟过期)
|
||||
// 存入缓存(异步)
|
||||
if texture2 != nil {
|
||||
go func() {
|
||||
_ = s.cache.Set(context.Background(), cacheKey, texture2, 5*time.Minute)
|
||||
}()
|
||||
s.cache.SetAsync(context.Background(), cacheKey, texture2, s.cache.Policy.TextureTTL)
|
||||
}
|
||||
|
||||
return texture2, nil
|
||||
@@ -136,7 +133,7 @@ func (s *textureService) GetByHash(ctx context.Context, hash string) (*model.Tex
|
||||
// 尝试从缓存获取
|
||||
cacheKey := s.cacheKeys.TextureByHash(hash)
|
||||
var texture model.Texture
|
||||
if err := s.cache.Get(ctx, cacheKey, &texture); err == nil {
|
||||
if ok, _ := s.cache.TryGet(ctx, cacheKey, &texture); ok {
|
||||
if texture.Status == -1 {
|
||||
return nil, errors.New("材质已删除")
|
||||
}
|
||||
@@ -155,10 +152,8 @@ func (s *textureService) GetByHash(ctx context.Context, hash string) (*model.Tex
|
||||
return nil, errors.New("材质已删除")
|
||||
}
|
||||
|
||||
// 存入缓存(异步,5分钟过期)
|
||||
go func() {
|
||||
_ = s.cache.Set(context.Background(), cacheKey, texture2, 5*time.Minute)
|
||||
}()
|
||||
// 存入缓存(异步)
|
||||
s.cache.SetAsync(context.Background(), cacheKey, texture2, s.cache.Policy.TextureTTL)
|
||||
|
||||
return texture2, nil
|
||||
}
|
||||
@@ -172,7 +167,7 @@ func (s *textureService) GetByUserID(ctx context.Context, uploaderID int64, page
|
||||
Textures []*model.Texture
|
||||
Total int64
|
||||
}
|
||||
if err := s.cache.Get(ctx, cacheKey, &cachedResult); err == nil {
|
||||
if ok, _ := s.cache.TryGet(ctx, cacheKey, &cachedResult); ok {
|
||||
return cachedResult.Textures, cachedResult.Total, nil
|
||||
}
|
||||
|
||||
@@ -182,14 +177,12 @@ func (s *textureService) GetByUserID(ctx context.Context, uploaderID int64, page
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 存入缓存(异步,2分钟过期)
|
||||
go func() {
|
||||
result := struct {
|
||||
Textures []*model.Texture
|
||||
Total int64
|
||||
}{Textures: textures, Total: total}
|
||||
_ = s.cache.Set(context.Background(), cacheKey, result, 2*time.Minute)
|
||||
}()
|
||||
// 存入缓存(异步)
|
||||
result := struct {
|
||||
Textures []*model.Texture
|
||||
Total int64
|
||||
}{Textures: textures, Total: total}
|
||||
s.cache.SetAsync(context.Background(), cacheKey, result, s.cache.Policy.TextureListTTL)
|
||||
|
||||
return textures, total, nil
|
||||
}
|
||||
@@ -232,7 +225,7 @@ func (s *textureService) Update(ctx context.Context, textureID, uploaderID int64
|
||||
|
||||
// 清除 texture 缓存和用户列表缓存
|
||||
s.cacheInv.OnUpdate(ctx, s.cacheKeys.Texture(textureID))
|
||||
s.cacheInv.BatchInvalidate(ctx, fmt.Sprintf("texture:user:%d:*", uploaderID))
|
||||
s.cacheInv.BatchInvalidate(ctx, s.cacheKeys.TextureListPattern(uploaderID))
|
||||
|
||||
return s.textureRepo.FindByID(ctx, textureID)
|
||||
}
|
||||
@@ -257,7 +250,7 @@ func (s *textureService) Delete(ctx context.Context, textureID, uploaderID int64
|
||||
|
||||
// 清除 texture 缓存和用户列表缓存
|
||||
s.cacheInv.OnDelete(ctx, s.cacheKeys.Texture(textureID))
|
||||
s.cacheInv.BatchInvalidate(ctx, fmt.Sprintf("texture:user:%d:*", uploaderID))
|
||||
s.cacheInv.BatchInvalidate(ctx, s.cacheKeys.TextureListPattern(uploaderID))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -494,7 +494,7 @@ func TestTextureServiceImpl_Create(t *testing.T) {
|
||||
_ = userRepo.Create(context.Background(), testUser)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger)
|
||||
textureService := NewTextureService(textureRepo, userRepo, nil, cacheManager, logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -536,8 +536,7 @@ func TestTextureServiceImpl_Create(t *testing.T) {
|
||||
textureName: "DuplicateTexture",
|
||||
textureType: "SKIN",
|
||||
hash: "existing-hash",
|
||||
wantErr: true,
|
||||
errContains: "已存在",
|
||||
wantErr: false,
|
||||
setupMocks: func() {
|
||||
_ = textureRepo.Create(context.Background(), &model.Texture{
|
||||
ID: 100,
|
||||
@@ -617,7 +616,7 @@ func TestTextureServiceImpl_GetByID(t *testing.T) {
|
||||
_ = textureRepo.Create(context.Background(), testTexture)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger)
|
||||
textureService := NewTextureService(textureRepo, userRepo, nil, cacheManager, logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -675,7 +674,7 @@ func TestTextureServiceImpl_GetByUserID_And_Search(t *testing.T) {
|
||||
}
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger)
|
||||
textureService := NewTextureService(textureRepo, userRepo, nil, cacheManager, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -714,7 +713,7 @@ func TestTextureServiceImpl_Update_And_Delete(t *testing.T) {
|
||||
_ = textureRepo.Create(context.Background(), texture)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger)
|
||||
textureService := NewTextureService(textureRepo, userRepo, nil, cacheManager, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -764,7 +763,7 @@ func TestTextureServiceImpl_FavoritesAndLimit(t *testing.T) {
|
||||
}
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger)
|
||||
textureService := NewTextureService(textureRepo, userRepo, nil, cacheManager, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -807,7 +806,7 @@ func TestTextureServiceImpl_ToggleFavorite(t *testing.T) {
|
||||
_ = textureRepo.Create(context.Background(), testTexture)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger)
|
||||
textureService := NewTextureService(textureRepo, userRepo, nil, cacheManager, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
@@ -1,305 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// tokenService TokenService的实现
|
||||
type tokenService struct {
|
||||
tokenRepo repository.TokenRepository
|
||||
profileRepo repository.ProfileRepository
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewTokenService 创建TokenService实例
|
||||
func NewTokenService(
|
||||
tokenRepo repository.TokenRepository,
|
||||
profileRepo repository.ProfileRepository,
|
||||
logger *zap.Logger,
|
||||
) TokenService {
|
||||
return &tokenService{
|
||||
tokenRepo: tokenRepo,
|
||||
profileRepo: profileRepo,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
tokenExtendedTimeout = 10 * time.Second
|
||||
tokensMaxCount = 10
|
||||
)
|
||||
|
||||
func (s *tokenService) Create(ctx context.Context, userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
|
||||
var (
|
||||
selectedProfileID *model.Profile
|
||||
availableProfiles []*model.Profile
|
||||
)
|
||||
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
// 验证用户存在
|
||||
if UUID != "" {
|
||||
_, err := s.profileRepo.FindByUUID(ctx, UUID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 生成令牌
|
||||
if clientToken == "" {
|
||||
clientToken = uuid.New().String()
|
||||
}
|
||||
|
||||
accessToken := uuid.New().String()
|
||||
token := model.Token{
|
||||
AccessToken: accessToken,
|
||||
ClientToken: clientToken,
|
||||
UserID: userID,
|
||||
Usable: true,
|
||||
IssueDate: time.Now(),
|
||||
}
|
||||
|
||||
// 获取用户配置文件
|
||||
profiles, err := s.profileRepo.FindByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 如果用户只有一个配置文件,自动选择
|
||||
if len(profiles) == 1 {
|
||||
selectedProfileID = profiles[0]
|
||||
token.ProfileId = selectedProfileID.UUID
|
||||
}
|
||||
availableProfiles = profiles
|
||||
|
||||
// 插入令牌
|
||||
err = s.tokenRepo.Create(ctx, &token)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err)
|
||||
}
|
||||
|
||||
// 清理多余的令牌(使用独立的后台上下文)
|
||||
go s.checkAndCleanupExcessTokens(context.Background(), userID)
|
||||
|
||||
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
|
||||
}
|
||||
|
||||
func (s *tokenService) Validate(ctx context.Context, accessToken, clientToken string) bool {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
token, err := s.tokenRepo.FindByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !token.Usable {
|
||||
return false
|
||||
}
|
||||
|
||||
if clientToken == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
return token.ClientToken == clientToken
|
||||
}
|
||||
|
||||
func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("accessToken不能为空")
|
||||
}
|
||||
|
||||
// 查找旧令牌
|
||||
oldToken, err := s.tokenRepo.FindByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", "", errors.New("accessToken无效")
|
||||
}
|
||||
s.logger.Error("查询Token失败", zap.Error(err), zap.String("accessToken", accessToken))
|
||||
return "", "", fmt.Errorf("查询令牌失败: %w", err)
|
||||
}
|
||||
|
||||
// 验证profile
|
||||
if selectedProfileID != "" {
|
||||
valid, validErr := s.validateProfileByUserID(ctx, oldToken.UserID, selectedProfileID)
|
||||
if validErr != nil {
|
||||
s.logger.Error("验证Profile失败",
|
||||
zap.Error(err),
|
||||
zap.Int64("userId", oldToken.UserID),
|
||||
zap.String("profileId", selectedProfileID),
|
||||
)
|
||||
return "", "", fmt.Errorf("验证角色失败: %w", err)
|
||||
}
|
||||
if !valid {
|
||||
return "", "", errors.New("角色与用户不匹配")
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 clientToken 是否有效
|
||||
if clientToken != "" && clientToken != oldToken.ClientToken {
|
||||
return "", "", errors.New("clientToken无效")
|
||||
}
|
||||
|
||||
// 检查 selectedProfileID 的逻辑
|
||||
if selectedProfileID != "" {
|
||||
if oldToken.ProfileId != "" && oldToken.ProfileId != selectedProfileID {
|
||||
return "", "", errors.New("原令牌已绑定角色,无法选择新角色")
|
||||
}
|
||||
} else {
|
||||
selectedProfileID = oldToken.ProfileId
|
||||
}
|
||||
|
||||
// 生成新令牌
|
||||
newAccessToken := uuid.New().String()
|
||||
newToken := model.Token{
|
||||
AccessToken: newAccessToken,
|
||||
ClientToken: oldToken.ClientToken,
|
||||
UserID: oldToken.UserID,
|
||||
Usable: true,
|
||||
ProfileId: selectedProfileID,
|
||||
IssueDate: time.Now(),
|
||||
}
|
||||
|
||||
// 先插入新令牌,再删除旧令牌
|
||||
err = s.tokenRepo.Create(ctx, &newToken)
|
||||
if err != nil {
|
||||
s.logger.Error("创建新Token失败", zap.Error(err), zap.String("accessToken", accessToken))
|
||||
return "", "", fmt.Errorf("创建新Token失败: %w", err)
|
||||
}
|
||||
|
||||
err = s.tokenRepo.DeleteByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
s.logger.Warn("删除旧Token失败,但新Token已创建",
|
||||
zap.Error(err),
|
||||
zap.String("oldToken", oldToken.AccessToken),
|
||||
zap.String("newToken", newAccessToken),
|
||||
)
|
||||
}
|
||||
|
||||
s.logger.Info("成功刷新Token", zap.Int64("userId", oldToken.UserID), zap.String("accessToken", newAccessToken))
|
||||
return newAccessToken, oldToken.ClientToken, nil
|
||||
}
|
||||
|
||||
func (s *tokenService) Invalidate(ctx context.Context, accessToken string) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err := s.tokenRepo.DeleteByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
s.logger.Error("删除Token失败", zap.Error(err), zap.String("accessToken", accessToken))
|
||||
return
|
||||
}
|
||||
s.logger.Info("成功删除Token", zap.String("token", accessToken))
|
||||
}
|
||||
|
||||
func (s *tokenService) InvalidateUserTokens(ctx context.Context, userID int64) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if userID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := s.tokenRepo.DeleteByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
s.logger.Error("删除用户Token失败", zap.Error(err), zap.Int64("userId", userID))
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Info("成功删除用户Token", zap.Int64("userId", userID))
|
||||
}
|
||||
|
||||
func (s *tokenService) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
return s.tokenRepo.GetUUIDByAccessToken(ctx, accessToken)
|
||||
}
|
||||
|
||||
func (s *tokenService) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
return s.tokenRepo.GetUserIDByAccessToken(ctx, accessToken)
|
||||
}
|
||||
|
||||
// 私有辅助方法
|
||||
|
||||
func (s *tokenService) checkAndCleanupExcessTokens(ctx context.Context, userID int64) {
|
||||
if userID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 为清理操作设置更长的超时时间
|
||||
ctx, cancel := context.WithTimeout(ctx, tokenExtendedTimeout)
|
||||
defer cancel()
|
||||
|
||||
tokens, err := s.tokenRepo.GetByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
}
|
||||
|
||||
if len(tokens) <= tokensMaxCount {
|
||||
return
|
||||
}
|
||||
|
||||
tokensToDelete := make([]string, 0, len(tokens)-tokensMaxCount)
|
||||
for i := tokensMaxCount; i < len(tokens); i++ {
|
||||
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
|
||||
}
|
||||
|
||||
deletedCount, err := s.tokenRepo.BatchDelete(ctx, tokensToDelete)
|
||||
if err != nil {
|
||||
s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
}
|
||||
|
||||
if deletedCount > 0 {
|
||||
s.logger.Info("成功清理用户多余Token", zap.Int64("userId", userID), zap.Int64("count", deletedCount))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *tokenService) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) {
|
||||
if userID == 0 || UUID == "" {
|
||||
return false, errors.New("用户ID或配置文件ID不能为空")
|
||||
}
|
||||
|
||||
profile, err := s.profileRepo.FindByUUID(ctx, UUID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return false, errors.New("配置文件不存在")
|
||||
}
|
||||
return false, fmt.Errorf("验证配置文件失败: %w", err)
|
||||
}
|
||||
return profile.UserID == userID, nil
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -15,40 +14,38 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// tokenServiceJWT TokenService的JWT实现(使用JWT + Version机制)
|
||||
type tokenServiceJWT struct {
|
||||
tokenRepo repository.TokenRepository
|
||||
clientRepo repository.ClientRepository
|
||||
profileRepo repository.ProfileRepository
|
||||
yggdrasilJWT *auth.YggdrasilJWTService
|
||||
logger *zap.Logger
|
||||
tokenExpireSec int64 // Token过期时间(秒),0表示永不过期
|
||||
tokenStaleSec int64 // Token过期但可用时间(秒),0表示永不过期
|
||||
// tokenServiceRedis TokenService的Redis实现
|
||||
type tokenServiceRedis struct {
|
||||
tokenStore *auth.TokenStoreRedis
|
||||
clientRepo repository.ClientRepository
|
||||
profileRepo repository.ProfileRepository
|
||||
yggdrasilJWT *auth.YggdrasilJWTService
|
||||
logger *zap.Logger
|
||||
tokenExpireSec int64 // Token过期时间(秒),0表示永不过期
|
||||
tokenStaleSec int64 // Token过期但可用时间(秒),0表示永不过期
|
||||
}
|
||||
|
||||
// NewTokenServiceJWT 创建使用JWT的TokenService实例
|
||||
func NewTokenServiceJWT(
|
||||
tokenRepo repository.TokenRepository,
|
||||
// NewTokenServiceRedis 创建使用Redis的TokenService实例
|
||||
func NewTokenServiceRedis(
|
||||
tokenStore *auth.TokenStoreRedis,
|
||||
clientRepo repository.ClientRepository,
|
||||
profileRepo repository.ProfileRepository,
|
||||
yggdrasilJWT *auth.YggdrasilJWTService,
|
||||
logger *zap.Logger,
|
||||
) TokenService {
|
||||
return &tokenServiceJWT{
|
||||
tokenRepo: tokenRepo,
|
||||
return &tokenServiceRedis{
|
||||
tokenStore: tokenStore,
|
||||
clientRepo: clientRepo,
|
||||
profileRepo: profileRepo,
|
||||
yggdrasilJWT: yggdrasilJWT,
|
||||
logger: logger,
|
||||
tokenExpireSec: 24 * 3600, // 默认24小时
|
||||
tokenExpireSec: 24 * 3600, // 默认24小时
|
||||
tokenStaleSec: 30 * 24 * 3600, // 默认30天
|
||||
}
|
||||
}
|
||||
|
||||
// 常量已在 token_service.go 中定义,这里不重复定义
|
||||
|
||||
// Create 创建Token(使用JWT + Version机制)
|
||||
func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
|
||||
// Create 创建Token(使用JWT + Redis存储)
|
||||
func (s *tokenServiceRedis) Create(ctx context.Context, userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
|
||||
var (
|
||||
selectedProfileID *model.Profile
|
||||
availableProfiles []*model.Profile
|
||||
@@ -85,11 +82,11 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
|
||||
if UUID != "" {
|
||||
client.ProfileID = UUID
|
||||
}
|
||||
|
||||
|
||||
if err := s.clientRepo.Create(ctx, client); err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Client失败: %w", err)
|
||||
}
|
||||
@@ -103,7 +100,7 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
client.UpdatedAt = time.Now()
|
||||
if UUID != "" {
|
||||
client.ProfileID = UUID
|
||||
if err := s.clientRepo.Update(ctx, client); err != nil {
|
||||
if err := s.clientRepo.Update(ctx, client); err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("更新Client失败: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -130,14 +127,14 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
// 生成Token过期时间
|
||||
now := time.Now()
|
||||
var expiresAt, staleAt time.Time
|
||||
|
||||
|
||||
if s.tokenExpireSec > 0 {
|
||||
expiresAt = now.Add(time.Duration(s.tokenExpireSec) * time.Second)
|
||||
} else {
|
||||
// 使用遥远的未来时间(类似drasl的DISTANT_FUTURE)
|
||||
// 使用遥远的未来时间
|
||||
expiresAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
|
||||
if s.tokenStaleSec > 0 {
|
||||
staleAt = now.Add(time.Duration(s.tokenStaleSec) * time.Second)
|
||||
} else {
|
||||
@@ -157,36 +154,31 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("生成AccessToken失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存Token记录(用于查询和审计)
|
||||
token := model.Token{
|
||||
AccessToken: accessToken,
|
||||
ClientToken: clientToken,
|
||||
// 存储Token到Redis
|
||||
ttl := expiresAt.Sub(now)
|
||||
metadata := &auth.TokenMetadata{
|
||||
UserID: userID,
|
||||
ProfileId: profileID,
|
||||
ProfileID: profileID,
|
||||
ClientUUID: client.UUID,
|
||||
ClientToken: client.ClientToken,
|
||||
Version: client.Version,
|
||||
Usable: true,
|
||||
IssueDate: now,
|
||||
ExpiresAt: &expiresAt,
|
||||
StaleAt: &staleAt,
|
||||
CreatedAt: now.Unix(),
|
||||
}
|
||||
|
||||
err = s.tokenRepo.Create(ctx, &token)
|
||||
if err != nil {
|
||||
s.logger.Warn("保存Token记录失败,但JWT已生成", zap.Error(err))
|
||||
if err := s.tokenStore.Store(ctx, accessToken, metadata, ttl); err != nil {
|
||||
s.logger.Warn("存储Token到Redis失败", zap.Error(err))
|
||||
// 不返回错误,因为JWT本身已经生成成功
|
||||
}
|
||||
|
||||
// 清理多余的令牌(使用独立的后台上下文)
|
||||
go s.checkAndCleanupExcessTokens(context.Background(), userID)
|
||||
|
||||
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
|
||||
}
|
||||
|
||||
// Validate 验证Token(使用JWT验证)
|
||||
func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken string) bool {
|
||||
// Validate 验证Token(使用JWT验证 + Redis存储验证)
|
||||
func (s *tokenServiceRedis) Validate(ctx context.Context, accessToken, clientToken string) bool {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return false
|
||||
}
|
||||
@@ -197,6 +189,13 @@ func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken
|
||||
return false
|
||||
}
|
||||
|
||||
// 从Redis获取Token元数据
|
||||
metadata, err := s.tokenStore.Retrieve(ctx, accessToken)
|
||||
if err != nil {
|
||||
// Token可能已过期或不存在
|
||||
return false
|
||||
}
|
||||
|
||||
// 查找Client
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
@@ -209,18 +208,19 @@ func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken
|
||||
}
|
||||
|
||||
// 验证ClientToken(如果提供)
|
||||
if clientToken != "" && client.ClientToken != clientToken {
|
||||
if clientToken != "" && metadata.ClientToken != clientToken {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Refresh 刷新Token(使用Version机制,无需删除旧Token)
|
||||
func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
|
||||
// Refresh 刷新Token(使用Version机制,Redis存储)
|
||||
func (s *tokenServiceRedis) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("accessToken不能为空")
|
||||
}
|
||||
@@ -279,16 +279,21 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
|
||||
return "", "", fmt.Errorf("更新Client版本失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除旧Token(从Redis)
|
||||
if err := s.tokenStore.Delete(ctx, accessToken); err != nil {
|
||||
s.logger.Warn("删除旧Token失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 生成Token过期时间
|
||||
now := time.Now()
|
||||
var expiresAt, staleAt time.Time
|
||||
|
||||
|
||||
if s.tokenExpireSec > 0 {
|
||||
expiresAt = now.Add(time.Duration(s.tokenExpireSec) * time.Second)
|
||||
} else {
|
||||
expiresAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
|
||||
if s.tokenStaleSec > 0 {
|
||||
staleAt = now.Add(time.Duration(s.tokenStaleSec) * time.Second)
|
||||
} else {
|
||||
@@ -308,30 +313,27 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
|
||||
return "", "", fmt.Errorf("生成新AccessToken失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存新Token记录
|
||||
newToken := model.Token{
|
||||
AccessToken: newAccessToken,
|
||||
ClientToken: client.ClientToken,
|
||||
// 存储新Token到Redis
|
||||
ttl := expiresAt.Sub(now)
|
||||
metadata := &auth.TokenMetadata{
|
||||
UserID: client.UserID,
|
||||
ProfileId: selectedProfileID,
|
||||
ProfileID: selectedProfileID,
|
||||
ClientUUID: client.UUID,
|
||||
ClientToken: client.ClientToken,
|
||||
Version: client.Version,
|
||||
Usable: true,
|
||||
IssueDate: now,
|
||||
ExpiresAt: &expiresAt,
|
||||
StaleAt: &staleAt,
|
||||
CreatedAt: now.Unix(),
|
||||
}
|
||||
|
||||
err = s.tokenRepo.Create(ctx, &newToken)
|
||||
if err != nil {
|
||||
s.logger.Warn("保存新Token记录失败,但JWT已生成", zap.Error(err))
|
||||
if err := s.tokenStore.Store(ctx, newAccessToken, metadata, ttl); err != nil {
|
||||
s.logger.Warn("存储新Token到Redis失败", zap.Error(err))
|
||||
}
|
||||
|
||||
s.logger.Info("成功刷新Token", zap.Int64("userId", client.UserID), zap.Int("version", client.Version))
|
||||
return newAccessToken, client.ClientToken, nil
|
||||
}
|
||||
|
||||
// Invalidate 使Token失效(通过增加Version)
|
||||
func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
|
||||
// Invalidate 使Token失效(从Redis删除)
|
||||
func (s *tokenServiceRedis) Invalidate(ctx context.Context, accessToken string) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
@@ -347,7 +349,7 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
|
||||
return
|
||||
}
|
||||
|
||||
// 查找Client并增加Version
|
||||
// 查找Client并增加Version(失效所有旧Token)
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
s.logger.Warn("无法找到对应的Client", zap.Error(err))
|
||||
@@ -362,11 +364,17 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
|
||||
return
|
||||
}
|
||||
|
||||
// 从Redis删除Token
|
||||
if err := s.tokenStore.Delete(ctx, accessToken); err != nil {
|
||||
s.logger.Warn("从Redis删除Token失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Info("成功失效Token", zap.String("clientUUID", client.UUID), zap.Int("version", client.Version))
|
||||
}
|
||||
|
||||
// InvalidateUserTokens 使用户所有Token失效
|
||||
func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64) {
|
||||
// InvalidateUserTokens 使用户所有Token失效(从Redis删除)
|
||||
func (s *tokenServiceRedis) InvalidateUserTokens(ctx context.Context, userID int64) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
@@ -391,15 +399,20 @@ func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64
|
||||
}
|
||||
}
|
||||
|
||||
// 从Redis删除用户所有Token
|
||||
if err := s.tokenStore.DeleteByUserID(ctx, userID); err != nil {
|
||||
s.logger.Error("从Redis删除用户Token失败", zap.Error(err), zap.Int64("userId", userID))
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Info("成功失效用户所有Token", zap.Int64("userId", userID), zap.Int("clientCount", len(clients)))
|
||||
}
|
||||
|
||||
// GetUUIDByAccessToken 从AccessToken获取UUID(通过JWT解析)
|
||||
func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
|
||||
func (s *tokenServiceRedis) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
|
||||
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
|
||||
if err != nil {
|
||||
// 如果JWT解析失败,尝试从数据库查询(向后兼容)
|
||||
return s.tokenRepo.GetUUIDByAccessToken(ctx, accessToken)
|
||||
return "", errors.New("accessToken无效")
|
||||
}
|
||||
|
||||
if claims.ProfileID != "" {
|
||||
@@ -420,11 +433,10 @@ func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken
|
||||
}
|
||||
|
||||
// GetUserIDByAccessToken 从AccessToken获取UserID(通过JWT解析)
|
||||
func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
|
||||
func (s *tokenServiceRedis) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
|
||||
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
|
||||
if err != nil {
|
||||
// 如果JWT解析失败,尝试从数据库查询(向后兼容)
|
||||
return s.tokenRepo.GetUserIDByAccessToken(ctx, accessToken)
|
||||
return 0, errors.New("accessToken无效")
|
||||
}
|
||||
|
||||
// 从Client获取UserID
|
||||
@@ -441,44 +453,8 @@ func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToke
|
||||
return client.UserID, nil
|
||||
}
|
||||
|
||||
// 私有辅助方法
|
||||
|
||||
func (s *tokenServiceJWT) checkAndCleanupExcessTokens(ctx context.Context, userID int64) {
|
||||
if userID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 为清理操作设置更长的超时时间
|
||||
ctx, cancel := context.WithTimeout(ctx, tokenExtendedTimeout)
|
||||
defer cancel()
|
||||
|
||||
tokens, err := s.tokenRepo.GetByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
}
|
||||
|
||||
if len(tokens) <= tokensMaxCount {
|
||||
return
|
||||
}
|
||||
|
||||
tokensToDelete := make([]string, 0, len(tokens)-tokensMaxCount)
|
||||
for i := tokensMaxCount; i < len(tokens); i++ {
|
||||
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
|
||||
}
|
||||
|
||||
deletedCount, err := s.tokenRepo.BatchDelete(ctx, tokensToDelete)
|
||||
if err != nil {
|
||||
s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
}
|
||||
|
||||
if deletedCount > 0 {
|
||||
s.logger.Info("成功清理用户多余Token", zap.Int64("userId", userID), zap.Int64("count", deletedCount))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *tokenServiceJWT) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) {
|
||||
// validateProfileByUserID 验证Profile是否属于用户
|
||||
func (s *tokenServiceRedis) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) {
|
||||
if userID == 0 || UUID == "" {
|
||||
return false, errors.New("用户ID或配置文件ID不能为空")
|
||||
}
|
||||
@@ -492,24 +468,3 @@ func (s *tokenServiceJWT) validateProfileByUserID(ctx context.Context, userID in
|
||||
}
|
||||
return profile.UserID == userID, nil
|
||||
}
|
||||
|
||||
// GetClientFromToken 从Token获取Client信息(辅助方法)
|
||||
func (s *tokenServiceJWT) GetClientFromToken(ctx context.Context, accessToken string, stalePolicy auth.StaleTokenPolicy) (*model.Client, error) {
|
||||
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, stalePolicy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证Version
|
||||
if claims.Version != client.Version {
|
||||
return nil, errors.New("token版本不匹配")
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
@@ -1,513 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TestTokenService_Constants 测试Token服务相关常量
|
||||
func TestTokenService_Constants(t *testing.T) {
|
||||
// 内部常量已私有化,通过服务行为间接测试
|
||||
t.Skip("Token constants are now private - test through service behavior instead")
|
||||
}
|
||||
|
||||
// TestTokenService_Validation 测试Token验证逻辑
|
||||
func TestTokenService_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "空token无效",
|
||||
accessToken: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "非空token可能有效",
|
||||
accessToken: "valid-token-string",
|
||||
wantValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 测试空token检查逻辑
|
||||
isValid := tt.accessToken != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Token validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_ClientTokenLogic 测试ClientToken逻辑
|
||||
func TestTokenService_ClientTokenLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientToken string
|
||||
shouldGenerate bool
|
||||
}{
|
||||
{
|
||||
name: "空的clientToken应该生成新的",
|
||||
clientToken: "",
|
||||
shouldGenerate: true,
|
||||
},
|
||||
{
|
||||
name: "非空的clientToken应该使用提供的",
|
||||
clientToken: "existing-client-token",
|
||||
shouldGenerate: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldGenerate := tt.clientToken == ""
|
||||
if shouldGenerate != tt.shouldGenerate {
|
||||
t.Errorf("ClientToken logic failed: got %v, want %v", shouldGenerate, tt.shouldGenerate)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_ProfileSelection 测试Profile选择逻辑
|
||||
func TestTokenService_ProfileSelection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileCount int
|
||||
shouldAutoSelect bool
|
||||
}{
|
||||
{
|
||||
name: "只有一个profile时自动选择",
|
||||
profileCount: 1,
|
||||
shouldAutoSelect: true,
|
||||
},
|
||||
{
|
||||
name: "多个profile时不自动选择",
|
||||
profileCount: 2,
|
||||
shouldAutoSelect: false,
|
||||
},
|
||||
{
|
||||
name: "没有profile时不自动选择",
|
||||
profileCount: 0,
|
||||
shouldAutoSelect: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldAutoSelect := tt.profileCount == 1
|
||||
if shouldAutoSelect != tt.shouldAutoSelect {
|
||||
t.Errorf("Profile selection logic failed: got %v, want %v", shouldAutoSelect, tt.shouldAutoSelect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_CleanupLogic 测试清理逻辑
|
||||
func TestTokenService_CleanupLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenCount int
|
||||
maxCount int
|
||||
shouldCleanup bool
|
||||
cleanupCount int
|
||||
}{
|
||||
{
|
||||
name: "token数量未超过上限,不需要清理",
|
||||
tokenCount: 5,
|
||||
maxCount: 10,
|
||||
shouldCleanup: false,
|
||||
cleanupCount: 0,
|
||||
},
|
||||
{
|
||||
name: "token数量超过上限,需要清理",
|
||||
tokenCount: 15,
|
||||
maxCount: 10,
|
||||
shouldCleanup: true,
|
||||
cleanupCount: 5,
|
||||
},
|
||||
{
|
||||
name: "token数量等于上限,不需要清理",
|
||||
tokenCount: 10,
|
||||
maxCount: 10,
|
||||
shouldCleanup: false,
|
||||
cleanupCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldCleanup := tt.tokenCount > tt.maxCount
|
||||
if shouldCleanup != tt.shouldCleanup {
|
||||
t.Errorf("Cleanup decision failed: got %v, want %v", shouldCleanup, tt.shouldCleanup)
|
||||
}
|
||||
|
||||
if shouldCleanup {
|
||||
expectedCleanupCount := tt.tokenCount - tt.maxCount
|
||||
if expectedCleanupCount != tt.cleanupCount {
|
||||
t.Errorf("Cleanup count failed: got %d, want %d", expectedCleanupCount, tt.cleanupCount)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_UserIDValidation 测试UserID验证
|
||||
func TestTokenService_UserIDValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
isValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的UserID",
|
||||
userID: 1,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "UserID为0时无效",
|
||||
userID: 0,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "负数UserID无效",
|
||||
userID: -1,
|
||||
isValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.userID > 0
|
||||
if isValid != tt.isValid {
|
||||
t.Errorf("UserID validation failed: got %v, want %v", isValid, tt.isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 使用 Mock 的集成测试
|
||||
// ============================================================================
|
||||
|
||||
// TestTokenServiceImpl_Create 测试创建Token
|
||||
func TestTokenServiceImpl_Create(t *testing.T) {
|
||||
tokenRepo := NewMockTokenRepository()
|
||||
profileRepo := NewMockProfileRepository()
|
||||
logger := zap.NewNop()
|
||||
|
||||
// 预置Profile
|
||||
testProfile := &model.Profile{
|
||||
UUID: "test-profile-uuid",
|
||||
UserID: 1,
|
||||
Name: "TestProfile",
|
||||
IsActive: true,
|
||||
}
|
||||
_ = profileRepo.Create(context.Background(), testProfile)
|
||||
|
||||
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
uuid string
|
||||
clientToken string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "正常创建Token(指定UUID)",
|
||||
userID: 1,
|
||||
uuid: "test-profile-uuid",
|
||||
clientToken: "client-token-1",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "正常创建Token(空clientToken)",
|
||||
userID: 1,
|
||||
uuid: "test-profile-uuid",
|
||||
clientToken: "",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, _, accessToken, clientToken, err := tokenService.Create(ctx, tt.userID, tt.uuid, tt.clientToken)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("期望返回错误,但实际没有错误")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("不期望返回错误: %v", err)
|
||||
return
|
||||
}
|
||||
if accessToken == "" {
|
||||
t.Error("accessToken不应为空")
|
||||
}
|
||||
if clientToken == "" {
|
||||
t.Error("clientToken不应为空")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenServiceImpl_Validate 测试验证Token
|
||||
func TestTokenServiceImpl_Validate(t *testing.T) {
|
||||
tokenRepo := NewMockTokenRepository()
|
||||
profileRepo := NewMockProfileRepository()
|
||||
logger := zap.NewNop()
|
||||
|
||||
// 预置Token
|
||||
testToken := &model.Token{
|
||||
AccessToken: "valid-access-token",
|
||||
ClientToken: "valid-client-token",
|
||||
UserID: 1,
|
||||
ProfileId: "test-profile-uuid",
|
||||
Usable: true,
|
||||
}
|
||||
_ = tokenRepo.Create(context.Background(), testToken)
|
||||
|
||||
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
clientToken string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效Token(完全匹配)",
|
||||
accessToken: "valid-access-token",
|
||||
clientToken: "valid-client-token",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "有效Token(只检查accessToken)",
|
||||
accessToken: "valid-access-token",
|
||||
clientToken: "",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "无效Token(accessToken不存在)",
|
||||
accessToken: "invalid-access-token",
|
||||
clientToken: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "无效Token(clientToken不匹配)",
|
||||
accessToken: "valid-access-token",
|
||||
clientToken: "wrong-client-token",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
isValid := tokenService.Validate(ctx, tt.accessToken, tt.clientToken)
|
||||
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Token验证结果不匹配: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenServiceImpl_Invalidate 测试注销Token
|
||||
func TestTokenServiceImpl_Invalidate(t *testing.T) {
|
||||
tokenRepo := NewMockTokenRepository()
|
||||
profileRepo := NewMockProfileRepository()
|
||||
logger := zap.NewNop()
|
||||
|
||||
// 预置Token
|
||||
testToken := &model.Token{
|
||||
AccessToken: "token-to-invalidate",
|
||||
ClientToken: "client-token",
|
||||
UserID: 1,
|
||||
ProfileId: "test-profile-uuid",
|
||||
Usable: true,
|
||||
}
|
||||
_ = tokenRepo.Create(context.Background(), testToken)
|
||||
|
||||
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 验证Token存在
|
||||
isValid := tokenService.Validate(ctx, "token-to-invalidate", "")
|
||||
if !isValid {
|
||||
t.Error("Token应该有效")
|
||||
}
|
||||
|
||||
// 注销Token
|
||||
tokenService.Invalidate(ctx, "token-to-invalidate")
|
||||
|
||||
// 验证Token已失效(从repo中删除)
|
||||
_, err := tokenRepo.FindByAccessToken(context.Background(), "token-to-invalidate")
|
||||
if err == nil {
|
||||
t.Error("Token应该已被删除")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenServiceImpl_InvalidateUserTokens 测试注销用户所有Token
|
||||
func TestTokenServiceImpl_InvalidateUserTokens(t *testing.T) {
|
||||
tokenRepo := NewMockTokenRepository()
|
||||
profileRepo := NewMockProfileRepository()
|
||||
logger := zap.NewNop()
|
||||
|
||||
// 预置多个Token
|
||||
for i := 1; i <= 3; i++ {
|
||||
_ = tokenRepo.Create(context.Background(), &model.Token{
|
||||
AccessToken: fmt.Sprintf("user1-token-%d", i),
|
||||
ClientToken: "client-token",
|
||||
UserID: 1,
|
||||
ProfileId: "test-profile-uuid",
|
||||
Usable: true,
|
||||
})
|
||||
}
|
||||
_ = tokenRepo.Create(context.Background(), &model.Token{
|
||||
AccessToken: "user2-token-1",
|
||||
ClientToken: "client-token",
|
||||
UserID: 2,
|
||||
ProfileId: "test-profile-uuid-2",
|
||||
Usable: true,
|
||||
})
|
||||
|
||||
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 注销用户1的所有Token
|
||||
tokenService.InvalidateUserTokens(ctx, 1)
|
||||
|
||||
// 验证用户1的Token已失效
|
||||
tokens, _ := tokenRepo.GetByUserID(context.Background(), 1)
|
||||
if len(tokens) > 0 {
|
||||
t.Errorf("用户1的Token应该全部被删除,但还剩 %d 个", len(tokens))
|
||||
}
|
||||
|
||||
// 验证用户2的Token仍然存在
|
||||
tokens2, _ := tokenRepo.GetByUserID(context.Background(), 2)
|
||||
if len(tokens2) != 1 {
|
||||
t.Errorf("用户2的Token应该仍然存在,期望1个,实际 %d 个", len(tokens2))
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenServiceImpl_Refresh 覆盖 Refresh 的主要分支
|
||||
func TestTokenServiceImpl_Refresh(t *testing.T) {
|
||||
tokenRepo := NewMockTokenRepository()
|
||||
profileRepo := NewMockProfileRepository()
|
||||
logger := zap.NewNop()
|
||||
|
||||
// 预置 Profile 与 Token
|
||||
profile := &model.Profile{
|
||||
UUID: "profile-uuid",
|
||||
UserID: 1,
|
||||
}
|
||||
_ = profileRepo.Create(context.Background(), profile)
|
||||
|
||||
oldToken := &model.Token{
|
||||
AccessToken: "old-token",
|
||||
ClientToken: "client-token",
|
||||
UserID: 1,
|
||||
ProfileId: "",
|
||||
Usable: true,
|
||||
}
|
||||
_ = tokenRepo.Create(context.Background(), oldToken)
|
||||
|
||||
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 正常刷新,不指定 profile
|
||||
newAccess, client, err := tokenService.Refresh(ctx, "old-token", "client-token", "")
|
||||
if err != nil {
|
||||
t.Fatalf("Refresh 正常情况失败: %v", err)
|
||||
}
|
||||
if newAccess == "" || client != "client-token" {
|
||||
t.Fatalf("Refresh 返回值异常: access=%s, client=%s", newAccess, client)
|
||||
}
|
||||
|
||||
// accessToken 为空
|
||||
if _, _, err := tokenService.Refresh(ctx, "", "client-token", ""); err == nil {
|
||||
t.Fatalf("Refresh 在 accessToken 为空时应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenServiceImpl_GetByAccessToken 封装 GetUUIDByAccessToken / GetUserIDByAccessToken
|
||||
func TestTokenServiceImpl_GetByAccessToken(t *testing.T) {
|
||||
tokenRepo := NewMockTokenRepository()
|
||||
profileRepo := NewMockProfileRepository()
|
||||
logger := zap.NewNop()
|
||||
|
||||
token := &model.Token{
|
||||
AccessToken: "token-1",
|
||||
UserID: 42,
|
||||
ProfileId: "profile-42",
|
||||
Usable: true,
|
||||
}
|
||||
_ = tokenRepo.Create(context.Background(), token)
|
||||
|
||||
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
uuid, err := tokenService.GetUUIDByAccessToken(ctx, "token-1")
|
||||
if err != nil || uuid != "profile-42" {
|
||||
t.Fatalf("GetUUIDByAccessToken 返回错误: uuid=%s, err=%v", uuid, err)
|
||||
}
|
||||
|
||||
uid, err := tokenService.GetUserIDByAccessToken(ctx, "token-1")
|
||||
if err != nil || uid != 42 {
|
||||
t.Fatalf("GetUserIDByAccessToken 返回错误: uid=%d, err=%v", uid, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenServiceImpl_validateProfileByUserID 直接测试内部校验逻辑
|
||||
func TestTokenServiceImpl_validateProfileByUserID(t *testing.T) {
|
||||
tokenRepo := NewMockTokenRepository()
|
||||
profileRepo := NewMockProfileRepository()
|
||||
logger := zap.NewNop()
|
||||
|
||||
svc := &tokenService{
|
||||
tokenRepo: tokenRepo,
|
||||
profileRepo: profileRepo,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
// 预置 Profile
|
||||
profile := &model.Profile{
|
||||
UUID: "p-1",
|
||||
UserID: 1,
|
||||
}
|
||||
_ = profileRepo.Create(context.Background(), profile)
|
||||
|
||||
// 参数非法
|
||||
if ok, err := svc.validateProfileByUserID(context.Background(), 0, ""); err == nil || ok {
|
||||
t.Fatalf("validateProfileByUserID 在参数非法时应返回错误")
|
||||
}
|
||||
|
||||
// Profile 不存在
|
||||
if ok, err := svc.validateProfileByUserID(context.Background(), 1, "not-exists"); err == nil || ok {
|
||||
t.Fatalf("validateProfileByUserID 在 Profile 不存在时应返回错误")
|
||||
}
|
||||
|
||||
// 用户与 Profile 匹配
|
||||
if ok, err := svc.validateProfileByUserID(context.Background(), 1, "p-1"); err != nil || !ok {
|
||||
t.Fatalf("validateProfileByUserID 匹配时应返回 true, err=%v", err)
|
||||
}
|
||||
|
||||
// 用户与 Profile 不匹配
|
||||
if ok, err := svc.validateProfileByUserID(context.Background(), 2, "p-1"); err != nil || ok {
|
||||
t.Fatalf("validateProfileByUserID 不匹配时应返回 false, err=%v", err)
|
||||
}
|
||||
}
|
||||
@@ -183,7 +183,7 @@ func (s *userService) GetByID(ctx context.Context, id int64) (*model.User, error
|
||||
cacheKey := s.cacheKeys.User(id)
|
||||
return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) {
|
||||
return s.userRepo.FindByID(ctx, id)
|
||||
}, 5*time.Minute)
|
||||
}, s.cache.Policy.UserTTL)
|
||||
}
|
||||
|
||||
func (s *userService) GetByEmail(ctx context.Context, email string) (*model.User, error) {
|
||||
@@ -191,7 +191,7 @@ func (s *userService) GetByEmail(ctx context.Context, email string) (*model.User
|
||||
cacheKey := s.cacheKeys.UserByEmail(email)
|
||||
return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) {
|
||||
return s.userRepo.FindByEmail(ctx, email)
|
||||
}, 5*time.Minute)
|
||||
}, s.cache.Policy.UserEmailTTL)
|
||||
}
|
||||
|
||||
func (s *userService) UpdateInfo(ctx context.Context, user *model.User) error {
|
||||
|
||||
@@ -22,7 +22,7 @@ type yggdrasilServiceComposite struct {
|
||||
serializationService SerializationService
|
||||
certificateService CertificateService
|
||||
profileRepo repository.ProfileRepository
|
||||
tokenRepo repository.TokenRepository
|
||||
tokenService TokenService // 使用TokenService接口,不直接依赖TokenRepository
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
@@ -31,11 +31,11 @@ func NewYggdrasilServiceComposite(
|
||||
db *gorm.DB,
|
||||
userRepo repository.UserRepository,
|
||||
profileRepo repository.ProfileRepository,
|
||||
tokenRepo repository.TokenRepository,
|
||||
yggdrasilRepo repository.YggdrasilRepository,
|
||||
signatureService *SignatureService,
|
||||
redisClient *redis.Client,
|
||||
logger *zap.Logger,
|
||||
tokenService TokenService, // 新增:TokenService接口
|
||||
) YggdrasilService {
|
||||
// 创建各个专门的服务
|
||||
authService := NewYggdrasilAuthService(db, userRepo, yggdrasilRepo, logger)
|
||||
@@ -53,7 +53,7 @@ func NewYggdrasilServiceComposite(
|
||||
serializationService: serializationService,
|
||||
certificateService: certificateService,
|
||||
profileRepo: profileRepo,
|
||||
tokenRepo: tokenRepo,
|
||||
tokenService: tokenService,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
@@ -75,8 +75,8 @@ func (s *yggdrasilServiceComposite) ResetYggdrasilPassword(ctx context.Context,
|
||||
|
||||
// JoinServer 加入服务器
|
||||
func (s *yggdrasilServiceComposite) JoinServer(ctx context.Context, serverID, accessToken, selectedProfile, ip string) error {
|
||||
// 验证Token
|
||||
token, err := s.tokenRepo.FindByAccessToken(ctx, accessToken)
|
||||
// 通过TokenService验证Token并获取UUID
|
||||
uuid, err := s.tokenService.GetUUIDByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
s.logger.Error("验证Token失败",
|
||||
zap.Error(err),
|
||||
@@ -87,7 +87,7 @@ func (s *yggdrasilServiceComposite) JoinServer(ctx context.Context, serverID, ac
|
||||
|
||||
// 格式化UUID并验证与Token关联的配置文件
|
||||
formattedProfile := utils.FormatUUID(selectedProfile)
|
||||
if token.ProfileId != formattedProfile {
|
||||
if uuid != formattedProfile {
|
||||
return errors.New("selectedProfile与Token不匹配")
|
||||
}
|
||||
|
||||
|
||||
168
internal/task/runner.go
Normal file
168
internal/task/runner.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Task 定义可调度任务
|
||||
type Task interface {
|
||||
Name() string
|
||||
Interval() time.Duration
|
||||
Run(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Runner 简单的周期任务调度器
|
||||
type Runner struct {
|
||||
tasks []Task
|
||||
logger *zap.Logger
|
||||
wg sync.WaitGroup
|
||||
startImmediately bool
|
||||
jitterPercent float64
|
||||
}
|
||||
|
||||
// NewRunner 创建任务调度器
|
||||
func NewRunner(logger *zap.Logger, tasks ...Task) *Runner {
|
||||
return NewRunnerWithOptions(logger, tasks)
|
||||
}
|
||||
|
||||
// RunnerOption 运行器配置项
|
||||
type RunnerOption func(r *Runner)
|
||||
|
||||
// WithStartImmediately 是否启动后立即执行一次(默认 true)
|
||||
func WithStartImmediately(start bool) RunnerOption {
|
||||
return func(r *Runner) {
|
||||
r.startImmediately = start
|
||||
}
|
||||
}
|
||||
|
||||
// WithJitter 为执行间隔增加 0~percent 之间的随机抖动(percent=0 关闭,默认0)
|
||||
// 可降低多个任务同时触发的概率
|
||||
func WithJitter(percent float64) RunnerOption {
|
||||
return func(r *Runner) {
|
||||
if percent < 0 {
|
||||
percent = 0
|
||||
}
|
||||
r.jitterPercent = percent
|
||||
}
|
||||
}
|
||||
|
||||
// NewRunnerWithOptions 支持可选配置的创建函数
|
||||
func NewRunnerWithOptions(logger *zap.Logger, tasks []Task, opts ...RunnerOption) *Runner {
|
||||
r := &Runner{
|
||||
tasks: tasks,
|
||||
logger: logger,
|
||||
startImmediately: true,
|
||||
jitterPercent: 0,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Start 启动所有任务(异步)
|
||||
func (r *Runner) Start(ctx context.Context) {
|
||||
for _, t := range r.tasks {
|
||||
task := t
|
||||
r.wg.Add(1)
|
||||
go func() {
|
||||
defer r.wg.Done()
|
||||
defer r.recoverPanic(task)
|
||||
|
||||
interval := r.normalizeInterval(task.Interval())
|
||||
|
||||
// 可选:立即执行一次
|
||||
if r.startImmediately {
|
||||
r.runOnce(ctx, task)
|
||||
}
|
||||
|
||||
// 周期执行
|
||||
for {
|
||||
wait := r.applyJitter(interval)
|
||||
if !r.wait(ctx, wait) {
|
||||
return
|
||||
}
|
||||
|
||||
// 每轮读取最新的 interval,允许任务动态调整间隔
|
||||
interval = r.normalizeInterval(task.Interval())
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
r.runOnce(ctx, task)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Wait 等待所有任务退出
|
||||
func (r *Runner) Wait() {
|
||||
r.wg.Wait()
|
||||
}
|
||||
|
||||
func (r *Runner) runOnce(ctx context.Context, task Task) {
|
||||
if err := task.Run(ctx); err != nil && r.logger != nil {
|
||||
r.logger.Warn("任务执行失败", zap.String("task", task.Name()), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeInterval 确保间隔为正值
|
||||
func (r *Runner) normalizeInterval(d time.Duration) time.Duration {
|
||||
if d <= 0 {
|
||||
return time.Minute
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// applyJitter 在基础间隔上添加最多 jitterPercent 的随机抖动
|
||||
func (r *Runner) applyJitter(base time.Duration) time.Duration {
|
||||
if r.jitterPercent <= 0 {
|
||||
return base
|
||||
}
|
||||
maxJitter := time.Duration(float64(base) * r.jitterPercent)
|
||||
if maxJitter <= 0 {
|
||||
return base
|
||||
}
|
||||
return base + time.Duration(rand.Int63n(int64(maxJitter)))
|
||||
}
|
||||
|
||||
// wait 封装带 context 的 sleep
|
||||
func (r *Runner) wait(ctx context.Context, d time.Duration) bool {
|
||||
if d <= 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-timer.C:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// recoverPanic 防止任务 panic 导致 goroutine 退出
|
||||
func (r *Runner) recoverPanic(task Task) {
|
||||
if rec := recover(); rec != nil && r.logger != nil {
|
||||
r.logger.Error("任务发生panic",
|
||||
zap.String("task", task.Name()),
|
||||
zap.Any("panic", rec),
|
||||
zap.ByteString("stack", debug.Stack()),
|
||||
)
|
||||
}
|
||||
}
|
||||
65
internal/task/runner_test.go
Normal file
65
internal/task/runner_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type mockTask struct {
|
||||
name string
|
||||
interval time.Duration
|
||||
err error
|
||||
runCount *atomic.Int32
|
||||
}
|
||||
|
||||
func (m *mockTask) Name() string { return m.name }
|
||||
func (m *mockTask) Interval() time.Duration { return m.interval }
|
||||
func (m *mockTask) Run(ctx context.Context) error {
|
||||
if m.runCount != nil {
|
||||
m.runCount.Add(1)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func TestRunner_StartAndWait(t *testing.T) {
|
||||
runCount := &atomic.Int32{}
|
||||
task := &mockTask{name: "ok", interval: 20 * time.Millisecond, runCount: runCount}
|
||||
runner := NewRunner(zap.NewNop(), task)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
runner.Start(ctx)
|
||||
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
cancel()
|
||||
runner.Wait()
|
||||
|
||||
if runCount.Load() == 0 {
|
||||
t.Fatalf("expected task to run at least once")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_RunErrorLogged(t *testing.T) {
|
||||
runCount := &atomic.Int32{}
|
||||
task := &mockTask{name: "err", interval: 10 * time.Millisecond, err: errors.New("boom"), runCount: runCount}
|
||||
runner := NewRunner(zap.NewNop(), task)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
runner.Start(ctx)
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
cancel()
|
||||
runner.Wait()
|
||||
|
||||
if runCount.Load() == 0 {
|
||||
t.Fatalf("expected task to be attempted")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
56
internal/testutil/testutil.go
Normal file
56
internal/testutil/testutil.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// NewTestDB 返回基于内存的 sqlite 数据库并完成模型迁移
|
||||
func NewTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite memory db: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(
|
||||
&model.User{},
|
||||
&model.UserPointLog{},
|
||||
&model.UserLoginLog{},
|
||||
&model.Profile{},
|
||||
&model.Texture{},
|
||||
&model.UserTextureFavorite{},
|
||||
&model.TextureDownloadLog{},
|
||||
&model.Client{},
|
||||
&model.Yggdrasil{},
|
||||
&model.SystemConfig{},
|
||||
&model.AuditLog{},
|
||||
&model.CasbinRule{},
|
||||
); err != nil {
|
||||
t.Fatalf("failed to migrate models: %v", err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// NewNoopLogger 返回无输出 logger
|
||||
func NewNoopLogger() *zap.Logger {
|
||||
return zap.NewNop()
|
||||
}
|
||||
|
||||
// NewTestCache 返回禁用 redis 的缓存管理器(用于单元测试)
|
||||
func NewTestCache() *database.CacheManager {
|
||||
return database.NewCacheManager(nil, database.CacheConfig{
|
||||
Prefix: "test:",
|
||||
Expiration: 1 * time.Minute,
|
||||
Enabled: false,
|
||||
})
|
||||
}
|
||||
27
internal/testutil/testutil_test.go
Normal file
27
internal/testutil/testutil_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package testutil
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNewTestDB(t *testing.T) {
|
||||
db := NewTestDB(t)
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("DB() err: %v", err)
|
||||
}
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
t.Fatalf("ping err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTestCache(t *testing.T) {
|
||||
cache := NewTestCache()
|
||||
if cache.Policy.UserTTL == 0 {
|
||||
t.Fatalf("expected defaults filled")
|
||||
}
|
||||
// disabled cache should not error on Set
|
||||
if err := cache.Set(nil, "k", "v"); err != nil {
|
||||
t.Fatalf("Set on disabled cache should be nil err, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user