Merge remote-tracking branch 'origin/feature/redis-auth-integration' into dev
# Conflicts: # go.mod # go.sum # internal/container/container.go # internal/repository/interfaces.go # internal/service/mocks_test.go # internal/service/texture_service_test.go # internal/service/token_service_test.go # pkg/redis/manager.go
This commit is contained in:
@@ -214,6 +214,10 @@ func (m *MockProfileRepository) CountByUserID(ctx context.Context, userID int64)
|
||||
return int64(len(m.userProfiles[userID])), nil
|
||||
}
|
||||
|
||||
func (m *MockProfileRepository) SetActive(ctx context.Context, uuid string, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockProfileRepository) UpdateLastUsedAt(ctx context.Context, uuid string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -470,99 +474,51 @@ func (m *MockTextureRepository) BatchDelete(ctx context.Context, ids []int64) (i
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// MockTokenRepository 模拟TokenRepository
|
||||
type MockTokenRepository struct {
|
||||
tokens map[string]*model.Token
|
||||
userTokens map[int64][]*model.Token
|
||||
FailCreate bool
|
||||
FailFind bool
|
||||
FailDelete bool
|
||||
// MockSystemConfigRepository 模拟SystemConfigRepository
|
||||
type MockSystemConfigRepository struct {
|
||||
configs map[string]*model.SystemConfig
|
||||
}
|
||||
|
||||
func NewMockTokenRepository() *MockTokenRepository {
|
||||
return &MockTokenRepository{
|
||||
tokens: make(map[string]*model.Token),
|
||||
userTokens: make(map[int64][]*model.Token),
|
||||
func NewMockSystemConfigRepository() *MockSystemConfigRepository {
|
||||
return &MockSystemConfigRepository{
|
||||
configs: make(map[string]*model.SystemConfig),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) Create(ctx context.Context, token *model.Token) error {
|
||||
if m.FailCreate {
|
||||
return errors.New("mock create error")
|
||||
func (m *MockSystemConfigRepository) GetByKey(ctx context.Context, key string) (*model.SystemConfig, error) {
|
||||
if config, ok := m.configs[key]; ok {
|
||||
return config, nil
|
||||
}
|
||||
m.tokens[token.AccessToken] = token
|
||||
m.userTokens[token.UserID] = append(m.userTokens[token.UserID], token)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockSystemConfigRepository) GetPublic(ctx context.Context) ([]model.SystemConfig, error) {
|
||||
var result []model.SystemConfig
|
||||
for _, v := range m.configs {
|
||||
result = append(result, *v)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *MockSystemConfigRepository) GetAll(ctx context.Context) ([]model.SystemConfig, error) {
|
||||
var result []model.SystemConfig
|
||||
for _, v := range m.configs {
|
||||
result = append(result, *v)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *MockSystemConfigRepository) Update(ctx context.Context, config *model.SystemConfig) error {
|
||||
m.configs[config.Key] = config
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error) {
|
||||
if m.FailFind {
|
||||
return nil, errors.New("mock find error")
|
||||
func (m *MockSystemConfigRepository) UpdateValue(ctx context.Context, key, value string) error {
|
||||
if config, ok := m.configs[key]; ok {
|
||||
config.Value = value
|
||||
return nil
|
||||
}
|
||||
if token, ok := m.tokens[accessToken]; ok {
|
||||
return token, nil
|
||||
}
|
||||
return nil, errors.New("token not found")
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error) {
|
||||
if m.FailFind {
|
||||
return nil, errors.New("mock find error")
|
||||
}
|
||||
return m.userTokens[userId], nil
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
|
||||
if m.FailFind {
|
||||
return "", errors.New("mock find error")
|
||||
}
|
||||
if token, ok := m.tokens[accessToken]; ok {
|
||||
return token.ProfileId, nil
|
||||
}
|
||||
return "", errors.New("token not found")
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
|
||||
if m.FailFind {
|
||||
return 0, errors.New("mock find error")
|
||||
}
|
||||
if token, ok := m.tokens[accessToken]; ok {
|
||||
return token.UserID, nil
|
||||
}
|
||||
return 0, errors.New("token not found")
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) DeleteByAccessToken(ctx context.Context, accessToken string) error {
|
||||
if m.FailDelete {
|
||||
return errors.New("mock delete error")
|
||||
}
|
||||
delete(m.tokens, accessToken)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) DeleteByUserID(ctx context.Context, userId int64) error {
|
||||
if m.FailDelete {
|
||||
return errors.New("mock delete error")
|
||||
}
|
||||
for _, token := range m.userTokens[userId] {
|
||||
delete(m.tokens, token.AccessToken)
|
||||
}
|
||||
m.userTokens[userId] = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) BatchDelete(ctx context.Context, accessTokens []string) (int64, error) {
|
||||
if m.FailDelete {
|
||||
return 0, errors.New("mock delete error")
|
||||
}
|
||||
var count int64
|
||||
for _, accessToken := range accessTokens {
|
||||
if _, ok := m.tokens[accessToken]; ok {
|
||||
delete(m.tokens, accessToken)
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
return errors.New("config not found")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@@ -757,6 +713,10 @@ func (m *MockProfileService) Delete(uuid string, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockProfileService) SetActive(uuid string, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockProfileService) CheckLimit(userID int64, maxProfiles int) error {
|
||||
count := 0
|
||||
for _, profile := range m.profiles {
|
||||
@@ -913,90 +873,11 @@ func (m *MockTextureService) CheckUploadLimit(uploaderID int64, maxTextures int)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MockTokenService 模拟TokenService
|
||||
type MockTokenService struct {
|
||||
tokens map[string]*model.Token
|
||||
FailCreate bool
|
||||
FailValidate bool
|
||||
FailRefresh bool
|
||||
}
|
||||
|
||||
func NewMockTokenService() *MockTokenService {
|
||||
return &MockTokenService{
|
||||
tokens: make(map[string]*model.Token),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockTokenService) Create(userID int64, uuid, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
|
||||
if m.FailCreate {
|
||||
return nil, nil, "", "", errors.New("mock create error")
|
||||
}
|
||||
accessToken := "mock-access-token"
|
||||
if clientToken == "" {
|
||||
clientToken = "mock-client-token"
|
||||
}
|
||||
token := &model.Token{
|
||||
AccessToken: accessToken,
|
||||
ClientToken: clientToken,
|
||||
UserID: userID,
|
||||
ProfileId: uuid,
|
||||
Usable: true,
|
||||
}
|
||||
m.tokens[accessToken] = token
|
||||
return nil, nil, accessToken, clientToken, nil
|
||||
}
|
||||
|
||||
func (m *MockTokenService) Validate(accessToken, clientToken string) bool {
|
||||
if m.FailValidate {
|
||||
return false
|
||||
}
|
||||
if token, ok := m.tokens[accessToken]; ok {
|
||||
if clientToken == "" || token.ClientToken == clientToken {
|
||||
return token.Usable
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MockTokenService) Refresh(accessToken, clientToken, selectedProfileID string) (string, string, error) {
|
||||
if m.FailRefresh {
|
||||
return "", "", errors.New("mock refresh error")
|
||||
}
|
||||
return "new-access-token", clientToken, nil
|
||||
}
|
||||
|
||||
func (m *MockTokenService) Invalidate(accessToken string) {
|
||||
delete(m.tokens, accessToken)
|
||||
}
|
||||
|
||||
func (m *MockTokenService) InvalidateUserTokens(userID int64) {
|
||||
for key, token := range m.tokens {
|
||||
if token.UserID == userID {
|
||||
delete(m.tokens, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockTokenService) GetUUIDByAccessToken(accessToken string) (string, error) {
|
||||
if token, ok := m.tokens[accessToken]; ok {
|
||||
return token.ProfileId, nil
|
||||
}
|
||||
return "", errors.New("token not found")
|
||||
}
|
||||
|
||||
func (m *MockTokenService) GetUserIDByAccessToken(accessToken string) (int64, error) {
|
||||
if token, ok := m.tokens[accessToken]; ok {
|
||||
return token.UserID, nil
|
||||
}
|
||||
return 0, errors.New("token not found")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// CacheManager Mock - uses database.CacheManager with nil redis
|
||||
// CacheManager Mock - 使用 database.CacheManager 的内存版本
|
||||
// ============================================================================
|
||||
|
||||
// NewMockCacheManager 创建一个禁用的 CacheManager 用于测试
|
||||
// 通过设置 Enabled = false,缓存操作会被跳过,测试不依赖 Redis
|
||||
// NewMockCacheManager 创建一个内存 CacheManager 用于测试
|
||||
func NewMockCacheManager() *database.CacheManager {
|
||||
return database.NewCacheManager(nil, database.CacheConfig{
|
||||
Prefix: "test:",
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
@@ -93,7 +92,7 @@ func (s *profileService) GetByUUID(ctx context.Context, uuid string) (*model.Pro
|
||||
// 尝试从缓存获取
|
||||
cacheKey := s.cacheKeys.Profile(uuid)
|
||||
var profile model.Profile
|
||||
if err := s.cache.Get(ctx, cacheKey, &profile); err == nil {
|
||||
if ok, _ := s.cache.TryGet(ctx, cacheKey, &profile); ok {
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
@@ -106,11 +105,9 @@ func (s *profileService) GetByUUID(ctx context.Context, uuid string) (*model.Pro
|
||||
return nil, fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 存入缓存(异步,5分钟过期)
|
||||
// 存入缓存(异步)
|
||||
if profile2 != nil {
|
||||
go func() {
|
||||
_ = s.cache.Set(context.Background(), cacheKey, profile2, 5*time.Minute)
|
||||
}()
|
||||
s.cache.SetAsync(context.Background(), cacheKey, profile2, s.cache.Policy.ProfileTTL)
|
||||
}
|
||||
|
||||
return profile2, nil
|
||||
@@ -120,7 +117,7 @@ func (s *profileService) GetByUserID(ctx context.Context, userID int64) ([]*mode
|
||||
// 尝试从缓存获取
|
||||
cacheKey := s.cacheKeys.ProfileList(userID)
|
||||
var profiles []*model.Profile
|
||||
if err := s.cache.Get(ctx, cacheKey, &profiles); err == nil {
|
||||
if ok, _ := s.cache.TryGet(ctx, cacheKey, &profiles); ok {
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
@@ -130,11 +127,9 @@ func (s *profileService) GetByUserID(ctx context.Context, userID int64) ([]*mode
|
||||
return nil, fmt.Errorf("查询档案列表失败: %w", err)
|
||||
}
|
||||
|
||||
// 存入缓存(异步,3分钟过期)
|
||||
// 存入缓存(异步)
|
||||
if profiles != nil {
|
||||
go func() {
|
||||
_ = s.cache.Set(context.Background(), cacheKey, profiles, 3*time.Minute)
|
||||
}()
|
||||
s.cache.SetAsync(context.Background(), cacheKey, profiles, s.cache.Policy.ProfileListTTL)
|
||||
}
|
||||
|
||||
return profiles, nil
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -52,7 +51,7 @@ func (s *textureService) GetByID(ctx context.Context, id int64) (*model.Texture,
|
||||
// 尝试从缓存获取
|
||||
cacheKey := s.cacheKeys.Texture(id)
|
||||
var texture model.Texture
|
||||
if err := s.cache.Get(ctx, cacheKey, &texture); err == nil {
|
||||
if ok, _ := s.cache.TryGet(ctx, cacheKey, &texture); ok {
|
||||
if texture.Status == -1 {
|
||||
return nil, errors.New("材质已删除")
|
||||
}
|
||||
@@ -71,11 +70,9 @@ func (s *textureService) GetByID(ctx context.Context, id int64) (*model.Texture,
|
||||
return nil, errors.New("材质已删除")
|
||||
}
|
||||
|
||||
// 存入缓存(异步,5分钟过期)
|
||||
// 存入缓存(异步)
|
||||
if texture2 != nil {
|
||||
go func() {
|
||||
_ = s.cache.Set(context.Background(), cacheKey, texture2, 5*time.Minute)
|
||||
}()
|
||||
s.cache.SetAsync(context.Background(), cacheKey, texture2, s.cache.Policy.TextureTTL)
|
||||
}
|
||||
|
||||
return texture2, nil
|
||||
@@ -85,7 +82,7 @@ func (s *textureService) GetByHash(ctx context.Context, hash string) (*model.Tex
|
||||
// 尝试从缓存获取
|
||||
cacheKey := s.cacheKeys.TextureByHash(hash)
|
||||
var texture model.Texture
|
||||
if err := s.cache.Get(ctx, cacheKey, &texture); err == nil {
|
||||
if ok, _ := s.cache.TryGet(ctx, cacheKey, &texture); ok {
|
||||
if texture.Status == -1 {
|
||||
return nil, errors.New("材质已删除")
|
||||
}
|
||||
@@ -104,10 +101,8 @@ func (s *textureService) GetByHash(ctx context.Context, hash string) (*model.Tex
|
||||
return nil, errors.New("材质已删除")
|
||||
}
|
||||
|
||||
// 存入缓存(异步,5分钟过期)
|
||||
go func() {
|
||||
_ = s.cache.Set(context.Background(), cacheKey, texture2, 5*time.Minute)
|
||||
}()
|
||||
// 存入缓存(异步)
|
||||
s.cache.SetAsync(context.Background(), cacheKey, texture2, s.cache.Policy.TextureTTL)
|
||||
|
||||
return texture2, nil
|
||||
}
|
||||
@@ -121,7 +116,7 @@ func (s *textureService) GetByUserID(ctx context.Context, uploaderID int64, page
|
||||
Textures []*model.Texture
|
||||
Total int64
|
||||
}
|
||||
if err := s.cache.Get(ctx, cacheKey, &cachedResult); err == nil {
|
||||
if ok, _ := s.cache.TryGet(ctx, cacheKey, &cachedResult); ok {
|
||||
return cachedResult.Textures, cachedResult.Total, nil
|
||||
}
|
||||
|
||||
@@ -131,14 +126,12 @@ func (s *textureService) GetByUserID(ctx context.Context, uploaderID int64, page
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 存入缓存(异步,2分钟过期)
|
||||
go func() {
|
||||
result := struct {
|
||||
Textures []*model.Texture
|
||||
Total int64
|
||||
}{Textures: textures, Total: total}
|
||||
_ = s.cache.Set(context.Background(), cacheKey, result, 2*time.Minute)
|
||||
}()
|
||||
// 存入缓存(异步)
|
||||
result := struct {
|
||||
Textures []*model.Texture
|
||||
Total int64
|
||||
}{Textures: textures, Total: total}
|
||||
s.cache.SetAsync(context.Background(), cacheKey, result, s.cache.Policy.TextureListTTL)
|
||||
|
||||
return textures, total, nil
|
||||
}
|
||||
@@ -181,7 +174,7 @@ func (s *textureService) Update(ctx context.Context, textureID, uploaderID int64
|
||||
|
||||
// 清除 texture 缓存和用户列表缓存
|
||||
s.cacheInv.OnUpdate(ctx, s.cacheKeys.Texture(textureID))
|
||||
s.cacheInv.BatchInvalidate(ctx, fmt.Sprintf("texture:user:%d:*", uploaderID))
|
||||
s.cacheInv.BatchInvalidate(ctx, s.cacheKeys.TextureListPattern(uploaderID))
|
||||
|
||||
return s.textureRepo.FindByID(ctx, textureID)
|
||||
}
|
||||
@@ -206,7 +199,7 @@ func (s *textureService) Delete(ctx context.Context, textureID, uploaderID int64
|
||||
|
||||
// 清除 texture 缓存和用户列表缓存
|
||||
s.cacheInv.OnDelete(ctx, s.cacheKeys.Texture(textureID))
|
||||
s.cacheInv.BatchInvalidate(ctx, fmt.Sprintf("texture:user:%d:*", uploaderID))
|
||||
s.cacheInv.BatchInvalidate(ctx, s.cacheKeys.TextureListPattern(uploaderID))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -478,6 +478,128 @@ func boolPtr(b bool) *bool {
|
||||
// 使用 Mock 的集成测试
|
||||
// ============================================================================
|
||||
|
||||
// TestTextureServiceImpl_Create 测试创建Texture
|
||||
func TestTextureServiceImpl_Create(t *testing.T) {
|
||||
textureRepo := NewMockTextureRepository()
|
||||
userRepo := NewMockUserRepository()
|
||||
logger := zap.NewNop()
|
||||
|
||||
// 预置用户
|
||||
testUser := &model.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Status: 1,
|
||||
}
|
||||
_ = userRepo.Create(context.Background(), testUser)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
textureService := NewTextureService(textureRepo, userRepo, nil, cacheManager, logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uploaderID int64
|
||||
textureName string
|
||||
textureType string
|
||||
hash string
|
||||
wantErr bool
|
||||
errContains string
|
||||
setupMocks func()
|
||||
}{
|
||||
{
|
||||
name: "正常创建SKIN材质",
|
||||
uploaderID: 1,
|
||||
textureName: "TestSkin",
|
||||
textureType: "SKIN",
|
||||
hash: "unique-hash-1",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "正常创建CAPE材质",
|
||||
uploaderID: 1,
|
||||
textureName: "TestCape",
|
||||
textureType: "CAPE",
|
||||
hash: "unique-hash-2",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "用户不存在",
|
||||
uploaderID: 999,
|
||||
textureName: "TestTexture",
|
||||
textureType: "SKIN",
|
||||
hash: "unique-hash-3",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "材质Hash已存在",
|
||||
uploaderID: 1,
|
||||
textureName: "DuplicateTexture",
|
||||
textureType: "SKIN",
|
||||
hash: "existing-hash",
|
||||
wantErr: false,
|
||||
setupMocks: func() {
|
||||
_ = textureRepo.Create(context.Background(), &model.Texture{
|
||||
ID: 100,
|
||||
UploaderID: 1,
|
||||
Name: "ExistingTexture",
|
||||
Hash: "existing-hash",
|
||||
})
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "无效的材质类型",
|
||||
uploaderID: 1,
|
||||
textureName: "InvalidTypeTexture",
|
||||
textureType: "INVALID",
|
||||
hash: "unique-hash-4",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.setupMocks != nil {
|
||||
tt.setupMocks()
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
texture, err := textureService.Create(
|
||||
ctx,
|
||||
tt.uploaderID,
|
||||
tt.textureName,
|
||||
"Test description",
|
||||
tt.textureType,
|
||||
"http://example.com/texture.png",
|
||||
tt.hash,
|
||||
512,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("期望返回错误,但实际没有错误")
|
||||
return
|
||||
}
|
||||
if tt.errContains != "" && !containsString(err.Error(), tt.errContains) {
|
||||
t.Errorf("错误信息应包含 %q, 实际为: %v", tt.errContains, err.Error())
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("不期望返回错误: %v", err)
|
||||
return
|
||||
}
|
||||
if texture == nil {
|
||||
t.Error("返回的Texture不应为nil")
|
||||
}
|
||||
if texture.Name != tt.textureName {
|
||||
t.Errorf("Texture名称不匹配: got %v, want %v", texture.Name, tt.textureName)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextureServiceImpl_GetByID 测试获取Texture
|
||||
func TestTextureServiceImpl_GetByID(t *testing.T) {
|
||||
textureRepo := NewMockTextureRepository()
|
||||
|
||||
@@ -1,305 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// tokenService TokenService的实现
|
||||
type tokenService struct {
|
||||
tokenRepo repository.TokenRepository
|
||||
profileRepo repository.ProfileRepository
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewTokenService 创建TokenService实例
|
||||
func NewTokenService(
|
||||
tokenRepo repository.TokenRepository,
|
||||
profileRepo repository.ProfileRepository,
|
||||
logger *zap.Logger,
|
||||
) TokenService {
|
||||
return &tokenService{
|
||||
tokenRepo: tokenRepo,
|
||||
profileRepo: profileRepo,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
tokenExtendedTimeout = 10 * time.Second
|
||||
tokensMaxCount = 10
|
||||
)
|
||||
|
||||
func (s *tokenService) Create(ctx context.Context, userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
|
||||
var (
|
||||
selectedProfileID *model.Profile
|
||||
availableProfiles []*model.Profile
|
||||
)
|
||||
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
// 验证用户存在
|
||||
if UUID != "" {
|
||||
_, err := s.profileRepo.FindByUUID(ctx, UUID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 生成令牌
|
||||
if clientToken == "" {
|
||||
clientToken = uuid.New().String()
|
||||
}
|
||||
|
||||
accessToken := uuid.New().String()
|
||||
token := model.Token{
|
||||
AccessToken: accessToken,
|
||||
ClientToken: clientToken,
|
||||
UserID: userID,
|
||||
Usable: true,
|
||||
IssueDate: time.Now(),
|
||||
}
|
||||
|
||||
// 获取用户配置文件
|
||||
profiles, err := s.profileRepo.FindByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 如果用户只有一个配置文件,自动选择
|
||||
if len(profiles) == 1 {
|
||||
selectedProfileID = profiles[0]
|
||||
token.ProfileId = selectedProfileID.UUID
|
||||
}
|
||||
availableProfiles = profiles
|
||||
|
||||
// 插入令牌
|
||||
err = s.tokenRepo.Create(ctx, &token)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err)
|
||||
}
|
||||
|
||||
// 清理多余的令牌(使用独立的后台上下文)
|
||||
go s.checkAndCleanupExcessTokens(context.Background(), userID)
|
||||
|
||||
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
|
||||
}
|
||||
|
||||
func (s *tokenService) Validate(ctx context.Context, accessToken, clientToken string) bool {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
token, err := s.tokenRepo.FindByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !token.Usable {
|
||||
return false
|
||||
}
|
||||
|
||||
if clientToken == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
return token.ClientToken == clientToken
|
||||
}
|
||||
|
||||
func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("accessToken不能为空")
|
||||
}
|
||||
|
||||
// 查找旧令牌
|
||||
oldToken, err := s.tokenRepo.FindByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", "", errors.New("accessToken无效")
|
||||
}
|
||||
s.logger.Error("查询Token失败", zap.Error(err), zap.String("accessToken", accessToken))
|
||||
return "", "", fmt.Errorf("查询令牌失败: %w", err)
|
||||
}
|
||||
|
||||
// 验证profile
|
||||
if selectedProfileID != "" {
|
||||
valid, validErr := s.validateProfileByUserID(ctx, oldToken.UserID, selectedProfileID)
|
||||
if validErr != nil {
|
||||
s.logger.Error("验证Profile失败",
|
||||
zap.Error(err),
|
||||
zap.Int64("userId", oldToken.UserID),
|
||||
zap.String("profileId", selectedProfileID),
|
||||
)
|
||||
return "", "", fmt.Errorf("验证角色失败: %w", err)
|
||||
}
|
||||
if !valid {
|
||||
return "", "", errors.New("角色与用户不匹配")
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 clientToken 是否有效
|
||||
if clientToken != "" && clientToken != oldToken.ClientToken {
|
||||
return "", "", errors.New("clientToken无效")
|
||||
}
|
||||
|
||||
// 检查 selectedProfileID 的逻辑
|
||||
if selectedProfileID != "" {
|
||||
if oldToken.ProfileId != "" && oldToken.ProfileId != selectedProfileID {
|
||||
return "", "", errors.New("原令牌已绑定角色,无法选择新角色")
|
||||
}
|
||||
} else {
|
||||
selectedProfileID = oldToken.ProfileId
|
||||
}
|
||||
|
||||
// 生成新令牌
|
||||
newAccessToken := uuid.New().String()
|
||||
newToken := model.Token{
|
||||
AccessToken: newAccessToken,
|
||||
ClientToken: oldToken.ClientToken,
|
||||
UserID: oldToken.UserID,
|
||||
Usable: true,
|
||||
ProfileId: selectedProfileID,
|
||||
IssueDate: time.Now(),
|
||||
}
|
||||
|
||||
// 先插入新令牌,再删除旧令牌
|
||||
err = s.tokenRepo.Create(ctx, &newToken)
|
||||
if err != nil {
|
||||
s.logger.Error("创建新Token失败", zap.Error(err), zap.String("accessToken", accessToken))
|
||||
return "", "", fmt.Errorf("创建新Token失败: %w", err)
|
||||
}
|
||||
|
||||
err = s.tokenRepo.DeleteByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
s.logger.Warn("删除旧Token失败,但新Token已创建",
|
||||
zap.Error(err),
|
||||
zap.String("oldToken", oldToken.AccessToken),
|
||||
zap.String("newToken", newAccessToken),
|
||||
)
|
||||
}
|
||||
|
||||
s.logger.Info("成功刷新Token", zap.Int64("userId", oldToken.UserID), zap.String("accessToken", newAccessToken))
|
||||
return newAccessToken, oldToken.ClientToken, nil
|
||||
}
|
||||
|
||||
func (s *tokenService) Invalidate(ctx context.Context, accessToken string) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err := s.tokenRepo.DeleteByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
s.logger.Error("删除Token失败", zap.Error(err), zap.String("accessToken", accessToken))
|
||||
return
|
||||
}
|
||||
s.logger.Info("成功删除Token", zap.String("token", accessToken))
|
||||
}
|
||||
|
||||
func (s *tokenService) InvalidateUserTokens(ctx context.Context, userID int64) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if userID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := s.tokenRepo.DeleteByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
s.logger.Error("删除用户Token失败", zap.Error(err), zap.Int64("userId", userID))
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Info("成功删除用户Token", zap.Int64("userId", userID))
|
||||
}
|
||||
|
||||
func (s *tokenService) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
return s.tokenRepo.GetUUIDByAccessToken(ctx, accessToken)
|
||||
}
|
||||
|
||||
func (s *tokenService) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
return s.tokenRepo.GetUserIDByAccessToken(ctx, accessToken)
|
||||
}
|
||||
|
||||
// 私有辅助方法
|
||||
|
||||
func (s *tokenService) checkAndCleanupExcessTokens(ctx context.Context, userID int64) {
|
||||
if userID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 为清理操作设置更长的超时时间
|
||||
ctx, cancel := context.WithTimeout(ctx, tokenExtendedTimeout)
|
||||
defer cancel()
|
||||
|
||||
tokens, err := s.tokenRepo.GetByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
}
|
||||
|
||||
if len(tokens) <= tokensMaxCount {
|
||||
return
|
||||
}
|
||||
|
||||
tokensToDelete := make([]string, 0, len(tokens)-tokensMaxCount)
|
||||
for i := tokensMaxCount; i < len(tokens); i++ {
|
||||
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
|
||||
}
|
||||
|
||||
deletedCount, err := s.tokenRepo.BatchDelete(ctx, tokensToDelete)
|
||||
if err != nil {
|
||||
s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
}
|
||||
|
||||
if deletedCount > 0 {
|
||||
s.logger.Info("成功清理用户多余Token", zap.Int64("userId", userID), zap.Int64("count", deletedCount))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *tokenService) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) {
|
||||
if userID == 0 || UUID == "" {
|
||||
return false, errors.New("用户ID或配置文件ID不能为空")
|
||||
}
|
||||
|
||||
profile, err := s.profileRepo.FindByUUID(ctx, UUID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return false, errors.New("配置文件不存在")
|
||||
}
|
||||
return false, fmt.Errorf("验证配置文件失败: %w", err)
|
||||
}
|
||||
return profile.UserID == userID, nil
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -15,40 +14,38 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// tokenServiceJWT TokenService的JWT实现(使用JWT + Version机制)
|
||||
type tokenServiceJWT struct {
|
||||
tokenRepo repository.TokenRepository
|
||||
clientRepo repository.ClientRepository
|
||||
profileRepo repository.ProfileRepository
|
||||
yggdrasilJWT *auth.YggdrasilJWTService
|
||||
logger *zap.Logger
|
||||
tokenExpireSec int64 // Token过期时间(秒),0表示永不过期
|
||||
tokenStaleSec int64 // Token过期但可用时间(秒),0表示永不过期
|
||||
// tokenServiceRedis TokenService的Redis实现
|
||||
type tokenServiceRedis struct {
|
||||
tokenStore *auth.TokenStoreRedis
|
||||
clientRepo repository.ClientRepository
|
||||
profileRepo repository.ProfileRepository
|
||||
yggdrasilJWT *auth.YggdrasilJWTService
|
||||
logger *zap.Logger
|
||||
tokenExpireSec int64 // Token过期时间(秒),0表示永不过期
|
||||
tokenStaleSec int64 // Token过期但可用时间(秒),0表示永不过期
|
||||
}
|
||||
|
||||
// NewTokenServiceJWT 创建使用JWT的TokenService实例
|
||||
func NewTokenServiceJWT(
|
||||
tokenRepo repository.TokenRepository,
|
||||
// NewTokenServiceRedis 创建使用Redis的TokenService实例
|
||||
func NewTokenServiceRedis(
|
||||
tokenStore *auth.TokenStoreRedis,
|
||||
clientRepo repository.ClientRepository,
|
||||
profileRepo repository.ProfileRepository,
|
||||
yggdrasilJWT *auth.YggdrasilJWTService,
|
||||
logger *zap.Logger,
|
||||
) TokenService {
|
||||
return &tokenServiceJWT{
|
||||
tokenRepo: tokenRepo,
|
||||
return &tokenServiceRedis{
|
||||
tokenStore: tokenStore,
|
||||
clientRepo: clientRepo,
|
||||
profileRepo: profileRepo,
|
||||
yggdrasilJWT: yggdrasilJWT,
|
||||
logger: logger,
|
||||
tokenExpireSec: 24 * 3600, // 默认24小时
|
||||
tokenExpireSec: 24 * 3600, // 默认24小时
|
||||
tokenStaleSec: 30 * 24 * 3600, // 默认30天
|
||||
}
|
||||
}
|
||||
|
||||
// 常量已在 token_service.go 中定义,这里不重复定义
|
||||
|
||||
// Create 创建Token(使用JWT + Version机制)
|
||||
func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
|
||||
// Create 创建Token(使用JWT + Redis存储)
|
||||
func (s *tokenServiceRedis) Create(ctx context.Context, userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
|
||||
var (
|
||||
selectedProfileID *model.Profile
|
||||
availableProfiles []*model.Profile
|
||||
@@ -85,11 +82,11 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
|
||||
if UUID != "" {
|
||||
client.ProfileID = UUID
|
||||
}
|
||||
|
||||
|
||||
if err := s.clientRepo.Create(ctx, client); err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Client失败: %w", err)
|
||||
}
|
||||
@@ -103,7 +100,7 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
client.UpdatedAt = time.Now()
|
||||
if UUID != "" {
|
||||
client.ProfileID = UUID
|
||||
if err := s.clientRepo.Update(ctx, client); err != nil {
|
||||
if err := s.clientRepo.Update(ctx, client); err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("更新Client失败: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -130,14 +127,14 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
// 生成Token过期时间
|
||||
now := time.Now()
|
||||
var expiresAt, staleAt time.Time
|
||||
|
||||
|
||||
if s.tokenExpireSec > 0 {
|
||||
expiresAt = now.Add(time.Duration(s.tokenExpireSec) * time.Second)
|
||||
} else {
|
||||
// 使用遥远的未来时间(类似drasl的DISTANT_FUTURE)
|
||||
// 使用遥远的未来时间
|
||||
expiresAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
|
||||
if s.tokenStaleSec > 0 {
|
||||
staleAt = now.Add(time.Duration(s.tokenStaleSec) * time.Second)
|
||||
} else {
|
||||
@@ -157,36 +154,31 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("生成AccessToken失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存Token记录(用于查询和审计)
|
||||
token := model.Token{
|
||||
AccessToken: accessToken,
|
||||
ClientToken: clientToken,
|
||||
// 存储Token到Redis
|
||||
ttl := expiresAt.Sub(now)
|
||||
metadata := &auth.TokenMetadata{
|
||||
UserID: userID,
|
||||
ProfileId: profileID,
|
||||
ProfileID: profileID,
|
||||
ClientUUID: client.UUID,
|
||||
ClientToken: client.ClientToken,
|
||||
Version: client.Version,
|
||||
Usable: true,
|
||||
IssueDate: now,
|
||||
ExpiresAt: &expiresAt,
|
||||
StaleAt: &staleAt,
|
||||
CreatedAt: now.Unix(),
|
||||
}
|
||||
|
||||
err = s.tokenRepo.Create(ctx, &token)
|
||||
if err != nil {
|
||||
s.logger.Warn("保存Token记录失败,但JWT已生成", zap.Error(err))
|
||||
if err := s.tokenStore.Store(ctx, accessToken, metadata, ttl); err != nil {
|
||||
s.logger.Warn("存储Token到Redis失败", zap.Error(err))
|
||||
// 不返回错误,因为JWT本身已经生成成功
|
||||
}
|
||||
|
||||
// 清理多余的令牌(使用独立的后台上下文)
|
||||
go s.checkAndCleanupExcessTokens(context.Background(), userID)
|
||||
|
||||
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
|
||||
}
|
||||
|
||||
// Validate 验证Token(使用JWT验证)
|
||||
func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken string) bool {
|
||||
// Validate 验证Token(使用JWT验证 + Redis存储验证)
|
||||
func (s *tokenServiceRedis) Validate(ctx context.Context, accessToken, clientToken string) bool {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return false
|
||||
}
|
||||
@@ -197,6 +189,13 @@ func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken
|
||||
return false
|
||||
}
|
||||
|
||||
// 从Redis获取Token元数据
|
||||
metadata, err := s.tokenStore.Retrieve(ctx, accessToken)
|
||||
if err != nil {
|
||||
// Token可能已过期或不存在
|
||||
return false
|
||||
}
|
||||
|
||||
// 查找Client
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
@@ -209,18 +208,19 @@ func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken
|
||||
}
|
||||
|
||||
// 验证ClientToken(如果提供)
|
||||
if clientToken != "" && client.ClientToken != clientToken {
|
||||
if clientToken != "" && metadata.ClientToken != clientToken {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Refresh 刷新Token(使用Version机制,无需删除旧Token)
|
||||
func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
|
||||
// Refresh 刷新Token(使用Version机制,Redis存储)
|
||||
func (s *tokenServiceRedis) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("accessToken不能为空")
|
||||
}
|
||||
@@ -279,16 +279,21 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
|
||||
return "", "", fmt.Errorf("更新Client版本失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除旧Token(从Redis)
|
||||
if err := s.tokenStore.Delete(ctx, accessToken); err != nil {
|
||||
s.logger.Warn("删除旧Token失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 生成Token过期时间
|
||||
now := time.Now()
|
||||
var expiresAt, staleAt time.Time
|
||||
|
||||
|
||||
if s.tokenExpireSec > 0 {
|
||||
expiresAt = now.Add(time.Duration(s.tokenExpireSec) * time.Second)
|
||||
} else {
|
||||
expiresAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
|
||||
if s.tokenStaleSec > 0 {
|
||||
staleAt = now.Add(time.Duration(s.tokenStaleSec) * time.Second)
|
||||
} else {
|
||||
@@ -308,30 +313,27 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
|
||||
return "", "", fmt.Errorf("生成新AccessToken失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存新Token记录
|
||||
newToken := model.Token{
|
||||
AccessToken: newAccessToken,
|
||||
ClientToken: client.ClientToken,
|
||||
// 存储新Token到Redis
|
||||
ttl := expiresAt.Sub(now)
|
||||
metadata := &auth.TokenMetadata{
|
||||
UserID: client.UserID,
|
||||
ProfileId: selectedProfileID,
|
||||
ProfileID: selectedProfileID,
|
||||
ClientUUID: client.UUID,
|
||||
ClientToken: client.ClientToken,
|
||||
Version: client.Version,
|
||||
Usable: true,
|
||||
IssueDate: now,
|
||||
ExpiresAt: &expiresAt,
|
||||
StaleAt: &staleAt,
|
||||
CreatedAt: now.Unix(),
|
||||
}
|
||||
|
||||
err = s.tokenRepo.Create(ctx, &newToken)
|
||||
if err != nil {
|
||||
s.logger.Warn("保存新Token记录失败,但JWT已生成", zap.Error(err))
|
||||
if err := s.tokenStore.Store(ctx, newAccessToken, metadata, ttl); err != nil {
|
||||
s.logger.Warn("存储新Token到Redis失败", zap.Error(err))
|
||||
}
|
||||
|
||||
s.logger.Info("成功刷新Token", zap.Int64("userId", client.UserID), zap.Int("version", client.Version))
|
||||
return newAccessToken, client.ClientToken, nil
|
||||
}
|
||||
|
||||
// Invalidate 使Token失效(通过增加Version)
|
||||
func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
|
||||
// Invalidate 使Token失效(从Redis删除)
|
||||
func (s *tokenServiceRedis) Invalidate(ctx context.Context, accessToken string) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
@@ -347,7 +349,7 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
|
||||
return
|
||||
}
|
||||
|
||||
// 查找Client并增加Version
|
||||
// 查找Client并增加Version(失效所有旧Token)
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
s.logger.Warn("无法找到对应的Client", zap.Error(err))
|
||||
@@ -362,11 +364,17 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
|
||||
return
|
||||
}
|
||||
|
||||
// 从Redis删除Token
|
||||
if err := s.tokenStore.Delete(ctx, accessToken); err != nil {
|
||||
s.logger.Warn("从Redis删除Token失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Info("成功失效Token", zap.String("clientUUID", client.UUID), zap.Int("version", client.Version))
|
||||
}
|
||||
|
||||
// InvalidateUserTokens 使用户所有Token失效
|
||||
func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64) {
|
||||
// InvalidateUserTokens 使用户所有Token失效(从Redis删除)
|
||||
func (s *tokenServiceRedis) InvalidateUserTokens(ctx context.Context, userID int64) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
@@ -391,15 +399,20 @@ func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64
|
||||
}
|
||||
}
|
||||
|
||||
// 从Redis删除用户所有Token
|
||||
if err := s.tokenStore.DeleteByUserID(ctx, userID); err != nil {
|
||||
s.logger.Error("从Redis删除用户Token失败", zap.Error(err), zap.Int64("userId", userID))
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Info("成功失效用户所有Token", zap.Int64("userId", userID), zap.Int("clientCount", len(clients)))
|
||||
}
|
||||
|
||||
// GetUUIDByAccessToken 从AccessToken获取UUID(通过JWT解析)
|
||||
func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
|
||||
func (s *tokenServiceRedis) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
|
||||
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
|
||||
if err != nil {
|
||||
// 如果JWT解析失败,尝试从数据库查询(向后兼容)
|
||||
return s.tokenRepo.GetUUIDByAccessToken(ctx, accessToken)
|
||||
return "", errors.New("accessToken无效")
|
||||
}
|
||||
|
||||
if claims.ProfileID != "" {
|
||||
@@ -420,11 +433,10 @@ func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken
|
||||
}
|
||||
|
||||
// GetUserIDByAccessToken 从AccessToken获取UserID(通过JWT解析)
|
||||
func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
|
||||
func (s *tokenServiceRedis) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
|
||||
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
|
||||
if err != nil {
|
||||
// 如果JWT解析失败,尝试从数据库查询(向后兼容)
|
||||
return s.tokenRepo.GetUserIDByAccessToken(ctx, accessToken)
|
||||
return 0, errors.New("accessToken无效")
|
||||
}
|
||||
|
||||
// 从Client获取UserID
|
||||
@@ -441,44 +453,8 @@ func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToke
|
||||
return client.UserID, nil
|
||||
}
|
||||
|
||||
// 私有辅助方法
|
||||
|
||||
func (s *tokenServiceJWT) checkAndCleanupExcessTokens(ctx context.Context, userID int64) {
|
||||
if userID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 为清理操作设置更长的超时时间
|
||||
ctx, cancel := context.WithTimeout(ctx, tokenExtendedTimeout)
|
||||
defer cancel()
|
||||
|
||||
tokens, err := s.tokenRepo.GetByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
}
|
||||
|
||||
if len(tokens) <= tokensMaxCount {
|
||||
return
|
||||
}
|
||||
|
||||
tokensToDelete := make([]string, 0, len(tokens)-tokensMaxCount)
|
||||
for i := tokensMaxCount; i < len(tokens); i++ {
|
||||
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
|
||||
}
|
||||
|
||||
deletedCount, err := s.tokenRepo.BatchDelete(ctx, tokensToDelete)
|
||||
if err != nil {
|
||||
s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
}
|
||||
|
||||
if deletedCount > 0 {
|
||||
s.logger.Info("成功清理用户多余Token", zap.Int64("userId", userID), zap.Int64("count", deletedCount))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *tokenServiceJWT) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) {
|
||||
// validateProfileByUserID 验证Profile是否属于用户
|
||||
func (s *tokenServiceRedis) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) {
|
||||
if userID == 0 || UUID == "" {
|
||||
return false, errors.New("用户ID或配置文件ID不能为空")
|
||||
}
|
||||
@@ -492,24 +468,3 @@ func (s *tokenServiceJWT) validateProfileByUserID(ctx context.Context, userID in
|
||||
}
|
||||
return profile.UserID == userID, nil
|
||||
}
|
||||
|
||||
// GetClientFromToken 从Token获取Client信息(辅助方法)
|
||||
func (s *tokenServiceJWT) GetClientFromToken(ctx context.Context, accessToken string, stalePolicy auth.StaleTokenPolicy) (*model.Client, error) {
|
||||
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, stalePolicy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证Version
|
||||
if claims.Version != client.Version {
|
||||
return nil, errors.New("token版本不匹配")
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ func (s *userService) GetByID(ctx context.Context, id int64) (*model.User, error
|
||||
cacheKey := s.cacheKeys.User(id)
|
||||
return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) {
|
||||
return s.userRepo.FindByID(ctx, id)
|
||||
}, 5*time.Minute)
|
||||
}, s.cache.Policy.UserTTL)
|
||||
}
|
||||
|
||||
func (s *userService) GetByEmail(ctx context.Context, email string) (*model.User, error) {
|
||||
@@ -196,7 +196,7 @@ func (s *userService) GetByEmail(ctx context.Context, email string) (*model.User
|
||||
cacheKey := s.cacheKeys.UserByEmail(email)
|
||||
return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) {
|
||||
return s.userRepo.FindByEmail(ctx, email)
|
||||
}, 5*time.Minute)
|
||||
}, s.cache.Policy.UserEmailTTL)
|
||||
}
|
||||
|
||||
func (s *userService) UpdateInfo(ctx context.Context, user *model.User) error {
|
||||
|
||||
@@ -22,7 +22,7 @@ type yggdrasilServiceComposite struct {
|
||||
serializationService SerializationService
|
||||
certificateService CertificateService
|
||||
profileRepo repository.ProfileRepository
|
||||
tokenRepo repository.TokenRepository
|
||||
tokenService TokenService // 使用TokenService接口,不直接依赖TokenRepository
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
@@ -31,11 +31,11 @@ func NewYggdrasilServiceComposite(
|
||||
db *gorm.DB,
|
||||
userRepo repository.UserRepository,
|
||||
profileRepo repository.ProfileRepository,
|
||||
tokenRepo repository.TokenRepository,
|
||||
yggdrasilRepo repository.YggdrasilRepository,
|
||||
signatureService *SignatureService,
|
||||
redisClient *redis.Client,
|
||||
logger *zap.Logger,
|
||||
tokenService TokenService, // 新增:TokenService接口
|
||||
) YggdrasilService {
|
||||
// 创建各个专门的服务
|
||||
authService := NewYggdrasilAuthService(db, userRepo, yggdrasilRepo, logger)
|
||||
@@ -53,7 +53,7 @@ func NewYggdrasilServiceComposite(
|
||||
serializationService: serializationService,
|
||||
certificateService: certificateService,
|
||||
profileRepo: profileRepo,
|
||||
tokenRepo: tokenRepo,
|
||||
tokenService: tokenService,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
@@ -75,8 +75,8 @@ func (s *yggdrasilServiceComposite) ResetYggdrasilPassword(ctx context.Context,
|
||||
|
||||
// JoinServer 加入服务器
|
||||
func (s *yggdrasilServiceComposite) JoinServer(ctx context.Context, serverID, accessToken, selectedProfile, ip string) error {
|
||||
// 验证Token
|
||||
token, err := s.tokenRepo.FindByAccessToken(ctx, accessToken)
|
||||
// 通过TokenService验证Token并获取UUID
|
||||
uuid, err := s.tokenService.GetUUIDByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
s.logger.Error("验证Token失败",
|
||||
zap.Error(err),
|
||||
@@ -87,7 +87,7 @@ func (s *yggdrasilServiceComposite) JoinServer(ctx context.Context, serverID, ac
|
||||
|
||||
// 格式化UUID并验证与Token关联的配置文件
|
||||
formattedProfile := utils.FormatUUID(selectedProfile)
|
||||
if token.ProfileId != formattedProfile {
|
||||
if uuid != formattedProfile {
|
||||
return errors.New("selectedProfile与Token不匹配")
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user