refactor: Update service and repository methods to use context
- Refactored multiple service and repository methods to accept context as a parameter, enhancing consistency and enabling better control over request lifecycles. - Updated handlers to utilize context in method calls, improving error handling and performance. - Cleaned up Dockerfile by removing unnecessary whitespace.
This commit is contained in:
@@ -55,12 +55,12 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
)
|
||||
|
||||
// 设置超时上下文
|
||||
_, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
// 验证用户存在
|
||||
if UUID != "" {
|
||||
_, err := s.profileRepo.FindByUUID(UUID)
|
||||
_, err := s.profileRepo.FindByUUID(ctx, UUID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
|
||||
}
|
||||
@@ -73,7 +73,7 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
|
||||
// 获取或创建Client
|
||||
var client *model.Client
|
||||
existingClient, err := s.clientRepo.FindByClientToken(clientToken)
|
||||
existingClient, err := s.clientRepo.FindByClientToken(ctx, clientToken)
|
||||
if err != nil {
|
||||
// Client不存在,创建新的
|
||||
clientUUID := uuid.New().String()
|
||||
@@ -90,7 +90,7 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
client.ProfileID = UUID
|
||||
}
|
||||
|
||||
if err := s.clientRepo.Create(client); err != nil {
|
||||
if err := s.clientRepo.Create(ctx, client); err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Client失败: %w", err)
|
||||
}
|
||||
} else {
|
||||
@@ -103,14 +103,14 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
client.UpdatedAt = time.Now()
|
||||
if UUID != "" {
|
||||
client.ProfileID = UUID
|
||||
if err := s.clientRepo.Update(client); err != nil {
|
||||
if err := s.clientRepo.Update(ctx, client); err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("更新Client失败: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取用户配置文件
|
||||
profiles, err := s.profileRepo.FindByUserID(userID)
|
||||
profiles, err := s.profileRepo.FindByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
|
||||
}
|
||||
@@ -122,7 +122,7 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
if profileID == "" {
|
||||
profileID = selectedProfileID.UUID
|
||||
client.ProfileID = profileID
|
||||
s.clientRepo.Update(client)
|
||||
_ = s.clientRepo.Update(ctx, client)
|
||||
}
|
||||
}
|
||||
availableProfiles = profiles
|
||||
@@ -170,20 +170,23 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
StaleAt: &staleAt,
|
||||
}
|
||||
|
||||
err = s.tokenRepo.Create(&token)
|
||||
err = s.tokenRepo.Create(ctx, &token)
|
||||
if err != nil {
|
||||
s.logger.Warn("保存Token记录失败,但JWT已生成", zap.Error(err))
|
||||
// 不返回错误,因为JWT本身已经生成成功
|
||||
}
|
||||
|
||||
// 清理多余的令牌
|
||||
go s.checkAndCleanupExcessTokens(userID)
|
||||
// 清理多余的令牌(使用独立的后台上下文)
|
||||
go s.checkAndCleanupExcessTokens(context.Background(), userID)
|
||||
|
||||
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
|
||||
}
|
||||
|
||||
// Validate 验证Token(使用JWT验证)
|
||||
func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken string) bool {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
if accessToken == "" {
|
||||
return false
|
||||
}
|
||||
@@ -195,7 +198,7 @@ func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken
|
||||
}
|
||||
|
||||
// 查找Client
|
||||
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -215,6 +218,9 @@ func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken
|
||||
|
||||
// Refresh 刷新Token(使用Version机制,无需删除旧Token)
|
||||
func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("accessToken不能为空")
|
||||
}
|
||||
@@ -226,7 +232,7 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
|
||||
}
|
||||
|
||||
// 查找Client
|
||||
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
return "", "", errors.New("无法找到对应的Client")
|
||||
}
|
||||
@@ -243,7 +249,7 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
|
||||
|
||||
// 验证Profile
|
||||
if selectedProfileID != "" {
|
||||
valid, validErr := s.validateProfileByUserID(client.UserID, selectedProfileID)
|
||||
valid, validErr := s.validateProfileByUserID(ctx, client.UserID, selectedProfileID)
|
||||
if validErr != nil {
|
||||
s.logger.Error("验证Profile失败",
|
||||
zap.Error(validErr),
|
||||
@@ -269,7 +275,7 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
|
||||
// 增加Version(这是关键:通过Version失效所有旧Token)
|
||||
client.Version++
|
||||
client.UpdatedAt = time.Now()
|
||||
if err := s.clientRepo.Update(client); err != nil {
|
||||
if err := s.clientRepo.Update(ctx, client); err != nil {
|
||||
return "", "", fmt.Errorf("更新Client版本失败: %w", err)
|
||||
}
|
||||
|
||||
@@ -315,7 +321,7 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
|
||||
StaleAt: &staleAt,
|
||||
}
|
||||
|
||||
err = s.tokenRepo.Create(&newToken)
|
||||
err = s.tokenRepo.Create(ctx, &newToken)
|
||||
if err != nil {
|
||||
s.logger.Warn("保存新Token记录失败,但JWT已生成", zap.Error(err))
|
||||
}
|
||||
@@ -326,6 +332,10 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
|
||||
|
||||
// Invalidate 使Token失效(通过增加Version)
|
||||
func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return
|
||||
}
|
||||
@@ -338,7 +348,7 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
|
||||
}
|
||||
|
||||
// 查找Client并增加Version
|
||||
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
s.logger.Warn("无法找到对应的Client", zap.Error(err))
|
||||
return
|
||||
@@ -347,7 +357,7 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
|
||||
// 增加Version以失效所有旧Token
|
||||
client.Version++
|
||||
client.UpdatedAt = time.Now()
|
||||
if err := s.clientRepo.Update(client); err != nil {
|
||||
if err := s.clientRepo.Update(ctx, client); err != nil {
|
||||
s.logger.Error("失效Token失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
@@ -357,12 +367,16 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
|
||||
|
||||
// InvalidateUserTokens 使用户所有Token失效
|
||||
func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if userID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户所有Client
|
||||
clients, err := s.clientRepo.FindByUserID(userID)
|
||||
clients, err := s.clientRepo.FindByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
s.logger.Error("获取用户Client失败", zap.Error(err), zap.Int64("userId", userID))
|
||||
return
|
||||
@@ -372,7 +386,7 @@ func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64
|
||||
for _, client := range clients {
|
||||
client.Version++
|
||||
client.UpdatedAt = time.Now()
|
||||
if err := s.clientRepo.Update(client); err != nil {
|
||||
if err := s.clientRepo.Update(ctx, client); err != nil {
|
||||
s.logger.Error("失效用户Token失败", zap.Error(err), zap.Int64("userId", userID))
|
||||
}
|
||||
}
|
||||
@@ -385,7 +399,7 @@ func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken
|
||||
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
|
||||
if err != nil {
|
||||
// 如果JWT解析失败,尝试从数据库查询(向后兼容)
|
||||
return s.tokenRepo.GetUUIDByAccessToken(accessToken)
|
||||
return s.tokenRepo.GetUUIDByAccessToken(ctx, accessToken)
|
||||
}
|
||||
|
||||
if claims.ProfileID != "" {
|
||||
@@ -393,7 +407,7 @@ func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken
|
||||
}
|
||||
|
||||
// 如果没有ProfileID,从Client获取
|
||||
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("无法找到对应的Client: %w", err)
|
||||
}
|
||||
@@ -410,11 +424,11 @@ func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToke
|
||||
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
|
||||
if err != nil {
|
||||
// 如果JWT解析失败,尝试从数据库查询(向后兼容)
|
||||
return s.tokenRepo.GetUserIDByAccessToken(accessToken)
|
||||
return s.tokenRepo.GetUserIDByAccessToken(ctx, accessToken)
|
||||
}
|
||||
|
||||
// 从Client获取UserID
|
||||
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("无法找到对应的Client: %w", err)
|
||||
}
|
||||
@@ -429,12 +443,16 @@ func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToke
|
||||
|
||||
// 私有辅助方法
|
||||
|
||||
func (s *tokenServiceJWT) checkAndCleanupExcessTokens(userID int64) {
|
||||
func (s *tokenServiceJWT) checkAndCleanupExcessTokens(ctx context.Context, userID int64) {
|
||||
if userID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := s.tokenRepo.GetByUserID(userID)
|
||||
// 为清理操作设置更长的超时时间
|
||||
ctx, cancel := context.WithTimeout(ctx, tokenExtendedTimeout)
|
||||
defer cancel()
|
||||
|
||||
tokens, err := s.tokenRepo.GetByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
@@ -449,7 +467,7 @@ func (s *tokenServiceJWT) checkAndCleanupExcessTokens(userID int64) {
|
||||
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
|
||||
}
|
||||
|
||||
deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete)
|
||||
deletedCount, err := s.tokenRepo.BatchDelete(ctx, tokensToDelete)
|
||||
if err != nil {
|
||||
s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
@@ -460,12 +478,12 @@ func (s *tokenServiceJWT) checkAndCleanupExcessTokens(userID int64) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *tokenServiceJWT) validateProfileByUserID(userID int64, UUID string) (bool, error) {
|
||||
func (s *tokenServiceJWT) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) {
|
||||
if userID == 0 || UUID == "" {
|
||||
return false, errors.New("用户ID或配置文件ID不能为空")
|
||||
}
|
||||
|
||||
profile, err := s.profileRepo.FindByUUID(UUID)
|
||||
profile, err := s.profileRepo.FindByUUID(ctx, UUID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return false, errors.New("配置文件不存在")
|
||||
@@ -482,7 +500,7 @@ func (s *tokenServiceJWT) GetClientFromToken(ctx context.Context, accessToken st
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user