Files
backend/internal/service/user_service.go

500 lines
13 KiB
Go
Raw Normal View History

package service
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net/url"
"path/filepath"
"strings"
"time"
apperrors "carrotskin/internal/errors"
"carrotskin/internal/model"
"carrotskin/internal/repository"
"carrotskin/pkg/auth"
"carrotskin/pkg/config"
"carrotskin/pkg/database"
"carrotskin/pkg/redis"
"carrotskin/pkg/storage"
"go.uber.org/zap"
)
// userService UserService的实现
type userService struct {
userRepo repository.UserRepository
jwtService *auth.JWTService
redis *redis.Client
cache *database.CacheManager
cacheKeys *database.CacheKeyBuilder
cacheInv *database.CacheInvalidator
storage *storage.StorageClient
logger *zap.Logger
}
// NewUserService 创建UserService实例
func NewUserService(
userRepo repository.UserRepository,
jwtService *auth.JWTService,
redisClient *redis.Client,
cacheManager *database.CacheManager,
storageClient *storage.StorageClient,
logger *zap.Logger,
) UserService {
// CacheKeyBuilder 使用空前缀,因为 CacheManager 已经处理了前缀
// 这样缓存键的格式为: CacheManager前缀 + CacheKeyBuilder生成的键
return &userService{
userRepo: userRepo,
jwtService: jwtService,
redis: redisClient,
cache: cacheManager,
cacheKeys: database.NewCacheKeyBuilder(""),
cacheInv: database.NewCacheInvalidator(cacheManager),
storage: storageClient,
logger: logger,
}
}
func (s *userService) Register(ctx context.Context, username, password, email, avatar string) (*model.User, string, error) {
// 检查用户名是否已存在
existingUser, err := s.userRepo.FindByUsername(ctx, username)
if err != nil {
return nil, "", err
}
if existingUser != nil {
return nil, "", apperrors.ErrUserAlreadyExists
}
// 检查邮箱是否已存在
existingEmail, err := s.userRepo.FindByEmail(ctx, email)
if err != nil {
return nil, "", err
}
if existingEmail != nil {
return nil, "", apperrors.ErrEmailAlreadyExists
}
// 加密密码
hashedPassword, err := auth.HashPassword(password)
if err != nil {
return nil, "", errors.New("密码加密失败")
}
// 确定头像URL
avatarURL := avatar
if avatarURL != "" {
if err := s.ValidateAvatarURL(ctx, avatarURL); err != nil {
return nil, "", err
}
} else {
avatarURL = s.getDefaultAvatar()
}
// 创建用户
user := &model.User{
Username: username,
Password: hashedPassword,
Email: email,
Avatar: avatarURL,
Role: "user",
Status: 1,
Points: 0,
}
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, "", err
}
// 生成JWT Token
token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role)
if err != nil {
return nil, "", errors.New("生成Token失败")
}
return user, token, nil
}
func (s *userService) Login(ctx context.Context, usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) {
// 检查账号是否被锁定
if s.redis != nil {
identifier := usernameOrEmail + ":" + ipAddress
locked, ttl, err := CheckLoginLocked(ctx, s.redis, identifier)
if err == nil && locked {
return nil, "", fmt.Errorf("登录尝试次数过多,请在 %d 分钟后重试", int(ttl.Minutes())+1)
}
}
// 查找用户
var user *model.User
var err error
if strings.Contains(usernameOrEmail, "@") {
user, err = s.userRepo.FindByEmail(ctx, usernameOrEmail)
} else {
user, err = s.userRepo.FindByUsername(ctx, usernameOrEmail)
}
if err != nil {
return nil, "", err
}
if user == nil {
s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, 0, "用户不存在")
return nil, "", errors.New("用户名/邮箱或密码错误")
}
// 检查用户状态
if user.Status != 1 {
s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "账号已被禁用")
return nil, "", errors.New("账号已被禁用")
}
// 验证密码
if !auth.CheckPassword(user.Password, password) {
s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "密码错误")
return nil, "", errors.New("用户名/邮箱或密码错误")
}
// 登录成功,清除失败计数
if s.redis != nil {
identifier := usernameOrEmail + ":" + ipAddress
_ = ClearLoginAttempts(ctx, s.redis, identifier)
}
// 生成JWT Token
token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role)
if err != nil {
return nil, "", errors.New("生成Token失败")
}
// 更新最后登录时间
now := time.Now()
user.LastLoginAt = &now
_ = s.userRepo.UpdateFields(ctx, user.ID, map[string]interface{}{
"last_login_at": now,
})
// 记录成功登录日志
s.logSuccessLogin(ctx, user.ID, ipAddress, userAgent)
return user, token, nil
}
func (s *userService) GetByID(ctx context.Context, id int64) (*model.User, error) {
// 使用 Cached 装饰器自动处理缓存
cacheKey := s.cacheKeys.User(id)
return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) {
return s.userRepo.FindByID(ctx, id)
}, 5*time.Minute)
}
func (s *userService) GetByEmail(ctx context.Context, email string) (*model.User, error) {
// 使用 Cached 装饰器自动处理缓存
cacheKey := s.cacheKeys.UserByEmail(email)
return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) {
return s.userRepo.FindByEmail(ctx, email)
}, 5*time.Minute)
}
func (s *userService) UpdateInfo(ctx context.Context, user *model.User) error {
err := s.userRepo.Update(ctx, user)
if err != nil {
return err
}
// 清除缓存
s.cacheInv.OnUpdate(ctx,
s.cacheKeys.User(user.ID),
s.cacheKeys.UserByEmail(user.Email),
s.cacheKeys.UserByUsername(user.Username),
)
return nil
}
func (s *userService) UpdateAvatar(ctx context.Context, userID int64, avatarURL string) error {
err := s.userRepo.UpdateFields(ctx, userID, map[string]interface{}{
"avatar": avatarURL,
})
if err != nil {
return err
}
// 清除用户缓存
s.cacheInv.OnUpdate(ctx, s.cacheKeys.User(userID))
return nil
}
func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error {
user, err := s.userRepo.FindByID(ctx, userID)
if err != nil || user == nil {
return errors.New("用户不存在")
}
if !auth.CheckPassword(user.Password, oldPassword) {
return errors.New("原密码错误")
}
hashedPassword, err := auth.HashPassword(newPassword)
if err != nil {
return errors.New("密码加密失败")
}
err = s.userRepo.UpdateFields(ctx, userID, map[string]interface{}{
"password": hashedPassword,
})
if err != nil {
return err
}
// 清除用户缓存
s.cacheInv.OnUpdate(ctx, s.cacheKeys.User(userID))
return nil
}
func (s *userService) ResetPassword(ctx context.Context, email, newPassword string) error {
user, err := s.userRepo.FindByEmail(ctx, email)
if err != nil || user == nil {
return errors.New("用户不存在")
}
hashedPassword, err := auth.HashPassword(newPassword)
if err != nil {
return errors.New("密码加密失败")
}
err = s.userRepo.UpdateFields(ctx, user.ID, map[string]interface{}{
"password": hashedPassword,
})
if err != nil {
return err
}
// 清除用户缓存
s.cacheInv.OnUpdate(ctx,
s.cacheKeys.User(user.ID),
s.cacheKeys.UserByEmail(email),
)
return nil
}
func (s *userService) ChangeEmail(ctx context.Context, userID int64, newEmail string) error {
// 获取旧邮箱
oldUser, _ := s.userRepo.FindByID(ctx, userID)
existingUser, err := s.userRepo.FindByEmail(ctx, newEmail)
if err != nil {
return err
}
if existingUser != nil && existingUser.ID != userID {
return apperrors.ErrEmailAlreadyExists
}
err = s.userRepo.UpdateFields(ctx, userID, map[string]interface{}{
"email": newEmail,
})
if err != nil {
return err
}
// 清除旧邮箱和用户ID的缓存
keysToInvalidate := []string{
s.cacheKeys.User(userID),
s.cacheKeys.UserByEmail(newEmail),
}
if oldUser != nil {
keysToInvalidate = append(keysToInvalidate, s.cacheKeys.UserByEmail(oldUser.Email))
}
s.cacheInv.OnUpdate(ctx, keysToInvalidate...)
return nil
}
func (s *userService) ValidateAvatarURL(ctx context.Context, avatarURL string) error {
if avatarURL == "" {
return nil
}
// 允许相对路径
if strings.HasPrefix(avatarURL, "/") {
return nil
}
// 解析URL
parsedURL, err := url.Parse(avatarURL)
if err != nil {
return errors.New("无效的URL格式")
}
// 必须是HTTP或HTTPS协议
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return errors.New("URL必须使用http或https协议")
}
host := parsedURL.Hostname()
if host == "" {
return errors.New("URL缺少主机名")
}
// 从配置获取允许的域名列表
cfg, err := config.GetConfig()
if err != nil {
allowedDomains := []string{"localhost", "127.0.0.1"}
return s.checkDomainAllowed(host, allowedDomains)
}
return s.checkDomainAllowed(host, cfg.Security.AllowedDomains)
}
func (s *userService) UploadAvatar(ctx context.Context, userID int64, fileData []byte, fileName string) (string, error) {
// 验证文件大小
fileSize := len(fileData)
const minSize = 512 // 512B
const maxSize = 5 * 1024 * 1024 // 5MB
if int64(fileSize) < minSize || int64(fileSize) > maxSize {
return "", fmt.Errorf("文件大小必须在 %d 到 %d 字节之间", minSize, maxSize)
}
// 验证文件扩展名
ext := strings.ToLower(filepath.Ext(fileName))
allowedExts := map[string]bool{".jpg": true, ".jpeg": true, ".png": true, ".gif": true, ".webp": true}
if !allowedExts[ext] {
return "", fmt.Errorf("不支持的文件格式: %s仅支持 jpg/jpeg/png/gif/webp", ext)
}
// 检查存储服务
if s.storage == nil {
return "", errors.New("存储服务不可用")
}
// 计算文件哈希
hashBytes := sha256.Sum256(fileData)
hash := hex.EncodeToString(hashBytes[:])
// 获取存储桶
bucketName, err := s.storage.GetBucket("avatars")
if err != nil {
return "", fmt.Errorf("获取存储桶失败: %w", err)
}
// 生成对象路径: avatars/{hash[:2]}/{hash[2:4]}/{hash}{ext}
objectName := fmt.Sprintf("%s/%s/%s%s", hash[:2], hash[2:4], hash, ext)
// 上传文件
reader := bytes.NewReader(fileData)
contentType := "image/" + strings.TrimPrefix(ext, ".")
if ext == ".jpg" {
contentType = "image/jpeg"
}
if err := s.storage.UploadObject(ctx, bucketName, objectName, reader, int64(fileSize), contentType); err != nil {
return "", fmt.Errorf("上传文件失败: %w", err)
}
// 构建文件URL
avatarURL := s.storage.BuildFileURL(bucketName, objectName)
// 更新用户头像
if err := s.UpdateAvatar(ctx, userID, avatarURL); err != nil {
return "", fmt.Errorf("更新用户头像失败: %w", err)
}
s.logger.Info("上传头像成功",
zap.Int64("user_id", userID),
zap.String("hash", hash),
zap.String("url", avatarURL),
)
return avatarURL, nil
}
func (s *userService) GetMaxProfilesPerUser() int {
cfg, err := config.GetConfig()
if err != nil || cfg.Site.MaxProfilesPerUser <= 0 {
return 5
}
return cfg.Site.MaxProfilesPerUser
}
func (s *userService) GetMaxTexturesPerUser() int {
cfg, err := config.GetConfig()
if err != nil || cfg.Site.MaxTexturesPerUser <= 0 {
return 50
}
return cfg.Site.MaxTexturesPerUser
}
// 私有辅助方法
func (s *userService) getDefaultAvatar() string {
cfg, err := config.GetConfig()
if err != nil {
return ""
}
return cfg.Site.DefaultAvatar
}
func (s *userService) checkDomainAllowed(host string, allowedDomains []string) error {
host = strings.ToLower(host)
for _, allowed := range allowedDomains {
allowed = strings.ToLower(strings.TrimSpace(allowed))
if allowed == "" {
continue
}
if host == allowed {
return nil
}
if strings.HasPrefix(allowed, "*.") {
suffix := allowed[1:]
if strings.HasSuffix(host, suffix) {
return nil
}
}
}
return errors.New("URL域名不在允许的列表中")
}
func (s *userService) recordLoginFailure(ctx context.Context, usernameOrEmail, ipAddress, userAgent string, userID int64, reason string) {
if s.redis != nil {
identifier := usernameOrEmail + ":" + ipAddress
count, _ := RecordLoginFailure(ctx, s.redis, identifier)
if count >= MaxLoginAttempts {
s.logFailedLogin(ctx, userID, ipAddress, userAgent, reason+"-账号已锁定")
return
}
}
s.logFailedLogin(ctx, userID, ipAddress, userAgent, reason)
}
func (s *userService) logSuccessLogin(ctx context.Context, userID int64, ipAddress, userAgent string) {
log := &model.UserLoginLog{
UserID: userID,
IPAddress: ipAddress,
UserAgent: userAgent,
LoginMethod: "PASSWORD",
IsSuccess: true,
}
_ = s.userRepo.CreateLoginLog(ctx, log)
}
func (s *userService) logFailedLogin(ctx context.Context, userID int64, ipAddress, userAgent, reason string) {
log := &model.UserLoginLog{
UserID: userID,
IPAddress: ipAddress,
UserAgent: userAgent,
LoginMethod: "PASSWORD",
IsSuccess: false,
FailureReason: reason,
}
_ = s.userRepo.CreateLoginLog(ctx, log)
}