Files
backend/internal/service/user_service.go

392 lines
10 KiB
Go
Raw Normal View History

package service
import (
"carrotskin/internal/model"
"carrotskin/internal/repository"
"carrotskin/pkg/auth"
"carrotskin/pkg/config"
"carrotskin/pkg/redis"
"context"
"errors"
"fmt"
"net/url"
"strings"
"time"
)
// RegisterUser 用户注册
func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar string) (*model.User, string, error) {
// 检查用户名是否已存在
existingUser, err := repository.FindUserByUsername(username)
if err != nil {
return nil, "", err
}
if existingUser != nil {
return nil, "", errors.New("用户名已存在")
}
// 检查邮箱是否已存在
existingEmail, err := repository.FindUserByEmail(email)
if err != nil {
return nil, "", err
}
if existingEmail != nil {
return nil, "", errors.New("邮箱已被注册")
}
// 加密密码
hashedPassword, err := auth.HashPassword(password)
if err != nil {
return nil, "", errors.New("密码加密失败")
}
// 确定头像URL优先使用用户提供的头像否则使用默认头像
avatarURL := avatar
if avatarURL != "" {
// 验证用户提供的头像 URL 是否来自允许的域名
if err := ValidateAvatarURL(avatarURL); err != nil {
return nil, "", err
}
} else {
avatarURL = getDefaultAvatar()
}
// 创建用户
user := &model.User{
Username: username,
Password: hashedPassword,
Email: email,
Avatar: avatarURL,
Role: "user",
Status: 1,
Points: 0,
}
if err := repository.CreateUser(user); err != nil {
return nil, "", err
}
// 生成JWT Token
token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role)
if err != nil {
return nil, "", errors.New("生成Token失败")
}
return user, token, nil
}
// LoginUser 用户登录(支持用户名或邮箱登录)
func LoginUser(jwtService *auth.JWTService, usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) {
return LoginUserWithRateLimit(nil, jwtService, usernameOrEmail, password, ipAddress, userAgent)
}
// LoginUserWithRateLimit 用户登录(带频率限制)
func LoginUserWithRateLimit(redisClient *redis.Client, jwtService *auth.JWTService, usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) {
ctx := context.Background()
// 检查账号是否被锁定(基于用户名/邮箱和IP
if redisClient != nil {
identifier := usernameOrEmail + ":" + ipAddress
locked, ttl, err := CheckLoginLocked(ctx, redisClient, identifier)
if err == nil && locked {
return nil, "", fmt.Errorf("登录尝试次数过多,请在 %d 分钟后重试", int(ttl.Minutes())+1)
}
}
// 查找用户:判断是用户名还是邮箱
var user *model.User
var err error
if strings.Contains(usernameOrEmail, "@") {
user, err = repository.FindUserByEmail(usernameOrEmail)
} else {
user, err = repository.FindUserByUsername(usernameOrEmail)
}
if err != nil {
return nil, "", err
}
if user == nil {
// 记录失败尝试
if redisClient != nil {
identifier := usernameOrEmail + ":" + ipAddress
count, _ := RecordLoginFailure(ctx, redisClient, identifier)
// 检查是否触发锁定
if count >= MaxLoginAttempts {
logFailedLogin(0, ipAddress, userAgent, "用户不存在-账号已锁定")
return nil, "", fmt.Errorf("登录失败次数过多,账号已被锁定 %d 分钟", int(LoginLockDuration.Minutes()))
}
remaining := MaxLoginAttempts - count
if remaining > 0 {
logFailedLogin(0, ipAddress, userAgent, "用户不存在")
return nil, "", fmt.Errorf("用户名/邮箱或密码错误,还剩 %d 次尝试机会", remaining)
}
}
logFailedLogin(0, ipAddress, userAgent, "用户不存在")
return nil, "", errors.New("用户名/邮箱或密码错误")
}
// 检查用户状态
if user.Status != 1 {
logFailedLogin(user.ID, ipAddress, userAgent, "账号已被禁用")
return nil, "", errors.New("账号已被禁用")
}
// 验证密码
if !auth.CheckPassword(user.Password, password) {
// 记录失败尝试
if redisClient != nil {
identifier := usernameOrEmail + ":" + ipAddress
count, _ := RecordLoginFailure(ctx, redisClient, identifier)
// 检查是否触发锁定
if count >= MaxLoginAttempts {
logFailedLogin(user.ID, ipAddress, userAgent, "密码错误-账号已锁定")
return nil, "", fmt.Errorf("登录失败次数过多,账号已被锁定 %d 分钟", int(LoginLockDuration.Minutes()))
}
remaining := MaxLoginAttempts - count
if remaining > 0 {
logFailedLogin(user.ID, ipAddress, userAgent, "密码错误")
return nil, "", fmt.Errorf("用户名/邮箱或密码错误,还剩 %d 次尝试机会", remaining)
}
}
logFailedLogin(user.ID, ipAddress, userAgent, "密码错误")
return nil, "", errors.New("用户名/邮箱或密码错误")
}
// 登录成功,清除失败计数
if redisClient != nil {
identifier := usernameOrEmail + ":" + ipAddress
_ = ClearLoginAttempts(ctx, redisClient, identifier)
}
// 生成JWT Token
token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role)
if err != nil {
return nil, "", errors.New("生成Token失败")
}
// 更新最后登录时间
now := time.Now()
user.LastLoginAt = &now
_ = repository.UpdateUserFields(user.ID, map[string]interface{}{
"last_login_at": now,
})
// 记录成功登录日志
logSuccessLogin(user.ID, ipAddress, userAgent)
return user, token, nil
}
// GetUserByID 根据ID获取用户
func GetUserByID(id int64) (*model.User, error) {
return repository.FindUserByID(id)
}
// UpdateUserInfo 更新用户信息
func UpdateUserInfo(user *model.User) error {
return repository.UpdateUser(user)
}
// UpdateUserAvatar 更新用户头像
func UpdateUserAvatar(userID int64, avatarURL string) error {
return repository.UpdateUserFields(userID, map[string]interface{}{
"avatar": avatarURL,
})
}
// ChangeUserPassword 修改密码
func ChangeUserPassword(userID int64, oldPassword, newPassword string) error {
user, err := repository.FindUserByID(userID)
if err != nil {
return errors.New("用户不存在")
}
if !auth.CheckPassword(user.Password, oldPassword) {
return errors.New("原密码错误")
}
hashedPassword, err := auth.HashPassword(newPassword)
if err != nil {
return errors.New("密码加密失败")
}
return repository.UpdateUserFields(userID, map[string]interface{}{
"password": hashedPassword,
})
}
// ResetUserPassword 重置密码(通过邮箱)
func ResetUserPassword(email, newPassword string) error {
user, err := repository.FindUserByEmail(email)
if err != nil {
return errors.New("用户不存在")
}
hashedPassword, err := auth.HashPassword(newPassword)
if err != nil {
return errors.New("密码加密失败")
}
return repository.UpdateUserFields(user.ID, map[string]interface{}{
"password": hashedPassword,
})
}
// ChangeUserEmail 更换邮箱
func ChangeUserEmail(userID int64, newEmail string) error {
existingUser, err := repository.FindUserByEmail(newEmail)
if err != nil {
return err
}
if existingUser != nil && existingUser.ID != userID {
return errors.New("邮箱已被其他用户使用")
}
return repository.UpdateUserFields(userID, map[string]interface{}{
"email": newEmail,
})
}
// logSuccessLogin 记录成功登录
func logSuccessLogin(userID int64, ipAddress, userAgent string) {
log := &model.UserLoginLog{
UserID: userID,
IPAddress: ipAddress,
UserAgent: userAgent,
LoginMethod: "PASSWORD",
IsSuccess: true,
}
_ = repository.CreateLoginLog(log)
}
// logFailedLogin 记录失败登录
func logFailedLogin(userID int64, ipAddress, userAgent, reason string) {
log := &model.UserLoginLog{
UserID: userID,
IPAddress: ipAddress,
UserAgent: userAgent,
LoginMethod: "PASSWORD",
IsSuccess: false,
FailureReason: reason,
}
_ = repository.CreateLoginLog(log)
}
// getDefaultAvatar 获取默认头像URL
func getDefaultAvatar() string {
config, err := repository.GetSystemConfigByKey("default_avatar")
if err != nil || config == nil || config.Value == "" {
return ""
}
return config.Value
}
// ValidateAvatarURL 验证头像URL是否合法
func ValidateAvatarURL(avatarURL string) error {
if avatarURL == "" {
return nil
}
// 允许相对路径
if strings.HasPrefix(avatarURL, "/") {
return nil
}
return ValidateURLDomain(avatarURL)
}
// ValidateURLDomain 验证URL的域名是否在允许列表中
func ValidateURLDomain(rawURL string) error {
// 解析URL
parsedURL, err := url.Parse(rawURL)
if err != nil {
return errors.New("无效的URL格式")
}
// 必须是HTTP或HTTPS协议
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return errors.New("URL必须使用http或https协议")
}
// 获取主机名(不包含端口)
host := parsedURL.Hostname()
if host == "" {
return errors.New("URL缺少主机名")
}
// 从配置获取允许的域名列表
cfg, err := config.GetConfig()
if err != nil {
// 如果配置获取失败,使用默认的安全域名列表
allowedDomains := []string{"localhost", "127.0.0.1"}
return checkDomainAllowed(host, allowedDomains)
}
return checkDomainAllowed(host, cfg.Security.AllowedDomains)
}
// checkDomainAllowed 检查域名是否在允许列表中
func checkDomainAllowed(host string, allowedDomains []string) error {
host = strings.ToLower(host)
for _, allowed := range allowedDomains {
allowed = strings.ToLower(strings.TrimSpace(allowed))
if allowed == "" {
continue
}
// 精确匹配
if host == allowed {
return nil
}
// 支持通配符子域名匹配 (如 *.example.com)
if strings.HasPrefix(allowed, "*.") {
suffix := allowed[1:] // 移除 "*",保留 ".example.com"
if strings.HasSuffix(host, suffix) {
return nil
}
}
}
return errors.New("URL域名不在允许的列表中")
}
// GetUserByEmail 根据邮箱获取用户
func GetUserByEmail(email string) (*model.User, error) {
user, err := repository.FindUserByEmail(email)
if err != nil {
return nil, errors.New("邮箱查找失败")
}
return user, nil
}
// GetMaxProfilesPerUser 获取每用户最大档案数量配置
func GetMaxProfilesPerUser() int {
config, err := repository.GetSystemConfigByKey("max_profiles_per_user")
if err != nil || config == nil {
return 5
}
var value int
fmt.Sscanf(config.Value, "%d", &value)
if value <= 0 {
return 5
}
return value
}
// GetMaxTexturesPerUser 获取每用户最大材质数量配置
func GetMaxTexturesPerUser() int {
config, err := repository.GetSystemConfigByKey("max_textures_per_user")
if err != nil || config == nil {
return 50
}
var value int
fmt.Sscanf(config.Value, "%d", &value)
if value <= 0 {
return 50
}
return value
}