Compare commits
4 Commits
373c61f625
...
188a05caa7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
188a05caa7 | ||
|
|
e05ba3b041 | ||
|
|
ffdc3e3e6b | ||
|
|
f7589ebbb8 |
@@ -74,3 +74,6 @@ local/
|
||||
dev/
|
||||
minio-data/
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
name: Build and Push Docker Image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
- dev
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
REGISTRY: code.littlelan.cn
|
||||
IMAGE_NAME: carrotskin/backend
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: quay.io/buildah/stable:latest
|
||||
options: --privileged
|
||||
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
dnf install -y git nodejs
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Login to registry
|
||||
run: |
|
||||
buildah login \
|
||||
-u "${{ secrets.REGISTRY_USERNAME }}" \
|
||||
-p "${{ secrets.REGISTRY_PASSWORD }}" \
|
||||
${{ env.REGISTRY }}
|
||||
echo "Registry 登录成功"
|
||||
|
||||
- name: Build image
|
||||
run: |
|
||||
buildah bud \
|
||||
--format docker \
|
||||
--layers \
|
||||
-t ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:build \
|
||||
-f Dockerfile \
|
||||
.
|
||||
echo "镜像构建完成"
|
||||
|
||||
- name: Tag and push image
|
||||
run: |
|
||||
SHORT_SHA=$(echo "${{ github.sha }}" | cut -c1-7)
|
||||
REF_NAME="${{ github.ref_name }}"
|
||||
REF="${{ github.ref }}"
|
||||
|
||||
# 推送分支/标签名
|
||||
buildah tag ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:build \
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${REF_NAME}
|
||||
buildah push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${REF_NAME}
|
||||
echo "✓ 推送: ${REF_NAME}"
|
||||
|
||||
# 推送 SHA 标签
|
||||
buildah tag ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:build \
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${SHORT_SHA}
|
||||
buildah push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${SHORT_SHA}
|
||||
echo "✓ 推送: sha-${SHORT_SHA}"
|
||||
|
||||
# main/master 推送 latest
|
||||
if [ "$REF" = "refs/heads/main" ] || [ "$REF" = "refs/heads/master" ]; then
|
||||
buildah tag ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:build \
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
|
||||
buildah push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest
|
||||
echo "✓ 推送: latest"
|
||||
fi
|
||||
|
||||
- name: Build summary
|
||||
run: |
|
||||
echo "=============================="
|
||||
echo "✅ 镜像构建完成!"
|
||||
echo "仓库: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}"
|
||||
echo "分支: ${{ github.ref_name }}"
|
||||
echo "=============================="
|
||||
@@ -59,3 +59,6 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
# 启动应用
|
||||
ENTRYPOINT ["./server"]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
_ "carrotskin/docs" // Swagger文档
|
||||
"carrotskin/internal/container"
|
||||
"carrotskin/internal/handler"
|
||||
"carrotskin/internal/middleware"
|
||||
"carrotskin/pkg/auth"
|
||||
@@ -66,10 +67,11 @@ func main() {
|
||||
defer redis.MustGetClient().Close()
|
||||
|
||||
// 初始化对象存储 (RustFS - S3兼容)
|
||||
// 如果对象存储未配置或连接失败,记录警告但不退出(某些功能可能不可用)
|
||||
var storageClient *storage.StorageClient
|
||||
if err := storage.Init(cfg.RustFS); err != nil {
|
||||
loggerInstance.Warn("对象存储连接失败,某些功能可能不可用", zap.Error(err))
|
||||
} else {
|
||||
storageClient = storage.MustGetClient()
|
||||
loggerInstance.Info("对象存储连接成功")
|
||||
}
|
||||
|
||||
@@ -78,6 +80,15 @@ func main() {
|
||||
loggerInstance.Fatal("邮件服务初始化失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 创建依赖注入容器
|
||||
c := container.NewContainer(
|
||||
database.MustGetDB(),
|
||||
redis.MustGetClient(),
|
||||
loggerInstance,
|
||||
auth.MustGetJWTService(),
|
||||
storageClient,
|
||||
)
|
||||
|
||||
// 设置Gin模式
|
||||
if cfg.Server.Mode == "production" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -91,8 +102,8 @@ func main() {
|
||||
router.Use(middleware.Recovery(loggerInstance))
|
||||
router.Use(middleware.CORS())
|
||||
|
||||
// 注册路由
|
||||
handler.RegisterRoutes(router)
|
||||
// 使用依赖注入方式注册路由
|
||||
handler.RegisterRoutesWithDI(router, c)
|
||||
|
||||
// 创建HTTP服务器
|
||||
srv := &http.Server{
|
||||
|
||||
178
internal/container/container.go
Normal file
178
internal/container/container.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package container
|
||||
|
||||
import (
|
||||
"carrotskin/internal/repository"
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/pkg/auth"
|
||||
"carrotskin/pkg/redis"
|
||||
"carrotskin/pkg/storage"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Container 依赖注入容器
|
||||
// 集中管理所有依赖,便于测试和维护
|
||||
type Container struct {
|
||||
// 基础设施依赖
|
||||
DB *gorm.DB
|
||||
Redis *redis.Client
|
||||
Logger *zap.Logger
|
||||
JWT *auth.JWTService
|
||||
Storage *storage.StorageClient
|
||||
|
||||
// Repository层
|
||||
UserRepo repository.UserRepository
|
||||
ProfileRepo repository.ProfileRepository
|
||||
TextureRepo repository.TextureRepository
|
||||
TokenRepo repository.TokenRepository
|
||||
ConfigRepo repository.SystemConfigRepository
|
||||
|
||||
// Service层
|
||||
UserService service.UserService
|
||||
ProfileService service.ProfileService
|
||||
TextureService service.TextureService
|
||||
TokenService service.TokenService
|
||||
}
|
||||
|
||||
// NewContainer 创建依赖容器
|
||||
func NewContainer(
|
||||
db *gorm.DB,
|
||||
redisClient *redis.Client,
|
||||
logger *zap.Logger,
|
||||
jwtService *auth.JWTService,
|
||||
storageClient *storage.StorageClient,
|
||||
) *Container {
|
||||
c := &Container{
|
||||
DB: db,
|
||||
Redis: redisClient,
|
||||
Logger: logger,
|
||||
JWT: jwtService,
|
||||
Storage: storageClient,
|
||||
}
|
||||
|
||||
// 初始化Repository
|
||||
c.UserRepo = repository.NewUserRepository(db)
|
||||
c.ProfileRepo = repository.NewProfileRepository(db)
|
||||
c.TextureRepo = repository.NewTextureRepository(db)
|
||||
c.TokenRepo = repository.NewTokenRepository(db)
|
||||
c.ConfigRepo = repository.NewSystemConfigRepository(db)
|
||||
|
||||
// 初始化Service
|
||||
c.UserService = service.NewUserService(c.UserRepo, c.ConfigRepo, jwtService, redisClient, logger)
|
||||
c.ProfileService = service.NewProfileService(c.ProfileRepo, c.UserRepo, logger)
|
||||
c.TextureService = service.NewTextureService(c.TextureRepo, c.UserRepo, logger)
|
||||
c.TokenService = service.NewTokenService(c.TokenRepo, c.ProfileRepo, logger)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// NewTestContainer 创建测试用容器(可注入mock依赖)
|
||||
func NewTestContainer(opts ...Option) *Container {
|
||||
c := &Container{}
|
||||
for _, opt := range opts {
|
||||
opt(c)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// Option 容器配置选项
|
||||
type Option func(*Container)
|
||||
|
||||
// WithDB 设置数据库连接
|
||||
func WithDB(db *gorm.DB) Option {
|
||||
return func(c *Container) {
|
||||
c.DB = db
|
||||
}
|
||||
}
|
||||
|
||||
// WithRedis 设置Redis客户端
|
||||
func WithRedis(redis *redis.Client) Option {
|
||||
return func(c *Container) {
|
||||
c.Redis = redis
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger 设置日志
|
||||
func WithLogger(logger *zap.Logger) Option {
|
||||
return func(c *Container) {
|
||||
c.Logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// WithJWT 设置JWT服务
|
||||
func WithJWT(jwt *auth.JWTService) Option {
|
||||
return func(c *Container) {
|
||||
c.JWT = jwt
|
||||
}
|
||||
}
|
||||
|
||||
// WithStorage 设置存储客户端
|
||||
func WithStorage(storage *storage.StorageClient) Option {
|
||||
return func(c *Container) {
|
||||
c.Storage = storage
|
||||
}
|
||||
}
|
||||
|
||||
// WithUserRepo 设置用户仓储
|
||||
func WithUserRepo(repo repository.UserRepository) Option {
|
||||
return func(c *Container) {
|
||||
c.UserRepo = repo
|
||||
}
|
||||
}
|
||||
|
||||
// WithProfileRepo 设置档案仓储
|
||||
func WithProfileRepo(repo repository.ProfileRepository) Option {
|
||||
return func(c *Container) {
|
||||
c.ProfileRepo = repo
|
||||
}
|
||||
}
|
||||
|
||||
// WithTextureRepo 设置材质仓储
|
||||
func WithTextureRepo(repo repository.TextureRepository) Option {
|
||||
return func(c *Container) {
|
||||
c.TextureRepo = repo
|
||||
}
|
||||
}
|
||||
|
||||
// WithTokenRepo 设置令牌仓储
|
||||
func WithTokenRepo(repo repository.TokenRepository) Option {
|
||||
return func(c *Container) {
|
||||
c.TokenRepo = repo
|
||||
}
|
||||
}
|
||||
|
||||
// WithConfigRepo 设置系统配置仓储
|
||||
func WithConfigRepo(repo repository.SystemConfigRepository) Option {
|
||||
return func(c *Container) {
|
||||
c.ConfigRepo = repo
|
||||
}
|
||||
}
|
||||
|
||||
// WithUserService 设置用户服务
|
||||
func WithUserService(svc service.UserService) Option {
|
||||
return func(c *Container) {
|
||||
c.UserService = svc
|
||||
}
|
||||
}
|
||||
|
||||
// WithProfileService 设置档案服务
|
||||
func WithProfileService(svc service.ProfileService) Option {
|
||||
return func(c *Container) {
|
||||
c.ProfileService = svc
|
||||
}
|
||||
}
|
||||
|
||||
// WithTextureService 设置材质服务
|
||||
func WithTextureService(svc service.TextureService) Option {
|
||||
return func(c *Container) {
|
||||
c.TextureService = svc
|
||||
}
|
||||
}
|
||||
|
||||
// WithTokenService 设置令牌服务
|
||||
func WithTokenService(svc service.TokenService) Option {
|
||||
return func(c *Container) {
|
||||
c.TokenService = svc
|
||||
}
|
||||
}
|
||||
127
internal/errors/errors.go
Normal file
127
internal/errors/errors.go
Normal file
@@ -0,0 +1,127 @@
|
||||
// Package errors 定义应用程序的错误类型
|
||||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// 预定义错误
|
||||
var (
|
||||
// 用户相关错误
|
||||
ErrUserNotFound = errors.New("用户不存在")
|
||||
ErrUserAlreadyExists = errors.New("用户已存在")
|
||||
ErrEmailAlreadyExists = errors.New("邮箱已被注册")
|
||||
ErrInvalidPassword = errors.New("密码错误")
|
||||
ErrAccountDisabled = errors.New("账号已被禁用")
|
||||
|
||||
// 认证相关错误
|
||||
ErrUnauthorized = errors.New("未授权")
|
||||
ErrInvalidToken = errors.New("无效的令牌")
|
||||
ErrTokenExpired = errors.New("令牌已过期")
|
||||
ErrInvalidSignature = errors.New("签名验证失败")
|
||||
|
||||
// 档案相关错误
|
||||
ErrProfileNotFound = errors.New("档案不存在")
|
||||
ErrProfileNameExists = errors.New("角色名已被使用")
|
||||
ErrProfileLimitReached = errors.New("已达档案数量上限")
|
||||
ErrProfileNoPermission = errors.New("无权操作此档案")
|
||||
|
||||
// 材质相关错误
|
||||
ErrTextureNotFound = errors.New("材质不存在")
|
||||
ErrTextureExists = errors.New("该材质已存在")
|
||||
ErrTextureLimitReached = errors.New("已达材质数量上限")
|
||||
ErrTextureNoPermission = errors.New("无权操作此材质")
|
||||
ErrInvalidTextureType = errors.New("无效的材质类型")
|
||||
|
||||
// 验证码相关错误
|
||||
ErrInvalidVerificationCode = errors.New("验证码错误或已过期")
|
||||
ErrTooManyAttempts = errors.New("尝试次数过多")
|
||||
ErrSendTooFrequent = errors.New("发送过于频繁")
|
||||
|
||||
// URL验证相关错误
|
||||
ErrInvalidURL = errors.New("无效的URL格式")
|
||||
ErrDomainNotAllowed = errors.New("URL域名不在允许的列表中")
|
||||
|
||||
// 存储相关错误
|
||||
ErrStorageUnavailable = errors.New("存储服务不可用")
|
||||
ErrUploadFailed = errors.New("上传失败")
|
||||
|
||||
// 通用错误
|
||||
ErrBadRequest = errors.New("请求参数错误")
|
||||
ErrInternalServer = errors.New("服务器内部错误")
|
||||
ErrNotFound = errors.New("资源不存在")
|
||||
ErrForbidden = errors.New("权限不足")
|
||||
)
|
||||
|
||||
// AppError 应用错误类型,包含错误码和消息
|
||||
type AppError struct {
|
||||
Code int // HTTP状态码
|
||||
Message string // 用户可见的错误消息
|
||||
Err error // 原始错误(用于日志)
|
||||
}
|
||||
|
||||
// Error 实现error接口
|
||||
func (e *AppError) Error() string {
|
||||
if e.Err != nil {
|
||||
return fmt.Sprintf("%s: %v", e.Message, e.Err)
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// Unwrap 支持errors.Is和errors.As
|
||||
func (e *AppError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// NewAppError 创建新的应用错误
|
||||
func NewAppError(code int, message string, err error) *AppError {
|
||||
return &AppError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// NewBadRequest 创建400错误
|
||||
func NewBadRequest(message string, err error) *AppError {
|
||||
return NewAppError(400, message, err)
|
||||
}
|
||||
|
||||
// NewUnauthorized 创建401错误
|
||||
func NewUnauthorized(message string) *AppError {
|
||||
return NewAppError(401, message, nil)
|
||||
}
|
||||
|
||||
// NewForbidden 创建403错误
|
||||
func NewForbidden(message string) *AppError {
|
||||
return NewAppError(403, message, nil)
|
||||
}
|
||||
|
||||
// NewNotFound 创建404错误
|
||||
func NewNotFound(message string) *AppError {
|
||||
return NewAppError(404, message, nil)
|
||||
}
|
||||
|
||||
// NewInternalError 创建500错误
|
||||
func NewInternalError(message string, err error) *AppError {
|
||||
return NewAppError(500, message, err)
|
||||
}
|
||||
|
||||
// Is 检查错误是否匹配
|
||||
func Is(err, target error) bool {
|
||||
return errors.Is(err, target)
|
||||
}
|
||||
|
||||
// As 尝试将错误转换为指定类型
|
||||
func As(err error, target interface{}) bool {
|
||||
return errors.As(err, target)
|
||||
}
|
||||
|
||||
// Wrap 包装错误
|
||||
func Wrap(err error, message string) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%s: %w", message, err)
|
||||
}
|
||||
177
internal/handler/auth_handler_di.go
Normal file
177
internal/handler/auth_handler_di.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/container"
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/internal/types"
|
||||
"carrotskin/pkg/email"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AuthHandler 认证处理器(依赖注入版本)
|
||||
type AuthHandler struct {
|
||||
container *container.Container
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAuthHandler 创建AuthHandler实例
|
||||
func NewAuthHandler(c *container.Container) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
container: c,
|
||||
logger: c.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 用户注册
|
||||
// @Summary 用户注册
|
||||
// @Description 注册新用户账号
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body types.RegisterRequest true "注册信息"
|
||||
// @Success 200 {object} model.Response "注册成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/auth/register [post]
|
||||
func (h *AuthHandler) Register(c *gin.Context) {
|
||||
var req types.RegisterRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
RespondBadRequest(c, "请求参数错误", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证邮箱验证码
|
||||
if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.Email, req.VerificationCode, service.VerificationTypeRegister); err != nil {
|
||||
h.logger.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err))
|
||||
RespondBadRequest(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
// 注册用户
|
||||
user, token, err := service.RegisterUser(h.container.JWT, req.Username, req.Password, req.Email, req.Avatar)
|
||||
if err != nil {
|
||||
h.logger.Error("用户注册失败", zap.Error(err))
|
||||
RespondBadRequest(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, &types.LoginResponse{
|
||||
Token: token,
|
||||
UserInfo: UserToUserInfo(user),
|
||||
})
|
||||
}
|
||||
|
||||
// Login 用户登录
|
||||
// @Summary 用户登录
|
||||
// @Description 用户登录获取JWT Token,支持用户名或邮箱登录
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body types.LoginRequest true "登录信息(username字段支持用户名或邮箱)"
|
||||
// @Success 200 {object} model.Response{data=types.LoginResponse} "登录成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Failure 401 {object} model.ErrorResponse "登录失败"
|
||||
// @Router /api/v1/auth/login [post]
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var req types.LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
RespondBadRequest(c, "请求参数错误", err)
|
||||
return
|
||||
}
|
||||
|
||||
ipAddress := c.ClientIP()
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
|
||||
user, token, err := service.LoginUserWithRateLimit(h.container.Redis, h.container.JWT, req.Username, req.Password, ipAddress, userAgent)
|
||||
if err != nil {
|
||||
h.logger.Warn("用户登录失败",
|
||||
zap.String("username_or_email", req.Username),
|
||||
zap.String("ip", ipAddress),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondUnauthorized(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, &types.LoginResponse{
|
||||
Token: token,
|
||||
UserInfo: UserToUserInfo(user),
|
||||
})
|
||||
}
|
||||
|
||||
// SendVerificationCode 发送验证码
|
||||
// @Summary 发送验证码
|
||||
// @Description 发送邮箱验证码(注册/重置密码/更换邮箱)
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body types.SendVerificationCodeRequest true "发送验证码请求"
|
||||
// @Success 200 {object} model.Response "发送成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/auth/send-code [post]
|
||||
func (h *AuthHandler) SendVerificationCode(c *gin.Context) {
|
||||
var req types.SendVerificationCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
RespondBadRequest(c, "请求参数错误", err)
|
||||
return
|
||||
}
|
||||
|
||||
emailService, err := h.getEmailService()
|
||||
if err != nil {
|
||||
RespondServerError(c, "邮件服务不可用", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.SendVerificationCode(c.Request.Context(), h.container.Redis, emailService, req.Email, req.Type); err != nil {
|
||||
h.logger.Error("发送验证码失败",
|
||||
zap.String("email", req.Email),
|
||||
zap.String("type", req.Type),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondBadRequest(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, gin.H{"message": "验证码已发送,请查收邮件"})
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
// @Summary 重置密码
|
||||
// @Description 通过邮箱验证码重置密码
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body types.ResetPasswordRequest true "重置密码请求"
|
||||
// @Success 200 {object} model.Response "重置成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/auth/reset-password [post]
|
||||
func (h *AuthHandler) ResetPassword(c *gin.Context) {
|
||||
var req types.ResetPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
RespondBadRequest(c, "请求参数错误", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证验证码
|
||||
if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.Email, req.VerificationCode, service.VerificationTypeResetPassword); err != nil {
|
||||
h.logger.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err))
|
||||
RespondBadRequest(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
// 重置密码
|
||||
if err := service.ResetUserPassword(req.Email, req.NewPassword); err != nil {
|
||||
h.logger.Error("重置密码失败", zap.String("email", req.Email), zap.Error(err))
|
||||
RespondServerError(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, gin.H{"message": "密码重置成功"})
|
||||
}
|
||||
|
||||
// getEmailService 获取邮件服务(暂时使用全局方式,后续可改为依赖注入)
|
||||
func (h *AuthHandler) getEmailService() (*email.Service, error) {
|
||||
return email.GetService()
|
||||
}
|
||||
|
||||
109
internal/handler/captcha_handler_di.go
Normal file
109
internal/handler/captcha_handler_di.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/container"
|
||||
"carrotskin/internal/service"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// CaptchaHandler 验证码处理器
|
||||
type CaptchaHandler struct {
|
||||
container *container.Container
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewCaptchaHandler 创建CaptchaHandler实例
|
||||
func NewCaptchaHandler(c *container.Container) *CaptchaHandler {
|
||||
return &CaptchaHandler{
|
||||
container: c,
|
||||
logger: c.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CaptchaVerifyRequest 验证码验证请求
|
||||
type CaptchaVerifyRequest struct {
|
||||
CaptchaID string `json:"captchaId" binding:"required"`
|
||||
Dx int `json:"dx" binding:"required"`
|
||||
}
|
||||
|
||||
// Generate 生成验证码
|
||||
// @Summary 生成滑动验证码
|
||||
// @Description 生成滑动验证码图片
|
||||
// @Tags captcha
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} map[string]interface{} "生成成功"
|
||||
// @Failure 500 {object} map[string]interface{} "生成失败"
|
||||
// @Router /api/v1/captcha/generate [get]
|
||||
func (h *CaptchaHandler) Generate(c *gin.Context) {
|
||||
masterImg, tileImg, captchaID, y, err := service.GenerateCaptchaData(c.Request.Context(), h.container.Redis)
|
||||
if err != nil {
|
||||
h.logger.Error("生成验证码失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"msg": "生成验证码失败",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 200,
|
||||
"data": gin.H{
|
||||
"masterImage": masterImg,
|
||||
"tileImage": tileImg,
|
||||
"captchaId": captchaID,
|
||||
"y": y,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Verify 验证验证码
|
||||
// @Summary 验证滑动验证码
|
||||
// @Description 验证用户滑动的偏移量是否正确
|
||||
// @Tags captcha
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body CaptchaVerifyRequest true "验证请求"
|
||||
// @Success 200 {object} map[string]interface{} "验证结果"
|
||||
// @Failure 400 {object} map[string]interface{} "参数错误"
|
||||
// @Router /api/v1/captcha/verify [post]
|
||||
func (h *CaptchaHandler) Verify(c *gin.Context) {
|
||||
var req CaptchaVerifyRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"msg": "参数错误: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
valid, err := service.VerifyCaptchaData(c.Request.Context(), h.container.Redis, req.Dx, req.CaptchaID)
|
||||
if err != nil {
|
||||
h.logger.Error("验证码验证失败",
|
||||
zap.String("captcha_id", req.CaptchaID),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"msg": "验证失败",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if valid {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 200,
|
||||
"msg": "验证成功",
|
||||
})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 400,
|
||||
"msg": "验证失败,请重试",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,14 +4,24 @@ import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/types"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// parseIntWithDefault 将字符串解析为整数,解析失败返回默认值
|
||||
func parseIntWithDefault(s string, defaultVal int) int {
|
||||
val, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return defaultVal
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// GetUserIDFromContext 从上下文获取用户ID,如果不存在返回未授权响应
|
||||
// 返回值: userID, ok (如果ok为false,已经发送了错误响应)
|
||||
func GetUserIDFromContext(c *gin.Context) (int64, bool) {
|
||||
userID, exists := c.Get("user_id")
|
||||
userIDValue, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
@@ -20,7 +30,19 @@ func GetUserIDFromContext(c *gin.Context) (int64, bool) {
|
||||
))
|
||||
return 0, false
|
||||
}
|
||||
return userID.(int64), true
|
||||
|
||||
// 安全的类型断言
|
||||
userID, ok := userIDValue.(int64)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
"用户ID类型错误",
|
||||
nil,
|
||||
))
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return userID, true
|
||||
}
|
||||
|
||||
// UserToUserInfo 将 User 模型转换为 UserInfo 响应
|
||||
@@ -157,4 +179,3 @@ func RespondWithError(c *gin.Context, err error) {
|
||||
RespondServerError(c, msg, nil)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
247
internal/handler/profile_handler_di.go
Normal file
247
internal/handler/profile_handler_di.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/container"
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ProfileHandler 档案处理器
|
||||
type ProfileHandler struct {
|
||||
container *container.Container
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewProfileHandler 创建ProfileHandler实例
|
||||
func NewProfileHandler(c *container.Container) *ProfileHandler {
|
||||
return &ProfileHandler{
|
||||
container: c,
|
||||
logger: c.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建档案
|
||||
// @Summary 创建Minecraft档案
|
||||
// @Description 创建新的Minecraft角色档案,UUID由后端自动生成
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body types.CreateProfileRequest true "档案信息(仅需提供角色名)"
|
||||
// @Success 200 {object} model.Response{data=types.ProfileInfo} "创建成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/profile [post]
|
||||
func (h *ProfileHandler) Create(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req types.CreateProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
RespondBadRequest(c, "请求参数错误: "+err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
maxProfiles := service.GetMaxProfilesPerUser()
|
||||
if err := service.CheckProfileLimit(h.container.DB, userID, maxProfiles); err != nil {
|
||||
RespondBadRequest(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
profile, err := service.CreateProfile(h.container.DB, userID, req.Name)
|
||||
if err != nil {
|
||||
h.logger.Error("创建档案失败",
|
||||
zap.Int64("user_id", userID),
|
||||
zap.String("name", req.Name),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondServerError(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, ProfileToProfileInfo(profile))
|
||||
}
|
||||
|
||||
// List 获取档案列表
|
||||
// @Summary 获取档案列表
|
||||
// @Description 获取当前用户的所有档案
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Success 200 {object} model.Response "获取成功"
|
||||
// @Router /api/v1/profile [get]
|
||||
func (h *ProfileHandler) List(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
profiles, err := service.GetUserProfiles(h.container.DB, userID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取档案列表失败",
|
||||
zap.Int64("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondServerError(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, ProfilesToProfileInfos(profiles))
|
||||
}
|
||||
|
||||
// Get 获取档案详情
|
||||
// @Summary 获取档案详情
|
||||
// @Description 根据UUID获取档案详细信息
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param uuid path string true "档案UUID"
|
||||
// @Success 200 {object} model.Response "获取成功"
|
||||
// @Failure 404 {object} model.ErrorResponse "档案不存在"
|
||||
// @Router /api/v1/profile/{uuid} [get]
|
||||
func (h *ProfileHandler) Get(c *gin.Context) {
|
||||
uuid := c.Param("uuid")
|
||||
if uuid == "" {
|
||||
RespondBadRequest(c, "UUID不能为空", nil)
|
||||
return
|
||||
}
|
||||
|
||||
profile, err := service.GetProfileByUUID(h.container.DB, uuid)
|
||||
if err != nil {
|
||||
h.logger.Error("获取档案失败",
|
||||
zap.String("uuid", uuid),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondNotFound(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, ProfileToProfileInfo(profile))
|
||||
}
|
||||
|
||||
// Update 更新档案
|
||||
// @Summary 更新档案
|
||||
// @Description 更新档案信息
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param uuid path string true "档案UUID"
|
||||
// @Param request body types.UpdateProfileRequest true "更新信息"
|
||||
// @Success 200 {object} model.Response "更新成功"
|
||||
// @Failure 403 {object} model.ErrorResponse "无权操作"
|
||||
// @Router /api/v1/profile/{uuid} [put]
|
||||
func (h *ProfileHandler) Update(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
uuid := c.Param("uuid")
|
||||
if uuid == "" {
|
||||
RespondBadRequest(c, "UUID不能为空", nil)
|
||||
return
|
||||
}
|
||||
|
||||
var req types.UpdateProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
RespondBadRequest(c, "请求参数错误: "+err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
var namePtr *string
|
||||
if req.Name != "" {
|
||||
namePtr = &req.Name
|
||||
}
|
||||
|
||||
profile, err := service.UpdateProfile(h.container.DB, uuid, userID, namePtr, req.SkinID, req.CapeID)
|
||||
if err != nil {
|
||||
h.logger.Error("更新档案失败",
|
||||
zap.String("uuid", uuid),
|
||||
zap.Int64("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondWithError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, ProfileToProfileInfo(profile))
|
||||
}
|
||||
|
||||
// Delete 删除档案
|
||||
// @Summary 删除档案
|
||||
// @Description 删除指定的Minecraft档案
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param uuid path string true "档案UUID"
|
||||
// @Success 200 {object} model.Response "删除成功"
|
||||
// @Failure 403 {object} model.ErrorResponse "无权操作"
|
||||
// @Router /api/v1/profile/{uuid} [delete]
|
||||
func (h *ProfileHandler) Delete(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
uuid := c.Param("uuid")
|
||||
if uuid == "" {
|
||||
RespondBadRequest(c, "UUID不能为空", nil)
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.DeleteProfile(h.container.DB, uuid, userID); err != nil {
|
||||
h.logger.Error("删除档案失败",
|
||||
zap.String("uuid", uuid),
|
||||
zap.Int64("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondWithError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
// SetActive 设置活跃档案
|
||||
// @Summary 设置活跃档案
|
||||
// @Description 将指定档案设置为活跃状态
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param uuid path string true "档案UUID"
|
||||
// @Success 200 {object} model.Response "设置成功"
|
||||
// @Failure 403 {object} model.ErrorResponse "无权操作"
|
||||
// @Router /api/v1/profile/{uuid}/activate [post]
|
||||
func (h *ProfileHandler) SetActive(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
uuid := c.Param("uuid")
|
||||
if uuid == "" {
|
||||
RespondBadRequest(c, "UUID不能为空", nil)
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.SetActiveProfile(h.container.DB, uuid, userID); err != nil {
|
||||
h.logger.Error("设置活跃档案失败",
|
||||
zap.String("uuid", uuid),
|
||||
zap.Int64("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondWithError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, gin.H{"message": "设置成功"})
|
||||
}
|
||||
|
||||
193
internal/handler/routes_di.go
Normal file
193
internal/handler/routes_di.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/container"
|
||||
"carrotskin/internal/middleware"
|
||||
"carrotskin/internal/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Handlers 集中管理所有Handler
|
||||
type Handlers struct {
|
||||
Auth *AuthHandler
|
||||
User *UserHandler
|
||||
Texture *TextureHandler
|
||||
Profile *ProfileHandler
|
||||
Captcha *CaptchaHandler
|
||||
Yggdrasil *YggdrasilHandler
|
||||
}
|
||||
|
||||
// NewHandlers 创建所有Handler实例
|
||||
func NewHandlers(c *container.Container) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: NewAuthHandler(c),
|
||||
User: NewUserHandler(c),
|
||||
Texture: NewTextureHandler(c),
|
||||
Profile: NewProfileHandler(c),
|
||||
Captcha: NewCaptchaHandler(c),
|
||||
Yggdrasil: NewYggdrasilHandler(c),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutesWithDI 使用依赖注入注册所有路由
|
||||
func RegisterRoutesWithDI(router *gin.Engine, c *container.Container) {
|
||||
// 设置Swagger文档
|
||||
SetupSwagger(router)
|
||||
|
||||
// 创建Handler实例
|
||||
h := NewHandlers(c)
|
||||
|
||||
// API路由组
|
||||
v1 := router.Group("/api/v1")
|
||||
{
|
||||
// 认证路由(无需JWT)
|
||||
registerAuthRoutes(v1, h.Auth)
|
||||
|
||||
// 用户路由(需要JWT认证)
|
||||
registerUserRoutes(v1, h.User)
|
||||
|
||||
// 材质路由
|
||||
registerTextureRoutes(v1, h.Texture)
|
||||
|
||||
// 档案路由
|
||||
registerProfileRoutesWithDI(v1, h.Profile)
|
||||
|
||||
// 验证码路由
|
||||
registerCaptchaRoutesWithDI(v1, h.Captcha)
|
||||
|
||||
// Yggdrasil API路由组
|
||||
registerYggdrasilRoutesWithDI(v1, h.Yggdrasil)
|
||||
|
||||
// 系统路由
|
||||
registerSystemRoutes(v1)
|
||||
}
|
||||
}
|
||||
|
||||
// registerAuthRoutes 注册认证路由
|
||||
func registerAuthRoutes(v1 *gin.RouterGroup, h *AuthHandler) {
|
||||
authGroup := v1.Group("/auth")
|
||||
{
|
||||
authGroup.POST("/register", h.Register)
|
||||
authGroup.POST("/login", h.Login)
|
||||
authGroup.POST("/send-code", h.SendVerificationCode)
|
||||
authGroup.POST("/reset-password", h.ResetPassword)
|
||||
}
|
||||
}
|
||||
|
||||
// registerUserRoutes 注册用户路由
|
||||
func registerUserRoutes(v1 *gin.RouterGroup, h *UserHandler) {
|
||||
userGroup := v1.Group("/user")
|
||||
userGroup.Use(middleware.AuthMiddleware())
|
||||
{
|
||||
userGroup.GET("/profile", h.GetProfile)
|
||||
userGroup.PUT("/profile", h.UpdateProfile)
|
||||
|
||||
// 头像相关
|
||||
userGroup.POST("/avatar/upload-url", h.GenerateAvatarUploadURL)
|
||||
userGroup.PUT("/avatar", h.UpdateAvatar)
|
||||
|
||||
// 更换邮箱
|
||||
userGroup.POST("/change-email", h.ChangeEmail)
|
||||
|
||||
// Yggdrasil密码相关
|
||||
userGroup.POST("/yggdrasil-password/reset", h.ResetYggdrasilPassword)
|
||||
}
|
||||
}
|
||||
|
||||
// registerTextureRoutes 注册材质路由
|
||||
func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) {
|
||||
textureGroup := v1.Group("/texture")
|
||||
{
|
||||
// 公开路由(无需认证)
|
||||
textureGroup.GET("", h.Search)
|
||||
textureGroup.GET("/:id", h.Get)
|
||||
|
||||
// 需要认证的路由
|
||||
textureAuth := textureGroup.Group("")
|
||||
textureAuth.Use(middleware.AuthMiddleware())
|
||||
{
|
||||
textureAuth.POST("/upload-url", h.GenerateUploadURL)
|
||||
textureAuth.POST("", h.Create)
|
||||
textureAuth.PUT("/:id", h.Update)
|
||||
textureAuth.DELETE("/:id", h.Delete)
|
||||
textureAuth.POST("/:id/favorite", h.ToggleFavorite)
|
||||
textureAuth.GET("/my", h.GetUserTextures)
|
||||
textureAuth.GET("/favorites", h.GetUserFavorites)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// registerProfileRoutesWithDI 注册档案路由(依赖注入版本)
|
||||
func registerProfileRoutesWithDI(v1 *gin.RouterGroup, h *ProfileHandler) {
|
||||
profileGroup := v1.Group("/profile")
|
||||
{
|
||||
// 公开路由(无需认证)
|
||||
profileGroup.GET("/:uuid", h.Get)
|
||||
|
||||
// 需要认证的路由
|
||||
profileAuth := profileGroup.Group("")
|
||||
profileAuth.Use(middleware.AuthMiddleware())
|
||||
{
|
||||
profileAuth.POST("/", h.Create)
|
||||
profileAuth.GET("/", h.List)
|
||||
profileAuth.PUT("/:uuid", h.Update)
|
||||
profileAuth.DELETE("/:uuid", h.Delete)
|
||||
profileAuth.POST("/:uuid/activate", h.SetActive)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// registerCaptchaRoutesWithDI 注册验证码路由(依赖注入版本)
|
||||
func registerCaptchaRoutesWithDI(v1 *gin.RouterGroup, h *CaptchaHandler) {
|
||||
captchaGroup := v1.Group("/captcha")
|
||||
{
|
||||
captchaGroup.GET("/generate", h.Generate)
|
||||
captchaGroup.POST("/verify", h.Verify)
|
||||
}
|
||||
}
|
||||
|
||||
// registerYggdrasilRoutesWithDI 注册Yggdrasil API路由(依赖注入版本)
|
||||
func registerYggdrasilRoutesWithDI(v1 *gin.RouterGroup, h *YggdrasilHandler) {
|
||||
ygg := v1.Group("/yggdrasil")
|
||||
{
|
||||
ygg.GET("", h.GetMetaData)
|
||||
ygg.POST("/minecraftservices/player/certificates", h.GetPlayerCertificates)
|
||||
authserver := ygg.Group("/authserver")
|
||||
{
|
||||
authserver.POST("/authenticate", h.Authenticate)
|
||||
authserver.POST("/validate", h.ValidToken)
|
||||
authserver.POST("/refresh", h.RefreshToken)
|
||||
authserver.POST("/invalidate", h.InvalidToken)
|
||||
authserver.POST("/signout", h.SignOut)
|
||||
}
|
||||
sessionServer := ygg.Group("/sessionserver")
|
||||
{
|
||||
sessionServer.GET("/session/minecraft/profile/:uuid", h.GetProfileByUUID)
|
||||
sessionServer.POST("/session/minecraft/join", h.JoinServer)
|
||||
sessionServer.GET("/session/minecraft/hasJoined", h.HasJoinedServer)
|
||||
}
|
||||
api := ygg.Group("/api")
|
||||
profiles := api.Group("/profiles")
|
||||
{
|
||||
profiles.POST("/minecraft", h.GetProfilesByName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// registerSystemRoutes 注册系统路由
|
||||
func registerSystemRoutes(v1 *gin.RouterGroup) {
|
||||
system := v1.Group("/system")
|
||||
{
|
||||
system.GET("/config", func(c *gin.Context) {
|
||||
// TODO: 实现从数据库读取系统配置
|
||||
c.JSON(200, model.NewSuccessResponse(gin.H{
|
||||
"site_name": "CarrotSkin",
|
||||
"site_description": "A Minecraft Skin Station",
|
||||
"registration_enabled": true,
|
||||
"max_textures_per_user": 100,
|
||||
"max_profiles_per_user": 5,
|
||||
}))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -160,8 +160,8 @@ func SearchTextures(c *gin.Context) {
|
||||
textureTypeStr := c.Query("type")
|
||||
publicOnly := c.Query("public_only") == "true"
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1)
|
||||
pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20)
|
||||
|
||||
var textureType model.TextureType
|
||||
switch textureTypeStr {
|
||||
@@ -314,8 +314,8 @@ func GetUserTextures(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1)
|
||||
pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20)
|
||||
|
||||
textures, total, err := service.GetUserTextures(database.MustGetDB(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
@@ -344,8 +344,8 @@ func GetUserFavorites(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1)
|
||||
pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20)
|
||||
|
||||
textures, total, err := service.GetUserTextureFavorites(database.MustGetDB(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
|
||||
285
internal/handler/texture_handler_di.go
Normal file
285
internal/handler/texture_handler_di.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/container"
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/internal/types"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TextureHandler 材质处理器(依赖注入版本)
|
||||
type TextureHandler struct {
|
||||
container *container.Container
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewTextureHandler 创建TextureHandler实例
|
||||
func NewTextureHandler(c *container.Container) *TextureHandler {
|
||||
return &TextureHandler{
|
||||
container: c,
|
||||
logger: c.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateUploadURL 生成材质上传URL
|
||||
func (h *TextureHandler) GenerateUploadURL(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req types.GenerateTextureUploadURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
RespondBadRequest(c, "请求参数错误", err)
|
||||
return
|
||||
}
|
||||
|
||||
if h.container.Storage == nil {
|
||||
RespondServerError(c, "存储服务不可用", nil)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := service.GenerateTextureUploadURL(
|
||||
c.Request.Context(),
|
||||
h.container.Storage,
|
||||
userID,
|
||||
req.FileName,
|
||||
string(req.TextureType),
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Error("生成材质上传URL失败",
|
||||
zap.Int64("user_id", userID),
|
||||
zap.String("file_name", req.FileName),
|
||||
zap.String("texture_type", string(req.TextureType)),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondBadRequest(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, &types.GenerateTextureUploadURLResponse{
|
||||
PostURL: result.PostURL,
|
||||
FormData: result.FormData,
|
||||
TextureURL: result.FileURL,
|
||||
ExpiresIn: 900,
|
||||
})
|
||||
}
|
||||
|
||||
// Create 创建材质记录
|
||||
func (h *TextureHandler) Create(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req types.CreateTextureRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
RespondBadRequest(c, "请求参数错误", err)
|
||||
return
|
||||
}
|
||||
|
||||
maxTextures := service.GetMaxTexturesPerUser()
|
||||
if err := service.CheckTextureUploadLimit(h.container.DB, userID, maxTextures); err != nil {
|
||||
RespondBadRequest(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
texture, err := service.CreateTexture(h.container.DB,
|
||||
userID,
|
||||
req.Name,
|
||||
req.Description,
|
||||
string(req.Type),
|
||||
req.URL,
|
||||
req.Hash,
|
||||
req.Size,
|
||||
req.IsPublic,
|
||||
req.IsSlim,
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Error("创建材质失败",
|
||||
zap.Int64("user_id", userID),
|
||||
zap.String("name", req.Name),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondBadRequest(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, TextureToTextureInfo(texture))
|
||||
}
|
||||
|
||||
// Get 获取材质详情
|
||||
func (h *TextureHandler) Get(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
RespondBadRequest(c, "无效的材质ID", err)
|
||||
return
|
||||
}
|
||||
|
||||
texture, err := service.GetTextureByID(h.container.DB, id)
|
||||
if err != nil {
|
||||
RespondNotFound(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, TextureToTextureInfo(texture))
|
||||
}
|
||||
|
||||
// Search 搜索材质
|
||||
func (h *TextureHandler) Search(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
textureTypeStr := c.Query("type")
|
||||
publicOnly := c.Query("public_only") == "true"
|
||||
|
||||
page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1)
|
||||
pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20)
|
||||
|
||||
var textureType model.TextureType
|
||||
switch textureTypeStr {
|
||||
case "SKIN":
|
||||
textureType = model.TextureTypeSkin
|
||||
case "CAPE":
|
||||
textureType = model.TextureTypeCape
|
||||
}
|
||||
|
||||
textures, total, err := service.SearchTextures(h.container.DB, keyword, textureType, publicOnly, page, pageSize)
|
||||
if err != nil {
|
||||
h.logger.Error("搜索材质失败", zap.String("keyword", keyword), zap.Error(err))
|
||||
RespondServerError(c, "搜索材质失败", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize))
|
||||
}
|
||||
|
||||
// Update 更新材质
|
||||
func (h *TextureHandler) Update(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
textureID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
RespondBadRequest(c, "无效的材质ID", err)
|
||||
return
|
||||
}
|
||||
|
||||
var req types.UpdateTextureRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
RespondBadRequest(c, "请求参数错误", err)
|
||||
return
|
||||
}
|
||||
|
||||
texture, err := service.UpdateTexture(h.container.DB, textureID, userID, req.Name, req.Description, req.IsPublic)
|
||||
if err != nil {
|
||||
h.logger.Error("更新材质失败",
|
||||
zap.Int64("user_id", userID),
|
||||
zap.Int64("texture_id", textureID),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondForbidden(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, TextureToTextureInfo(texture))
|
||||
}
|
||||
|
||||
// Delete 删除材质
|
||||
func (h *TextureHandler) Delete(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
textureID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
RespondBadRequest(c, "无效的材质ID", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.DeleteTexture(h.container.DB, textureID, userID); err != nil {
|
||||
h.logger.Error("删除材质失败",
|
||||
zap.Int64("user_id", userID),
|
||||
zap.Int64("texture_id", textureID),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondForbidden(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, nil)
|
||||
}
|
||||
|
||||
// ToggleFavorite 切换收藏状态
|
||||
func (h *TextureHandler) ToggleFavorite(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
textureID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
RespondBadRequest(c, "无效的材质ID", err)
|
||||
return
|
||||
}
|
||||
|
||||
isFavorited, err := service.ToggleTextureFavorite(h.container.DB, userID, textureID)
|
||||
if err != nil {
|
||||
h.logger.Error("切换收藏状态失败",
|
||||
zap.Int64("user_id", userID),
|
||||
zap.Int64("texture_id", textureID),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondBadRequest(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, map[string]bool{"is_favorited": isFavorited})
|
||||
}
|
||||
|
||||
// GetUserTextures 获取用户上传的材质列表
|
||||
func (h *TextureHandler) GetUserTextures(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1)
|
||||
pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20)
|
||||
|
||||
textures, total, err := service.GetUserTextures(h.container.DB, userID, page, pageSize)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户材质列表失败", zap.Int64("user_id", userID), zap.Error(err))
|
||||
RespondServerError(c, "获取材质列表失败", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize))
|
||||
}
|
||||
|
||||
// GetUserFavorites 获取用户收藏的材质列表
|
||||
func (h *TextureHandler) GetUserFavorites(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1)
|
||||
pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20)
|
||||
|
||||
textures, total, err := service.GetUserTextureFavorites(h.container.DB, userID, page, pageSize)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户收藏列表失败", zap.Int64("user_id", userID), zap.Error(err))
|
||||
RespondServerError(c, "获取收藏列表失败", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize))
|
||||
}
|
||||
|
||||
|
||||
233
internal/handler/user_handler_di.go
Normal file
233
internal/handler/user_handler_di.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/container"
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/internal/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// UserHandler 用户处理器(依赖注入版本)
|
||||
type UserHandler struct {
|
||||
container *container.Container
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewUserHandler 创建UserHandler实例
|
||||
func NewUserHandler(c *container.Container) *UserHandler {
|
||||
return &UserHandler{
|
||||
container: c,
|
||||
logger: c.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetProfile 获取用户信息
|
||||
func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := service.GetUserByID(userID)
|
||||
if err != nil || user == nil {
|
||||
h.logger.Error("获取用户信息失败",
|
||||
zap.Int64("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondNotFound(c, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, UserToUserInfo(user))
|
||||
}
|
||||
|
||||
// UpdateProfile 更新用户信息
|
||||
func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req types.UpdateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
RespondBadRequest(c, "请求参数错误", err)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := service.GetUserByID(userID)
|
||||
if err != nil || user == nil {
|
||||
RespondNotFound(c, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 处理密码修改
|
||||
if req.NewPassword != "" {
|
||||
if req.OldPassword == "" {
|
||||
RespondBadRequest(c, "修改密码需要提供原密码", nil)
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.ChangeUserPassword(userID, req.OldPassword, req.NewPassword); err != nil {
|
||||
h.logger.Error("修改密码失败", zap.Int64("user_id", userID), zap.Error(err))
|
||||
RespondBadRequest(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("用户修改密码成功", zap.Int64("user_id", userID))
|
||||
}
|
||||
|
||||
// 更新头像
|
||||
if req.Avatar != "" {
|
||||
if err := service.ValidateAvatarURL(req.Avatar); err != nil {
|
||||
RespondBadRequest(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
user.Avatar = req.Avatar
|
||||
if err := service.UpdateUserInfo(user); err != nil {
|
||||
h.logger.Error("更新用户信息失败", zap.Int64("user_id", user.ID), zap.Error(err))
|
||||
RespondServerError(c, "更新失败", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 重新获取更新后的用户信息
|
||||
updatedUser, err := service.GetUserByID(userID)
|
||||
if err != nil || updatedUser == nil {
|
||||
RespondNotFound(c, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, UserToUserInfo(updatedUser))
|
||||
}
|
||||
|
||||
// GenerateAvatarUploadURL 生成头像上传URL
|
||||
func (h *UserHandler) GenerateAvatarUploadURL(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req types.GenerateAvatarUploadURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
RespondBadRequest(c, "请求参数错误", err)
|
||||
return
|
||||
}
|
||||
|
||||
if h.container.Storage == nil {
|
||||
RespondServerError(c, "存储服务不可用", nil)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := service.GenerateAvatarUploadURL(c.Request.Context(), h.container.Storage, userID, req.FileName)
|
||||
if err != nil {
|
||||
h.logger.Error("生成头像上传URL失败",
|
||||
zap.Int64("user_id", userID),
|
||||
zap.String("file_name", req.FileName),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondBadRequest(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, &types.GenerateAvatarUploadURLResponse{
|
||||
PostURL: result.PostURL,
|
||||
FormData: result.FormData,
|
||||
AvatarURL: result.FileURL,
|
||||
ExpiresIn: 900,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAvatar 更新头像URL
|
||||
func (h *UserHandler) UpdateAvatar(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
avatarURL := c.Query("avatar_url")
|
||||
if avatarURL == "" {
|
||||
RespondBadRequest(c, "头像URL不能为空", nil)
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.ValidateAvatarURL(avatarURL); err != nil {
|
||||
RespondBadRequest(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.UpdateUserAvatar(userID, avatarURL); err != nil {
|
||||
h.logger.Error("更新头像失败",
|
||||
zap.Int64("user_id", userID),
|
||||
zap.String("avatar_url", avatarURL),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondServerError(c, "更新头像失败", err)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := service.GetUserByID(userID)
|
||||
if err != nil || user == nil {
|
||||
RespondNotFound(c, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, UserToUserInfo(user))
|
||||
}
|
||||
|
||||
// ChangeEmail 更换邮箱
|
||||
func (h *UserHandler) ChangeEmail(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req types.ChangeEmailRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
RespondBadRequest(c, "请求参数错误", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.NewEmail, req.VerificationCode, service.VerificationTypeChangeEmail); err != nil {
|
||||
h.logger.Warn("验证码验证失败", zap.String("new_email", req.NewEmail), zap.Error(err))
|
||||
RespondBadRequest(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.ChangeUserEmail(userID, req.NewEmail); err != nil {
|
||||
h.logger.Error("更换邮箱失败",
|
||||
zap.Int64("user_id", userID),
|
||||
zap.String("new_email", req.NewEmail),
|
||||
zap.Error(err),
|
||||
)
|
||||
RespondBadRequest(c, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := service.GetUserByID(userID)
|
||||
if err != nil || user == nil {
|
||||
RespondNotFound(c, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
RespondSuccess(c, UserToUserInfo(user))
|
||||
}
|
||||
|
||||
// ResetYggdrasilPassword 重置Yggdrasil密码
|
||||
func (h *UserHandler) ResetYggdrasilPassword(c *gin.Context) {
|
||||
userID, ok := GetUserIDFromContext(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
newPassword, err := service.ResetYggdrasilPassword(h.container.DB, userID)
|
||||
if err != nil {
|
||||
h.logger.Error("重置Yggdrasil密码失败", zap.Error(err), zap.Int64("userId", userID))
|
||||
RespondServerError(c, "重置Yggdrasil密码失败", nil)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("Yggdrasil密码重置成功", zap.Int64("userId", userID))
|
||||
RespondSuccess(c, gin.H{"password": newPassword})
|
||||
}
|
||||
454
internal/handler/yggdrasil_handler_di.go
Normal file
454
internal/handler/yggdrasil_handler_di.go
Normal file
@@ -0,0 +1,454 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"carrotskin/internal/container"
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/pkg/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// YggdrasilHandler Yggdrasil API处理器
|
||||
type YggdrasilHandler struct {
|
||||
container *container.Container
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewYggdrasilHandler 创建YggdrasilHandler实例
|
||||
func NewYggdrasilHandler(c *container.Container) *YggdrasilHandler {
|
||||
return &YggdrasilHandler{
|
||||
container: c,
|
||||
logger: c.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Authenticate 用户认证
|
||||
func (h *YggdrasilHandler) Authenticate(c *gin.Context) {
|
||||
rawData, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
h.logger.Error("读取请求体失败", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
|
||||
return
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(rawData))
|
||||
|
||||
var request AuthenticateRequest
|
||||
if err = c.ShouldBindJSON(&request); err != nil {
|
||||
h.logger.Error("解析认证请求失败", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var userId int64
|
||||
var profile *model.Profile
|
||||
var UUID string
|
||||
|
||||
if emailRegex.MatchString(request.Identifier) {
|
||||
userId, err = service.GetUserIDByEmail(h.container.DB, request.Identifier)
|
||||
} else {
|
||||
profile, err = service.GetProfileByProfileName(h.container.DB, request.Identifier)
|
||||
if err != nil {
|
||||
h.logger.Error("用户名不存在", zap.String("identifier", request.Identifier), zap.Error(err))
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
userId = profile.UserID
|
||||
UUID = profile.UUID
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
h.logger.Warn("认证失败: 用户不存在", zap.String("identifier", request.Identifier), zap.Error(err))
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "用户不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.VerifyPassword(h.container.DB, request.Password, userId); err != nil {
|
||||
h.logger.Warn("认证失败: 密码错误", zap.Error(err))
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": ErrWrongPassword})
|
||||
return
|
||||
}
|
||||
|
||||
selectedProfile, availableProfiles, accessToken, clientToken, err := service.NewToken(h.container.DB, h.logger, userId, UUID, request.ClientToken)
|
||||
if err != nil {
|
||||
h.logger.Error("生成令牌失败", zap.Error(err), zap.Int64("userId", userId))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := service.GetUserByID(userId)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户信息失败", zap.Error(err), zap.Int64("userId", userId))
|
||||
}
|
||||
|
||||
availableProfilesData := make([]map[string]interface{}, 0, len(availableProfiles))
|
||||
for _, p := range availableProfiles {
|
||||
availableProfilesData = append(availableProfilesData, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *p))
|
||||
}
|
||||
|
||||
response := AuthenticateResponse{
|
||||
AccessToken: accessToken,
|
||||
ClientToken: clientToken,
|
||||
AvailableProfiles: availableProfilesData,
|
||||
}
|
||||
|
||||
if selectedProfile != nil {
|
||||
response.SelectedProfile = service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *selectedProfile)
|
||||
}
|
||||
|
||||
if request.RequestUser && user != nil {
|
||||
response.User = service.SerializeUser(h.logger, user, UUID)
|
||||
}
|
||||
|
||||
h.logger.Info("用户认证成功", zap.Int64("userId", userId))
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// ValidToken 验证令牌
|
||||
func (h *YggdrasilHandler) ValidToken(c *gin.Context) {
|
||||
var request ValidTokenRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
h.logger.Error("解析验证令牌请求失败", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if service.ValidToken(h.container.DB, request.AccessToken, request.ClientToken) {
|
||||
h.logger.Info("令牌验证成功", zap.String("accessToken", request.AccessToken))
|
||||
c.JSON(http.StatusNoContent, gin.H{"valid": true})
|
||||
} else {
|
||||
h.logger.Warn("令牌验证失败", zap.String("accessToken", request.AccessToken))
|
||||
c.JSON(http.StatusForbidden, gin.H{"valid": false})
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshToken 刷新令牌
|
||||
func (h *YggdrasilHandler) RefreshToken(c *gin.Context) {
|
||||
var request RefreshRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
h.logger.Error("解析刷新令牌请求失败", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
UUID, err := service.GetUUIDByAccessToken(h.container.DB, request.AccessToken)
|
||||
if err != nil {
|
||||
h.logger.Warn("刷新令牌失败: 无效的访问令牌", zap.String("token", request.AccessToken), zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
userID, _ := service.GetUserIDByAccessToken(h.container.DB, request.AccessToken)
|
||||
UUID = utils.FormatUUID(UUID)
|
||||
|
||||
profile, err := service.GetProfileByUUID(h.container.DB, UUID)
|
||||
if err != nil {
|
||||
h.logger.Error("刷新令牌失败: 无法获取用户信息", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var profileData map[string]interface{}
|
||||
var userData map[string]interface{}
|
||||
var profileID string
|
||||
|
||||
if request.SelectedProfile != nil {
|
||||
profileIDValue, ok := request.SelectedProfile["id"]
|
||||
if !ok {
|
||||
h.logger.Error("刷新令牌失败: 缺少配置文件ID", zap.Int64("userId", userID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少配置文件ID"})
|
||||
return
|
||||
}
|
||||
|
||||
profileID, ok = profileIDValue.(string)
|
||||
if !ok {
|
||||
h.logger.Error("刷新令牌失败: 配置文件ID类型错误", zap.Int64("userId", userID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "配置文件ID必须是字符串"})
|
||||
return
|
||||
}
|
||||
|
||||
profileID = utils.FormatUUID(profileID)
|
||||
|
||||
if profile.UserID != userID {
|
||||
h.logger.Warn("刷新令牌失败: 用户不匹配",
|
||||
zap.Int64("userId", userID),
|
||||
zap.Int64("profileUserId", profile.UserID),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": ErrUserNotMatch})
|
||||
return
|
||||
}
|
||||
|
||||
profileData = service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile)
|
||||
}
|
||||
|
||||
user, _ := service.GetUserByID(userID)
|
||||
if request.RequestUser && user != nil {
|
||||
userData = service.SerializeUser(h.logger, user, UUID)
|
||||
}
|
||||
|
||||
newAccessToken, newClientToken, err := service.RefreshToken(h.container.DB, h.logger,
|
||||
request.AccessToken,
|
||||
request.ClientToken,
|
||||
profileID,
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Error("刷新令牌失败", zap.Error(err), zap.Int64("userId", userID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("刷新令牌成功", zap.Int64("userId", userID))
|
||||
c.JSON(http.StatusOK, RefreshResponse{
|
||||
AccessToken: newAccessToken,
|
||||
ClientToken: newClientToken,
|
||||
SelectedProfile: profileData,
|
||||
User: userData,
|
||||
})
|
||||
}
|
||||
|
||||
// InvalidToken 使令牌失效
|
||||
func (h *YggdrasilHandler) InvalidToken(c *gin.Context) {
|
||||
var request ValidTokenRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
h.logger.Error("解析使令牌失效请求失败", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
service.InvalidToken(h.container.DB, h.logger, request.AccessToken)
|
||||
h.logger.Info("令牌已失效", zap.String("token", request.AccessToken))
|
||||
c.JSON(http.StatusNoContent, gin.H{})
|
||||
}
|
||||
|
||||
// SignOut 用户登出
|
||||
func (h *YggdrasilHandler) SignOut(c *gin.Context) {
|
||||
var request SignOutRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
h.logger.Error("解析登出请求失败", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if !emailRegex.MatchString(request.Email) {
|
||||
h.logger.Warn("登出失败: 邮箱格式不正确", zap.String("email", request.Email))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": ErrInvalidEmailFormat})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := service.GetUserByEmail(request.Email)
|
||||
if err != nil || user == nil {
|
||||
h.logger.Warn("登出失败: 用户不存在", zap.String("email", request.Email), zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "用户不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.VerifyPassword(h.container.DB, request.Password, user.ID); err != nil {
|
||||
h.logger.Warn("登出失败: 密码错误", zap.Int64("userId", user.ID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": ErrWrongPassword})
|
||||
return
|
||||
}
|
||||
|
||||
service.InvalidUserTokens(h.container.DB, h.logger, user.ID)
|
||||
h.logger.Info("用户登出成功", zap.Int64("userId", user.ID))
|
||||
c.JSON(http.StatusNoContent, gin.H{"valid": true})
|
||||
}
|
||||
|
||||
// GetProfileByUUID 根据UUID获取档案
|
||||
func (h *YggdrasilHandler) GetProfileByUUID(c *gin.Context) {
|
||||
uuid := utils.FormatUUID(c.Param("uuid"))
|
||||
h.logger.Info("获取配置文件请求", zap.String("uuid", uuid))
|
||||
|
||||
profile, err := service.GetProfileByUUID(h.container.DB, uuid)
|
||||
if err != nil {
|
||||
h.logger.Error("获取配置文件失败", zap.Error(err), zap.String("uuid", uuid))
|
||||
standardResponse(c, http.StatusInternalServerError, nil, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("成功获取配置文件", zap.String("uuid", uuid), zap.String("name", profile.Name))
|
||||
c.JSON(http.StatusOK, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile))
|
||||
}
|
||||
|
||||
// JoinServer 加入服务器
|
||||
func (h *YggdrasilHandler) JoinServer(c *gin.Context) {
|
||||
var request JoinServerRequest
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
h.logger.Error("解析加入服务器请求失败", zap.Error(err), zap.String("ip", clientIP))
|
||||
standardResponse(c, http.StatusBadRequest, nil, ErrInvalidRequest)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("收到加入服务器请求",
|
||||
zap.String("serverId", request.ServerID),
|
||||
zap.String("userUUID", request.SelectedProfile),
|
||||
zap.String("ip", clientIP),
|
||||
)
|
||||
|
||||
if err := service.JoinServer(h.container.DB, h.logger, h.container.Redis, request.ServerID, request.AccessToken, request.SelectedProfile, clientIP); err != nil {
|
||||
h.logger.Error("加入服务器失败",
|
||||
zap.Error(err),
|
||||
zap.String("serverId", request.ServerID),
|
||||
zap.String("userUUID", request.SelectedProfile),
|
||||
zap.String("ip", clientIP),
|
||||
)
|
||||
standardResponse(c, http.StatusInternalServerError, nil, ErrJoinServerFailed)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("加入服务器成功",
|
||||
zap.String("serverId", request.ServerID),
|
||||
zap.String("userUUID", request.SelectedProfile),
|
||||
zap.String("ip", clientIP),
|
||||
)
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// HasJoinedServer 验证玩家是否已加入服务器
|
||||
func (h *YggdrasilHandler) HasJoinedServer(c *gin.Context) {
|
||||
clientIP, _ := c.GetQuery("ip")
|
||||
|
||||
serverID, exists := c.GetQuery("serverId")
|
||||
if !exists || serverID == "" {
|
||||
h.logger.Warn("缺少服务器ID参数", zap.String("ip", clientIP))
|
||||
standardResponse(c, http.StatusNoContent, nil, ErrServerIDRequired)
|
||||
return
|
||||
}
|
||||
|
||||
username, exists := c.GetQuery("username")
|
||||
if !exists || username == "" {
|
||||
h.logger.Warn("缺少用户名参数", zap.String("serverId", serverID), zap.String("ip", clientIP))
|
||||
standardResponse(c, http.StatusNoContent, nil, ErrUsernameRequired)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("收到会话验证请求",
|
||||
zap.String("serverId", serverID),
|
||||
zap.String("username", username),
|
||||
zap.String("ip", clientIP),
|
||||
)
|
||||
|
||||
if err := service.HasJoinedServer(h.logger, h.container.Redis, serverID, username, clientIP); err != nil {
|
||||
h.logger.Warn("会话验证失败",
|
||||
zap.Error(err),
|
||||
zap.String("serverId", serverID),
|
||||
zap.String("username", username),
|
||||
zap.String("ip", clientIP),
|
||||
)
|
||||
standardResponse(c, http.StatusNoContent, nil, ErrSessionVerifyFailed)
|
||||
return
|
||||
}
|
||||
|
||||
profile, err := service.GetProfileByUUID(h.container.DB, username)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户配置文件失败", zap.Error(err), zap.String("username", username))
|
||||
standardResponse(c, http.StatusNoContent, nil, ErrProfileNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("会话验证成功",
|
||||
zap.String("serverId", serverID),
|
||||
zap.String("username", username),
|
||||
zap.String("uuid", profile.UUID),
|
||||
)
|
||||
c.JSON(200, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile))
|
||||
}
|
||||
|
||||
// GetProfilesByName 批量获取配置文件
|
||||
func (h *YggdrasilHandler) GetProfilesByName(c *gin.Context) {
|
||||
var names []string
|
||||
|
||||
if err := c.ShouldBindJSON(&names); err != nil {
|
||||
h.logger.Error("解析名称数组请求失败", zap.Error(err))
|
||||
standardResponse(c, http.StatusBadRequest, nil, ErrInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("接收到批量获取配置文件请求", zap.Int("count", len(names)))
|
||||
|
||||
profiles, err := service.GetProfilesDataByNames(h.container.DB, names)
|
||||
if err != nil {
|
||||
h.logger.Error("获取配置文件失败", zap.Error(err))
|
||||
}
|
||||
|
||||
h.logger.Info("成功获取配置文件", zap.Int("requested", len(names)), zap.Int("returned", len(profiles)))
|
||||
c.JSON(http.StatusOK, profiles)
|
||||
}
|
||||
|
||||
// GetMetaData 获取Yggdrasil元数据
|
||||
func (h *YggdrasilHandler) GetMetaData(c *gin.Context) {
|
||||
meta := gin.H{
|
||||
"implementationName": "CellAuth",
|
||||
"implementationVersion": "0.0.1",
|
||||
"serverName": "LittleLan's Yggdrasil Server Implementation.",
|
||||
"links": gin.H{
|
||||
"homepage": "https://skin.littlelan.cn",
|
||||
"register": "https://skin.littlelan.cn/auth",
|
||||
},
|
||||
"feature.non_email_login": true,
|
||||
"feature.enable_profile_key": true,
|
||||
}
|
||||
|
||||
skinDomains := []string{".hitwh.games", ".littlelan.cn"}
|
||||
signature, err := service.GetPublicKeyFromRedisFunc(h.logger, h.container.Redis)
|
||||
if err != nil {
|
||||
h.logger.Error("获取公钥失败", zap.Error(err))
|
||||
standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("提供元数据")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"meta": meta,
|
||||
"skinDomains": skinDomains,
|
||||
"signaturePublickey": signature,
|
||||
})
|
||||
}
|
||||
|
||||
// GetPlayerCertificates 获取玩家证书
|
||||
func (h *YggdrasilHandler) GetPlayerCertificates(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header not provided"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
bearerPrefix := "Bearer "
|
||||
if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
tokenID := authHeader[len(bearerPrefix):]
|
||||
if tokenID == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
uuid, err := service.GetUUIDByAccessToken(h.container.DB, tokenID)
|
||||
if uuid == "" {
|
||||
h.logger.Error("获取玩家UUID失败", zap.Error(err))
|
||||
standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer)
|
||||
return
|
||||
}
|
||||
|
||||
uuid = utils.FormatUUID(uuid)
|
||||
|
||||
certificate, err := service.GeneratePlayerCertificate(h.container.DB, h.logger, h.container.Redis, uuid)
|
||||
if err != nil {
|
||||
h.logger.Error("生成玩家证书失败", zap.Error(err))
|
||||
standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("成功生成玩家证书")
|
||||
c.JSON(http.StatusOK, certificate)
|
||||
}
|
||||
@@ -1,16 +1,48 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS 跨域中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
// 获取配置,如果配置未初始化则使用默认值
|
||||
var allowedOrigins []string
|
||||
if cfg, err := config.GetConfig(); err == nil {
|
||||
allowedOrigins = cfg.Security.AllowedOrigins
|
||||
} else {
|
||||
// 默认允许所有来源(向后兼容)
|
||||
allowedOrigins = []string{"*"}
|
||||
}
|
||||
|
||||
return gin.HandlerFunc(func(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
origin := c.GetHeader("Origin")
|
||||
|
||||
// 检查是否允许该来源
|
||||
allowOrigin := "*"
|
||||
if len(allowedOrigins) > 0 && allowedOrigins[0] != "*" {
|
||||
allowOrigin = ""
|
||||
for _, allowed := range allowedOrigins {
|
||||
if allowed == origin || allowed == "*" {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if allowOrigin != "" {
|
||||
c.Header("Access-Control-Allow-Origin", allowOrigin)
|
||||
// 只有在非通配符模式下才允许credentials
|
||||
if allowOrigin != "*" {
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
}
|
||||
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||||
c.Header("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
|
||||
c.Header("Access-Control-Max-Age", "86400") // 缓存预检请求结果24小时
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
|
||||
@@ -24,10 +24,11 @@ func TestCORS_Headers(t *testing.T) {
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证CORS响应头
|
||||
// 注意:当 Access-Control-Allow-Origin 为 "*" 时,根据CORS规范,
|
||||
// 不应该设置 Access-Control-Allow-Credentials 为 "true"
|
||||
expectedHeaders := map[string]string{
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS, GET, PUT, DELETE",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS, GET, PUT, DELETE",
|
||||
}
|
||||
|
||||
for header, expectedValue := range expectedHeaders {
|
||||
@@ -37,6 +38,11 @@ func TestCORS_Headers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// 验证在通配符模式下不设置Credentials(这是正确的安全行为)
|
||||
if credentials := w.Header().Get("Access-Control-Allow-Credentials"); credentials != "" {
|
||||
t.Errorf("通配符origin模式下不应设置 Access-Control-Allow-Credentials, got %q", credentials)
|
||||
}
|
||||
|
||||
// 验证Access-Control-Allow-Headers包含必要字段
|
||||
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
|
||||
if allowHeaders == "" {
|
||||
@@ -117,6 +123,30 @@ func TestCORS_AllowHeaders(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestCORS_WithSpecificOrigin 测试配置了具体origin时的CORS行为
|
||||
func TestCORS_WithSpecificOrigin(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// 注意:此测试验证的是在配置了具体allowed origins时的行为
|
||||
// 在没有配置初始化的情况下,默认使用通配符模式
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 默认配置下使用通配符,所以不应该设置credentials
|
||||
if credentials := w.Header().Get("Access-Control-Allow-Credentials"); credentials != "" {
|
||||
t.Logf("当前模式下 Access-Control-Allow-Credentials = %q (通配符模式不设置)", credentials)
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数:检查字符串是否包含子字符串(简单实现)
|
||||
func contains(s, substr string) bool {
|
||||
if len(substr) == 0 {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
@@ -11,16 +12,26 @@ import (
|
||||
// Recovery 恢复中间件
|
||||
func Recovery(logger *zap.Logger) gin.HandlerFunc {
|
||||
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
|
||||
if err, ok := recovered.(string); ok {
|
||||
logger.Error("服务器恐慌",
|
||||
zap.String("error", err),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("ip", c.ClientIP()),
|
||||
zap.String("stack", string(debug.Stack())),
|
||||
)
|
||||
// 将任意类型的panic转换为字符串
|
||||
var errMsg string
|
||||
switch v := recovered.(type) {
|
||||
case string:
|
||||
errMsg = v
|
||||
case error:
|
||||
errMsg = v.Error()
|
||||
default:
|
||||
errMsg = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
logger.Error("服务器恐慌",
|
||||
zap.String("error", errMsg),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("ip", c.ClientIP()),
|
||||
zap.String("user_agent", c.GetHeader("User-Agent")),
|
||||
zap.String("stack", string(debug.Stack())),
|
||||
)
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": "服务器内部错误",
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package model
|
||||
|
||||
import "os"
|
||||
|
||||
// Response 通用API响应结构
|
||||
type Response struct {
|
||||
Code int `json:"code"` // 业务状态码
|
||||
Message string `json:"message"` // 响应消息
|
||||
Data interface{} `json:"data,omitempty"` // 响应数据
|
||||
Code int `json:"code"` // 业务状态码
|
||||
Message string `json:"message"` // 响应消息
|
||||
Data interface{} `json:"data,omitempty"` // 响应数据
|
||||
}
|
||||
|
||||
// PaginationResponse 分页响应结构
|
||||
@@ -12,9 +14,9 @@ type PaginationResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data"`
|
||||
Total int64 `json:"total"` // 总记录数
|
||||
Page int `json:"page"` // 当前页码
|
||||
PerPage int `json:"per_page"` // 每页数量
|
||||
Total int64 `json:"total"` // 总记录数
|
||||
Page int `json:"page"` // 当前页码
|
||||
PerPage int `json:"per_page"` // 每页数量
|
||||
}
|
||||
|
||||
// ErrorResponse 错误响应
|
||||
@@ -26,14 +28,14 @@ type ErrorResponse struct {
|
||||
|
||||
// 常用状态码
|
||||
const (
|
||||
CodeSuccess = 200 // 成功
|
||||
CodeCreated = 201 // 创建成功
|
||||
CodeBadRequest = 400 // 请求参数错误
|
||||
CodeUnauthorized = 401 // 未授权
|
||||
CodeForbidden = 403 // 禁止访问
|
||||
CodeNotFound = 404 // 资源不存在
|
||||
CodeConflict = 409 // 资源冲突
|
||||
CodeServerError = 500 // 服务器错误
|
||||
CodeSuccess = 200 // 成功
|
||||
CodeCreated = 201 // 创建成功
|
||||
CodeBadRequest = 400 // 请求参数错误
|
||||
CodeUnauthorized = 401 // 未授权
|
||||
CodeForbidden = 403 // 禁止访问
|
||||
CodeNotFound = 404 // 资源不存在
|
||||
CodeConflict = 409 // 资源冲突
|
||||
CodeServerError = 500 // 服务器错误
|
||||
)
|
||||
|
||||
// 常用响应消息
|
||||
@@ -61,17 +63,26 @@ func NewSuccessResponse(data interface{}) *Response {
|
||||
}
|
||||
|
||||
// NewErrorResponse 创建错误响应
|
||||
// 注意:err参数仅在开发环境下显示,生产环境不应暴露详细错误信息
|
||||
func NewErrorResponse(code int, message string, err error) *ErrorResponse {
|
||||
resp := &ErrorResponse{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
if err != nil {
|
||||
// 仅在非生产环境下返回详细错误信息
|
||||
// 可以通过环境变量 ENVIRONMENT 控制
|
||||
if err != nil && !isProductionEnvironment() {
|
||||
resp.Error = err.Error()
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// isProductionEnvironment 检查是否为生产环境
|
||||
func isProductionEnvironment() bool {
|
||||
env := os.Getenv("ENVIRONMENT")
|
||||
return env == "production" || env == "prod"
|
||||
}
|
||||
|
||||
// NewPaginationResponse 创建分页响应
|
||||
func NewPaginationResponse(data interface{}, total int64, page, perPage int) *PaginationResponse {
|
||||
return &PaginationResponse{
|
||||
|
||||
85
internal/repository/interfaces.go
Normal file
85
internal/repository/interfaces.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
)
|
||||
|
||||
// UserRepository 用户仓储接口
|
||||
type UserRepository interface {
|
||||
Create(user *model.User) error
|
||||
FindByID(id int64) (*model.User, error)
|
||||
FindByUsername(username string) (*model.User, error)
|
||||
FindByEmail(email string) (*model.User, error)
|
||||
Update(user *model.User) error
|
||||
UpdateFields(id int64, fields map[string]interface{}) error
|
||||
Delete(id int64) error
|
||||
CreateLoginLog(log *model.UserLoginLog) error
|
||||
CreatePointLog(log *model.UserPointLog) error
|
||||
UpdatePoints(userID int64, amount int, changeType, reason string) error
|
||||
}
|
||||
|
||||
// ProfileRepository 档案仓储接口
|
||||
type ProfileRepository interface {
|
||||
Create(profile *model.Profile) error
|
||||
FindByUUID(uuid string) (*model.Profile, error)
|
||||
FindByName(name string) (*model.Profile, error)
|
||||
FindByUserID(userID int64) ([]*model.Profile, error)
|
||||
Update(profile *model.Profile) error
|
||||
UpdateFields(uuid string, updates map[string]interface{}) error
|
||||
Delete(uuid string) error
|
||||
CountByUserID(userID int64) (int64, error)
|
||||
SetActive(uuid string, userID int64) error
|
||||
UpdateLastUsedAt(uuid string) error
|
||||
GetByNames(names []string) ([]*model.Profile, error)
|
||||
GetKeyPair(profileId string) (*model.KeyPair, error)
|
||||
UpdateKeyPair(profileId string, keyPair *model.KeyPair) error
|
||||
}
|
||||
|
||||
// TextureRepository 材质仓储接口
|
||||
type TextureRepository interface {
|
||||
Create(texture *model.Texture) error
|
||||
FindByID(id int64) (*model.Texture, error)
|
||||
FindByHash(hash string) (*model.Texture, error)
|
||||
FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error)
|
||||
Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error)
|
||||
Update(texture *model.Texture) error
|
||||
UpdateFields(id int64, fields map[string]interface{}) error
|
||||
Delete(id int64) error
|
||||
IncrementDownloadCount(id int64) error
|
||||
IncrementFavoriteCount(id int64) error
|
||||
DecrementFavoriteCount(id int64) error
|
||||
CreateDownloadLog(log *model.TextureDownloadLog) error
|
||||
IsFavorited(userID, textureID int64) (bool, error)
|
||||
AddFavorite(userID, textureID int64) error
|
||||
RemoveFavorite(userID, textureID int64) error
|
||||
GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error)
|
||||
CountByUploaderID(uploaderID int64) (int64, error)
|
||||
}
|
||||
|
||||
// TokenRepository 令牌仓储接口
|
||||
type TokenRepository interface {
|
||||
Create(token *model.Token) error
|
||||
FindByAccessToken(accessToken string) (*model.Token, error)
|
||||
GetByUserID(userId int64) ([]*model.Token, error)
|
||||
GetUUIDByAccessToken(accessToken string) (string, error)
|
||||
GetUserIDByAccessToken(accessToken string) (int64, error)
|
||||
DeleteByAccessToken(accessToken string) error
|
||||
DeleteByUserID(userId int64) error
|
||||
BatchDelete(accessTokens []string) (int64, error)
|
||||
}
|
||||
|
||||
// SystemConfigRepository 系统配置仓储接口
|
||||
type SystemConfigRepository interface {
|
||||
GetByKey(key string) (*model.SystemConfig, error)
|
||||
GetPublic() ([]model.SystemConfig, error)
|
||||
GetAll() ([]model.SystemConfig, error)
|
||||
Update(config *model.SystemConfig) error
|
||||
UpdateValue(key, value string) error
|
||||
}
|
||||
|
||||
// YggdrasilRepository Yggdrasil仓储接口
|
||||
type YggdrasilRepository interface {
|
||||
GetPasswordByID(id int64) (string, error)
|
||||
ResetPassword(id int64, password string) error
|
||||
}
|
||||
|
||||
149
internal/repository/profile_repository_impl.go
Normal file
149
internal/repository/profile_repository_impl.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// profileRepositoryImpl ProfileRepository的实现
|
||||
type profileRepositoryImpl struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewProfileRepository 创建ProfileRepository实例
|
||||
func NewProfileRepository(db *gorm.DB) ProfileRepository {
|
||||
return &profileRepositoryImpl{db: db}
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) Create(profile *model.Profile) error {
|
||||
return r.db.Create(profile).Error
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) FindByUUID(uuid string) (*model.Profile, error) {
|
||||
var profile model.Profile
|
||||
err := r.db.Where("uuid = ?", uuid).
|
||||
Preload("Skin").
|
||||
Preload("Cape").
|
||||
First(&profile).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) FindByName(name string) (*model.Profile, error) {
|
||||
var profile model.Profile
|
||||
err := r.db.Where("name = ?", name).First(&profile).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) FindByUserID(userID int64) ([]*model.Profile, error) {
|
||||
var profiles []*model.Profile
|
||||
err := r.db.Where("user_id = ?", userID).
|
||||
Preload("Skin").
|
||||
Preload("Cape").
|
||||
Order("created_at DESC").
|
||||
Find(&profiles).Error
|
||||
return profiles, err
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) Update(profile *model.Profile) error {
|
||||
return r.db.Save(profile).Error
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) UpdateFields(uuid string, updates map[string]interface{}) error {
|
||||
return r.db.Model(&model.Profile{}).
|
||||
Where("uuid = ?", uuid).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) Delete(uuid string) error {
|
||||
return r.db.Where("uuid = ?", uuid).Delete(&model.Profile{}).Error
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) CountByUserID(userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Profile{}).
|
||||
Where("user_id = ?", userID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) SetActive(uuid string, userID int64) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Model(&model.Profile{}).
|
||||
Where("user_id = ?", userID).
|
||||
Update("is_active", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Model(&model.Profile{}).
|
||||
Where("uuid = ? AND user_id = ?", uuid, userID).
|
||||
Update("is_active", true).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) UpdateLastUsedAt(uuid string) error {
|
||||
return r.db.Model(&model.Profile{}).
|
||||
Where("uuid = ?", uuid).
|
||||
Update("last_used_at", gorm.Expr("CURRENT_TIMESTAMP")).Error
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) GetByNames(names []string) ([]*model.Profile, error) {
|
||||
var profiles []*model.Profile
|
||||
err := r.db.Where("name in (?)", names).Find(&profiles).Error
|
||||
return profiles, err
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) GetKeyPair(profileId string) (*model.KeyPair, error) {
|
||||
if profileId == "" {
|
||||
return nil, errors.New("参数不能为空")
|
||||
}
|
||||
|
||||
var profile model.Profile
|
||||
result := r.db.WithContext(context.Background()).
|
||||
Select("key_pair").
|
||||
Where("id = ?", profileId).
|
||||
First(&profile)
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("key pair未找到")
|
||||
}
|
||||
return nil, fmt.Errorf("获取key pair失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return &model.KeyPair{}, nil
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) UpdateKeyPair(profileId string, keyPair *model.KeyPair) error {
|
||||
if profileId == "" {
|
||||
return errors.New("profileId 不能为空")
|
||||
}
|
||||
if keyPair == nil {
|
||||
return errors.New("keyPair 不能为 nil")
|
||||
}
|
||||
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.WithContext(context.Background()).
|
||||
Table("profiles").
|
||||
Where("id = ?", profileId).
|
||||
UpdateColumns(map[string]interface{}{
|
||||
"private_key": keyPair.PrivateKey,
|
||||
"public_key": keyPair.PublicKey,
|
||||
})
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("更新 keyPair 失败: %w", result.Error)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
45
internal/repository/system_config_repository_impl.go
Normal file
45
internal/repository/system_config_repository_impl.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// systemConfigRepositoryImpl SystemConfigRepository的实现
|
||||
type systemConfigRepositoryImpl struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewSystemConfigRepository 创建SystemConfigRepository实例
|
||||
func NewSystemConfigRepository(db *gorm.DB) SystemConfigRepository {
|
||||
return &systemConfigRepositoryImpl{db: db}
|
||||
}
|
||||
|
||||
func (r *systemConfigRepositoryImpl) GetByKey(key string) (*model.SystemConfig, error) {
|
||||
var config model.SystemConfig
|
||||
err := r.db.Where("key = ?", key).First(&config).Error
|
||||
return handleNotFoundResult(&config, err)
|
||||
}
|
||||
|
||||
func (r *systemConfigRepositoryImpl) GetPublic() ([]model.SystemConfig, error) {
|
||||
var configs []model.SystemConfig
|
||||
err := r.db.Where("is_public = ?", true).Find(&configs).Error
|
||||
return configs, err
|
||||
}
|
||||
|
||||
func (r *systemConfigRepositoryImpl) GetAll() ([]model.SystemConfig, error) {
|
||||
var configs []model.SystemConfig
|
||||
err := r.db.Find(&configs).Error
|
||||
return configs, err
|
||||
}
|
||||
|
||||
func (r *systemConfigRepositoryImpl) Update(config *model.SystemConfig) error {
|
||||
return r.db.Save(config).Error
|
||||
}
|
||||
|
||||
func (r *systemConfigRepositoryImpl) UpdateValue(key, value string) error {
|
||||
return r.db.Model(&model.SystemConfig{}).Where("key = ?", key).Update("value", value).Error
|
||||
}
|
||||
|
||||
|
||||
175
internal/repository/texture_repository_impl.go
Normal file
175
internal/repository/texture_repository_impl.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// textureRepositoryImpl TextureRepository的实现
|
||||
type textureRepositoryImpl struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewTextureRepository 创建TextureRepository实例
|
||||
func NewTextureRepository(db *gorm.DB) TextureRepository {
|
||||
return &textureRepositoryImpl{db: db}
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) Create(texture *model.Texture) error {
|
||||
return r.db.Create(texture).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) FindByID(id int64) (*model.Texture, error) {
|
||||
var texture model.Texture
|
||||
err := r.db.Preload("Uploader").First(&texture, id).Error
|
||||
return handleNotFoundResult(&texture, err)
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) FindByHash(hash string) (*model.Texture, error) {
|
||||
var texture model.Texture
|
||||
err := r.db.Where("hash = ?", hash).First(&texture).Error
|
||||
return handleNotFoundResult(&texture, err)
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
var textures []*model.Texture
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Texture{}).Where("uploader_id = ? AND status != -1", uploaderID)
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
err := query.Scopes(Paginate(page, pageSize)).
|
||||
Preload("Uploader").
|
||||
Order("created_at DESC").
|
||||
Find(&textures).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return textures, total, nil
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
var textures []*model.Texture
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Texture{}).Where("status = 1")
|
||||
|
||||
if publicOnly {
|
||||
query = query.Where("is_public = ?", true)
|
||||
}
|
||||
if textureType != "" {
|
||||
query = query.Where("type = ?", textureType)
|
||||
}
|
||||
if keyword != "" {
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
err := query.Scopes(Paginate(page, pageSize)).
|
||||
Preload("Uploader").
|
||||
Order("created_at DESC").
|
||||
Find(&textures).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return textures, total, nil
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) Update(texture *model.Texture) error {
|
||||
return r.db.Save(texture).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) UpdateFields(id int64, fields map[string]interface{}) error {
|
||||
return r.db.Model(&model.Texture{}).Where("id = ?", id).Updates(fields).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) Delete(id int64) error {
|
||||
return r.db.Model(&model.Texture{}).Where("id = ?", id).Update("status", -1).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) IncrementDownloadCount(id int64) error {
|
||||
return r.db.Model(&model.Texture{}).Where("id = ?", id).
|
||||
UpdateColumn("download_count", gorm.Expr("download_count + ?", 1)).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) IncrementFavoriteCount(id int64) error {
|
||||
return r.db.Model(&model.Texture{}).Where("id = ?", id).
|
||||
UpdateColumn("favorite_count", gorm.Expr("favorite_count + ?", 1)).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) DecrementFavoriteCount(id int64) error {
|
||||
return r.db.Model(&model.Texture{}).Where("id = ?", id).
|
||||
UpdateColumn("favorite_count", gorm.Expr("favorite_count - ?", 1)).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) CreateDownloadLog(log *model.TextureDownloadLog) error {
|
||||
return r.db.Create(log).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) IsFavorited(userID, textureID int64) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.UserTextureFavorite{}).
|
||||
Where("user_id = ? AND texture_id = ?", userID, textureID).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) AddFavorite(userID, textureID int64) error {
|
||||
favorite := &model.UserTextureFavorite{
|
||||
UserID: userID,
|
||||
TextureID: textureID,
|
||||
}
|
||||
return r.db.Create(favorite).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) RemoveFavorite(userID, textureID int64) error {
|
||||
return r.db.Where("user_id = ? AND texture_id = ?", userID, textureID).
|
||||
Delete(&model.UserTextureFavorite{}).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
var textures []*model.Texture
|
||||
var total int64
|
||||
|
||||
subQuery := r.db.Model(&model.UserTextureFavorite{}).
|
||||
Select("texture_id").
|
||||
Where("user_id = ?", userID)
|
||||
|
||||
query := r.db.Model(&model.Texture{}).
|
||||
Where("id IN (?) AND status = 1", subQuery)
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
err := query.Scopes(Paginate(page, pageSize)).
|
||||
Preload("Uploader").
|
||||
Order("created_at DESC").
|
||||
Find(&textures).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return textures, total, nil
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) CountByUploaderID(uploaderID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Texture{}).
|
||||
Where("uploader_id = ? AND status != -1", uploaderID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
71
internal/repository/token_repository_impl.go
Normal file
71
internal/repository/token_repository_impl.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// tokenRepositoryImpl TokenRepository的实现
|
||||
type tokenRepositoryImpl struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewTokenRepository 创建TokenRepository实例
|
||||
func NewTokenRepository(db *gorm.DB) TokenRepository {
|
||||
return &tokenRepositoryImpl{db: db}
|
||||
}
|
||||
|
||||
func (r *tokenRepositoryImpl) Create(token *model.Token) error {
|
||||
return r.db.Create(token).Error
|
||||
}
|
||||
|
||||
func (r *tokenRepositoryImpl) FindByAccessToken(accessToken string) (*model.Token, error) {
|
||||
var token model.Token
|
||||
err := r.db.Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (r *tokenRepositoryImpl) GetByUserID(userId int64) ([]*model.Token, error) {
|
||||
var tokens []*model.Token
|
||||
err := r.db.Where("user_id = ?", userId).Find(&tokens).Error
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
func (r *tokenRepositoryImpl) GetUUIDByAccessToken(accessToken string) (string, error) {
|
||||
var token model.Token
|
||||
err := r.db.Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return token.ProfileId, nil
|
||||
}
|
||||
|
||||
func (r *tokenRepositoryImpl) GetUserIDByAccessToken(accessToken string) (int64, error) {
|
||||
var token model.Token
|
||||
err := r.db.Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return token.UserID, nil
|
||||
}
|
||||
|
||||
func (r *tokenRepositoryImpl) DeleteByAccessToken(accessToken string) error {
|
||||
return r.db.Where("access_token = ?", accessToken).Delete(&model.Token{}).Error
|
||||
}
|
||||
|
||||
func (r *tokenRepositoryImpl) DeleteByUserID(userId int64) error {
|
||||
return r.db.Where("user_id = ?", userId).Delete(&model.Token{}).Error
|
||||
}
|
||||
|
||||
func (r *tokenRepositoryImpl) BatchDelete(accessTokens []string) (int64, error) {
|
||||
if len(accessTokens) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
result := r.db.Where("access_token IN ?", accessTokens).Delete(&model.Token{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
103
internal/repository/user_repository_impl.go
Normal file
103
internal/repository/user_repository_impl.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// userRepositoryImpl UserRepository的实现
|
||||
type userRepositoryImpl struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewUserRepository 创建UserRepository实例
|
||||
func NewUserRepository(db *gorm.DB) UserRepository {
|
||||
return &userRepositoryImpl{db: db}
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) Create(user *model.User) error {
|
||||
return r.db.Create(user).Error
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) FindByID(id int64) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.Where("id = ? AND status != -1", id).First(&user).Error
|
||||
return handleNotFoundResult(&user, err)
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) FindByUsername(username string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.Where("username = ? AND status != -1", username).First(&user).Error
|
||||
return handleNotFoundResult(&user, err)
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) FindByEmail(email string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.Where("email = ? AND status != -1", email).First(&user).Error
|
||||
return handleNotFoundResult(&user, err)
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) Update(user *model.User) error {
|
||||
return r.db.Save(user).Error
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) UpdateFields(id int64, fields map[string]interface{}) error {
|
||||
return r.db.Model(&model.User{}).Where("id = ?", id).Updates(fields).Error
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) Delete(id int64) error {
|
||||
return r.db.Model(&model.User{}).Where("id = ?", id).Update("status", -1).Error
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) CreateLoginLog(log *model.UserLoginLog) error {
|
||||
return r.db.Create(log).Error
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) CreatePointLog(log *model.UserPointLog) error {
|
||||
return r.db.Create(log).Error
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) UpdatePoints(userID int64, amount int, changeType, reason string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var user model.User
|
||||
if err := tx.Where("id = ?", userID).First(&user).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
balanceBefore := user.Points
|
||||
balanceAfter := balanceBefore + amount
|
||||
|
||||
if balanceAfter < 0 {
|
||||
return errors.New("积分不足")
|
||||
}
|
||||
|
||||
if err := tx.Model(&user).Update("points", balanceAfter).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log := &model.UserPointLog{
|
||||
UserID: userID,
|
||||
ChangeType: changeType,
|
||||
Amount: amount,
|
||||
BalanceBefore: balanceBefore,
|
||||
BalanceAfter: balanceAfter,
|
||||
Reason: reason,
|
||||
}
|
||||
|
||||
return tx.Create(log).Error
|
||||
})
|
||||
}
|
||||
|
||||
// handleNotFoundResult 处理记录未找到的情况
|
||||
func handleNotFoundResult[T any](result *T, err error) (*T, error) {
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
145
internal/service/interfaces.go
Normal file
145
internal/service/interfaces.go
Normal file
@@ -0,0 +1,145 @@
|
||||
// Package service 定义业务逻辑层接口
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/storage"
|
||||
"context"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// UserService 用户服务接口
|
||||
type UserService interface {
|
||||
// 用户认证
|
||||
Register(username, password, email, avatar string) (*model.User, string, error)
|
||||
Login(usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error)
|
||||
|
||||
// 用户查询
|
||||
GetByID(id int64) (*model.User, error)
|
||||
GetByEmail(email string) (*model.User, error)
|
||||
|
||||
// 用户更新
|
||||
UpdateInfo(user *model.User) error
|
||||
UpdateAvatar(userID int64, avatarURL string) error
|
||||
ChangePassword(userID int64, oldPassword, newPassword string) error
|
||||
ResetPassword(email, newPassword string) error
|
||||
ChangeEmail(userID int64, newEmail string) error
|
||||
|
||||
// URL验证
|
||||
ValidateAvatarURL(avatarURL string) error
|
||||
|
||||
// 配置获取
|
||||
GetMaxProfilesPerUser() int
|
||||
GetMaxTexturesPerUser() int
|
||||
}
|
||||
|
||||
// ProfileService 档案服务接口
|
||||
type ProfileService interface {
|
||||
// 档案CRUD
|
||||
Create(userID int64, name string) (*model.Profile, error)
|
||||
GetByUUID(uuid string) (*model.Profile, error)
|
||||
GetByUserID(userID int64) ([]*model.Profile, error)
|
||||
Update(uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error)
|
||||
Delete(uuid string, userID int64) error
|
||||
|
||||
// 档案状态
|
||||
SetActive(uuid string, userID int64) error
|
||||
CheckLimit(userID int64, maxProfiles int) error
|
||||
|
||||
// 批量查询
|
||||
GetByNames(names []string) ([]*model.Profile, error)
|
||||
GetByProfileName(name string) (*model.Profile, error)
|
||||
}
|
||||
|
||||
// TextureService 材质服务接口
|
||||
type TextureService interface {
|
||||
// 材质CRUD
|
||||
Create(uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error)
|
||||
GetByID(id int64) (*model.Texture, error)
|
||||
GetByUserID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error)
|
||||
Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error)
|
||||
Update(textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error)
|
||||
Delete(textureID, uploaderID int64) error
|
||||
|
||||
// 收藏
|
||||
ToggleFavorite(userID, textureID int64) (bool, error)
|
||||
GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error)
|
||||
|
||||
// 限制检查
|
||||
CheckUploadLimit(uploaderID int64, maxTextures int) error
|
||||
}
|
||||
|
||||
// TokenService 令牌服务接口
|
||||
type TokenService interface {
|
||||
// 令牌管理
|
||||
Create(userID int64, uuid, clientToken string) (*model.Profile, []*model.Profile, string, string, error)
|
||||
Validate(accessToken, clientToken string) bool
|
||||
Refresh(accessToken, clientToken, selectedProfileID string) (string, string, error)
|
||||
Invalidate(accessToken string)
|
||||
InvalidateUserTokens(userID int64)
|
||||
|
||||
// 令牌查询
|
||||
GetUUIDByAccessToken(accessToken string) (string, error)
|
||||
GetUserIDByAccessToken(accessToken string) (int64, error)
|
||||
}
|
||||
|
||||
// VerificationService 验证码服务接口
|
||||
type VerificationService interface {
|
||||
SendCode(ctx context.Context, email, codeType string) error
|
||||
VerifyCode(ctx context.Context, email, code, codeType string) error
|
||||
}
|
||||
|
||||
// CaptchaService 滑动验证码服务接口
|
||||
type CaptchaService interface {
|
||||
Generate(ctx context.Context) (masterImg, tileImg, captchaID string, y int, err error)
|
||||
Verify(ctx context.Context, dx int, captchaID string) (bool, error)
|
||||
}
|
||||
|
||||
// UploadService 上传服务接口
|
||||
type UploadService interface {
|
||||
GenerateAvatarUploadURL(ctx context.Context, userID int64, fileName string) (*storage.PresignedPostPolicyResult, error)
|
||||
GenerateTextureUploadURL(ctx context.Context, userID int64, fileName, textureType string) (*storage.PresignedPostPolicyResult, error)
|
||||
}
|
||||
|
||||
// YggdrasilService Yggdrasil服务接口
|
||||
type YggdrasilService interface {
|
||||
// 用户认证
|
||||
GetUserIDByEmail(email string) (int64, error)
|
||||
VerifyPassword(password string, userID int64) error
|
||||
|
||||
// 会话管理
|
||||
JoinServer(serverID, accessToken, selectedProfile, ip string) error
|
||||
HasJoinedServer(serverID, username, ip string) error
|
||||
|
||||
// 密码管理
|
||||
ResetYggdrasilPassword(userID int64) (string, error)
|
||||
|
||||
// 序列化
|
||||
SerializeProfile(profile model.Profile) map[string]interface{}
|
||||
SerializeUser(user *model.User, uuid string) map[string]interface{}
|
||||
|
||||
// 证书
|
||||
GeneratePlayerCertificate(uuid string) (map[string]interface{}, error)
|
||||
GetPublicKey() (string, error)
|
||||
}
|
||||
|
||||
// Services 服务集合
|
||||
type Services struct {
|
||||
User UserService
|
||||
Profile ProfileService
|
||||
Texture TextureService
|
||||
Token TokenService
|
||||
Verification VerificationService
|
||||
Captcha CaptchaService
|
||||
Upload UploadService
|
||||
Yggdrasil YggdrasilService
|
||||
}
|
||||
|
||||
// ServiceDeps 服务依赖
|
||||
type ServiceDeps struct {
|
||||
Logger *zap.Logger
|
||||
Storage *storage.StorageClient
|
||||
}
|
||||
|
||||
|
||||
234
internal/service/profile_service_impl.go
Normal file
234
internal/service/profile_service_impl.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// profileServiceImpl ProfileService的实现
|
||||
type profileServiceImpl struct {
|
||||
profileRepo repository.ProfileRepository
|
||||
userRepo repository.UserRepository
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewProfileService 创建ProfileService实例
|
||||
func NewProfileService(
|
||||
profileRepo repository.ProfileRepository,
|
||||
userRepo repository.UserRepository,
|
||||
logger *zap.Logger,
|
||||
) ProfileService {
|
||||
return &profileServiceImpl{
|
||||
profileRepo: profileRepo,
|
||||
userRepo: userRepo,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *profileServiceImpl) Create(userID int64, name string) (*model.Profile, error) {
|
||||
// 验证用户存在
|
||||
user, err := s.userRepo.FindByID(userID)
|
||||
if err != nil || user == nil {
|
||||
return nil, errors.New("用户不存在")
|
||||
}
|
||||
if user.Status != 1 {
|
||||
return nil, errors.New("用户状态异常")
|
||||
}
|
||||
|
||||
// 检查角色名是否已存在
|
||||
existingName, err := s.profileRepo.FindByName(name)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("查询角色名失败: %w", err)
|
||||
}
|
||||
if existingName != nil {
|
||||
return nil, errors.New("角色名已被使用")
|
||||
}
|
||||
|
||||
// 生成UUID和RSA密钥
|
||||
profileUUID := uuid.New().String()
|
||||
privateKey, err := generateRSAPrivateKeyInternal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成RSA密钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建档案
|
||||
profile := &model.Profile{
|
||||
UUID: profileUUID,
|
||||
UserID: userID,
|
||||
Name: name,
|
||||
RSAPrivateKey: privateKey,
|
||||
IsActive: true,
|
||||
}
|
||||
|
||||
if err := s.profileRepo.Create(profile); err != nil {
|
||||
return nil, fmt.Errorf("创建档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置活跃状态
|
||||
if err := s.profileRepo.SetActive(profileUUID, userID); err != nil {
|
||||
return nil, fmt.Errorf("设置活跃状态失败: %w", err)
|
||||
}
|
||||
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
func (s *profileServiceImpl) GetByUUID(uuid string) (*model.Profile, error) {
|
||||
profile, err := s.profileRepo.FindByUUID(uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrProfileNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
func (s *profileServiceImpl) GetByUserID(userID int64) ([]*model.Profile, error) {
|
||||
profiles, err := s.profileRepo.FindByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询档案列表失败: %w", err)
|
||||
}
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
func (s *profileServiceImpl) Update(uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) {
|
||||
// 获取档案并验证权限
|
||||
profile, err := s.profileRepo.FindByUUID(uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrProfileNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
|
||||
if profile.UserID != userID {
|
||||
return nil, ErrProfileNoPermission
|
||||
}
|
||||
|
||||
// 检查角色名是否重复
|
||||
if name != nil && *name != profile.Name {
|
||||
existingName, err := s.profileRepo.FindByName(*name)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("查询角色名失败: %w", err)
|
||||
}
|
||||
if existingName != nil {
|
||||
return nil, errors.New("角色名已被使用")
|
||||
}
|
||||
profile.Name = *name
|
||||
}
|
||||
|
||||
// 更新皮肤和披风
|
||||
if skinID != nil {
|
||||
profile.SkinID = skinID
|
||||
}
|
||||
if capeID != nil {
|
||||
profile.CapeID = capeID
|
||||
}
|
||||
|
||||
if err := s.profileRepo.Update(profile); err != nil {
|
||||
return nil, fmt.Errorf("更新档案失败: %w", err)
|
||||
}
|
||||
|
||||
return s.profileRepo.FindByUUID(uuid)
|
||||
}
|
||||
|
||||
func (s *profileServiceImpl) Delete(uuid string, userID int64) error {
|
||||
// 获取档案并验证权限
|
||||
profile, err := s.profileRepo.FindByUUID(uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrProfileNotFound
|
||||
}
|
||||
return fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
|
||||
if profile.UserID != userID {
|
||||
return ErrProfileNoPermission
|
||||
}
|
||||
|
||||
if err := s.profileRepo.Delete(uuid); err != nil {
|
||||
return fmt.Errorf("删除档案失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *profileServiceImpl) SetActive(uuid string, userID int64) error {
|
||||
// 获取档案并验证权限
|
||||
profile, err := s.profileRepo.FindByUUID(uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrProfileNotFound
|
||||
}
|
||||
return fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
|
||||
if profile.UserID != userID {
|
||||
return ErrProfileNoPermission
|
||||
}
|
||||
|
||||
if err := s.profileRepo.SetActive(uuid, userID); err != nil {
|
||||
return fmt.Errorf("设置活跃状态失败: %w", err)
|
||||
}
|
||||
|
||||
if err := s.profileRepo.UpdateLastUsedAt(uuid); err != nil {
|
||||
return fmt.Errorf("更新使用时间失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *profileServiceImpl) CheckLimit(userID int64, maxProfiles int) error {
|
||||
count, err := s.profileRepo.CountByUserID(userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询档案数量失败: %w", err)
|
||||
}
|
||||
|
||||
if int(count) >= maxProfiles {
|
||||
return fmt.Errorf("已达到档案数量上限(%d个)", maxProfiles)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *profileServiceImpl) GetByNames(names []string) ([]*model.Profile, error) {
|
||||
profiles, err := s.profileRepo.GetByNames(names)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查找失败: %w", err)
|
||||
}
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
func (s *profileServiceImpl) GetByProfileName(name string) (*model.Profile, error) {
|
||||
profile, err := s.profileRepo.FindByName(name)
|
||||
if err != nil {
|
||||
return nil, errors.New("用户角色未创建")
|
||||
}
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
// generateRSAPrivateKeyInternal 生成RSA-2048私钥(PEM格式)
|
||||
func generateRSAPrivateKeyInternal() (string, error) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: privateKeyBytes,
|
||||
})
|
||||
|
||||
return string(privateKeyPEM), nil
|
||||
}
|
||||
|
||||
|
||||
215
internal/service/texture_service_impl.go
Normal file
215
internal/service/texture_service_impl.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// textureServiceImpl TextureService的实现
|
||||
type textureServiceImpl struct {
|
||||
textureRepo repository.TextureRepository
|
||||
userRepo repository.UserRepository
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewTextureService 创建TextureService实例
|
||||
func NewTextureService(
|
||||
textureRepo repository.TextureRepository,
|
||||
userRepo repository.UserRepository,
|
||||
logger *zap.Logger,
|
||||
) TextureService {
|
||||
return &textureServiceImpl{
|
||||
textureRepo: textureRepo,
|
||||
userRepo: userRepo,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *textureServiceImpl) Create(uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) {
|
||||
// 验证用户存在
|
||||
user, err := s.userRepo.FindByID(uploaderID)
|
||||
if err != nil || user == nil {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
|
||||
// 检查Hash是否已存在
|
||||
existingTexture, err := s.textureRepo.FindByHash(hash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if existingTexture != nil {
|
||||
return nil, errors.New("该材质已存在")
|
||||
}
|
||||
|
||||
// 转换材质类型
|
||||
textureTypeEnum, err := parseTextureTypeInternal(textureType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建材质
|
||||
texture := &model.Texture{
|
||||
UploaderID: uploaderID,
|
||||
Name: name,
|
||||
Description: description,
|
||||
Type: textureTypeEnum,
|
||||
URL: url,
|
||||
Hash: hash,
|
||||
Size: size,
|
||||
IsPublic: isPublic,
|
||||
IsSlim: isSlim,
|
||||
Status: 1,
|
||||
DownloadCount: 0,
|
||||
FavoriteCount: 0,
|
||||
}
|
||||
|
||||
if err := s.textureRepo.Create(texture); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return texture, nil
|
||||
}
|
||||
|
||||
func (s *textureServiceImpl) GetByID(id int64) (*model.Texture, error) {
|
||||
texture, err := s.textureRepo.FindByID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if texture == nil {
|
||||
return nil, ErrTextureNotFound
|
||||
}
|
||||
if texture.Status == -1 {
|
||||
return nil, errors.New("材质已删除")
|
||||
}
|
||||
return texture, nil
|
||||
}
|
||||
|
||||
func (s *textureServiceImpl) GetByUserID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
page, pageSize = NormalizePagination(page, pageSize)
|
||||
return s.textureRepo.FindByUploaderID(uploaderID, page, pageSize)
|
||||
}
|
||||
|
||||
func (s *textureServiceImpl) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
page, pageSize = NormalizePagination(page, pageSize)
|
||||
return s.textureRepo.Search(keyword, textureType, publicOnly, page, pageSize)
|
||||
}
|
||||
|
||||
func (s *textureServiceImpl) Update(textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) {
|
||||
// 获取材质并验证权限
|
||||
texture, err := s.textureRepo.FindByID(textureID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if texture == nil {
|
||||
return nil, ErrTextureNotFound
|
||||
}
|
||||
if texture.UploaderID != uploaderID {
|
||||
return nil, ErrTextureNoPermission
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
updates := make(map[string]interface{})
|
||||
if name != "" {
|
||||
updates["name"] = name
|
||||
}
|
||||
if description != "" {
|
||||
updates["description"] = description
|
||||
}
|
||||
if isPublic != nil {
|
||||
updates["is_public"] = *isPublic
|
||||
}
|
||||
|
||||
if len(updates) > 0 {
|
||||
if err := s.textureRepo.UpdateFields(textureID, updates); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return s.textureRepo.FindByID(textureID)
|
||||
}
|
||||
|
||||
func (s *textureServiceImpl) Delete(textureID, uploaderID int64) error {
|
||||
// 获取材质并验证权限
|
||||
texture, err := s.textureRepo.FindByID(textureID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if texture == nil {
|
||||
return ErrTextureNotFound
|
||||
}
|
||||
if texture.UploaderID != uploaderID {
|
||||
return ErrTextureNoPermission
|
||||
}
|
||||
|
||||
return s.textureRepo.Delete(textureID)
|
||||
}
|
||||
|
||||
func (s *textureServiceImpl) ToggleFavorite(userID, textureID int64) (bool, error) {
|
||||
// 确保材质存在
|
||||
texture, err := s.textureRepo.FindByID(textureID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if texture == nil {
|
||||
return false, ErrTextureNotFound
|
||||
}
|
||||
|
||||
isFavorited, err := s.textureRepo.IsFavorited(userID, textureID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if isFavorited {
|
||||
// 已收藏 -> 取消收藏
|
||||
if err := s.textureRepo.RemoveFavorite(userID, textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := s.textureRepo.DecrementFavoriteCount(textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 未收藏 -> 添加收藏
|
||||
if err := s.textureRepo.AddFavorite(userID, textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := s.textureRepo.IncrementFavoriteCount(textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *textureServiceImpl) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
page, pageSize = NormalizePagination(page, pageSize)
|
||||
return s.textureRepo.GetUserFavorites(userID, page, pageSize)
|
||||
}
|
||||
|
||||
func (s *textureServiceImpl) CheckUploadLimit(uploaderID int64, maxTextures int) error {
|
||||
count, err := s.textureRepo.CountByUploaderID(uploaderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count >= int64(maxTextures) {
|
||||
return fmt.Errorf("已达到最大上传数量限制(%d)", maxTextures)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseTextureTypeInternal 解析材质类型
|
||||
func parseTextureTypeInternal(textureType string) (model.TextureType, error) {
|
||||
switch textureType {
|
||||
case "SKIN":
|
||||
return model.TextureTypeSkin, nil
|
||||
case "CAPE":
|
||||
return model.TextureTypeCape, nil
|
||||
default:
|
||||
return "", errors.New("无效的材质类型")
|
||||
}
|
||||
}
|
||||
277
internal/service/token_service_impl.go
Normal file
277
internal/service/token_service_impl.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// tokenServiceImpl TokenService的实现
|
||||
type tokenServiceImpl struct {
|
||||
tokenRepo repository.TokenRepository
|
||||
profileRepo repository.ProfileRepository
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewTokenService 创建TokenService实例
|
||||
func NewTokenService(
|
||||
tokenRepo repository.TokenRepository,
|
||||
profileRepo repository.ProfileRepository,
|
||||
logger *zap.Logger,
|
||||
) TokenService {
|
||||
return &tokenServiceImpl{
|
||||
tokenRepo: tokenRepo,
|
||||
profileRepo: profileRepo,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
tokenExtendedTimeout = 10 * time.Second
|
||||
tokensMaxCount = 10
|
||||
)
|
||||
|
||||
func (s *tokenServiceImpl) Create(userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
|
||||
var (
|
||||
selectedProfileID *model.Profile
|
||||
availableProfiles []*model.Profile
|
||||
)
|
||||
|
||||
// 设置超时上下文
|
||||
_, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
// 验证用户存在
|
||||
if UUID != "" {
|
||||
_, err := s.profileRepo.FindByUUID(UUID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 生成令牌
|
||||
if clientToken == "" {
|
||||
clientToken = uuid.New().String()
|
||||
}
|
||||
|
||||
accessToken := uuid.New().String()
|
||||
token := model.Token{
|
||||
AccessToken: accessToken,
|
||||
ClientToken: clientToken,
|
||||
UserID: userID,
|
||||
Usable: true,
|
||||
IssueDate: time.Now(),
|
||||
}
|
||||
|
||||
// 获取用户配置文件
|
||||
profiles, err := s.profileRepo.FindByUserID(userID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 如果用户只有一个配置文件,自动选择
|
||||
if len(profiles) == 1 {
|
||||
selectedProfileID = profiles[0]
|
||||
token.ProfileId = selectedProfileID.UUID
|
||||
}
|
||||
availableProfiles = profiles
|
||||
|
||||
// 插入令牌
|
||||
err = s.tokenRepo.Create(&token)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err)
|
||||
}
|
||||
|
||||
// 清理多余的令牌
|
||||
go s.checkAndCleanupExcessTokens(userID)
|
||||
|
||||
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
|
||||
}
|
||||
|
||||
func (s *tokenServiceImpl) Validate(accessToken, clientToken string) bool {
|
||||
if accessToken == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
token, err := s.tokenRepo.FindByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !token.Usable {
|
||||
return false
|
||||
}
|
||||
|
||||
if clientToken == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
return token.ClientToken == clientToken
|
||||
}
|
||||
|
||||
func (s *tokenServiceImpl) Refresh(accessToken, clientToken, selectedProfileID string) (string, string, error) {
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("accessToken不能为空")
|
||||
}
|
||||
|
||||
// 查找旧令牌
|
||||
oldToken, err := s.tokenRepo.FindByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", "", errors.New("accessToken无效")
|
||||
}
|
||||
s.logger.Error("查询Token失败", zap.Error(err), zap.String("accessToken", accessToken))
|
||||
return "", "", fmt.Errorf("查询令牌失败: %w", err)
|
||||
}
|
||||
|
||||
// 验证profile
|
||||
if selectedProfileID != "" {
|
||||
valid, validErr := s.validateProfileByUserID(oldToken.UserID, selectedProfileID)
|
||||
if validErr != nil {
|
||||
s.logger.Error("验证Profile失败",
|
||||
zap.Error(err),
|
||||
zap.Int64("userId", oldToken.UserID),
|
||||
zap.String("profileId", selectedProfileID),
|
||||
)
|
||||
return "", "", fmt.Errorf("验证角色失败: %w", err)
|
||||
}
|
||||
if !valid {
|
||||
return "", "", errors.New("角色与用户不匹配")
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 clientToken 是否有效
|
||||
if clientToken != "" && clientToken != oldToken.ClientToken {
|
||||
return "", "", errors.New("clientToken无效")
|
||||
}
|
||||
|
||||
// 检查 selectedProfileID 的逻辑
|
||||
if selectedProfileID != "" {
|
||||
if oldToken.ProfileId != "" && oldToken.ProfileId != selectedProfileID {
|
||||
return "", "", errors.New("原令牌已绑定角色,无法选择新角色")
|
||||
}
|
||||
} else {
|
||||
selectedProfileID = oldToken.ProfileId
|
||||
}
|
||||
|
||||
// 生成新令牌
|
||||
newAccessToken := uuid.New().String()
|
||||
newToken := model.Token{
|
||||
AccessToken: newAccessToken,
|
||||
ClientToken: oldToken.ClientToken,
|
||||
UserID: oldToken.UserID,
|
||||
Usable: true,
|
||||
ProfileId: selectedProfileID,
|
||||
IssueDate: time.Now(),
|
||||
}
|
||||
|
||||
// 先插入新令牌,再删除旧令牌
|
||||
err = s.tokenRepo.Create(&newToken)
|
||||
if err != nil {
|
||||
s.logger.Error("创建新Token失败", zap.Error(err), zap.String("accessToken", accessToken))
|
||||
return "", "", fmt.Errorf("创建新Token失败: %w", err)
|
||||
}
|
||||
|
||||
err = s.tokenRepo.DeleteByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
s.logger.Warn("删除旧Token失败,但新Token已创建",
|
||||
zap.Error(err),
|
||||
zap.String("oldToken", oldToken.AccessToken),
|
||||
zap.String("newToken", newAccessToken),
|
||||
)
|
||||
}
|
||||
|
||||
s.logger.Info("成功刷新Token", zap.Int64("userId", oldToken.UserID), zap.String("accessToken", newAccessToken))
|
||||
return newAccessToken, oldToken.ClientToken, nil
|
||||
}
|
||||
|
||||
func (s *tokenServiceImpl) Invalidate(accessToken string) {
|
||||
if accessToken == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err := s.tokenRepo.DeleteByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
s.logger.Error("删除Token失败", zap.Error(err), zap.String("accessToken", accessToken))
|
||||
return
|
||||
}
|
||||
s.logger.Info("成功删除Token", zap.String("token", accessToken))
|
||||
}
|
||||
|
||||
func (s *tokenServiceImpl) InvalidateUserTokens(userID int64) {
|
||||
if userID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := s.tokenRepo.DeleteByUserID(userID)
|
||||
if err != nil {
|
||||
s.logger.Error("删除用户Token失败", zap.Error(err), zap.Int64("userId", userID))
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Info("成功删除用户Token", zap.Int64("userId", userID))
|
||||
}
|
||||
|
||||
func (s *tokenServiceImpl) GetUUIDByAccessToken(accessToken string) (string, error) {
|
||||
return s.tokenRepo.GetUUIDByAccessToken(accessToken)
|
||||
}
|
||||
|
||||
func (s *tokenServiceImpl) GetUserIDByAccessToken(accessToken string) (int64, error) {
|
||||
return s.tokenRepo.GetUserIDByAccessToken(accessToken)
|
||||
}
|
||||
|
||||
// 私有辅助方法
|
||||
|
||||
func (s *tokenServiceImpl) checkAndCleanupExcessTokens(userID int64) {
|
||||
if userID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := s.tokenRepo.GetByUserID(userID)
|
||||
if err != nil {
|
||||
s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
}
|
||||
|
||||
if len(tokens) <= tokensMaxCount {
|
||||
return
|
||||
}
|
||||
|
||||
tokensToDelete := make([]string, 0, len(tokens)-tokensMaxCount)
|
||||
for i := tokensMaxCount; i < len(tokens); i++ {
|
||||
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
|
||||
}
|
||||
|
||||
deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete)
|
||||
if err != nil {
|
||||
s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
}
|
||||
|
||||
if deletedCount > 0 {
|
||||
s.logger.Info("成功清理用户多余Token", zap.Int64("userId", userID), zap.Int64("count", deletedCount))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *tokenServiceImpl) validateProfileByUserID(userID int64, UUID string) (bool, error) {
|
||||
if userID == 0 || UUID == "" {
|
||||
return false, errors.New("用户ID或配置文件ID不能为空")
|
||||
}
|
||||
|
||||
profile, err := s.profileRepo.FindByUUID(UUID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return false, errors.New("配置文件不存在")
|
||||
}
|
||||
return false, fmt.Errorf("验证配置文件失败: %w", err)
|
||||
}
|
||||
return profile.UserID == userID, nil
|
||||
}
|
||||
@@ -4,10 +4,12 @@ import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"carrotskin/pkg/auth"
|
||||
"carrotskin/pkg/config"
|
||||
"carrotskin/pkg/redis"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -286,24 +288,69 @@ func ValidateAvatarURL(avatarURL string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 允许的域名列表
|
||||
allowedDomains := []string{
|
||||
"rustfs.example.com",
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
}
|
||||
|
||||
for _, domain := range allowedDomains {
|
||||
if strings.Contains(avatarURL, domain) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 允许相对路径
|
||||
if strings.HasPrefix(avatarURL, "/") {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("头像URL不在允许的域名列表中")
|
||||
return ValidateURLDomain(avatarURL)
|
||||
}
|
||||
|
||||
// ValidateURLDomain 验证URL的域名是否在允许列表中
|
||||
func ValidateURLDomain(rawURL string) error {
|
||||
// 解析URL
|
||||
parsedURL, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return errors.New("无效的URL格式")
|
||||
}
|
||||
|
||||
// 必须是HTTP或HTTPS协议
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return errors.New("URL必须使用http或https协议")
|
||||
}
|
||||
|
||||
// 获取主机名(不包含端口)
|
||||
host := parsedURL.Hostname()
|
||||
if host == "" {
|
||||
return errors.New("URL缺少主机名")
|
||||
}
|
||||
|
||||
// 从配置获取允许的域名列表
|
||||
cfg, err := config.GetConfig()
|
||||
if err != nil {
|
||||
// 如果配置获取失败,使用默认的安全域名列表
|
||||
allowedDomains := []string{"localhost", "127.0.0.1"}
|
||||
return checkDomainAllowed(host, allowedDomains)
|
||||
}
|
||||
|
||||
return checkDomainAllowed(host, cfg.Security.AllowedDomains)
|
||||
}
|
||||
|
||||
// checkDomainAllowed 检查域名是否在允许列表中
|
||||
func checkDomainAllowed(host string, allowedDomains []string) error {
|
||||
host = strings.ToLower(host)
|
||||
|
||||
for _, allowed := range allowedDomains {
|
||||
allowed = strings.ToLower(strings.TrimSpace(allowed))
|
||||
if allowed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 精确匹配
|
||||
if host == allowed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 支持通配符子域名匹配 (如 *.example.com)
|
||||
if strings.HasPrefix(allowed, "*.") {
|
||||
suffix := allowed[1:] // 移除 "*",保留 ".example.com"
|
||||
if strings.HasSuffix(host, suffix) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New("URL域名不在允许的列表中")
|
||||
}
|
||||
|
||||
// GetUserByEmail 根据邮箱获取用户
|
||||
|
||||
368
internal/service/user_service_impl.go
Normal file
368
internal/service/user_service_impl.go
Normal file
@@ -0,0 +1,368 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"carrotskin/pkg/auth"
|
||||
"carrotskin/pkg/config"
|
||||
"carrotskin/pkg/redis"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// userServiceImpl UserService的实现
|
||||
type userServiceImpl struct {
|
||||
userRepo repository.UserRepository
|
||||
configRepo repository.SystemConfigRepository
|
||||
jwtService *auth.JWTService
|
||||
redis *redis.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewUserService 创建UserService实例
|
||||
func NewUserService(
|
||||
userRepo repository.UserRepository,
|
||||
configRepo repository.SystemConfigRepository,
|
||||
jwtService *auth.JWTService,
|
||||
redisClient *redis.Client,
|
||||
logger *zap.Logger,
|
||||
) UserService {
|
||||
return &userServiceImpl{
|
||||
userRepo: userRepo,
|
||||
configRepo: configRepo,
|
||||
jwtService: jwtService,
|
||||
redis: redisClient,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) Register(username, password, email, avatar string) (*model.User, string, error) {
|
||||
// 检查用户名是否已存在
|
||||
existingUser, err := s.userRepo.FindByUsername(username)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if existingUser != nil {
|
||||
return nil, "", errors.New("用户名已存在")
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existingEmail, err := s.userRepo.FindByEmail(email)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if existingEmail != nil {
|
||||
return nil, "", errors.New("邮箱已被注册")
|
||||
}
|
||||
|
||||
// 加密密码
|
||||
hashedPassword, err := auth.HashPassword(password)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
// 确定头像URL
|
||||
avatarURL := avatar
|
||||
if avatarURL != "" {
|
||||
if err := s.ValidateAvatarURL(avatarURL); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
} else {
|
||||
avatarURL = s.getDefaultAvatar()
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
user := &model.User{
|
||||
Username: username,
|
||||
Password: hashedPassword,
|
||||
Email: email,
|
||||
Avatar: avatarURL,
|
||||
Role: "user",
|
||||
Status: 1,
|
||||
Points: 0,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(user); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// 生成JWT Token
|
||||
token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("生成Token失败")
|
||||
}
|
||||
|
||||
return user, token, nil
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) Login(usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
// 检查账号是否被锁定
|
||||
if s.redis != nil {
|
||||
identifier := usernameOrEmail + ":" + ipAddress
|
||||
locked, ttl, err := CheckLoginLocked(ctx, s.redis, identifier)
|
||||
if err == nil && locked {
|
||||
return nil, "", fmt.Errorf("登录尝试次数过多,请在 %d 分钟后重试", int(ttl.Minutes())+1)
|
||||
}
|
||||
}
|
||||
|
||||
// 查找用户
|
||||
var user *model.User
|
||||
var err error
|
||||
|
||||
if strings.Contains(usernameOrEmail, "@") {
|
||||
user, err = s.userRepo.FindByEmail(usernameOrEmail)
|
||||
} else {
|
||||
user, err = s.userRepo.FindByUsername(usernameOrEmail)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if user == nil {
|
||||
s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, 0, "用户不存在")
|
||||
return nil, "", errors.New("用户名/邮箱或密码错误")
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if user.Status != 1 {
|
||||
s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "账号已被禁用")
|
||||
return nil, "", errors.New("账号已被禁用")
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if !auth.CheckPassword(user.Password, password) {
|
||||
s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "密码错误")
|
||||
return nil, "", errors.New("用户名/邮箱或密码错误")
|
||||
}
|
||||
|
||||
// 登录成功,清除失败计数
|
||||
if s.redis != nil {
|
||||
identifier := usernameOrEmail + ":" + ipAddress
|
||||
_ = ClearLoginAttempts(ctx, s.redis, identifier)
|
||||
}
|
||||
|
||||
// 生成JWT Token
|
||||
token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("生成Token失败")
|
||||
}
|
||||
|
||||
// 更新最后登录时间
|
||||
now := time.Now()
|
||||
user.LastLoginAt = &now
|
||||
_ = s.userRepo.UpdateFields(user.ID, map[string]interface{}{
|
||||
"last_login_at": now,
|
||||
})
|
||||
|
||||
// 记录成功登录日志
|
||||
s.logSuccessLogin(user.ID, ipAddress, userAgent)
|
||||
|
||||
return user, token, nil
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) GetByID(id int64) (*model.User, error) {
|
||||
return s.userRepo.FindByID(id)
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) GetByEmail(email string) (*model.User, error) {
|
||||
return s.userRepo.FindByEmail(email)
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) UpdateInfo(user *model.User) error {
|
||||
return s.userRepo.Update(user)
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) UpdateAvatar(userID int64, avatarURL string) error {
|
||||
return s.userRepo.UpdateFields(userID, map[string]interface{}{
|
||||
"avatar": avatarURL,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) ChangePassword(userID int64, oldPassword, newPassword string) error {
|
||||
user, err := s.userRepo.FindByID(userID)
|
||||
if err != nil || user == nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
if !auth.CheckPassword(user.Password, oldPassword) {
|
||||
return errors.New("原密码错误")
|
||||
}
|
||||
|
||||
hashedPassword, err := auth.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
return s.userRepo.UpdateFields(userID, map[string]interface{}{
|
||||
"password": hashedPassword,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) ResetPassword(email, newPassword string) error {
|
||||
user, err := s.userRepo.FindByEmail(email)
|
||||
if err != nil || user == nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
hashedPassword, err := auth.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
return s.userRepo.UpdateFields(user.ID, map[string]interface{}{
|
||||
"password": hashedPassword,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) ChangeEmail(userID int64, newEmail string) error {
|
||||
existingUser, err := s.userRepo.FindByEmail(newEmail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existingUser != nil && existingUser.ID != userID {
|
||||
return errors.New("邮箱已被其他用户使用")
|
||||
}
|
||||
|
||||
return s.userRepo.UpdateFields(userID, map[string]interface{}{
|
||||
"email": newEmail,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) ValidateAvatarURL(avatarURL string) error {
|
||||
if avatarURL == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 允许相对路径
|
||||
if strings.HasPrefix(avatarURL, "/") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 解析URL
|
||||
parsedURL, err := url.Parse(avatarURL)
|
||||
if err != nil {
|
||||
return errors.New("无效的URL格式")
|
||||
}
|
||||
|
||||
// 必须是HTTP或HTTPS协议
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return errors.New("URL必须使用http或https协议")
|
||||
}
|
||||
|
||||
host := parsedURL.Hostname()
|
||||
if host == "" {
|
||||
return errors.New("URL缺少主机名")
|
||||
}
|
||||
|
||||
// 从配置获取允许的域名列表
|
||||
cfg, err := config.GetConfig()
|
||||
if err != nil {
|
||||
allowedDomains := []string{"localhost", "127.0.0.1"}
|
||||
return s.checkDomainAllowed(host, allowedDomains)
|
||||
}
|
||||
|
||||
return s.checkDomainAllowed(host, cfg.Security.AllowedDomains)
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) GetMaxProfilesPerUser() int {
|
||||
config, err := s.configRepo.GetByKey("max_profiles_per_user")
|
||||
if err != nil || config == nil {
|
||||
return 5
|
||||
}
|
||||
var value int
|
||||
fmt.Sscanf(config.Value, "%d", &value)
|
||||
if value <= 0 {
|
||||
return 5
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) GetMaxTexturesPerUser() int {
|
||||
config, err := s.configRepo.GetByKey("max_textures_per_user")
|
||||
if err != nil || config == nil {
|
||||
return 50
|
||||
}
|
||||
var value int
|
||||
fmt.Sscanf(config.Value, "%d", &value)
|
||||
if value <= 0 {
|
||||
return 50
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// 私有辅助方法
|
||||
|
||||
func (s *userServiceImpl) getDefaultAvatar() string {
|
||||
config, err := s.configRepo.GetByKey("default_avatar")
|
||||
if err != nil || config == nil || config.Value == "" {
|
||||
return ""
|
||||
}
|
||||
return config.Value
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) checkDomainAllowed(host string, allowedDomains []string) error {
|
||||
host = strings.ToLower(host)
|
||||
|
||||
for _, allowed := range allowedDomains {
|
||||
allowed = strings.ToLower(strings.TrimSpace(allowed))
|
||||
if allowed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if host == allowed {
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(allowed, "*.") {
|
||||
suffix := allowed[1:]
|
||||
if strings.HasSuffix(host, suffix) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New("URL域名不在允许的列表中")
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) recordLoginFailure(ctx context.Context, usernameOrEmail, ipAddress, userAgent string, userID int64, reason string) {
|
||||
if s.redis != nil {
|
||||
identifier := usernameOrEmail + ":" + ipAddress
|
||||
count, _ := RecordLoginFailure(ctx, s.redis, identifier)
|
||||
if count >= MaxLoginAttempts {
|
||||
s.logFailedLogin(userID, ipAddress, userAgent, reason+"-账号已锁定")
|
||||
return
|
||||
}
|
||||
}
|
||||
s.logFailedLogin(userID, ipAddress, userAgent, reason)
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) logSuccessLogin(userID int64, ipAddress, userAgent string) {
|
||||
log := &model.UserLoginLog{
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
LoginMethod: "PASSWORD",
|
||||
IsSuccess: true,
|
||||
}
|
||||
_ = s.userRepo.CreateLoginLog(log)
|
||||
}
|
||||
|
||||
func (s *userServiceImpl) logFailedLogin(userID int64, ipAddress, userAgent, reason string) {
|
||||
log := &model.UserLoginLog{
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
LoginMethod: "PASSWORD",
|
||||
IsSuccess: false,
|
||||
FailureReason: reason,
|
||||
}
|
||||
_ = s.userRepo.CreateLoginLog(log)
|
||||
}
|
||||
@@ -55,6 +55,10 @@ func (j *JWTService) GenerateToken(userID int64, username, role string) (string,
|
||||
// ValidateToken 验证JWT Token
|
||||
func (j *JWTService) ValidateToken(tokenString string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// 验证签名算法,防止algorithm confusion攻击
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("不支持的签名算法")
|
||||
}
|
||||
return []byte(j.secretKey), nil
|
||||
})
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
@@ -22,6 +23,7 @@ type Config struct {
|
||||
Log LogConfig `mapstructure:"log"`
|
||||
Upload UploadConfig `mapstructure:"upload"`
|
||||
Email EmailConfig `mapstructure:"email"`
|
||||
Security SecurityConfig `mapstructure:"security"`
|
||||
}
|
||||
|
||||
// ServerConfig 服务器配置
|
||||
@@ -107,6 +109,12 @@ type EmailConfig struct {
|
||||
FromName string `mapstructure:"from_name"`
|
||||
}
|
||||
|
||||
// SecurityConfig 安全配置
|
||||
type SecurityConfig struct {
|
||||
AllowedOrigins []string `mapstructure:"allowed_origins"` // 允许的CORS来源
|
||||
AllowedDomains []string `mapstructure:"allowed_domains"` // 允许的头像/材质URL域名
|
||||
}
|
||||
|
||||
// Load 加载配置 - 完全从环境变量加载,不依赖YAML文件
|
||||
func Load() (*Config, error) {
|
||||
// 加载.env文件(如果存在)
|
||||
@@ -160,7 +168,7 @@ func setDefaults() {
|
||||
|
||||
// RustFS默认配置
|
||||
viper.SetDefault("rustfs.endpoint", "127.0.0.1:9000")
|
||||
viper.SetDefault("rustfs.public_url", "") // 为空时使用 endpoint 构建 URL
|
||||
viper.SetDefault("rustfs.public_url", "") // 为空时使用 endpoint 构建 URL
|
||||
viper.SetDefault("rustfs.use_ssl", false)
|
||||
|
||||
// JWT默认配置
|
||||
@@ -188,6 +196,10 @@ func setDefaults() {
|
||||
// 邮件默认配置
|
||||
viper.SetDefault("email.enabled", false)
|
||||
viper.SetDefault("email.smtp_port", 587)
|
||||
|
||||
// 安全默认配置
|
||||
viper.SetDefault("security.allowed_origins", []string{"*"})
|
||||
viper.SetDefault("security.allowed_domains", []string{"localhost", "127.0.0.1"})
|
||||
}
|
||||
|
||||
// setupEnvMappings 设置环境变量映射
|
||||
@@ -310,6 +322,15 @@ func overrideFromEnv(config *Config) {
|
||||
if env := os.Getenv("ENVIRONMENT"); env != "" {
|
||||
config.Environment = env
|
||||
}
|
||||
|
||||
// 处理安全配置
|
||||
if allowedOrigins := os.Getenv("SECURITY_ALLOWED_ORIGINS"); allowedOrigins != "" {
|
||||
config.Security.AllowedOrigins = strings.Split(allowedOrigins, ",")
|
||||
}
|
||||
|
||||
if allowedDomains := os.Getenv("SECURITY_ALLOWED_DOMAINS"); allowedDomains != "" {
|
||||
config.Security.AllowedDomains = strings.Split(allowedDomains, ",")
|
||||
}
|
||||
}
|
||||
|
||||
// IsTestEnvironment 判断是否为测试环境
|
||||
|
||||
@@ -62,6 +62,3 @@ func MustGetRustFSConfig() *RustFSConfig {
|
||||
return cfg
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user