refactor: Implement dependency injection for handlers and services
- Refactored AuthHandler, UserHandler, TextureHandler, ProfileHandler, CaptchaHandler, and YggdrasilHandler to use dependency injection. - Removed direct instantiation of services and repositories within handlers, replacing them with constructor injection. - Updated the container to initialize service instances and provide them to handlers. - Enhanced code structure for better testability and adherence to Go best practices.
This commit is contained in:
@@ -12,12 +12,39 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RegisterUser 用户注册
|
||||
func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar string) (*model.User, string, error) {
|
||||
// userServiceImpl UserService的实现
|
||||
type userServiceImpl struct {
|
||||
userRepo repository.UserRepository
|
||||
configRepo repository.SystemConfigRepository
|
||||
jwtService *auth.JWTService
|
||||
redis *redis.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewUserService 创建UserService实例
|
||||
func NewUserService(
|
||||
userRepo repository.UserRepository,
|
||||
configRepo repository.SystemConfigRepository,
|
||||
jwtService *auth.JWTService,
|
||||
redisClient *redis.Client,
|
||||
logger *zap.Logger,
|
||||
) UserService {
|
||||
return &userServiceImpl{
|
||||
userRepo: userRepo,
|
||||
configRepo: configRepo,
|
||||
jwtService: jwtService,
|
||||
redis: redisClient,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) Register(username, password, email, avatar string) (*model.User, string, error) {
|
||||
// 检查用户名是否已存在
|
||||
existingUser, err := repository.FindUserByUsername(username)
|
||||
existingUser, err := s.userRepo.FindByUsername(username)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -26,7 +53,7 @@ func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existingEmail, err := repository.FindUserByEmail(email)
|
||||
existingEmail, err := s.userRepo.FindByEmail(email)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@@ -40,15 +67,14 @@ func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar
|
||||
return nil, "", errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
// 确定头像URL:优先使用用户提供的头像,否则使用默认头像
|
||||
// 确定头像URL
|
||||
avatarURL := avatar
|
||||
if avatarURL != "" {
|
||||
// 验证用户提供的头像 URL 是否来自允许的域名
|
||||
if err := ValidateAvatarURL(avatarURL); err != nil {
|
||||
if err := s.ValidateAvatarURL(avatarURL); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
} else {
|
||||
avatarURL = getDefaultAvatar()
|
||||
avatarURL = s.getDefaultAvatar()
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
@@ -62,12 +88,12 @@ func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar
|
||||
Points: 0,
|
||||
}
|
||||
|
||||
if err := repository.CreateUser(user); err != nil {
|
||||
if err := s.userRepo.Create(user); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// 生成JWT Token
|
||||
token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role)
|
||||
token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("生成Token失败")
|
||||
}
|
||||
@@ -75,92 +101,56 @@ func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar
|
||||
return user, token, nil
|
||||
}
|
||||
|
||||
// LoginUser 用户登录(支持用户名或邮箱登录)
|
||||
func LoginUser(jwtService *auth.JWTService, usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) {
|
||||
return LoginUserWithRateLimit(nil, jwtService, usernameOrEmail, password, ipAddress, userAgent)
|
||||
}
|
||||
|
||||
// LoginUserWithRateLimit 用户登录(带频率限制)
|
||||
func LoginUserWithRateLimit(redisClient *redis.Client, jwtService *auth.JWTService, usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) {
|
||||
func (s *userServiceImpl) Login(usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
// 检查账号是否被锁定(基于用户名/邮箱和IP)
|
||||
if redisClient != nil {
|
||||
// 检查账号是否被锁定
|
||||
if s.redis != nil {
|
||||
identifier := usernameOrEmail + ":" + ipAddress
|
||||
locked, ttl, err := CheckLoginLocked(ctx, redisClient, identifier)
|
||||
locked, ttl, err := CheckLoginLocked(ctx, s.redis, identifier)
|
||||
if err == nil && locked {
|
||||
return nil, "", fmt.Errorf("登录尝试次数过多,请在 %d 分钟后重试", int(ttl.Minutes())+1)
|
||||
}
|
||||
}
|
||||
|
||||
// 查找用户:判断是用户名还是邮箱
|
||||
// 查找用户
|
||||
var user *model.User
|
||||
var err error
|
||||
|
||||
if strings.Contains(usernameOrEmail, "@") {
|
||||
user, err = repository.FindUserByEmail(usernameOrEmail)
|
||||
user, err = s.userRepo.FindByEmail(usernameOrEmail)
|
||||
} else {
|
||||
user, err = repository.FindUserByUsername(usernameOrEmail)
|
||||
user, err = s.userRepo.FindByUsername(usernameOrEmail)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if user == nil {
|
||||
// 记录失败尝试
|
||||
if redisClient != nil {
|
||||
identifier := usernameOrEmail + ":" + ipAddress
|
||||
count, _ := RecordLoginFailure(ctx, redisClient, identifier)
|
||||
// 检查是否触发锁定
|
||||
if count >= MaxLoginAttempts {
|
||||
logFailedLogin(0, ipAddress, userAgent, "用户不存在-账号已锁定")
|
||||
return nil, "", fmt.Errorf("登录失败次数过多,账号已被锁定 %d 分钟", int(LoginLockDuration.Minutes()))
|
||||
}
|
||||
remaining := MaxLoginAttempts - count
|
||||
if remaining > 0 {
|
||||
logFailedLogin(0, ipAddress, userAgent, "用户不存在")
|
||||
return nil, "", fmt.Errorf("用户名/邮箱或密码错误,还剩 %d 次尝试机会", remaining)
|
||||
}
|
||||
}
|
||||
logFailedLogin(0, ipAddress, userAgent, "用户不存在")
|
||||
s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, 0, "用户不存在")
|
||||
return nil, "", errors.New("用户名/邮箱或密码错误")
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if user.Status != 1 {
|
||||
logFailedLogin(user.ID, ipAddress, userAgent, "账号已被禁用")
|
||||
s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "账号已被禁用")
|
||||
return nil, "", errors.New("账号已被禁用")
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if !auth.CheckPassword(user.Password, password) {
|
||||
// 记录失败尝试
|
||||
if redisClient != nil {
|
||||
identifier := usernameOrEmail + ":" + ipAddress
|
||||
count, _ := RecordLoginFailure(ctx, redisClient, identifier)
|
||||
// 检查是否触发锁定
|
||||
if count >= MaxLoginAttempts {
|
||||
logFailedLogin(user.ID, ipAddress, userAgent, "密码错误-账号已锁定")
|
||||
return nil, "", fmt.Errorf("登录失败次数过多,账号已被锁定 %d 分钟", int(LoginLockDuration.Minutes()))
|
||||
}
|
||||
remaining := MaxLoginAttempts - count
|
||||
if remaining > 0 {
|
||||
logFailedLogin(user.ID, ipAddress, userAgent, "密码错误")
|
||||
return nil, "", fmt.Errorf("用户名/邮箱或密码错误,还剩 %d 次尝试机会", remaining)
|
||||
}
|
||||
}
|
||||
logFailedLogin(user.ID, ipAddress, userAgent, "密码错误")
|
||||
s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "密码错误")
|
||||
return nil, "", errors.New("用户名/邮箱或密码错误")
|
||||
}
|
||||
|
||||
// 登录成功,清除失败计数
|
||||
if redisClient != nil {
|
||||
if s.redis != nil {
|
||||
identifier := usernameOrEmail + ":" + ipAddress
|
||||
_ = ClearLoginAttempts(ctx, redisClient, identifier)
|
||||
_ = ClearLoginAttempts(ctx, s.redis, identifier)
|
||||
}
|
||||
|
||||
// 生成JWT Token
|
||||
token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role)
|
||||
token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("生成Token失败")
|
||||
}
|
||||
@@ -168,37 +158,37 @@ func LoginUserWithRateLimit(redisClient *redis.Client, jwtService *auth.JWTServi
|
||||
// 更新最后登录时间
|
||||
now := time.Now()
|
||||
user.LastLoginAt = &now
|
||||
_ = repository.UpdateUserFields(user.ID, map[string]interface{}{
|
||||
_ = s.userRepo.UpdateFields(user.ID, map[string]interface{}{
|
||||
"last_login_at": now,
|
||||
})
|
||||
|
||||
// 记录成功登录日志
|
||||
logSuccessLogin(user.ID, ipAddress, userAgent)
|
||||
s.logSuccessLogin(user.ID, ipAddress, userAgent)
|
||||
|
||||
return user, token, nil
|
||||
}
|
||||
|
||||
// GetUserByID 根据ID获取用户
|
||||
func GetUserByID(id int64) (*model.User, error) {
|
||||
return repository.FindUserByID(id)
|
||||
func (s *userServiceImpl) GetByID(id int64) (*model.User, error) {
|
||||
return s.userRepo.FindByID(id)
|
||||
}
|
||||
|
||||
// UpdateUserInfo 更新用户信息
|
||||
func UpdateUserInfo(user *model.User) error {
|
||||
return repository.UpdateUser(user)
|
||||
func (s *userServiceImpl) GetByEmail(email string) (*model.User, error) {
|
||||
return s.userRepo.FindByEmail(email)
|
||||
}
|
||||
|
||||
// UpdateUserAvatar 更新用户头像
|
||||
func UpdateUserAvatar(userID int64, avatarURL string) error {
|
||||
return repository.UpdateUserFields(userID, map[string]interface{}{
|
||||
func (s *userServiceImpl) UpdateInfo(user *model.User) error {
|
||||
return s.userRepo.Update(user)
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) UpdateAvatar(userID int64, avatarURL string) error {
|
||||
return s.userRepo.UpdateFields(userID, map[string]interface{}{
|
||||
"avatar": avatarURL,
|
||||
})
|
||||
}
|
||||
|
||||
// ChangeUserPassword 修改密码
|
||||
func ChangeUserPassword(userID int64, oldPassword, newPassword string) error {
|
||||
user, err := repository.FindUserByID(userID)
|
||||
if err != nil {
|
||||
func (s *userServiceImpl) ChangePassword(userID int64, oldPassword, newPassword string) error {
|
||||
user, err := s.userRepo.FindByID(userID)
|
||||
if err != nil || user == nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
@@ -211,15 +201,14 @@ func ChangeUserPassword(userID int64, oldPassword, newPassword string) error {
|
||||
return errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
return repository.UpdateUserFields(userID, map[string]interface{}{
|
||||
return s.userRepo.UpdateFields(userID, map[string]interface{}{
|
||||
"password": hashedPassword,
|
||||
})
|
||||
}
|
||||
|
||||
// ResetUserPassword 重置密码(通过邮箱)
|
||||
func ResetUserPassword(email, newPassword string) error {
|
||||
user, err := repository.FindUserByEmail(email)
|
||||
if err != nil {
|
||||
func (s *userServiceImpl) ResetPassword(email, newPassword string) error {
|
||||
user, err := s.userRepo.FindByEmail(email)
|
||||
if err != nil || user == nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
@@ -228,14 +217,13 @@ func ResetUserPassword(email, newPassword string) error {
|
||||
return errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
return repository.UpdateUserFields(user.ID, map[string]interface{}{
|
||||
return s.userRepo.UpdateFields(user.ID, map[string]interface{}{
|
||||
"password": hashedPassword,
|
||||
})
|
||||
}
|
||||
|
||||
// ChangeUserEmail 更换邮箱
|
||||
func ChangeUserEmail(userID int64, newEmail string) error {
|
||||
existingUser, err := repository.FindUserByEmail(newEmail)
|
||||
func (s *userServiceImpl) ChangeEmail(userID int64, newEmail string) error {
|
||||
existingUser, err := s.userRepo.FindByEmail(newEmail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -243,47 +231,12 @@ func ChangeUserEmail(userID int64, newEmail string) error {
|
||||
return errors.New("邮箱已被其他用户使用")
|
||||
}
|
||||
|
||||
return repository.UpdateUserFields(userID, map[string]interface{}{
|
||||
return s.userRepo.UpdateFields(userID, map[string]interface{}{
|
||||
"email": newEmail,
|
||||
})
|
||||
}
|
||||
|
||||
// logSuccessLogin 记录成功登录
|
||||
func logSuccessLogin(userID int64, ipAddress, userAgent string) {
|
||||
log := &model.UserLoginLog{
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
LoginMethod: "PASSWORD",
|
||||
IsSuccess: true,
|
||||
}
|
||||
_ = repository.CreateLoginLog(log)
|
||||
}
|
||||
|
||||
// logFailedLogin 记录失败登录
|
||||
func logFailedLogin(userID int64, ipAddress, userAgent, reason string) {
|
||||
log := &model.UserLoginLog{
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
LoginMethod: "PASSWORD",
|
||||
IsSuccess: false,
|
||||
FailureReason: reason,
|
||||
}
|
||||
_ = repository.CreateLoginLog(log)
|
||||
}
|
||||
|
||||
// getDefaultAvatar 获取默认头像URL
|
||||
func getDefaultAvatar() string {
|
||||
config, err := repository.GetSystemConfigByKey("default_avatar")
|
||||
if err != nil || config == nil || config.Value == "" {
|
||||
return ""
|
||||
}
|
||||
return config.Value
|
||||
}
|
||||
|
||||
// ValidateAvatarURL 验证头像URL是否合法
|
||||
func ValidateAvatarURL(avatarURL string) error {
|
||||
func (s *userServiceImpl) ValidateAvatarURL(avatarURL string) error {
|
||||
if avatarURL == "" {
|
||||
return nil
|
||||
}
|
||||
@@ -293,13 +246,8 @@ func ValidateAvatarURL(avatarURL string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ValidateURLDomain(avatarURL)
|
||||
}
|
||||
|
||||
// ValidateURLDomain 验证URL的域名是否在允许列表中
|
||||
func ValidateURLDomain(rawURL string) error {
|
||||
// 解析URL
|
||||
parsedURL, err := url.Parse(rawURL)
|
||||
parsedURL, err := url.Parse(avatarURL)
|
||||
if err != nil {
|
||||
return errors.New("无效的URL格式")
|
||||
}
|
||||
@@ -309,7 +257,6 @@ func ValidateURLDomain(rawURL string) error {
|
||||
return errors.New("URL必须使用http或https协议")
|
||||
}
|
||||
|
||||
// 获取主机名(不包含端口)
|
||||
host := parsedURL.Hostname()
|
||||
if host == "" {
|
||||
return errors.New("URL缺少主机名")
|
||||
@@ -318,16 +265,50 @@ func ValidateURLDomain(rawURL string) error {
|
||||
// 从配置获取允许的域名列表
|
||||
cfg, err := config.GetConfig()
|
||||
if err != nil {
|
||||
// 如果配置获取失败,使用默认的安全域名列表
|
||||
allowedDomains := []string{"localhost", "127.0.0.1"}
|
||||
return checkDomainAllowed(host, allowedDomains)
|
||||
return s.checkDomainAllowed(host, allowedDomains)
|
||||
}
|
||||
|
||||
return checkDomainAllowed(host, cfg.Security.AllowedDomains)
|
||||
return s.checkDomainAllowed(host, cfg.Security.AllowedDomains)
|
||||
}
|
||||
|
||||
// checkDomainAllowed 检查域名是否在允许列表中
|
||||
func checkDomainAllowed(host string, allowedDomains []string) error {
|
||||
func (s *userServiceImpl) GetMaxProfilesPerUser() int {
|
||||
config, err := s.configRepo.GetByKey("max_profiles_per_user")
|
||||
if err != nil || config == nil {
|
||||
return 5
|
||||
}
|
||||
var value int
|
||||
fmt.Sscanf(config.Value, "%d", &value)
|
||||
if value <= 0 {
|
||||
return 5
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) GetMaxTexturesPerUser() int {
|
||||
config, err := s.configRepo.GetByKey("max_textures_per_user")
|
||||
if err != nil || config == nil {
|
||||
return 50
|
||||
}
|
||||
var value int
|
||||
fmt.Sscanf(config.Value, "%d", &value)
|
||||
if value <= 0 {
|
||||
return 50
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// 私有辅助方法
|
||||
|
||||
func (s *userServiceImpl) getDefaultAvatar() string {
|
||||
config, err := s.configRepo.GetByKey("default_avatar")
|
||||
if err != nil || config == nil || config.Value == "" {
|
||||
return ""
|
||||
}
|
||||
return config.Value
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) checkDomainAllowed(host string, allowedDomains []string) error {
|
||||
host = strings.ToLower(host)
|
||||
|
||||
for _, allowed := range allowedDomains {
|
||||
@@ -336,14 +317,12 @@ func checkDomainAllowed(host string, allowedDomains []string) error {
|
||||
continue
|
||||
}
|
||||
|
||||
// 精确匹配
|
||||
if host == allowed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 支持通配符子域名匹配 (如 *.example.com)
|
||||
if strings.HasPrefix(allowed, "*.") {
|
||||
suffix := allowed[1:] // 移除 "*",保留 ".example.com"
|
||||
suffix := allowed[1:]
|
||||
if strings.HasSuffix(host, suffix) {
|
||||
return nil
|
||||
}
|
||||
@@ -353,39 +332,37 @@ func checkDomainAllowed(host string, allowedDomains []string) error {
|
||||
return errors.New("URL域名不在允许的列表中")
|
||||
}
|
||||
|
||||
// GetUserByEmail 根据邮箱获取用户
|
||||
func GetUserByEmail(email string) (*model.User, error) {
|
||||
user, err := repository.FindUserByEmail(email)
|
||||
if err != nil {
|
||||
return nil, errors.New("邮箱查找失败")
|
||||
func (s *userServiceImpl) recordLoginFailure(ctx context.Context, usernameOrEmail, ipAddress, userAgent string, userID int64, reason string) {
|
||||
if s.redis != nil {
|
||||
identifier := usernameOrEmail + ":" + ipAddress
|
||||
count, _ := RecordLoginFailure(ctx, s.redis, identifier)
|
||||
if count >= MaxLoginAttempts {
|
||||
s.logFailedLogin(userID, ipAddress, userAgent, reason+"-账号已锁定")
|
||||
return
|
||||
}
|
||||
}
|
||||
return user, nil
|
||||
s.logFailedLogin(userID, ipAddress, userAgent, reason)
|
||||
}
|
||||
|
||||
// GetMaxProfilesPerUser 获取每用户最大档案数量配置
|
||||
func GetMaxProfilesPerUser() int {
|
||||
config, err := repository.GetSystemConfigByKey("max_profiles_per_user")
|
||||
if err != nil || config == nil {
|
||||
return 5
|
||||
func (s *userServiceImpl) logSuccessLogin(userID int64, ipAddress, userAgent string) {
|
||||
log := &model.UserLoginLog{
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
LoginMethod: "PASSWORD",
|
||||
IsSuccess: true,
|
||||
}
|
||||
var value int
|
||||
fmt.Sscanf(config.Value, "%d", &value)
|
||||
if value <= 0 {
|
||||
return 5
|
||||
}
|
||||
return value
|
||||
_ = s.userRepo.CreateLoginLog(log)
|
||||
}
|
||||
|
||||
// GetMaxTexturesPerUser 获取每用户最大材质数量配置
|
||||
func GetMaxTexturesPerUser() int {
|
||||
config, err := repository.GetSystemConfigByKey("max_textures_per_user")
|
||||
if err != nil || config == nil {
|
||||
return 50
|
||||
func (s *userServiceImpl) logFailedLogin(userID int64, ipAddress, userAgent, reason string) {
|
||||
log := &model.UserLoginLog{
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
LoginMethod: "PASSWORD",
|
||||
IsSuccess: false,
|
||||
FailureReason: reason,
|
||||
}
|
||||
var value int
|
||||
fmt.Sscanf(config.Value, "%d", &value)
|
||||
if value <= 0 {
|
||||
return 50
|
||||
}
|
||||
return value
|
||||
_ = s.userRepo.CreateLoginLog(log)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user