Files
backend/internal/service/user_service.go
lan f7589ebbb8 feat: 引入依赖注入模式
- 创建Repository接口定义(UserRepository、ProfileRepository、TextureRepository等)
- 创建Repository接口实现
- 创建依赖注入容器(container.Container)
- 改造Handler层使用依赖注入(AuthHandler、UserHandler、TextureHandler)
- 创建新的路由注册方式(RegisterRoutesWithDI)
- 提供main.go示例文件展示如何使用依赖注入

同时包含之前的安全修复:
- CORS配置安全加固
- 头像URL验证安全修复
- JWT algorithm confusion漏洞修复
- Recovery中间件增强
- 敏感错误信息泄露修复
- 类型断言安全修复
2025-12-02 17:40:39 +08:00

392 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"carrotskin/internal/model"
"carrotskin/internal/repository"
"carrotskin/pkg/auth"
"carrotskin/pkg/config"
"carrotskin/pkg/redis"
"context"
"errors"
"fmt"
"net/url"
"strings"
"time"
)
// RegisterUser 用户注册
func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar string) (*model.User, string, error) {
// 检查用户名是否已存在
existingUser, err := repository.FindUserByUsername(username)
if err != nil {
return nil, "", err
}
if existingUser != nil {
return nil, "", errors.New("用户名已存在")
}
// 检查邮箱是否已存在
existingEmail, err := repository.FindUserByEmail(email)
if err != nil {
return nil, "", err
}
if existingEmail != nil {
return nil, "", errors.New("邮箱已被注册")
}
// 加密密码
hashedPassword, err := auth.HashPassword(password)
if err != nil {
return nil, "", errors.New("密码加密失败")
}
// 确定头像URL优先使用用户提供的头像否则使用默认头像
avatarURL := avatar
if avatarURL != "" {
// 验证用户提供的头像 URL 是否来自允许的域名
if err := ValidateAvatarURL(avatarURL); err != nil {
return nil, "", err
}
} else {
avatarURL = getDefaultAvatar()
}
// 创建用户
user := &model.User{
Username: username,
Password: hashedPassword,
Email: email,
Avatar: avatarURL,
Role: "user",
Status: 1,
Points: 0,
}
if err := repository.CreateUser(user); err != nil {
return nil, "", err
}
// 生成JWT Token
token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role)
if err != nil {
return nil, "", errors.New("生成Token失败")
}
return user, token, nil
}
// LoginUser 用户登录(支持用户名或邮箱登录)
func LoginUser(jwtService *auth.JWTService, usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) {
return LoginUserWithRateLimit(nil, jwtService, usernameOrEmail, password, ipAddress, userAgent)
}
// LoginUserWithRateLimit 用户登录(带频率限制)
func LoginUserWithRateLimit(redisClient *redis.Client, jwtService *auth.JWTService, usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) {
ctx := context.Background()
// 检查账号是否被锁定(基于用户名/邮箱和IP
if redisClient != nil {
identifier := usernameOrEmail + ":" + ipAddress
locked, ttl, err := CheckLoginLocked(ctx, redisClient, identifier)
if err == nil && locked {
return nil, "", fmt.Errorf("登录尝试次数过多,请在 %d 分钟后重试", int(ttl.Minutes())+1)
}
}
// 查找用户:判断是用户名还是邮箱
var user *model.User
var err error
if strings.Contains(usernameOrEmail, "@") {
user, err = repository.FindUserByEmail(usernameOrEmail)
} else {
user, err = repository.FindUserByUsername(usernameOrEmail)
}
if err != nil {
return nil, "", err
}
if user == nil {
// 记录失败尝试
if redisClient != nil {
identifier := usernameOrEmail + ":" + ipAddress
count, _ := RecordLoginFailure(ctx, redisClient, identifier)
// 检查是否触发锁定
if count >= MaxLoginAttempts {
logFailedLogin(0, ipAddress, userAgent, "用户不存在-账号已锁定")
return nil, "", fmt.Errorf("登录失败次数过多,账号已被锁定 %d 分钟", int(LoginLockDuration.Minutes()))
}
remaining := MaxLoginAttempts - count
if remaining > 0 {
logFailedLogin(0, ipAddress, userAgent, "用户不存在")
return nil, "", fmt.Errorf("用户名/邮箱或密码错误,还剩 %d 次尝试机会", remaining)
}
}
logFailedLogin(0, ipAddress, userAgent, "用户不存在")
return nil, "", errors.New("用户名/邮箱或密码错误")
}
// 检查用户状态
if user.Status != 1 {
logFailedLogin(user.ID, ipAddress, userAgent, "账号已被禁用")
return nil, "", errors.New("账号已被禁用")
}
// 验证密码
if !auth.CheckPassword(user.Password, password) {
// 记录失败尝试
if redisClient != nil {
identifier := usernameOrEmail + ":" + ipAddress
count, _ := RecordLoginFailure(ctx, redisClient, identifier)
// 检查是否触发锁定
if count >= MaxLoginAttempts {
logFailedLogin(user.ID, ipAddress, userAgent, "密码错误-账号已锁定")
return nil, "", fmt.Errorf("登录失败次数过多,账号已被锁定 %d 分钟", int(LoginLockDuration.Minutes()))
}
remaining := MaxLoginAttempts - count
if remaining > 0 {
logFailedLogin(user.ID, ipAddress, userAgent, "密码错误")
return nil, "", fmt.Errorf("用户名/邮箱或密码错误,还剩 %d 次尝试机会", remaining)
}
}
logFailedLogin(user.ID, ipAddress, userAgent, "密码错误")
return nil, "", errors.New("用户名/邮箱或密码错误")
}
// 登录成功,清除失败计数
if redisClient != nil {
identifier := usernameOrEmail + ":" + ipAddress
_ = ClearLoginAttempts(ctx, redisClient, identifier)
}
// 生成JWT Token
token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role)
if err != nil {
return nil, "", errors.New("生成Token失败")
}
// 更新最后登录时间
now := time.Now()
user.LastLoginAt = &now
_ = repository.UpdateUserFields(user.ID, map[string]interface{}{
"last_login_at": now,
})
// 记录成功登录日志
logSuccessLogin(user.ID, ipAddress, userAgent)
return user, token, nil
}
// GetUserByID 根据ID获取用户
func GetUserByID(id int64) (*model.User, error) {
return repository.FindUserByID(id)
}
// UpdateUserInfo 更新用户信息
func UpdateUserInfo(user *model.User) error {
return repository.UpdateUser(user)
}
// UpdateUserAvatar 更新用户头像
func UpdateUserAvatar(userID int64, avatarURL string) error {
return repository.UpdateUserFields(userID, map[string]interface{}{
"avatar": avatarURL,
})
}
// ChangeUserPassword 修改密码
func ChangeUserPassword(userID int64, oldPassword, newPassword string) error {
user, err := repository.FindUserByID(userID)
if err != nil {
return errors.New("用户不存在")
}
if !auth.CheckPassword(user.Password, oldPassword) {
return errors.New("原密码错误")
}
hashedPassword, err := auth.HashPassword(newPassword)
if err != nil {
return errors.New("密码加密失败")
}
return repository.UpdateUserFields(userID, map[string]interface{}{
"password": hashedPassword,
})
}
// ResetUserPassword 重置密码(通过邮箱)
func ResetUserPassword(email, newPassword string) error {
user, err := repository.FindUserByEmail(email)
if err != nil {
return errors.New("用户不存在")
}
hashedPassword, err := auth.HashPassword(newPassword)
if err != nil {
return errors.New("密码加密失败")
}
return repository.UpdateUserFields(user.ID, map[string]interface{}{
"password": hashedPassword,
})
}
// ChangeUserEmail 更换邮箱
func ChangeUserEmail(userID int64, newEmail string) error {
existingUser, err := repository.FindUserByEmail(newEmail)
if err != nil {
return err
}
if existingUser != nil && existingUser.ID != userID {
return errors.New("邮箱已被其他用户使用")
}
return repository.UpdateUserFields(userID, map[string]interface{}{
"email": newEmail,
})
}
// logSuccessLogin 记录成功登录
func logSuccessLogin(userID int64, ipAddress, userAgent string) {
log := &model.UserLoginLog{
UserID: userID,
IPAddress: ipAddress,
UserAgent: userAgent,
LoginMethod: "PASSWORD",
IsSuccess: true,
}
_ = repository.CreateLoginLog(log)
}
// logFailedLogin 记录失败登录
func logFailedLogin(userID int64, ipAddress, userAgent, reason string) {
log := &model.UserLoginLog{
UserID: userID,
IPAddress: ipAddress,
UserAgent: userAgent,
LoginMethod: "PASSWORD",
IsSuccess: false,
FailureReason: reason,
}
_ = repository.CreateLoginLog(log)
}
// getDefaultAvatar 获取默认头像URL
func getDefaultAvatar() string {
config, err := repository.GetSystemConfigByKey("default_avatar")
if err != nil || config == nil || config.Value == "" {
return ""
}
return config.Value
}
// ValidateAvatarURL 验证头像URL是否合法
func ValidateAvatarURL(avatarURL string) error {
if avatarURL == "" {
return nil
}
// 允许相对路径
if strings.HasPrefix(avatarURL, "/") {
return nil
}
return ValidateURLDomain(avatarURL)
}
// ValidateURLDomain 验证URL的域名是否在允许列表中
func ValidateURLDomain(rawURL string) error {
// 解析URL
parsedURL, err := url.Parse(rawURL)
if err != nil {
return errors.New("无效的URL格式")
}
// 必须是HTTP或HTTPS协议
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return errors.New("URL必须使用http或https协议")
}
// 获取主机名(不包含端口)
host := parsedURL.Hostname()
if host == "" {
return errors.New("URL缺少主机名")
}
// 从配置获取允许的域名列表
cfg, err := config.GetConfig()
if err != nil {
// 如果配置获取失败,使用默认的安全域名列表
allowedDomains := []string{"localhost", "127.0.0.1"}
return checkDomainAllowed(host, allowedDomains)
}
return checkDomainAllowed(host, cfg.Security.AllowedDomains)
}
// checkDomainAllowed 检查域名是否在允许列表中
func checkDomainAllowed(host string, allowedDomains []string) error {
host = strings.ToLower(host)
for _, allowed := range allowedDomains {
allowed = strings.ToLower(strings.TrimSpace(allowed))
if allowed == "" {
continue
}
// 精确匹配
if host == allowed {
return nil
}
// 支持通配符子域名匹配 (如 *.example.com)
if strings.HasPrefix(allowed, "*.") {
suffix := allowed[1:] // 移除 "*",保留 ".example.com"
if strings.HasSuffix(host, suffix) {
return nil
}
}
}
return errors.New("URL域名不在允许的列表中")
}
// GetUserByEmail 根据邮箱获取用户
func GetUserByEmail(email string) (*model.User, error) {
user, err := repository.FindUserByEmail(email)
if err != nil {
return nil, errors.New("邮箱查找失败")
}
return user, nil
}
// GetMaxProfilesPerUser 获取每用户最大档案数量配置
func GetMaxProfilesPerUser() int {
config, err := repository.GetSystemConfigByKey("max_profiles_per_user")
if err != nil || config == nil {
return 5
}
var value int
fmt.Sscanf(config.Value, "%d", &value)
if value <= 0 {
return 5
}
return value
}
// GetMaxTexturesPerUser 获取每用户最大材质数量配置
func GetMaxTexturesPerUser() int {
config, err := repository.GetSystemConfigByKey("max_textures_per_user")
if err != nil || config == nil {
return 50
}
var value int
fmt.Sscanf(config.Value, "%d", &value)
if value <= 0 {
return 50
}
return value
}