500 lines
13 KiB
Go
500 lines
13 KiB
Go
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"errors"
|
||
"fmt"
|
||
"net/url"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
|
||
apperrors "carrotskin/internal/errors"
|
||
"carrotskin/internal/model"
|
||
"carrotskin/internal/repository"
|
||
"carrotskin/pkg/auth"
|
||
"carrotskin/pkg/config"
|
||
"carrotskin/pkg/database"
|
||
"carrotskin/pkg/redis"
|
||
"carrotskin/pkg/storage"
|
||
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// userService UserService的实现
|
||
type userService struct {
|
||
userRepo repository.UserRepository
|
||
jwtService *auth.JWTService
|
||
redis *redis.Client
|
||
cache *database.CacheManager
|
||
cacheKeys *database.CacheKeyBuilder
|
||
cacheInv *database.CacheInvalidator
|
||
storage *storage.StorageClient
|
||
logger *zap.Logger
|
||
}
|
||
|
||
// NewUserService 创建UserService实例
|
||
func NewUserService(
|
||
userRepo repository.UserRepository,
|
||
jwtService *auth.JWTService,
|
||
redisClient *redis.Client,
|
||
cacheManager *database.CacheManager,
|
||
storageClient *storage.StorageClient,
|
||
logger *zap.Logger,
|
||
) UserService {
|
||
// CacheKeyBuilder 使用空前缀,因为 CacheManager 已经处理了前缀
|
||
// 这样缓存键的格式为: CacheManager前缀 + CacheKeyBuilder生成的键
|
||
return &userService{
|
||
userRepo: userRepo,
|
||
jwtService: jwtService,
|
||
redis: redisClient,
|
||
cache: cacheManager,
|
||
cacheKeys: database.NewCacheKeyBuilder(""),
|
||
cacheInv: database.NewCacheInvalidator(cacheManager),
|
||
storage: storageClient,
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
func (s *userService) Register(ctx context.Context, username, password, email, avatar string) (*model.User, string, error) {
|
||
// 检查用户名是否已存在
|
||
existingUser, err := s.userRepo.FindByUsername(ctx, username)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
if existingUser != nil {
|
||
return nil, "", apperrors.ErrUserAlreadyExists
|
||
}
|
||
|
||
// 检查邮箱是否已存在
|
||
existingEmail, err := s.userRepo.FindByEmail(ctx, email)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
if existingEmail != nil {
|
||
return nil, "", apperrors.ErrEmailAlreadyExists
|
||
}
|
||
|
||
// 加密密码
|
||
hashedPassword, err := auth.HashPassword(password)
|
||
if err != nil {
|
||
return nil, "", errors.New("密码加密失败")
|
||
}
|
||
|
||
// 确定头像URL
|
||
avatarURL := avatar
|
||
if avatarURL != "" {
|
||
if err := s.ValidateAvatarURL(ctx, avatarURL); err != nil {
|
||
return nil, "", err
|
||
}
|
||
} else {
|
||
avatarURL = s.getDefaultAvatar()
|
||
}
|
||
|
||
// 创建用户
|
||
user := &model.User{
|
||
Username: username,
|
||
Password: hashedPassword,
|
||
Email: email,
|
||
Avatar: avatarURL,
|
||
Role: "user",
|
||
Status: 1,
|
||
Points: 0,
|
||
}
|
||
|
||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||
return nil, "", err
|
||
}
|
||
|
||
// 生成JWT Token
|
||
token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role)
|
||
if err != nil {
|
||
return nil, "", errors.New("生成Token失败")
|
||
}
|
||
|
||
return user, token, nil
|
||
}
|
||
|
||
func (s *userService) Login(ctx context.Context, usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) {
|
||
// 检查账号是否被锁定
|
||
if s.redis != nil {
|
||
identifier := usernameOrEmail + ":" + ipAddress
|
||
locked, ttl, err := CheckLoginLocked(ctx, s.redis, identifier)
|
||
if err == nil && locked {
|
||
return nil, "", fmt.Errorf("登录尝试次数过多,请在 %d 分钟后重试", int(ttl.Minutes())+1)
|
||
}
|
||
}
|
||
|
||
// 查找用户
|
||
var user *model.User
|
||
var err error
|
||
|
||
if strings.Contains(usernameOrEmail, "@") {
|
||
user, err = s.userRepo.FindByEmail(ctx, usernameOrEmail)
|
||
} else {
|
||
user, err = s.userRepo.FindByUsername(ctx, usernameOrEmail)
|
||
}
|
||
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
if user == nil {
|
||
s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, 0, "用户不存在")
|
||
return nil, "", errors.New("用户名/邮箱或密码错误")
|
||
}
|
||
|
||
// 检查用户状态
|
||
if user.Status != 1 {
|
||
s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "账号已被禁用")
|
||
return nil, "", errors.New("账号已被禁用")
|
||
}
|
||
|
||
// 验证密码
|
||
if !auth.CheckPassword(user.Password, password) {
|
||
s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "密码错误")
|
||
return nil, "", errors.New("用户名/邮箱或密码错误")
|
||
}
|
||
|
||
// 登录成功,清除失败计数
|
||
if s.redis != nil {
|
||
identifier := usernameOrEmail + ":" + ipAddress
|
||
_ = ClearLoginAttempts(ctx, s.redis, identifier)
|
||
}
|
||
|
||
// 生成JWT Token
|
||
token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role)
|
||
if err != nil {
|
||
return nil, "", errors.New("生成Token失败")
|
||
}
|
||
|
||
// 更新最后登录时间
|
||
now := time.Now()
|
||
user.LastLoginAt = &now
|
||
_ = s.userRepo.UpdateFields(ctx, user.ID, map[string]interface{}{
|
||
"last_login_at": now,
|
||
})
|
||
|
||
// 记录成功登录日志
|
||
s.logSuccessLogin(ctx, user.ID, ipAddress, userAgent)
|
||
|
||
return user, token, nil
|
||
}
|
||
|
||
func (s *userService) GetByID(ctx context.Context, id int64) (*model.User, error) {
|
||
// 使用 Cached 装饰器自动处理缓存
|
||
cacheKey := s.cacheKeys.User(id)
|
||
return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) {
|
||
return s.userRepo.FindByID(ctx, id)
|
||
}, 5*time.Minute)
|
||
}
|
||
|
||
func (s *userService) GetByEmail(ctx context.Context, email string) (*model.User, error) {
|
||
// 使用 Cached 装饰器自动处理缓存
|
||
cacheKey := s.cacheKeys.UserByEmail(email)
|
||
return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) {
|
||
return s.userRepo.FindByEmail(ctx, email)
|
||
}, 5*time.Minute)
|
||
}
|
||
|
||
func (s *userService) UpdateInfo(ctx context.Context, user *model.User) error {
|
||
err := s.userRepo.Update(ctx, user)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 清除缓存
|
||
s.cacheInv.OnUpdate(ctx,
|
||
s.cacheKeys.User(user.ID),
|
||
s.cacheKeys.UserByEmail(user.Email),
|
||
s.cacheKeys.UserByUsername(user.Username),
|
||
)
|
||
|
||
return nil
|
||
}
|
||
|
||
func (s *userService) UpdateAvatar(ctx context.Context, userID int64, avatarURL string) error {
|
||
err := s.userRepo.UpdateFields(ctx, userID, map[string]interface{}{
|
||
"avatar": avatarURL,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 清除用户缓存
|
||
s.cacheInv.OnUpdate(ctx, s.cacheKeys.User(userID))
|
||
|
||
return nil
|
||
}
|
||
|
||
func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error {
|
||
user, err := s.userRepo.FindByID(ctx, userID)
|
||
if err != nil || user == nil {
|
||
return errors.New("用户不存在")
|
||
}
|
||
|
||
if !auth.CheckPassword(user.Password, oldPassword) {
|
||
return errors.New("原密码错误")
|
||
}
|
||
|
||
hashedPassword, err := auth.HashPassword(newPassword)
|
||
if err != nil {
|
||
return errors.New("密码加密失败")
|
||
}
|
||
|
||
err = s.userRepo.UpdateFields(ctx, userID, map[string]interface{}{
|
||
"password": hashedPassword,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 清除用户缓存
|
||
s.cacheInv.OnUpdate(ctx, s.cacheKeys.User(userID))
|
||
|
||
return nil
|
||
}
|
||
|
||
func (s *userService) ResetPassword(ctx context.Context, email, newPassword string) error {
|
||
user, err := s.userRepo.FindByEmail(ctx, email)
|
||
if err != nil || user == nil {
|
||
return errors.New("用户不存在")
|
||
}
|
||
|
||
hashedPassword, err := auth.HashPassword(newPassword)
|
||
if err != nil {
|
||
return errors.New("密码加密失败")
|
||
}
|
||
|
||
err = s.userRepo.UpdateFields(ctx, user.ID, map[string]interface{}{
|
||
"password": hashedPassword,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 清除用户缓存
|
||
s.cacheInv.OnUpdate(ctx,
|
||
s.cacheKeys.User(user.ID),
|
||
s.cacheKeys.UserByEmail(email),
|
||
)
|
||
|
||
return nil
|
||
}
|
||
|
||
func (s *userService) ChangeEmail(ctx context.Context, userID int64, newEmail string) error {
|
||
// 获取旧邮箱
|
||
oldUser, _ := s.userRepo.FindByID(ctx, userID)
|
||
|
||
existingUser, err := s.userRepo.FindByEmail(ctx, newEmail)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if existingUser != nil && existingUser.ID != userID {
|
||
return apperrors.ErrEmailAlreadyExists
|
||
}
|
||
|
||
err = s.userRepo.UpdateFields(ctx, userID, map[string]interface{}{
|
||
"email": newEmail,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 清除旧邮箱和用户ID的缓存
|
||
keysToInvalidate := []string{
|
||
s.cacheKeys.User(userID),
|
||
s.cacheKeys.UserByEmail(newEmail),
|
||
}
|
||
if oldUser != nil {
|
||
keysToInvalidate = append(keysToInvalidate, s.cacheKeys.UserByEmail(oldUser.Email))
|
||
}
|
||
s.cacheInv.OnUpdate(ctx, keysToInvalidate...)
|
||
|
||
return nil
|
||
}
|
||
|
||
func (s *userService) ValidateAvatarURL(ctx context.Context, avatarURL string) error {
|
||
if avatarURL == "" {
|
||
return nil
|
||
}
|
||
|
||
// 允许相对路径
|
||
if strings.HasPrefix(avatarURL, "/") {
|
||
return nil
|
||
}
|
||
|
||
// 解析URL
|
||
parsedURL, err := url.Parse(avatarURL)
|
||
if err != nil {
|
||
return errors.New("无效的URL格式")
|
||
}
|
||
|
||
// 必须是HTTP或HTTPS协议
|
||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||
return errors.New("URL必须使用http或https协议")
|
||
}
|
||
|
||
host := parsedURL.Hostname()
|
||
if host == "" {
|
||
return errors.New("URL缺少主机名")
|
||
}
|
||
|
||
// 从配置获取允许的域名列表
|
||
cfg, err := config.GetConfig()
|
||
if err != nil {
|
||
allowedDomains := []string{"localhost", "127.0.0.1"}
|
||
return s.checkDomainAllowed(host, allowedDomains)
|
||
}
|
||
|
||
return s.checkDomainAllowed(host, cfg.Security.AllowedDomains)
|
||
}
|
||
|
||
func (s *userService) UploadAvatar(ctx context.Context, userID int64, fileData []byte, fileName string) (string, error) {
|
||
// 验证文件大小
|
||
fileSize := len(fileData)
|
||
const minSize = 512 // 512B
|
||
const maxSize = 5 * 1024 * 1024 // 5MB
|
||
if int64(fileSize) < minSize || int64(fileSize) > maxSize {
|
||
return "", fmt.Errorf("文件大小必须在 %d 到 %d 字节之间", minSize, maxSize)
|
||
}
|
||
|
||
// 验证文件扩展名
|
||
ext := strings.ToLower(filepath.Ext(fileName))
|
||
allowedExts := map[string]bool{".jpg": true, ".jpeg": true, ".png": true, ".gif": true, ".webp": true}
|
||
if !allowedExts[ext] {
|
||
return "", fmt.Errorf("不支持的文件格式: %s,仅支持 jpg/jpeg/png/gif/webp", ext)
|
||
}
|
||
|
||
// 检查存储服务
|
||
if s.storage == nil {
|
||
return "", errors.New("存储服务不可用")
|
||
}
|
||
|
||
// 计算文件哈希
|
||
hashBytes := sha256.Sum256(fileData)
|
||
hash := hex.EncodeToString(hashBytes[:])
|
||
|
||
// 获取存储桶
|
||
bucketName, err := s.storage.GetBucket("avatars")
|
||
if err != nil {
|
||
return "", fmt.Errorf("获取存储桶失败: %w", err)
|
||
}
|
||
|
||
// 生成对象路径: avatars/{hash[:2]}/{hash[2:4]}/{hash}{ext}
|
||
objectName := fmt.Sprintf("%s/%s/%s%s", hash[:2], hash[2:4], hash, ext)
|
||
|
||
// 上传文件
|
||
reader := bytes.NewReader(fileData)
|
||
contentType := "image/" + strings.TrimPrefix(ext, ".")
|
||
if ext == ".jpg" {
|
||
contentType = "image/jpeg"
|
||
}
|
||
if err := s.storage.UploadObject(ctx, bucketName, objectName, reader, int64(fileSize), contentType); err != nil {
|
||
return "", fmt.Errorf("上传文件失败: %w", err)
|
||
}
|
||
|
||
// 构建文件URL
|
||
avatarURL := s.storage.BuildFileURL(bucketName, objectName)
|
||
|
||
// 更新用户头像
|
||
if err := s.UpdateAvatar(ctx, userID, avatarURL); err != nil {
|
||
return "", fmt.Errorf("更新用户头像失败: %w", err)
|
||
}
|
||
|
||
s.logger.Info("上传头像成功",
|
||
zap.Int64("user_id", userID),
|
||
zap.String("hash", hash),
|
||
zap.String("url", avatarURL),
|
||
)
|
||
|
||
return avatarURL, nil
|
||
}
|
||
|
||
func (s *userService) GetMaxProfilesPerUser() int {
|
||
cfg, err := config.GetConfig()
|
||
if err != nil || cfg.Site.MaxProfilesPerUser <= 0 {
|
||
return 5
|
||
}
|
||
return cfg.Site.MaxProfilesPerUser
|
||
}
|
||
|
||
func (s *userService) GetMaxTexturesPerUser() int {
|
||
cfg, err := config.GetConfig()
|
||
if err != nil || cfg.Site.MaxTexturesPerUser <= 0 {
|
||
return 50
|
||
}
|
||
return cfg.Site.MaxTexturesPerUser
|
||
}
|
||
|
||
// 私有辅助方法
|
||
|
||
func (s *userService) getDefaultAvatar() string {
|
||
cfg, err := config.GetConfig()
|
||
if err != nil {
|
||
return ""
|
||
}
|
||
return cfg.Site.DefaultAvatar
|
||
}
|
||
|
||
func (s *userService) checkDomainAllowed(host string, allowedDomains []string) error {
|
||
host = strings.ToLower(host)
|
||
|
||
for _, allowed := range allowedDomains {
|
||
allowed = strings.ToLower(strings.TrimSpace(allowed))
|
||
if allowed == "" {
|
||
continue
|
||
}
|
||
|
||
if host == allowed {
|
||
return nil
|
||
}
|
||
|
||
if strings.HasPrefix(allowed, "*.") {
|
||
suffix := allowed[1:]
|
||
if strings.HasSuffix(host, suffix) {
|
||
return nil
|
||
}
|
||
}
|
||
}
|
||
|
||
return errors.New("URL域名不在允许的列表中")
|
||
}
|
||
|
||
func (s *userService) recordLoginFailure(ctx context.Context, usernameOrEmail, ipAddress, userAgent string, userID int64, reason string) {
|
||
if s.redis != nil {
|
||
identifier := usernameOrEmail + ":" + ipAddress
|
||
count, _ := RecordLoginFailure(ctx, s.redis, identifier)
|
||
if count >= MaxLoginAttempts {
|
||
s.logFailedLogin(ctx, userID, ipAddress, userAgent, reason+"-账号已锁定")
|
||
return
|
||
}
|
||
}
|
||
s.logFailedLogin(ctx, userID, ipAddress, userAgent, reason)
|
||
}
|
||
|
||
func (s *userService) logSuccessLogin(ctx context.Context, userID int64, ipAddress, userAgent string) {
|
||
log := &model.UserLoginLog{
|
||
UserID: userID,
|
||
IPAddress: ipAddress,
|
||
UserAgent: userAgent,
|
||
LoginMethod: "PASSWORD",
|
||
IsSuccess: true,
|
||
}
|
||
_ = s.userRepo.CreateLoginLog(ctx, log)
|
||
}
|
||
|
||
func (s *userService) logFailedLogin(ctx context.Context, userID int64, ipAddress, userAgent, reason string) {
|
||
log := &model.UserLoginLog{
|
||
UserID: userID,
|
||
IPAddress: ipAddress,
|
||
UserAgent: userAgent,
|
||
LoginMethod: "PASSWORD",
|
||
IsSuccess: false,
|
||
FailureReason: reason,
|
||
}
|
||
_ = s.userRepo.CreateLoginLog(ctx, log)
|
||
}
|