refactor: Update service and repository methods to use context
- Refactored multiple service and repository methods to accept context as a parameter, enhancing consistency and enabling better control over request lifecycles. - Updated handlers to utilize context in method calls, improving error handling and performance. - Cleaned up Dockerfile by removing unnecessary whitespace.
This commit is contained in:
@@ -46,12 +46,12 @@ func (s *tokenService) Create(ctx context.Context, userID int64, UUID string, cl
|
||||
)
|
||||
|
||||
// 设置超时上下文
|
||||
_, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
// 验证用户存在
|
||||
if UUID != "" {
|
||||
_, err := s.profileRepo.FindByUUID(UUID)
|
||||
_, err := s.profileRepo.FindByUUID(ctx, UUID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
|
||||
}
|
||||
@@ -72,7 +72,7 @@ func (s *tokenService) Create(ctx context.Context, userID int64, UUID string, cl
|
||||
}
|
||||
|
||||
// 获取用户配置文件
|
||||
profiles, err := s.profileRepo.FindByUserID(userID)
|
||||
profiles, err := s.profileRepo.FindByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
|
||||
}
|
||||
@@ -85,23 +85,27 @@ func (s *tokenService) Create(ctx context.Context, userID int64, UUID string, cl
|
||||
availableProfiles = profiles
|
||||
|
||||
// 插入令牌
|
||||
err = s.tokenRepo.Create(&token)
|
||||
err = s.tokenRepo.Create(ctx, &token)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err)
|
||||
}
|
||||
|
||||
// 清理多余的令牌
|
||||
go s.checkAndCleanupExcessTokens(userID)
|
||||
// 清理多余的令牌(使用独立的后台上下文)
|
||||
go s.checkAndCleanupExcessTokens(context.Background(), userID)
|
||||
|
||||
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
|
||||
}
|
||||
|
||||
func (s *tokenService) Validate(ctx context.Context, accessToken, clientToken string) bool {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
token, err := s.tokenRepo.FindByAccessToken(accessToken)
|
||||
token, err := s.tokenRepo.FindByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -118,12 +122,16 @@ func (s *tokenService) Validate(ctx context.Context, accessToken, clientToken st
|
||||
}
|
||||
|
||||
func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("accessToken不能为空")
|
||||
}
|
||||
|
||||
// 查找旧令牌
|
||||
oldToken, err := s.tokenRepo.FindByAccessToken(accessToken)
|
||||
oldToken, err := s.tokenRepo.FindByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", "", errors.New("accessToken无效")
|
||||
@@ -134,7 +142,7 @@ func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, se
|
||||
|
||||
// 验证profile
|
||||
if selectedProfileID != "" {
|
||||
valid, validErr := s.validateProfileByUserID(oldToken.UserID, selectedProfileID)
|
||||
valid, validErr := s.validateProfileByUserID(ctx, oldToken.UserID, selectedProfileID)
|
||||
if validErr != nil {
|
||||
s.logger.Error("验证Profile失败",
|
||||
zap.Error(err),
|
||||
@@ -174,13 +182,13 @@ func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, se
|
||||
}
|
||||
|
||||
// 先插入新令牌,再删除旧令牌
|
||||
err = s.tokenRepo.Create(&newToken)
|
||||
err = s.tokenRepo.Create(ctx, &newToken)
|
||||
if err != nil {
|
||||
s.logger.Error("创建新Token失败", zap.Error(err), zap.String("accessToken", accessToken))
|
||||
return "", "", fmt.Errorf("创建新Token失败: %w", err)
|
||||
}
|
||||
|
||||
err = s.tokenRepo.DeleteByAccessToken(accessToken)
|
||||
err = s.tokenRepo.DeleteByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
s.logger.Warn("删除旧Token失败,但新Token已创建",
|
||||
zap.Error(err),
|
||||
@@ -194,11 +202,15 @@ func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, se
|
||||
}
|
||||
|
||||
func (s *tokenService) Invalidate(ctx context.Context, accessToken string) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err := s.tokenRepo.DeleteByAccessToken(accessToken)
|
||||
err := s.tokenRepo.DeleteByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
s.logger.Error("删除Token失败", zap.Error(err), zap.String("accessToken", accessToken))
|
||||
return
|
||||
@@ -207,11 +219,15 @@ func (s *tokenService) Invalidate(ctx context.Context, accessToken string) {
|
||||
}
|
||||
|
||||
func (s *tokenService) InvalidateUserTokens(ctx context.Context, userID int64) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if userID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := s.tokenRepo.DeleteByUserID(userID)
|
||||
err := s.tokenRepo.DeleteByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
s.logger.Error("删除用户Token失败", zap.Error(err), zap.Int64("userId", userID))
|
||||
return
|
||||
@@ -221,21 +237,33 @@ func (s *tokenService) InvalidateUserTokens(ctx context.Context, userID int64) {
|
||||
}
|
||||
|
||||
func (s *tokenService) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
|
||||
return s.tokenRepo.GetUUIDByAccessToken(accessToken)
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
return s.tokenRepo.GetUUIDByAccessToken(ctx, accessToken)
|
||||
}
|
||||
|
||||
func (s *tokenService) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
|
||||
return s.tokenRepo.GetUserIDByAccessToken(accessToken)
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
return s.tokenRepo.GetUserIDByAccessToken(ctx, accessToken)
|
||||
}
|
||||
|
||||
// 私有辅助方法
|
||||
|
||||
func (s *tokenService) checkAndCleanupExcessTokens(userID int64) {
|
||||
func (s *tokenService) checkAndCleanupExcessTokens(ctx context.Context, userID int64) {
|
||||
if userID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := s.tokenRepo.GetByUserID(userID)
|
||||
// 为清理操作设置更长的超时时间
|
||||
ctx, cancel := context.WithTimeout(ctx, tokenExtendedTimeout)
|
||||
defer cancel()
|
||||
|
||||
tokens, err := s.tokenRepo.GetByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
@@ -250,7 +278,7 @@ func (s *tokenService) checkAndCleanupExcessTokens(userID int64) {
|
||||
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
|
||||
}
|
||||
|
||||
deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete)
|
||||
deletedCount, err := s.tokenRepo.BatchDelete(ctx, tokensToDelete)
|
||||
if err != nil {
|
||||
s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
@@ -261,12 +289,12 @@ func (s *tokenService) checkAndCleanupExcessTokens(userID int64) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *tokenService) validateProfileByUserID(userID int64, UUID string) (bool, error) {
|
||||
func (s *tokenService) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) {
|
||||
if userID == 0 || UUID == "" {
|
||||
return false, errors.New("用户ID或配置文件ID不能为空")
|
||||
}
|
||||
|
||||
profile, err := s.profileRepo.FindByUUID(UUID)
|
||||
profile, err := s.profileRepo.FindByUUID(ctx, UUID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return false, errors.New("配置文件不存在")
|
||||
|
||||
Reference in New Issue
Block a user