chore: 初始化仓库,排除二进制文件和覆盖率文件
This commit is contained in:
85
.env.example
Normal file
85
.env.example
Normal file
@@ -0,0 +1,85 @@
|
||||
# CarrotSkin 环境配置文件示例
|
||||
# 复制此文件为 .env 并修改相应的配置值
|
||||
|
||||
# =============================================================================
|
||||
# 服务器配置
|
||||
# =============================================================================
|
||||
SERVER_PORT=:8080
|
||||
SERVER_MODE=debug
|
||||
SERVER_READ_TIMEOUT=30s
|
||||
SERVER_WRITE_TIMEOUT=30s
|
||||
|
||||
# =============================================================================
|
||||
# 数据库配置
|
||||
# =============================================================================
|
||||
DATABASE_DRIVER=postgres
|
||||
DATABASE_HOST=localhost
|
||||
DATABASE_PORT=5432
|
||||
DATABASE_USERNAME=postgres
|
||||
DATABASE_PASSWORD=your_password_here
|
||||
DATABASE_NAME=carrotskin
|
||||
DATABASE_SSL_MODE=disable
|
||||
DATABASE_TIMEZONE=Asia/Shanghai
|
||||
DATABASE_MAX_IDLE_CONNS=10
|
||||
DATABASE_MAX_OPEN_CONNS=100
|
||||
DATABASE_CONN_MAX_LIFETIME=1h
|
||||
|
||||
# =============================================================================
|
||||
# Redis配置
|
||||
# =============================================================================
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=
|
||||
REDIS_DATABASE=0
|
||||
REDIS_POOL_SIZE=10
|
||||
|
||||
# =============================================================================
|
||||
# RustFS对象存储配置 (S3兼容)
|
||||
# =============================================================================
|
||||
RUSTFS_ENDPOINT=127.0.0.1:9000
|
||||
RUSTFS_ACCESS_KEY=your_access_key
|
||||
RUSTFS_SECRET_KEY=your_secret_key
|
||||
RUSTFS_USE_SSL=false
|
||||
RUSTFS_BUCKET_TEXTURES=carrot-skin-textures
|
||||
RUSTFS_BUCKET_AVATARS=carrot-skin-avatars
|
||||
|
||||
# =============================================================================
|
||||
# JWT配置
|
||||
# =============================================================================
|
||||
JWT_SECRET=your-jwt-secret-key-change-this-in-production
|
||||
JWT_EXPIRE_HOURS=168
|
||||
|
||||
# =============================================================================
|
||||
# 日志配置
|
||||
# =============================================================================
|
||||
LOG_LEVEL=info
|
||||
LOG_FORMAT=json
|
||||
LOG_OUTPUT=logs/app.log
|
||||
LOG_MAX_SIZE=100
|
||||
LOG_MAX_BACKUPS=3
|
||||
LOG_MAX_AGE=28
|
||||
LOG_COMPRESS=true
|
||||
|
||||
# =============================================================================
|
||||
# 文件上传配置
|
||||
# =============================================================================
|
||||
UPLOAD_MAX_SIZE=10485760
|
||||
UPLOAD_TEXTURE_MAX_SIZE=2097152
|
||||
UPLOAD_AVATAR_MAX_SIZE=1048576
|
||||
|
||||
# =============================================================================
|
||||
# 安全配置
|
||||
# =============================================================================
|
||||
MAX_LOGIN_ATTEMPTS=5
|
||||
LOGIN_LOCK_DURATION=30m
|
||||
|
||||
# =============================================================================
|
||||
# 邮件配置(可选)
|
||||
# 腾讯企业邮箱SSL配置示例:smtp.exmail.qq.com, 端口465
|
||||
# =============================================================================
|
||||
EMAIL_ENABLED=false
|
||||
EMAIL_SMTP_HOST=smtp.example.com
|
||||
EMAIL_SMTP_PORT=587
|
||||
EMAIL_USERNAME=noreply@example.com
|
||||
EMAIL_PASSWORD=your-email-password
|
||||
EMAIL_FROM_NAME=CarrotSkin
|
||||
43
.gitea/workflows/sonarqube.yml
Normal file
43
.gitea/workflows/sonarqube.yml
Normal file
@@ -0,0 +1,43 @@
|
||||
name: SonarQube Analysis
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
sonarqube:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Shallow clones should be disabled for better analysis
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.23'
|
||||
|
||||
- name: Download and extract SonarQube Scanner
|
||||
run: |
|
||||
export SONAR_SCANNER_VERSION=7.2.0.5079
|
||||
export SONAR_SCANNER_HOME=$HOME/.sonar/sonar-scanner-$SONAR_SCANNER_VERSION-linux-x64
|
||||
curl --create-dirs -sSLo $HOME/.sonar/sonar-scanner.zip https://binaries.sonarsource.com/Distribution/sonar-scanner-cli/sonar-scanner-cli-$SONAR_SCANNER_VERSION-linux-x64.zip
|
||||
unzip -o $HOME/.sonar/sonar-scanner.zip -d $HOME/.sonar/
|
||||
export PATH=$SONAR_SCANNER_HOME/bin:$PATH
|
||||
echo "SONAR_SCANNER_HOME=$SONAR_SCANNER_HOME" >> $GITHUB_ENV
|
||||
echo "$SONAR_SCANNER_HOME/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Run SonarQube Scanner
|
||||
env:
|
||||
SONAR_TOKEN: sqp_b8a64837bd9e967b6876166e9ba27f0bc88626ed
|
||||
run: |
|
||||
export SONAR_SCANNER_VERSION=7.2.0.5079
|
||||
export SONAR_SCANNER_HOME=$HOME/.sonar/sonar-scanner-$SONAR_SCANNER_VERSION-linux-x64
|
||||
export PATH=$SONAR_SCANNER_HOME/bin:$PATH
|
||||
sonar-scanner \
|
||||
-Dsonar.projectKey=CarrotSkin \
|
||||
-Dsonar.sources=. \
|
||||
-Dsonar.host.url=https://sonar.littlelan.cn
|
||||
|
||||
104
.gitea/workflows/test.yml
Normal file
104
.gitea/workflows/test.yml
Normal file
@@ -0,0 +1,104 @@
|
||||
name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
- develop
|
||||
- 'feature/**'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
- develop
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.23'
|
||||
cache-dependency-path: go.sum
|
||||
|
||||
- name: Download dependencies
|
||||
run: go mod download
|
||||
|
||||
- name: Run tests
|
||||
run: go test -v -race -coverprofile=coverage.out -covermode=atomic ./...
|
||||
|
||||
- name: Generate coverage report
|
||||
run: |
|
||||
go tool cover -html=coverage.out -o coverage.html
|
||||
go tool cover -func=coverage.out -o coverage.txt
|
||||
|
||||
- name: Upload coverage reports
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: coverage-reports
|
||||
path: |
|
||||
coverage.out
|
||||
coverage.html
|
||||
coverage.txt
|
||||
|
||||
- name: Display coverage summary
|
||||
run: |
|
||||
echo "## Test Coverage Summary" >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
cat coverage.txt >> $GITHUB_STEP_SUMMARY
|
||||
echo '```' >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.23'
|
||||
cache-dependency-path: go.sum
|
||||
|
||||
- name: Download dependencies
|
||||
run: go mod download
|
||||
|
||||
- name: Run golangci-lint
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
with:
|
||||
version: latest
|
||||
args: --timeout=5m
|
||||
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [test, lint]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.23'
|
||||
cache-dependency-path: go.sum
|
||||
|
||||
- name: Download dependencies
|
||||
run: go mod download
|
||||
|
||||
- name: Build
|
||||
run: go build -v -o server ./cmd/server
|
||||
|
||||
- name: Upload build artifacts
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: build-artifacts
|
||||
path: server
|
||||
|
||||
107
.gitignore
vendored
Normal file
107
.gitignore
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Test binary, built with `go test -c`
|
||||
*.test
|
||||
|
||||
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||
*.out
|
||||
|
||||
# Dependency directories (remove the comment below to include it)
|
||||
# vendor/
|
||||
|
||||
# Go workspace file
|
||||
go.work
|
||||
|
||||
# Build directories
|
||||
bin/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Compiled binaries
|
||||
server
|
||||
|
||||
# IDE files
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS generated files
|
||||
.DS_Store
|
||||
.DS_Store?
|
||||
._*
|
||||
.Spotlight-V100
|
||||
.Trashes
|
||||
ehthumbs.db
|
||||
Thumbs.db
|
||||
|
||||
# Log files
|
||||
*.log
|
||||
logs/
|
||||
log/
|
||||
|
||||
# Configuration files (keep examples)
|
||||
configs/config.yaml
|
||||
!configs/config.yaml.example
|
||||
!configs/casbin/
|
||||
|
||||
# Environment files
|
||||
.env
|
||||
.env.local
|
||||
.env.development
|
||||
.env.test
|
||||
.env.production
|
||||
|
||||
# Keep example files
|
||||
!.env.example
|
||||
|
||||
# Database files
|
||||
*.db
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
|
||||
# Backup files
|
||||
*.bak
|
||||
*.backup
|
||||
|
||||
# Temporary files
|
||||
tmp/
|
||||
temp/
|
||||
.tmp/
|
||||
|
||||
# Coverage files
|
||||
coverage.out
|
||||
coverage.html
|
||||
|
||||
# Air live reload
|
||||
.air.toml
|
||||
tmp/
|
||||
|
||||
# Testing
|
||||
test_results/
|
||||
test_coverage/
|
||||
|
||||
# Documentation generation
|
||||
api/swagger/docs.go
|
||||
api/swagger/*.json
|
||||
api/swagger/*.yaml
|
||||
|
||||
# Docker volumes
|
||||
docker/data/
|
||||
docker/logs/
|
||||
|
||||
# MinIO data
|
||||
minio-data/
|
||||
|
||||
# Compiled protobuf files
|
||||
*.pb.go
|
||||
|
||||
# Local development files
|
||||
local/
|
||||
dev/
|
||||
564
README.md
Normal file
564
README.md
Normal file
@@ -0,0 +1,564 @@
|
||||
# CarrotSkin Backend
|
||||
|
||||
一个功能完善的Minecraft皮肤站后端系统,采用单体架构设计,基于Go语言和Gin框架开发。
|
||||
|
||||
## ✨ 核心功能
|
||||
|
||||
- ✅ **用户认证系统** - 注册、登录、JWT认证、积分系统
|
||||
- ✅ **邮箱验证系统** - 注册验证、找回密码、更换邮箱(基于Redis的验证码)
|
||||
- ✅ **材质管理系统** - 皮肤/披风上传、搜索、收藏、下载统计
|
||||
- ✅ **角色档案系统** - Minecraft角色创建、管理、RSA密钥生成
|
||||
- ✅ **文件存储** - MinIO/RustFS对象存储集成、预签名URL上传
|
||||
- ✅ **缓存系统** - Redis缓存、验证码存储、频率限制
|
||||
- ✅ **权限管理** - Casbin RBAC权限控制
|
||||
- ✅ **数据审计** - 登录日志、操作审计、下载记录
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
backend/
|
||||
├── cmd/ # 应用程序入口
|
||||
│ └── server/ # 主服务器入口
|
||||
│ └── main.go # 服务初始化、路由注册
|
||||
├── internal/ # 私有应用代码
|
||||
│ ├── handler/ # HTTP处理器(函数式)
|
||||
│ │ ├── routes.go # 路由注册
|
||||
│ │ ├── auth_handler.go
|
||||
│ │ ├── user_handler.go
|
||||
│ │ └── ...
|
||||
│ ├── service/ # 业务逻辑服务(函数式)
|
||||
│ │ ├── common.go # 公共声明(jsoniter等)
|
||||
│ │ ├── user_service.go
|
||||
│ │ └── ...
|
||||
│ ├── repository/ # 数据访问层(函数式)
|
||||
│ │ ├── user_repository.go
|
||||
│ │ └── ...
|
||||
│ ├── model/ # 数据模型(GORM)
|
||||
│ ├── middleware/ # 中间件
|
||||
│ └── types/ # 类型定义
|
||||
├── pkg/ # 公共库代码
|
||||
│ ├── auth/ # 认证授权
|
||||
│ │ └── manager.go # JWT服务管理器
|
||||
│ ├── config/ # 配置管理
|
||||
│ │ └── manager.go # 配置管理器
|
||||
│ ├── database/ # 数据库连接
|
||||
│ │ ├── manager.go # 数据库管理器(AutoMigrate)
|
||||
│ │ └── postgres.go # PostgreSQL连接
|
||||
│ ├── email/ # 邮件服务
|
||||
│ │ └── manager.go # 邮件服务管理器
|
||||
│ ├── logger/ # 日志系统
|
||||
│ │ └── manager.go # 日志管理器
|
||||
│ ├── redis/ # Redis客户端
|
||||
│ │ └── manager.go # Redis管理器
|
||||
│ ├── storage/ # 文件存储(RustFS/MinIO)
|
||||
│ │ └── manager.go # 存储管理器
|
||||
│ ├── utils/ # 工具函数
|
||||
│ └── validator/ # 数据验证
|
||||
├── docs/ # API定义和文档(Swagger)
|
||||
├── configs/ # 配置文件
|
||||
│ └── casbin/ # Casbin权限配置
|
||||
├── go.mod # Go模块依赖
|
||||
├── go.sum # Go模块校验
|
||||
├── start.sh # Linux/Mac启动脚本
|
||||
├── .env # 环境变量配置
|
||||
└── README.md # 项目说明
|
||||
```
|
||||
|
||||
## 技术栈
|
||||
|
||||
- **语言**: Go 1.23+
|
||||
- **框架**: Gin Web Framework
|
||||
- **数据库**: PostgreSQL 15+ (GORM ORM)
|
||||
- **缓存**: Redis 6.0+
|
||||
- **存储**: RustFS/MinIO (S3兼容对象存储)
|
||||
- **权限**: Casbin RBAC
|
||||
- **日志**: Zap (结构化日志)
|
||||
- **配置**: 环境变量 (.env) + Viper
|
||||
- **JSON**: jsoniter (高性能JSON序列化)
|
||||
- **文档**: Swagger/OpenAPI 3.0
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 环境要求
|
||||
|
||||
- Go 1.21或更高版本
|
||||
- PostgreSQL 15或更高版本
|
||||
- Redis 6.0或更高版本
|
||||
- RustFS 或其他 S3 兼容对象存储服务
|
||||
|
||||
### 安装和运行
|
||||
|
||||
1. **克隆项目**
|
||||
```bash
|
||||
git clone <repository-url>
|
||||
cd CarrotSkin/backend
|
||||
```
|
||||
|
||||
2. **安装依赖**
|
||||
```bash
|
||||
go mod download
|
||||
```
|
||||
|
||||
3. **配置环境**
|
||||
```bash
|
||||
# 复制环境变量文件
|
||||
cp .env.example .env
|
||||
# 编辑 .env 文件配置数据库、RustFS等服务连接信息
|
||||
```
|
||||
|
||||
**注意**:项目完全依赖 `.env` 文件进行配置,不再使用 YAML 配置文件,便于 Docker 容器化部署。
|
||||
|
||||
4. **初始化数据库**
|
||||
```bash
|
||||
# 创建数据库
|
||||
createdb carrotskin
|
||||
# 或者使用PostgreSQL客户端
|
||||
psql -h localhost -U postgres -c "CREATE DATABASE carrotskin;"
|
||||
```
|
||||
|
||||
> 💡 **提示**: 项目使用 GORM 的 `AutoMigrate` 功能自动创建和更新数据库表结构,无需手动执行SQL脚本。首次启动时会自动创建所有表。
|
||||
|
||||
5. **运行服务**
|
||||
|
||||
方式一:使用启动脚本(推荐)
|
||||
```bash
|
||||
# Linux/Mac
|
||||
chmod +x start.sh
|
||||
./start.sh
|
||||
|
||||
# Windows
|
||||
start.bat
|
||||
```
|
||||
|
||||
方式二:直接运行
|
||||
```bash
|
||||
# 设置环境变量(或使用.env文件)
|
||||
export DATABASE_HOST=localhost
|
||||
export DATABASE_PORT=5432
|
||||
# ... 其他环境变量
|
||||
|
||||
# 运行服务
|
||||
go run cmd/server/main.go
|
||||
```
|
||||
|
||||
> 💡 **提示**:
|
||||
> - 启动脚本会自动加载 `.env` 文件中的环境变量
|
||||
> - 首次启动时会自动执行数据库迁移(AutoMigrate)
|
||||
> - 如果对象存储未配置,服务仍可启动(相关功能不可用)
|
||||
|
||||
服务启动后:
|
||||
- **服务地址**: http://localhost:8080
|
||||
- **Swagger文档**: http://localhost:8080/swagger/index.html
|
||||
- **健康检查**: http://localhost:8080/health
|
||||
|
||||
## API接口
|
||||
|
||||
### 认证相关
|
||||
- `POST /api/v1/auth/register` - 用户注册(需邮箱验证码)
|
||||
- `POST /api/v1/auth/login` - 用户登录(支持用户名/邮箱)
|
||||
- `POST /api/v1/auth/send-code` - 发送验证码(注册/重置密码/更换邮箱)
|
||||
- `POST /api/v1/auth/reset-password` - 重置密码(需验证码)
|
||||
|
||||
### 用户相关(需认证)
|
||||
- `GET /api/v1/user/profile` - 获取用户信息
|
||||
- `PUT /api/v1/user/profile` - 更新用户信息(头像、密码)
|
||||
- `POST /api/v1/user/avatar/upload-url` - 生成头像上传URL
|
||||
- `PUT /api/v1/user/avatar` - 更新头像
|
||||
- `POST /api/v1/user/change-email` - 更换邮箱(需验证码)
|
||||
|
||||
### 材质管理
|
||||
公开接口:
|
||||
- `GET /api/v1/texture` - 搜索材质
|
||||
- `GET /api/v1/texture/:id` - 获取材质详情
|
||||
|
||||
认证接口:
|
||||
- `POST /api/v1/texture/upload-url` - 生成材质上传URL
|
||||
- `POST /api/v1/texture` - 创建材质记录
|
||||
- `PUT /api/v1/texture/:id` - 更新材质
|
||||
- `DELETE /api/v1/texture/:id` - 删除材质
|
||||
- `POST /api/v1/texture/:id/favorite` - 切换收藏状态
|
||||
- `GET /api/v1/texture/my` - 我的材质列表
|
||||
- `GET /api/v1/texture/favorites` - 我的收藏列表
|
||||
|
||||
### 角色档案
|
||||
公开接口:
|
||||
- `GET /api/v1/profile/:uuid` - 获取档案详情
|
||||
|
||||
认证接口:
|
||||
- `POST /api/v1/profile` - 创建角色档案(UUID由后端生成)
|
||||
- `GET /api/v1/profile` - 我的档案列表
|
||||
- `PUT /api/v1/profile/:uuid` - 更新档案
|
||||
- `DELETE /api/v1/profile/:uuid` - 删除档案
|
||||
- `POST /api/v1/profile/:uuid/activate` - 设置活跃档案
|
||||
|
||||
### 系统配置
|
||||
- `GET /api/v1/system/config` - 获取系统配置
|
||||
|
||||
## 配置管理
|
||||
|
||||
### 环境变量配置
|
||||
|
||||
项目**完全依赖环境变量**进行配置,不使用 YAML 配置文件,便于容器化部署:
|
||||
|
||||
1. **配置来源**: 环境变量 或 `.env` 文件
|
||||
2. **环境变量格式**: 使用下划线分隔,全大写,如 `DATABASE_HOST`
|
||||
3. **容器部署**: 直接在容器运行时设置环境变量即可
|
||||
|
||||
**主要环境变量**:
|
||||
```bash
|
||||
# 数据库配置
|
||||
DATABASE_HOST=localhost
|
||||
DATABASE_PORT=5432
|
||||
DATABASE_USERNAME=postgres
|
||||
DATABASE_PASSWORD=your_password
|
||||
DATABASE_NAME=carrotskin
|
||||
|
||||
# Redis配置
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=your_redis_password
|
||||
REDIS_DATABASE=0
|
||||
REDIS_POOL_SIZE=10
|
||||
|
||||
# RustFS对象存储配置 (S3兼容)
|
||||
RUSTFS_ENDPOINT=127.0.0.1:9000
|
||||
RUSTFS_ACCESS_KEY=your_access_key
|
||||
RUSTFS_SECRET_KEY=your_secret_key
|
||||
RUSTFS_USE_SSL=false
|
||||
RUSTFS_BUCKET_TEXTURES=carrot-skin-textures
|
||||
RUSTFS_BUCKET_AVATARS=carrot-skin-avatars
|
||||
|
||||
# JWT配置
|
||||
JWT_SECRET=your-jwt-secret-key
|
||||
JWT_EXPIRE_HOURS=168
|
||||
|
||||
# 邮件配置
|
||||
EMAIL_ENABLED=true
|
||||
EMAIL_SMTP_HOST=smtp.example.com
|
||||
EMAIL_SMTP_PORT=587
|
||||
EMAIL_USERNAME=noreply@example.com
|
||||
EMAIL_PASSWORD=your_email_password
|
||||
EMAIL_FROM_NAME=CarrotSkin
|
||||
```
|
||||
|
||||
**动态配置(存储在数据库中)**:
|
||||
- 积分系统配置(注册奖励、签到积分等)
|
||||
- 用户限制配置(最大材质数、最大角色数等)
|
||||
- 网站设置(站点名称、公告、维护模式等)
|
||||
|
||||
完整的环境变量列表请参考 `.env.example` 文件。
|
||||
|
||||
### 数据库自动迁移
|
||||
|
||||
项目使用 GORM 的 `AutoMigrate` 功能自动管理数据库表结构:
|
||||
|
||||
- **首次启动**: 自动创建所有表结构
|
||||
- **模型更新**: 自动添加新字段、索引等
|
||||
- **类型转换**: 自动处理字段类型变更(如枚举类型转为varchar)
|
||||
- **外键管理**: 自动管理外键关系
|
||||
|
||||
**注意事项**:
|
||||
- 生产环境建议先备份数据库再执行迁移
|
||||
- 某些复杂变更(如删除字段)可能需要手动处理
|
||||
- 枚举类型在PostgreSQL中存储为varchar,避免类型兼容问题
|
||||
|
||||
## 架构设计
|
||||
|
||||
### 面向过程的函数式架构
|
||||
|
||||
项目采用**面向过程的函数式架构**,摒弃不必要的面向对象抽象,使用独立函数和单例管理器模式,代码更简洁、可维护性更强:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────┐
|
||||
│ Handler 层 (函数) │ ← 路由处理、参数验证、响应格式化
|
||||
├─────────────────────────────────────┤
|
||||
│ Service 层 (函数) │ ← 业务逻辑、权限检查、数据验证
|
||||
├─────────────────────────────────────┤
|
||||
│ Repository 层 (函数) │ ← 数据库操作、关联查询
|
||||
├─────────────────────────────────────┤
|
||||
│ Manager 层 (单例模式) │ ← 核心依赖管理(线程安全)
|
||||
│ - database.MustGetDB() │
|
||||
│ - logger.MustGetLogger() │
|
||||
│ - auth.MustGetJWTService() │
|
||||
│ - redis.MustGetClient() │
|
||||
│ - email.MustGetService() │
|
||||
│ - storage.MustGetClient() │
|
||||
│ - config.MustGetConfig() │
|
||||
├──────────────┬──────────────────────┤
|
||||
│ PostgreSQL │ Redis │ RustFS │ ← 数据存储层
|
||||
└──────────────┴──────────────────────┘
|
||||
```
|
||||
|
||||
### 架构特点
|
||||
|
||||
1. **函数式设计**: 所有业务逻辑以独立函数形式实现,无结构体方法,降低耦合度
|
||||
2. **管理器模式**: 使用 `sync.Once` 实现线程安全的单例管理器,统一管理核心依赖
|
||||
3. **按需获取**: 通过管理器函数按需获取依赖,避免链式传递,代码更清晰
|
||||
4. **自动迁移**: 使用 GORM AutoMigrate 自动管理数据库表结构
|
||||
5. **高性能**: 使用 jsoniter 替代标准库 json,提升序列化性能
|
||||
|
||||
### 核心模块
|
||||
|
||||
1. **认证模块** (`internal/handler/auth_handler.go`)
|
||||
- JWT令牌生成和验证(通过 `auth.MustGetJWTService()` 获取)
|
||||
- bcrypt密码加密
|
||||
- 邮箱验证码注册
|
||||
- 密码重置功能
|
||||
- 登录日志记录(支持用户名/邮箱登录)
|
||||
|
||||
2. **用户模块** (`internal/handler/user_handler.go`)
|
||||
- 用户信息管理
|
||||
- 头像上传(预签名URL,通过 `storage.MustGetClient()` 获取)
|
||||
- 密码修改(需原密码验证)
|
||||
- 邮箱更换(需验证码)
|
||||
- 积分系统
|
||||
|
||||
3. **邮箱验证模块** (`internal/service/verification_service.go`)
|
||||
- 验证码生成(6位数字)
|
||||
- 验证码存储(Redis,10分钟有效期,通过 `redis.MustGetClient()` 获取)
|
||||
- 发送频率限制(1分钟)
|
||||
- 邮件发送(HTML格式,通过 `email.MustGetService()` 获取)
|
||||
|
||||
4. **材质模块** (`internal/handler/texture_handler.go`)
|
||||
- 材质上传(预签名URL)
|
||||
- 材质搜索和收藏
|
||||
- Hash去重
|
||||
- 下载统计
|
||||
|
||||
5. **档案模块** (`internal/handler/profile_handler.go`)
|
||||
- Minecraft角色管理
|
||||
- RSA密钥生成(RSA-2048)
|
||||
- 活跃状态管理
|
||||
- 档案数量限制
|
||||
|
||||
6. **管理器模块** (`pkg/*/manager.go`)
|
||||
- 数据库管理器:`database.MustGetDB()` - 线程安全的数据库连接
|
||||
- 日志管理器:`logger.MustGetLogger()` - 结构化日志实例
|
||||
- JWT管理器:`auth.MustGetJWTService()` - JWT服务实例
|
||||
- Redis管理器:`redis.MustGetClient()` - Redis客户端
|
||||
- 邮件管理器:`email.MustGetService()` - 邮件服务
|
||||
- 存储管理器:`storage.MustGetClient()` - 对象存储客户端
|
||||
- 配置管理器:`config.MustGetConfig()` - 应用配置
|
||||
|
||||
### 技术特性
|
||||
|
||||
- **架构优势**:
|
||||
- 面向过程的函数式设计,代码简洁清晰
|
||||
- 单例管理器模式,线程安全的依赖管理
|
||||
- 按需获取依赖,避免链式传递
|
||||
- 自动数据库迁移(AutoMigrate)
|
||||
|
||||
- **安全性**:
|
||||
- bcrypt密码加密、JWT令牌认证
|
||||
- 邮箱验证码(注册/重置密码/更换邮箱)
|
||||
- Casbin RBAC权限控制
|
||||
- 频率限制(防暴力破解)
|
||||
|
||||
- **性能**:
|
||||
- jsoniter 高性能JSON序列化(替代标准库)
|
||||
- PostgreSQL索引优化
|
||||
- Redis缓存(验证码、会话等)
|
||||
- 预签名URL减轻服务器压力
|
||||
- 连接池管理
|
||||
|
||||
- **可靠性**:
|
||||
- 事务保证数据一致性
|
||||
- 完整的错误处理和日志记录
|
||||
- 优雅关闭和资源清理
|
||||
- 对象存储连接失败时服务仍可启动
|
||||
|
||||
- **可扩展**:
|
||||
- 清晰的函数式架构
|
||||
- 管理器模式统一管理依赖
|
||||
- 环境变量配置(便于容器化)
|
||||
|
||||
- **审计**:
|
||||
- 登录日志(成功/失败)
|
||||
- 操作审计
|
||||
- 下载记录
|
||||
|
||||
## 开发指南
|
||||
|
||||
### 代码结构
|
||||
|
||||
- `cmd/server/` - 应用入口,初始化服务
|
||||
- `internal/handler/` - HTTP请求处理
|
||||
- `internal/service/` - 业务逻辑实现
|
||||
- `internal/repository/` - 数据库操作
|
||||
- `internal/model/` - 数据模型定义
|
||||
- `internal/types/` - 请求/响应类型定义
|
||||
- `internal/middleware/` - 中间件(JWT、CORS、日志等)
|
||||
- `pkg/` - 可复用的公共库
|
||||
|
||||
### 开发规范
|
||||
|
||||
1. **代码风格**: 遵循Go官方代码规范,使用 `gofmt` 格式化
|
||||
2. **架构模式**: 使用函数式设计,避免不必要的结构体和方法
|
||||
3. **依赖管理**: 通过管理器函数获取依赖(如 `database.MustGetDB()`),避免链式传递
|
||||
4. **错误处理**: 使用统一的错误响应格式 (`model.NewErrorResponse`)
|
||||
5. **日志记录**: 使用 Zap 结构化日志,通过 `logger.MustGetLogger()` 获取实例
|
||||
6. **JSON序列化**: 使用 jsoniter 替代标准库 json,提升性能
|
||||
7. **RESTful API**: 遵循 REST 设计原则,合理使用HTTP方法
|
||||
|
||||
### 添加新功能
|
||||
|
||||
1. 在 `internal/model/` 定义数据模型(GORM会自动迁移)
|
||||
2. 在 `internal/repository/` 实现数据访问函数(使用 `database.MustGetDB()` 获取数据库)
|
||||
3. 在 `internal/service/` 实现业务逻辑函数(按需使用管理器获取依赖)
|
||||
4. 在 `internal/handler/` 实现HTTP处理函数(使用管理器获取logger、jwtService等)
|
||||
5. 在 `internal/handler/routes.go` 注册路由
|
||||
|
||||
**示例**:
|
||||
```go
|
||||
// Repository层
|
||||
func FindUserByID(id uint) (*model.User, error) {
|
||||
db := database.MustGetDB()
|
||||
var user model.User
|
||||
err := db.First(&user, id).Error
|
||||
return &user, err
|
||||
}
|
||||
|
||||
// Service层
|
||||
func GetUserProfile(userID uint) (*model.User, error) {
|
||||
logger := logger.MustGetLogger()
|
||||
user, err := repository.FindUserByID(userID)
|
||||
if err != nil {
|
||||
logger.Error("获取用户失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Handler层
|
||||
func GetUserProfile(c *gin.Context) {
|
||||
logger := logger.MustGetLogger()
|
||||
jwtService := auth.MustGetJWTService()
|
||||
// ... 处理逻辑
|
||||
}
|
||||
```
|
||||
|
||||
## 部署
|
||||
|
||||
### 本地开发
|
||||
|
||||
```bash
|
||||
# 安装依赖
|
||||
go mod download
|
||||
|
||||
# 配置环境变量(创建.env文件或直接export)
|
||||
cp .env.example .env
|
||||
# 编辑 .env 文件
|
||||
|
||||
# 启动服务
|
||||
# 方式1: 使用启动脚本
|
||||
./start.sh # Linux/Mac
|
||||
start.bat # Windows
|
||||
|
||||
# 方式2: 直接运行
|
||||
go run cmd/server/main.go
|
||||
```
|
||||
|
||||
**首次启动**:
|
||||
- 会自动执行数据库迁移(AutoMigrate),创建所有表结构
|
||||
- 如果对象存储未配置,会记录警告但服务仍可启动
|
||||
- 检查日志确认所有服务初始化成功
|
||||
|
||||
### 生产部署
|
||||
|
||||
```bash
|
||||
# 构建二进制文件
|
||||
go build -o carrotskin-server cmd/server/main.go
|
||||
|
||||
# 运行服务
|
||||
./carrotskin-server
|
||||
```
|
||||
|
||||
### Docker部署
|
||||
|
||||
```bash
|
||||
# 构建镜像
|
||||
docker build -t carrotskin-backend:latest .
|
||||
|
||||
# 启动服务
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
## 故障排查
|
||||
|
||||
### 常见问题
|
||||
|
||||
1. **数据库连接失败**
|
||||
- 检查 `.env` 中的数据库配置(`DATABASE_HOST`, `DATABASE_PORT`, `DATABASE_USERNAME`, `DATABASE_PASSWORD`, `DATABASE_NAME`)
|
||||
- 确认PostgreSQL服务已启动
|
||||
- 验证数据库用户权限
|
||||
- 确认数据库已创建:`createdb carrotskin` 或 `psql -c "CREATE DATABASE carrotskin;"`
|
||||
- 检查数据库迁移日志,确认表结构创建成功
|
||||
|
||||
2. **Redis连接失败**
|
||||
- 检查Redis服务是否运行:`redis-cli ping`
|
||||
- 验证 `.env` 中的Redis配置
|
||||
- 确认Redis密码是否正确
|
||||
- 检查防火墙规则
|
||||
|
||||
3. **RustFS/MinIO连接失败**
|
||||
- 检查存储服务是否运行
|
||||
- 验证访问密钥是否正确(`RUSTFS_ACCESS_KEY`, `RUSTFS_SECRET_KEY`)
|
||||
- 确认存储桶是否已创建(`RUSTFS_BUCKET_TEXTURES`, `RUSTFS_BUCKET_AVATARS`)
|
||||
- 检查网络连接和端口(`RUSTFS_ENDPOINT`)
|
||||
- **注意**: 如果对象存储连接失败,服务仍可启动,但上传功能不可用
|
||||
|
||||
4. **邮件发送失败**
|
||||
- 检查 `EMAIL_ENABLED=true`
|
||||
- 验证SMTP服务器地址和端口
|
||||
- 确认邮箱用户名和密码正确
|
||||
- 检查邮件服务商是否需要开启SMTP
|
||||
- 查看日志获取详细错误信息
|
||||
|
||||
5. **验证码相关问题**
|
||||
- 验证码过期(10分钟有效期)
|
||||
- 发送过于频繁(1分钟限制)
|
||||
- Redis存储失败(检查Redis连接)
|
||||
- 邮件未收到(检查垃圾邮件)
|
||||
|
||||
6. **JWT验证失败**
|
||||
- 检查 `JWT_SECRET` 是否配置
|
||||
- 验证令牌是否过期(默认168小时)
|
||||
- 确认请求头中包含 `Authorization: Bearer <token>`
|
||||
- Token格式是否正确
|
||||
|
||||
### 调试技巧
|
||||
|
||||
1. **查看日志**
|
||||
```bash
|
||||
# 实时查看日志
|
||||
tail -f logs/app.log
|
||||
|
||||
# 搜索错误日志
|
||||
grep "ERROR" logs/app.log
|
||||
```
|
||||
|
||||
2. **测试Redis连接**
|
||||
```bash
|
||||
redis-cli -h localhost -p 6379 -a your_password
|
||||
> PING
|
||||
> KEYS *
|
||||
```
|
||||
|
||||
3. **测试数据库连接**
|
||||
```bash
|
||||
psql -h localhost -U postgres -d carrotskin
|
||||
\dt # 查看所有表
|
||||
```
|
||||
|
||||
4. **测试邮件配置**
|
||||
- 使用Swagger文档测试 `/api/v1/auth/send-code` 接口
|
||||
- 检查邮件服务商是否限制发送频率
|
||||
|
||||
### 开发调试
|
||||
|
||||
启用详细日志:
|
||||
```bash
|
||||
# 在 .env 中设置
|
||||
LOG_LEVEL=debug
|
||||
SERVER_MODE=debug
|
||||
```
|
||||
14
configs/casbin/rbac_model.conf
Normal file
14
configs/casbin/rbac_model.conf
Normal file
@@ -0,0 +1,14 @@
|
||||
[request_definition]
|
||||
r = sub, obj, act
|
||||
|
||||
[policy_definition]
|
||||
p = sub, obj, act
|
||||
|
||||
[role_definition]
|
||||
g = _, _
|
||||
|
||||
[policy_effect]
|
||||
e = some(where (p.eft == allow))
|
||||
|
||||
[matchers]
|
||||
m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act
|
||||
1720
docs/docs.go
Normal file
1720
docs/docs.go
Normal file
File diff suppressed because it is too large
Load Diff
1691
docs/swagger.json
Normal file
1691
docs/swagger.json
Normal file
File diff suppressed because it is too large
Load Diff
1110
docs/swagger.yaml
Normal file
1110
docs/swagger.yaml
Normal file
File diff suppressed because it is too large
Load Diff
91
go.mod
Normal file
91
go.mod
Normal file
@@ -0,0 +1,91 @@
|
||||
module carrotskin
|
||||
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.24.2
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/jordan-wright/email v4.0.1-0.20210109023952-943e75fe5223+incompatible
|
||||
github.com/minio/minio-go/v7 v7.0.66
|
||||
github.com/redis/go-redis/v9 v9.0.5
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/swaggo/files v1.0.1
|
||||
github.com/swaggo/gin-swagger v1.6.0
|
||||
github.com/wenlng/go-captcha-assets v1.0.7
|
||||
github.com/wenlng/go-captcha/v2 v2.0.4
|
||||
go.uber.org/zap v1.26.0
|
||||
gorm.io/driver/postgres v1.5.4
|
||||
gorm.io/gorm v1.25.5
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
|
||||
golang.org/x/image v0.16.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/KyleBanks/depth v1.2.1 // indirect
|
||||
github.com/PuerkitoBio/purell v1.1.1 // indirect
|
||||
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-openapi/jsonpointer v0.19.5 // indirect
|
||||
github.com/go-openapi/jsonreference v0.19.6 // indirect
|
||||
github.com/go-openapi/spec v0.20.4 // indirect
|
||||
github.com/go-openapi/swag v0.19.15 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.15.1 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/uuid v1.5.0
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.4.3
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12
|
||||
github.com/klauspost/compress v1.17.4 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.6 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/mailru/easyjson v0.7.6 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/swaggo/swag v1.16.2
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
go.uber.org/multierr v1.10.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/crypto v0.40.0
|
||||
golang.org/x/net v0.42.0 // indirect
|
||||
golang.org/x/sys v0.34.0 // indirect
|
||||
golang.org/x/text v0.28.0 // indirect
|
||||
golang.org/x/tools v0.35.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
261
go.sum
Normal file
261
go.sum
Normal file
@@ -0,0 +1,261 @@
|
||||
github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc=
|
||||
github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE=
|
||||
github.com/PuerkitoBio/purell v1.1.1 h1:WEQqlqaGbrPkxLJWfBwQmfEAE1Z7ONdDLqrN38tNFfI=
|
||||
github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0=
|
||||
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV/sSk/8dngufqelfh6jnri85riMAaF/M=
|
||||
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE=
|
||||
github.com/bsm/ginkgo/v2 v2.7.0 h1:ItPMPH90RbmZJt5GtkcNvIRuGEdwlBItdNVoyzaNQao=
|
||||
github.com/bsm/ginkgo/v2 v2.7.0/go.mod h1:AiKlXPm7ItEHNc/2+OkrNG4E0ITzojb9/xWzvQ9XZ9w=
|
||||
github.com/bsm/gomega v1.26.0 h1:LhQm+AFcgV2M0WyKroMASzAzCAJVpAxQXv4SaI9a69Y=
|
||||
github.com/bsm/gomega v1.26.0/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
||||
github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
|
||||
github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||
github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
|
||||
github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY=
|
||||
github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
|
||||
github.com/go-openapi/jsonreference v0.19.6 h1:UBIxjkht+AWIgYzCDSv2GN+E/togfwXUJFRTWhl2Jjs=
|
||||
github.com/go-openapi/jsonreference v0.19.6/go.mod h1:diGHMEHg2IqXZGKxqyvWdfWU/aim5Dprw5bqpKkTvns=
|
||||
github.com/go-openapi/spec v0.20.4 h1:O8hJrt0UMnhHcluhIdUgCLRWyM2x7QkBXRvOs7m+O1M=
|
||||
github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7FOEWeq8I=
|
||||
github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
|
||||
github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyrCM=
|
||||
github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.15.1 h1:BSe8uhN+xQ4r5guV/ywQI4gO59C2raYcGffYWZEjZzM=
|
||||
github.com/go-playground/validator/v10 v10.15.1/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
|
||||
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
|
||||
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/jordan-wright/email v4.0.1-0.20210109023952-943e75fe5223+incompatible h1:jdpOPRN1zP63Td1hDQbZW73xKmzDvZHzVdNYxhnTMDA=
|
||||
github.com/jordan-wright/email v4.0.1-0.20210109023952-943e75fe5223+incompatible/go.mod h1:1c7szIrayyPPB/987hsnvNzLushdWf4o/79s3P08L8A=
|
||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
|
||||
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc=
|
||||
github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
||||
github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
||||
github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
||||
github.com/mailru/easyjson v0.7.6 h1:8yTIVnZgCoiM1TgqoeTl+LfU5Jg6/xL3QhGQnimLYnA=
|
||||
github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
|
||||
github.com/minio/minio-go/v7 v7.0.66/go.mod h1:DHAgmyQEGdW3Cif0UooKOyrT3Vxs82zNdV6tkKhRtbs=
|
||||
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
|
||||
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.0.5 h1:CuQcn5HIEeK7BgElubPP8CGtE0KakrnbBSTLjathl5o=
|
||||
github.com/redis/go-redis/v9 v9.0.5/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
|
||||
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
||||
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
|
||||
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/swaggo/files v1.0.1 h1:J1bVJ4XHZNq0I46UU90611i9/YzdrF7x92oX1ig5IdE=
|
||||
github.com/swaggo/files v1.0.1/go.mod h1:0qXmMNH6sXNf+73t65aKeB+ApmgxdnkQzVTAj2uaMUg=
|
||||
github.com/swaggo/gin-swagger v1.6.0 h1:y8sxvQ3E20/RCyrXeFfg60r6H0Z+SwpTjMYsMm+zy8M=
|
||||
github.com/swaggo/gin-swagger v1.6.0/go.mod h1:BG00cCEy294xtVpyIAHG6+e2Qzj/xKlRdOqDkvq0uzo=
|
||||
github.com/swaggo/swag v1.16.2 h1:28Pp+8DkQoV+HLzLx8RGJZXNGKbFqnuvSbAAtoxiY04=
|
||||
github.com/swaggo/swag v1.16.2/go.mod h1:6YzXnDcpr0767iOejs318CwYkCQqyGer6BizOg03f+E=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/wenlng/go-captcha-assets v1.0.7 h1:tfF84A4un/i4p+TbRVHDqDPeQeatvddOfB2xbKvLVq8=
|
||||
github.com/wenlng/go-captcha-assets v1.0.7/go.mod h1:zinRACsdYcL/S6pHgI9Iv7FKTU41d00+43pNX+b9+MM=
|
||||
github.com/wenlng/go-captcha/v2 v2.0.4 h1:5cSUF36ZyA03qeDMjKmeXGpbYJMXEexZIYK3Vga3ME0=
|
||||
github.com/wenlng/go-captcha/v2 v2.0.4/go.mod h1:5hac1em3uXoyC5ipZ0xFv9umNM/waQvYAQdr0cx/h34=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
|
||||
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
|
||||
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo=
|
||||
go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
||||
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
||||
golang.org/x/image v0.16.0 h1:9kloLAKhUufZhA12l5fwnx2NZW39/we1UhBesW433jw=
|
||||
golang.org/x/image v0.16.0/go.mod h1:ugSZItdV4nOxyqp56HmXwH0Ry0nBCpjnZdpDaIHdoPs=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
|
||||
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
||||
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
||||
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
|
||||
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
|
||||
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
|
||||
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
249
internal/handler/auth_handler.go
Normal file
249
internal/handler/auth_handler.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/internal/types"
|
||||
"carrotskin/pkg/auth"
|
||||
"carrotskin/pkg/email"
|
||||
"carrotskin/pkg/logger"
|
||||
"carrotskin/pkg/redis"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Register 用户注册
|
||||
// @Summary 用户注册
|
||||
// @Description 注册新用户账号
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body types.RegisterRequest true "注册信息"
|
||||
// @Success 200 {object} model.Response "注册成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/auth/register [post]
|
||||
func Register(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
jwtService := auth.MustGetJWTService()
|
||||
redisClient := redis.MustGetClient()
|
||||
|
||||
var req types.RegisterRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 验证邮箱验证码
|
||||
if err := service.VerifyCode(c.Request.Context(), redisClient, req.Email, req.VerificationCode, service.VerificationTypeRegister); err != nil {
|
||||
loggerInstance.Warn("验证码验证失败",
|
||||
zap.String("email", req.Email),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 调用service层注册用户(传递可选的头像URL)
|
||||
user, token, err := service.RegisterUser(jwtService, req.Username, req.Password, req.Email, req.Avatar)
|
||||
if err != nil {
|
||||
loggerInstance.Error("用户注册失败", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.LoginResponse{
|
||||
Token: token,
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
Avatar: user.Avatar,
|
||||
Points: user.Points,
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
LastLoginAt: user.LastLoginAt,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// Login 用户登录
|
||||
// @Summary 用户登录
|
||||
// @Description 用户登录获取JWT Token,支持用户名或邮箱登录
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body types.LoginRequest true "登录信息(username字段支持用户名或邮箱)"
|
||||
// @Success 200 {object} model.Response{data=types.LoginResponse} "登录成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Failure 401 {object} model.ErrorResponse "登录失败"
|
||||
// @Router /api/v1/auth/login [post]
|
||||
func Login(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
jwtService := auth.MustGetJWTService()
|
||||
|
||||
var req types.LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 获取IP和UserAgent
|
||||
ipAddress := c.ClientIP()
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
|
||||
// 调用service层登录
|
||||
user, token, err := service.LoginUser(jwtService, req.Username, req.Password, ipAddress, userAgent)
|
||||
if err != nil {
|
||||
loggerInstance.Warn("用户登录失败",
|
||||
zap.String("username_or_email", req.Username),
|
||||
zap.String("ip", ipAddress),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.LoginResponse{
|
||||
Token: token,
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
Avatar: user.Avatar,
|
||||
Points: user.Points,
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
LastLoginAt: user.LastLoginAt,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// SendVerificationCode 发送验证码
|
||||
// @Summary 发送验证码
|
||||
// @Description 发送邮箱验证码(注册/重置密码/更换邮箱)
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body types.SendVerificationCodeRequest true "发送验证码请求"
|
||||
// @Success 200 {object} model.Response "发送成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/auth/send-code [post]
|
||||
func SendVerificationCode(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
redisClient := redis.MustGetClient()
|
||||
emailService := email.MustGetService()
|
||||
|
||||
var req types.SendVerificationCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 发送验证码
|
||||
if err := service.SendVerificationCode(c.Request.Context(), redisClient, emailService, req.Email, req.Type); err != nil {
|
||||
loggerInstance.Error("发送验证码失败",
|
||||
zap.String("email", req.Email),
|
||||
zap.String("type", req.Type),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(gin.H{
|
||||
"message": "验证码已发送,请查收邮件",
|
||||
}))
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
// @Summary 重置密码
|
||||
// @Description 通过邮箱验证码重置密码
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body types.ResetPasswordRequest true "重置密码请求"
|
||||
// @Success 200 {object} model.Response "重置成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/auth/reset-password [post]
|
||||
func ResetPassword(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
redisClient := redis.MustGetClient()
|
||||
|
||||
var req types.ResetPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 验证验证码
|
||||
if err := service.VerifyCode(c.Request.Context(), redisClient, req.Email, req.VerificationCode, service.VerificationTypeResetPassword); err != nil {
|
||||
loggerInstance.Warn("验证码验证失败",
|
||||
zap.String("email", req.Email),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 重置密码
|
||||
if err := service.ResetUserPassword(req.Email, req.NewPassword); err != nil {
|
||||
loggerInstance.Error("重置密码失败",
|
||||
zap.String("email", req.Email),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(gin.H{
|
||||
"message": "密码重置成功",
|
||||
}))
|
||||
}
|
||||
155
internal/handler/auth_handler_test.go
Normal file
155
internal/handler/auth_handler_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestAuthHandler_RequestValidation 测试认证请求验证逻辑
|
||||
func TestAuthHandler_RequestValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
email string
|
||||
password string
|
||||
code string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的注册请求",
|
||||
username: "testuser",
|
||||
email: "test@example.com",
|
||||
password: "password123",
|
||||
code: "123456",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "有效的登录请求",
|
||||
username: "testuser",
|
||||
email: "",
|
||||
password: "password123",
|
||||
code: "",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户名为空",
|
||||
username: "",
|
||||
email: "test@example.com",
|
||||
password: "password123",
|
||||
code: "123456",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "密码为空",
|
||||
username: "testuser",
|
||||
email: "test@example.com",
|
||||
password: "",
|
||||
code: "123456",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "注册时验证码为空",
|
||||
username: "testuser",
|
||||
email: "test@example.com",
|
||||
password: "password123",
|
||||
code: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证请求参数逻辑
|
||||
isValid := tt.username != "" && tt.password != ""
|
||||
// 如果是注册请求,还需要验证码
|
||||
if tt.email != "" && tt.code == "" {
|
||||
isValid = false
|
||||
}
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Request validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_ErrorHandling 测试错误处理逻辑
|
||||
func TestAuthHandler_ErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
wantCode int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "参数错误",
|
||||
errType: "binding",
|
||||
wantCode: 400,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "验证码错误",
|
||||
errType: "verification",
|
||||
wantCode: 400,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "登录失败",
|
||||
errType: "login",
|
||||
wantCode: 401,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "注册失败",
|
||||
errType: "register",
|
||||
wantCode: 400,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证错误处理逻辑
|
||||
if !tt.wantError {
|
||||
t.Error("Error handling test should expect error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_ResponseFormat 测试响应格式逻辑
|
||||
func TestAuthHandler_ResponseFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
success bool
|
||||
wantCode int
|
||||
hasToken bool
|
||||
}{
|
||||
{
|
||||
name: "注册成功",
|
||||
success: true,
|
||||
wantCode: 200,
|
||||
hasToken: true,
|
||||
},
|
||||
{
|
||||
name: "登录成功",
|
||||
success: true,
|
||||
wantCode: 200,
|
||||
hasToken: true,
|
||||
},
|
||||
{
|
||||
name: "发送验证码成功",
|
||||
success: true,
|
||||
wantCode: 200,
|
||||
hasToken: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证响应格式逻辑
|
||||
if tt.success && tt.wantCode != 200 {
|
||||
t.Errorf("Success response should have code 200, got %d", tt.wantCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
76
internal/handler/captcha_handler.go
Normal file
76
internal/handler/captcha_handler.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/pkg/redis"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Generate 生成验证码
|
||||
func Generate(c *gin.Context) {
|
||||
// 调用验证码服务生成验证码数据
|
||||
redisClient := redis.MustGetClient()
|
||||
masterImg, tileImg, captchaID, y, err := service.GenerateCaptchaData(c.Request.Context(), redisClient)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"msg": "生成验证码失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回验证码数据给前端
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 200,
|
||||
"data": gin.H{
|
||||
"masterImage": masterImg, // 主图(base64格式)
|
||||
"tileImage": tileImg, // 滑块图(base64格式)
|
||||
"captchaId": captchaID, // 验证码唯一标识(用于后续验证)
|
||||
"y": y, // 滑块Y坐标(前端可用于定位滑块初始位置)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Verify 验证验证码
|
||||
func Verify(c *gin.Context) {
|
||||
// 定义请求参数结构体
|
||||
var req struct {
|
||||
CaptchaID string `json:"captchaId" binding:"required"` // 验证码唯一标识
|
||||
Dx int `json:"dx" binding:"required"` // 用户滑动的X轴偏移量
|
||||
}
|
||||
|
||||
// 解析并校验请求参数
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"msg": "参数错误: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 调用验证码服务验证偏移量
|
||||
redisClient := redis.MustGetClient()
|
||||
valid, err := service.VerifyCaptchaData(c.Request.Context(), redisClient, req.Dx, req.CaptchaID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"msg": "验证失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 根据验证结果返回响应
|
||||
if valid {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 200,
|
||||
"msg": "验证成功",
|
||||
})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 400,
|
||||
"msg": "验证失败,请重试",
|
||||
})
|
||||
}
|
||||
}
|
||||
133
internal/handler/captcha_handler_test.go
Normal file
133
internal/handler/captcha_handler_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestCaptchaHandler_RequestValidation 测试验证码请求验证逻辑
|
||||
func TestCaptchaHandler_RequestValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
captchaID string
|
||||
dx int
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的请求参数",
|
||||
captchaID: "captcha-123",
|
||||
dx: 100,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "captchaID为空",
|
||||
captchaID: "",
|
||||
dx: 100,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "dx为0(可能有效)",
|
||||
captchaID: "captcha-123",
|
||||
dx: 0,
|
||||
wantValid: true, // dx为0也可能是有效的(用户没有滑动)
|
||||
},
|
||||
{
|
||||
name: "dx为负数(可能无效)",
|
||||
captchaID: "captcha-123",
|
||||
dx: -10,
|
||||
wantValid: true, // 负数也可能是有效的,取决于业务逻辑
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.captchaID != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Request validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCaptchaHandler_ResponseFormat 测试响应格式逻辑
|
||||
func TestCaptchaHandler_ResponseFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
valid bool
|
||||
wantCode int
|
||||
wantStatus string
|
||||
}{
|
||||
{
|
||||
name: "验证成功",
|
||||
valid: true,
|
||||
wantCode: 200,
|
||||
wantStatus: "验证成功",
|
||||
},
|
||||
{
|
||||
name: "验证失败",
|
||||
valid: false,
|
||||
wantCode: 400,
|
||||
wantStatus: "验证失败,请重试",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证响应格式逻辑
|
||||
var code int
|
||||
var status string
|
||||
if tt.valid {
|
||||
code = 200
|
||||
status = "验证成功"
|
||||
} else {
|
||||
code = 400
|
||||
status = "验证失败,请重试"
|
||||
}
|
||||
|
||||
if code != tt.wantCode {
|
||||
t.Errorf("Response code = %d, want %d", code, tt.wantCode)
|
||||
}
|
||||
if status != tt.wantStatus {
|
||||
t.Errorf("Response status = %q, want %q", status, tt.wantStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCaptchaHandler_ErrorHandling 测试错误处理逻辑
|
||||
func TestCaptchaHandler_ErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hasError bool
|
||||
wantCode int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "生成验证码失败",
|
||||
hasError: true,
|
||||
wantCode: 500,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "验证验证码失败",
|
||||
hasError: true,
|
||||
wantCode: 500,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "参数错误",
|
||||
hasError: true,
|
||||
wantCode: 400,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证错误处理逻辑
|
||||
if tt.hasError && !tt.wantError {
|
||||
t.Error("Error handling logic failed")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
398
internal/handler/profile_handler.go
Normal file
398
internal/handler/profile_handler.go
Normal file
@@ -0,0 +1,398 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/internal/types"
|
||||
"carrotskin/pkg/database"
|
||||
"carrotskin/pkg/logger"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// CreateProfile 创建档案
|
||||
// @Summary 创建Minecraft档案
|
||||
// @Description 创建新的Minecraft角色档案,UUID由后端自动生成
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body types.CreateProfileRequest true "档案信息(仅需提供角色名)"
|
||||
// @Success 200 {object} model.Response{data=types.ProfileInfo} "创建成功,返回完整档案信息(含自动生成的UUID)"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误或已达档案数量上限"
|
||||
// @Failure 401 {object} model.ErrorResponse "未授权"
|
||||
// @Failure 500 {object} model.ErrorResponse "服务器错误"
|
||||
// @Router /api/v1/profile [post]
|
||||
func CreateProfile(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
// 获取用户ID
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
"未授权",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求
|
||||
var req types.CreateProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误: "+err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: 从配置或数据库读取限制
|
||||
maxProfiles := 5
|
||||
db := database.MustGetDB()
|
||||
// 检查档案数量限制
|
||||
if err := service.CheckProfileLimit(db, userID.(int64), maxProfiles); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 创建档案
|
||||
profile, err := service.CreateProfile(db, userID.(int64), req.Name)
|
||||
if err != nil {
|
||||
loggerInstance.Error("创建档案失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.String("name", req.Name),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.ProfileInfo{
|
||||
UUID: profile.UUID,
|
||||
UserID: profile.UserID,
|
||||
Name: profile.Name,
|
||||
SkinID: profile.SkinID,
|
||||
CapeID: profile.CapeID,
|
||||
IsActive: profile.IsActive,
|
||||
LastUsedAt: profile.LastUsedAt,
|
||||
CreatedAt: profile.CreatedAt,
|
||||
UpdatedAt: profile.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// GetProfiles 获取档案列表
|
||||
// @Summary 获取档案列表
|
||||
// @Description 获取当前用户的所有档案
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Success 200 {object} model.Response "获取成功"
|
||||
// @Failure 401 {object} model.ErrorResponse "未授权"
|
||||
// @Failure 500 {object} model.ErrorResponse "服务器错误"
|
||||
// @Router /api/v1/profile [get]
|
||||
func GetProfiles(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
// 获取用户ID
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
"未授权",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 查询档案列表
|
||||
profiles, err := service.GetUserProfiles(database.MustGetDB(), userID.(int64))
|
||||
if err != nil {
|
||||
loggerInstance.Error("获取档案列表失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := make([]*types.ProfileInfo, 0, len(profiles))
|
||||
for _, profile := range profiles {
|
||||
result = append(result, &types.ProfileInfo{
|
||||
UUID: profile.UUID,
|
||||
UserID: profile.UserID,
|
||||
Name: profile.Name,
|
||||
SkinID: profile.SkinID,
|
||||
CapeID: profile.CapeID,
|
||||
IsActive: profile.IsActive,
|
||||
LastUsedAt: profile.LastUsedAt,
|
||||
CreatedAt: profile.CreatedAt,
|
||||
UpdatedAt: profile.UpdatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(result))
|
||||
}
|
||||
|
||||
// GetProfile 获取档案详情
|
||||
// @Summary 获取档案详情
|
||||
// @Description 根据UUID获取档案详细信息
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param uuid path string true "档案UUID"
|
||||
// @Success 200 {object} model.Response "获取成功"
|
||||
// @Failure 404 {object} model.ErrorResponse "档案不存在"
|
||||
// @Failure 500 {object} model.ErrorResponse "服务器错误"
|
||||
// @Router /api/v1/profile/{uuid} [get]
|
||||
func GetProfile(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
uuid := c.Param("uuid")
|
||||
|
||||
// 查询档案
|
||||
profile, err := service.GetProfileByUUID(database.MustGetDB(), uuid)
|
||||
if err != nil {
|
||||
loggerInstance.Error("获取档案失败",
|
||||
zap.String("uuid", uuid),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusNotFound, model.NewErrorResponse(
|
||||
model.CodeNotFound,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.ProfileInfo{
|
||||
UUID: profile.UUID,
|
||||
UserID: profile.UserID,
|
||||
Name: profile.Name,
|
||||
SkinID: profile.SkinID,
|
||||
CapeID: profile.CapeID,
|
||||
IsActive: profile.IsActive,
|
||||
LastUsedAt: profile.LastUsedAt,
|
||||
CreatedAt: profile.CreatedAt,
|
||||
UpdatedAt: profile.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// UpdateProfile 更新档案
|
||||
// @Summary 更新档案
|
||||
// @Description 更新档案信息
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param uuid path string true "档案UUID"
|
||||
// @Param request body types.UpdateProfileRequest true "更新信息"
|
||||
// @Success 200 {object} model.Response "更新成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Failure 401 {object} model.ErrorResponse "未授权"
|
||||
// @Failure 403 {object} model.ErrorResponse "无权操作"
|
||||
// @Failure 404 {object} model.ErrorResponse "档案不存在"
|
||||
// @Failure 500 {object} model.ErrorResponse "服务器错误"
|
||||
// @Router /api/v1/profile/{uuid} [put]
|
||||
func UpdateProfile(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
uuid := c.Param("uuid")
|
||||
|
||||
// 获取用户ID
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
"未授权",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求
|
||||
var req types.UpdateProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误: "+err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 更新档案
|
||||
var namePtr *string
|
||||
if req.Name != "" {
|
||||
namePtr = &req.Name
|
||||
}
|
||||
|
||||
profile, err := service.UpdateProfile(database.MustGetDB(), uuid, userID.(int64), namePtr, req.SkinID, req.CapeID)
|
||||
if err != nil {
|
||||
loggerInstance.Error("更新档案失败",
|
||||
zap.String("uuid", uuid),
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
statusCode := http.StatusInternalServerError
|
||||
if err.Error() == "档案不存在" {
|
||||
statusCode = http.StatusNotFound
|
||||
} else if err.Error() == "无权操作此档案" {
|
||||
statusCode = http.StatusForbidden
|
||||
}
|
||||
|
||||
c.JSON(statusCode, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.ProfileInfo{
|
||||
UUID: profile.UUID,
|
||||
UserID: profile.UserID,
|
||||
Name: profile.Name,
|
||||
SkinID: profile.SkinID,
|
||||
CapeID: profile.CapeID,
|
||||
IsActive: profile.IsActive,
|
||||
LastUsedAt: profile.LastUsedAt,
|
||||
CreatedAt: profile.CreatedAt,
|
||||
UpdatedAt: profile.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// DeleteProfile 删除档案
|
||||
// @Summary 删除档案
|
||||
// @Description 删除指定的Minecraft档案
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param uuid path string true "档案UUID"
|
||||
// @Success 200 {object} model.Response "删除成功"
|
||||
// @Failure 401 {object} model.ErrorResponse "未授权"
|
||||
// @Failure 403 {object} model.ErrorResponse "无权操作"
|
||||
// @Failure 404 {object} model.ErrorResponse "档案不存在"
|
||||
// @Failure 500 {object} model.ErrorResponse "服务器错误"
|
||||
// @Router /api/v1/profile/{uuid} [delete]
|
||||
func DeleteProfile(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
uuid := c.Param("uuid")
|
||||
|
||||
// 获取用户ID
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
"未授权",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 删除档案
|
||||
err := service.DeleteProfile(database.MustGetDB(), uuid, userID.(int64))
|
||||
if err != nil {
|
||||
loggerInstance.Error("删除档案失败",
|
||||
zap.String("uuid", uuid),
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
statusCode := http.StatusInternalServerError
|
||||
if err.Error() == "档案不存在" {
|
||||
statusCode = http.StatusNotFound
|
||||
} else if err.Error() == "无权操作此档案" {
|
||||
statusCode = http.StatusForbidden
|
||||
}
|
||||
|
||||
c.JSON(statusCode, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(gin.H{
|
||||
"message": "删除成功",
|
||||
}))
|
||||
}
|
||||
|
||||
// SetActiveProfile 设置活跃档案
|
||||
// @Summary 设置活跃档案
|
||||
// @Description 将指定档案设置为活跃状态
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param uuid path string true "档案UUID"
|
||||
// @Success 200 {object} model.Response "设置成功"
|
||||
// @Failure 401 {object} model.ErrorResponse "未授权"
|
||||
// @Failure 403 {object} model.ErrorResponse "无权操作"
|
||||
// @Failure 404 {object} model.ErrorResponse "档案不存在"
|
||||
// @Failure 500 {object} model.ErrorResponse "服务器错误"
|
||||
// @Router /api/v1/profile/{uuid}/activate [post]
|
||||
func SetActiveProfile(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
uuid := c.Param("uuid")
|
||||
|
||||
// 获取用户ID
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
"未授权",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 设置活跃状态
|
||||
err := service.SetActiveProfile(database.MustGetDB(), uuid, userID.(int64))
|
||||
if err != nil {
|
||||
loggerInstance.Error("设置活跃档案失败",
|
||||
zap.String("uuid", uuid),
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
statusCode := http.StatusInternalServerError
|
||||
if err.Error() == "档案不存在" {
|
||||
statusCode = http.StatusNotFound
|
||||
} else if err.Error() == "无权操作此档案" {
|
||||
statusCode = http.StatusForbidden
|
||||
}
|
||||
|
||||
c.JSON(statusCode, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(gin.H{
|
||||
"message": "设置成功",
|
||||
}))
|
||||
}
|
||||
151
internal/handler/profile_handler_test.go
Normal file
151
internal/handler/profile_handler_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestProfileHandler_PermissionCheck 测试权限检查逻辑
|
||||
func TestProfileHandler_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID interface{}
|
||||
exists bool
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的用户ID",
|
||||
userID: int64(1),
|
||||
exists: true,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户ID不存在",
|
||||
userID: nil,
|
||||
exists: false,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证权限检查逻辑
|
||||
isValid := tt.exists
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Permission check failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileHandler_RequestValidation 测试请求验证逻辑
|
||||
func TestProfileHandler_RequestValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileName string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的档案名",
|
||||
profileName: "PlayerName",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "档案名为空",
|
||||
profileName: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "档案名长度超过16",
|
||||
profileName: "ThisIsAVeryLongPlayerName",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证请求逻辑:档案名长度应该在1-16之间
|
||||
isValid := tt.profileName != "" && len(tt.profileName) >= 1 && len(tt.profileName) <= 16
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Request validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileHandler_LimitCheck 测试限制检查逻辑
|
||||
func TestProfileHandler_LimitCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentCount int
|
||||
maxCount int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "未达到限制",
|
||||
currentCount: 3,
|
||||
maxCount: 5,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "达到限制",
|
||||
currentCount: 5,
|
||||
maxCount: 5,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "超过限制",
|
||||
currentCount: 6,
|
||||
maxCount: 5,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证限制检查逻辑
|
||||
hasError := tt.currentCount >= tt.maxCount
|
||||
if hasError != tt.wantError {
|
||||
t.Errorf("Limit check failed: got error=%v, want error=%v", hasError, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileHandler_ErrorHandling 测试错误处理逻辑
|
||||
func TestProfileHandler_ErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
wantCode int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "未授权",
|
||||
errType: "unauthorized",
|
||||
wantCode: 401,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "参数错误",
|
||||
errType: "bad_request",
|
||||
wantCode: 400,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "服务器错误",
|
||||
errType: "server_error",
|
||||
wantCode: 500,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证错误处理逻辑
|
||||
if !tt.wantError {
|
||||
t.Error("Error handling test should expect error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
139
internal/handler/routes.go
Normal file
139
internal/handler/routes.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/middleware"
|
||||
"carrotskin/internal/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RegisterRoutes 注册所有路由
|
||||
func RegisterRoutes(router *gin.Engine) {
|
||||
// 设置Swagger文档
|
||||
SetupSwagger(router)
|
||||
|
||||
// API路由组
|
||||
v1 := router.Group("/api/v1")
|
||||
{
|
||||
// 认证路由(无需JWT)
|
||||
authGroup := v1.Group("/auth")
|
||||
{
|
||||
authGroup.POST("/register", Register)
|
||||
authGroup.POST("/login", Login)
|
||||
authGroup.POST("/send-code", SendVerificationCode)
|
||||
authGroup.POST("/reset-password", ResetPassword)
|
||||
}
|
||||
|
||||
// 用户路由(需要JWT认证)
|
||||
userGroup := v1.Group("/user")
|
||||
userGroup.Use(middleware.AuthMiddleware())
|
||||
{
|
||||
userGroup.GET("/profile", GetUserProfile)
|
||||
userGroup.PUT("/profile", UpdateUserProfile)
|
||||
|
||||
// 头像相关
|
||||
userGroup.POST("/avatar/upload-url", GenerateAvatarUploadURL)
|
||||
userGroup.PUT("/avatar", UpdateAvatar)
|
||||
|
||||
// 更换邮箱
|
||||
userGroup.POST("/change-email", ChangeEmail)
|
||||
}
|
||||
|
||||
// 材质路由
|
||||
textureGroup := v1.Group("/texture")
|
||||
{
|
||||
// 公开路由(无需认证)
|
||||
textureGroup.GET("", SearchTextures) // 搜索材质
|
||||
textureGroup.GET("/:id", GetTexture) // 获取材质详情
|
||||
|
||||
// 需要认证的路由
|
||||
textureAuth := textureGroup.Group("")
|
||||
textureAuth.Use(middleware.AuthMiddleware())
|
||||
{
|
||||
textureAuth.POST("/upload-url", GenerateTextureUploadURL) // 生成上传URL
|
||||
textureAuth.POST("", CreateTexture) // 创建材质记录
|
||||
textureAuth.PUT("/:id", UpdateTexture) // 更新材质
|
||||
textureAuth.DELETE("/:id", DeleteTexture) // 删除材质
|
||||
textureAuth.POST("/:id/favorite", ToggleFavorite) // 切换收藏
|
||||
textureAuth.GET("/my", GetUserTextures) // 我的材质
|
||||
textureAuth.GET("/favorites", GetUserFavorites) // 我的收藏
|
||||
}
|
||||
}
|
||||
|
||||
// 档案路由
|
||||
profileGroup := v1.Group("/profile")
|
||||
{
|
||||
// 公开路由(无需认证)
|
||||
profileGroup.GET("/:uuid", GetProfile) // 获取档案详情
|
||||
|
||||
// 需要认证的路由
|
||||
profileAuth := profileGroup.Group("")
|
||||
profileAuth.Use(middleware.AuthMiddleware())
|
||||
{
|
||||
profileAuth.POST("/", CreateProfile) // 创建档案
|
||||
profileAuth.GET("/", GetProfiles) // 获取我的档案列表
|
||||
profileAuth.PUT("/:uuid", UpdateProfile) // 更新档案
|
||||
profileAuth.DELETE("/:uuid", DeleteProfile) // 删除档案
|
||||
profileAuth.POST("/:uuid/activate", SetActiveProfile) // 设置活跃档案
|
||||
}
|
||||
}
|
||||
// 验证码路由
|
||||
captchaGroup := v1.Group("/captcha")
|
||||
{
|
||||
captchaGroup.GET("/generate", Generate) //生成验证码
|
||||
captchaGroup.POST("/verify", Verify) //验证验证码
|
||||
}
|
||||
|
||||
// Yggdrasil API路由组
|
||||
ygg := v1.Group("/yggdrasil")
|
||||
{
|
||||
ygg.GET("", GetMetaData)
|
||||
ygg.POST("/minecraftservices/player/certificates", GetPlayerCertificates)
|
||||
authserver := ygg.Group("/authserver")
|
||||
{
|
||||
authserver.POST("/authenticate", Authenticate)
|
||||
authserver.POST("/validate", ValidToken)
|
||||
authserver.POST("/refresh", RefreshToken)
|
||||
authserver.POST("/invalidate", InvalidToken)
|
||||
authserver.POST("/signout", SignOut)
|
||||
}
|
||||
sessionServer := ygg.Group("/sessionserver")
|
||||
{
|
||||
sessionServer.GET("/session/minecraft/profile/:uuid", GetProfileByUUID)
|
||||
sessionServer.POST("/session/minecraft/join", JoinServer)
|
||||
sessionServer.GET("/session/minecraft/hasJoined", HasJoinedServer)
|
||||
}
|
||||
api := ygg.Group("/api")
|
||||
profiles := api.Group("/profiles")
|
||||
{
|
||||
profiles.POST("/minecraft", GetProfilesByName)
|
||||
}
|
||||
}
|
||||
// 系统路由
|
||||
system := v1.Group("/system")
|
||||
{
|
||||
system.GET("/config", GetSystemConfig)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 以下是系统配置相关的占位符函数,待后续实现
|
||||
|
||||
// GetSystemConfig 获取系统配置
|
||||
// @Summary 获取系统配置
|
||||
// @Description 获取公开的系统配置信息
|
||||
// @Tags system
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response "获取成功"
|
||||
// @Router /api/v1/system/config [get]
|
||||
func GetSystemConfig(c *gin.Context) {
|
||||
// TODO: 实现从数据库读取系统配置
|
||||
c.JSON(200, model.NewSuccessResponse(gin.H{
|
||||
"site_name": "CarrotSkin",
|
||||
"site_description": "A Minecraft Skin Station",
|
||||
"registration_enabled": true,
|
||||
"max_textures_per_user": 100,
|
||||
"max_profiles_per_user": 5,
|
||||
}))
|
||||
}
|
||||
62
internal/handler/swagger.go
Normal file
62
internal/handler/swagger.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
swaggerFiles "github.com/swaggo/files"
|
||||
ginSwagger "github.com/swaggo/gin-swagger"
|
||||
)
|
||||
|
||||
// @title CarrotSkin API
|
||||
// @version 1.0
|
||||
// @description CarrotSkin 是一个优秀的 Minecraft 皮肤站 API 服务
|
||||
// @description
|
||||
// @description ## 功能特性
|
||||
// @description - 用户注册/登录/管理
|
||||
// @description - 材质上传/下载/管理
|
||||
// @description - Minecraft 档案管理
|
||||
// @description - 权限控制系统
|
||||
// @description - 积分系统
|
||||
// @description
|
||||
// @description ## 认证方式
|
||||
// @description 使用 JWT Token 进行身份认证,需要在请求头中包含:
|
||||
// @description ```
|
||||
// @description Authorization: Bearer <your-jwt-token>
|
||||
// @description ```
|
||||
|
||||
// @contact.name CarrotSkin Team
|
||||
// @contact.email support@carrotskin.com
|
||||
// @license.name MIT
|
||||
// @license.url https://opensource.org/licenses/MIT
|
||||
|
||||
// @host localhost:8080
|
||||
// @BasePath /api/v1
|
||||
|
||||
// @securityDefinitions.apikey BearerAuth
|
||||
// @in header
|
||||
// @name Authorization
|
||||
// @description Type "Bearer" followed by a space and JWT token.
|
||||
|
||||
func SetupSwagger(router *gin.Engine) {
|
||||
// Swagger文档路由
|
||||
router.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
|
||||
|
||||
// 健康检查接口
|
||||
router.GET("/health", HealthCheck)
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
// @Summary 健康检查
|
||||
// @Description 检查服务是否正常运行
|
||||
// @Tags system
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} map[string]interface{} "成功"
|
||||
// @Router /health [get]
|
||||
func HealthCheck(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "ok",
|
||||
"message": "CarrotSkin API is running",
|
||||
})
|
||||
}
|
||||
599
internal/handler/texture_handler.go
Normal file
599
internal/handler/texture_handler.go
Normal file
@@ -0,0 +1,599 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/internal/types"
|
||||
"carrotskin/pkg/config"
|
||||
"carrotskin/pkg/database"
|
||||
"carrotskin/pkg/logger"
|
||||
"carrotskin/pkg/storage"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// GenerateTextureUploadURL 生成材质上传URL
|
||||
// @Summary 生成材质上传URL
|
||||
// @Description 生成预签名URL用于上传材质文件
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body types.GenerateTextureUploadURLRequest true "上传URL请求"
|
||||
// @Success 200 {object} model.Response "生成成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/texture/upload-url [post]
|
||||
func GenerateTextureUploadURL(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
var req types.GenerateTextureUploadURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 调用UploadService生成预签名URL
|
||||
storageClient := storage.MustGetClient()
|
||||
cfg := *config.MustGetRustFSConfig()
|
||||
result, err := service.GenerateTextureUploadURL(
|
||||
c.Request.Context(),
|
||||
storageClient,
|
||||
cfg,
|
||||
userID.(int64),
|
||||
req.FileName,
|
||||
string(req.TextureType),
|
||||
)
|
||||
if err != nil {
|
||||
logger.MustGetLogger().Error("生成材质上传URL失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.String("file_name", req.FileName),
|
||||
zap.String("texture_type", string(req.TextureType)),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.GenerateTextureUploadURLResponse{
|
||||
PostURL: result.PostURL,
|
||||
FormData: result.FormData,
|
||||
TextureURL: result.FileURL,
|
||||
ExpiresIn: 900, // 15分钟 = 900秒
|
||||
}))
|
||||
}
|
||||
|
||||
// CreateTexture 创建材质记录
|
||||
// @Summary 创建材质记录
|
||||
// @Description 文件上传完成后,创建材质记录到数据库
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body types.CreateTextureRequest true "创建材质请求"
|
||||
// @Success 200 {object} model.Response "创建成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/texture [post]
|
||||
func CreateTexture(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
var req types.CreateTextureRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: 从配置或数据库读取限制
|
||||
maxTextures := 100
|
||||
if err := service.CheckTextureUploadLimit(database.MustGetDB(), userID.(int64), maxTextures); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 创建材质
|
||||
texture, err := service.CreateTexture(database.MustGetDB(),
|
||||
userID.(int64),
|
||||
req.Name,
|
||||
req.Description,
|
||||
string(req.Type),
|
||||
req.URL,
|
||||
req.Hash,
|
||||
req.Size,
|
||||
req.IsPublic,
|
||||
req.IsSlim,
|
||||
)
|
||||
if err != nil {
|
||||
logger.MustGetLogger().Error("创建材质失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.String("name", req.Name),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.TextureInfo{
|
||||
ID: texture.ID,
|
||||
UploaderID: texture.UploaderID,
|
||||
Name: texture.Name,
|
||||
Description: texture.Description,
|
||||
Type: types.TextureType(texture.Type),
|
||||
URL: texture.URL,
|
||||
Hash: texture.Hash,
|
||||
Size: texture.Size,
|
||||
IsPublic: texture.IsPublic,
|
||||
DownloadCount: texture.DownloadCount,
|
||||
FavoriteCount: texture.FavoriteCount,
|
||||
IsSlim: texture.IsSlim,
|
||||
Status: texture.Status,
|
||||
CreatedAt: texture.CreatedAt,
|
||||
UpdatedAt: texture.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// GetTexture 获取材质详情
|
||||
// @Summary 获取材质详情
|
||||
// @Description 根据ID获取材质详细信息
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path int true "材质ID"
|
||||
// @Success 200 {object} model.Response "获取成功"
|
||||
// @Failure 404 {object} model.ErrorResponse "材质不存在"
|
||||
// @Router /api/v1/texture/{id} [get]
|
||||
func GetTexture(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"无效的材质ID",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
texture, err := service.GetTextureByID(database.MustGetDB(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, model.NewErrorResponse(
|
||||
model.CodeNotFound,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.TextureInfo{
|
||||
ID: texture.ID,
|
||||
UploaderID: texture.UploaderID,
|
||||
Name: texture.Name,
|
||||
Description: texture.Description,
|
||||
Type: types.TextureType(texture.Type),
|
||||
URL: texture.URL,
|
||||
Hash: texture.Hash,
|
||||
Size: texture.Size,
|
||||
IsPublic: texture.IsPublic,
|
||||
DownloadCount: texture.DownloadCount,
|
||||
FavoriteCount: texture.FavoriteCount,
|
||||
IsSlim: texture.IsSlim,
|
||||
Status: texture.Status,
|
||||
CreatedAt: texture.CreatedAt,
|
||||
UpdatedAt: texture.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// SearchTextures 搜索材质
|
||||
// @Summary 搜索材质
|
||||
// @Description 根据关键词和类型搜索材质
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param keyword query string false "关键词"
|
||||
// @Param type query string false "材质类型(SKIN/CAPE)"
|
||||
// @Param public_only query bool false "只看公开材质"
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(20)
|
||||
// @Success 200 {object} model.PaginationResponse "搜索成功"
|
||||
// @Router /api/v1/texture [get]
|
||||
func SearchTextures(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
textureTypeStr := c.Query("type")
|
||||
publicOnly := c.Query("public_only") == "true"
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
var textureType model.TextureType
|
||||
switch textureTypeStr {
|
||||
case "SKIN":
|
||||
textureType = model.TextureTypeSkin
|
||||
case "CAPE":
|
||||
textureType = model.TextureTypeCape
|
||||
}
|
||||
|
||||
textures, total, err := service.SearchTextures(database.MustGetDB(), keyword, textureType, publicOnly, page, pageSize)
|
||||
if err != nil {
|
||||
logger.MustGetLogger().Error("搜索材质失败",
|
||||
zap.String("keyword", keyword),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
"搜索材质失败",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为TextureInfo
|
||||
textureInfos := make([]*types.TextureInfo, len(textures))
|
||||
for i, texture := range textures {
|
||||
textureInfos[i] = &types.TextureInfo{
|
||||
ID: texture.ID,
|
||||
UploaderID: texture.UploaderID,
|
||||
Name: texture.Name,
|
||||
Description: texture.Description,
|
||||
Type: types.TextureType(texture.Type),
|
||||
URL: texture.URL,
|
||||
Hash: texture.Hash,
|
||||
Size: texture.Size,
|
||||
IsPublic: texture.IsPublic,
|
||||
DownloadCount: texture.DownloadCount,
|
||||
FavoriteCount: texture.FavoriteCount,
|
||||
IsSlim: texture.IsSlim,
|
||||
Status: texture.Status,
|
||||
CreatedAt: texture.CreatedAt,
|
||||
UpdatedAt: texture.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewPaginationResponse(textureInfos, total, page, pageSize))
|
||||
}
|
||||
|
||||
// UpdateTexture 更新材质
|
||||
// @Summary 更新材质
|
||||
// @Description 更新材质信息(仅上传者可操作)
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param id path int true "材质ID"
|
||||
// @Param request body types.UpdateTextureRequest true "更新材质请求"
|
||||
// @Success 200 {object} model.Response "更新成功"
|
||||
// @Failure 403 {object} model.ErrorResponse "无权操作"
|
||||
// @Router /api/v1/texture/{id} [put]
|
||||
func UpdateTexture(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
idStr := c.Param("id")
|
||||
textureID, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"无效的材质ID",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
var req types.UpdateTextureRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
texture, err := service.UpdateTexture(database.MustGetDB(), textureID, userID.(int64), req.Name, req.Description, req.IsPublic)
|
||||
if err != nil {
|
||||
logger.MustGetLogger().Error("更新材质失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Int64("texture_id", textureID),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusForbidden, model.NewErrorResponse(
|
||||
model.CodeForbidden,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.TextureInfo{
|
||||
ID: texture.ID,
|
||||
UploaderID: texture.UploaderID,
|
||||
Name: texture.Name,
|
||||
Description: texture.Description,
|
||||
Type: types.TextureType(texture.Type),
|
||||
URL: texture.URL,
|
||||
Hash: texture.Hash,
|
||||
Size: texture.Size,
|
||||
IsPublic: texture.IsPublic,
|
||||
DownloadCount: texture.DownloadCount,
|
||||
FavoriteCount: texture.FavoriteCount,
|
||||
IsSlim: texture.IsSlim,
|
||||
Status: texture.Status,
|
||||
CreatedAt: texture.CreatedAt,
|
||||
UpdatedAt: texture.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// DeleteTexture 删除材质
|
||||
// @Summary 删除材质
|
||||
// @Description 删除材质(软删除,仅上传者可操作)
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param id path int true "材质ID"
|
||||
// @Success 200 {object} model.Response "删除成功"
|
||||
// @Failure 403 {object} model.ErrorResponse "无权操作"
|
||||
// @Router /api/v1/texture/{id} [delete]
|
||||
func DeleteTexture(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
idStr := c.Param("id")
|
||||
textureID, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"无效的材质ID",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.DeleteTexture(database.MustGetDB(), textureID, userID.(int64)); err != nil {
|
||||
logger.MustGetLogger().Error("删除材质失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Int64("texture_id", textureID),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusForbidden, model.NewErrorResponse(
|
||||
model.CodeForbidden,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(nil))
|
||||
}
|
||||
|
||||
// ToggleFavorite 切换收藏状态
|
||||
// @Summary 切换收藏状态
|
||||
// @Description 收藏或取消收藏材质
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param id path int true "材质ID"
|
||||
// @Success 200 {object} model.Response "切换成功"
|
||||
// @Router /api/v1/texture/{id}/favorite [post]
|
||||
func ToggleFavorite(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
idStr := c.Param("id")
|
||||
textureID, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"无效的材质ID",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
isFavorited, err := service.ToggleTextureFavorite(database.MustGetDB(), userID.(int64), textureID)
|
||||
if err != nil {
|
||||
logger.MustGetLogger().Error("切换收藏状态失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Int64("texture_id", textureID),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(map[string]bool{
|
||||
"is_favorited": isFavorited,
|
||||
}))
|
||||
}
|
||||
|
||||
// GetUserTextures 获取用户上传的材质列表
|
||||
// @Summary 获取用户上传的材质列表
|
||||
// @Description 获取当前用户上传的所有材质
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(20)
|
||||
// @Success 200 {object} model.PaginationResponse "获取成功"
|
||||
// @Router /api/v1/texture/my [get]
|
||||
func GetUserTextures(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
textures, total, err := service.GetUserTextures(database.MustGetDB(), userID.(int64), page, pageSize)
|
||||
if err != nil {
|
||||
logger.MustGetLogger().Error("获取用户材质列表失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
"获取材质列表失败",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为TextureInfo
|
||||
textureInfos := make([]*types.TextureInfo, len(textures))
|
||||
for i, texture := range textures {
|
||||
textureInfos[i] = &types.TextureInfo{
|
||||
ID: texture.ID,
|
||||
UploaderID: texture.UploaderID,
|
||||
Name: texture.Name,
|
||||
Description: texture.Description,
|
||||
Type: types.TextureType(texture.Type),
|
||||
URL: texture.URL,
|
||||
Hash: texture.Hash,
|
||||
Size: texture.Size,
|
||||
IsPublic: texture.IsPublic,
|
||||
DownloadCount: texture.DownloadCount,
|
||||
FavoriteCount: texture.FavoriteCount,
|
||||
IsSlim: texture.IsSlim,
|
||||
Status: texture.Status,
|
||||
CreatedAt: texture.CreatedAt,
|
||||
UpdatedAt: texture.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewPaginationResponse(textureInfos, total, page, pageSize))
|
||||
}
|
||||
|
||||
// GetUserFavorites 获取用户收藏的材质列表
|
||||
// @Summary 获取用户收藏的材质列表
|
||||
// @Description 获取当前用户收藏的所有材质
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(20)
|
||||
// @Success 200 {object} model.PaginationResponse "获取成功"
|
||||
// @Router /api/v1/texture/favorites [get]
|
||||
func GetUserFavorites(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
textures, total, err := service.GetUserTextureFavorites(database.MustGetDB(), userID.(int64), page, pageSize)
|
||||
if err != nil {
|
||||
logger.MustGetLogger().Error("获取用户收藏列表失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
"获取收藏列表失败",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为TextureInfo
|
||||
textureInfos := make([]*types.TextureInfo, len(textures))
|
||||
for i, texture := range textures {
|
||||
textureInfos[i] = &types.TextureInfo{
|
||||
ID: texture.ID,
|
||||
UploaderID: texture.UploaderID,
|
||||
Name: texture.Name,
|
||||
Description: texture.Description,
|
||||
Type: types.TextureType(texture.Type),
|
||||
URL: texture.URL,
|
||||
Hash: texture.Hash,
|
||||
Size: texture.Size,
|
||||
IsPublic: texture.IsPublic,
|
||||
DownloadCount: texture.DownloadCount,
|
||||
FavoriteCount: texture.FavoriteCount,
|
||||
IsSlim: texture.IsSlim,
|
||||
Status: texture.Status,
|
||||
CreatedAt: texture.CreatedAt,
|
||||
UpdatedAt: texture.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewPaginationResponse(textureInfos, total, page, pageSize))
|
||||
}
|
||||
415
internal/handler/user_handler.go
Normal file
415
internal/handler/user_handler.go
Normal file
@@ -0,0 +1,415 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/internal/types"
|
||||
"carrotskin/pkg/config"
|
||||
"carrotskin/pkg/logger"
|
||||
"carrotskin/pkg/redis"
|
||||
"carrotskin/pkg/storage"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// GetUserProfile 获取用户信息
|
||||
// @Summary 获取用户信息
|
||||
// @Description 获取当前登录用户的详细信息
|
||||
// @Tags user
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Success 200 {object} model.Response "获取成功"
|
||||
// @Failure 401 {object} model.ErrorResponse "未授权"
|
||||
// @Router /api/v1/user/profile [get]
|
||||
func GetUserProfile(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
// 从上下文获取用户ID (由JWT中间件设置)
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := service.GetUserByID(userID.(int64))
|
||||
if err != nil || user == nil {
|
||||
loggerInstance.Error("获取用户信息失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusNotFound, model.NewErrorResponse(
|
||||
model.CodeNotFound,
|
||||
"用户不存在",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回用户信息
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.UserInfo{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
Avatar: user.Avatar,
|
||||
Points: user.Points,
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
LastLoginAt: user.LastLoginAt,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// UpdateUserProfile 更新用户信息
|
||||
// @Summary 更新用户信息
|
||||
// @Description 更新当前登录用户的头像和密码(修改邮箱请使用 /change-email 接口)
|
||||
// @Tags user
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body types.UpdateUserRequest true "更新信息(修改密码时需同时提供old_password和new_password)"
|
||||
// @Success 200 {object} model.Response{data=types.UserInfo} "更新成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Failure 401 {object} model.ErrorResponse "未授权"
|
||||
// @Failure 404 {object} model.ErrorResponse "用户不存在"
|
||||
// @Failure 500 {object} model.ErrorResponse "服务器错误"
|
||||
// @Router /api/v1/user/profile [put]
|
||||
func UpdateUserProfile(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
var req types.UpdateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户
|
||||
user, err := service.GetUserByID(userID.(int64))
|
||||
if err != nil || user == nil {
|
||||
c.JSON(http.StatusNotFound, model.NewErrorResponse(
|
||||
model.CodeNotFound,
|
||||
"用户不存在",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 处理密码修改
|
||||
if req.NewPassword != "" {
|
||||
// 如果提供了新密码,必须同时提供旧密码
|
||||
if req.OldPassword == "" {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"修改密码需要提供原密码",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 调用修改密码服务
|
||||
if err := service.ChangeUserPassword(userID.(int64), req.OldPassword, req.NewPassword); err != nil {
|
||||
loggerInstance.Error("修改密码失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
loggerInstance.Info("用户修改密码成功",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
)
|
||||
}
|
||||
|
||||
// 更新头像
|
||||
if req.Avatar != "" {
|
||||
user.Avatar = req.Avatar
|
||||
}
|
||||
|
||||
// 保存更新(仅当有头像修改时)
|
||||
if req.Avatar != "" {
|
||||
if err := service.UpdateUserInfo(user); err != nil {
|
||||
loggerInstance.Error("更新用户信息失败",
|
||||
zap.Int64("user_id", user.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
"更新失败",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 重新获取更新后的用户信息
|
||||
updatedUser, err := service.GetUserByID(userID.(int64))
|
||||
if err != nil || updatedUser == nil {
|
||||
c.JSON(http.StatusNotFound, model.NewErrorResponse(
|
||||
model.CodeNotFound,
|
||||
"用户不存在",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的用户信息
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.UserInfo{
|
||||
ID: updatedUser.ID,
|
||||
Username: updatedUser.Username,
|
||||
Email: updatedUser.Email,
|
||||
Avatar: updatedUser.Avatar,
|
||||
Points: updatedUser.Points,
|
||||
Role: updatedUser.Role,
|
||||
Status: updatedUser.Status,
|
||||
LastLoginAt: updatedUser.LastLoginAt,
|
||||
CreatedAt: updatedUser.CreatedAt,
|
||||
UpdatedAt: updatedUser.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// GenerateAvatarUploadURL 生成头像上传URL
|
||||
// @Summary 生成头像上传URL
|
||||
// @Description 生成预签名URL用于上传用户头像
|
||||
// @Tags user
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body types.GenerateAvatarUploadURLRequest true "文件名"
|
||||
// @Success 200 {object} model.Response "生成成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/user/avatar/upload-url [post]
|
||||
func GenerateAvatarUploadURL(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
var req types.GenerateAvatarUploadURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 调用UploadService生成预签名URL
|
||||
storageClient := storage.MustGetClient()
|
||||
cfg := *config.MustGetRustFSConfig()
|
||||
result, err := service.GenerateAvatarUploadURL(c.Request.Context(), storageClient, cfg, userID.(int64), req.FileName)
|
||||
if err != nil {
|
||||
loggerInstance.Error("生成头像上传URL失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.String("file_name", req.FileName),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.GenerateAvatarUploadURLResponse{
|
||||
PostURL: result.PostURL,
|
||||
FormData: result.FormData,
|
||||
AvatarURL: result.FileURL,
|
||||
ExpiresIn: 900, // 15分钟 = 900秒
|
||||
}))
|
||||
}
|
||||
|
||||
// UpdateAvatar 更新头像URL
|
||||
// @Summary 更新头像URL
|
||||
// @Description 上传完成后更新用户的头像URL到数据库
|
||||
// @Tags user
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param avatar_url query string true "头像URL"
|
||||
// @Success 200 {object} model.Response "更新成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/user/avatar [put]
|
||||
func UpdateAvatar(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
avatarURL := c.Query("avatar_url")
|
||||
if avatarURL == "" {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"头像URL不能为空",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 更新头像
|
||||
if err := service.UpdateUserAvatar(userID.(int64), avatarURL); err != nil {
|
||||
loggerInstance.Error("更新头像失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.String("avatar_url", avatarURL),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
"更新头像失败",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 获取更新后的用户信息
|
||||
user, err := service.GetUserByID(userID.(int64))
|
||||
if err != nil || user == nil {
|
||||
c.JSON(http.StatusNotFound, model.NewErrorResponse(
|
||||
model.CodeNotFound,
|
||||
"用户不存在",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的用户信息
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.UserInfo{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
Avatar: user.Avatar,
|
||||
Points: user.Points,
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
LastLoginAt: user.LastLoginAt,
|
||||
CreatedAt: user.CreatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// ChangeEmail 更换邮箱
|
||||
// @Summary 更换邮箱
|
||||
// @Description 通过验证码更换用户邮箱
|
||||
// @Tags user
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body types.ChangeEmailRequest true "更换邮箱请求"
|
||||
// @Success 200 {object} model.Response{data=types.UserInfo} "更换成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Failure 401 {object} model.ErrorResponse "未授权"
|
||||
// @Router /api/v1/user/change-email [post]
|
||||
func ChangeEmail(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
var req types.ChangeEmailRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 验证验证码
|
||||
redisClient := redis.MustGetClient()
|
||||
if err := service.VerifyCode(c.Request.Context(), redisClient, req.NewEmail, req.VerificationCode, service.VerificationTypeChangeEmail); err != nil {
|
||||
loggerInstance.Warn("验证码验证失败",
|
||||
zap.String("new_email", req.NewEmail),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 更换邮箱
|
||||
if err := service.ChangeUserEmail(userID.(int64), req.NewEmail); err != nil {
|
||||
loggerInstance.Error("更换邮箱失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.String("new_email", req.NewEmail),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 获取更新后的用户信息
|
||||
user, err := service.GetUserByID(userID.(int64))
|
||||
if err != nil || user == nil {
|
||||
c.JSON(http.StatusNotFound, model.NewErrorResponse(
|
||||
model.CodeNotFound,
|
||||
"用户不存在",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.UserInfo{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
Avatar: user.Avatar,
|
||||
Points: user.Points,
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
LastLoginAt: user.LastLoginAt,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
151
internal/handler/user_handler_test.go
Normal file
151
internal/handler/user_handler_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestUserHandler_PermissionCheck 测试权限检查逻辑
|
||||
func TestUserHandler_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID interface{}
|
||||
exists bool
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的用户ID",
|
||||
userID: int64(1),
|
||||
exists: true,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户ID不存在",
|
||||
userID: nil,
|
||||
exists: false,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "用户ID类型错误",
|
||||
userID: "invalid",
|
||||
exists: true,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证权限检查逻辑
|
||||
isValid := tt.exists
|
||||
if tt.exists {
|
||||
// 验证类型转换
|
||||
if _, ok := tt.userID.(int64); !ok {
|
||||
isValid = false
|
||||
}
|
||||
}
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Permission check failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserHandler_RequestValidation 测试请求验证逻辑
|
||||
func TestUserHandler_RequestValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
avatar string
|
||||
oldPass string
|
||||
newPass string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "只更新头像",
|
||||
avatar: "https://example.com/avatar.png",
|
||||
oldPass: "",
|
||||
newPass: "",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "更新密码(提供旧密码和新密码)",
|
||||
avatar: "",
|
||||
oldPass: "oldpass123",
|
||||
newPass: "newpass123",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "只提供新密码(无效)",
|
||||
avatar: "",
|
||||
oldPass: "",
|
||||
newPass: "newpass123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "只提供旧密码(无效)",
|
||||
avatar: "",
|
||||
oldPass: "oldpass123",
|
||||
newPass: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证请求逻辑:更新密码时需要同时提供旧密码和新密码
|
||||
isValid := true
|
||||
if tt.newPass != "" && tt.oldPass == "" {
|
||||
isValid = false
|
||||
}
|
||||
if tt.oldPass != "" && tt.newPass == "" {
|
||||
isValid = false
|
||||
}
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Request validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserHandler_ErrorHandling 测试错误处理逻辑
|
||||
func TestUserHandler_ErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
wantCode int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "未授权",
|
||||
errType: "unauthorized",
|
||||
wantCode: 401,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "用户不存在",
|
||||
errType: "not_found",
|
||||
wantCode: 404,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "参数错误",
|
||||
errType: "bad_request",
|
||||
wantCode: 400,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "服务器错误",
|
||||
errType: "server_error",
|
||||
wantCode: 500,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证错误处理逻辑
|
||||
if !tt.wantError {
|
||||
t.Error("Error handling test should expect error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
666
internal/handler/yggdrasil_handler.go
Normal file
666
internal/handler/yggdrasil_handler.go
Normal file
@@ -0,0 +1,666 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/pkg/database"
|
||||
"carrotskin/pkg/logger"
|
||||
"carrotskin/pkg/redis"
|
||||
"carrotskin/pkg/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// 常量定义
|
||||
const (
|
||||
ErrInternalServer = "服务器内部错误"
|
||||
// 错误类型
|
||||
ErrInvalidEmailFormat = "邮箱格式不正确"
|
||||
ErrInvalidPassword = "密码必须至少包含8个字符,只能包含字母、数字和特殊字符"
|
||||
ErrWrongPassword = "密码错误"
|
||||
ErrUserNotMatch = "用户不匹配"
|
||||
|
||||
// 错误消息
|
||||
ErrInvalidRequest = "请求格式无效"
|
||||
ErrJoinServerFailed = "加入服务器失败"
|
||||
ErrServerIDRequired = "服务器ID不能为空"
|
||||
ErrUsernameRequired = "用户名不能为空"
|
||||
ErrSessionVerifyFailed = "会话验证失败"
|
||||
ErrProfileNotFound = "未找到用户配置文件"
|
||||
ErrInvalidParams = "无效的请求参数"
|
||||
ErrEmptyUserID = "用户ID为空"
|
||||
ErrUnauthorized = "无权操作此配置文件"
|
||||
ErrGetProfileService = "获取配置文件服务失败"
|
||||
|
||||
// 成功信息
|
||||
SuccessProfileCreated = "创建成功"
|
||||
MsgRegisterSuccess = "注册成功"
|
||||
|
||||
// 错误消息
|
||||
ErrGetProfile = "获取配置文件失败"
|
||||
ErrGetTextureService = "获取材质服务失败"
|
||||
ErrInvalidContentType = "无效的请求内容类型"
|
||||
ErrParseMultipartForm = "解析多部分表单失败"
|
||||
ErrGetFileFromForm = "从表单获取文件失败"
|
||||
ErrInvalidFileType = "无效的文件类型,仅支持PNG图片"
|
||||
ErrSaveTexture = "保存材质失败"
|
||||
ErrSetTexture = "设置材质失败"
|
||||
ErrGetTexture = "获取材质失败"
|
||||
|
||||
// 内存限制
|
||||
MaxMultipartMemory = 32 << 20 // 32 MB
|
||||
|
||||
// 材质类型
|
||||
TextureTypeSkin = "SKIN"
|
||||
TextureTypeCape = "CAPE"
|
||||
|
||||
// 内容类型
|
||||
ContentTypePNG = "image/png"
|
||||
ContentTypeMultipart = "multipart/form-data"
|
||||
|
||||
// 表单参数
|
||||
FormKeyModel = "model"
|
||||
FormKeyFile = "file"
|
||||
|
||||
// 元数据键
|
||||
MetaKeyModel = "model"
|
||||
)
|
||||
|
||||
// 正则表达式
|
||||
var (
|
||||
// 邮箱正则表达式
|
||||
emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
|
||||
|
||||
// 密码强度正则表达式(最少8位,只允许字母、数字和特定特殊字符)
|
||||
passwordRegex = regexp.MustCompile(`^[a-zA-Z0-9!@#$%^&*]{8,}$`)
|
||||
)
|
||||
|
||||
// 请求结构体
|
||||
type (
|
||||
// AuthenticateRequest 认证请求
|
||||
AuthenticateRequest struct {
|
||||
Agent map[string]interface{} `json:"agent"`
|
||||
ClientToken string `json:"clientToken"`
|
||||
Identifier string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
RequestUser bool `json:"requestUser"`
|
||||
}
|
||||
|
||||
// ValidTokenRequest 验证令牌请求
|
||||
ValidTokenRequest struct {
|
||||
AccessToken string `json:"accessToken" binding:"required"`
|
||||
ClientToken string `json:"clientToken"`
|
||||
}
|
||||
|
||||
// RefreshRequest 刷新令牌请求
|
||||
RefreshRequest struct {
|
||||
AccessToken string `json:"accessToken" binding:"required"`
|
||||
ClientToken string `json:"clientToken"`
|
||||
RequestUser bool `json:"requestUser"`
|
||||
SelectedProfile map[string]interface{} `json:"selectedProfile"`
|
||||
}
|
||||
|
||||
// SignOutRequest 登出请求
|
||||
SignOutRequest struct {
|
||||
Email string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
JoinServerRequest struct {
|
||||
ServerID string `json:"serverId" binding:"required"`
|
||||
AccessToken string `json:"accessToken" binding:"required"`
|
||||
SelectedProfile string `json:"selectedProfile" binding:"required"`
|
||||
}
|
||||
)
|
||||
|
||||
// 响应结构体
|
||||
type (
|
||||
// AuthenticateResponse 认证响应
|
||||
AuthenticateResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
ClientToken string `json:"clientToken"`
|
||||
SelectedProfile map[string]interface{} `json:"selectedProfile,omitempty"`
|
||||
AvailableProfiles []map[string]interface{} `json:"availableProfiles"`
|
||||
User map[string]interface{} `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// RefreshResponse 刷新令牌响应
|
||||
RefreshResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
ClientToken string `json:"clientToken"`
|
||||
SelectedProfile map[string]interface{} `json:"selectedProfile,omitempty"`
|
||||
User map[string]interface{} `json:"user,omitempty"`
|
||||
}
|
||||
)
|
||||
|
||||
type APIResponse struct {
|
||||
Status int `json:"status"`
|
||||
Data interface{} `json:"data"`
|
||||
Error interface{} `json:"error"`
|
||||
}
|
||||
|
||||
// standardResponse 生成标准响应
|
||||
func standardResponse(c *gin.Context, status int, data interface{}, err interface{}) {
|
||||
c.JSON(status, APIResponse{
|
||||
Status: status,
|
||||
Data: data,
|
||||
Error: err,
|
||||
})
|
||||
}
|
||||
|
||||
// Authenticate 用户认证
|
||||
func Authenticate(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
|
||||
// 读取并保存原始请求体,以便多次读取
|
||||
rawData, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 读取请求体失败: ", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
|
||||
return
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(rawData))
|
||||
|
||||
// 绑定JSON数据到请求结构体
|
||||
var request AuthenticateRequest
|
||||
if err = c.ShouldBindJSON(&request); err != nil {
|
||||
loggerInstance.Error("[ERROR] 解析认证请求失败: ", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 根据标识符类型(邮箱或用户名)获取用户
|
||||
var userId int64
|
||||
var profile *model.Profile
|
||||
var UUID string
|
||||
if emailRegex.MatchString(request.Identifier) {
|
||||
userId, err = service.GetUserIDByEmail(db, request.Identifier)
|
||||
} else {
|
||||
profile, err = service.GetProfileByProfileName(db, request.Identifier)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 用户名不存在: ", zap.String("标识符", request.Identifier), zap.Error(err))
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
userId = profile.UserID
|
||||
UUID = profile.UUID
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
loggerInstance.Warn("[WARN] 认证失败: 用户不存在",
|
||||
zap.String("标识符:", request.Identifier),
|
||||
zap.Error(err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
err = service.VerifyPassword(db, request.Password, userId)
|
||||
if err != nil {
|
||||
loggerInstance.Warn("[WARN] 认证失败:", zap.Error(err))
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": ErrWrongPassword})
|
||||
return
|
||||
}
|
||||
// 生成新令牌
|
||||
selectedProfile, availableProfiles, accessToken, clientToken, err := service.NewToken(db, loggerInstance, userId, UUID, request.ClientToken)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 生成令牌失败:", zap.Error(err), zap.Any("用户ID:", userId))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := service.GetUserByID(userId)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] id查找错误:", zap.Error(err), zap.Any("ID:", userId))
|
||||
}
|
||||
// 处理可用的配置文件
|
||||
redisClient := redis.MustGetClient()
|
||||
availableProfilesData := make([]map[string]interface{}, 0, len(availableProfiles))
|
||||
for _, profile := range availableProfiles {
|
||||
availableProfilesData = append(availableProfilesData, service.SerializeProfile(db, loggerInstance, redisClient, *profile))
|
||||
}
|
||||
response := AuthenticateResponse{
|
||||
AccessToken: accessToken,
|
||||
ClientToken: clientToken,
|
||||
AvailableProfiles: availableProfilesData,
|
||||
}
|
||||
if selectedProfile != nil {
|
||||
response.SelectedProfile = service.SerializeProfile(db, loggerInstance, redisClient, *selectedProfile)
|
||||
}
|
||||
if request.RequestUser {
|
||||
response.User = map[string]interface{}{
|
||||
"id": userId,
|
||||
"properties": user.Properties,
|
||||
}
|
||||
}
|
||||
|
||||
// 返回认证响应
|
||||
loggerInstance.Info("[INFO] 用户认证成功", zap.Any("用户ID:", userId))
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// ValidToken 验证令牌
|
||||
func ValidToken(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
|
||||
var request ValidTokenRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
loggerInstance.Error("[ERROR] 解析验证令牌请求失败: ", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
// 验证令牌
|
||||
if service.ValidToken(db, request.AccessToken, request.ClientToken) {
|
||||
loggerInstance.Info("[INFO] 令牌验证成功", zap.Any("访问令牌:", request.AccessToken))
|
||||
c.JSON(http.StatusNoContent, gin.H{"valid": true})
|
||||
} else {
|
||||
loggerInstance.Warn("[WARN] 令牌验证失败", zap.Any("访问令牌:", request.AccessToken))
|
||||
c.JSON(http.StatusForbidden, gin.H{"valid": false})
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshToken 刷新令牌
|
||||
func RefreshToken(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
|
||||
var request RefreshRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
loggerInstance.Error("[ERROR] 解析刷新令牌请求失败: ", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户ID和用户信息
|
||||
UUID, err := service.GetUUIDByAccessToken(db, request.AccessToken)
|
||||
if err != nil {
|
||||
loggerInstance.Warn("[WARN] 刷新令牌失败: 无效的访问令牌", zap.Any("令牌:", request.AccessToken), zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
userID, _ := service.GetUserIDByAccessToken(db, request.AccessToken)
|
||||
// 格式化UUID 这里是因为HMCL的传入参数是HEX格式,为了兼容HMCL,在此做处理
|
||||
UUID = utils.FormatUUID(UUID)
|
||||
|
||||
profile, err := service.GetProfileByUUID(db, UUID)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 刷新令牌失败: 无法获取用户信息 错误: ", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 准备响应数据
|
||||
var profileData map[string]interface{}
|
||||
var userData map[string]interface{}
|
||||
var profileID string
|
||||
|
||||
// 处理选定的配置文件
|
||||
if request.SelectedProfile != nil {
|
||||
// 验证profileID是否存在
|
||||
profileIDValue, ok := request.SelectedProfile["id"]
|
||||
if !ok {
|
||||
loggerInstance.Error("[ERROR] 刷新令牌失败: 缺少配置文件ID", zap.Any("ID:", userID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少配置文件ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// 类型断言
|
||||
profileID, ok = profileIDValue.(string)
|
||||
if !ok {
|
||||
loggerInstance.Error("[ERROR] 刷新令牌失败: 配置文件ID类型错误 ", zap.Any("用户ID:", userID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "配置文件ID必须是字符串"})
|
||||
return
|
||||
}
|
||||
|
||||
// 格式化profileID
|
||||
profileID = utils.FormatUUID(profileID)
|
||||
|
||||
// 验证配置文件所属用户
|
||||
if profile.UserID != userID {
|
||||
loggerInstance.Warn("[WARN] 刷新令牌失败: 用户不匹配 ", zap.Any("用户ID:", userID), zap.Any("配置文件用户ID:", profile.UserID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": ErrUserNotMatch})
|
||||
return
|
||||
}
|
||||
|
||||
profileData = service.SerializeProfile(db, loggerInstance, redis.MustGetClient(), *profile)
|
||||
}
|
||||
user, _ := service.GetUserByID(userID)
|
||||
// 添加用户信息(如果请求了)
|
||||
if request.RequestUser {
|
||||
userData = service.SerializeUser(loggerInstance, user, UUID)
|
||||
}
|
||||
|
||||
// 刷新令牌
|
||||
newAccessToken, newClientToken, err := service.RefreshToken(db, loggerInstance,
|
||||
request.AccessToken,
|
||||
request.ClientToken,
|
||||
profileID,
|
||||
)
|
||||
if err != nil {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
loggerInstance.Error("[ERROR] 刷新令牌失败: ", zap.Error(err), zap.Any("用户ID: ", userID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
loggerInstance.Info("[INFO] 刷新令牌成功", zap.Any("用户ID:", userID))
|
||||
c.JSON(http.StatusOK, RefreshResponse{
|
||||
AccessToken: newAccessToken,
|
||||
ClientToken: newClientToken,
|
||||
SelectedProfile: profileData,
|
||||
User: userData,
|
||||
})
|
||||
}
|
||||
|
||||
// InvalidToken 使令牌失效
|
||||
func InvalidToken(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
|
||||
var request ValidTokenRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
loggerInstance.Error("[ERROR] 解析使令牌失效请求失败: ", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
// 使令牌失效
|
||||
service.InvalidToken(db, loggerInstance, request.AccessToken)
|
||||
loggerInstance.Info("[INFO] 令牌已使失效", zap.Any("访问令牌:", request.AccessToken))
|
||||
c.JSON(http.StatusNoContent, gin.H{})
|
||||
}
|
||||
|
||||
// SignOut 用户登出
|
||||
func SignOut(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
|
||||
var request SignOutRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
loggerInstance.Error("[ERROR] 解析登出请求失败: %v", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证邮箱格式
|
||||
if !emailRegex.MatchString(request.Email) {
|
||||
loggerInstance.Warn("[WARN] 登出失败: 邮箱格式不正确 ", zap.Any(" ", request.Email))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": ErrInvalidEmailFormat})
|
||||
return
|
||||
}
|
||||
|
||||
// 通过邮箱获取用户
|
||||
user, err := service.GetUserByEmail(request.Email)
|
||||
if err != nil {
|
||||
loggerInstance.Warn(
|
||||
"登出失败: 用户不存在",
|
||||
zap.String("邮箱", request.Email),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
password, err := service.GetPasswordByUserId(db, user.ID)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 邮箱查找失败", zap.Any("UserId:", user.ID), zap.Error(err))
|
||||
}
|
||||
// 验证密码
|
||||
if password != request.Password {
|
||||
loggerInstance.Warn("[WARN] 登出失败: 密码错误", zap.Any("用户ID:", user.ID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": ErrWrongPassword})
|
||||
return
|
||||
}
|
||||
|
||||
// 使该用户的所有令牌失效
|
||||
service.InvalidUserTokens(db, loggerInstance, user.ID)
|
||||
loggerInstance.Info("[INFO] 用户登出成功", zap.Any("用户ID:", user.ID))
|
||||
c.JSON(http.StatusNoContent, gin.H{"valid": true})
|
||||
}
|
||||
|
||||
func GetProfileByUUID(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
redisClient := redis.MustGetClient()
|
||||
|
||||
// 获取并格式化UUID
|
||||
uuid := utils.FormatUUID(c.Param("uuid"))
|
||||
loggerInstance.Info("[INFO] 接收到获取配置文件请求", zap.Any("UUID:", uuid))
|
||||
|
||||
// 获取配置文件
|
||||
profile, err := service.GetProfileByUUID(db, uuid)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 获取配置文件失败:", zap.Error(err), zap.String("UUID:", uuid))
|
||||
standardResponse(c, http.StatusInternalServerError, nil, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 返回配置文件信息
|
||||
loggerInstance.Info("[INFO] 成功获取配置文件", zap.String("UUID:", uuid), zap.String("名称:", profile.Name))
|
||||
c.JSON(http.StatusOK, service.SerializeProfile(db, loggerInstance, redisClient, *profile))
|
||||
}
|
||||
|
||||
func JoinServer(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
redisClient := redis.MustGetClient()
|
||||
|
||||
var request JoinServerRequest
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
// 解析请求参数
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
loggerInstance.Error(
|
||||
"解析加入服务器请求失败",
|
||||
zap.Error(err),
|
||||
zap.String("IP", clientIP),
|
||||
)
|
||||
standardResponse(c, http.StatusBadRequest, nil, ErrInvalidRequest)
|
||||
return
|
||||
}
|
||||
|
||||
loggerInstance.Info(
|
||||
"收到加入服务器请求",
|
||||
zap.String("服务器ID", request.ServerID),
|
||||
zap.String("用户UUID", request.SelectedProfile),
|
||||
zap.String("IP", clientIP),
|
||||
)
|
||||
|
||||
// 处理加入服务器请求
|
||||
if err := service.JoinServer(db, loggerInstance, redisClient, request.ServerID, request.AccessToken, request.SelectedProfile, clientIP); err != nil {
|
||||
loggerInstance.Error(
|
||||
"加入服务器失败",
|
||||
zap.Error(err),
|
||||
zap.String("服务器ID", request.ServerID),
|
||||
zap.String("用户UUID", request.SelectedProfile),
|
||||
zap.String("IP", clientIP),
|
||||
)
|
||||
standardResponse(c, http.StatusInternalServerError, nil, ErrJoinServerFailed)
|
||||
return
|
||||
}
|
||||
|
||||
// 加入成功,返回204状态码
|
||||
loggerInstance.Info(
|
||||
"加入服务器成功",
|
||||
zap.String("服务器ID", request.ServerID),
|
||||
zap.String("用户UUID", request.SelectedProfile),
|
||||
zap.String("IP", clientIP),
|
||||
)
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func HasJoinedServer(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
redisClient := redis.MustGetClient()
|
||||
|
||||
clientIP, _ := c.GetQuery("ip")
|
||||
|
||||
// 获取并验证服务器ID参数
|
||||
serverID, exists := c.GetQuery("serverId")
|
||||
if !exists || serverID == "" {
|
||||
loggerInstance.Warn("[WARN] 缺少服务器ID参数", zap.Any("IP:", clientIP))
|
||||
standardResponse(c, http.StatusNoContent, nil, ErrServerIDRequired)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取并验证用户名参数
|
||||
username, exists := c.GetQuery("username")
|
||||
if !exists || username == "" {
|
||||
loggerInstance.Warn("[WARN] 缺少用户名参数", zap.Any("服务器ID:", serverID), zap.Any("IP:", clientIP))
|
||||
standardResponse(c, http.StatusNoContent, nil, ErrUsernameRequired)
|
||||
return
|
||||
}
|
||||
|
||||
loggerInstance.Info("[INFO] 收到会话验证请求", zap.Any("服务器ID:", serverID), zap.Any("用户名: ", username), zap.Any("IP: ", clientIP))
|
||||
|
||||
// 验证玩家是否已加入服务器
|
||||
if err := service.HasJoinedServer(loggerInstance, redisClient, serverID, username, clientIP); err != nil {
|
||||
loggerInstance.Warn("[WARN] 会话验证失败",
|
||||
zap.Error(err),
|
||||
zap.String("serverID", serverID),
|
||||
zap.String("username", username),
|
||||
zap.String("clientIP", clientIP),
|
||||
)
|
||||
standardResponse(c, http.StatusNoContent, nil, ErrSessionVerifyFailed)
|
||||
return
|
||||
}
|
||||
|
||||
profile, err := service.GetProfileByUUID(db, username)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 获取用户配置文件失败: %v - 用户名: %s",
|
||||
zap.Error(err), // 错误详情(zap 原生支持,保留错误链)
|
||||
zap.String("username", username), // 结构化存储用户名(便于检索)
|
||||
)
|
||||
standardResponse(c, http.StatusNoContent, nil, ErrProfileNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回玩家配置文件
|
||||
loggerInstance.Info("[INFO] 会话验证成功 - 服务器ID: %s, 用户名: %s, UUID: %s",
|
||||
zap.String("serverID", serverID), // 结构化存储服务器ID
|
||||
zap.String("username", username), // 结构化存储用户名
|
||||
zap.String("UUID", profile.UUID), // 结构化存储UUID
|
||||
)
|
||||
c.JSON(200, service.SerializeProfile(db, loggerInstance, redisClient, *profile))
|
||||
}
|
||||
|
||||
func GetProfilesByName(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
|
||||
var names []string
|
||||
|
||||
// 解析请求参数
|
||||
if err := c.ShouldBindJSON(&names); err != nil {
|
||||
loggerInstance.Error("[ERROR] 解析名称数组请求失败: ",
|
||||
zap.Error(err),
|
||||
)
|
||||
standardResponse(c, http.StatusBadRequest, nil, ErrInvalidParams)
|
||||
return
|
||||
}
|
||||
loggerInstance.Info("[INFO] 接收到批量获取配置文件请求",
|
||||
zap.Int("名称数量:", len(names)), // 结构化存储名称数量
|
||||
)
|
||||
|
||||
// 批量获取配置文件
|
||||
profiles, err := service.GetProfilesDataByNames(db, names)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 获取配置文件失败: ",
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
// 改造:zap 兼容原有 INFO 日志格式
|
||||
loggerInstance.Info("[INFO] 成功获取配置文件",
|
||||
zap.Int("请求名称数:", len(names)),
|
||||
zap.Int("返回结果数: ", len(profiles)),
|
||||
)
|
||||
|
||||
c.JSON(http.StatusOK, profiles)
|
||||
}
|
||||
|
||||
func GetMetaData(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
redisClient := redis.MustGetClient()
|
||||
|
||||
meta := gin.H{
|
||||
"implementationName": "CellAuth",
|
||||
"implementationVersion": "0.0.1",
|
||||
"serverName": "LittleLan's Yggdrasil Server Implementation.",
|
||||
"links": gin.H{
|
||||
"homepage": "https://skin.littlelan.cn",
|
||||
"register": "https://skin.littlelan.cn/auth",
|
||||
},
|
||||
"feature.non_email_login": true,
|
||||
"feature.enable_profile_key": true,
|
||||
}
|
||||
skinDomains := []string{".hitwh.games", ".littlelan.cn"}
|
||||
signature, err := service.GetPublicKeyFromRedisFunc(loggerInstance, redisClient)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 获取公钥失败: ", zap.Error(err))
|
||||
standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer)
|
||||
return
|
||||
}
|
||||
|
||||
loggerInstance.Info("[INFO] 提供元数据")
|
||||
c.JSON(http.StatusOK, gin.H{"meta": meta,
|
||||
"skinDomains": skinDomains,
|
||||
"signaturePublickey": signature})
|
||||
}
|
||||
|
||||
func GetPlayerCertificates(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
redisClient := redis.MustGetClient()
|
||||
|
||||
var uuid string
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header not provided"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否以 Bearer 开头并提取 sessionID
|
||||
bearerPrefix := "Bearer "
|
||||
if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
tokenID := authHeader[len(bearerPrefix):]
|
||||
if tokenID == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
var err error
|
||||
uuid, err = service.GetUUIDByAccessToken(db, tokenID)
|
||||
|
||||
if uuid == "" {
|
||||
loggerInstance.Error("[ERROR] 获取玩家UUID失败: ", zap.Error(err))
|
||||
standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer)
|
||||
return
|
||||
}
|
||||
|
||||
// 格式化UUID
|
||||
uuid = utils.FormatUUID(uuid)
|
||||
|
||||
// 生成玩家证书
|
||||
certificate, err := service.GeneratePlayerCertificate(db, loggerInstance, redisClient, uuid)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 生成玩家证书失败: ", zap.Error(err))
|
||||
standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer)
|
||||
return
|
||||
}
|
||||
|
||||
loggerInstance.Info("[INFO] 成功生成玩家证书")
|
||||
c.JSON(http.StatusOK, certificate)
|
||||
}
|
||||
157
internal/handler/yggdrasil_handler_test.go
Normal file
157
internal/handler/yggdrasil_handler_test.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestYggdrasilHandler_EmailValidation 测试邮箱验证逻辑
|
||||
func TestYggdrasilHandler_EmailValidation(t *testing.T) {
|
||||
// 使用简单的邮箱正则表达式
|
||||
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的邮箱",
|
||||
email: "test@example.com",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "无效的邮箱格式",
|
||||
email: "invalid-email",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "缺少@符号",
|
||||
email: "testexample.com",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "缺少域名",
|
||||
email: "test@",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "空邮箱",
|
||||
email: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := emailRegex.MatchString(tt.email)
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Email validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestYggdrasilHandler_RequestValidation 测试请求验证逻辑
|
||||
func TestYggdrasilHandler_RequestValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
serverID string
|
||||
username string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的请求",
|
||||
accessToken: "token-123",
|
||||
serverID: "server-456",
|
||||
username: "player",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "accessToken为空",
|
||||
accessToken: "",
|
||||
serverID: "server-456",
|
||||
username: "player",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverID为空",
|
||||
accessToken: "token-123",
|
||||
serverID: "",
|
||||
username: "player",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "username为空",
|
||||
accessToken: "token-123",
|
||||
serverID: "server-456",
|
||||
username: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.accessToken != "" && tt.serverID != "" && tt.username != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Request validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestYggdrasilHandler_ErrorHandling 测试错误处理逻辑
|
||||
func TestYggdrasilHandler_ErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
wantCode int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "参数错误",
|
||||
errType: "bad_request",
|
||||
wantCode: 400,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "未授权",
|
||||
errType: "forbidden",
|
||||
wantCode: 403,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "服务器错误",
|
||||
errType: "server_error",
|
||||
wantCode: 500,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证错误处理逻辑
|
||||
if !tt.wantError {
|
||||
t.Error("Error handling test should expect error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestYggdrasilHandler_Constants 测试常量定义
|
||||
func TestYggdrasilHandler_Constants(t *testing.T) {
|
||||
// 验证常量定义
|
||||
if MaxMultipartMemory != 32<<20 {
|
||||
t.Errorf("MaxMultipartMemory = %d, want %d", MaxMultipartMemory, 32<<20)
|
||||
}
|
||||
|
||||
if TextureTypeSkin != "SKIN" {
|
||||
t.Errorf("TextureTypeSkin = %q, want 'SKIN'", TextureTypeSkin)
|
||||
}
|
||||
|
||||
if TextureTypeCape != "CAPE" {
|
||||
t.Errorf("TextureTypeCape = %q, want 'CAPE'", TextureTypeCape)
|
||||
}
|
||||
}
|
||||
|
||||
78
internal/middleware/auth.go
Normal file
78
internal/middleware/auth.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"carrotskin/pkg/auth"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AuthMiddleware JWT认证中间件
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return gin.HandlerFunc(func(c *gin.Context) {
|
||||
jwtService := auth.MustGetJWTService()
|
||||
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "缺少Authorization头",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Bearer token格式
|
||||
tokenParts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "无效的Authorization头格式",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token := tokenParts[1]
|
||||
claims, err := jwtService.ValidateToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "无效的token",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息存储到上下文中
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("role", claims.Role)
|
||||
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
|
||||
// OptionalAuthMiddleware 可选的JWT认证中间件
|
||||
func OptionalAuthMiddleware() gin.HandlerFunc {
|
||||
return gin.HandlerFunc(func(c *gin.Context) {
|
||||
jwtService := auth.MustGetJWTService()
|
||||
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
tokenParts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(tokenParts) == 2 && tokenParts[0] == "Bearer" {
|
||||
token := tokenParts[1]
|
||||
claims, err := jwtService.ValidateToken(token)
|
||||
if err == nil {
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("role", claims.Role)
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
158
internal/middleware/auth_test.go
Normal file
158
internal/middleware/auth_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"carrotskin/pkg/auth"
|
||||
)
|
||||
|
||||
// TestAuthMiddleware_MissingHeader 测试缺少Authorization头的情况
|
||||
// 注意:这个测试需要auth服务初始化,暂时跳过实际执行
|
||||
func TestAuthMiddleware_MissingHeader(t *testing.T) {
|
||||
// 测试逻辑:缺少Authorization头应该返回401
|
||||
// 由于需要auth服务初始化,这里只测试逻辑部分
|
||||
hasHeader := false
|
||||
if hasHeader {
|
||||
t.Error("测试场景应该没有Authorization头")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_InvalidFormat 测试无效的Authorization头格式
|
||||
// 注意:这个测试需要auth服务初始化,这里只测试解析逻辑
|
||||
func TestAuthMiddleware_InvalidFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "缺少Bearer前缀",
|
||||
header: "token123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "只有Bearer没有token",
|
||||
header: "Bearer",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "空字符串",
|
||||
header: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "错误的格式",
|
||||
header: "Token token123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "标准格式",
|
||||
header: "Bearer token123",
|
||||
wantValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 测试header解析逻辑
|
||||
tokenParts := strings.SplitN(tt.header, " ", 2)
|
||||
isValid := len(tokenParts) == 2 && tokenParts[0] == "Bearer"
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Header validation: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_ValidToken 测试有效token的情况
|
||||
// 注意:这个测试需要auth服务初始化,这里只测试token格式
|
||||
func TestAuthMiddleware_ValidToken(t *testing.T) {
|
||||
// 创建JWT服务并生成token
|
||||
jwtService := auth.NewJWTService("test-secret-key", 24)
|
||||
token, err := jwtService.GenerateToken(1, "testuser", "user")
|
||||
if err != nil {
|
||||
t.Fatalf("生成token失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证token格式
|
||||
if token == "" {
|
||||
t.Error("生成的token不应为空")
|
||||
}
|
||||
|
||||
// 验证可以解析token
|
||||
claims, err := jwtService.ValidateToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("验证token失败: %v", err)
|
||||
}
|
||||
|
||||
if claims.UserID != 1 {
|
||||
t.Errorf("UserID = %d, want 1", claims.UserID)
|
||||
}
|
||||
if claims.Username != "testuser" {
|
||||
t.Errorf("Username = %q, want 'testuser'", claims.Username)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOptionalAuthMiddleware_NoHeader 测试可选认证中间件无header的情况
|
||||
// 注意:这个测试需要auth服务初始化,这里只测试逻辑
|
||||
func TestOptionalAuthMiddleware_NoHeader(t *testing.T) {
|
||||
// 测试逻辑:可选认证中间件在没有header时应该允许请求继续
|
||||
hasHeader := false
|
||||
shouldContinue := true // 可选认证应该允许继续
|
||||
|
||||
if hasHeader && !shouldContinue {
|
||||
t.Error("可选认证逻辑错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_HeaderParsing 测试Authorization头解析逻辑
|
||||
func TestAuthMiddleware_HeaderParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
wantValid bool
|
||||
wantToken string
|
||||
}{
|
||||
{
|
||||
name: "标准Bearer格式",
|
||||
header: "Bearer token123",
|
||||
wantValid: true,
|
||||
wantToken: "token123",
|
||||
},
|
||||
{
|
||||
name: "Bearer后多个空格",
|
||||
header: "Bearer token123",
|
||||
wantValid: true,
|
||||
wantToken: " token123", // SplitN只分割一次
|
||||
},
|
||||
{
|
||||
name: "缺少Bearer",
|
||||
header: "token123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "只有Bearer",
|
||||
header: "Bearer",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenParts := strings.SplitN(tt.header, " ", 2)
|
||||
if len(tokenParts) == 2 && tokenParts[0] == "Bearer" {
|
||||
if !tt.wantValid {
|
||||
t.Errorf("应该无效但被识别为有效")
|
||||
}
|
||||
if tokenParts[1] != tt.wantToken {
|
||||
t.Errorf("Token = %q, want %q", tokenParts[1], tt.wantToken)
|
||||
}
|
||||
} else {
|
||||
if tt.wantValid {
|
||||
t.Errorf("应该有效但被识别为无效")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
22
internal/middleware/cors.go
Normal file
22
internal/middleware/cors.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS 跨域中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
return gin.HandlerFunc(func(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||||
c.Header("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
134
internal/middleware/cors_test.go
Normal file
134
internal/middleware/cors_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// TestCORS_Headers 测试CORS中间件设置的响应头
|
||||
func TestCORS_Headers(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证CORS响应头
|
||||
expectedHeaders := map[string]string{
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS, GET, PUT, DELETE",
|
||||
}
|
||||
|
||||
for header, expectedValue := range expectedHeaders {
|
||||
actualValue := w.Header().Get(header)
|
||||
if actualValue != expectedValue {
|
||||
t.Errorf("Header %s = %q, want %q", header, actualValue, expectedValue)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证Access-Control-Allow-Headers包含必要字段
|
||||
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
|
||||
if allowHeaders == "" {
|
||||
t.Error("Access-Control-Allow-Headers 不应为空")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCORS_OPTIONS 测试OPTIONS请求处理
|
||||
func TestCORS_OPTIONS(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("OPTIONS", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// OPTIONS请求应该返回204状态码
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Errorf("OPTIONS请求状态码 = %d, want %d", w.Code, http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCORS_AllowMethods 测试允许的HTTP方法
|
||||
func TestCORS_AllowMethods(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
methods := []string{"GET", "POST", "PUT", "DELETE"}
|
||||
for _, method := range methods {
|
||||
t.Run(method, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(method, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证允许的方法头包含该方法
|
||||
allowMethods := w.Header().Get("Access-Control-Allow-Methods")
|
||||
if allowMethods == "" {
|
||||
t.Error("Access-Control-Allow-Methods 不应为空")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCORS_AllowHeaders 测试允许的请求头
|
||||
func TestCORS_AllowHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
|
||||
expectedHeaders := []string{"Content-Type", "Authorization", "Accept"}
|
||||
|
||||
for _, expectedHeader := range expectedHeaders {
|
||||
if !contains(allowHeaders, expectedHeader) {
|
||||
t.Errorf("Access-Control-Allow-Headers 应包含 %s", expectedHeader)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数:检查字符串是否包含子字符串(简单实现)
|
||||
func contains(s, substr string) bool {
|
||||
if len(substr) == 0 {
|
||||
return true
|
||||
}
|
||||
if len(s) < len(substr) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
39
internal/middleware/logger.go
Normal file
39
internal/middleware/logger.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Logger 日志中间件
|
||||
func Logger(logger *zap.Logger) gin.HandlerFunc {
|
||||
return gin.HandlerFunc(func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
raw := c.Request.URL.RawQuery
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 记录日志
|
||||
latency := time.Since(start)
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
if raw != "" {
|
||||
path = path + "?" + raw
|
||||
}
|
||||
|
||||
logger.Info("HTTP请求",
|
||||
zap.String("method", method),
|
||||
zap.String("path", path),
|
||||
zap.Int("status", statusCode),
|
||||
zap.String("ip", clientIP),
|
||||
zap.Duration("latency", latency),
|
||||
zap.String("user_agent", c.Request.UserAgent()),
|
||||
)
|
||||
})
|
||||
}
|
||||
185
internal/middleware/logger_test.go
Normal file
185
internal/middleware/logger_test.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestLogger_Middleware 测试日志中间件基本功能
|
||||
func TestLogger_Middleware(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Logger(logger))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
start := time.Now()
|
||||
router.ServeHTTP(w, req)
|
||||
duration := time.Since(start)
|
||||
|
||||
// 验证请求成功处理
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 验证处理时间合理(应该很短)
|
||||
if duration > 1*time.Second {
|
||||
t.Errorf("处理时间过长: %v", duration)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLogger_RequestInfo 测试日志中间件记录的请求信息
|
||||
func TestLogger_RequestInfo(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Logger(logger))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{
|
||||
name: "GET请求",
|
||||
method: "GET",
|
||||
path: "/test",
|
||||
},
|
||||
{
|
||||
name: "POST请求",
|
||||
method: "POST",
|
||||
path: "/test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(tt.method, tt.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证请求被正确处理
|
||||
if w.Code != http.StatusOK && w.Code != http.StatusNotFound {
|
||||
t.Errorf("状态码 = %d", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLogger_QueryParams 测试带查询参数的请求
|
||||
func TestLogger_QueryParams(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Logger(logger))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test?page=1&size=20", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证请求成功处理
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLogger_StatusCodes 测试不同状态码的日志记录
|
||||
func TestLogger_StatusCodes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Logger(logger))
|
||||
|
||||
router.GET("/success", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
router.GET("/notfound", func(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"message": "not found"})
|
||||
})
|
||||
router.GET("/error", func(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"message": "error"})
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "成功请求",
|
||||
path: "/success",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "404请求",
|
||||
path: "/notfound",
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "500请求",
|
||||
path: "/error",
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", tt.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tt.wantStatus {
|
||||
t.Errorf("状态码 = %d, want %d", w.Code, tt.wantStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLogger_Latency 测试延迟计算
|
||||
func TestLogger_Latency(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Logger(logger))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
// 模拟一些处理时间
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
start := time.Now()
|
||||
router.ServeHTTP(w, req)
|
||||
duration := time.Since(start)
|
||||
|
||||
// 验证延迟计算合理(应该包含处理时间)
|
||||
if duration < 10*time.Millisecond {
|
||||
t.Errorf("延迟计算可能不正确: %v", duration)
|
||||
}
|
||||
}
|
||||
29
internal/middleware/recovery.go
Normal file
29
internal/middleware/recovery.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Recovery 恢复中间件
|
||||
func Recovery(logger *zap.Logger) gin.HandlerFunc {
|
||||
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
|
||||
if err, ok := recovered.(string); ok {
|
||||
logger.Error("服务器恐慌",
|
||||
zap.String("error", err),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("ip", c.ClientIP()),
|
||||
zap.String("stack", string(debug.Stack())),
|
||||
)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": "服务器内部错误",
|
||||
})
|
||||
})
|
||||
}
|
||||
153
internal/middleware/recovery_test.go
Normal file
153
internal/middleware/recovery_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestRecovery_PanicHandling 测试恢复中间件处理panic
|
||||
func TestRecovery_PanicHandling(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Recovery(logger))
|
||||
|
||||
// 创建一个会panic的路由
|
||||
router.GET("/panic", func(c *gin.Context) {
|
||||
panic("test panic")
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/panic", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// 应该不会导致测试panic,而是返回500错误
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证返回500状态码
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecovery_StringPanic 测试字符串类型的panic
|
||||
func TestRecovery_StringPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Recovery(logger))
|
||||
|
||||
router.GET("/panic", func(c *gin.Context) {
|
||||
panic("string panic message")
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/panic", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证返回500状态码
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecovery_ErrorPanic 测试error类型的panic
|
||||
func TestRecovery_ErrorPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Recovery(logger))
|
||||
|
||||
router.GET("/panic", func(c *gin.Context) {
|
||||
panic(http.ErrBodyReadAfterClose)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/panic", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// 应该不会导致测试panic
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证返回500状态码
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecovery_NilPanic 测试nil panic
|
||||
func TestRecovery_NilPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Recovery(logger))
|
||||
|
||||
router.GET("/panic", func(c *gin.Context) {
|
||||
// 直接panic模拟nil pointer错误,避免linter警告
|
||||
panic("runtime error: invalid memory address or nil pointer dereference")
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/panic", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证返回500状态码
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecovery_ResponseFormat 测试恢复后的响应格式
|
||||
func TestRecovery_ResponseFormat(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Recovery(logger))
|
||||
|
||||
router.GET("/panic", func(c *gin.Context) {
|
||||
panic("test panic")
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/panic", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证响应体包含错误信息
|
||||
body := w.Body.String()
|
||||
if body == "" {
|
||||
t.Error("响应体不应为空")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecovery_NormalRequest 测试正常请求不受影响
|
||||
func TestRecovery_NormalRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Recovery(logger))
|
||||
|
||||
router.GET("/normal", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/normal", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 正常请求应该不受影响
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
45
internal/model/audit_log.go
Normal file
45
internal/model/audit_log.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// AuditLog 审计日志模型
|
||||
type AuditLog struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserID *int64 `gorm:"column:user_id;type:bigint;index" json:"user_id,omitempty"`
|
||||
Action string `gorm:"column:action;type:varchar(100);not null;index" json:"action"`
|
||||
ResourceType string `gorm:"column:resource_type;type:varchar(50);not null;index:idx_audit_logs_resource" json:"resource_type"`
|
||||
ResourceID string `gorm:"column:resource_id;type:varchar(50);index:idx_audit_logs_resource" json:"resource_id,omitempty"`
|
||||
OldValues string `gorm:"column:old_values;type:jsonb" json:"old_values,omitempty"` // JSONB 格式
|
||||
NewValues string `gorm:"column:new_values;type:jsonb" json:"new_values,omitempty"` // JSONB 格式
|
||||
IPAddress string `gorm:"column:ip_address;type:inet;not null" json:"ip_address"`
|
||||
UserAgent string `gorm:"column:user_agent;type:text" json:"user_agent,omitempty"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP;index:idx_audit_logs_created_at,sort:desc" json:"created_at"`
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (AuditLog) TableName() string {
|
||||
return "audit_logs"
|
||||
}
|
||||
|
||||
// CasbinRule Casbin 权限规则模型
|
||||
type CasbinRule struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
PType string `gorm:"column:ptype;type:varchar(100);not null;index;uniqueIndex:uk_casbin_rule" json:"ptype"`
|
||||
V0 string `gorm:"column:v0;type:varchar(100);not null;default:'';index;uniqueIndex:uk_casbin_rule" json:"v0"`
|
||||
V1 string `gorm:"column:v1;type:varchar(100);not null;default:'';index;uniqueIndex:uk_casbin_rule" json:"v1"`
|
||||
V2 string `gorm:"column:v2;type:varchar(100);not null;default:'';uniqueIndex:uk_casbin_rule" json:"v2"`
|
||||
V3 string `gorm:"column:v3;type:varchar(100);not null;default:'';uniqueIndex:uk_casbin_rule" json:"v3"`
|
||||
V4 string `gorm:"column:v4;type:varchar(100);not null;default:'';uniqueIndex:uk_casbin_rule" json:"v4"`
|
||||
V5 string `gorm:"column:v5;type:varchar(100);not null;default:'';uniqueIndex:uk_casbin_rule" json:"v5"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"created_at"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (CasbinRule) TableName() string {
|
||||
return "casbin_rule"
|
||||
}
|
||||
63
internal/model/profile.go
Normal file
63
internal/model/profile.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Profile Minecraft 档案模型
|
||||
type Profile struct {
|
||||
UUID string `gorm:"column:uuid;type:varchar(36);primaryKey" json:"uuid"`
|
||||
UserID int64 `gorm:"column:user_id;not null;index" json:"user_id"`
|
||||
Name string `gorm:"column:name;type:varchar(16);not null;uniqueIndex" json:"name"` // Minecraft 角色名
|
||||
SkinID *int64 `gorm:"column:skin_id;type:bigint" json:"skin_id,omitempty"`
|
||||
CapeID *int64 `gorm:"column:cape_id;type:bigint" json:"cape_id,omitempty"`
|
||||
RSAPrivateKey string `gorm:"column:rsa_private_key;type:text;not null" json:"-"` // RSA 私钥不返回给前端
|
||||
IsActive bool `gorm:"column:is_active;not null;default:true;index" json:"is_active"`
|
||||
LastUsedAt *time.Time `gorm:"column:last_used_at;type:timestamp" json:"last_used_at,omitempty"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
Skin *Texture `gorm:"foreignKey:SkinID" json:"skin,omitempty"`
|
||||
Cape *Texture `gorm:"foreignKey:CapeID" json:"cape,omitempty"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (Profile) TableName() string {
|
||||
return "profiles"
|
||||
}
|
||||
|
||||
// ProfileResponse 档案响应(包含完整的皮肤/披风信息)
|
||||
type ProfileResponse struct {
|
||||
UUID string `json:"uuid"`
|
||||
Name string `json:"name"`
|
||||
Textures ProfileTexturesData `json:"textures"`
|
||||
IsActive bool `json:"is_active"`
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// ProfileTexturesData Minecraft 材质数据结构
|
||||
type ProfileTexturesData struct {
|
||||
Skin *ProfileTexture `json:"SKIN,omitempty"`
|
||||
Cape *ProfileTexture `json:"CAPE,omitempty"`
|
||||
}
|
||||
|
||||
// ProfileTexture 单个材质信息
|
||||
type ProfileTexture struct {
|
||||
URL string `json:"url"`
|
||||
Metadata *ProfileTextureMetadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ProfileTextureMetadata 材质元数据
|
||||
type ProfileTextureMetadata struct {
|
||||
Model string `json:"model,omitempty"` // "slim" or "classic"
|
||||
}
|
||||
|
||||
type KeyPair struct {
|
||||
PrivateKey string `json:"private_key" bson:"private_key"`
|
||||
PublicKey string `json:"public_key" bson:"public_key"`
|
||||
Expiration time.Time `json:"expiration" bson:"expiration"`
|
||||
Refresh time.Time `json:"refresh" bson:"refresh"`
|
||||
}
|
||||
85
internal/model/response.go
Normal file
85
internal/model/response.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package model
|
||||
|
||||
// Response 通用API响应结构
|
||||
type Response struct {
|
||||
Code int `json:"code"` // 业务状态码
|
||||
Message string `json:"message"` // 响应消息
|
||||
Data interface{} `json:"data,omitempty"` // 响应数据
|
||||
}
|
||||
|
||||
// PaginationResponse 分页响应结构
|
||||
type PaginationResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data"`
|
||||
Total int64 `json:"total"` // 总记录数
|
||||
Page int `json:"page"` // 当前页码
|
||||
PerPage int `json:"per_page"` // 每页数量
|
||||
}
|
||||
|
||||
// ErrorResponse 错误响应
|
||||
type ErrorResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Error string `json:"error,omitempty"` // 详细错误信息(仅开发环境)
|
||||
}
|
||||
|
||||
// 常用状态码
|
||||
const (
|
||||
CodeSuccess = 200 // 成功
|
||||
CodeCreated = 201 // 创建成功
|
||||
CodeBadRequest = 400 // 请求参数错误
|
||||
CodeUnauthorized = 401 // 未授权
|
||||
CodeForbidden = 403 // 禁止访问
|
||||
CodeNotFound = 404 // 资源不存在
|
||||
CodeConflict = 409 // 资源冲突
|
||||
CodeServerError = 500 // 服务器错误
|
||||
)
|
||||
|
||||
// 常用响应消息
|
||||
const (
|
||||
MsgSuccess = "操作成功"
|
||||
MsgCreated = "创建成功"
|
||||
MsgBadRequest = "请求参数错误"
|
||||
MsgUnauthorized = "未授权,请先登录"
|
||||
MsgForbidden = "权限不足"
|
||||
MsgNotFound = "资源不存在"
|
||||
MsgConflict = "资源已存在"
|
||||
MsgServerError = "服务器内部错误"
|
||||
MsgInvalidToken = "无效的令牌"
|
||||
MsgTokenExpired = "令牌已过期"
|
||||
MsgInvalidCredentials = "用户名或密码错误"
|
||||
)
|
||||
|
||||
// NewSuccessResponse 创建成功响应
|
||||
func NewSuccessResponse(data interface{}) *Response {
|
||||
return &Response{
|
||||
Code: CodeSuccess,
|
||||
Message: MsgSuccess,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// NewErrorResponse 创建错误响应
|
||||
func NewErrorResponse(code int, message string, err error) *ErrorResponse {
|
||||
resp := &ErrorResponse{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
if err != nil {
|
||||
resp.Error = err.Error()
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// NewPaginationResponse 创建分页响应
|
||||
func NewPaginationResponse(data interface{}, total int64, page, perPage int) *PaginationResponse {
|
||||
return &PaginationResponse{
|
||||
Code: CodeSuccess,
|
||||
Message: MsgSuccess,
|
||||
Data: data,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PerPage: perPage,
|
||||
}
|
||||
}
|
||||
257
internal/model/response_test.go
Normal file
257
internal/model/response_test.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestNewSuccessResponse 测试创建成功响应
|
||||
func TestNewSuccessResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data interface{}
|
||||
}{
|
||||
{
|
||||
name: "字符串数据",
|
||||
data: "success",
|
||||
},
|
||||
{
|
||||
name: "map数据",
|
||||
data: map[string]string{
|
||||
"id": "1",
|
||||
"name": "test",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil数据",
|
||||
data: nil,
|
||||
},
|
||||
{
|
||||
name: "数组数据",
|
||||
data: []string{"a", "b", "c"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp := NewSuccessResponse(tt.data)
|
||||
if resp == nil {
|
||||
t.Fatal("NewSuccessResponse() 返回nil")
|
||||
}
|
||||
if resp.Code != CodeSuccess {
|
||||
t.Errorf("Code = %d, want %d", resp.Code, CodeSuccess)
|
||||
}
|
||||
if resp.Message != MsgSuccess {
|
||||
t.Errorf("Message = %q, want %q", resp.Message, MsgSuccess)
|
||||
}
|
||||
// 对于可比较类型直接比较,对于不可比较类型只验证不为nil
|
||||
switch v := tt.data.(type) {
|
||||
case string, nil:
|
||||
// 数组不能直接比较,只验证不为nil
|
||||
if tt.data != nil && resp.Data == nil {
|
||||
t.Error("Data 不应为nil")
|
||||
}
|
||||
if tt.data == nil && resp.Data != nil {
|
||||
t.Error("Data 应为nil")
|
||||
}
|
||||
case []string:
|
||||
// 数组不能直接比较,只验证不为nil
|
||||
if resp.Data == nil {
|
||||
t.Error("Data 不应为nil")
|
||||
}
|
||||
default:
|
||||
// 对于map等不可比较类型,只验证不为nil
|
||||
if tt.data != nil && resp.Data == nil {
|
||||
t.Error("Data 不应为nil")
|
||||
}
|
||||
_ = v
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewErrorResponse 测试创建错误响应
|
||||
func TestNewErrorResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code int
|
||||
message string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "带错误信息",
|
||||
code: CodeBadRequest,
|
||||
message: "请求参数错误",
|
||||
err: errors.New("具体错误信息"),
|
||||
},
|
||||
{
|
||||
name: "无错误信息",
|
||||
code: CodeUnauthorized,
|
||||
message: "未授权",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "服务器错误",
|
||||
code: CodeServerError,
|
||||
message: "服务器内部错误",
|
||||
err: errors.New("数据库连接失败"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp := NewErrorResponse(tt.code, tt.message, tt.err)
|
||||
if resp == nil {
|
||||
t.Fatal("NewErrorResponse() 返回nil")
|
||||
}
|
||||
if resp.Code != tt.code {
|
||||
t.Errorf("Code = %d, want %d", resp.Code, tt.code)
|
||||
}
|
||||
if resp.Message != tt.message {
|
||||
t.Errorf("Message = %q, want %q", resp.Message, tt.message)
|
||||
}
|
||||
if tt.err != nil {
|
||||
if resp.Error != tt.err.Error() {
|
||||
t.Errorf("Error = %q, want %q", resp.Error, tt.err.Error())
|
||||
}
|
||||
} else {
|
||||
if resp.Error != "" {
|
||||
t.Errorf("Error 应为空,实际为 %q", resp.Error)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewPaginationResponse 测试创建分页响应
|
||||
func TestNewPaginationResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data interface{}
|
||||
total int64
|
||||
page int
|
||||
perPage int
|
||||
}{
|
||||
{
|
||||
name: "正常分页",
|
||||
data: []string{"a", "b", "c"},
|
||||
total: 100,
|
||||
page: 1,
|
||||
perPage: 20,
|
||||
},
|
||||
{
|
||||
name: "空数据",
|
||||
data: []string{},
|
||||
total: 0,
|
||||
page: 1,
|
||||
perPage: 20,
|
||||
},
|
||||
{
|
||||
name: "最后一页",
|
||||
data: []string{"a", "b"},
|
||||
total: 22,
|
||||
page: 3,
|
||||
perPage: 10,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp := NewPaginationResponse(tt.data, tt.total, tt.page, tt.perPage)
|
||||
if resp == nil {
|
||||
t.Fatal("NewPaginationResponse() 返回nil")
|
||||
}
|
||||
if resp.Code != CodeSuccess {
|
||||
t.Errorf("Code = %d, want %d", resp.Code, CodeSuccess)
|
||||
}
|
||||
if resp.Message != MsgSuccess {
|
||||
t.Errorf("Message = %q, want %q", resp.Message, MsgSuccess)
|
||||
}
|
||||
// 对于可比较类型直接比较,对于不可比较类型只验证不为nil
|
||||
switch v := tt.data.(type) {
|
||||
case string, nil:
|
||||
// 数组不能直接比较,只验证不为nil
|
||||
if tt.data != nil && resp.Data == nil {
|
||||
t.Error("Data 不应为nil")
|
||||
}
|
||||
if tt.data == nil && resp.Data != nil {
|
||||
t.Error("Data 应为nil")
|
||||
}
|
||||
case []string:
|
||||
// 数组不能直接比较,只验证不为nil
|
||||
if resp.Data == nil {
|
||||
t.Error("Data 不应为nil")
|
||||
}
|
||||
default:
|
||||
// 对于map等不可比较类型,只验证不为nil
|
||||
if tt.data != nil && resp.Data == nil {
|
||||
t.Error("Data 不应为nil")
|
||||
}
|
||||
_ = v
|
||||
}
|
||||
if resp.Total != tt.total {
|
||||
t.Errorf("Total = %d, want %d", resp.Total, tt.total)
|
||||
}
|
||||
if resp.Page != tt.page {
|
||||
t.Errorf("Page = %d, want %d", resp.Page, tt.page)
|
||||
}
|
||||
if resp.PerPage != tt.perPage {
|
||||
t.Errorf("PerPage = %d, want %d", resp.PerPage, tt.perPage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResponseConstants 测试响应常量
|
||||
func TestResponseConstants(t *testing.T) {
|
||||
// 测试状态码常量
|
||||
statusCodes := map[string]int{
|
||||
"CodeSuccess": CodeSuccess,
|
||||
"CodeCreated": CodeCreated,
|
||||
"CodeBadRequest": CodeBadRequest,
|
||||
"CodeUnauthorized": CodeUnauthorized,
|
||||
"CodeForbidden": CodeForbidden,
|
||||
"CodeNotFound": CodeNotFound,
|
||||
"CodeConflict": CodeConflict,
|
||||
"CodeServerError": CodeServerError,
|
||||
}
|
||||
|
||||
expectedCodes := map[string]int{
|
||||
"CodeSuccess": 200,
|
||||
"CodeCreated": 201,
|
||||
"CodeBadRequest": 400,
|
||||
"CodeUnauthorized": 401,
|
||||
"CodeForbidden": 403,
|
||||
"CodeNotFound": 404,
|
||||
"CodeConflict": 409,
|
||||
"CodeServerError": 500,
|
||||
}
|
||||
|
||||
for name, code := range statusCodes {
|
||||
expected := expectedCodes[name]
|
||||
if code != expected {
|
||||
t.Errorf("%s = %d, want %d", name, code, expected)
|
||||
}
|
||||
}
|
||||
|
||||
// 测试消息常量不为空
|
||||
messages := []string{
|
||||
MsgSuccess,
|
||||
MsgCreated,
|
||||
MsgBadRequest,
|
||||
MsgUnauthorized,
|
||||
MsgForbidden,
|
||||
MsgNotFound,
|
||||
MsgConflict,
|
||||
MsgServerError,
|
||||
MsgInvalidToken,
|
||||
MsgTokenExpired,
|
||||
MsgInvalidCredentials,
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg == "" {
|
||||
t.Error("响应消息常量不应为空")
|
||||
}
|
||||
}
|
||||
}
|
||||
41
internal/model/system_config.go
Normal file
41
internal/model/system_config.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConfigType 配置类型
|
||||
type ConfigType string
|
||||
|
||||
const (
|
||||
ConfigTypeString ConfigType = "STRING"
|
||||
ConfigTypeInteger ConfigType = "INTEGER"
|
||||
ConfigTypeBoolean ConfigType = "BOOLEAN"
|
||||
ConfigTypeJSON ConfigType = "JSON"
|
||||
)
|
||||
|
||||
// SystemConfig 系统配置模型
|
||||
type SystemConfig struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
Key string `gorm:"column:key;type:varchar(100);not null;uniqueIndex" json:"key"`
|
||||
Value string `gorm:"column:value;type:text;not null" json:"value"`
|
||||
Description string `gorm:"column:description;type:varchar(255);not null;default:''" json:"description"`
|
||||
Type ConfigType `gorm:"column:type;type:varchar(50);not null;default:'STRING'" json:"type"` // STRING, INTEGER, BOOLEAN, JSON
|
||||
IsPublic bool `gorm:"column:is_public;not null;default:false;index" json:"is_public"` // 是否可被前端获取
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (SystemConfig) TableName() string {
|
||||
return "system_config"
|
||||
}
|
||||
|
||||
// SystemConfigPublicResponse 公开配置响应
|
||||
type SystemConfigPublicResponse struct {
|
||||
SiteName string `json:"site_name"`
|
||||
SiteDescription string `json:"site_description"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
MaintenanceMode bool `json:"maintenance_mode"`
|
||||
Announcement string `json:"announcement"`
|
||||
}
|
||||
76
internal/model/texture.go
Normal file
76
internal/model/texture.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// TextureType 材质类型
|
||||
type TextureType string
|
||||
|
||||
const (
|
||||
TextureTypeSkin TextureType = "SKIN"
|
||||
TextureTypeCape TextureType = "CAPE"
|
||||
)
|
||||
|
||||
// Texture 材质模型
|
||||
type Texture struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UploaderID int64 `gorm:"column:uploader_id;not null;index" json:"uploader_id"`
|
||||
Name string `gorm:"column:name;type:varchar(100);not null;default:''" json:"name"`
|
||||
Description string `gorm:"column:description;type:text" json:"description,omitempty"`
|
||||
Type TextureType `gorm:"column:type;type:varchar(50);not null" json:"type"` // SKIN, CAPE
|
||||
URL string `gorm:"column:url;type:varchar(255);not null" json:"url"`
|
||||
Hash string `gorm:"column:hash;type:varchar(64);not null;uniqueIndex" json:"hash"` // SHA-256
|
||||
Size int `gorm:"column:size;type:integer;not null;default:0" json:"size"`
|
||||
IsPublic bool `gorm:"column:is_public;not null;default:false;index:idx_textures_public_type_status" json:"is_public"`
|
||||
DownloadCount int `gorm:"column:download_count;type:integer;not null;default:0;index:idx_textures_download_count,sort:desc" json:"download_count"`
|
||||
FavoriteCount int `gorm:"column:favorite_count;type:integer;not null;default:0;index:idx_textures_favorite_count,sort:desc" json:"favorite_count"`
|
||||
IsSlim bool `gorm:"column:is_slim;not null;default:false" json:"is_slim"` // Alex(细) or Steve(粗)
|
||||
Status int16 `gorm:"column:status;type:smallint;not null;default:1;index:idx_textures_public_type_status" json:"status"` // 1:正常, 0:审核中, -1:已删除
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
||||
|
||||
// 关联
|
||||
Uploader *User `gorm:"foreignKey:UploaderID" json:"uploader,omitempty"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (Texture) TableName() string {
|
||||
return "textures"
|
||||
}
|
||||
|
||||
// UserTextureFavorite 用户材质收藏
|
||||
type UserTextureFavorite struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserID int64 `gorm:"column:user_id;not null;index;uniqueIndex:uk_user_texture" json:"user_id"`
|
||||
TextureID int64 `gorm:"column:texture_id;not null;index;uniqueIndex:uk_user_texture" json:"texture_id"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP;index" json:"created_at"`
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
Texture *Texture `gorm:"foreignKey:TextureID" json:"texture,omitempty"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (UserTextureFavorite) TableName() string {
|
||||
return "user_texture_favorites"
|
||||
}
|
||||
|
||||
// TextureDownloadLog 材质下载记录
|
||||
type TextureDownloadLog struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
TextureID int64 `gorm:"column:texture_id;not null;index" json:"texture_id"`
|
||||
UserID *int64 `gorm:"column:user_id;type:bigint;index" json:"user_id,omitempty"`
|
||||
IPAddress string `gorm:"column:ip_address;type:inet;not null;index" json:"ip_address"`
|
||||
UserAgent string `gorm:"column:user_agent;type:text" json:"user_agent,omitempty"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP;index:idx_download_logs_created_at,sort:desc" json:"created_at"`
|
||||
|
||||
// 关联
|
||||
Texture *Texture `gorm:"foreignKey:TextureID" json:"texture,omitempty"`
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (TextureDownloadLog) TableName() string {
|
||||
return "texture_download_logs"
|
||||
}
|
||||
14
internal/model/token.go
Normal file
14
internal/model/token.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
type Token struct {
|
||||
AccessToken string `json:"_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
ClientToken string `json:"client_token"`
|
||||
ProfileId string `json:"profile_id"`
|
||||
Usable bool `json:"usable"`
|
||||
IssueDate time.Time `json:"issue_date"`
|
||||
}
|
||||
|
||||
func (Token) TableName() string { return "token" }
|
||||
70
internal/model/user.go
Normal file
70
internal/model/user.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// User 用户模型
|
||||
type User struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
Username string `gorm:"column:username;type:varchar(255);not null;uniqueIndex" json:"username"`
|
||||
Password string `gorm:"column:password;type:varchar(255);not null" json:"-"` // 密码不返回给前端
|
||||
Email string `gorm:"column:email;type:varchar(255);not null;uniqueIndex" json:"email"`
|
||||
Avatar string `gorm:"column:avatar;type:varchar(255);not null;default:''" json:"avatar"`
|
||||
Points int `gorm:"column:points;type:integer;not null;default:0" json:"points"`
|
||||
Role string `gorm:"column:role;type:varchar(50);not null;default:'user'" json:"role"`
|
||||
Status int16 `gorm:"column:status;type:smallint;not null;default:1" json:"status"` // 1:正常, 0:禁用, -1:删除
|
||||
Properties string `gorm:"column:properties;type:jsonb" json:"properties"` // JSON字符串,存储为PostgreSQL的JSONB类型
|
||||
LastLoginAt *time.Time `gorm:"column:last_login_at;type:timestamp" json:"last_login_at,omitempty"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (User) TableName() string {
|
||||
return "user"
|
||||
}
|
||||
|
||||
// UserPointLog 用户积分变更记录
|
||||
type UserPointLog struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserID int64 `gorm:"column:user_id;not null;index" json:"user_id"`
|
||||
ChangeType string `gorm:"column:change_type;type:varchar(50);not null" json:"change_type"` // EARN, SPEND, ADMIN_ADJUST
|
||||
Amount int `gorm:"column:amount;type:integer;not null" json:"amount"`
|
||||
BalanceBefore int `gorm:"column:balance_before;type:integer;not null" json:"balance_before"`
|
||||
BalanceAfter int `gorm:"column:balance_after;type:integer;not null" json:"balance_after"`
|
||||
Reason string `gorm:"column:reason;type:varchar(255);not null" json:"reason"`
|
||||
ReferenceType string `gorm:"column:reference_type;type:varchar(50)" json:"reference_type,omitempty"`
|
||||
ReferenceID *int64 `gorm:"column:reference_id;type:bigint" json:"reference_id,omitempty"`
|
||||
OperatorID *int64 `gorm:"column:operator_id;type:bigint" json:"operator_id,omitempty"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP;index:idx_point_logs_created_at,sort:desc" json:"created_at"`
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
Operator *User `gorm:"foreignKey:OperatorID" json:"operator,omitempty"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (UserPointLog) TableName() string {
|
||||
return "user_point_logs"
|
||||
}
|
||||
|
||||
// UserLoginLog 用户登录日志
|
||||
type UserLoginLog struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserID int64 `gorm:"column:user_id;not null;index" json:"user_id"`
|
||||
IPAddress string `gorm:"column:ip_address;type:inet;not null;index" json:"ip_address"`
|
||||
UserAgent string `gorm:"column:user_agent;type:text" json:"user_agent,omitempty"`
|
||||
LoginMethod string `gorm:"column:login_method;type:varchar(50);not null;default:'PASSWORD'" json:"login_method"`
|
||||
IsSuccess bool `gorm:"column:is_success;not null;index" json:"is_success"`
|
||||
FailureReason string `gorm:"column:failure_reason;type:varchar(255)" json:"failure_reason,omitempty"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP;index:idx_login_logs_created_at,sort:desc" json:"created_at"`
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (UserLoginLog) TableName() string {
|
||||
return "user_login_logs"
|
||||
}
|
||||
48
internal/model/yggdrasil.go
Normal file
48
internal/model/yggdrasil.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 定义随机字符集
|
||||
const passwordChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
// Yggdrasil ygg密码与用户id绑定
|
||||
type Yggdrasil struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;not null" json:"id"`
|
||||
Password string `gorm:"column:password;not null" json:"password"`
|
||||
// 关联 - Yggdrasil的ID引用User的ID,但不自动创建外键约束(避免循环依赖)
|
||||
User *User `gorm:"foreignKey:ID;references:ID;constraint:OnDelete:CASCADE,OnUpdate:CASCADE" json:"user,omitempty"`
|
||||
}
|
||||
|
||||
func (Yggdrasil) TableName() string { return "Yggdrasil" }
|
||||
|
||||
// AfterCreate User创建后自动同步生成GeneratePassword记录
|
||||
func (u *User) AfterCreate(tx *gorm.DB) error {
|
||||
randomPwd := GenerateRandomPassword(16)
|
||||
|
||||
// 创建GeneratePassword记录
|
||||
gp := Yggdrasil{
|
||||
ID: u.ID, // 关联User的ID
|
||||
Password: randomPwd, // 16位随机密码
|
||||
}
|
||||
|
||||
if err := tx.Create(&gp).Error; err != nil {
|
||||
// 若同步失败,可记录日志或回滚事务(根据业务需求处理)
|
||||
return fmt.Errorf("同步生成密码失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateRandomPassword 生成指定长度的随机字符串
|
||||
func GenerateRandomPassword(length int) string {
|
||||
rand.Seed(time.Now().UnixNano()) // 初始化随机数种子
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = passwordChars[rand.Intn(len(passwordChars))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
199
internal/repository/profile_repository.go
Normal file
199
internal/repository/profile_repository.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/database"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateProfile 创建档案
|
||||
func CreateProfile(profile *model.Profile) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Create(profile).Error
|
||||
}
|
||||
|
||||
// FindProfileByUUID 根据UUID查找档案
|
||||
func FindProfileByUUID(uuid string) (*model.Profile, error) {
|
||||
db := database.MustGetDB()
|
||||
var profile model.Profile
|
||||
err := db.Where("uuid = ?", uuid).
|
||||
Preload("Skin").
|
||||
Preload("Cape").
|
||||
First(&profile).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
// FindProfileByName 根据角色名查找档案
|
||||
func FindProfileByName(name string) (*model.Profile, error) {
|
||||
db := database.MustGetDB()
|
||||
var profile model.Profile
|
||||
err := db.Where("name = ?", name).First(&profile).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
// FindProfilesByUserID 获取用户的所有档案
|
||||
func FindProfilesByUserID(userID int64) ([]*model.Profile, error) {
|
||||
db := database.MustGetDB()
|
||||
var profiles []*model.Profile
|
||||
err := db.Where("user_id = ?", userID).
|
||||
Preload("Skin").
|
||||
Preload("Cape").
|
||||
Order("created_at DESC").
|
||||
Find(&profiles).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// UpdateProfile 更新档案
|
||||
func UpdateProfile(profile *model.Profile) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Save(profile).Error
|
||||
}
|
||||
|
||||
// UpdateProfileFields 更新指定字段
|
||||
func UpdateProfileFields(uuid string, updates map[string]interface{}) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.Profile{}).
|
||||
Where("uuid = ?", uuid).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteProfile 删除档案
|
||||
func DeleteProfile(uuid string) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Where("uuid = ?", uuid).Delete(&model.Profile{}).Error
|
||||
}
|
||||
|
||||
// CountProfilesByUserID 统计用户的档案数量
|
||||
func CountProfilesByUserID(userID int64) (int64, error) {
|
||||
db := database.MustGetDB()
|
||||
var count int64
|
||||
err := db.Model(&model.Profile{}).
|
||||
Where("user_id = ?", userID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// SetActiveProfile 设置档案为活跃状态(同时将用户的其他档案设置为非活跃)
|
||||
func SetActiveProfile(uuid string, userID int64) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
// 将用户的所有档案设置为非活跃
|
||||
if err := tx.Model(&model.Profile{}).
|
||||
Where("user_id = ?", userID).
|
||||
Update("is_active", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 将指定档案设置为活跃
|
||||
if err := tx.Model(&model.Profile{}).
|
||||
Where("uuid = ? AND user_id = ?", uuid, userID).
|
||||
Update("is_active", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateProfileLastUsedAt 更新最后使用时间
|
||||
func UpdateProfileLastUsedAt(uuid string) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.Profile{}).
|
||||
Where("uuid = ?", uuid).
|
||||
Update("last_used_at", gorm.Expr("CURRENT_TIMESTAMP")).Error
|
||||
}
|
||||
|
||||
// FindOneProfileByUserID 根据id找一个角色
|
||||
func FindOneProfileByUserID(userID int64) (*model.Profile, error) {
|
||||
profiles, err := FindProfilesByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
profile := profiles[0]
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
func GetProfilesByNames(names []string) ([]*model.Profile, error) {
|
||||
db := database.MustGetDB()
|
||||
var profiles []*model.Profile
|
||||
err := db.Where("name in (?)", names).Find(&profiles).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
func GetProfileKeyPair(profileId string) (*model.KeyPair, error) {
|
||||
db := database.MustGetDB()
|
||||
// 1. 参数校验(保持原逻辑)
|
||||
if profileId == "" {
|
||||
return nil, errors.New("参数不能为空")
|
||||
}
|
||||
|
||||
// 2. GORM 查询:只查询 key_pair 字段(对应原 mongo 投影)
|
||||
var profile *model.Profile
|
||||
// 条件:id = profileId(PostgreSQL 主键),只选择 key_pair 字段
|
||||
result := db.WithContext(context.Background()).
|
||||
Select("key_pair"). // 只查询需要的字段(投影)
|
||||
Where("id = ?", profileId). // 查询条件(GORM 自动处理占位符,避免 SQL 注入)
|
||||
First(&profile) // 查单条记录
|
||||
|
||||
// 3. 错误处理(适配 GORM 错误类型)
|
||||
if result.Error != nil {
|
||||
// 空结果判断(对应原 mongo.ErrNoDocuments / pgx.ErrNoRows)
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("key pair未找到")
|
||||
}
|
||||
// 保持原错误封装格式
|
||||
return nil, fmt.Errorf("获取key pair失败: %w", result.Error)
|
||||
}
|
||||
|
||||
// 4. JSONB 反序列化为 model.KeyPair
|
||||
keyPair := &model.KeyPair{}
|
||||
return keyPair, nil
|
||||
}
|
||||
|
||||
func UpdateProfileKeyPair(profileId string, keyPair *model.KeyPair) error {
|
||||
db := database.MustGetDB()
|
||||
// 仅保留最必要的入参校验(避免无效数据库请求)
|
||||
if profileId == "" {
|
||||
return errors.New("profileId 不能为空")
|
||||
}
|
||||
if keyPair == nil {
|
||||
return errors.New("keyPair 不能为 nil")
|
||||
}
|
||||
|
||||
// 事务内执行核心更新(保证原子性,出错自动回滚)
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
// 核心更新逻辑:按 profileId 匹配,直接更新 key_pair 相关字段
|
||||
result := tx.WithContext(context.Background()).
|
||||
Table("profiles"). // 目标表名(与 PostgreSQL 表一致)
|
||||
Where("id = ?", profileId). // 更新条件:profileId 匹配
|
||||
// 直接映射字段(无需序列化,依赖 GORM 自动字段匹配)
|
||||
UpdateColumns(map[string]interface{}{
|
||||
"private_key": keyPair.PrivateKey, // 数据库 private_key 字段
|
||||
"public_key": keyPair.PublicKey, // 数据库 public_key 字段
|
||||
// 若 key_pair 是单个字段(非拆分),替换为:"key_pair": keyPair
|
||||
})
|
||||
|
||||
// 仅处理数据库层面的致命错误
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("更新 keyPair 失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
184
internal/repository/profile_repository_test.go
Normal file
184
internal/repository/profile_repository_test.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestProfileRepository_QueryConditions 测试档案查询条件逻辑
|
||||
func TestProfileRepository_QueryConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uuid string
|
||||
userID int64
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的UUID",
|
||||
uuid: "123e4567-e89b-12d3-a456-426614174000",
|
||||
userID: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "UUID为空",
|
||||
uuid: "",
|
||||
userID: 1,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "用户ID为0",
|
||||
uuid: "123e4567-e89b-12d3-a456-426614174000",
|
||||
userID: 0,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.uuid != "" && tt.userID > 0
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Query condition validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileRepository_SetActiveLogic 测试设置活跃档案的逻辑
|
||||
func TestProfileRepository_SetActiveLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uuid string
|
||||
userID int64
|
||||
otherProfiles int
|
||||
wantAllInactive bool
|
||||
}{
|
||||
{
|
||||
name: "设置一个档案为活跃,其他应该变为非活跃",
|
||||
uuid: "profile-1",
|
||||
userID: 1,
|
||||
otherProfiles: 2,
|
||||
wantAllInactive: true,
|
||||
},
|
||||
{
|
||||
name: "只有一个档案时",
|
||||
uuid: "profile-1",
|
||||
userID: 1,
|
||||
otherProfiles: 0,
|
||||
wantAllInactive: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证逻辑:设置一个档案为活跃时,应该先将所有档案设为非活跃
|
||||
if !tt.wantAllInactive {
|
||||
t.Error("Setting active profile should first set all profiles to inactive")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileRepository_CountLogic 测试统计逻辑
|
||||
func TestProfileRepository_CountLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
wantCount int64
|
||||
}{
|
||||
{
|
||||
name: "有效用户ID",
|
||||
userID: 1,
|
||||
wantCount: 0, // 实际值取决于数据库
|
||||
},
|
||||
{
|
||||
name: "用户ID为0",
|
||||
userID: 0,
|
||||
wantCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证统计逻辑:用户ID应该大于0
|
||||
if tt.userID <= 0 && tt.wantCount != 0 {
|
||||
t.Error("Invalid userID should not count profiles")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileRepository_UpdateFieldsLogic 测试更新字段逻辑
|
||||
func TestProfileRepository_UpdateFieldsLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uuid string
|
||||
updates map[string]interface{}
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的更新",
|
||||
uuid: "123e4567-e89b-12d3-a456-426614174000",
|
||||
updates: map[string]interface{}{
|
||||
"name": "NewName",
|
||||
"skin_id": int64(1),
|
||||
},
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "UUID为空",
|
||||
uuid: "",
|
||||
updates: map[string]interface{}{"name": "NewName"},
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "更新字段为空",
|
||||
uuid: "123e4567-e89b-12d3-a456-426614174000",
|
||||
updates: map[string]interface{}{},
|
||||
wantValid: true, // 空更新也是有效的,只是不会更新任何字段
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.uuid != "" && tt.updates != nil
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Update fields validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileRepository_FindOneProfileLogic 测试查找单个档案的逻辑
|
||||
func TestProfileRepository_FindOneProfileLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileCount int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "有档案时返回第一个",
|
||||
profileCount: 1,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "多个档案时返回第一个",
|
||||
profileCount: 3,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "没有档案时应该错误",
|
||||
profileCount: 0,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证逻辑:如果没有档案,访问索引0会panic或返回错误
|
||||
hasError := tt.profileCount == 0
|
||||
if hasError != tt.wantError {
|
||||
t.Errorf("FindOneProfile logic failed: got error=%v, want error=%v", hasError, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
57
internal/repository/system_config_repository.go
Normal file
57
internal/repository/system_config_repository.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/database"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GetSystemConfigByKey 根据键获取配置
|
||||
func GetSystemConfigByKey(key string) (*model.SystemConfig, error) {
|
||||
db := database.MustGetDB()
|
||||
var config model.SystemConfig
|
||||
err := db.Where("key = ?", key).First(&config).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// GetPublicSystemConfigs 获取所有公开配置
|
||||
func GetPublicSystemConfigs() ([]model.SystemConfig, error) {
|
||||
db := database.MustGetDB()
|
||||
var configs []model.SystemConfig
|
||||
err := db.Where("is_public = ?", true).Find(&configs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
// GetAllSystemConfigs 获取所有配置(管理员用)
|
||||
func GetAllSystemConfigs() ([]model.SystemConfig, error) {
|
||||
db := database.MustGetDB()
|
||||
var configs []model.SystemConfig
|
||||
err := db.Find(&configs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
// UpdateSystemConfig 更新配置
|
||||
func UpdateSystemConfig(config *model.SystemConfig) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Save(config).Error
|
||||
}
|
||||
|
||||
// UpdateSystemConfigValue 更新配置值
|
||||
func UpdateSystemConfigValue(key, value string) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.SystemConfig{}).Where("key = ?", key).Update("value", value).Error
|
||||
}
|
||||
146
internal/repository/system_config_repository_test.go
Normal file
146
internal/repository/system_config_repository_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestSystemConfigRepository_QueryConditions 测试系统配置查询条件逻辑
|
||||
func TestSystemConfigRepository_QueryConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
isPublic bool
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的配置键",
|
||||
key: "site_name",
|
||||
isPublic: true,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "配置键为空",
|
||||
key: "",
|
||||
isPublic: true,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "公开配置查询",
|
||||
key: "site_name",
|
||||
isPublic: true,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "私有配置查询",
|
||||
key: "secret_key",
|
||||
isPublic: false,
|
||||
wantValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.key != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Query condition validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSystemConfigRepository_PublicConfigLogic 测试公开配置逻辑
|
||||
func TestSystemConfigRepository_PublicConfigLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isPublic bool
|
||||
wantInclude bool
|
||||
}{
|
||||
{
|
||||
name: "只获取公开配置",
|
||||
isPublic: true,
|
||||
wantInclude: true,
|
||||
},
|
||||
{
|
||||
name: "私有配置不应包含",
|
||||
isPublic: false,
|
||||
wantInclude: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证逻辑:GetPublicSystemConfigs应该只返回is_public=true的配置
|
||||
if tt.isPublic != tt.wantInclude {
|
||||
t.Errorf("Public config logic failed: isPublic=%v, wantInclude=%v", tt.isPublic, tt.wantInclude)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSystemConfigRepository_UpdateValueLogic 测试更新配置值逻辑
|
||||
func TestSystemConfigRepository_UpdateValueLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的键值对",
|
||||
key: "site_name",
|
||||
value: "CarrotSkin",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "键为空",
|
||||
key: "",
|
||||
value: "CarrotSkin",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "值为空(可能有效)",
|
||||
key: "site_name",
|
||||
value: "",
|
||||
wantValid: true, // 空值也可能是有效的
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.key != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Update value validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSystemConfigRepository_ErrorHandling 测试错误处理逻辑
|
||||
func TestSystemConfigRepository_ErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isNotFound bool
|
||||
wantNilConfig bool
|
||||
}{
|
||||
{
|
||||
name: "记录未找到应该返回nil配置",
|
||||
isNotFound: true,
|
||||
wantNilConfig: true,
|
||||
},
|
||||
{
|
||||
name: "找到记录应该返回配置",
|
||||
isNotFound: false,
|
||||
wantNilConfig: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证错误处理逻辑:如果是RecordNotFound,返回nil配置
|
||||
if tt.isNotFound != tt.wantNilConfig {
|
||||
t.Errorf("Error handling logic failed: isNotFound=%v, wantNilConfig=%v", tt.isNotFound, tt.wantNilConfig)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
231
internal/repository/texture_repository.go
Normal file
231
internal/repository/texture_repository.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/database"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateTexture 创建材质
|
||||
func CreateTexture(texture *model.Texture) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Create(texture).Error
|
||||
}
|
||||
|
||||
// FindTextureByID 根据ID查找材质
|
||||
func FindTextureByID(id int64) (*model.Texture, error) {
|
||||
db := database.MustGetDB()
|
||||
var texture model.Texture
|
||||
err := db.Preload("Uploader").First(&texture, id).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &texture, nil
|
||||
}
|
||||
|
||||
// FindTextureByHash 根据Hash查找材质
|
||||
func FindTextureByHash(hash string) (*model.Texture, error) {
|
||||
db := database.MustGetDB()
|
||||
var texture model.Texture
|
||||
err := db.Where("hash = ?", hash).First(&texture).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &texture, nil
|
||||
}
|
||||
|
||||
// FindTexturesByUploaderID 根据上传者ID查找材质列表
|
||||
func FindTexturesByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
db := database.MustGetDB()
|
||||
var textures []*model.Texture
|
||||
var total int64
|
||||
|
||||
query := db.Model(&model.Texture{}).Where("uploader_id = ? AND status != -1", uploaderID)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Preload("Uploader").
|
||||
Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&textures).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return textures, total, nil
|
||||
}
|
||||
|
||||
// SearchTextures 搜索材质
|
||||
func SearchTextures(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
db := database.MustGetDB()
|
||||
var textures []*model.Texture
|
||||
var total int64
|
||||
|
||||
query := db.Model(&model.Texture{}).Where("status = 1")
|
||||
|
||||
// 公开筛选
|
||||
if publicOnly {
|
||||
query = query.Where("is_public = ?", true)
|
||||
}
|
||||
|
||||
// 类型筛选
|
||||
if textureType != "" {
|
||||
query = query.Where("type = ?", textureType)
|
||||
}
|
||||
|
||||
// 关键词搜索
|
||||
if keyword != "" {
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Preload("Uploader").
|
||||
Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&textures).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return textures, total, nil
|
||||
}
|
||||
|
||||
// UpdateTexture 更新材质
|
||||
func UpdateTexture(texture *model.Texture) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Save(texture).Error
|
||||
}
|
||||
|
||||
// UpdateTextureFields 更新材质指定字段
|
||||
func UpdateTextureFields(id int64, fields map[string]interface{}) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.Texture{}).Where("id = ?", id).Updates(fields).Error
|
||||
}
|
||||
|
||||
// DeleteTexture 删除材质(软删除)
|
||||
func DeleteTexture(id int64) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.Texture{}).Where("id = ?", id).Update("status", -1).Error
|
||||
}
|
||||
|
||||
// IncrementTextureDownloadCount 增加下载次数
|
||||
func IncrementTextureDownloadCount(id int64) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.Texture{}).Where("id = ?", id).
|
||||
UpdateColumn("download_count", gorm.Expr("download_count + ?", 1)).Error
|
||||
}
|
||||
|
||||
// IncrementTextureFavoriteCount 增加收藏次数
|
||||
func IncrementTextureFavoriteCount(id int64) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.Texture{}).Where("id = ?", id).
|
||||
UpdateColumn("favorite_count", gorm.Expr("favorite_count + ?", 1)).Error
|
||||
}
|
||||
|
||||
// DecrementTextureFavoriteCount 减少收藏次数
|
||||
func DecrementTextureFavoriteCount(id int64) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.Texture{}).Where("id = ?", id).
|
||||
UpdateColumn("favorite_count", gorm.Expr("favorite_count - ?", 1)).Error
|
||||
}
|
||||
|
||||
// CreateTextureDownloadLog 创建下载日志
|
||||
func CreateTextureDownloadLog(log *model.TextureDownloadLog) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Create(log).Error
|
||||
}
|
||||
|
||||
// IsTextureFavorited 检查是否已收藏
|
||||
func IsTextureFavorited(userID, textureID int64) (bool, error) {
|
||||
db := database.MustGetDB()
|
||||
var count int64
|
||||
err := db.Model(&model.UserTextureFavorite{}).
|
||||
Where("user_id = ? AND texture_id = ?", userID, textureID).
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// AddTextureFavorite 添加收藏
|
||||
func AddTextureFavorite(userID, textureID int64) error {
|
||||
db := database.MustGetDB()
|
||||
favorite := &model.UserTextureFavorite{
|
||||
UserID: userID,
|
||||
TextureID: textureID,
|
||||
}
|
||||
return db.Create(favorite).Error
|
||||
}
|
||||
|
||||
// RemoveTextureFavorite 取消收藏
|
||||
func RemoveTextureFavorite(userID, textureID int64) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Where("user_id = ? AND texture_id = ?", userID, textureID).
|
||||
Delete(&model.UserTextureFavorite{}).Error
|
||||
}
|
||||
|
||||
// GetUserTextureFavorites 获取用户收藏的材质列表
|
||||
func GetUserTextureFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
db := database.MustGetDB()
|
||||
var textures []*model.Texture
|
||||
var total int64
|
||||
|
||||
// 子查询获取收藏的材质ID
|
||||
subQuery := db.Model(&model.UserTextureFavorite{}).
|
||||
Select("texture_id").
|
||||
Where("user_id = ?", userID)
|
||||
|
||||
query := db.Model(&model.Texture{}).
|
||||
Where("id IN (?) AND status = 1", subQuery)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Preload("Uploader").
|
||||
Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&textures).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return textures, total, nil
|
||||
}
|
||||
|
||||
// CountTexturesByUploaderID 统计用户上传的材质数量
|
||||
func CountTexturesByUploaderID(uploaderID int64) (int64, error) {
|
||||
db := database.MustGetDB()
|
||||
var count int64
|
||||
err := db.Model(&model.Texture{}).
|
||||
Where("uploader_id = ? AND status != -1", uploaderID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
89
internal/repository/token_repository.go
Normal file
89
internal/repository/token_repository.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/database"
|
||||
)
|
||||
|
||||
func CreateToken(token *model.Token) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Create(token).Error
|
||||
}
|
||||
|
||||
func GetTokensByUserId(userId int64) ([]*model.Token, error) {
|
||||
db := database.MustGetDB()
|
||||
tokens := make([]*model.Token, 0)
|
||||
err := db.Where("user_id = ?", userId).Find(&tokens).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func BatchDeleteTokens(tokensToDelete []string) (int64, error) {
|
||||
db := database.MustGetDB()
|
||||
if len(tokensToDelete) == 0 {
|
||||
return 0, nil // 无需要删除的令牌,直接返回
|
||||
}
|
||||
result := db.Where("access_token IN ?", tokensToDelete).Delete(&model.Token{})
|
||||
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func FindTokenByID(accessToken string) (*model.Token, error) {
|
||||
db := database.MustGetDB()
|
||||
var tokens []*model.Token
|
||||
err := db.Where("_id = ?", accessToken).Find(&tokens).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tokens[0], nil
|
||||
}
|
||||
|
||||
func GetUUIDByAccessToken(accessToken string) (string, error) {
|
||||
db := database.MustGetDB()
|
||||
var token model.Token
|
||||
err := db.Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return token.ProfileId, nil
|
||||
}
|
||||
|
||||
func GetUserIDByAccessToken(accessToken string) (int64, error) {
|
||||
db := database.MustGetDB()
|
||||
var token model.Token
|
||||
err := db.Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return token.UserID, nil
|
||||
}
|
||||
|
||||
func GetTokenByAccessToken(accessToken string) (*model.Token, error) {
|
||||
db := database.MustGetDB()
|
||||
var token model.Token
|
||||
err := db.Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func DeleteTokenByAccessToken(accessToken string) error {
|
||||
db := database.MustGetDB()
|
||||
err := db.Where("access_token = ?", accessToken).Delete(&model.Token{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteTokenByUserId(userId int64) error {
|
||||
db := database.MustGetDB()
|
||||
err := db.Where("user_id = ?", userId).Delete(&model.Token{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
123
internal/repository/token_repository_test.go
Normal file
123
internal/repository/token_repository_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestTokenRepository_BatchDeleteLogic 测试批量删除逻辑
|
||||
func TestTokenRepository_BatchDeleteLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokensToDelete []string
|
||||
wantCount int64
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "有效的token列表",
|
||||
tokensToDelete: []string{"token1", "token2", "token3"},
|
||||
wantCount: 3,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "空列表应该返回0",
|
||||
tokensToDelete: []string{},
|
||||
wantCount: 0,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "单个token",
|
||||
tokensToDelete: []string{"token1"},
|
||||
wantCount: 1,
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证批量删除逻辑:空列表应该直接返回0
|
||||
if len(tt.tokensToDelete) == 0 {
|
||||
if tt.wantCount != 0 {
|
||||
t.Errorf("Empty list should return count 0, got %d", tt.wantCount)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenRepository_QueryConditions 测试token查询条件逻辑
|
||||
func TestTokenRepository_QueryConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
userID int64
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的access token",
|
||||
accessToken: "valid-token-123",
|
||||
userID: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "access token为空",
|
||||
accessToken: "",
|
||||
userID: 1,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "用户ID为0",
|
||||
accessToken: "valid-token-123",
|
||||
userID: 0,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.accessToken != "" && tt.userID > 0
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Query condition validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenRepository_FindTokenByIDLogic 测试根据ID查找token的逻辑
|
||||
func TestTokenRepository_FindTokenByIDLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
resultCount int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "找到token",
|
||||
accessToken: "token-123",
|
||||
resultCount: 1,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "未找到token",
|
||||
accessToken: "token-123",
|
||||
resultCount: 0,
|
||||
wantError: true, // 访问索引0会panic
|
||||
},
|
||||
{
|
||||
name: "找到多个token(异常情况)",
|
||||
accessToken: "token-123",
|
||||
resultCount: 2,
|
||||
wantError: false, // 返回第一个
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证逻辑:如果结果为空,访问索引0会出错
|
||||
hasError := tt.resultCount == 0
|
||||
if hasError != tt.wantError {
|
||||
t.Errorf("FindTokenByID logic failed: got error=%v, want error=%v", hasError, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
136
internal/repository/user_repository.go
Normal file
136
internal/repository/user_repository.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/database"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateUser 创建用户
|
||||
func CreateUser(user *model.User) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Create(user).Error
|
||||
}
|
||||
|
||||
// FindUserByID 根据ID查找用户
|
||||
func FindUserByID(id int64) (*model.User, error) {
|
||||
db := database.MustGetDB()
|
||||
var user model.User
|
||||
err := db.Where("id = ? AND status != -1", id).First(&user).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// FindUserByUsername 根据用户名查找用户
|
||||
func FindUserByUsername(username string) (*model.User, error) {
|
||||
db := database.MustGetDB()
|
||||
var user model.User
|
||||
err := db.Where("username = ? AND status != -1", username).First(&user).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// FindUserByEmail 根据邮箱查找用户
|
||||
func FindUserByEmail(email string) (*model.User, error) {
|
||||
db := database.MustGetDB()
|
||||
var user model.User
|
||||
err := db.Where("email = ? AND status != -1", email).First(&user).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateUser 更新用户
|
||||
func UpdateUser(user *model.User) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Save(user).Error
|
||||
}
|
||||
|
||||
// UpdateUserFields 更新指定字段
|
||||
func UpdateUserFields(id int64, fields map[string]interface{}) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.User{}).Where("id = ?", id).Updates(fields).Error
|
||||
}
|
||||
|
||||
// DeleteUser 软删除用户
|
||||
func DeleteUser(id int64) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.User{}).Where("id = ?", id).Update("status", -1).Error
|
||||
}
|
||||
|
||||
// CreateLoginLog 创建登录日志
|
||||
func CreateLoginLog(log *model.UserLoginLog) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Create(log).Error
|
||||
}
|
||||
|
||||
// CreatePointLog 创建积分日志
|
||||
func CreatePointLog(log *model.UserPointLog) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Create(log).Error
|
||||
}
|
||||
|
||||
// UpdateUserPoints 更新用户积分(事务)
|
||||
func UpdateUserPoints(userID int64, amount int, changeType, reason string) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
// 获取当前用户积分
|
||||
var user model.User
|
||||
if err := tx.Where("id = ?", userID).First(&user).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
balanceBefore := user.Points
|
||||
balanceAfter := balanceBefore + amount
|
||||
|
||||
// 检查积分是否足够
|
||||
if balanceAfter < 0 {
|
||||
return errors.New("积分不足")
|
||||
}
|
||||
|
||||
// 更新用户积分
|
||||
if err := tx.Model(&user).Update("points", balanceAfter).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建积分日志
|
||||
log := &model.UserPointLog{
|
||||
UserID: userID,
|
||||
ChangeType: changeType,
|
||||
Amount: amount,
|
||||
BalanceBefore: balanceBefore,
|
||||
BalanceAfter: balanceAfter,
|
||||
Reason: reason,
|
||||
}
|
||||
|
||||
return tx.Create(log).Error
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUserAvatar 更新用户头像
|
||||
func UpdateUserAvatar(userID int64, avatarURL string) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.User{}).Where("id = ?", userID).Update("avatar", avatarURL).Error
|
||||
}
|
||||
|
||||
// UpdateUserEmail 更新用户邮箱
|
||||
func UpdateUserEmail(userID int64, email string) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.User{}).Where("id = ?", userID).Update("email", email).Error
|
||||
}
|
||||
155
internal/repository/user_repository_test.go
Normal file
155
internal/repository/user_repository_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestUserRepository_QueryConditions 测试用户查询条件逻辑
|
||||
func TestUserRepository_QueryConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id int64
|
||||
status int16
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的用户ID和状态",
|
||||
id: 1,
|
||||
status: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户ID为0时无效",
|
||||
id: 0,
|
||||
status: 1,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "状态为-1(已删除)应该被排除",
|
||||
id: 1,
|
||||
status: -1,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "状态为0(禁用)可能有效",
|
||||
id: 1,
|
||||
status: 0,
|
||||
wantValid: true, // 查询条件中只排除-1
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 测试查询条件逻辑:status != -1
|
||||
isValid := tt.id > 0 && tt.status != -1
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Query condition validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_DeleteLogic 测试软删除逻辑
|
||||
func TestUserRepository_DeleteLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
oldStatus int16
|
||||
newStatus int16
|
||||
}{
|
||||
{
|
||||
name: "软删除应该将状态设置为-1",
|
||||
oldStatus: 1,
|
||||
newStatus: -1,
|
||||
},
|
||||
{
|
||||
name: "从禁用状态删除",
|
||||
oldStatus: 0,
|
||||
newStatus: -1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证软删除逻辑:状态应该变为-1
|
||||
if tt.newStatus != -1 {
|
||||
t.Errorf("Delete should set status to -1, got %d", tt.newStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_UpdateFieldsLogic 测试更新字段逻辑
|
||||
func TestUserRepository_UpdateFieldsLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fields map[string]interface{}
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的更新字段",
|
||||
fields: map[string]interface{}{
|
||||
"email": "new@example.com",
|
||||
"avatar": "https://example.com/avatar.png",
|
||||
},
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "空字段映射",
|
||||
fields: map[string]interface{}{},
|
||||
wantValid: true, // 空映射也是有效的,只是不会更新任何字段
|
||||
},
|
||||
{
|
||||
name: "包含nil值的字段",
|
||||
fields: map[string]interface{}{
|
||||
"email": "new@example.com",
|
||||
"avatar": nil,
|
||||
},
|
||||
wantValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证字段映射逻辑
|
||||
isValid := tt.fields != nil
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Update fields validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_ErrorHandling 测试错误处理逻辑
|
||||
func TestUserRepository_ErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
isNotFound bool
|
||||
wantNilUser bool
|
||||
}{
|
||||
{
|
||||
name: "记录未找到应该返回nil用户",
|
||||
err: nil, // 模拟gorm.ErrRecordNotFound
|
||||
isNotFound: true,
|
||||
wantNilUser: true,
|
||||
},
|
||||
{
|
||||
name: "其他错误应该返回错误",
|
||||
err: nil,
|
||||
isNotFound: false,
|
||||
wantNilUser: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 测试错误处理逻辑:如果是RecordNotFound,返回nil用户;否则返回错误
|
||||
if tt.isNotFound {
|
||||
if !tt.wantNilUser {
|
||||
t.Error("RecordNotFound should return nil user")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
16
internal/repository/yggdrasil_repository.go
Normal file
16
internal/repository/yggdrasil_repository.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/database"
|
||||
)
|
||||
|
||||
func GetYggdrasilPasswordById(Id int64) (string, error) {
|
||||
db := database.MustGetDB()
|
||||
var yggdrasil model.Yggdrasil
|
||||
err := db.Where("id = ?", Id).First(&yggdrasil).Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return yggdrasil.Password, nil
|
||||
}
|
||||
165
internal/service/captcha_service.go
Normal file
165
internal/service/captcha_service.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"carrotskin/pkg/redis"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/wenlng/go-captcha-assets/resources/imagesv2"
|
||||
"github.com/wenlng/go-captcha-assets/resources/tiles"
|
||||
"github.com/wenlng/go-captcha/v2/slide"
|
||||
)
|
||||
|
||||
var (
|
||||
slideTileCapt slide.Captcha
|
||||
cfg *config.Config
|
||||
)
|
||||
|
||||
// 常量定义(业务相关配置,与Redis连接配置分离)
|
||||
const (
|
||||
redisKeyPrefix = "captcha:" // Redis键前缀(便于区分业务)
|
||||
paddingValue = 3 // 验证允许的误差像素(±3px)
|
||||
)
|
||||
|
||||
// Init 验证码图初始化
|
||||
func init() {
|
||||
cfg, _ = config.Load()
|
||||
// 从默认仓库中获取主图
|
||||
builder := slide.NewBuilder()
|
||||
bgImage, err := imagesv2.GetImages()
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
// 滑块形状获取
|
||||
graphs := getSlideTileGraphArr()
|
||||
|
||||
builder.SetResources(
|
||||
slide.WithGraphImages(graphs),
|
||||
slide.WithBackgrounds(bgImage),
|
||||
)
|
||||
slideTileCapt = builder.Make()
|
||||
if slideTileCapt == nil {
|
||||
log.Fatalln("验证码实例初始化失败")
|
||||
}
|
||||
}
|
||||
|
||||
// getSlideTileGraphArr 滑块选择
|
||||
func getSlideTileGraphArr() []*slide.GraphImage {
|
||||
graphs, err := tiles.GetTiles()
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
var newGraphs = make([]*slide.GraphImage, 0, len(graphs))
|
||||
for i := 0; i < len(graphs); i++ {
|
||||
graph := graphs[i]
|
||||
newGraphs = append(newGraphs, &slide.GraphImage{
|
||||
OverlayImage: graph.OverlayImage,
|
||||
MaskImage: graph.MaskImage,
|
||||
ShadowImage: graph.ShadowImage,
|
||||
})
|
||||
}
|
||||
return newGraphs
|
||||
}
|
||||
|
||||
// RedisData 存储到Redis的验证信息(仅包含校验必需字段)
|
||||
type RedisData struct {
|
||||
Tx int `json:"tx"` // 滑块目标X坐标
|
||||
Ty int `json:"ty"` // 滑块目标Y坐标
|
||||
}
|
||||
|
||||
// GenerateCaptchaData 提取生成验证码的相关信息
|
||||
func GenerateCaptchaData(ctx context.Context, redisClient *redis.Client) (string, string, string, int, error) {
|
||||
// 生成uuid作为验证码进程唯一标识
|
||||
captchaID := uuid.NewString()
|
||||
if captchaID == "" {
|
||||
return "", "", "", 0, errors.New("生成验证码唯一标识失败")
|
||||
}
|
||||
|
||||
captData, err := slideTileCapt.Generate()
|
||||
if err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("生成验证码失败: %w", err)
|
||||
}
|
||||
blockData := captData.GetData()
|
||||
if blockData == nil {
|
||||
return "", "", "", 0, errors.New("获取验证码数据失败")
|
||||
}
|
||||
block, _ := json.Marshal(blockData)
|
||||
var blockMap map[string]interface{}
|
||||
|
||||
if err := json.Unmarshal(block, &blockMap); err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("反序列化为map失败: %w", err)
|
||||
}
|
||||
// 提取x和y并转换为int类型
|
||||
tx, ok := blockMap["x"].(float64)
|
||||
if !ok {
|
||||
return "", "", "", 0, errors.New("无法将x转换为float64")
|
||||
}
|
||||
var x = int(tx)
|
||||
ty, ok := blockMap["y"].(float64)
|
||||
if !ok {
|
||||
return "", "", "", 0, errors.New("无法将y转换为float64")
|
||||
}
|
||||
var y = int(ty)
|
||||
var mBase64, tBase64 string
|
||||
mBase64, err = captData.GetMasterImage().ToBase64()
|
||||
if err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("主图转换为base64失败: %w", err)
|
||||
}
|
||||
tBase64, err = captData.GetTileImage().ToBase64()
|
||||
if err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("滑块图转换为base64失败: %w", err)
|
||||
}
|
||||
redisData := RedisData{
|
||||
Tx: x,
|
||||
Ty: y,
|
||||
}
|
||||
redisDataJSON, _ := json.Marshal(redisData)
|
||||
redisKey := redisKeyPrefix + captchaID
|
||||
expireTime := 300 * time.Second
|
||||
|
||||
// 使用注入的Redis客户端
|
||||
if err := redisClient.Set(
|
||||
ctx,
|
||||
redisKey,
|
||||
redisDataJSON,
|
||||
expireTime,
|
||||
); err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("存储验证码到Redis失败: %w", err)
|
||||
}
|
||||
return mBase64, tBase64, captchaID, y - 10, nil
|
||||
}
|
||||
|
||||
// VerifyCaptchaData 验证用户验证码
|
||||
func VerifyCaptchaData(ctx context.Context, redisClient *redis.Client, dx int, id string) (bool, error) {
|
||||
redisKey := redisKeyPrefix + id
|
||||
|
||||
// 从Redis获取验证信息,使用注入的客户端
|
||||
dataJSON, err := redisClient.Get(ctx, redisKey)
|
||||
if err != nil {
|
||||
if redisClient.Nil(err) { // 使用封装客户端的Nil错误
|
||||
return false, errors.New("验证码已过期或无效")
|
||||
}
|
||||
return false, fmt.Errorf("Redis查询失败: %w", err)
|
||||
}
|
||||
var redisData RedisData
|
||||
if err := json.Unmarshal([]byte(dataJSON), &redisData); err != nil {
|
||||
return false, fmt.Errorf("解析Redis数据失败: %w", err)
|
||||
}
|
||||
tx := redisData.Tx
|
||||
ty := redisData.Ty
|
||||
ok := slide.Validate(dx, ty, tx, ty, paddingValue)
|
||||
|
||||
// 验证后立即删除Redis记录(防止重复使用)
|
||||
if ok {
|
||||
if err := redisClient.Del(ctx, redisKey); err != nil {
|
||||
// 记录警告但不影响验证结果
|
||||
log.Printf("删除验证码Redis记录失败: %v", err)
|
||||
}
|
||||
}
|
||||
return ok, nil
|
||||
}
|
||||
174
internal/service/captcha_service_test.go
Normal file
174
internal/service/captcha_service_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestCaptchaService_Constants 测试验证码服务常量
|
||||
func TestCaptchaService_Constants(t *testing.T) {
|
||||
if redisKeyPrefix != "captcha:" {
|
||||
t.Errorf("redisKeyPrefix = %s, want 'captcha:'", redisKeyPrefix)
|
||||
}
|
||||
|
||||
if paddingValue != 3 {
|
||||
t.Errorf("paddingValue = %d, want 3", paddingValue)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedisData_Structure 测试RedisData结构
|
||||
func TestRedisData_Structure(t *testing.T) {
|
||||
data := RedisData{
|
||||
Tx: 100,
|
||||
Ty: 200,
|
||||
}
|
||||
|
||||
if data.Tx != 100 {
|
||||
t.Errorf("RedisData.Tx = %d, want 100", data.Tx)
|
||||
}
|
||||
|
||||
if data.Ty != 200 {
|
||||
t.Errorf("RedisData.Ty = %d, want 200", data.Ty)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateCaptchaData_Logic 测试生成验证码的逻辑部分
|
||||
func TestGenerateCaptchaData_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
captchaID string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "有效的captchaID",
|
||||
captchaID: "test-uuid-123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "空的captchaID应该失败",
|
||||
captchaID: "",
|
||||
wantErr: true,
|
||||
errContains: "生成验证码唯一标识失败",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 测试UUID验证逻辑
|
||||
if tt.captchaID == "" {
|
||||
if !tt.wantErr {
|
||||
t.Error("空captchaID应该返回错误")
|
||||
}
|
||||
} else {
|
||||
if tt.wantErr {
|
||||
t.Error("非空captchaID不应该返回错误")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyCaptchaData_Logic 测试验证验证码的逻辑部分
|
||||
func TestVerifyCaptchaData_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dx int
|
||||
tx int
|
||||
ty int
|
||||
padding int
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "精确匹配",
|
||||
dx: 100,
|
||||
tx: 100,
|
||||
ty: 200,
|
||||
padding: 3,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "在误差范围内(+3)",
|
||||
dx: 103,
|
||||
tx: 100,
|
||||
ty: 200,
|
||||
padding: 3,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "在误差范围内(-3)",
|
||||
dx: 97,
|
||||
tx: 100,
|
||||
ty: 200,
|
||||
padding: 3,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "超出误差范围(+4)",
|
||||
dx: 104,
|
||||
tx: 100,
|
||||
ty: 200,
|
||||
padding: 3,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "超出误差范围(-4)",
|
||||
dx: 96,
|
||||
tx: 100,
|
||||
ty: 200,
|
||||
padding: 3,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证逻辑:dx应该在[tx-padding, tx+padding]范围内
|
||||
diff := tt.dx - tt.tx
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
isValid := diff <= tt.padding
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v (dx=%d, tx=%d, padding=%d)", isValid, tt.wantValid, tt.dx, tt.tx, tt.padding)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyCaptchaData_RedisKey 测试Redis键生成逻辑
|
||||
func TestVerifyCaptchaData_RedisKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "生成正确的Redis键",
|
||||
id: "test-id-123",
|
||||
expected: "captcha:test-id-123",
|
||||
},
|
||||
{
|
||||
name: "空ID",
|
||||
id: "",
|
||||
expected: "captcha:",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
redisKey := redisKeyPrefix + tt.id
|
||||
if redisKey != tt.expected {
|
||||
t.Errorf("Redis key = %s, want %s", redisKey, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateCaptchaData_ExpireTime 测试过期时间
|
||||
func TestGenerateCaptchaData_ExpireTime(t *testing.T) {
|
||||
expectedExpireTime := 300 * time.Second
|
||||
if expectedExpireTime != 5*time.Minute {
|
||||
t.Errorf("Expire time should be 5 minutes")
|
||||
}
|
||||
}
|
||||
13
internal/service/common.go
Normal file
13
internal/service/common.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
jsoniter "github.com/json-iterator/go"
|
||||
)
|
||||
|
||||
// 统一的json变量,用于整个service包
|
||||
var json = jsoniter.ConfigCompatibleWithStandardLibrary
|
||||
|
||||
// DefaultTimeout 默认超时时间
|
||||
const DefaultTimeout = 5 * time.Second
|
||||
48
internal/service/common_test.go
Normal file
48
internal/service/common_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestCommon_Constants 测试common包的常量
|
||||
func TestCommon_Constants(t *testing.T) {
|
||||
if DefaultTimeout != 5*time.Second {
|
||||
t.Errorf("DefaultTimeout = %v, want 5 seconds", DefaultTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCommon_JSON 测试JSON变量
|
||||
func TestCommon_JSON(t *testing.T) {
|
||||
// 验证json变量不为nil
|
||||
if json == nil {
|
||||
t.Error("json 变量不应为nil")
|
||||
}
|
||||
|
||||
// 测试JSON序列化
|
||||
testData := map[string]interface{}{
|
||||
"name": "test",
|
||||
"age": 25,
|
||||
}
|
||||
|
||||
bytes, err := json.Marshal(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() 失败: %v", err)
|
||||
}
|
||||
|
||||
if len(bytes) == 0 {
|
||||
t.Error("json.Marshal() 返回的字节不应为空")
|
||||
}
|
||||
|
||||
// 测试JSON反序列化
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(bytes, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Unmarshal() 失败: %v", err)
|
||||
}
|
||||
|
||||
if result["name"] != "test" {
|
||||
t.Errorf("反序列化结果 name = %v, want 'test'", result["name"])
|
||||
}
|
||||
}
|
||||
|
||||
252
internal/service/profile_service.go
Normal file
252
internal/service/profile_service.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateProfile 创建档案
|
||||
func CreateProfile(db *gorm.DB, userID int64, name string) (*model.Profile, error) {
|
||||
// 1. 验证用户存在
|
||||
user, err := repository.FindUserByID(userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("用户不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询用户失败: %w", err)
|
||||
}
|
||||
|
||||
if user.Status != 1 {
|
||||
return nil, fmt.Errorf("用户状态异常")
|
||||
}
|
||||
|
||||
// 2. 检查角色名是否已存在
|
||||
existingName, err := repository.FindProfileByName(name)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("查询角色名失败: %w", err)
|
||||
}
|
||||
if existingName != nil {
|
||||
return nil, fmt.Errorf("角色名已被使用")
|
||||
}
|
||||
|
||||
// 3. 生成UUID
|
||||
profileUUID := uuid.New().String()
|
||||
|
||||
// 4. 生成RSA密钥对
|
||||
privateKey, err := generateRSAPrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成RSA密钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 5. 创建档案
|
||||
profile := &model.Profile{
|
||||
UUID: profileUUID,
|
||||
UserID: userID,
|
||||
Name: name,
|
||||
RSAPrivateKey: privateKey,
|
||||
IsActive: true, // 新创建的档案默认为活跃状态
|
||||
}
|
||||
|
||||
if err := repository.CreateProfile(profile); err != nil {
|
||||
return nil, fmt.Errorf("创建档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 6. 将用户的其他档案设置为非活跃
|
||||
if err := repository.SetActiveProfile(profileUUID, userID); err != nil {
|
||||
return nil, fmt.Errorf("设置活跃状态失败: %w", err)
|
||||
}
|
||||
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
// GetProfileByUUID 获取档案详情
|
||||
func GetProfileByUUID(db *gorm.DB, uuid string) (*model.Profile, error) {
|
||||
profile, err := repository.FindProfileByUUID(uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("档案不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
// GetUserProfiles 获取用户的所有档案
|
||||
func GetUserProfiles(db *gorm.DB, userID int64) ([]*model.Profile, error) {
|
||||
profiles, err := repository.FindProfilesByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询档案列表失败: %w", err)
|
||||
}
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// UpdateProfile 更新档案
|
||||
func UpdateProfile(db *gorm.DB, uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) {
|
||||
// 1. 查询档案
|
||||
profile, err := repository.FindProfileByUUID(uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("档案不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 验证权限
|
||||
if profile.UserID != userID {
|
||||
return nil, fmt.Errorf("无权操作此档案")
|
||||
}
|
||||
|
||||
// 3. 检查角色名是否重复
|
||||
if name != nil && *name != profile.Name {
|
||||
existingName, err := repository.FindProfileByName(*name)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("查询角色名失败: %w", err)
|
||||
}
|
||||
if existingName != nil {
|
||||
return nil, fmt.Errorf("角色名已被使用")
|
||||
}
|
||||
profile.Name = *name
|
||||
}
|
||||
|
||||
// 4. 更新皮肤和披风
|
||||
if skinID != nil {
|
||||
profile.SkinID = skinID
|
||||
}
|
||||
if capeID != nil {
|
||||
profile.CapeID = capeID
|
||||
}
|
||||
|
||||
// 5. 保存更新
|
||||
if err := repository.UpdateProfile(profile); err != nil {
|
||||
return nil, fmt.Errorf("更新档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 6. 重新加载关联数据
|
||||
return repository.FindProfileByUUID(uuid)
|
||||
}
|
||||
|
||||
// DeleteProfile 删除档案
|
||||
func DeleteProfile(db *gorm.DB, uuid string, userID int64) error {
|
||||
// 1. 查询档案
|
||||
profile, err := repository.FindProfileByUUID(uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("档案不存在")
|
||||
}
|
||||
return fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 验证权限
|
||||
if profile.UserID != userID {
|
||||
return fmt.Errorf("无权操作此档案")
|
||||
}
|
||||
|
||||
// 3. 删除档案
|
||||
if err := repository.DeleteProfile(uuid); err != nil {
|
||||
return fmt.Errorf("删除档案失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetActiveProfile 设置活跃档案
|
||||
func SetActiveProfile(db *gorm.DB, uuid string, userID int64) error {
|
||||
// 1. 查询档案
|
||||
profile, err := repository.FindProfileByUUID(uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("档案不存在")
|
||||
}
|
||||
return fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 验证权限
|
||||
if profile.UserID != userID {
|
||||
return fmt.Errorf("无权操作此档案")
|
||||
}
|
||||
|
||||
// 3. 设置活跃状态
|
||||
if err := repository.SetActiveProfile(uuid, userID); err != nil {
|
||||
return fmt.Errorf("设置活跃状态失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 更新最后使用时间
|
||||
if err := repository.UpdateProfileLastUsedAt(uuid); err != nil {
|
||||
return fmt.Errorf("更新使用时间失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckProfileLimit 检查用户档案数量限制
|
||||
func CheckProfileLimit(db *gorm.DB, userID int64, maxProfiles int) error {
|
||||
count, err := repository.CountProfilesByUserID(userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询档案数量失败: %w", err)
|
||||
}
|
||||
|
||||
if int(count) >= maxProfiles {
|
||||
return fmt.Errorf("已达到档案数量上限(%d个)", maxProfiles)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateRSAPrivateKey 生成RSA-2048私钥(PEM格式)
|
||||
func generateRSAPrivateKey() (string, error) {
|
||||
// 生成2048位RSA密钥对
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 将私钥编码为PEM格式
|
||||
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: privateKeyBytes,
|
||||
})
|
||||
|
||||
return string(privateKeyPEM), nil
|
||||
}
|
||||
|
||||
func ValidateProfileByUserID(db *gorm.DB, userId int64, UUID string) (bool, error) {
|
||||
if userId == 0 || UUID == "" {
|
||||
return false, errors.New("用户ID或配置文件ID不能为空")
|
||||
}
|
||||
|
||||
profile, err := repository.FindProfileByUUID(UUID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return false, errors.New("配置文件不存在")
|
||||
}
|
||||
return false, fmt.Errorf("验证配置文件失败: %w", err)
|
||||
}
|
||||
return profile.UserID == userId, nil
|
||||
}
|
||||
|
||||
func GetProfilesDataByNames(db *gorm.DB, names []string) ([]*model.Profile, error) {
|
||||
profiles, err := repository.GetProfilesByNames(names)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查找失败: %w", err)
|
||||
}
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// GetProfileKeyPair 从 PostgreSQL 获取密钥对(GORM 实现,无手动 SQL)
|
||||
func GetProfileKeyPair(db *gorm.DB, profileId string) (*model.KeyPair, error) {
|
||||
keyPair, err := repository.GetProfileKeyPair(profileId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查找失败: %w", err)
|
||||
}
|
||||
return keyPair, nil
|
||||
}
|
||||
406
internal/service/profile_service_test.go
Normal file
406
internal/service/profile_service_test.go
Normal file
@@ -0,0 +1,406 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestProfileService_Validation 测试Profile服务验证逻辑
|
||||
func TestProfileService_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
profileName string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的用户ID和角色名",
|
||||
userID: 1,
|
||||
profileName: "TestProfile",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户ID为0时无效",
|
||||
userID: 0,
|
||||
profileName: "TestProfile",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "角色名为空时无效",
|
||||
userID: 1,
|
||||
profileName: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.userID > 0 && tt.profileName != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileService_StatusValidation 测试用户状态验证
|
||||
func TestProfileService_StatusValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status int16
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "状态为1(正常)时有效",
|
||||
status: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "状态为0(禁用)时无效",
|
||||
status: 0,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "状态为-1(删除)时无效",
|
||||
status: -1,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.status == 1
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Status validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileService_IsActiveDefault 测试Profile默认活跃状态
|
||||
func TestProfileService_IsActiveDefault(t *testing.T) {
|
||||
// 新创建的档案默认为活跃状态
|
||||
isActive := true
|
||||
if !isActive {
|
||||
t.Error("新创建的Profile应该默认为活跃状态")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateProfile_PermissionCheck 测试更新Profile的权限检查逻辑
|
||||
func TestUpdateProfile_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileUserID int64
|
||||
requestUserID int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "用户ID匹配,允许操作",
|
||||
profileUserID: 1,
|
||||
requestUserID: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "用户ID不匹配,拒绝操作",
|
||||
profileUserID: 1,
|
||||
requestUserID: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.profileUserID != tt.requestUserID
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateProfile_NameValidation 测试更新Profile时名称验证逻辑
|
||||
func TestUpdateProfile_NameValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentName string
|
||||
newName *string
|
||||
shouldCheck bool
|
||||
}{
|
||||
{
|
||||
name: "名称未改变,不检查",
|
||||
currentName: "TestProfile",
|
||||
newName: stringPtr("TestProfile"),
|
||||
shouldCheck: false,
|
||||
},
|
||||
{
|
||||
name: "名称改变,需要检查",
|
||||
currentName: "TestProfile",
|
||||
newName: stringPtr("NewProfile"),
|
||||
shouldCheck: true,
|
||||
},
|
||||
{
|
||||
name: "名称为nil,不检查",
|
||||
currentName: "TestProfile",
|
||||
newName: nil,
|
||||
shouldCheck: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldCheck := tt.newName != nil && *tt.newName != tt.currentName
|
||||
if shouldCheck != tt.shouldCheck {
|
||||
t.Errorf("Name validation check failed: got %v, want %v", shouldCheck, tt.shouldCheck)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteProfile_PermissionCheck 测试删除Profile的权限检查
|
||||
func TestDeleteProfile_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileUserID int64
|
||||
requestUserID int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "用户ID匹配,允许删除",
|
||||
profileUserID: 1,
|
||||
requestUserID: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "用户ID不匹配,拒绝删除",
|
||||
profileUserID: 1,
|
||||
requestUserID: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.profileUserID != tt.requestUserID
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetActiveProfile_PermissionCheck 测试设置活跃Profile的权限检查
|
||||
func TestSetActiveProfile_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileUserID int64
|
||||
requestUserID int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "用户ID匹配,允许设置",
|
||||
profileUserID: 1,
|
||||
requestUserID: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "用户ID不匹配,拒绝设置",
|
||||
profileUserID: 1,
|
||||
requestUserID: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.profileUserID != tt.requestUserID
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckProfileLimit_Logic 测试Profile数量限制检查逻辑
|
||||
func TestCheckProfileLimit_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
count int
|
||||
maxProfiles int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "未达到上限",
|
||||
count: 5,
|
||||
maxProfiles: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "达到上限",
|
||||
count: 10,
|
||||
maxProfiles: 10,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "超过上限",
|
||||
count: 15,
|
||||
maxProfiles: 10,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.count >= tt.maxProfiles
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Limit check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateProfileByUserID_InputValidation 测试ValidateProfileByUserID输入验证
|
||||
func TestValidateProfileByUserID_InputValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
uuid string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "有效输入",
|
||||
userID: 1,
|
||||
uuid: "test-uuid",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "userID为0",
|
||||
userID: 0,
|
||||
uuid: "test-uuid",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "uuid为空",
|
||||
userID: 1,
|
||||
uuid: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "两者都无效",
|
||||
userID: 0,
|
||||
uuid: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.userID == 0 || tt.uuid == ""
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateProfileByUserID_UserIDMatching 测试用户ID匹配逻辑
|
||||
func TestValidateProfileByUserID_UserIDMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileUserID int64
|
||||
requestUserID int64
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "用户ID匹配",
|
||||
profileUserID: 1,
|
||||
requestUserID: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户ID不匹配",
|
||||
profileUserID: 1,
|
||||
requestUserID: 2,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.profileUserID == tt.requestUserID
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("UserID matching failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateRSAPrivateKey 测试RSA私钥生成
|
||||
func TestGenerateRSAPrivateKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "生成RSA私钥",
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
privateKey, err := generateRSAPrivateKey()
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("generateRSAPrivateKey() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
if !tt.wantError {
|
||||
if privateKey == "" {
|
||||
t.Error("generateRSAPrivateKey() 返回的私钥不应为空")
|
||||
}
|
||||
// 验证PEM格式
|
||||
if len(privateKey) < 100 {
|
||||
t.Errorf("generateRSAPrivateKey() 返回的私钥长度异常: %d", len(privateKey))
|
||||
}
|
||||
// 验证包含PEM头部
|
||||
if !contains(privateKey, "BEGIN RSA PRIVATE KEY") {
|
||||
t.Error("generateRSAPrivateKey() 返回的私钥应包含PEM头部")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateRSAPrivateKey_Uniqueness 测试RSA私钥唯一性
|
||||
func TestGenerateRSAPrivateKey_Uniqueness(t *testing.T) {
|
||||
keys := make(map[string]bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
key, err := generateRSAPrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("generateRSAPrivateKey() 失败: %v", err)
|
||||
}
|
||||
if keys[key] {
|
||||
t.Errorf("第%d次生成的密钥与之前重复", i+1)
|
||||
}
|
||||
keys[key] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr ||
|
||||
(len(s) > len(substr) && (s[:len(substr)] == substr ||
|
||||
s[len(s)-len(substr):] == substr ||
|
||||
containsMiddle(s, substr))))
|
||||
}
|
||||
|
||||
func containsMiddle(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
97
internal/service/serialize_service.go
Normal file
97
internal/service/serialize_service.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/redis"
|
||||
"encoding/base64"
|
||||
"go.uber.org/zap"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Property struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
func SerializeProfile(db *gorm.DB, logger *zap.Logger, redisClient *redis.Client, p model.Profile) map[string]interface{} {
|
||||
var err error
|
||||
|
||||
// 创建基本材质数据
|
||||
texturesMap := make(map[string]interface{})
|
||||
textures := map[string]interface{}{
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
"profileId": p.UUID,
|
||||
"profileName": p.Name,
|
||||
"textures": texturesMap,
|
||||
}
|
||||
|
||||
// 处理皮肤
|
||||
if p.SkinID != nil {
|
||||
skin, err := GetTextureByID(db, *p.SkinID)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 获取皮肤失败:", zap.Error(err), zap.Any("SkinID:", *p.SkinID))
|
||||
} else {
|
||||
texturesMap["SKIN"] = map[string]interface{}{
|
||||
"url": skin.URL,
|
||||
"metadata": skin.Size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理披风
|
||||
if p.CapeID != nil {
|
||||
cape, err := GetTextureByID(db, *p.CapeID)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 获取披风失败:", zap.Error(err), zap.Any("capeID:", *p.CapeID))
|
||||
} else {
|
||||
texturesMap["CAPE"] = map[string]interface{}{
|
||||
"url": cape.URL,
|
||||
"metadata": cape.Size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 将textures编码为base64
|
||||
bytes, err := json.Marshal(textures)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 序列化textures失败: ", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
|
||||
textureData := base64.StdEncoding.EncodeToString(bytes)
|
||||
signature, err := SignStringWithSHA1withRSA(logger, redisClient, textureData)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 签名textures失败: ", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 构建结果
|
||||
data := map[string]interface{}{
|
||||
"id": p.UUID,
|
||||
"name": p.Name,
|
||||
"properties": []Property{
|
||||
{
|
||||
Name: "textures",
|
||||
Value: textureData,
|
||||
Signature: signature,
|
||||
},
|
||||
},
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func SerializeUser(logger *zap.Logger, u *model.User, UUID string) map[string]interface{} {
|
||||
if u == nil {
|
||||
logger.Error("[ERROR] 尝试序列化空用户")
|
||||
return nil
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"id": UUID,
|
||||
"properties": u.Properties,
|
||||
}
|
||||
return data
|
||||
}
|
||||
172
internal/service/serialize_service_test.go
Normal file
172
internal/service/serialize_service_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestSerializeUser_NilUser 实际调用SerializeUser函数测试nil用户
|
||||
func TestSerializeUser_NilUser(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
result := SerializeUser(logger, nil, "test-uuid")
|
||||
if result != nil {
|
||||
t.Error("SerializeUser() 对于nil用户应返回nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerializeUser_ActualCall 实际调用SerializeUser函数
|
||||
func TestSerializeUser_ActualCall(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
user := &model.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Properties: "{}",
|
||||
}
|
||||
|
||||
result := SerializeUser(logger, user, "test-uuid-123")
|
||||
if result == nil {
|
||||
t.Fatal("SerializeUser() 返回的结果不应为nil")
|
||||
}
|
||||
|
||||
if result["id"] != "test-uuid-123" {
|
||||
t.Errorf("id = %v, want 'test-uuid-123'", result["id"])
|
||||
}
|
||||
|
||||
if result["properties"] == nil {
|
||||
t.Error("properties 不应为nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProperty_Structure 测试Property结构
|
||||
func TestProperty_Structure(t *testing.T) {
|
||||
prop := Property{
|
||||
Name: "textures",
|
||||
Value: "base64value",
|
||||
Signature: "signature",
|
||||
}
|
||||
|
||||
if prop.Name == "" {
|
||||
t.Error("Property name should not be empty")
|
||||
}
|
||||
|
||||
if prop.Value == "" {
|
||||
t.Error("Property value should not be empty")
|
||||
}
|
||||
|
||||
// Signature是可选的
|
||||
if prop.Signature == "" {
|
||||
t.Log("Property signature is optional")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerializeService_PropertyFields 测试Property字段
|
||||
func TestSerializeService_PropertyFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
property Property
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的Property",
|
||||
property: Property{
|
||||
Name: "textures",
|
||||
Value: "base64value",
|
||||
Signature: "signature",
|
||||
},
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "缺少Name的Property",
|
||||
property: Property{
|
||||
Name: "",
|
||||
Value: "base64value",
|
||||
Signature: "signature",
|
||||
},
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "缺少Value的Property",
|
||||
property: Property{
|
||||
Name: "textures",
|
||||
Value: "",
|
||||
Signature: "signature",
|
||||
},
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "没有Signature的Property(有效)",
|
||||
property: Property{
|
||||
Name: "textures",
|
||||
Value: "base64value",
|
||||
Signature: "",
|
||||
},
|
||||
wantValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.property.Name != "" && tt.property.Value != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Property validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerializeUser_InputValidation 测试SerializeUser输入验证
|
||||
func TestSerializeUser_InputValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user *struct{}
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "用户不为nil",
|
||||
user: &struct{}{},
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户为nil",
|
||||
user: nil,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.user != nil
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Input validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerializeProfile_Structure 测试SerializeProfile返回结构
|
||||
func TestSerializeProfile_Structure(t *testing.T) {
|
||||
// 测试返回的数据结构应该包含的字段
|
||||
expectedFields := []string{"id", "name", "properties"}
|
||||
|
||||
// 验证字段名称
|
||||
for _, field := range expectedFields {
|
||||
if field == "" {
|
||||
t.Error("Field name should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// 验证properties应该是数组
|
||||
// 注意:这里只测试逻辑,不测试实际序列化
|
||||
}
|
||||
|
||||
// TestSerializeProfile_PropertyName 测试Property名称
|
||||
func TestSerializeProfile_PropertyName(t *testing.T) {
|
||||
// textures是固定的属性名
|
||||
propertyName := "textures"
|
||||
if propertyName != "textures" {
|
||||
t.Errorf("Property name = %s, want 'textures'", propertyName)
|
||||
}
|
||||
}
|
||||
605
internal/service/signature_service.go
Normal file
605
internal/service/signature_service.go
Normal file
@@ -0,0 +1,605 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"carrotskin/pkg/redis"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"go.uber.org/zap"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 常量定义
|
||||
const (
|
||||
// RSA密钥长度
|
||||
RSAKeySize = 4096
|
||||
|
||||
// Redis密钥名称
|
||||
PrivateKeyRedisKey = "private_key"
|
||||
PublicKeyRedisKey = "public_key"
|
||||
|
||||
// 密钥过期时间
|
||||
KeyExpirationTime = time.Hour * 24 * 7
|
||||
|
||||
// 证书相关
|
||||
CertificateRefreshInterval = time.Hour * 24 // 证书刷新时间间隔
|
||||
CertificateExpirationPeriod = time.Hour * 24 * 7 // 证书过期时间
|
||||
)
|
||||
|
||||
// PlayerCertificate 表示玩家证书信息
|
||||
type PlayerCertificate struct {
|
||||
ExpiresAt string `json:"expiresAt"`
|
||||
RefreshedAfter string `json:"refreshedAfter"`
|
||||
PublicKeySignature string `json:"publicKeySignature,omitempty"`
|
||||
PublicKeySignatureV2 string `json:"publicKeySignatureV2,omitempty"`
|
||||
KeyPair struct {
|
||||
PrivateKey string `json:"privateKey"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
} `json:"keyPair"`
|
||||
}
|
||||
// SignatureService 保留结构体以保持向后兼容,但推荐使用函数式版本
|
||||
type SignatureService struct {
|
||||
logger *zap.Logger
|
||||
redisClient *redis.Client
|
||||
}
|
||||
|
||||
func NewSignatureService(logger *zap.Logger, redisClient *redis.Client) *SignatureService {
|
||||
return &SignatureService{
|
||||
logger: logger,
|
||||
redisClient: redisClient,
|
||||
}
|
||||
}
|
||||
|
||||
// SignStringWithSHA1withRSA 使用SHA1withRSA签名字符串并返回Base64编码的签名(函数式版本)
|
||||
func SignStringWithSHA1withRSA(logger *zap.Logger, redisClient *redis.Client, data string) (string, error) {
|
||||
if data == "" {
|
||||
return "", fmt.Errorf("签名数据不能为空")
|
||||
}
|
||||
|
||||
// 获取私钥
|
||||
privateKey, err := DecodePrivateKeyFromPEM(logger, redisClient)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 解码私钥失败: ", zap.Error(err))
|
||||
return "", fmt.Errorf("解码私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 计算SHA1哈希
|
||||
hashed := sha1.Sum([]byte(data))
|
||||
|
||||
// 使用RSA-PKCS1v15算法签名
|
||||
signature, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA1, hashed[:])
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] RSA签名失败: ", zap.Error(err))
|
||||
return "", fmt.Errorf("RSA签名失败: %w", err)
|
||||
}
|
||||
|
||||
// Base64编码签名
|
||||
encodedSignature := base64.StdEncoding.EncodeToString(signature)
|
||||
|
||||
logger.Info("[INFO] 成功使用SHA1withRSA生成签名,", zap.Any("数据长度:", len(data)))
|
||||
return encodedSignature, nil
|
||||
}
|
||||
|
||||
// SignStringWithSHA1withRSAService 使用SHA1withRSA签名字符串并返回Base64编码的签名(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) SignStringWithSHA1withRSA(data string) (string, error) {
|
||||
return SignStringWithSHA1withRSA(s.logger, s.redisClient, data)
|
||||
}
|
||||
|
||||
// DecodePrivateKeyFromPEM 从Redis获取并解码PEM格式的私钥(函数式版本)
|
||||
func DecodePrivateKeyFromPEM(logger *zap.Logger, redisClient *redis.Client) (*rsa.PrivateKey, error) {
|
||||
// 从Redis获取私钥
|
||||
privateKeyString, err := GetPrivateKeyFromRedis(logger, redisClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("从Redis获取私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 解码PEM格式
|
||||
privateKeyBlock, rest := pem.Decode([]byte(privateKeyString))
|
||||
if privateKeyBlock == nil || len(rest) > 0 {
|
||||
logger.Error("[ERROR] 无效的PEM格式私钥")
|
||||
return nil, fmt.Errorf("无效的PEM格式私钥")
|
||||
}
|
||||
|
||||
// 解析PKCS1格式的私钥
|
||||
privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 解析私钥失败: ", zap.Error(err))
|
||||
return nil, fmt.Errorf("解析私钥失败: %w", err)
|
||||
}
|
||||
|
||||
return privateKey, nil
|
||||
}
|
||||
|
||||
// GetPrivateKeyFromRedis 从Redis获取私钥(PEM格式)(函数式版本)
|
||||
func GetPrivateKeyFromRedis(logger *zap.Logger, redisClient *redis.Client) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
pemBytes, err := redisClient.GetBytes(ctx, PrivateKeyRedisKey)
|
||||
if err != nil {
|
||||
logger.Info("[INFO] 从Redis获取私钥失败,尝试生成新的密钥对: ", zap.Error(err))
|
||||
|
||||
// 生成新的密钥对
|
||||
err = GenerateRSAKeyPair(logger, redisClient)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成RSA密钥对失败: ", zap.Error(err))
|
||||
return "", fmt.Errorf("生成RSA密钥对失败: %w", err)
|
||||
}
|
||||
|
||||
// 递归获取生成的密钥
|
||||
return GetPrivateKeyFromRedis(logger, redisClient)
|
||||
}
|
||||
|
||||
return string(pemBytes), nil
|
||||
}
|
||||
|
||||
// DecodePrivateKeyFromPEMService 从Redis获取并解码PEM格式的私钥(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) DecodePrivateKeyFromPEM() (*rsa.PrivateKey, error) {
|
||||
return DecodePrivateKeyFromPEM(s.logger, s.redisClient)
|
||||
}
|
||||
|
||||
// GetPrivateKeyFromRedisService 从Redis获取私钥(PEM格式)(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) GetPrivateKeyFromRedis() (string, error) {
|
||||
return GetPrivateKeyFromRedis(s.logger, s.redisClient)
|
||||
}
|
||||
|
||||
// GenerateRSAKeyPair 生成新的RSA密钥对(函数式版本)
|
||||
func GenerateRSAKeyPair(logger *zap.Logger, redisClient *redis.Client) error {
|
||||
logger.Info("[INFO] 开始生成RSA密钥对", zap.Int("keySize", RSAKeySize))
|
||||
|
||||
// 生成私钥
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, RSAKeySize)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成RSA私钥失败: ", zap.Error(err))
|
||||
return fmt.Errorf("生成RSA私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 编码私钥为PEM格式
|
||||
pemPrivateKey, err := EncodePrivateKeyToPEM(privateKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 编码RSA私钥失败: ", zap.Error(err))
|
||||
return fmt.Errorf("编码RSA私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取公钥并编码为PEM格式
|
||||
pubKey := privateKey.PublicKey
|
||||
pemPublicKey, err := EncodePublicKeyToPEM(logger, &pubKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 编码RSA公钥失败: ", zap.Error(err))
|
||||
return fmt.Errorf("编码RSA公钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存密钥对到Redis
|
||||
return SaveKeyPairToRedis(logger, redisClient, string(pemPrivateKey), string(pemPublicKey))
|
||||
}
|
||||
|
||||
// GenerateRSAKeyPairService 生成新的RSA密钥对(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) GenerateRSAKeyPair() error {
|
||||
return GenerateRSAKeyPair(s.logger, s.redisClient)
|
||||
}
|
||||
|
||||
// EncodePrivateKeyToPEM 将私钥编码为PEM格式(函数式版本)
|
||||
func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey, keyType ...string) ([]byte, error) {
|
||||
if privateKey == nil {
|
||||
return nil, fmt.Errorf("私钥不能为空")
|
||||
}
|
||||
|
||||
// 默认使用 "PRIVATE KEY" 类型
|
||||
pemType := "PRIVATE KEY"
|
||||
|
||||
// 如果指定了类型参数且为 "RSA",则使用 "RSA PRIVATE KEY"
|
||||
if len(keyType) > 0 && keyType[0] == "RSA" {
|
||||
pemType = "RSA PRIVATE KEY"
|
||||
}
|
||||
|
||||
// 将私钥转换为PKCS1格式
|
||||
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||
|
||||
// 编码为PEM格式
|
||||
pemBlock := &pem.Block{
|
||||
Type: pemType,
|
||||
Bytes: privateKeyBytes,
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(pemBlock), nil
|
||||
}
|
||||
|
||||
// EncodePublicKeyToPEM 将公钥编码为PEM格式(函数式版本)
|
||||
func EncodePublicKeyToPEM(logger *zap.Logger, publicKey *rsa.PublicKey, keyType ...string) ([]byte, error) {
|
||||
if publicKey == nil {
|
||||
return nil, fmt.Errorf("公钥不能为空")
|
||||
}
|
||||
|
||||
// 默认使用 "PUBLIC KEY" 类型
|
||||
pemType := "PUBLIC KEY"
|
||||
var publicKeyBytes []byte
|
||||
var err error
|
||||
|
||||
// 如果指定了类型参数且为 "RSA",则使用 "RSA PUBLIC KEY"
|
||||
if len(keyType) > 0 && keyType[0] == "RSA" {
|
||||
pemType = "RSA PUBLIC KEY"
|
||||
publicKeyBytes = x509.MarshalPKCS1PublicKey(publicKey)
|
||||
} else {
|
||||
// 默认将公钥转换为PKIX格式
|
||||
publicKeyBytes, err = x509.MarshalPKIXPublicKey(publicKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 序列化公钥失败: ", zap.Error(err))
|
||||
return nil, fmt.Errorf("序列化公钥失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 编码为PEM格式
|
||||
pemBlock := &pem.Block{
|
||||
Type: pemType,
|
||||
Bytes: publicKeyBytes,
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(pemBlock), nil
|
||||
}
|
||||
|
||||
// SaveKeyPairToRedis 将RSA密钥对保存到Redis(函数式版本)
|
||||
func SaveKeyPairToRedis(logger *zap.Logger, redisClient *redis.Client, privateKey, publicKey string) error {
|
||||
// 创建上下文并设置超时
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
// 使用事务确保两个操作的原子性
|
||||
tx := redisClient.TxPipeline()
|
||||
|
||||
tx.Set(ctx, PrivateKeyRedisKey, privateKey, KeyExpirationTime)
|
||||
tx.Set(ctx, PublicKeyRedisKey, publicKey, KeyExpirationTime)
|
||||
|
||||
// 执行事务
|
||||
_, err := tx.Exec(ctx)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 保存RSA密钥对到Redis失败: ", zap.Error(err))
|
||||
return fmt.Errorf("保存RSA密钥对到Redis失败: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("[INFO] 成功保存RSA密钥对到Redis")
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodePrivateKeyToPEMService 将私钥编码为PEM格式(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey, keyType ...string) ([]byte, error) {
|
||||
return EncodePrivateKeyToPEM(privateKey, keyType...)
|
||||
}
|
||||
|
||||
// EncodePublicKeyToPEMService 将公钥编码为PEM格式(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) EncodePublicKeyToPEM(publicKey *rsa.PublicKey, keyType ...string) ([]byte, error) {
|
||||
return EncodePublicKeyToPEM(s.logger, publicKey, keyType...)
|
||||
}
|
||||
|
||||
// SaveKeyPairToRedisService 将RSA密钥对保存到Redis(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) SaveKeyPairToRedis(privateKey, publicKey string) error {
|
||||
return SaveKeyPairToRedis(s.logger, s.redisClient, privateKey, publicKey)
|
||||
}
|
||||
|
||||
// GetPublicKeyFromRedisFunc 从Redis获取公钥(PEM格式,函数式版本)
|
||||
func GetPublicKeyFromRedisFunc(logger *zap.Logger, redisClient *redis.Client) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
pemBytes, err := redisClient.GetBytes(ctx, PublicKeyRedisKey)
|
||||
if err != nil {
|
||||
logger.Info("[INFO] 从Redis获取公钥失败,尝试生成新的密钥对: ", zap.Error(err))
|
||||
|
||||
// 生成新的密钥对
|
||||
err = GenerateRSAKeyPair(logger, redisClient)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成RSA密钥对失败: ", zap.Error(err))
|
||||
return "", fmt.Errorf("生成RSA密钥对失败: %w", err)
|
||||
}
|
||||
|
||||
// 递归获取生成的密钥
|
||||
return GetPublicKeyFromRedisFunc(logger, redisClient)
|
||||
}
|
||||
|
||||
// 检查获取到的公钥是否为空(key不存在时GetBytes返回nil, nil)
|
||||
if len(pemBytes) == 0 {
|
||||
logger.Info("[INFO] Redis中公钥为空,尝试生成新的密钥对")
|
||||
// 生成新的密钥对
|
||||
err = GenerateRSAKeyPair(logger, redisClient)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成RSA密钥对失败: ", zap.Error(err))
|
||||
return "", fmt.Errorf("生成RSA密钥对失败: %w", err)
|
||||
}
|
||||
// 递归获取生成的密钥
|
||||
return GetPublicKeyFromRedisFunc(logger, redisClient)
|
||||
}
|
||||
|
||||
return string(pemBytes), nil
|
||||
}
|
||||
|
||||
// GetPublicKeyFromRedis 从Redis获取公钥(PEM格式,结构体方法版本)
|
||||
func (s *SignatureService) GetPublicKeyFromRedis() (string, error) {
|
||||
return GetPublicKeyFromRedisFunc(s.logger, s.redisClient)
|
||||
}
|
||||
|
||||
|
||||
// GeneratePlayerCertificate 生成玩家证书(函数式版本)
|
||||
func GeneratePlayerCertificate(db *gorm.DB, logger *zap.Logger, redisClient *redis.Client, uuid string) (*PlayerCertificate, error) {
|
||||
if uuid == "" {
|
||||
return nil, fmt.Errorf("UUID不能为空")
|
||||
}
|
||||
logger.Info("[INFO] 开始生成玩家证书,用户UUID: %s",
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
|
||||
keyPair, err := repository.GetProfileKeyPair(uuid)
|
||||
if err != nil {
|
||||
logger.Info("[INFO] 获取用户密钥对失败,将创建新密钥对: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
keyPair = nil
|
||||
}
|
||||
|
||||
// 如果没有找到密钥对或密钥对已过期,创建一个新的
|
||||
now := time.Now().UTC()
|
||||
if keyPair == nil || keyPair.Refresh.Before(now) || keyPair.PrivateKey == "" || keyPair.PublicKey == "" {
|
||||
logger.Info("[INFO] 为用户创建新的密钥对: %s",
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
keyPair, err = NewKeyPair(logger)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成玩家证书密钥对失败: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
return nil, fmt.Errorf("生成玩家证书密钥对失败: %w", err)
|
||||
}
|
||||
// 保存密钥对到数据库
|
||||
err = repository.UpdateProfileKeyPair(uuid, keyPair)
|
||||
if err != nil {
|
||||
// 日志修改:logger → s.logger,zap结构化字段
|
||||
logger.Warn("[WARN] 更新用户密钥对失败: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
// 继续执行,即使保存失败
|
||||
}
|
||||
}
|
||||
|
||||
// 计算expiresAt的毫秒时间戳
|
||||
expiresAtMillis := keyPair.Expiration.UnixMilli()
|
||||
|
||||
// 准备签名
|
||||
publicKeySignature := ""
|
||||
publicKeySignatureV2 := ""
|
||||
|
||||
// 获取服务器私钥用于签名
|
||||
serverPrivateKey, err := DecodePrivateKeyFromPEM(logger, redisClient)
|
||||
if err != nil {
|
||||
// 日志修改:logger → s.logger,zap结构化字段
|
||||
logger.Error("[ERROR] 获取服务器私钥失败: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
return nil, fmt.Errorf("获取服务器私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 提取公钥DER编码
|
||||
pubPEMBlock, _ := pem.Decode([]byte(keyPair.PublicKey))
|
||||
if pubPEMBlock == nil {
|
||||
// 日志修改:logger → s.logger,zap结构化字段
|
||||
logger.Error("[ERROR] 解码公钥PEM失败",
|
||||
zap.String("uuid", uuid),
|
||||
zap.String("publicKey", keyPair.PublicKey),
|
||||
)
|
||||
return nil, fmt.Errorf("解码公钥PEM失败")
|
||||
}
|
||||
pubDER := pubPEMBlock.Bytes
|
||||
|
||||
// 准备publicKeySignature(用于MC 1.19)
|
||||
// Base64编码公钥,不包含换行
|
||||
pubBase64 := strings.ReplaceAll(base64.StdEncoding.EncodeToString(pubDER), "\n", "")
|
||||
|
||||
// 按76字符一行进行包装
|
||||
pubBase64Wrapped := WrapString(pubBase64, 76)
|
||||
|
||||
// 放入PEM格式
|
||||
pubMojangPEM := "-----BEGIN RSA PUBLIC KEY-----\n" +
|
||||
pubBase64Wrapped +
|
||||
"\n-----END RSA PUBLIC KEY-----\n"
|
||||
|
||||
// 签名数据: expiresAt毫秒时间戳 + 公钥PEM格式
|
||||
signedData := []byte(fmt.Sprintf("%d%s", expiresAtMillis, pubMojangPEM))
|
||||
|
||||
// 计算SHA1哈希并签名
|
||||
hash1 := sha1.Sum(signedData)
|
||||
signature, err := rsa.SignPKCS1v15(rand.Reader, serverPrivateKey, crypto.SHA1, hash1[:])
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 签名失败: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
zap.Int64("expiresAtMillis", expiresAtMillis),
|
||||
)
|
||||
return nil, fmt.Errorf("签名失败: %w", err)
|
||||
}
|
||||
publicKeySignature = base64.StdEncoding.EncodeToString(signature)
|
||||
|
||||
// 准备publicKeySignatureV2(用于MC 1.19.1+)
|
||||
var uuidBytes []byte
|
||||
|
||||
// 如果提供了UUID,则使用它
|
||||
// 移除UUID中的连字符
|
||||
uuidStr := strings.ReplaceAll(uuid, "-", "")
|
||||
|
||||
// 将UUID转换为字节数组(16字节)
|
||||
if len(uuidStr) < 32 {
|
||||
logger.Warn("[WARN] UUID长度不足32字符,使用空UUID: %s",
|
||||
zap.String("uuid", uuid),
|
||||
zap.String("processedUuidStr", uuidStr),
|
||||
)
|
||||
uuidBytes = make([]byte, 16)
|
||||
} else {
|
||||
// 解析UUID字符串为字节
|
||||
uuidBytes = make([]byte, 16)
|
||||
parseErr := error(nil)
|
||||
for i := 0; i < 16; i++ {
|
||||
// 每两个字符转换为一个字节
|
||||
byteStr := uuidStr[i*2 : i*2+2]
|
||||
byteVal, err := strconv.ParseUint(byteStr, 16, 8)
|
||||
if err != nil {
|
||||
parseErr = err
|
||||
logger.Error("[ERROR] 解析UUID字节失败: %v, byteStr: %s",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
zap.String("byteStr", byteStr),
|
||||
zap.Int("index", i),
|
||||
)
|
||||
uuidBytes = make([]byte, 16) // 出错时使用空UUID
|
||||
break
|
||||
}
|
||||
uuidBytes[i] = byte(byteVal)
|
||||
}
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("解析UUID字节失败: %w", parseErr)
|
||||
}
|
||||
}
|
||||
|
||||
// 准备签名数据:UUID + expiresAt时间戳 + DER编码的公钥
|
||||
signedDataV2 := make([]byte, 0, 24+len(pubDER)) // 预分配缓冲区
|
||||
|
||||
// 添加UUID(16字节)
|
||||
signedDataV2 = append(signedDataV2, uuidBytes...)
|
||||
|
||||
// 添加expiresAt毫秒时间戳(8字节,大端序)
|
||||
expiresAtBytes := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(expiresAtBytes, uint64(expiresAtMillis))
|
||||
signedDataV2 = append(signedDataV2, expiresAtBytes...)
|
||||
|
||||
// 添加DER编码的公钥
|
||||
signedDataV2 = append(signedDataV2, pubDER...)
|
||||
|
||||
// 计算SHA1哈希并签名
|
||||
hash2 := sha1.Sum(signedDataV2)
|
||||
signatureV2, err := rsa.SignPKCS1v15(rand.Reader, serverPrivateKey, crypto.SHA1, hash2[:])
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 签名V2失败: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
zap.Int64("expiresAtMillis", expiresAtMillis),
|
||||
)
|
||||
return nil, fmt.Errorf("签名V2失败: %w", err)
|
||||
}
|
||||
publicKeySignatureV2 = base64.StdEncoding.EncodeToString(signatureV2)
|
||||
|
||||
// 创建玩家证书结构
|
||||
certificate := &PlayerCertificate{
|
||||
KeyPair: struct {
|
||||
PrivateKey string `json:"privateKey"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}{
|
||||
PrivateKey: keyPair.PrivateKey,
|
||||
PublicKey: keyPair.PublicKey,
|
||||
},
|
||||
PublicKeySignature: publicKeySignature,
|
||||
PublicKeySignatureV2: publicKeySignatureV2,
|
||||
ExpiresAt: keyPair.Expiration.Format(time.RFC3339Nano),
|
||||
RefreshedAfter: keyPair.Refresh.Format(time.RFC3339Nano),
|
||||
}
|
||||
|
||||
logger.Info("[INFO] 成功生成玩家证书,过期时间: %s",
|
||||
zap.String("uuid", uuid),
|
||||
zap.String("expiresAt", certificate.ExpiresAt),
|
||||
zap.String("refreshedAfter", certificate.RefreshedAfter),
|
||||
)
|
||||
return certificate, nil
|
||||
}
|
||||
|
||||
// GeneratePlayerCertificateService 生成玩家证书(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) GeneratePlayerCertificate(uuid string) (*PlayerCertificate, error) {
|
||||
return GeneratePlayerCertificate(nil, s.logger, s.redisClient, uuid) // TODO: 需要传入db参数
|
||||
}
|
||||
|
||||
// NewKeyPair 生成新的密钥对(函数式版本)
|
||||
func NewKeyPair(logger *zap.Logger) (*model.KeyPair, error) {
|
||||
// 生成新的RSA密钥对(用于玩家证书)
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) // 对玩家证书使用更小的密钥以提高性能
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成玩家证书私钥失败: %v",
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("生成玩家证书私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取DER编码的密钥
|
||||
keyDER, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 编码私钥为PKCS8格式失败: %v",
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("编码私钥为PKCS8格式失败: %w", err)
|
||||
}
|
||||
|
||||
pubDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 编码公钥为PKIX格式失败: %v",
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("编码公钥为PKIX格式失败: %w", err)
|
||||
}
|
||||
|
||||
// 将密钥编码为PEM格式
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: keyDER,
|
||||
})
|
||||
|
||||
pubPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PUBLIC KEY",
|
||||
Bytes: pubDER,
|
||||
})
|
||||
|
||||
// 创建证书过期和刷新时间
|
||||
now := time.Now().UTC()
|
||||
expiresAtTime := now.Add(CertificateExpirationPeriod)
|
||||
refreshedAfter := now.Add(CertificateRefreshInterval)
|
||||
keyPair := &model.KeyPair{
|
||||
Expiration: expiresAtTime,
|
||||
PrivateKey: string(keyPEM),
|
||||
PublicKey: string(pubPEM),
|
||||
Refresh: refreshedAfter,
|
||||
}
|
||||
return keyPair, nil
|
||||
}
|
||||
|
||||
// WrapString 将字符串按指定宽度进行换行(函数式版本)
|
||||
func WrapString(str string, width int) string {
|
||||
if width <= 0 {
|
||||
return str
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
for i := 0; i < len(str); i += width {
|
||||
end := i + width
|
||||
if end > len(str) {
|
||||
end = len(str)
|
||||
}
|
||||
b.WriteString(str[i:end])
|
||||
if end < len(str) {
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// NewKeyPairService 生成新的密钥对(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) NewKeyPair() (*model.KeyPair, error) {
|
||||
return NewKeyPair(s.logger)
|
||||
}
|
||||
358
internal/service/signature_service_test.go
Normal file
358
internal/service/signature_service_test.go
Normal file
@@ -0,0 +1,358 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestSignatureService_Constants 测试签名服务相关常量
|
||||
func TestSignatureService_Constants(t *testing.T) {
|
||||
if RSAKeySize != 4096 {
|
||||
t.Errorf("RSAKeySize = %d, want 4096", RSAKeySize)
|
||||
}
|
||||
|
||||
if PrivateKeyRedisKey == "" {
|
||||
t.Error("PrivateKeyRedisKey should not be empty")
|
||||
}
|
||||
|
||||
if PublicKeyRedisKey == "" {
|
||||
t.Error("PublicKeyRedisKey should not be empty")
|
||||
}
|
||||
|
||||
if KeyExpirationTime != 24*7*time.Hour {
|
||||
t.Errorf("KeyExpirationTime = %v, want 7 days", KeyExpirationTime)
|
||||
}
|
||||
|
||||
if CertificateRefreshInterval != 24*time.Hour {
|
||||
t.Errorf("CertificateRefreshInterval = %v, want 24 hours", CertificateRefreshInterval)
|
||||
}
|
||||
|
||||
if CertificateExpirationPeriod != 24*7*time.Hour {
|
||||
t.Errorf("CertificateExpirationPeriod = %v, want 7 days", CertificateExpirationPeriod)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSignatureService_DataValidation 测试签名数据验证逻辑
|
||||
func TestSignatureService_DataValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "非空数据有效",
|
||||
data: "test data",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "空数据无效",
|
||||
data: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.data != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Data validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPlayerCertificate_Structure 测试PlayerCertificate结构
|
||||
func TestPlayerCertificate_Structure(t *testing.T) {
|
||||
cert := PlayerCertificate{
|
||||
ExpiresAt: "2025-01-01T00:00:00Z",
|
||||
RefreshedAfter: "2025-01-01T00:00:00Z",
|
||||
PublicKeySignature: "signature",
|
||||
PublicKeySignatureV2: "signaturev2",
|
||||
}
|
||||
|
||||
// 验证结构体字段
|
||||
if cert.ExpiresAt == "" {
|
||||
t.Error("ExpiresAt should not be empty")
|
||||
}
|
||||
|
||||
if cert.RefreshedAfter == "" {
|
||||
t.Error("RefreshedAfter should not be empty")
|
||||
}
|
||||
|
||||
// PublicKeySignature是可选的
|
||||
if cert.PublicKeySignature == "" {
|
||||
t.Log("PublicKeySignature is optional")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapString 测试字符串换行函数
|
||||
func TestWrapString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
str string
|
||||
width int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "正常换行",
|
||||
str: "1234567890",
|
||||
width: 5,
|
||||
expected: "12345\n67890",
|
||||
},
|
||||
{
|
||||
name: "字符串长度等于width",
|
||||
str: "12345",
|
||||
width: 5,
|
||||
expected: "12345",
|
||||
},
|
||||
{
|
||||
name: "字符串长度小于width",
|
||||
str: "123",
|
||||
width: 5,
|
||||
expected: "123",
|
||||
},
|
||||
{
|
||||
name: "width为0,返回原字符串",
|
||||
str: "1234567890",
|
||||
width: 0,
|
||||
expected: "1234567890",
|
||||
},
|
||||
{
|
||||
name: "width为负数,返回原字符串",
|
||||
str: "1234567890",
|
||||
width: -1,
|
||||
expected: "1234567890",
|
||||
},
|
||||
{
|
||||
name: "空字符串",
|
||||
str: "",
|
||||
width: 5,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "width为1",
|
||||
str: "12345",
|
||||
width: 1,
|
||||
expected: "1\n2\n3\n4\n5",
|
||||
},
|
||||
{
|
||||
name: "长字符串多次换行",
|
||||
str: "123456789012345",
|
||||
width: 5,
|
||||
expected: "12345\n67890\n12345",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := WrapString(tt.str, tt.width)
|
||||
if result != tt.expected {
|
||||
t.Errorf("WrapString(%q, %d) = %q, want %q", tt.str, tt.width, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapString_LineCount 测试换行后的行数
|
||||
func TestWrapString_LineCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
str string
|
||||
width int
|
||||
wantLines int
|
||||
}{
|
||||
{
|
||||
name: "10个字符,width=5,应该2行",
|
||||
str: "1234567890",
|
||||
width: 5,
|
||||
wantLines: 2,
|
||||
},
|
||||
{
|
||||
name: "15个字符,width=5,应该3行",
|
||||
str: "123456789012345",
|
||||
width: 5,
|
||||
wantLines: 3,
|
||||
},
|
||||
{
|
||||
name: "5个字符,width=5,应该1行",
|
||||
str: "12345",
|
||||
width: 5,
|
||||
wantLines: 1,
|
||||
},
|
||||
{
|
||||
name: "width为0,应该1行",
|
||||
str: "1234567890",
|
||||
width: 0,
|
||||
wantLines: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := WrapString(tt.str, tt.width)
|
||||
lines := strings.Count(result, "\n") + 1
|
||||
if lines != tt.wantLines {
|
||||
t.Errorf("Line count = %d, want %d (result: %q)", lines, tt.wantLines, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapString_NoTrailingNewline 测试末尾不换行
|
||||
func TestWrapString_NoTrailingNewline(t *testing.T) {
|
||||
str := "1234567890"
|
||||
result := WrapString(str, 5)
|
||||
|
||||
// 验证末尾没有换行符
|
||||
if strings.HasSuffix(result, "\n") {
|
||||
t.Error("Result should not end with newline")
|
||||
}
|
||||
|
||||
// 验证包含换行符(除了最后一行)
|
||||
if !strings.Contains(result, "\n") {
|
||||
t.Error("Result should contain newline for multi-line output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodePrivateKeyToPEM_ActualCall 实际调用EncodePrivateKeyToPEM函数
|
||||
func TestEncodePrivateKeyToPEM_ActualCall(t *testing.T) {
|
||||
// 生成测试用的RSA私钥
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("生成RSA私钥失败: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
keyType []string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "默认类型",
|
||||
keyType: []string{},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "RSA类型",
|
||||
keyType: []string{"RSA"},
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pemBytes, err := EncodePrivateKeyToPEM(privateKey, tt.keyType...)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("EncodePrivateKeyToPEM() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
if !tt.wantError {
|
||||
if len(pemBytes) == 0 {
|
||||
t.Error("EncodePrivateKeyToPEM() 返回的PEM字节不应为空")
|
||||
}
|
||||
pemStr := string(pemBytes)
|
||||
// 验证PEM格式
|
||||
if !strings.Contains(pemStr, "BEGIN") || !strings.Contains(pemStr, "END") {
|
||||
t.Error("EncodePrivateKeyToPEM() 返回的PEM格式不正确")
|
||||
}
|
||||
// 验证类型
|
||||
if len(tt.keyType) > 0 && tt.keyType[0] == "RSA" {
|
||||
if !strings.Contains(pemStr, "RSA PRIVATE KEY") {
|
||||
t.Error("EncodePrivateKeyToPEM() 应包含 'RSA PRIVATE KEY'")
|
||||
}
|
||||
} else {
|
||||
if !strings.Contains(pemStr, "PRIVATE KEY") {
|
||||
t.Error("EncodePrivateKeyToPEM() 应包含 'PRIVATE KEY'")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodePublicKeyToPEM_ActualCall 实际调用EncodePublicKeyToPEM函数
|
||||
func TestEncodePublicKeyToPEM_ActualCall(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
// 生成测试用的RSA密钥对
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("生成RSA密钥对失败: %v", err)
|
||||
}
|
||||
publicKey := &privateKey.PublicKey
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
keyType []string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "默认类型",
|
||||
keyType: []string{},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "RSA类型",
|
||||
keyType: []string{"RSA"},
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pemBytes, err := EncodePublicKeyToPEM(logger, publicKey, tt.keyType...)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("EncodePublicKeyToPEM() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
if !tt.wantError {
|
||||
if len(pemBytes) == 0 {
|
||||
t.Error("EncodePublicKeyToPEM() 返回的PEM字节不应为空")
|
||||
}
|
||||
pemStr := string(pemBytes)
|
||||
// 验证PEM格式
|
||||
if !strings.Contains(pemStr, "BEGIN") || !strings.Contains(pemStr, "END") {
|
||||
t.Error("EncodePublicKeyToPEM() 返回的PEM格式不正确")
|
||||
}
|
||||
// 验证类型
|
||||
if len(tt.keyType) > 0 && tt.keyType[0] == "RSA" {
|
||||
if !strings.Contains(pemStr, "RSA PUBLIC KEY") {
|
||||
t.Error("EncodePublicKeyToPEM() 应包含 'RSA PUBLIC KEY'")
|
||||
}
|
||||
} else {
|
||||
if !strings.Contains(pemStr, "PUBLIC KEY") {
|
||||
t.Error("EncodePublicKeyToPEM() 应包含 'PUBLIC KEY'")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodePublicKeyToPEM_NilKey 测试nil公钥
|
||||
func TestEncodePublicKeyToPEM_NilKey(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
_, err := EncodePublicKeyToPEM(logger, nil)
|
||||
if err == nil {
|
||||
t.Error("EncodePublicKeyToPEM() 对于nil公钥应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewSignatureService 测试创建SignatureService
|
||||
func TestNewSignatureService(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
// 注意:这里需要实际的redis client,但我们只测试结构体创建
|
||||
// 在实际测试中,可以使用mock redis client
|
||||
service := NewSignatureService(logger, nil)
|
||||
if service == nil {
|
||||
t.Error("NewSignatureService() 不应返回nil")
|
||||
}
|
||||
if service.logger != logger {
|
||||
t.Error("NewSignatureService() logger 设置不正确")
|
||||
}
|
||||
}
|
||||
251
internal/service/texture_service.go
Normal file
251
internal/service/texture_service.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateTexture 创建材质
|
||||
func CreateTexture(db *gorm.DB, uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) {
|
||||
// 验证用户存在
|
||||
user, err := repository.FindUserByID(uploaderID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user == nil {
|
||||
return nil, errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 检查Hash是否已存在
|
||||
existingTexture, err := repository.FindTextureByHash(hash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if existingTexture != nil {
|
||||
return nil, errors.New("该材质已存在")
|
||||
}
|
||||
|
||||
// 转换材质类型
|
||||
var textureTypeEnum model.TextureType
|
||||
switch textureType {
|
||||
case "SKIN":
|
||||
textureTypeEnum = model.TextureTypeSkin
|
||||
case "CAPE":
|
||||
textureTypeEnum = model.TextureTypeCape
|
||||
default:
|
||||
return nil, errors.New("无效的材质类型")
|
||||
}
|
||||
|
||||
// 创建材质
|
||||
texture := &model.Texture{
|
||||
UploaderID: uploaderID,
|
||||
Name: name,
|
||||
Description: description,
|
||||
Type: textureTypeEnum,
|
||||
URL: url,
|
||||
Hash: hash,
|
||||
Size: size,
|
||||
IsPublic: isPublic,
|
||||
IsSlim: isSlim,
|
||||
Status: 1,
|
||||
DownloadCount: 0,
|
||||
FavoriteCount: 0,
|
||||
}
|
||||
|
||||
if err := repository.CreateTexture(texture); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return texture, nil
|
||||
}
|
||||
|
||||
// GetTextureByID 根据ID获取材质
|
||||
func GetTextureByID(db *gorm.DB, id int64) (*model.Texture, error) {
|
||||
texture, err := repository.FindTextureByID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if texture == nil {
|
||||
return nil, errors.New("材质不存在")
|
||||
}
|
||||
if texture.Status == -1 {
|
||||
return nil, errors.New("材质已删除")
|
||||
}
|
||||
return texture, nil
|
||||
}
|
||||
|
||||
// GetUserTextures 获取用户上传的材质列表
|
||||
func GetUserTextures(db *gorm.DB, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
return repository.FindTexturesByUploaderID(uploaderID, page, pageSize)
|
||||
}
|
||||
|
||||
// SearchTextures 搜索材质
|
||||
func SearchTextures(db *gorm.DB, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
return repository.SearchTextures(keyword, textureType, publicOnly, page, pageSize)
|
||||
}
|
||||
|
||||
// UpdateTexture 更新材质
|
||||
func UpdateTexture(db *gorm.DB, textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) {
|
||||
// 获取材质
|
||||
texture, err := repository.FindTextureByID(textureID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if texture == nil {
|
||||
return nil, errors.New("材质不存在")
|
||||
}
|
||||
|
||||
// 检查权限:只有上传者可以修改
|
||||
if texture.UploaderID != uploaderID {
|
||||
return nil, errors.New("无权修改此材质")
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
updates := make(map[string]interface{})
|
||||
if name != "" {
|
||||
updates["name"] = name
|
||||
}
|
||||
if description != "" {
|
||||
updates["description"] = description
|
||||
}
|
||||
if isPublic != nil {
|
||||
updates["is_public"] = *isPublic
|
||||
}
|
||||
|
||||
if len(updates) > 0 {
|
||||
if err := repository.UpdateTextureFields(textureID, updates); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 返回更新后的材质
|
||||
return repository.FindTextureByID(textureID)
|
||||
}
|
||||
|
||||
// DeleteTexture 删除材质
|
||||
func DeleteTexture(db *gorm.DB, textureID, uploaderID int64) error {
|
||||
// 获取材质
|
||||
texture, err := repository.FindTextureByID(textureID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if texture == nil {
|
||||
return errors.New("材质不存在")
|
||||
}
|
||||
|
||||
// 检查权限:只有上传者可以删除
|
||||
if texture.UploaderID != uploaderID {
|
||||
return errors.New("无权删除此材质")
|
||||
}
|
||||
|
||||
return repository.DeleteTexture(textureID)
|
||||
}
|
||||
|
||||
// RecordTextureDownload 记录下载
|
||||
func RecordTextureDownload(db *gorm.DB, textureID int64, userID *int64, ipAddress, userAgent string) error {
|
||||
// 检查材质是否存在
|
||||
texture, err := repository.FindTextureByID(textureID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if texture == nil {
|
||||
return errors.New("材质不存在")
|
||||
}
|
||||
|
||||
// 增加下载次数
|
||||
if err := repository.IncrementTextureDownloadCount(textureID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建下载日志
|
||||
log := &model.TextureDownloadLog{
|
||||
TextureID: textureID,
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
}
|
||||
|
||||
return repository.CreateTextureDownloadLog(log)
|
||||
}
|
||||
|
||||
// ToggleTextureFavorite 切换收藏状态
|
||||
func ToggleTextureFavorite(db *gorm.DB, userID, textureID int64) (bool, error) {
|
||||
// 检查材质是否存在
|
||||
texture, err := repository.FindTextureByID(textureID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if texture == nil {
|
||||
return false, errors.New("材质不存在")
|
||||
}
|
||||
|
||||
// 检查是否已收藏
|
||||
isFavorited, err := repository.IsTextureFavorited(userID, textureID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if isFavorited {
|
||||
// 取消收藏
|
||||
if err := repository.RemoveTextureFavorite(userID, textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := repository.DecrementTextureFavoriteCount(textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return false, nil
|
||||
} else {
|
||||
// 添加收藏
|
||||
if err := repository.AddTextureFavorite(userID, textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := repository.IncrementTextureFavoriteCount(textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserTextureFavorites 获取用户收藏的材质列表
|
||||
func GetUserTextureFavorites(db *gorm.DB, userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
return repository.GetUserTextureFavorites(userID, page, pageSize)
|
||||
}
|
||||
|
||||
// CheckTextureUploadLimit 检查用户上传材质数量限制
|
||||
func CheckTextureUploadLimit(db *gorm.DB, uploaderID int64, maxTextures int) error {
|
||||
count, err := repository.CountTexturesByUploaderID(uploaderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count >= int64(maxTextures) {
|
||||
return fmt.Errorf("已达到最大上传数量限制(%d)", maxTextures)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
471
internal/service/texture_service_test.go
Normal file
471
internal/service/texture_service_test.go
Normal file
@@ -0,0 +1,471 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestTextureService_TypeValidation 测试材质类型验证
|
||||
func TestTextureService_TypeValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
textureType string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "SKIN类型有效",
|
||||
textureType: "SKIN",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "CAPE类型有效",
|
||||
textureType: "CAPE",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "无效类型",
|
||||
textureType: "INVALID",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "空类型无效",
|
||||
textureType: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.textureType == "SKIN" || tt.textureType == "CAPE"
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Texture type validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextureService_DefaultValues 测试材质默认值
|
||||
func TestTextureService_DefaultValues(t *testing.T) {
|
||||
// 测试默认状态
|
||||
defaultStatus := 1
|
||||
if defaultStatus != 1 {
|
||||
t.Errorf("默认状态应为1,实际为%d", defaultStatus)
|
||||
}
|
||||
|
||||
// 测试默认下载数
|
||||
defaultDownloadCount := 0
|
||||
if defaultDownloadCount != 0 {
|
||||
t.Errorf("默认下载数应为0,实际为%d", defaultDownloadCount)
|
||||
}
|
||||
|
||||
// 测试默认收藏数
|
||||
defaultFavoriteCount := 0
|
||||
if defaultFavoriteCount != 0 {
|
||||
t.Errorf("默认收藏数应为0,实际为%d", defaultFavoriteCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextureService_StatusValidation 测试材质状态验证
|
||||
func TestTextureService_StatusValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status int16
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "状态为1(正常)时有效",
|
||||
status: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "状态为-1(删除)时无效",
|
||||
status: -1,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "状态为0时可能有效(取决于业务逻辑)",
|
||||
status: 0,
|
||||
wantValid: true, // 状态为0(禁用)时,材质仍然存在,只是不可用,但查询时不会返回错误
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 材质状态为-1时表示已删除,无效
|
||||
isValid := tt.status != -1
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Status validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUserTextures_Pagination 测试分页逻辑
|
||||
func TestGetUserTextures_Pagination(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
page int
|
||||
pageSize int
|
||||
wantPage int
|
||||
wantSize int
|
||||
}{
|
||||
{
|
||||
name: "有效的分页参数",
|
||||
page: 2,
|
||||
pageSize: 20,
|
||||
wantPage: 2,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page小于1,应该设为1",
|
||||
page: 0,
|
||||
pageSize: 20,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "pageSize小于1,应该设为20",
|
||||
page: 1,
|
||||
pageSize: 0,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "pageSize超过100,应该设为20",
|
||||
page: 1,
|
||||
pageSize: 200,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
page := tt.page
|
||||
pageSize := tt.pageSize
|
||||
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
if page != tt.wantPage {
|
||||
t.Errorf("Page = %d, want %d", page, tt.wantPage)
|
||||
}
|
||||
if pageSize != tt.wantSize {
|
||||
t.Errorf("PageSize = %d, want %d", pageSize, tt.wantSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchTextures_Pagination 测试搜索分页逻辑
|
||||
func TestSearchTextures_Pagination(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
page int
|
||||
pageSize int
|
||||
wantPage int
|
||||
wantSize int
|
||||
}{
|
||||
{
|
||||
name: "有效的分页参数",
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
wantPage: 1,
|
||||
wantSize: 10,
|
||||
},
|
||||
{
|
||||
name: "page小于1,应该设为1",
|
||||
page: -1,
|
||||
pageSize: 20,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "pageSize超过100,应该设为20",
|
||||
page: 1,
|
||||
pageSize: 150,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
page := tt.page
|
||||
pageSize := tt.pageSize
|
||||
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
if page != tt.wantPage {
|
||||
t.Errorf("Page = %d, want %d", page, tt.wantPage)
|
||||
}
|
||||
if pageSize != tt.wantSize {
|
||||
t.Errorf("PageSize = %d, want %d", pageSize, tt.wantSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateTexture_PermissionCheck 测试更新材质的权限检查
|
||||
func TestUpdateTexture_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uploaderID int64
|
||||
requestID int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "上传者ID匹配,允许更新",
|
||||
uploaderID: 1,
|
||||
requestID: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "上传者ID不匹配,拒绝更新",
|
||||
uploaderID: 1,
|
||||
requestID: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.uploaderID != tt.requestID
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateTexture_FieldUpdates 测试更新字段逻辑
|
||||
func TestUpdateTexture_FieldUpdates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nameValue string
|
||||
descValue string
|
||||
isPublic *bool
|
||||
wantUpdates int
|
||||
}{
|
||||
{
|
||||
name: "更新所有字段",
|
||||
nameValue: "NewName",
|
||||
descValue: "NewDesc",
|
||||
isPublic: boolPtr(true),
|
||||
wantUpdates: 3,
|
||||
},
|
||||
{
|
||||
name: "只更新名称",
|
||||
nameValue: "NewName",
|
||||
descValue: "",
|
||||
isPublic: nil,
|
||||
wantUpdates: 1,
|
||||
},
|
||||
{
|
||||
name: "只更新描述",
|
||||
nameValue: "",
|
||||
descValue: "NewDesc",
|
||||
isPublic: nil,
|
||||
wantUpdates: 1,
|
||||
},
|
||||
{
|
||||
name: "只更新公开状态",
|
||||
nameValue: "",
|
||||
descValue: "",
|
||||
isPublic: boolPtr(false),
|
||||
wantUpdates: 1,
|
||||
},
|
||||
{
|
||||
name: "没有更新",
|
||||
nameValue: "",
|
||||
descValue: "",
|
||||
isPublic: nil,
|
||||
wantUpdates: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
updates := 0
|
||||
if tt.nameValue != "" {
|
||||
updates++
|
||||
}
|
||||
if tt.descValue != "" {
|
||||
updates++
|
||||
}
|
||||
if tt.isPublic != nil {
|
||||
updates++
|
||||
}
|
||||
|
||||
if updates != tt.wantUpdates {
|
||||
t.Errorf("Updates count = %d, want %d", updates, tt.wantUpdates)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteTexture_PermissionCheck 测试删除材质的权限检查
|
||||
func TestDeleteTexture_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uploaderID int64
|
||||
requestID int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "上传者ID匹配,允许删除",
|
||||
uploaderID: 1,
|
||||
requestID: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "上传者ID不匹配,拒绝删除",
|
||||
uploaderID: 1,
|
||||
requestID: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.uploaderID != tt.requestID
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestToggleTextureFavorite_Logic 测试切换收藏状态的逻辑
|
||||
func TestToggleTextureFavorite_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isFavorited bool
|
||||
wantResult bool
|
||||
}{
|
||||
{
|
||||
name: "已收藏,取消收藏",
|
||||
isFavorited: true,
|
||||
wantResult: false,
|
||||
},
|
||||
{
|
||||
name: "未收藏,添加收藏",
|
||||
isFavorited: false,
|
||||
wantResult: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := !tt.isFavorited
|
||||
if result != tt.wantResult {
|
||||
t.Errorf("Toggle favorite failed: got %v, want %v", result, tt.wantResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUserTextureFavorites_Pagination 测试收藏列表分页
|
||||
func TestGetUserTextureFavorites_Pagination(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
page int
|
||||
pageSize int
|
||||
wantPage int
|
||||
wantSize int
|
||||
}{
|
||||
{
|
||||
name: "有效的分页参数",
|
||||
page: 1,
|
||||
pageSize: 20,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page小于1,应该设为1",
|
||||
page: 0,
|
||||
pageSize: 20,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "pageSize超过100,应该设为20",
|
||||
page: 1,
|
||||
pageSize: 200,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
page := tt.page
|
||||
pageSize := tt.pageSize
|
||||
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
if page != tt.wantPage {
|
||||
t.Errorf("Page = %d, want %d", page, tt.wantPage)
|
||||
}
|
||||
if pageSize != tt.wantSize {
|
||||
t.Errorf("PageSize = %d, want %d", pageSize, tt.wantSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckTextureUploadLimit_Logic 测试上传限制检查逻辑
|
||||
func TestCheckTextureUploadLimit_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
count int64
|
||||
maxTextures int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "未达到上限",
|
||||
count: 5,
|
||||
maxTextures: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "达到上限",
|
||||
count: 10,
|
||||
maxTextures: 10,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "超过上限",
|
||||
count: 15,
|
||||
maxTextures: 10,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.count >= int64(tt.maxTextures)
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Limit check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
277
internal/service/token_service.go
Normal file
277
internal/service/token_service.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"go.uber.org/zap"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 常量定义
|
||||
const (
|
||||
ExtendedTimeout = 10 * time.Second
|
||||
TokensMaxCount = 10 // 用户最多保留的token数量
|
||||
)
|
||||
|
||||
// NewToken 创建新令牌
|
||||
func NewToken(db *gorm.DB, logger *zap.Logger, userId int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
|
||||
var (
|
||||
selectedProfileID *model.Profile
|
||||
availableProfiles []*model.Profile
|
||||
)
|
||||
// 设置超时上下文
|
||||
_, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
// 验证用户存在
|
||||
_, err := repository.FindProfileByUUID(UUID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
|
||||
}
|
||||
|
||||
// 生成令牌
|
||||
if clientToken == "" {
|
||||
clientToken = uuid.New().String()
|
||||
}
|
||||
|
||||
accessToken := uuid.New().String()
|
||||
token := model.Token{
|
||||
AccessToken: accessToken,
|
||||
ClientToken: clientToken,
|
||||
UserID: userId,
|
||||
Usable: true,
|
||||
IssueDate: time.Now(),
|
||||
}
|
||||
|
||||
// 获取用户配置文件
|
||||
profiles, err := repository.FindProfilesByUserID(userId)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 如果用户只有一个配置文件,自动选择
|
||||
if len(profiles) == 1 {
|
||||
selectedProfileID = profiles[0]
|
||||
token.ProfileId = selectedProfileID.UUID
|
||||
}
|
||||
availableProfiles = profiles
|
||||
|
||||
// 插入令牌到tokens集合
|
||||
_, insertCancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer insertCancel()
|
||||
|
||||
err = repository.CreateToken(&token)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err)
|
||||
}
|
||||
// 清理多余的令牌
|
||||
go CheckAndCleanupExcessTokens(db, logger, userId)
|
||||
|
||||
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
|
||||
}
|
||||
|
||||
// CheckAndCleanupExcessTokens 检查并清理用户多余的令牌,只保留最新的10个
|
||||
func CheckAndCleanupExcessTokens(db *gorm.DB, logger *zap.Logger, userId int64) {
|
||||
if userId == 0 {
|
||||
return
|
||||
}
|
||||
// 获取用户所有令牌,按发行日期降序排序
|
||||
tokens, err := repository.GetTokensByUserId(userId)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 获取用户Token失败: ", zap.Error(err), zap.String("userId", strconv.FormatInt(userId, 10)))
|
||||
return
|
||||
}
|
||||
|
||||
// 如果令牌数量不超过上限,无需清理
|
||||
if len(tokens) <= TokensMaxCount {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取需要删除的令牌ID列表
|
||||
tokensToDelete := make([]string, 0, len(tokens)-TokensMaxCount)
|
||||
for i := TokensMaxCount; i < len(tokens); i++ {
|
||||
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
|
||||
}
|
||||
|
||||
// 执行批量删除,传入上下文和待删除的令牌列表(作为切片参数)
|
||||
DeletedCount, err := repository.BatchDeleteTokens(tokensToDelete)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 清理用户多余Token失败: ", zap.Error(err), zap.String("userId", strconv.FormatInt(userId, 10)))
|
||||
return
|
||||
}
|
||||
|
||||
if DeletedCount > 0 {
|
||||
logger.Info("[INFO] 成功清理用户多余Token", zap.Any("userId:", userId), zap.Any("count:", DeletedCount))
|
||||
}
|
||||
}
|
||||
|
||||
// ValidToken 验证令牌有效性
|
||||
func ValidToken(db *gorm.DB, accessToken string, clientToken string) bool {
|
||||
if accessToken == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 使用投影只获取需要的字段
|
||||
var token *model.Token
|
||||
token, err := repository.FindTokenByID(accessToken)
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !token.Usable {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果客户端令牌为空,只验证访问令牌
|
||||
if clientToken == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// 否则验证客户端令牌是否匹配
|
||||
return token.ClientToken == clientToken
|
||||
}
|
||||
|
||||
func GetUUIDByAccessToken(db *gorm.DB, accessToken string) (string, error) {
|
||||
return repository.GetUUIDByAccessToken(accessToken)
|
||||
}
|
||||
|
||||
func GetUserIDByAccessToken(db *gorm.DB, accessToken string) (int64, error) {
|
||||
return repository.GetUserIDByAccessToken(accessToken)
|
||||
}
|
||||
|
||||
// RefreshToken 刷新令牌
|
||||
func RefreshToken(db *gorm.DB, logger *zap.Logger, accessToken, clientToken string, selectedProfileID string) (string, string, error) {
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("accessToken不能为空")
|
||||
}
|
||||
|
||||
// 查找旧令牌
|
||||
oldToken, err := repository.GetTokenByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", "", errors.New("accessToken无效")
|
||||
}
|
||||
logger.Error("[ERROR] 查询Token失败: ", zap.Error(err), zap.Any("accessToken:", accessToken))
|
||||
return "", "", fmt.Errorf("查询令牌失败: %w", err)
|
||||
}
|
||||
|
||||
// 验证profile
|
||||
if selectedProfileID != "" {
|
||||
valid, validErr := ValidateProfileByUserID(db, oldToken.UserID, selectedProfileID)
|
||||
if validErr != nil {
|
||||
logger.Error(
|
||||
"验证Profile失败",
|
||||
zap.Error(err),
|
||||
zap.Any("userId", oldToken.UserID),
|
||||
zap.String("profileId", selectedProfileID),
|
||||
)
|
||||
return "", "", fmt.Errorf("验证角色失败: %w", err)
|
||||
}
|
||||
if !valid {
|
||||
return "", "", errors.New("角色与用户不匹配")
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 clientToken 是否有效
|
||||
if clientToken != "" && clientToken != oldToken.ClientToken {
|
||||
return "", "", errors.New("clientToken无效")
|
||||
}
|
||||
|
||||
// 检查 selectedProfileID 的逻辑
|
||||
if selectedProfileID != "" {
|
||||
if oldToken.ProfileId != "" && oldToken.ProfileId != selectedProfileID {
|
||||
return "", "", errors.New("原令牌已绑定角色,无法选择新角色")
|
||||
}
|
||||
} else {
|
||||
selectedProfileID = oldToken.ProfileId // 如果未指定,则保持原角色
|
||||
}
|
||||
|
||||
// 生成新令牌
|
||||
newAccessToken := uuid.New().String()
|
||||
newToken := model.Token{
|
||||
AccessToken: newAccessToken,
|
||||
ClientToken: oldToken.ClientToken, // 新令牌的 clientToken 与原令牌相同
|
||||
UserID: oldToken.UserID,
|
||||
Usable: true,
|
||||
ProfileId: selectedProfileID, // 绑定到指定角色或保持原角色
|
||||
IssueDate: time.Now(),
|
||||
}
|
||||
|
||||
// 使用双重写入模式替代事务,先插入新令牌,再删除旧令牌
|
||||
|
||||
err = repository.CreateToken(&newToken)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"创建新Token失败",
|
||||
zap.Error(err),
|
||||
zap.String("accessToken", accessToken),
|
||||
)
|
||||
return "", "", fmt.Errorf("创建新Token失败: %w", err)
|
||||
}
|
||||
|
||||
err = repository.DeleteTokenByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
// 删除旧令牌失败,记录日志但不阻止操作,因为新令牌已成功创建
|
||||
logger.Warn(
|
||||
"删除旧Token失败,但新Token已创建",
|
||||
zap.Error(err),
|
||||
zap.String("oldToken", oldToken.AccessToken),
|
||||
zap.String("newToken", newAccessToken),
|
||||
)
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"成功刷新Token",
|
||||
zap.Any("userId", oldToken.UserID),
|
||||
zap.String("accessToken", newAccessToken),
|
||||
)
|
||||
return newAccessToken, oldToken.ClientToken, nil
|
||||
}
|
||||
|
||||
// InvalidToken 使令牌失效
|
||||
func InvalidToken(db *gorm.DB, logger *zap.Logger, accessToken string) {
|
||||
if accessToken == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err := repository.DeleteTokenByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"删除Token失败",
|
||||
zap.Error(err),
|
||||
zap.String("accessToken", accessToken),
|
||||
)
|
||||
return
|
||||
}
|
||||
logger.Info("[INFO] 成功删除", zap.Any("Token:", accessToken))
|
||||
|
||||
}
|
||||
|
||||
// InvalidUserTokens 使用户所有令牌失效
|
||||
func InvalidUserTokens(db *gorm.DB, logger *zap.Logger, userId int64) {
|
||||
if userId == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := repository.DeleteTokenByUserId(userId)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"[ERROR]删除用户Token失败",
|
||||
zap.Error(err),
|
||||
zap.Any("userId", userId),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("[INFO] 成功删除用户Token", zap.Any("userId:", userId))
|
||||
|
||||
}
|
||||
204
internal/service/token_service_test.go
Normal file
204
internal/service/token_service_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestTokenService_Constants 测试Token服务相关常量
|
||||
func TestTokenService_Constants(t *testing.T) {
|
||||
if ExtendedTimeout != 10*time.Second {
|
||||
t.Errorf("ExtendedTimeout = %v, want 10 seconds", ExtendedTimeout)
|
||||
}
|
||||
|
||||
if TokensMaxCount != 10 {
|
||||
t.Errorf("TokensMaxCount = %d, want 10", TokensMaxCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_Timeout 测试超时常量
|
||||
func TestTokenService_Timeout(t *testing.T) {
|
||||
if DefaultTimeout != 5*time.Second {
|
||||
t.Errorf("DefaultTimeout = %v, want 5 seconds", DefaultTimeout)
|
||||
}
|
||||
|
||||
if ExtendedTimeout <= DefaultTimeout {
|
||||
t.Errorf("ExtendedTimeout (%v) should be greater than DefaultTimeout (%v)", ExtendedTimeout, DefaultTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_Validation 测试Token验证逻辑
|
||||
func TestTokenService_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "空token无效",
|
||||
accessToken: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "非空token可能有效",
|
||||
accessToken: "valid-token-string",
|
||||
wantValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 测试空token检查逻辑
|
||||
isValid := tt.accessToken != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Token validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_ClientTokenLogic 测试ClientToken逻辑
|
||||
func TestTokenService_ClientTokenLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientToken string
|
||||
shouldGenerate bool
|
||||
}{
|
||||
{
|
||||
name: "空的clientToken应该生成新的",
|
||||
clientToken: "",
|
||||
shouldGenerate: true,
|
||||
},
|
||||
{
|
||||
name: "非空的clientToken应该使用提供的",
|
||||
clientToken: "existing-client-token",
|
||||
shouldGenerate: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldGenerate := tt.clientToken == ""
|
||||
if shouldGenerate != tt.shouldGenerate {
|
||||
t.Errorf("ClientToken logic failed: got %v, want %v", shouldGenerate, tt.shouldGenerate)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_ProfileSelection 测试Profile选择逻辑
|
||||
func TestTokenService_ProfileSelection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileCount int
|
||||
shouldAutoSelect bool
|
||||
}{
|
||||
{
|
||||
name: "只有一个profile时自动选择",
|
||||
profileCount: 1,
|
||||
shouldAutoSelect: true,
|
||||
},
|
||||
{
|
||||
name: "多个profile时不自动选择",
|
||||
profileCount: 2,
|
||||
shouldAutoSelect: false,
|
||||
},
|
||||
{
|
||||
name: "没有profile时不自动选择",
|
||||
profileCount: 0,
|
||||
shouldAutoSelect: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldAutoSelect := tt.profileCount == 1
|
||||
if shouldAutoSelect != tt.shouldAutoSelect {
|
||||
t.Errorf("Profile selection logic failed: got %v, want %v", shouldAutoSelect, tt.shouldAutoSelect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_CleanupLogic 测试清理逻辑
|
||||
func TestTokenService_CleanupLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenCount int
|
||||
maxCount int
|
||||
shouldCleanup bool
|
||||
cleanupCount int
|
||||
}{
|
||||
{
|
||||
name: "token数量未超过上限,不需要清理",
|
||||
tokenCount: 5,
|
||||
maxCount: 10,
|
||||
shouldCleanup: false,
|
||||
cleanupCount: 0,
|
||||
},
|
||||
{
|
||||
name: "token数量超过上限,需要清理",
|
||||
tokenCount: 15,
|
||||
maxCount: 10,
|
||||
shouldCleanup: true,
|
||||
cleanupCount: 5,
|
||||
},
|
||||
{
|
||||
name: "token数量等于上限,不需要清理",
|
||||
tokenCount: 10,
|
||||
maxCount: 10,
|
||||
shouldCleanup: false,
|
||||
cleanupCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldCleanup := tt.tokenCount > tt.maxCount
|
||||
if shouldCleanup != tt.shouldCleanup {
|
||||
t.Errorf("Cleanup decision failed: got %v, want %v", shouldCleanup, tt.shouldCleanup)
|
||||
}
|
||||
|
||||
if shouldCleanup {
|
||||
expectedCleanupCount := tt.tokenCount - tt.maxCount
|
||||
if expectedCleanupCount != tt.cleanupCount {
|
||||
t.Errorf("Cleanup count failed: got %d, want %d", expectedCleanupCount, tt.cleanupCount)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_UserIDValidation 测试UserID验证
|
||||
func TestTokenService_UserIDValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
isValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的UserID",
|
||||
userID: 1,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "UserID为0时无效",
|
||||
userID: 0,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "负数UserID无效",
|
||||
userID: -1,
|
||||
isValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.userID > 0
|
||||
if isValid != tt.isValid {
|
||||
t.Errorf("UserID validation failed: got %v, want %v", isValid, tt.isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
160
internal/service/upload_service.go
Normal file
160
internal/service/upload_service.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"carrotskin/pkg/storage"
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FileType 文件类型枚举
|
||||
type FileType string
|
||||
|
||||
const (
|
||||
FileTypeAvatar FileType = "avatar"
|
||||
FileTypeTexture FileType = "texture"
|
||||
)
|
||||
|
||||
// UploadConfig 上传配置
|
||||
type UploadConfig struct {
|
||||
AllowedExts map[string]bool // 允许的文件扩展名
|
||||
MinSize int64 // 最小文件大小(字节)
|
||||
MaxSize int64 // 最大文件大小(字节)
|
||||
Expires time.Duration // URL过期时间
|
||||
}
|
||||
|
||||
// GetUploadConfig 根据文件类型获取上传配置
|
||||
func GetUploadConfig(fileType FileType) *UploadConfig {
|
||||
switch fileType {
|
||||
case FileTypeAvatar:
|
||||
return &UploadConfig{
|
||||
AllowedExts: map[string]bool{
|
||||
".jpg": true,
|
||||
".jpeg": true,
|
||||
".png": true,
|
||||
".gif": true,
|
||||
".webp": true,
|
||||
},
|
||||
MinSize: 1024, // 1KB
|
||||
MaxSize: 5 * 1024 * 1024, // 5MB
|
||||
Expires: 15 * time.Minute,
|
||||
}
|
||||
case FileTypeTexture:
|
||||
return &UploadConfig{
|
||||
AllowedExts: map[string]bool{
|
||||
".png": true,
|
||||
},
|
||||
MinSize: 1024, // 1KB
|
||||
MaxSize: 10 * 1024 * 1024, // 10MB
|
||||
Expires: 15 * time.Minute,
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateFileName 验证文件名
|
||||
func ValidateFileName(fileName string, fileType FileType) error {
|
||||
if fileName == "" {
|
||||
return fmt.Errorf("文件名不能为空")
|
||||
}
|
||||
|
||||
uploadConfig := GetUploadConfig(fileType)
|
||||
if uploadConfig == nil {
|
||||
return fmt.Errorf("不支持的文件类型")
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(fileName))
|
||||
if !uploadConfig.AllowedExts[ext] {
|
||||
return fmt.Errorf("不支持的文件格式: %s", ext)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateAvatarUploadURL 生成头像上传URL
|
||||
func GenerateAvatarUploadURL(ctx context.Context, storageClient *storage.StorageClient, cfg config.RustFSConfig, userID int64, fileName string) (*storage.PresignedPostPolicyResult, error) {
|
||||
// 1. 验证文件名
|
||||
if err := ValidateFileName(fileName, FileTypeAvatar); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 获取上传配置
|
||||
uploadConfig := GetUploadConfig(FileTypeAvatar)
|
||||
|
||||
// 3. 获取存储桶名称
|
||||
bucketName, err := storageClient.GetBucket("avatars")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取存储桶失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 生成对象名称(路径)
|
||||
// 格式: user_{userId}/timestamp_{originalFileName}
|
||||
timestamp := time.Now().Format("20060102150405")
|
||||
objectName := fmt.Sprintf("user_%d/%s_%s", userID, timestamp, fileName)
|
||||
|
||||
// 5. 生成预签名POST URL
|
||||
result, err := storageClient.GeneratePresignedPostURL(
|
||||
ctx,
|
||||
bucketName,
|
||||
objectName,
|
||||
uploadConfig.MinSize,
|
||||
uploadConfig.MaxSize,
|
||||
uploadConfig.Expires,
|
||||
cfg.UseSSL,
|
||||
cfg.Endpoint,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成上传URL失败: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateTextureUploadURL 生成材质上传URL
|
||||
func GenerateTextureUploadURL(ctx context.Context, storageClient *storage.StorageClient, cfg config.RustFSConfig, userID int64, fileName, textureType string) (*storage.PresignedPostPolicyResult, error) {
|
||||
// 1. 验证文件名
|
||||
if err := ValidateFileName(fileName, FileTypeTexture); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 验证材质类型
|
||||
if textureType != "SKIN" && textureType != "CAPE" {
|
||||
return nil, fmt.Errorf("无效的材质类型: %s", textureType)
|
||||
}
|
||||
|
||||
// 3. 获取上传配置
|
||||
uploadConfig := GetUploadConfig(FileTypeTexture)
|
||||
|
||||
// 4. 获取存储桶名称
|
||||
bucketName, err := storageClient.GetBucket("textures")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取存储桶失败: %w", err)
|
||||
}
|
||||
|
||||
// 5. 生成对象名称(路径)
|
||||
// 格式: user_{userId}/{textureType}/timestamp_{originalFileName}
|
||||
timestamp := time.Now().Format("20060102150405")
|
||||
textureTypeFolder := strings.ToLower(textureType)
|
||||
objectName := fmt.Sprintf("user_%d/%s/%s_%s", userID, textureTypeFolder, timestamp, fileName)
|
||||
|
||||
// 6. 生成预签名POST URL
|
||||
result, err := storageClient.GeneratePresignedPostURL(
|
||||
ctx,
|
||||
bucketName,
|
||||
objectName,
|
||||
uploadConfig.MinSize,
|
||||
uploadConfig.MaxSize,
|
||||
uploadConfig.Expires,
|
||||
cfg.UseSSL,
|
||||
cfg.Endpoint,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成上传URL失败: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
279
internal/service/upload_service_test.go
Normal file
279
internal/service/upload_service_test.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestUploadService_FileTypes 测试文件类型常量
|
||||
func TestUploadService_FileTypes(t *testing.T) {
|
||||
if FileTypeAvatar == "" {
|
||||
t.Error("FileTypeAvatar should not be empty")
|
||||
}
|
||||
|
||||
if FileTypeTexture == "" {
|
||||
t.Error("FileTypeTexture should not be empty")
|
||||
}
|
||||
|
||||
if FileTypeAvatar == FileTypeTexture {
|
||||
t.Error("FileTypeAvatar and FileTypeTexture should be different")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUploadConfig 测试获取上传配置
|
||||
func TestGetUploadConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fileType FileType
|
||||
wantConfig bool
|
||||
}{
|
||||
{
|
||||
name: "头像类型返回配置",
|
||||
fileType: FileTypeAvatar,
|
||||
wantConfig: true,
|
||||
},
|
||||
{
|
||||
name: "材质类型返回配置",
|
||||
fileType: FileTypeTexture,
|
||||
wantConfig: true,
|
||||
},
|
||||
{
|
||||
name: "无效类型返回nil",
|
||||
fileType: FileType("invalid"),
|
||||
wantConfig: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := GetUploadConfig(tt.fileType)
|
||||
hasConfig := config != nil
|
||||
if hasConfig != tt.wantConfig {
|
||||
t.Errorf("GetUploadConfig() = %v, want %v", hasConfig, tt.wantConfig)
|
||||
}
|
||||
|
||||
if config != nil {
|
||||
// 验证配置字段
|
||||
if config.MinSize <= 0 {
|
||||
t.Error("MinSize should be greater than 0")
|
||||
}
|
||||
if config.MaxSize <= 0 {
|
||||
t.Error("MaxSize should be greater than 0")
|
||||
}
|
||||
if config.MaxSize < config.MinSize {
|
||||
t.Error("MaxSize should be greater than or equal to MinSize")
|
||||
}
|
||||
if config.Expires <= 0 {
|
||||
t.Error("Expires should be greater than 0")
|
||||
}
|
||||
if len(config.AllowedExts) == 0 {
|
||||
t.Error("AllowedExts should not be empty")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUploadConfig_AvatarConfig 测试头像配置详情
|
||||
func TestGetUploadConfig_AvatarConfig(t *testing.T) {
|
||||
config := GetUploadConfig(FileTypeAvatar)
|
||||
if config == nil {
|
||||
t.Fatal("Avatar config should not be nil")
|
||||
}
|
||||
|
||||
// 验证允许的扩展名
|
||||
expectedExts := []string{".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||
for _, ext := range expectedExts {
|
||||
if !config.AllowedExts[ext] {
|
||||
t.Errorf("Avatar config should allow %s extension", ext)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证文件大小限制
|
||||
if config.MinSize != 1024 {
|
||||
t.Errorf("Avatar MinSize = %d, want 1024", config.MinSize)
|
||||
}
|
||||
|
||||
if config.MaxSize != 5*1024*1024 {
|
||||
t.Errorf("Avatar MaxSize = %d, want 5MB", config.MaxSize)
|
||||
}
|
||||
|
||||
// 验证过期时间
|
||||
if config.Expires != 15*time.Minute {
|
||||
t.Errorf("Avatar Expires = %v, want 15 minutes", config.Expires)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUploadConfig_TextureConfig 测试材质配置详情
|
||||
func TestGetUploadConfig_TextureConfig(t *testing.T) {
|
||||
config := GetUploadConfig(FileTypeTexture)
|
||||
if config == nil {
|
||||
t.Fatal("Texture config should not be nil")
|
||||
}
|
||||
|
||||
// 验证允许的扩展名(材质只允许PNG)
|
||||
if !config.AllowedExts[".png"] {
|
||||
t.Error("Texture config should allow .png extension")
|
||||
}
|
||||
|
||||
// 验证文件大小限制
|
||||
if config.MinSize != 1024 {
|
||||
t.Errorf("Texture MinSize = %d, want 1024", config.MinSize)
|
||||
}
|
||||
|
||||
if config.MaxSize != 10*1024*1024 {
|
||||
t.Errorf("Texture MaxSize = %d, want 10MB", config.MaxSize)
|
||||
}
|
||||
|
||||
// 验证过期时间
|
||||
if config.Expires != 15*time.Minute {
|
||||
t.Errorf("Texture Expires = %v, want 15 minutes", config.Expires)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateFileName 测试文件名验证
|
||||
func TestValidateFileName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fileName string
|
||||
fileType FileType
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "有效的头像文件名",
|
||||
fileName: "avatar.png",
|
||||
fileType: FileTypeAvatar,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "有效的材质文件名",
|
||||
fileName: "texture.png",
|
||||
fileType: FileTypeTexture,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "文件名为空",
|
||||
fileName: "",
|
||||
fileType: FileTypeAvatar,
|
||||
wantErr: true,
|
||||
errContains: "文件名不能为空",
|
||||
},
|
||||
{
|
||||
name: "不支持的文件扩展名",
|
||||
fileName: "file.txt",
|
||||
fileType: FileTypeAvatar,
|
||||
wantErr: true,
|
||||
errContains: "不支持的文件格式",
|
||||
},
|
||||
{
|
||||
name: "无效的文件类型",
|
||||
fileName: "file.png",
|
||||
fileType: FileType("invalid"),
|
||||
wantErr: true,
|
||||
errContains: "不支持的文件类型",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateFileName(tt.fileName, tt.fileType)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateFileName() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr && tt.errContains != "" {
|
||||
if err == nil || !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("ValidateFileName() error = %v, should contain %s", err, tt.errContains)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateFileName_Extensions 测试各种扩展名
|
||||
func TestValidateFileName_Extensions(t *testing.T) {
|
||||
avatarExts := []string{".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||
for _, ext := range avatarExts {
|
||||
fileName := "test" + ext
|
||||
err := ValidateFileName(fileName, FileTypeAvatar)
|
||||
if err != nil {
|
||||
t.Errorf("Avatar file with %s extension should be valid, got error: %v", ext, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 材质只支持PNG
|
||||
textureExts := []string{".png"}
|
||||
for _, ext := range textureExts {
|
||||
fileName := "test" + ext
|
||||
err := ValidateFileName(fileName, FileTypeTexture)
|
||||
if err != nil {
|
||||
t.Errorf("Texture file with %s extension should be valid, got error: %v", ext, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 测试不支持的扩展名
|
||||
invalidExts := []string{".txt", ".pdf", ".doc"}
|
||||
for _, ext := range invalidExts {
|
||||
fileName := "test" + ext
|
||||
err := ValidateFileName(fileName, FileTypeAvatar)
|
||||
if err == nil {
|
||||
t.Errorf("Avatar file with %s extension should be invalid", ext)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateFileName_CaseInsensitive 测试扩展名大小写不敏感
|
||||
func TestValidateFileName_CaseInsensitive(t *testing.T) {
|
||||
testCases := []struct {
|
||||
fileName string
|
||||
fileType FileType
|
||||
wantErr bool
|
||||
}{
|
||||
{"test.PNG", FileTypeAvatar, false},
|
||||
{"test.JPG", FileTypeAvatar, false},
|
||||
{"test.JPEG", FileTypeAvatar, false},
|
||||
{"test.GIF", FileTypeAvatar, false},
|
||||
{"test.WEBP", FileTypeAvatar, false},
|
||||
{"test.PnG", FileTypeTexture, false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.fileName, func(t *testing.T) {
|
||||
err := ValidateFileName(tc.fileName, tc.fileType)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("ValidateFileName(%s, %s) error = %v, wantErr %v", tc.fileName, tc.fileType, err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUploadConfig_Structure 测试UploadConfig结构
|
||||
func TestUploadConfig_Structure(t *testing.T) {
|
||||
config := &UploadConfig{
|
||||
AllowedExts: map[string]bool{
|
||||
".png": true,
|
||||
},
|
||||
MinSize: 1024,
|
||||
MaxSize: 5 * 1024 * 1024,
|
||||
Expires: 15 * time.Minute,
|
||||
}
|
||||
|
||||
if config.AllowedExts == nil {
|
||||
t.Error("AllowedExts should not be nil")
|
||||
}
|
||||
|
||||
if config.MinSize <= 0 {
|
||||
t.Error("MinSize should be greater than 0")
|
||||
}
|
||||
|
||||
if config.MaxSize <= config.MinSize {
|
||||
t.Error("MaxSize should be greater than MinSize")
|
||||
}
|
||||
|
||||
if config.Expires <= 0 {
|
||||
t.Error("Expires should be greater than 0")
|
||||
}
|
||||
}
|
||||
|
||||
248
internal/service/user_service.go
Normal file
248
internal/service/user_service.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"carrotskin/pkg/auth"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RegisterUser 用户注册
|
||||
func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar string) (*model.User, string, error) {
|
||||
// 检查用户名是否已存在
|
||||
existingUser, err := repository.FindUserByUsername(username)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if existingUser != nil {
|
||||
return nil, "", errors.New("用户名已存在")
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existingEmail, err := repository.FindUserByEmail(email)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if existingEmail != nil {
|
||||
return nil, "", errors.New("邮箱已被注册")
|
||||
}
|
||||
|
||||
// 加密密码
|
||||
hashedPassword, err := auth.HashPassword(password)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
// 确定头像URL:优先使用用户提供的头像,否则使用默认头像
|
||||
avatarURL := avatar
|
||||
if avatarURL == "" {
|
||||
avatarURL = getDefaultAvatar()
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
user := &model.User{
|
||||
Username: username,
|
||||
Password: hashedPassword,
|
||||
Email: email,
|
||||
Avatar: avatarURL,
|
||||
Role: "user",
|
||||
Status: 1,
|
||||
Points: 0, // 初始积分可以从配置读取
|
||||
}
|
||||
|
||||
if err := repository.CreateUser(user); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// 生成JWT Token
|
||||
token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("生成Token失败")
|
||||
}
|
||||
|
||||
// TODO: 添加注册奖励积分
|
||||
|
||||
return user, token, nil
|
||||
}
|
||||
|
||||
// LoginUser 用户登录(支持用户名或邮箱登录)
|
||||
func LoginUser(jwtService *auth.JWTService, usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) {
|
||||
// 查找用户:判断是用户名还是邮箱
|
||||
var user *model.User
|
||||
var err error
|
||||
|
||||
if strings.Contains(usernameOrEmail, "@") {
|
||||
// 包含@符号,认为是邮箱
|
||||
user, err = repository.FindUserByEmail(usernameOrEmail)
|
||||
} else {
|
||||
// 否则认为是用户名
|
||||
user, err = repository.FindUserByUsername(usernameOrEmail)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if user == nil {
|
||||
// 记录失败日志
|
||||
logFailedLogin(0, ipAddress, userAgent, "用户不存在")
|
||||
return nil, "", errors.New("用户名/邮箱或密码错误")
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if user.Status != 1 {
|
||||
logFailedLogin(user.ID, ipAddress, userAgent, "账号已被禁用")
|
||||
return nil, "", errors.New("账号已被禁用")
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if !auth.CheckPassword(user.Password, password) {
|
||||
logFailedLogin(user.ID, ipAddress, userAgent, "密码错误")
|
||||
return nil, "", errors.New("用户名/邮箱或密码错误")
|
||||
}
|
||||
|
||||
// 生成JWT Token
|
||||
token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("生成Token失败")
|
||||
}
|
||||
|
||||
// 更新最后登录时间
|
||||
now := time.Now()
|
||||
user.LastLoginAt = &now
|
||||
_ = repository.UpdateUserFields(user.ID, map[string]interface{}{
|
||||
"last_login_at": now,
|
||||
})
|
||||
|
||||
// 记录成功登录日志
|
||||
logSuccessLogin(user.ID, ipAddress, userAgent)
|
||||
|
||||
return user, token, nil
|
||||
}
|
||||
|
||||
// GetUserByID 根据ID获取用户
|
||||
func GetUserByID(id int64) (*model.User, error) {
|
||||
return repository.FindUserByID(id)
|
||||
}
|
||||
|
||||
// UpdateUserInfo 更新用户信息
|
||||
func UpdateUserInfo(user *model.User) error {
|
||||
return repository.UpdateUser(user)
|
||||
}
|
||||
|
||||
// UpdateUserAvatar 更新用户头像
|
||||
func UpdateUserAvatar(userID int64, avatarURL string) error {
|
||||
return repository.UpdateUserFields(userID, map[string]interface{}{
|
||||
"avatar": avatarURL,
|
||||
})
|
||||
}
|
||||
|
||||
// ChangeUserPassword 修改密码
|
||||
func ChangeUserPassword(userID int64, oldPassword, newPassword string) error {
|
||||
// 获取用户
|
||||
user, err := repository.FindUserByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 验证旧密码
|
||||
if !auth.CheckPassword(user.Password, oldPassword) {
|
||||
return errors.New("原密码错误")
|
||||
}
|
||||
|
||||
// 加密新密码
|
||||
hashedPassword, err := auth.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
return repository.UpdateUserFields(userID, map[string]interface{}{
|
||||
"password": hashedPassword,
|
||||
})
|
||||
}
|
||||
|
||||
// ResetUserPassword 重置密码(通过邮箱)
|
||||
func ResetUserPassword(email, newPassword string) error {
|
||||
// 查找用户
|
||||
user, err := repository.FindUserByEmail(email)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 加密新密码
|
||||
hashedPassword, err := auth.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
return repository.UpdateUserFields(user.ID, map[string]interface{}{
|
||||
"password": hashedPassword,
|
||||
})
|
||||
}
|
||||
|
||||
// ChangeUserEmail 更换邮箱
|
||||
func ChangeUserEmail(userID int64, newEmail string) error {
|
||||
// 检查新邮箱是否已被使用
|
||||
existingUser, err := repository.FindUserByEmail(newEmail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existingUser != nil && existingUser.ID != userID {
|
||||
return errors.New("邮箱已被其他用户使用")
|
||||
}
|
||||
|
||||
// 更新邮箱
|
||||
return repository.UpdateUserFields(userID, map[string]interface{}{
|
||||
"email": newEmail,
|
||||
})
|
||||
}
|
||||
|
||||
// logSuccessLogin 记录成功登录
|
||||
func logSuccessLogin(userID int64, ipAddress, userAgent string) {
|
||||
log := &model.UserLoginLog{
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
LoginMethod: "PASSWORD",
|
||||
IsSuccess: true,
|
||||
}
|
||||
_ = repository.CreateLoginLog(log)
|
||||
}
|
||||
|
||||
// logFailedLogin 记录失败登录
|
||||
func logFailedLogin(userID int64, ipAddress, userAgent, reason string) {
|
||||
log := &model.UserLoginLog{
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
LoginMethod: "PASSWORD",
|
||||
IsSuccess: false,
|
||||
FailureReason: reason,
|
||||
}
|
||||
_ = repository.CreateLoginLog(log)
|
||||
}
|
||||
|
||||
// getDefaultAvatar 获取默认头像URL
|
||||
func getDefaultAvatar() string {
|
||||
// 如果数据库中不存在默认头像配置,返回错误信息
|
||||
const log = "数据库中不存在默认头像配置"
|
||||
|
||||
// 尝试从数据库读取配置
|
||||
config, err := repository.GetSystemConfigByKey("default_avatar")
|
||||
if err != nil || config == nil {
|
||||
return log
|
||||
}
|
||||
|
||||
return config.Value
|
||||
}
|
||||
|
||||
func GetUserByEmail(email string) (*model.User, error) {
|
||||
user, err := repository.FindUserByEmail(email)
|
||||
if err != nil {
|
||||
return nil, errors.New("邮箱查找失败")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
199
internal/service/user_service_test.go
Normal file
199
internal/service/user_service_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestGetDefaultAvatar 测试获取默认头像的逻辑
|
||||
// 注意:这个测试需要mock repository,但由于repository是函数式的,
|
||||
// 我们只测试逻辑部分
|
||||
func TestGetDefaultAvatar_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configExists bool
|
||||
configValue string
|
||||
expectedResult string
|
||||
}{
|
||||
{
|
||||
name: "配置存在时返回配置值",
|
||||
configExists: true,
|
||||
configValue: "https://example.com/avatar.png",
|
||||
expectedResult: "https://example.com/avatar.png",
|
||||
},
|
||||
{
|
||||
name: "配置不存在时返回错误信息",
|
||||
configExists: false,
|
||||
configValue: "",
|
||||
expectedResult: "数据库中不存在默认头像配置",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 这个测试只验证逻辑,不实际调用repository
|
||||
// 实际的repository调用测试需要集成测试或mock
|
||||
if tt.configExists {
|
||||
if tt.expectedResult != tt.configValue {
|
||||
t.Errorf("当配置存在时,应该返回配置值")
|
||||
}
|
||||
} else {
|
||||
if !strings.Contains(tt.expectedResult, "数据库中不存在默认头像配置") {
|
||||
t.Errorf("当配置不存在时,应该返回错误信息")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoginUser_EmailDetection 测试登录时邮箱检测逻辑
|
||||
func TestLoginUser_EmailDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
usernameOrEmail string
|
||||
isEmail bool
|
||||
}{
|
||||
{
|
||||
name: "包含@符号,识别为邮箱",
|
||||
usernameOrEmail: "user@example.com",
|
||||
isEmail: true,
|
||||
},
|
||||
{
|
||||
name: "不包含@符号,识别为用户名",
|
||||
usernameOrEmail: "username",
|
||||
isEmail: false,
|
||||
},
|
||||
{
|
||||
name: "空字符串",
|
||||
usernameOrEmail: "",
|
||||
isEmail: false,
|
||||
},
|
||||
{
|
||||
name: "只有@符号",
|
||||
usernameOrEmail: "@",
|
||||
isEmail: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isEmail := strings.Contains(tt.usernameOrEmail, "@")
|
||||
if isEmail != tt.isEmail {
|
||||
t.Errorf("Email detection failed: got %v, want %v", isEmail, tt.isEmail)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserService_Constants 测试用户服务相关常量
|
||||
func TestUserService_Constants(t *testing.T) {
|
||||
// 测试默认用户角色
|
||||
defaultRole := "user"
|
||||
if defaultRole == "" {
|
||||
t.Error("默认用户角色不能为空")
|
||||
}
|
||||
|
||||
// 测试默认用户状态
|
||||
defaultStatus := int16(1)
|
||||
if defaultStatus != 1 {
|
||||
t.Errorf("默认用户状态应为1(正常),实际为%d", defaultStatus)
|
||||
}
|
||||
|
||||
// 测试初始积分
|
||||
initialPoints := 0
|
||||
if initialPoints < 0 {
|
||||
t.Errorf("初始积分不应为负数,实际为%d", initialPoints)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserService_Validation 测试用户数据验证逻辑
|
||||
func TestUserService_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
email string
|
||||
password string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的用户名和邮箱",
|
||||
username: "testuser",
|
||||
email: "test@example.com",
|
||||
password: "password123",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户名为空",
|
||||
username: "",
|
||||
email: "test@example.com",
|
||||
password: "password123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "邮箱为空",
|
||||
username: "testuser",
|
||||
email: "",
|
||||
password: "password123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "密码为空",
|
||||
username: "testuser",
|
||||
email: "test@example.com",
|
||||
password: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "邮箱格式无效(缺少@)",
|
||||
username: "testuser",
|
||||
email: "invalid-email",
|
||||
password: "password123",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 简单的验证逻辑测试
|
||||
isValid := tt.username != "" && tt.email != "" && tt.password != "" && strings.Contains(tt.email, "@")
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserService_AvatarLogic 测试头像逻辑
|
||||
func TestUserService_AvatarLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
providedAvatar string
|
||||
defaultAvatar string
|
||||
expectedAvatar string
|
||||
}{
|
||||
{
|
||||
name: "提供头像时使用提供的头像",
|
||||
providedAvatar: "https://example.com/custom.png",
|
||||
defaultAvatar: "https://example.com/default.png",
|
||||
expectedAvatar: "https://example.com/custom.png",
|
||||
},
|
||||
{
|
||||
name: "未提供头像时使用默认头像",
|
||||
providedAvatar: "",
|
||||
defaultAvatar: "https://example.com/default.png",
|
||||
expectedAvatar: "https://example.com/default.png",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
avatarURL := tt.providedAvatar
|
||||
if avatarURL == "" {
|
||||
avatarURL = tt.defaultAvatar
|
||||
}
|
||||
if avatarURL != tt.expectedAvatar {
|
||||
t.Errorf("Avatar logic failed: got %s, want %s", avatarURL, tt.expectedAvatar)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
118
internal/service/verification_service.go
Normal file
118
internal/service/verification_service.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"carrotskin/pkg/email"
|
||||
"carrotskin/pkg/redis"
|
||||
)
|
||||
|
||||
const (
|
||||
// 验证码类型
|
||||
VerificationTypeRegister = "register"
|
||||
VerificationTypeResetPassword = "reset_password"
|
||||
VerificationTypeChangeEmail = "change_email"
|
||||
|
||||
// 验证码配置
|
||||
CodeLength = 6 // 验证码长度
|
||||
CodeExpiration = 10 * time.Minute // 验证码有效期
|
||||
CodeRateLimit = 1 * time.Minute // 发送频率限制
|
||||
)
|
||||
|
||||
// GenerateVerificationCode 生成6位数字验证码
|
||||
func GenerateVerificationCode() (string, error) {
|
||||
const digits = "0123456789"
|
||||
code := make([]byte, CodeLength)
|
||||
for i := range code {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(digits))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
code[i] = digits[num.Int64()]
|
||||
}
|
||||
return string(code), nil
|
||||
}
|
||||
|
||||
// SendVerificationCode 发送验证码
|
||||
func SendVerificationCode(ctx context.Context, redisClient *redis.Client, emailService *email.Service, email, codeType string) error {
|
||||
// 检查发送频率限制
|
||||
rateLimitKey := fmt.Sprintf("verification:rate_limit:%s:%s", codeType, email)
|
||||
exists, err := redisClient.Exists(ctx, rateLimitKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("检查发送频率失败: %w", err)
|
||||
}
|
||||
if exists > 0 {
|
||||
return fmt.Errorf("发送过于频繁,请稍后再试")
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
code, err := GenerateVerificationCode()
|
||||
if err != nil {
|
||||
return fmt.Errorf("生成验证码失败: %w", err)
|
||||
}
|
||||
|
||||
// 存储验证码到Redis
|
||||
codeKey := fmt.Sprintf("verification:code:%s:%s", codeType, email)
|
||||
if err := redisClient.Set(ctx, codeKey, code, CodeExpiration); err != nil {
|
||||
return fmt.Errorf("存储验证码失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置发送频率限制
|
||||
if err := redisClient.Set(ctx, rateLimitKey, "1", CodeRateLimit); err != nil {
|
||||
return fmt.Errorf("设置发送频率限制失败: %w", err)
|
||||
}
|
||||
|
||||
// 发送邮件
|
||||
if err := sendVerificationEmail(emailService, email, code, codeType); err != nil {
|
||||
// 发送失败,删除验证码
|
||||
_ = redisClient.Del(ctx, codeKey)
|
||||
return fmt.Errorf("发送邮件失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyCode 验证验证码
|
||||
func VerifyCode(ctx context.Context, redisClient *redis.Client, email, code, codeType string) error {
|
||||
codeKey := fmt.Sprintf("verification:code:%s:%s", codeType, email)
|
||||
|
||||
// 从Redis获取验证码
|
||||
storedCode, err := redisClient.Get(ctx, codeKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("验证码已过期或不存在")
|
||||
}
|
||||
|
||||
// 验证验证码
|
||||
if storedCode != code {
|
||||
return fmt.Errorf("验证码错误")
|
||||
}
|
||||
|
||||
// 验证成功,删除验证码
|
||||
_ = redisClient.Del(ctx, codeKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteVerificationCode 删除验证码
|
||||
func DeleteVerificationCode(ctx context.Context, redisClient *redis.Client, email, codeType string) error {
|
||||
codeKey := fmt.Sprintf("verification:code:%s:%s", codeType, email)
|
||||
return redisClient.Del(ctx, codeKey)
|
||||
}
|
||||
|
||||
// sendVerificationEmail 根据类型发送邮件
|
||||
func sendVerificationEmail(emailService *email.Service, to, code, codeType string) error {
|
||||
switch codeType {
|
||||
case VerificationTypeRegister:
|
||||
return emailService.SendEmailVerification(to, code)
|
||||
case VerificationTypeResetPassword:
|
||||
return emailService.SendResetPassword(to, code)
|
||||
case VerificationTypeChangeEmail:
|
||||
return emailService.SendChangeEmail(to, code)
|
||||
default:
|
||||
return emailService.SendVerificationCode(to, code, codeType)
|
||||
}
|
||||
}
|
||||
119
internal/service/verification_service_test.go
Normal file
119
internal/service/verification_service_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestGenerateVerificationCode 测试生成验证码函数
|
||||
func TestGenerateVerificationCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantLen int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "生成6位验证码",
|
||||
wantLen: CodeLength,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
code, err := GenerateVerificationCode()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GenerateVerificationCode() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && len(code) != tt.wantLen {
|
||||
t.Errorf("GenerateVerificationCode() code length = %v, want %v", len(code), tt.wantLen)
|
||||
}
|
||||
// 验证验证码只包含数字
|
||||
for _, c := range code {
|
||||
if c < '0' || c > '9' {
|
||||
t.Errorf("GenerateVerificationCode() code contains non-digit: %c", c)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 测试多次生成,验证码应该不同(概率上)
|
||||
codes := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
code, err := GenerateVerificationCode()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateVerificationCode() failed: %v", err)
|
||||
}
|
||||
if codes[code] {
|
||||
t.Logf("发现重复验证码(这是正常的,因为只有6位数字): %s", code)
|
||||
}
|
||||
codes[code] = true
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerificationConstants 测试验证码相关常量
|
||||
func TestVerificationConstants(t *testing.T) {
|
||||
if CodeLength != 6 {
|
||||
t.Errorf("CodeLength = %d, want 6", CodeLength)
|
||||
}
|
||||
|
||||
if CodeExpiration != 10*time.Minute {
|
||||
t.Errorf("CodeExpiration = %v, want 10 minutes", CodeExpiration)
|
||||
}
|
||||
|
||||
if CodeRateLimit != 1*time.Minute {
|
||||
t.Errorf("CodeRateLimit = %v, want 1 minute", CodeRateLimit)
|
||||
}
|
||||
|
||||
// 验证验证码类型常量
|
||||
types := []string{
|
||||
VerificationTypeRegister,
|
||||
VerificationTypeResetPassword,
|
||||
VerificationTypeChangeEmail,
|
||||
}
|
||||
|
||||
for _, vType := range types {
|
||||
if vType == "" {
|
||||
t.Error("验证码类型不能为空")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerificationCodeFormat 测试验证码格式
|
||||
func TestVerificationCodeFormat(t *testing.T) {
|
||||
code, err := GenerateVerificationCode()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateVerificationCode() failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证长度
|
||||
if len(code) != 6 {
|
||||
t.Errorf("验证码长度应为6位,实际为%d位", len(code))
|
||||
}
|
||||
|
||||
// 验证只包含数字
|
||||
for i, c := range code {
|
||||
if c < '0' || c > '9' {
|
||||
t.Errorf("验证码第%d位包含非数字字符: %c", i+1, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerificationTypes 测试验证码类型
|
||||
func TestVerificationTypes(t *testing.T) {
|
||||
validTypes := map[string]bool{
|
||||
VerificationTypeRegister: true,
|
||||
VerificationTypeResetPassword: true,
|
||||
VerificationTypeChangeEmail: true,
|
||||
}
|
||||
|
||||
for vType, isValid := range validTypes {
|
||||
if !isValid {
|
||||
t.Errorf("验证码类型 %s 应该是有效的", vType)
|
||||
}
|
||||
if vType == "" {
|
||||
t.Error("验证码类型不能为空字符串")
|
||||
}
|
||||
}
|
||||
}
|
||||
201
internal/service/yggdrasil_service.go
Normal file
201
internal/service/yggdrasil_service.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"carrotskin/pkg/redis"
|
||||
"carrotskin/pkg/utils"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SessionKeyPrefix Redis会话键前缀
|
||||
const SessionKeyPrefix = "Join_"
|
||||
|
||||
// SessionTTL 会话超时时间 - 增加到15分钟
|
||||
const SessionTTL = 15 * time.Minute
|
||||
|
||||
type SessionData struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
UserName string `json:"userName"`
|
||||
SelectedProfile string `json:"selectedProfile"`
|
||||
IP string `json:"ip"`
|
||||
}
|
||||
|
||||
// GetUserIDByEmail 根据邮箱返回用户id
|
||||
func GetUserIDByEmail(db *gorm.DB, Identifier string) (int64, error) {
|
||||
user, err := repository.FindUserByEmail(Identifier)
|
||||
if err != nil {
|
||||
return 0, errors.New("用户不存在")
|
||||
}
|
||||
return user.ID, nil
|
||||
}
|
||||
|
||||
// GetProfileByProfileName 根据用户名返回用户id
|
||||
func GetProfileByProfileName(db *gorm.DB, Identifier string) (*model.Profile, error) {
|
||||
profile, err := repository.FindProfileByName(Identifier)
|
||||
if err != nil {
|
||||
return nil, errors.New("用户角色未创建")
|
||||
}
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
// VerifyPassword 验证密码是否一致
|
||||
func VerifyPassword(db *gorm.DB, password string, Id int64) error {
|
||||
passwordStore, err := repository.GetYggdrasilPasswordById(Id)
|
||||
if err != nil {
|
||||
return errors.New("未生成密码")
|
||||
}
|
||||
if passwordStore != password {
|
||||
return errors.New("密码错误")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetProfileByUserId(db *gorm.DB, userId int64) (*model.Profile, error) {
|
||||
profiles, err := repository.FindProfilesByUserID(userId)
|
||||
if err != nil {
|
||||
return nil, errors.New("角色查找失败")
|
||||
}
|
||||
if len(profiles) == 0 {
|
||||
return nil, errors.New("角色查找失败")
|
||||
}
|
||||
return profiles[0], nil
|
||||
}
|
||||
|
||||
func GetPasswordByUserId(db *gorm.DB, userId int64) (string, error) {
|
||||
passwordStore, err := repository.GetYggdrasilPasswordById(userId)
|
||||
if err != nil {
|
||||
return "", errors.New("yggdrasil密码查找失败")
|
||||
}
|
||||
return passwordStore, nil
|
||||
}
|
||||
|
||||
// JoinServer 记录玩家加入服务器的会话信息
|
||||
func JoinServer(db *gorm.DB, logger *zap.Logger, redisClient *redis.Client, serverId, accessToken, selectedProfile, ip string) error {
|
||||
// 输入验证
|
||||
if serverId == "" || accessToken == "" || selectedProfile == "" {
|
||||
return errors.New("参数不能为空")
|
||||
}
|
||||
|
||||
// 验证serverId格式,防止注入攻击
|
||||
if len(serverId) > 100 || strings.ContainsAny(serverId, "<>\"'&") {
|
||||
return errors.New("服务器ID格式无效")
|
||||
}
|
||||
|
||||
// 验证IP格式
|
||||
if ip != "" {
|
||||
if net.ParseIP(ip) == nil {
|
||||
return errors.New("IP地址格式无效")
|
||||
}
|
||||
}
|
||||
|
||||
// 获取和验证Token
|
||||
token, err := repository.GetTokenByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"验证Token失败",
|
||||
zap.Error(err),
|
||||
zap.String("accessToken", accessToken),
|
||||
)
|
||||
return fmt.Errorf("验证Token失败: %w", err)
|
||||
}
|
||||
|
||||
// 格式化UUID并验证与Token关联的配置文件
|
||||
formattedProfile := utils.FormatUUID(selectedProfile)
|
||||
if token.ProfileId != formattedProfile {
|
||||
return errors.New("selectedProfile与Token不匹配")
|
||||
}
|
||||
|
||||
profile, err := repository.FindProfileByUUID(formattedProfile)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"获取Profile失败",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", formattedProfile),
|
||||
)
|
||||
return fmt.Errorf("获取Profile失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建会话数据
|
||||
data := SessionData{
|
||||
AccessToken: accessToken,
|
||||
UserName: profile.Name,
|
||||
SelectedProfile: formattedProfile,
|
||||
IP: ip,
|
||||
}
|
||||
|
||||
// 序列化会话数据
|
||||
marshaledData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"[ERROR]序列化会话数据失败",
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 存储会话数据到Redis
|
||||
sessionKey := SessionKeyPrefix + serverId
|
||||
ctx := context.Background()
|
||||
if err = redisClient.Set(ctx, sessionKey, marshaledData, SessionTTL); err != nil {
|
||||
logger.Error(
|
||||
"保存会话数据失败",
|
||||
zap.Error(err),
|
||||
zap.String("serverId", serverId),
|
||||
)
|
||||
return fmt.Errorf("保存会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"玩家成功加入服务器",
|
||||
zap.String("username", profile.Name),
|
||||
zap.String("serverId", serverId),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasJoinedServer 验证玩家是否已经加入了服务器
|
||||
func HasJoinedServer(logger *zap.Logger, redisClient *redis.Client, serverId, username, ip string) error {
|
||||
if serverId == "" || username == "" {
|
||||
return errors.New("服务器ID和用户名不能为空")
|
||||
}
|
||||
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 从Redis获取会话数据
|
||||
sessionKey := SessionKeyPrefix + serverId
|
||||
data, err := redisClient.GetBytes(ctx, sessionKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 获取会话数据失败:", zap.Error(err), zap.Any("serverId:", serverId))
|
||||
return fmt.Errorf("获取会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 反序列化会话数据
|
||||
var sessionData SessionData
|
||||
if err = json.Unmarshal(data, &sessionData); err != nil {
|
||||
logger.Error("[ERROR] 解析会话数据失败: ", zap.Error(err))
|
||||
return fmt.Errorf("解析会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 验证用户名
|
||||
if sessionData.UserName != username {
|
||||
return errors.New("用户名不匹配")
|
||||
}
|
||||
|
||||
// 验证IP(如果提供)
|
||||
if ip != "" && sessionData.IP != ip {
|
||||
return errors.New("IP地址不匹配")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
350
internal/service/yggdrasil_service_test.go
Normal file
350
internal/service/yggdrasil_service_test.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestYggdrasilService_Constants 测试Yggdrasil服务常量
|
||||
func TestYggdrasilService_Constants(t *testing.T) {
|
||||
if SessionKeyPrefix != "Join_" {
|
||||
t.Errorf("SessionKeyPrefix = %s, want 'Join_'", SessionKeyPrefix)
|
||||
}
|
||||
|
||||
if SessionTTL != 15*time.Minute {
|
||||
t.Errorf("SessionTTL = %v, want 15 minutes", SessionTTL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionData_Structure 测试SessionData结构
|
||||
func TestSessionData_Structure(t *testing.T) {
|
||||
data := SessionData{
|
||||
AccessToken: "test-token",
|
||||
UserName: "TestUser",
|
||||
SelectedProfile: "test-profile-uuid",
|
||||
IP: "127.0.0.1",
|
||||
}
|
||||
|
||||
if data.AccessToken == "" {
|
||||
t.Error("AccessToken should not be empty")
|
||||
}
|
||||
|
||||
if data.UserName == "" {
|
||||
t.Error("UserName should not be empty")
|
||||
}
|
||||
|
||||
if data.SelectedProfile == "" {
|
||||
t.Error("SelectedProfile should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// TestJoinServer_InputValidation 测试JoinServer输入验证逻辑
|
||||
func TestJoinServer_InputValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverId string
|
||||
accessToken string
|
||||
selectedProfile string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "所有参数有效",
|
||||
serverId: "test-server-123",
|
||||
accessToken: "test-token",
|
||||
selectedProfile: "test-profile",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "serverId为空",
|
||||
serverId: "",
|
||||
accessToken: "test-token",
|
||||
selectedProfile: "test-profile",
|
||||
wantErr: true,
|
||||
errContains: "参数不能为空",
|
||||
},
|
||||
{
|
||||
name: "accessToken为空",
|
||||
serverId: "test-server",
|
||||
accessToken: "",
|
||||
selectedProfile: "test-profile",
|
||||
wantErr: true,
|
||||
errContains: "参数不能为空",
|
||||
},
|
||||
{
|
||||
name: "selectedProfile为空",
|
||||
serverId: "test-server",
|
||||
accessToken: "test-token",
|
||||
selectedProfile: "",
|
||||
wantErr: true,
|
||||
errContains: "参数不能为空",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.serverId == "" || tt.accessToken == "" || tt.selectedProfile == ""
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJoinServer_ServerIDValidation 测试服务器ID格式验证
|
||||
func TestJoinServer_ServerIDValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverId string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的serverId",
|
||||
serverId: "test-server-123",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "serverId过长",
|
||||
serverId: strings.Repeat("a", 101),
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverId包含危险字符<",
|
||||
serverId: "test<server",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverId包含危险字符>",
|
||||
serverId: "test>server",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverId包含危险字符\"",
|
||||
serverId: "test\"server",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverId包含危险字符'",
|
||||
serverId: "test'server",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverId包含危险字符&",
|
||||
serverId: "test&server",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := len(tt.serverId) <= 100 && !strings.ContainsAny(tt.serverId, "<>\"'&")
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("ServerID validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJoinServer_IPValidation 测试IP地址验证逻辑
|
||||
func TestJoinServer_IPValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的IPv4地址",
|
||||
ip: "127.0.0.1",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "有效的IPv6地址",
|
||||
ip: "::1",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "无效的IP地址",
|
||||
ip: "invalid-ip",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "空IP地址(可选)",
|
||||
ip: "",
|
||||
wantValid: true, // 空IP是允许的
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var isValid bool
|
||||
if tt.ip == "" {
|
||||
isValid = true // 空IP是允许的
|
||||
} else {
|
||||
isValid = net.ParseIP(tt.ip) != nil
|
||||
}
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("IP validation failed: got %v, want %v (ip=%s)", isValid, tt.wantValid, tt.ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHasJoinedServer_InputValidation 测试HasJoinedServer输入验证
|
||||
func TestHasJoinedServer_InputValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverId string
|
||||
username string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "所有参数有效",
|
||||
serverId: "test-server",
|
||||
username: "TestUser",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "serverId为空",
|
||||
serverId: "",
|
||||
username: "TestUser",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "username为空",
|
||||
serverId: "test-server",
|
||||
username: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "两者都为空",
|
||||
serverId: "",
|
||||
username: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.serverId == "" || tt.username == ""
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHasJoinedServer_UsernameMatching 测试用户名匹配逻辑
|
||||
func TestHasJoinedServer_UsernameMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sessionUser string
|
||||
requestUser string
|
||||
wantMatch bool
|
||||
}{
|
||||
{
|
||||
name: "用户名匹配",
|
||||
sessionUser: "TestUser",
|
||||
requestUser: "TestUser",
|
||||
wantMatch: true,
|
||||
},
|
||||
{
|
||||
name: "用户名不匹配",
|
||||
sessionUser: "TestUser",
|
||||
requestUser: "OtherUser",
|
||||
wantMatch: false,
|
||||
},
|
||||
{
|
||||
name: "大小写敏感",
|
||||
sessionUser: "TestUser",
|
||||
requestUser: "testuser",
|
||||
wantMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matches := tt.sessionUser == tt.requestUser
|
||||
if matches != tt.wantMatch {
|
||||
t.Errorf("Username matching failed: got %v, want %v", matches, tt.wantMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHasJoinedServer_IPMatching 测试IP地址匹配逻辑
|
||||
func TestHasJoinedServer_IPMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sessionIP string
|
||||
requestIP string
|
||||
wantMatch bool
|
||||
shouldCheck bool
|
||||
}{
|
||||
{
|
||||
name: "IP匹配",
|
||||
sessionIP: "127.0.0.1",
|
||||
requestIP: "127.0.0.1",
|
||||
wantMatch: true,
|
||||
shouldCheck: true,
|
||||
},
|
||||
{
|
||||
name: "IP不匹配",
|
||||
sessionIP: "127.0.0.1",
|
||||
requestIP: "192.168.1.1",
|
||||
wantMatch: false,
|
||||
shouldCheck: true,
|
||||
},
|
||||
{
|
||||
name: "请求IP为空时不检查",
|
||||
sessionIP: "127.0.0.1",
|
||||
requestIP: "",
|
||||
wantMatch: true,
|
||||
shouldCheck: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var matches bool
|
||||
if tt.requestIP == "" {
|
||||
matches = true // 空IP不检查
|
||||
} else {
|
||||
matches = tt.sessionIP == tt.requestIP
|
||||
}
|
||||
if matches != tt.wantMatch {
|
||||
t.Errorf("IP matching failed: got %v, want %v", matches, tt.wantMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJoinServer_SessionKey 测试会话键生成
|
||||
func TestJoinServer_SessionKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverId string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "生成正确的会话键",
|
||||
serverId: "test-server-123",
|
||||
expected: "Join_test-server-123",
|
||||
},
|
||||
{
|
||||
name: "空serverId",
|
||||
serverId: "",
|
||||
expected: "Join_",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sessionKey := SessionKeyPrefix + tt.serverId
|
||||
if sessionKey != tt.expected {
|
||||
t.Errorf("Session key = %s, want %s", sessionKey, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
215
internal/types/common.go
Normal file
215
internal/types/common.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package types
|
||||
|
||||
import "time"
|
||||
|
||||
// BaseResponse 基础响应结构
|
||||
type BaseResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// PaginationRequest 分页请求
|
||||
type PaginationRequest struct {
|
||||
Page int `json:"page" form:"page" binding:"omitempty,min=1"`
|
||||
PageSize int `json:"page_size" form:"page_size" binding:"omitempty,min=1,max=100"`
|
||||
}
|
||||
|
||||
// PaginationResponse 分页响应
|
||||
type PaginationResponse struct {
|
||||
List interface{} `json:"list"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
// LoginRequest 登录请求
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username" binding:"required" example:"testuser"` // 支持用户名或邮箱
|
||||
Password string `json:"password" binding:"required,min=6,max=128" example:"password123"`
|
||||
}
|
||||
|
||||
// RegisterRequest 注册请求
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username" binding:"required,min=3,max=50" example:"newuser"`
|
||||
Email string `json:"email" binding:"required,email" example:"user@example.com"`
|
||||
Password string `json:"password" binding:"required,min=6,max=128" example:"password123"`
|
||||
VerificationCode string `json:"verification_code" binding:"required,len=6" example:"123456"` // 邮箱验证码
|
||||
Avatar string `json:"avatar" binding:"omitempty,url" example:"https://rustfs.example.com/avatars/user_1/avatar.png"` // 可选,用户自定义头像
|
||||
}
|
||||
|
||||
// UpdateUserRequest 更新用户请求
|
||||
type UpdateUserRequest struct {
|
||||
Avatar string `json:"avatar" binding:"omitempty,url" example:"https://example.com/new-avatar.png"`
|
||||
OldPassword string `json:"old_password" binding:"omitempty,min=6,max=128" example:"oldpassword123"` // 修改密码时必需
|
||||
NewPassword string `json:"new_password" binding:"omitempty,min=6,max=128" example:"newpassword123"` // 新密码
|
||||
}
|
||||
|
||||
// SendVerificationCodeRequest 发送验证码请求
|
||||
type SendVerificationCodeRequest struct {
|
||||
Email string `json:"email" binding:"required,email" example:"user@example.com"`
|
||||
Type string `json:"type" binding:"required,oneof=register reset_password change_email" example:"register"` // 类型: register/reset_password/change_email
|
||||
}
|
||||
|
||||
// ResetPasswordRequest 重置密码请求
|
||||
type ResetPasswordRequest struct {
|
||||
Email string `json:"email" binding:"required,email" example:"user@example.com"`
|
||||
VerificationCode string `json:"verification_code" binding:"required,len=6" example:"123456"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6,max=128" example:"newpassword123"`
|
||||
}
|
||||
|
||||
// ChangeEmailRequest 更换邮箱请求
|
||||
type ChangeEmailRequest struct {
|
||||
NewEmail string `json:"new_email" binding:"required,email" example:"newemail@example.com"`
|
||||
VerificationCode string `json:"verification_code" binding:"required,len=6" example:"123456"`
|
||||
}
|
||||
|
||||
// GenerateAvatarUploadURLRequest 生成头像上传URL请求
|
||||
type GenerateAvatarUploadURLRequest struct {
|
||||
FileName string `json:"file_name" binding:"required" example:"avatar.png"`
|
||||
}
|
||||
|
||||
// GenerateAvatarUploadURLResponse 生成头像上传URL响应
|
||||
type GenerateAvatarUploadURLResponse struct {
|
||||
PostURL string `json:"post_url" example:"https://rustfs.example.com/avatars"`
|
||||
FormData map[string]string `json:"form_data"`
|
||||
AvatarURL string `json:"avatar_url" example:"https://rustfs.example.com/avatars/user_1/xxx.png"`
|
||||
ExpiresIn int `json:"expires_in" example:"900"` // 秒
|
||||
}
|
||||
|
||||
// CreateProfileRequest 创建档案请求
|
||||
type CreateProfileRequest struct {
|
||||
Name string `json:"name" binding:"required,min=1,max=16" example:"PlayerName"`
|
||||
}
|
||||
|
||||
// UpdateTextureRequest 更新材质请求
|
||||
type UpdateTextureRequest struct {
|
||||
Name string `json:"name" binding:"omitempty,min=1,max=100" example:"My Skin"`
|
||||
Description string `json:"description" binding:"omitempty,max=500" example:"A cool skin"`
|
||||
IsPublic *bool `json:"is_public" example:"true"`
|
||||
}
|
||||
|
||||
// GenerateTextureUploadURLRequest 生成材质上传URL请求
|
||||
type GenerateTextureUploadURLRequest struct {
|
||||
FileName string `json:"file_name" binding:"required" example:"skin.png"`
|
||||
TextureType TextureType `json:"texture_type" binding:"required,oneof=SKIN CAPE" example:"SKIN"`
|
||||
}
|
||||
|
||||
// GenerateTextureUploadURLResponse 生成材质上传URL响应
|
||||
type GenerateTextureUploadURLResponse struct {
|
||||
PostURL string `json:"post_url" example:"https://rustfs.example.com/textures"`
|
||||
FormData map[string]string `json:"form_data"`
|
||||
TextureURL string `json:"texture_url" example:"https://rustfs.example.com/textures/user_1/skin/xxx.png"`
|
||||
ExpiresIn int `json:"expires_in" example:"900"` // 秒
|
||||
}
|
||||
|
||||
// LoginResponse 登录响应
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
UserInfo *UserInfo `json:"user_info"`
|
||||
}
|
||||
|
||||
// UserInfo 用户信息
|
||||
type UserInfo struct {
|
||||
ID int64 `json:"id" example:"1"`
|
||||
Username string `json:"username" example:"testuser"`
|
||||
Email string `json:"email" example:"test@example.com"`
|
||||
Avatar string `json:"avatar" example:"https://example.com/avatar.png"`
|
||||
Points int `json:"points" example:"100"`
|
||||
Role string `json:"role" example:"user"`
|
||||
Status int16 `json:"status" example:"1"`
|
||||
LastLoginAt *time.Time `json:"last_login_at,omitempty" example:"2025-10-01T12:00:00Z"`
|
||||
CreatedAt time.Time `json:"created_at" example:"2025-10-01T10:00:00Z"`
|
||||
UpdatedAt time.Time `json:"updated_at" example:"2025-10-01T10:00:00Z"`
|
||||
}
|
||||
|
||||
// TextureType 材质类型
|
||||
type TextureType string
|
||||
|
||||
const (
|
||||
TextureTypeSkin TextureType = "SKIN"
|
||||
TextureTypeCape TextureType = "CAPE"
|
||||
)
|
||||
|
||||
// TextureInfo 材质信息
|
||||
type TextureInfo struct {
|
||||
ID int64 `json:"id" example:"1"`
|
||||
UploaderID int64 `json:"uploader_id" example:"1"`
|
||||
Name string `json:"name" example:"My Skin"`
|
||||
Description string `json:"description,omitempty" example:"A cool skin"`
|
||||
Type TextureType `json:"type" example:"SKIN"`
|
||||
URL string `json:"url" example:"https://rustfs.example.com/textures/xxx.png"`
|
||||
Hash string `json:"hash" example:"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"`
|
||||
Size int `json:"size" example:"2048"`
|
||||
IsPublic bool `json:"is_public" example:"true"`
|
||||
DownloadCount int `json:"download_count" example:"100"`
|
||||
FavoriteCount int `json:"favorite_count" example:"50"`
|
||||
IsSlim bool `json:"is_slim" example:"false"`
|
||||
Status int16 `json:"status" example:"1"`
|
||||
CreatedAt time.Time `json:"created_at" example:"2025-10-01T10:00:00Z"`
|
||||
UpdatedAt time.Time `json:"updated_at" example:"2025-10-01T10:00:00Z"`
|
||||
}
|
||||
|
||||
// ProfileInfo 角色信息
|
||||
type ProfileInfo struct {
|
||||
UUID string `json:"uuid" example:"550e8400-e29b-41d4-a716-446655440000"`
|
||||
UserID int64 `json:"user_id" example:"1"`
|
||||
Name string `json:"name" example:"PlayerName"`
|
||||
SkinID *int64 `json:"skin_id,omitempty" example:"1"`
|
||||
CapeID *int64 `json:"cape_id,omitempty" example:"2"`
|
||||
IsActive bool `json:"is_active" example:"true"`
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty" example:"2025-10-01T12:00:00Z"`
|
||||
CreatedAt time.Time `json:"created_at" example:"2025-10-01T10:00:00Z"`
|
||||
UpdatedAt time.Time `json:"updated_at" example:"2025-10-01T10:00:00Z"`
|
||||
}
|
||||
|
||||
// UploadURLRequest 上传URL请求
|
||||
type UploadURLRequest struct {
|
||||
Type TextureType `json:"type" binding:"required,oneof=SKIN CAPE"`
|
||||
Filename string `json:"filename" binding:"required"`
|
||||
}
|
||||
|
||||
// UploadURLResponse 上传URL响应
|
||||
type UploadURLResponse struct {
|
||||
PostURL string `json:"post_url"`
|
||||
FormData map[string]string `json:"form_data"`
|
||||
FileURL string `json:"file_url"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
// CreateTextureRequest 创建材质请求
|
||||
type CreateTextureRequest struct {
|
||||
Name string `json:"name" binding:"required,min=1,max=100" example:"My Cool Skin"`
|
||||
Description string `json:"description" binding:"max=500" example:"A very cool skin"`
|
||||
Type TextureType `json:"type" binding:"required,oneof=SKIN CAPE" example:"SKIN"`
|
||||
URL string `json:"url" binding:"required,url" example:"https://rustfs.example.com/textures/user_1/skin/xxx.png"`
|
||||
Hash string `json:"hash" binding:"required,len=64" example:"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"`
|
||||
Size int `json:"size" binding:"required,min=1" example:"2048"`
|
||||
IsPublic bool `json:"is_public" example:"true"`
|
||||
IsSlim bool `json:"is_slim" example:"false"` // Alex模型(细臂)为true,Steve模型(粗臂)为false
|
||||
}
|
||||
|
||||
// SearchTextureRequest 搜索材质请求
|
||||
type SearchTextureRequest struct {
|
||||
PaginationRequest
|
||||
Keyword string `json:"keyword" form:"keyword"`
|
||||
Type TextureType `json:"type" form:"type" binding:"omitempty,oneof=SKIN CAPE"`
|
||||
PublicOnly bool `json:"public_only" form:"public_only"`
|
||||
}
|
||||
|
||||
// UpdateProfileRequest 更新角色请求
|
||||
type UpdateProfileRequest struct {
|
||||
Name string `json:"name" binding:"omitempty,min=1,max=16" example:"NewPlayerName"`
|
||||
SkinID *int64 `json:"skin_id,omitempty" example:"1"`
|
||||
CapeID *int64 `json:"cape_id,omitempty" example:"2"`
|
||||
}
|
||||
|
||||
// SystemConfigResponse 基础系统配置响应
|
||||
type SystemConfigResponse struct {
|
||||
SiteName string `json:"site_name" example:"CarrotSkin"`
|
||||
SiteDescription string `json:"site_description" example:"A Minecraft Skin Station"`
|
||||
RegistrationEnabled bool `json:"registration_enabled" example:"true"`
|
||||
MaxTexturesPerUser int `json:"max_textures_per_user" example:"100"`
|
||||
MaxProfilesPerUser int `json:"max_profiles_per_user" example:"5"`
|
||||
}
|
||||
384
internal/types/common_test.go
Normal file
384
internal/types/common_test.go
Normal file
@@ -0,0 +1,384 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestPaginationRequest_Validation 测试分页请求验证逻辑
|
||||
func TestPaginationRequest_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
page int
|
||||
pageSize int
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的分页参数",
|
||||
page: 1,
|
||||
pageSize: 20,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "page小于1应该无效",
|
||||
page: 0,
|
||||
pageSize: 20,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "pageSize小于1应该无效",
|
||||
page: 1,
|
||||
pageSize: 0,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "pageSize超过100应该无效",
|
||||
page: 1,
|
||||
pageSize: 200,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.page >= 1 && tt.pageSize >= 1 && tt.pageSize <= 100
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextureType_Constants 测试材质类型常量
|
||||
func TestTextureType_Constants(t *testing.T) {
|
||||
if TextureTypeSkin != "SKIN" {
|
||||
t.Errorf("TextureTypeSkin = %q, want 'SKIN'", TextureTypeSkin)
|
||||
}
|
||||
|
||||
if TextureTypeCape != "CAPE" {
|
||||
t.Errorf("TextureTypeCape = %q, want 'CAPE'", TextureTypeCape)
|
||||
}
|
||||
|
||||
if TextureTypeSkin == TextureTypeCape {
|
||||
t.Error("TextureTypeSkin 和 TextureTypeCape 应该不同")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPaginationResponse_Structure 测试分页响应结构
|
||||
func TestPaginationResponse_Structure(t *testing.T) {
|
||||
resp := PaginationResponse{
|
||||
List: []string{"a", "b", "c"},
|
||||
Total: 100,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
TotalPages: 5,
|
||||
}
|
||||
|
||||
if resp.Total != 100 {
|
||||
t.Errorf("Total = %d, want 100", resp.Total)
|
||||
}
|
||||
|
||||
if resp.Page != 1 {
|
||||
t.Errorf("Page = %d, want 1", resp.Page)
|
||||
}
|
||||
|
||||
if resp.PageSize != 20 {
|
||||
t.Errorf("PageSize = %d, want 20", resp.PageSize)
|
||||
}
|
||||
|
||||
if resp.TotalPages != 5 {
|
||||
t.Errorf("TotalPages = %d, want 5", resp.TotalPages)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPaginationResponse_TotalPagesCalculation 测试总页数计算逻辑
|
||||
func TestPaginationResponse_TotalPagesCalculation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
total int64
|
||||
pageSize int
|
||||
wantPages int
|
||||
}{
|
||||
{
|
||||
name: "正好整除",
|
||||
total: 100,
|
||||
pageSize: 20,
|
||||
wantPages: 5,
|
||||
},
|
||||
{
|
||||
name: "有余数",
|
||||
total: 101,
|
||||
pageSize: 20,
|
||||
wantPages: 6, // 向上取整
|
||||
},
|
||||
{
|
||||
name: "总数小于每页数量",
|
||||
total: 10,
|
||||
pageSize: 20,
|
||||
wantPages: 1,
|
||||
},
|
||||
{
|
||||
name: "总数为0",
|
||||
total: 0,
|
||||
pageSize: 20,
|
||||
wantPages: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 计算总页数:向上取整
|
||||
var totalPages int
|
||||
if tt.total == 0 {
|
||||
totalPages = 0
|
||||
} else {
|
||||
totalPages = int((tt.total + int64(tt.pageSize) - 1) / int64(tt.pageSize))
|
||||
}
|
||||
|
||||
if totalPages != tt.wantPages {
|
||||
t.Errorf("TotalPages = %d, want %d", totalPages, tt.wantPages)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseResponse_Structure 测试基础响应结构
|
||||
func TestBaseResponse_Structure(t *testing.T) {
|
||||
resp := BaseResponse{
|
||||
Code: 200,
|
||||
Message: "success",
|
||||
Data: "test data",
|
||||
}
|
||||
|
||||
if resp.Code != 200 {
|
||||
t.Errorf("Code = %d, want 200", resp.Code)
|
||||
}
|
||||
|
||||
if resp.Message != "success" {
|
||||
t.Errorf("Message = %q, want 'success'", resp.Message)
|
||||
}
|
||||
|
||||
if resp.Data != "test data" {
|
||||
t.Errorf("Data = %v, want 'test data'", resp.Data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoginRequest_Validation 测试登录请求验证逻辑
|
||||
func TestLoginRequest_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
password string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的登录请求",
|
||||
username: "testuser",
|
||||
password: "password123",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户名为空",
|
||||
username: "",
|
||||
password: "password123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "密码为空",
|
||||
username: "testuser",
|
||||
password: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "密码长度小于6",
|
||||
username: "testuser",
|
||||
password: "12345",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "密码长度超过128",
|
||||
username: "testuser",
|
||||
password: string(make([]byte, 129)),
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.username != "" && len(tt.password) >= 6 && len(tt.password) <= 128
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegisterRequest_Validation 测试注册请求验证逻辑
|
||||
func TestRegisterRequest_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
email string
|
||||
password string
|
||||
verificationCode string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的注册请求",
|
||||
username: "newuser",
|
||||
email: "user@example.com",
|
||||
password: "password123",
|
||||
verificationCode: "123456",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户名为空",
|
||||
username: "",
|
||||
email: "user@example.com",
|
||||
password: "password123",
|
||||
verificationCode: "123456",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "用户名长度小于3",
|
||||
username: "ab",
|
||||
email: "user@example.com",
|
||||
password: "password123",
|
||||
verificationCode: "123456",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "用户名长度超过50",
|
||||
username: string(make([]byte, 51)),
|
||||
email: "user@example.com",
|
||||
password: "password123",
|
||||
verificationCode: "123456",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "邮箱格式无效",
|
||||
username: "newuser",
|
||||
email: "invalid-email",
|
||||
password: "password123",
|
||||
verificationCode: "123456",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "验证码长度不是6",
|
||||
username: "newuser",
|
||||
email: "user@example.com",
|
||||
password: "password123",
|
||||
verificationCode: "12345",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.username != "" &&
|
||||
len(tt.username) >= 3 && len(tt.username) <= 50 &&
|
||||
tt.email != "" && contains(tt.email, "@") &&
|
||||
len(tt.password) >= 6 && len(tt.password) <= 128 &&
|
||||
len(tt.verificationCode) == 6
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func contains(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestResetPasswordRequest_Validation 测试重置密码请求验证
|
||||
func TestResetPasswordRequest_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
verificationCode string
|
||||
newPassword string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的重置密码请求",
|
||||
email: "user@example.com",
|
||||
verificationCode: "123456",
|
||||
newPassword: "newpassword123",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "邮箱为空",
|
||||
email: "",
|
||||
verificationCode: "123456",
|
||||
newPassword: "newpassword123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "验证码长度不是6",
|
||||
email: "user@example.com",
|
||||
verificationCode: "12345",
|
||||
newPassword: "newpassword123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "新密码长度小于6",
|
||||
email: "user@example.com",
|
||||
verificationCode: "123456",
|
||||
newPassword: "12345",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.email != "" &&
|
||||
len(tt.verificationCode) == 6 &&
|
||||
len(tt.newPassword) >= 6 && len(tt.newPassword) <= 128
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateProfileRequest_Validation 测试创建档案请求验证
|
||||
func TestCreateProfileRequest_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileName string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的档案名",
|
||||
profileName: "PlayerName",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "档案名为空",
|
||||
profileName: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "档案名长度超过16",
|
||||
profileName: string(make([]byte, 17)),
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.profileName != "" &&
|
||||
len(tt.profileName) >= 1 && len(tt.profileName) <= 16
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
70
pkg/auth/jwt.go
Normal file
70
pkg/auth/jwt.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// JWTService JWT服务
|
||||
type JWTService struct {
|
||||
secretKey string
|
||||
expireHours int
|
||||
}
|
||||
|
||||
// NewJWTService 创建新的JWT服务
|
||||
func NewJWTService(secretKey string, expireHours int) *JWTService {
|
||||
return &JWTService{
|
||||
secretKey: secretKey,
|
||||
expireHours: expireHours,
|
||||
}
|
||||
}
|
||||
|
||||
// Claims JWT声明
|
||||
type Claims struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT Token (使用UserID和基本信息)
|
||||
func (j *JWTService) GenerateToken(userID int64, username, role string) (string, error) {
|
||||
claims := Claims{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Role: role,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(j.expireHours) * time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
Issuer: "carrotskin",
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString([]byte(j.secretKey))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// ValidateToken 验证JWT Token
|
||||
func (j *JWTService) ValidateToken(tokenString string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(j.secretKey), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("无效的token")
|
||||
}
|
||||
235
pkg/auth/jwt_test.go
Normal file
235
pkg/auth/jwt_test.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestNewJWTService 测试创建JWT服务
|
||||
func TestNewJWTService(t *testing.T) {
|
||||
secretKey := "test-secret-key"
|
||||
expireHours := 24
|
||||
|
||||
service := NewJWTService(secretKey, expireHours)
|
||||
if service == nil {
|
||||
t.Fatal("NewJWTService() 返回nil")
|
||||
}
|
||||
|
||||
if service.secretKey != secretKey {
|
||||
t.Errorf("secretKey = %q, want %q", service.secretKey, secretKey)
|
||||
}
|
||||
|
||||
if service.expireHours != expireHours {
|
||||
t.Errorf("expireHours = %d, want %d", service.expireHours, expireHours)
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTService_GenerateToken 测试生成Token
|
||||
func TestJWTService_GenerateToken(t *testing.T) {
|
||||
service := NewJWTService("test-secret-key", 24)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
username string
|
||||
role string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "正常生成Token",
|
||||
userID: 1,
|
||||
username: "testuser",
|
||||
role: "user",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "空用户名",
|
||||
userID: 1,
|
||||
username: "",
|
||||
role: "user",
|
||||
wantError: false, // JWT允许空字符串
|
||||
},
|
||||
{
|
||||
name: "空角色",
|
||||
userID: 1,
|
||||
username: "testuser",
|
||||
role: "",
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, err := service.GenerateToken(tt.userID, tt.username, tt.role)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("GenerateToken() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
if !tt.wantError {
|
||||
if token == "" {
|
||||
t.Error("GenerateToken() 返回的token不应为空")
|
||||
}
|
||||
// 验证token长度合理(JWT token通常很长)
|
||||
if len(token) < 50 {
|
||||
t.Errorf("GenerateToken() 返回的token长度异常: %d", len(token))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTService_ValidateToken 测试验证Token
|
||||
func TestJWTService_ValidateToken(t *testing.T) {
|
||||
secretKey := "test-secret-key"
|
||||
service := NewJWTService(secretKey, 24)
|
||||
|
||||
// 生成一个有效的token
|
||||
userID := int64(1)
|
||||
username := "testuser"
|
||||
role := "user"
|
||||
token, err := service.GenerateToken(userID, username, role)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateToken() 失败: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
wantError bool
|
||||
wantUserID int64
|
||||
wantUsername string
|
||||
wantRole string
|
||||
}{
|
||||
{
|
||||
name: "有效token",
|
||||
token: token,
|
||||
wantError: false,
|
||||
wantUserID: userID,
|
||||
wantUsername: username,
|
||||
wantRole: role,
|
||||
},
|
||||
{
|
||||
name: "无效token",
|
||||
token: "invalid.token.here",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "空token",
|
||||
token: "",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "使用不同密钥签名的token",
|
||||
token: func() string {
|
||||
otherService := NewJWTService("different-secret", 24)
|
||||
token, _ := otherService.GenerateToken(1, "user", "role")
|
||||
return token
|
||||
}(),
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
claims, err := service.ValidateToken(tt.token)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateToken() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
if !tt.wantError {
|
||||
if claims == nil {
|
||||
t.Fatal("ValidateToken() 返回的claims不应为nil")
|
||||
}
|
||||
if claims.UserID != tt.wantUserID {
|
||||
t.Errorf("UserID = %d, want %d", claims.UserID, tt.wantUserID)
|
||||
}
|
||||
if claims.Username != tt.wantUsername {
|
||||
t.Errorf("Username = %q, want %q", claims.Username, tt.wantUsername)
|
||||
}
|
||||
if claims.Role != tt.wantRole {
|
||||
t.Errorf("Role = %q, want %q", claims.Role, tt.wantRole)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTService_TokenRoundTrip 测试Token的完整流程
|
||||
func TestJWTService_TokenRoundTrip(t *testing.T) {
|
||||
service := NewJWTService("test-secret-key", 24)
|
||||
|
||||
userID := int64(123)
|
||||
username := "testuser"
|
||||
role := "admin"
|
||||
|
||||
// 生成token
|
||||
token, err := service.GenerateToken(userID, username, role)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateToken() 失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证token
|
||||
claims, err := service.ValidateToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken() 失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证claims内容
|
||||
if claims.UserID != userID {
|
||||
t.Errorf("UserID = %d, want %d", claims.UserID, userID)
|
||||
}
|
||||
if claims.Username != username {
|
||||
t.Errorf("Username = %q, want %q", claims.Username, username)
|
||||
}
|
||||
if claims.Role != role {
|
||||
t.Errorf("Role = %q, want %q", claims.Role, role)
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTService_TokenExpiration 测试Token过期时间
|
||||
func TestJWTService_TokenExpiration(t *testing.T) {
|
||||
expireHours := 24
|
||||
service := NewJWTService("test-secret-key", expireHours)
|
||||
|
||||
token, err := service.GenerateToken(1, "user", "role")
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateToken() 失败: %v", err)
|
||||
}
|
||||
|
||||
claims, err := service.ValidateToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken() 失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证过期时间
|
||||
if claims.ExpiresAt == nil {
|
||||
t.Error("ExpiresAt 不应为nil")
|
||||
} else {
|
||||
expectedExpiry := time.Now().Add(time.Duration(expireHours) * time.Hour)
|
||||
// 允许1分钟的误差
|
||||
diff := claims.ExpiresAt.Time.Sub(expectedExpiry)
|
||||
if diff < -time.Minute || diff > time.Minute {
|
||||
t.Errorf("ExpiresAt 时间异常: %v, 期望约 %v", claims.ExpiresAt.Time, expectedExpiry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTService_TokenIssuer 测试Token发行者
|
||||
func TestJWTService_TokenIssuer(t *testing.T) {
|
||||
service := NewJWTService("test-secret-key", 24)
|
||||
|
||||
token, err := service.GenerateToken(1, "user", "role")
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateToken() 失败: %v", err)
|
||||
}
|
||||
|
||||
claims, err := service.ValidateToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken() 失败: %v", err)
|
||||
}
|
||||
|
||||
expectedIssuer := "carrotskin"
|
||||
if claims.Issuer != expectedIssuer {
|
||||
t.Errorf("Issuer = %q, want %q", claims.Issuer, expectedIssuer)
|
||||
}
|
||||
}
|
||||
45
pkg/auth/manager.go
Normal file
45
pkg/auth/manager.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
// jwtServiceInstance 全局JWT服务实例
|
||||
jwtServiceInstance *JWTService
|
||||
// once 确保只初始化一次
|
||||
once sync.Once
|
||||
// initError 初始化错误
|
||||
initError error
|
||||
)
|
||||
|
||||
// Init 初始化JWT服务(线程安全,只会执行一次)
|
||||
func Init(cfg config.JWTConfig) error {
|
||||
once.Do(func() {
|
||||
jwtServiceInstance = NewJWTService(cfg.Secret, cfg.ExpireHours)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetJWTService 获取JWT服务实例(线程安全)
|
||||
func GetJWTService() (*JWTService, error) {
|
||||
if jwtServiceInstance == nil {
|
||||
return nil, fmt.Errorf("JWT服务未初始化,请先调用 auth.Init()")
|
||||
}
|
||||
return jwtServiceInstance, nil
|
||||
}
|
||||
|
||||
// MustGetJWTService 获取JWT服务实例,如果未初始化则panic
|
||||
func MustGetJWTService() *JWTService {
|
||||
service, err := GetJWTService()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return service
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
86
pkg/auth/manager_test.go
Normal file
86
pkg/auth/manager_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestGetJWTService_NotInitialized 测试未初始化时获取JWT服务
|
||||
func TestGetJWTService_NotInitialized(t *testing.T) {
|
||||
_, err := GetJWTService()
|
||||
if err == nil {
|
||||
t.Error("未初始化时应该返回错误")
|
||||
}
|
||||
|
||||
expectedError := "JWT服务未初始化,请先调用 auth.Init()"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("错误消息 = %q, want %q", err.Error(), expectedError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMustGetJWTService_Panic 测试MustGetJWTService在未初始化时panic
|
||||
func TestMustGetJWTService_Panic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("MustGetJWTService 应该在未初始化时panic")
|
||||
}
|
||||
}()
|
||||
|
||||
_ = MustGetJWTService()
|
||||
}
|
||||
|
||||
// TestInit_JWTService 测试JWT服务初始化
|
||||
func TestInit_JWTService(t *testing.T) {
|
||||
cfg := config.JWTConfig{
|
||||
Secret: "test-secret-key",
|
||||
ExpireHours: 24,
|
||||
}
|
||||
|
||||
err := Init(cfg)
|
||||
if err != nil {
|
||||
t.Errorf("Init() 错误 = %v, want nil", err)
|
||||
}
|
||||
|
||||
// 验证可以获取服务
|
||||
service, err := GetJWTService()
|
||||
if err != nil {
|
||||
t.Errorf("GetJWTService() 错误 = %v, want nil", err)
|
||||
}
|
||||
if service == nil {
|
||||
t.Error("GetJWTService() 返回的服务不应为nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInit_JWTService_Once 测试Init只执行一次
|
||||
func TestInit_JWTService_Once(t *testing.T) {
|
||||
cfg := config.JWTConfig{
|
||||
Secret: "test-secret-key-1",
|
||||
ExpireHours: 24,
|
||||
}
|
||||
|
||||
// 第一次初始化
|
||||
err1 := Init(cfg)
|
||||
if err1 != nil {
|
||||
t.Fatalf("第一次Init() 错误 = %v", err1)
|
||||
}
|
||||
|
||||
service1, _ := GetJWTService()
|
||||
|
||||
// 第二次初始化(应该不会改变服务)
|
||||
cfg2 := config.JWTConfig{
|
||||
Secret: "test-secret-key-2",
|
||||
ExpireHours: 48,
|
||||
}
|
||||
err2 := Init(cfg2)
|
||||
if err2 != nil {
|
||||
t.Fatalf("第二次Init() 错误 = %v", err2)
|
||||
}
|
||||
|
||||
service2, _ := GetJWTService()
|
||||
|
||||
// 验证是同一个实例(sync.Once保证)
|
||||
if service1 != service2 {
|
||||
t.Error("Init应该只执行一次,返回同一个实例")
|
||||
}
|
||||
}
|
||||
|
||||
20
pkg/auth/password.go
Normal file
20
pkg/auth/password.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// HashPassword 使用bcrypt加密密码
|
||||
func HashPassword(password string) (string, error) {
|
||||
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(hashedBytes), nil
|
||||
}
|
||||
|
||||
// CheckPassword 验证密码是否匹配
|
||||
func CheckPassword(hashedPassword, password string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
145
pkg/auth/password_test.go
Normal file
145
pkg/auth/password_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestHashPassword 测试密码加密
|
||||
func TestHashPassword(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "正常密码",
|
||||
password: "testpassword123",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "空密码",
|
||||
password: "",
|
||||
wantError: false, // bcrypt允许空密码
|
||||
},
|
||||
{
|
||||
name: "长密码",
|
||||
password: "thisisaverylongpasswordthatexceedsnormallength",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "包含特殊字符的密码",
|
||||
password: "P@ssw0rd!#$%",
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hashed, err := HashPassword(tt.password)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("HashPassword() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
if !tt.wantError {
|
||||
// 验证哈希值不为空
|
||||
if hashed == "" {
|
||||
t.Error("HashPassword() 返回的哈希值不应为空")
|
||||
}
|
||||
// 验证哈希值与原密码不同
|
||||
if hashed == tt.password {
|
||||
t.Error("HashPassword() 返回的哈希值不应与原密码相同")
|
||||
}
|
||||
// 验证哈希值长度合理(bcrypt哈希通常是60个字符)
|
||||
if len(hashed) < 50 {
|
||||
t.Errorf("HashPassword() 返回的哈希值长度异常: %d", len(hashed))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckPassword 测试密码验证
|
||||
func TestCheckPassword(t *testing.T) {
|
||||
// 先加密一个密码
|
||||
password := "testpassword123"
|
||||
hashed, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() 失败: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
hashedPassword string
|
||||
password string
|
||||
wantMatch bool
|
||||
}{
|
||||
{
|
||||
name: "密码匹配",
|
||||
hashedPassword: hashed,
|
||||
password: password,
|
||||
wantMatch: true,
|
||||
},
|
||||
{
|
||||
name: "密码不匹配",
|
||||
hashedPassword: hashed,
|
||||
password: "wrongpassword",
|
||||
wantMatch: false,
|
||||
},
|
||||
{
|
||||
name: "空密码与空哈希",
|
||||
hashedPassword: "",
|
||||
password: "",
|
||||
wantMatch: false, // 空哈希无法验证
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := CheckPassword(tt.hashedPassword, tt.password)
|
||||
if result != tt.wantMatch {
|
||||
t.Errorf("CheckPassword() = %v, want %v", result, tt.wantMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHashPassword_Uniqueness 测试每次加密结果不同
|
||||
func TestHashPassword_Uniqueness(t *testing.T) {
|
||||
password := "testpassword123"
|
||||
|
||||
// 多次加密同一密码
|
||||
hashes := make(map[string]bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
hashed, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() 失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证每次加密的结果都不同(由于salt)
|
||||
if hashes[hashed] {
|
||||
t.Errorf("第%d次加密的结果与之前重复", i+1)
|
||||
}
|
||||
hashes[hashed] = true
|
||||
|
||||
// 但都能验证通过
|
||||
if !CheckPassword(hashed, password) {
|
||||
t.Errorf("第%d次加密的哈希无法验证原密码", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckPassword_Consistency 测试密码验证的一致性
|
||||
func TestCheckPassword_Consistency(t *testing.T) {
|
||||
password := "testpassword123"
|
||||
hashed, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() 失败: %v", err)
|
||||
}
|
||||
|
||||
// 多次验证应该结果一致
|
||||
for i := 0; i < 10; i++ {
|
||||
if !CheckPassword(hashed, password) {
|
||||
t.Errorf("第%d次验证失败", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
304
pkg/config/config.go
Normal file
304
pkg/config/config.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// Config 应用配置结构体
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
RustFS RustFSConfig `mapstructure:"rustfs"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Casbin CasbinConfig `mapstructure:"casbin"`
|
||||
Log LogConfig `mapstructure:"log"`
|
||||
Upload UploadConfig `mapstructure:"upload"`
|
||||
Email EmailConfig `mapstructure:"email"`
|
||||
}
|
||||
|
||||
// ServerConfig 服务器配置
|
||||
type ServerConfig struct {
|
||||
Port string `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"`
|
||||
ReadTimeout time.Duration `mapstructure:"read_timeout"`
|
||||
WriteTimeout time.Duration `mapstructure:"write_timeout"`
|
||||
}
|
||||
|
||||
// DatabaseConfig 数据库配置
|
||||
type DatabaseConfig struct {
|
||||
Driver string `mapstructure:"driver"`
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
Database string `mapstructure:"database"`
|
||||
SSLMode string `mapstructure:"ssl_mode"`
|
||||
Timezone string `mapstructure:"timezone"`
|
||||
MaxIdleConns int `mapstructure:"max_idle_conns"`
|
||||
MaxOpenConns int `mapstructure:"max_open_conns"`
|
||||
ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime"`
|
||||
}
|
||||
|
||||
// RedisConfig Redis配置
|
||||
type RedisConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Password string `mapstructure:"password"`
|
||||
Database int `mapstructure:"database"`
|
||||
PoolSize int `mapstructure:"pool_size"`
|
||||
}
|
||||
|
||||
// RustFSConfig RustFS对象存储配置 (S3兼容)
|
||||
type RustFSConfig struct {
|
||||
Endpoint string `mapstructure:"endpoint"`
|
||||
AccessKey string `mapstructure:"access_key"`
|
||||
SecretKey string `mapstructure:"secret_key"`
|
||||
UseSSL bool `mapstructure:"use_ssl"`
|
||||
Buckets map[string]string `mapstructure:"buckets"`
|
||||
}
|
||||
|
||||
// JWTConfig JWT配置
|
||||
type JWTConfig struct {
|
||||
Secret string `mapstructure:"secret"`
|
||||
ExpireHours int `mapstructure:"expire_hours"`
|
||||
}
|
||||
|
||||
// CasbinConfig Casbin权限配置
|
||||
type CasbinConfig struct {
|
||||
ModelPath string `mapstructure:"model_path"`
|
||||
PolicyAdapter string `mapstructure:"policy_adapter"`
|
||||
}
|
||||
|
||||
// LogConfig 日志配置
|
||||
type LogConfig struct {
|
||||
Level string `mapstructure:"level"`
|
||||
Format string `mapstructure:"format"`
|
||||
Output string `mapstructure:"output"`
|
||||
MaxSize int `mapstructure:"max_size"`
|
||||
MaxBackups int `mapstructure:"max_backups"`
|
||||
MaxAge int `mapstructure:"max_age"`
|
||||
Compress bool `mapstructure:"compress"`
|
||||
}
|
||||
|
||||
// UploadConfig 文件上传配置
|
||||
type UploadConfig struct {
|
||||
MaxSize int64 `mapstructure:"max_size"`
|
||||
AllowedTypes []string `mapstructure:"allowed_types"`
|
||||
TextureMaxSize int64 `mapstructure:"texture_max_size"`
|
||||
AvatarMaxSize int64 `mapstructure:"avatar_max_size"`
|
||||
}
|
||||
|
||||
// EmailConfig 邮件配置
|
||||
type EmailConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
SMTPHost string `mapstructure:"smtp_host"`
|
||||
SMTPPort int `mapstructure:"smtp_port"`
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
FromName string `mapstructure:"from_name"`
|
||||
}
|
||||
|
||||
// Load 加载配置 - 完全从环境变量加载,不依赖YAML文件
|
||||
func Load() (*Config, error) {
|
||||
// 加载.env文件(如果存在)
|
||||
_ = godotenv.Load(".env")
|
||||
|
||||
// 设置默认值
|
||||
setDefaults()
|
||||
|
||||
// 设置环境变量前缀
|
||||
viper.SetEnvPrefix("CARROTSKIN")
|
||||
viper.AutomaticEnv()
|
||||
|
||||
// 手动设置环境变量映射
|
||||
setupEnvMappings()
|
||||
|
||||
// 直接从环境变量解析配置
|
||||
var config Config
|
||||
if err := viper.Unmarshal(&config); err != nil {
|
||||
return nil, fmt.Errorf("解析配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 从环境变量中覆盖配置
|
||||
overrideFromEnv(&config)
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// setDefaults 设置默认配置值
|
||||
func setDefaults() {
|
||||
// 服务器默认配置
|
||||
viper.SetDefault("server.port", ":8080")
|
||||
viper.SetDefault("server.mode", "debug")
|
||||
viper.SetDefault("server.read_timeout", "30s")
|
||||
viper.SetDefault("server.write_timeout", "30s")
|
||||
|
||||
// 数据库默认配置
|
||||
viper.SetDefault("database.driver", "postgres")
|
||||
viper.SetDefault("database.host", "localhost")
|
||||
viper.SetDefault("database.port", 5432)
|
||||
viper.SetDefault("database.ssl_mode", "disable")
|
||||
viper.SetDefault("database.timezone", "Asia/Shanghai")
|
||||
viper.SetDefault("database.max_idle_conns", 10)
|
||||
viper.SetDefault("database.max_open_conns", 100)
|
||||
viper.SetDefault("database.conn_max_lifetime", "1h")
|
||||
|
||||
// Redis默认配置
|
||||
viper.SetDefault("redis.host", "localhost")
|
||||
viper.SetDefault("redis.port", 6379)
|
||||
viper.SetDefault("redis.database", 0)
|
||||
viper.SetDefault("redis.pool_size", 10)
|
||||
|
||||
// RustFS默认配置
|
||||
viper.SetDefault("rustfs.endpoint", "127.0.0.1:9000")
|
||||
viper.SetDefault("rustfs.use_ssl", false)
|
||||
|
||||
// JWT默认配置
|
||||
viper.SetDefault("jwt.expire_hours", 168)
|
||||
|
||||
// Casbin默认配置
|
||||
viper.SetDefault("casbin.model_path", "configs/casbin/rbac_model.conf")
|
||||
viper.SetDefault("casbin.policy_adapter", "gorm")
|
||||
|
||||
// 日志默认配置
|
||||
viper.SetDefault("log.level", "info")
|
||||
viper.SetDefault("log.format", "json")
|
||||
viper.SetDefault("log.output", "logs/app.log")
|
||||
viper.SetDefault("log.max_size", 100)
|
||||
viper.SetDefault("log.max_backups", 3)
|
||||
viper.SetDefault("log.max_age", 28)
|
||||
viper.SetDefault("log.compress", true)
|
||||
|
||||
// 文件上传默认配置
|
||||
viper.SetDefault("upload.max_size", 10485760)
|
||||
viper.SetDefault("upload.texture_max_size", 2097152)
|
||||
viper.SetDefault("upload.avatar_max_size", 1048576)
|
||||
viper.SetDefault("upload.allowed_types", []string{"image/png", "image/jpeg"})
|
||||
|
||||
// 邮件默认配置
|
||||
viper.SetDefault("email.enabled", false)
|
||||
viper.SetDefault("email.smtp_port", 587)
|
||||
}
|
||||
|
||||
// setupEnvMappings 设置环境变量映射
|
||||
func setupEnvMappings() {
|
||||
// 服务器配置
|
||||
viper.BindEnv("server.port", "SERVER_PORT")
|
||||
viper.BindEnv("server.mode", "SERVER_MODE")
|
||||
viper.BindEnv("server.read_timeout", "SERVER_READ_TIMEOUT")
|
||||
viper.BindEnv("server.write_timeout", "SERVER_WRITE_TIMEOUT")
|
||||
|
||||
// 数据库配置
|
||||
viper.BindEnv("database.driver", "DATABASE_DRIVER")
|
||||
viper.BindEnv("database.host", "DATABASE_HOST")
|
||||
viper.BindEnv("database.port", "DATABASE_PORT")
|
||||
viper.BindEnv("database.username", "DATABASE_USERNAME")
|
||||
viper.BindEnv("database.password", "DATABASE_PASSWORD")
|
||||
viper.BindEnv("database.database", "DATABASE_NAME")
|
||||
viper.BindEnv("database.ssl_mode", "DATABASE_SSL_MODE")
|
||||
viper.BindEnv("database.timezone", "DATABASE_TIMEZONE")
|
||||
|
||||
// Redis配置
|
||||
viper.BindEnv("redis.host", "REDIS_HOST")
|
||||
viper.BindEnv("redis.port", "REDIS_PORT")
|
||||
viper.BindEnv("redis.password", "REDIS_PASSWORD")
|
||||
viper.BindEnv("redis.database", "REDIS_DATABASE")
|
||||
|
||||
// RustFS配置
|
||||
viper.BindEnv("rustfs.endpoint", "RUSTFS_ENDPOINT")
|
||||
viper.BindEnv("rustfs.access_key", "RUSTFS_ACCESS_KEY")
|
||||
viper.BindEnv("rustfs.secret_key", "RUSTFS_SECRET_KEY")
|
||||
viper.BindEnv("rustfs.use_ssl", "RUSTFS_USE_SSL")
|
||||
|
||||
// JWT配置
|
||||
viper.BindEnv("jwt.secret", "JWT_SECRET")
|
||||
viper.BindEnv("jwt.expire_hours", "JWT_EXPIRE_HOURS")
|
||||
|
||||
// 日志配置
|
||||
viper.BindEnv("log.level", "LOG_LEVEL")
|
||||
viper.BindEnv("log.format", "LOG_FORMAT")
|
||||
viper.BindEnv("log.output", "LOG_OUTPUT")
|
||||
|
||||
// 邮件配置
|
||||
viper.BindEnv("email.enabled", "EMAIL_ENABLED")
|
||||
viper.BindEnv("email.smtp_host", "EMAIL_SMTP_HOST")
|
||||
viper.BindEnv("email.smtp_port", "EMAIL_SMTP_PORT")
|
||||
viper.BindEnv("email.username", "EMAIL_USERNAME")
|
||||
viper.BindEnv("email.password", "EMAIL_PASSWORD")
|
||||
viper.BindEnv("email.from_name", "EMAIL_FROM_NAME")
|
||||
}
|
||||
|
||||
// overrideFromEnv 从环境变量中覆盖配置
|
||||
func overrideFromEnv(config *Config) {
|
||||
// 处理RustFS存储桶配置
|
||||
if texturesBucket := os.Getenv("RUSTFS_BUCKET_TEXTURES"); texturesBucket != "" {
|
||||
if config.RustFS.Buckets == nil {
|
||||
config.RustFS.Buckets = make(map[string]string)
|
||||
}
|
||||
config.RustFS.Buckets["textures"] = texturesBucket
|
||||
}
|
||||
|
||||
if avatarsBucket := os.Getenv("RUSTFS_BUCKET_AVATARS"); avatarsBucket != "" {
|
||||
if config.RustFS.Buckets == nil {
|
||||
config.RustFS.Buckets = make(map[string]string)
|
||||
}
|
||||
config.RustFS.Buckets["avatars"] = avatarsBucket
|
||||
}
|
||||
|
||||
// 处理数据库连接池配置
|
||||
if maxIdleConns := os.Getenv("DATABASE_MAX_IDLE_CONNS"); maxIdleConns != "" {
|
||||
if val, err := strconv.Atoi(maxIdleConns); err == nil {
|
||||
config.Database.MaxIdleConns = val
|
||||
}
|
||||
}
|
||||
|
||||
if maxOpenConns := os.Getenv("DATABASE_MAX_OPEN_CONNS"); maxOpenConns != "" {
|
||||
if val, err := strconv.Atoi(maxOpenConns); err == nil {
|
||||
config.Database.MaxOpenConns = val
|
||||
}
|
||||
}
|
||||
|
||||
if connMaxLifetime := os.Getenv("DATABASE_CONN_MAX_LIFETIME"); connMaxLifetime != "" {
|
||||
if val, err := time.ParseDuration(connMaxLifetime); err == nil {
|
||||
config.Database.ConnMaxLifetime = val
|
||||
}
|
||||
}
|
||||
|
||||
// 处理Redis池大小
|
||||
if poolSize := os.Getenv("REDIS_POOL_SIZE"); poolSize != "" {
|
||||
if val, err := strconv.Atoi(poolSize); err == nil {
|
||||
config.Redis.PoolSize = val
|
||||
}
|
||||
}
|
||||
|
||||
// 处理文件上传配置
|
||||
if maxSize := os.Getenv("UPLOAD_MAX_SIZE"); maxSize != "" {
|
||||
if val, err := strconv.ParseInt(maxSize, 10, 64); err == nil {
|
||||
config.Upload.MaxSize = val
|
||||
}
|
||||
}
|
||||
|
||||
if textureMaxSize := os.Getenv("UPLOAD_TEXTURE_MAX_SIZE"); textureMaxSize != "" {
|
||||
if val, err := strconv.ParseInt(textureMaxSize, 10, 64); err == nil {
|
||||
config.Upload.TextureMaxSize = val
|
||||
}
|
||||
}
|
||||
|
||||
if avatarMaxSize := os.Getenv("UPLOAD_AVATAR_MAX_SIZE"); avatarMaxSize != "" {
|
||||
if val, err := strconv.ParseInt(avatarMaxSize, 10, 64); err == nil {
|
||||
config.Upload.AvatarMaxSize = val
|
||||
}
|
||||
}
|
||||
|
||||
// 处理邮件配置
|
||||
if emailEnabled := os.Getenv("EMAIL_ENABLED"); emailEnabled != "" {
|
||||
config.Email.Enabled = emailEnabled == "true" || emailEnabled == "True" || emailEnabled == "TRUE" || emailEnabled == "1"
|
||||
}
|
||||
}
|
||||
67
pkg/config/manager.go
Normal file
67
pkg/config/manager.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
// configInstance 全局配置实例
|
||||
configInstance *Config
|
||||
// rustFSConfigInstance 全局RustFS配置实例
|
||||
rustFSConfigInstance *RustFSConfig
|
||||
// once 确保只初始化一次
|
||||
once sync.Once
|
||||
// initError 初始化错误
|
||||
initError error
|
||||
)
|
||||
|
||||
// Init 初始化配置(线程安全,只会执行一次)
|
||||
func Init() error {
|
||||
once.Do(func() {
|
||||
configInstance, initError = Load()
|
||||
if initError != nil {
|
||||
return
|
||||
}
|
||||
rustFSConfigInstance = &configInstance.RustFS
|
||||
})
|
||||
return initError
|
||||
}
|
||||
|
||||
// GetConfig 获取配置实例(线程安全)
|
||||
func GetConfig() (*Config, error) {
|
||||
if configInstance == nil {
|
||||
return nil, fmt.Errorf("配置未初始化,请先调用 config.Init()")
|
||||
}
|
||||
return configInstance, nil
|
||||
}
|
||||
|
||||
// MustGetConfig 获取配置实例,如果未初始化则panic
|
||||
func MustGetConfig() *Config {
|
||||
cfg, err := GetConfig()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// GetRustFSConfig 获取RustFS配置实例(线程安全)
|
||||
func GetRustFSConfig() (*RustFSConfig, error) {
|
||||
if rustFSConfigInstance == nil {
|
||||
return nil, fmt.Errorf("配置未初始化,请先调用 config.Init()")
|
||||
}
|
||||
return rustFSConfigInstance, nil
|
||||
}
|
||||
|
||||
// MustGetRustFSConfig 获取RustFS配置实例,如果未初始化则panic
|
||||
func MustGetRustFSConfig() *RustFSConfig {
|
||||
cfg, err := GetRustFSConfig()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
70
pkg/config/manager_test.go
Normal file
70
pkg/config/manager_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestGetConfig_NotInitialized 测试未初始化时获取配置
|
||||
func TestGetConfig_NotInitialized(t *testing.T) {
|
||||
// 重置全局变量(在实际测试中可能需要更复杂的重置逻辑)
|
||||
// 注意:由于使用了 sync.Once,这个测试主要验证错误处理逻辑
|
||||
|
||||
// 测试未初始化时的错误消息
|
||||
_, err := GetConfig()
|
||||
if err == nil {
|
||||
t.Error("未初始化时应该返回错误")
|
||||
}
|
||||
|
||||
expectedError := "配置未初始化,请先调用 config.Init()"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("错误消息 = %q, want %q", err.Error(), expectedError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMustGetConfig_Panic 测试MustGetConfig在未初始化时panic
|
||||
func TestMustGetConfig_Panic(t *testing.T) {
|
||||
// 注意:这个测试会触发panic,需要recover
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("MustGetConfig 应该在未初始化时panic")
|
||||
}
|
||||
}()
|
||||
|
||||
// 尝试获取未初始化的配置
|
||||
_ = MustGetConfig()
|
||||
}
|
||||
|
||||
// TestGetRustFSConfig_NotInitialized 测试未初始化时获取RustFS配置
|
||||
func TestGetRustFSConfig_NotInitialized(t *testing.T) {
|
||||
_, err := GetRustFSConfig()
|
||||
if err == nil {
|
||||
t.Error("未初始化时应该返回错误")
|
||||
}
|
||||
|
||||
expectedError := "配置未初始化,请先调用 config.Init()"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("错误消息 = %q, want %q", err.Error(), expectedError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMustGetRustFSConfig_Panic 测试MustGetRustFSConfig在未初始化时panic
|
||||
func TestMustGetRustFSConfig_Panic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("MustGetRustFSConfig 应该在未初始化时panic")
|
||||
}
|
||||
}()
|
||||
|
||||
_ = MustGetRustFSConfig()
|
||||
}
|
||||
|
||||
// TestInit_Once 测试Init只执行一次的逻辑
|
||||
func TestInit_Once(t *testing.T) {
|
||||
// 注意:由于sync.Once的特性,这个测试主要验证逻辑
|
||||
// 实际测试中可能需要重置机制
|
||||
|
||||
// 验证Init函数可调用(函数不能直接比较nil)
|
||||
// 这里只验证函数存在
|
||||
_ = Init
|
||||
}
|
||||
|
||||
113
pkg/database/manager.go
Normal file
113
pkg/database/manager.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/config"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
// dbInstance 全局数据库实例
|
||||
dbInstance *gorm.DB
|
||||
// once 确保只初始化一次
|
||||
once sync.Once
|
||||
// initError 初始化错误
|
||||
initError error
|
||||
)
|
||||
|
||||
// Init 初始化数据库连接(线程安全,只会执行一次)
|
||||
func Init(cfg config.DatabaseConfig, logger *zap.Logger) error {
|
||||
once.Do(func() {
|
||||
dbInstance, initError = New(cfg)
|
||||
if initError != nil {
|
||||
logger.Error("数据库初始化失败", zap.Error(initError))
|
||||
return
|
||||
}
|
||||
logger.Info("数据库连接成功")
|
||||
})
|
||||
return initError
|
||||
}
|
||||
|
||||
// GetDB 获取数据库实例(线程安全)
|
||||
func GetDB() (*gorm.DB, error) {
|
||||
if dbInstance == nil {
|
||||
return nil, fmt.Errorf("数据库未初始化,请先调用 database.Init()")
|
||||
}
|
||||
return dbInstance, nil
|
||||
}
|
||||
|
||||
// MustGetDB 获取数据库实例,如果未初始化则panic
|
||||
func MustGetDB() *gorm.DB {
|
||||
db, err := GetDB()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// AutoMigrate 自动迁移数据库表结构
|
||||
func AutoMigrate(logger *zap.Logger) error {
|
||||
db, err := GetDB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取数据库实例失败: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("开始执行数据库迁移...")
|
||||
|
||||
// 迁移所有表 - 注意顺序:先创建被引用的表,再创建引用表
|
||||
err = db.AutoMigrate(
|
||||
// 用户相关表(先创建,因为其他表可能引用它)
|
||||
&model.User{},
|
||||
&model.UserPointLog{},
|
||||
&model.UserLoginLog{},
|
||||
|
||||
// 档案相关表
|
||||
&model.Profile{},
|
||||
|
||||
// 材质相关表
|
||||
&model.Texture{},
|
||||
&model.UserTextureFavorite{},
|
||||
&model.TextureDownloadLog{},
|
||||
|
||||
// 认证相关表
|
||||
&model.Token{},
|
||||
|
||||
// Yggdrasil相关表(在User之后创建,因为它引用User)
|
||||
&model.Yggdrasil{},
|
||||
|
||||
// 系统配置表
|
||||
&model.SystemConfig{},
|
||||
|
||||
// 审计日志表
|
||||
&model.AuditLog{},
|
||||
|
||||
// Casbin权限规则表
|
||||
&model.CasbinRule{},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
logger.Error("数据库迁移失败", zap.Error(err))
|
||||
return fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("数据库迁移完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func Close() error {
|
||||
if dbInstance == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sqlDB, err := dbInstance.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return sqlDB.Close()
|
||||
}
|
||||
85
pkg/database/manager_test.go
Normal file
85
pkg/database/manager_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestGetDB_NotInitialized 测试未初始化时获取数据库实例
|
||||
func TestGetDB_NotInitialized(t *testing.T) {
|
||||
_, err := GetDB()
|
||||
if err == nil {
|
||||
t.Error("未初始化时应该返回错误")
|
||||
}
|
||||
|
||||
expectedError := "数据库未初始化,请先调用 database.Init()"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("错误消息 = %q, want %q", err.Error(), expectedError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMustGetDB_Panic 测试MustGetDB在未初始化时panic
|
||||
func TestMustGetDB_Panic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("MustGetDB 应该在未初始化时panic")
|
||||
}
|
||||
}()
|
||||
|
||||
_ = MustGetDB()
|
||||
}
|
||||
|
||||
// TestInit_Database 测试数据库初始化逻辑
|
||||
func TestInit_Database(t *testing.T) {
|
||||
cfg := config.DatabaseConfig{
|
||||
Driver: "postgres",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "password",
|
||||
Database: "testdb",
|
||||
SSLMode: "disable",
|
||||
Timezone: "Asia/Shanghai",
|
||||
MaxIdleConns: 10,
|
||||
MaxOpenConns: 100,
|
||||
ConnMaxLifetime: 0,
|
||||
}
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
// 验证Init函数存在且可调用
|
||||
// 注意:实际连接可能失败,这是可以接受的
|
||||
err := Init(cfg, logger)
|
||||
if err != nil {
|
||||
t.Logf("Init() 返回错误(可能正常,如果数据库未运行): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutoMigrate_ErrorHandling 测试AutoMigrate的错误处理逻辑
|
||||
func TestAutoMigrate_ErrorHandling(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
// 测试未初始化时的错误处理
|
||||
err := AutoMigrate(logger)
|
||||
if err == nil {
|
||||
// 如果数据库已初始化,这是正常的
|
||||
t.Log("AutoMigrate() 成功(数据库可能已初始化)")
|
||||
} else {
|
||||
// 如果数据库未初始化,应该返回错误
|
||||
if err.Error() == "" {
|
||||
t.Error("AutoMigrate() 应该返回有意义的错误消息")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestClose_NotInitialized 测试未初始化时关闭数据库
|
||||
func TestClose_NotInitialized(t *testing.T) {
|
||||
// 未初始化时关闭应该不返回错误
|
||||
err := Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close() 在未初始化时应该返回nil,实际返回: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
73
pkg/database/postgres.go
Normal file
73
pkg/database/postgres.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"carrotskin/pkg/config"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// New 创建新的PostgreSQL数据库连接
|
||||
func New(cfg config.DatabaseConfig) (*gorm.DB, error) {
|
||||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
|
||||
cfg.Host,
|
||||
cfg.Port,
|
||||
cfg.Username,
|
||||
cfg.Password,
|
||||
cfg.Database,
|
||||
cfg.SSLMode,
|
||||
cfg.Timezone,
|
||||
)
|
||||
|
||||
// 配置GORM日志级别
|
||||
var gormLogLevel logger.LogLevel
|
||||
switch {
|
||||
case cfg.Driver == "postgres":
|
||||
gormLogLevel = logger.Info
|
||||
default:
|
||||
gormLogLevel = logger.Silent
|
||||
}
|
||||
|
||||
// 打开数据库连接
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(gormLogLevel),
|
||||
DisableForeignKeyConstraintWhenMigrating: true, // 禁用自动创建外键约束,避免循环依赖问题
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接PostgreSQL数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取底层SQL数据库实例
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
|
||||
}
|
||||
|
||||
// 配置连接池
|
||||
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime)
|
||||
|
||||
// 测试连接
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("数据库连接测试失败: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// GetDSN 获取数据源名称
|
||||
func GetDSN(cfg config.DatabaseConfig) string {
|
||||
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
|
||||
cfg.Host,
|
||||
cfg.Port,
|
||||
cfg.Username,
|
||||
cfg.Password,
|
||||
cfg.Database,
|
||||
cfg.SSLMode,
|
||||
cfg.Timezone,
|
||||
)
|
||||
}
|
||||
162
pkg/email/email.go
Normal file
162
pkg/email/email.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
"net/textproto"
|
||||
|
||||
"carrotskin/pkg/config"
|
||||
|
||||
"github.com/jordan-wright/email"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Service 邮件服务
|
||||
type Service struct {
|
||||
cfg config.EmailConfig
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewService 创建邮件服务
|
||||
func NewService(cfg config.EmailConfig, logger *zap.Logger) *Service {
|
||||
return &Service{
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SendVerificationCode 发送验证码邮件
|
||||
func (s *Service) SendVerificationCode(to, code, purpose string) error {
|
||||
if !s.cfg.Enabled {
|
||||
s.logger.Warn("邮件服务未启用,跳过发送", zap.String("to", to))
|
||||
return fmt.Errorf("邮件服务未启用")
|
||||
}
|
||||
|
||||
subject := s.getSubject(purpose)
|
||||
body := s.getBody(code, purpose)
|
||||
|
||||
return s.send([]string{to}, subject, body)
|
||||
}
|
||||
|
||||
// SendResetPassword 发送重置密码邮件
|
||||
func (s *Service) SendResetPassword(to, code string) error {
|
||||
return s.SendVerificationCode(to, code, "reset_password")
|
||||
}
|
||||
|
||||
// SendEmailVerification 发送邮箱验证邮件
|
||||
func (s *Service) SendEmailVerification(to, code string) error {
|
||||
return s.SendVerificationCode(to, code, "email_verification")
|
||||
}
|
||||
|
||||
// SendChangeEmail 发送更换邮箱验证码
|
||||
func (s *Service) SendChangeEmail(to, code string) error {
|
||||
return s.SendVerificationCode(to, code, "change_email")
|
||||
}
|
||||
|
||||
// send 发送邮件
|
||||
func (s *Service) send(to []string, subject, body string) error {
|
||||
e := email.NewEmail()
|
||||
e.From = fmt.Sprintf("%s <%s>", s.cfg.FromName, s.cfg.Username)
|
||||
e.To = to
|
||||
e.Subject = subject
|
||||
e.HTML = []byte(body)
|
||||
e.Headers = textproto.MIMEHeader{}
|
||||
|
||||
// SMTP认证
|
||||
auth := smtp.PlainAuth("", s.cfg.Username, s.cfg.Password, s.cfg.SMTPHost)
|
||||
|
||||
// 发送邮件
|
||||
addr := fmt.Sprintf("%s:%d", s.cfg.SMTPHost, s.cfg.SMTPPort)
|
||||
|
||||
// 判断端口决定发送方式
|
||||
// 465端口使用SSL/TLS(隐式TLS)
|
||||
// 587端口使用STARTTLS(显式TLS)
|
||||
var err error
|
||||
if s.cfg.SMTPPort == 465 {
|
||||
// 使用SSL/TLS连接(适用于465端口)
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: s.cfg.SMTPHost,
|
||||
InsecureSkipVerify: false, // 生产环境建议设置为false
|
||||
}
|
||||
err = e.SendWithTLS(addr, auth, tlsConfig)
|
||||
} else {
|
||||
// 使用STARTTLS连接(适用于587端口等)
|
||||
err = e.Send(addr, auth)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
s.logger.Error("发送邮件失败",
|
||||
zap.Strings("to", to),
|
||||
zap.String("subject", subject),
|
||||
zap.String("smtp_host", s.cfg.SMTPHost),
|
||||
zap.Int("smtp_port", s.cfg.SMTPPort),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("发送邮件失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("邮件发送成功",
|
||||
zap.Strings("to", to),
|
||||
zap.String("subject", subject),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getSubject 获取邮件主题
|
||||
func (s *Service) getSubject(purpose string) string {
|
||||
switch purpose {
|
||||
case "email_verification":
|
||||
return "【CarrotSkin】邮箱验证"
|
||||
case "reset_password":
|
||||
return "【CarrotSkin】重置密码"
|
||||
case "change_email":
|
||||
return "【CarrotSkin】更换邮箱验证"
|
||||
default:
|
||||
return "【CarrotSkin】验证码"
|
||||
}
|
||||
}
|
||||
|
||||
// getBody 获取邮件正文
|
||||
func (s *Service) getBody(code, purpose string) string {
|
||||
var message string
|
||||
switch purpose {
|
||||
case "email_verification":
|
||||
message = "感谢注册CarrotSkin!请使用以下验证码完成邮箱验证:"
|
||||
case "reset_password":
|
||||
message = "您正在重置密码,请使用以下验证码:"
|
||||
case "change_email":
|
||||
message = "您正在更换邮箱,请使用以下验证码验证新邮箱:"
|
||||
default:
|
||||
message = "您的验证码为:"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>验证码</title>
|
||||
</head>
|
||||
<body style="margin: 0; padding: 0; font-family: Arial, sans-serif; background-color: #f4f4f4;">
|
||||
<div style="max-width: 600px; margin: 20px auto; background-color: #ffffff; padding: 30px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
|
||||
<div style="text-align: center; padding-bottom: 20px;">
|
||||
<h1 style="color: #ff6b35; margin: 0;">CarrotSkin</h1>
|
||||
</div>
|
||||
<div style="padding: 20px 0; border-top: 2px solid #ff6b35; border-bottom: 2px solid #ff6b35;">
|
||||
<p style="font-size: 16px; color: #333; margin: 0 0 20px 0;">%s</p>
|
||||
<div style="background-color: #f9f9f9; padding: 20px; text-align: center; border-radius: 4px; margin: 20px 0;">
|
||||
<span style="font-size: 32px; font-weight: bold; color: #ff6b35; letter-spacing: 5px;">%s</span>
|
||||
</div>
|
||||
<p style="font-size: 14px; color: #666; margin: 20px 0 0 0;">验证码有效期为10分钟,请及时使用。</p>
|
||||
<p style="font-size: 14px; color: #666; margin: 10px 0 0 0;">如果这不是您的操作,请忽略此邮件。</p>
|
||||
</div>
|
||||
<div style="text-align: center; padding-top: 20px;">
|
||||
<p style="font-size: 12px; color: #999; margin: 0;">© 2025 CarrotSkin. All rights reserved.</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`, message, code)
|
||||
}
|
||||
47
pkg/email/manager.go
Normal file
47
pkg/email/manager.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
// serviceInstance 全局邮件服务实例
|
||||
serviceInstance *Service
|
||||
// once 确保只初始化一次
|
||||
once sync.Once
|
||||
// initError 初始化错误
|
||||
initError error
|
||||
)
|
||||
|
||||
// Init 初始化邮件服务(线程安全,只会执行一次)
|
||||
func Init(cfg config.EmailConfig, logger *zap.Logger) error {
|
||||
once.Do(func() {
|
||||
serviceInstance = NewService(cfg, logger)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetService 获取邮件服务实例(线程安全)
|
||||
func GetService() (*Service, error) {
|
||||
if serviceInstance == nil {
|
||||
return nil, fmt.Errorf("邮件服务未初始化,请先调用 email.Init()")
|
||||
}
|
||||
return serviceInstance, nil
|
||||
}
|
||||
|
||||
// MustGetService 获取邮件服务实例,如果未初始化则panic
|
||||
func MustGetService() *Service {
|
||||
service, err := GetService()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return service
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
61
pkg/email/manager_test.go
Normal file
61
pkg/email/manager_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestGetService_NotInitialized 测试未初始化时获取邮件服务
|
||||
func TestGetService_NotInitialized(t *testing.T) {
|
||||
_, err := GetService()
|
||||
if err == nil {
|
||||
t.Error("未初始化时应该返回错误")
|
||||
}
|
||||
|
||||
expectedError := "邮件服务未初始化,请先调用 email.Init()"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("错误消息 = %q, want %q", err.Error(), expectedError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMustGetService_Panic 测试MustGetService在未初始化时panic
|
||||
func TestMustGetService_Panic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("MustGetService 应该在未初始化时panic")
|
||||
}
|
||||
}()
|
||||
|
||||
_ = MustGetService()
|
||||
}
|
||||
|
||||
// TestInit_Email 测试邮件服务初始化
|
||||
func TestInit_Email(t *testing.T) {
|
||||
cfg := config.EmailConfig{
|
||||
Enabled: false,
|
||||
SMTPHost: "smtp.example.com",
|
||||
SMTPPort: 587,
|
||||
Username: "user@example.com",
|
||||
Password: "password",
|
||||
FromName: "noreply@example.com",
|
||||
}
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
err := Init(cfg, logger)
|
||||
if err != nil {
|
||||
t.Errorf("Init() 错误 = %v, want nil", err)
|
||||
}
|
||||
|
||||
// 验证可以获取服务
|
||||
service, err := GetService()
|
||||
if err != nil {
|
||||
t.Errorf("GetService() 错误 = %v, want nil", err)
|
||||
}
|
||||
if service == nil {
|
||||
t.Error("GetService() 返回的服务不应为nil")
|
||||
}
|
||||
}
|
||||
|
||||
68
pkg/logger/logger.go
Normal file
68
pkg/logger/logger.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"carrotskin/pkg/config"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// New 创建新的日志记录器
|
||||
func New(cfg config.LogConfig) (*zap.Logger, error) {
|
||||
// 配置日志级别
|
||||
var level zapcore.Level
|
||||
switch cfg.Level {
|
||||
case "debug":
|
||||
level = zapcore.DebugLevel
|
||||
case "info":
|
||||
level = zapcore.InfoLevel
|
||||
case "warn":
|
||||
level = zapcore.WarnLevel
|
||||
case "error":
|
||||
level = zapcore.ErrorLevel
|
||||
default:
|
||||
level = zapcore.InfoLevel
|
||||
}
|
||||
|
||||
// 配置编码器
|
||||
var encoder zapcore.Encoder
|
||||
encoderConfig := zap.NewProductionEncoderConfig()
|
||||
encoderConfig.TimeKey = "timestamp"
|
||||
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
|
||||
|
||||
if cfg.Format == "console" {
|
||||
encoder = zapcore.NewConsoleEncoder(encoderConfig)
|
||||
} else {
|
||||
encoder = zapcore.NewJSONEncoder(encoderConfig)
|
||||
}
|
||||
|
||||
// 配置输出
|
||||
var writeSyncer zapcore.WriteSyncer
|
||||
if cfg.Output == "" || cfg.Output == "stdout" {
|
||||
writeSyncer = zapcore.AddSync(os.Stdout)
|
||||
} else {
|
||||
// 自动创建日志目录
|
||||
logDir := filepath.Dir(cfg.Output)
|
||||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(cfg.Output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writeSyncer = zapcore.AddSync(file)
|
||||
}
|
||||
|
||||
// 创建核心
|
||||
core := zapcore.NewCore(encoder, writeSyncer, level)
|
||||
|
||||
// 创建日志记录器
|
||||
logger := zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1))
|
||||
|
||||
return logger, nil
|
||||
}
|
||||
50
pkg/logger/manager.go
Normal file
50
pkg/logger/manager.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
// loggerInstance 全局日志实例
|
||||
loggerInstance *zap.Logger
|
||||
// once 确保只初始化一次
|
||||
once sync.Once
|
||||
// initError 初始化错误
|
||||
initError error
|
||||
)
|
||||
|
||||
// Init 初始化日志记录器(线程安全,只会执行一次)
|
||||
func Init(cfg config.LogConfig) error {
|
||||
once.Do(func() {
|
||||
loggerInstance, initError = New(cfg)
|
||||
if initError != nil {
|
||||
return
|
||||
}
|
||||
})
|
||||
return initError
|
||||
}
|
||||
|
||||
// GetLogger 获取日志实例(线程安全)
|
||||
func GetLogger() (*zap.Logger, error) {
|
||||
if loggerInstance == nil {
|
||||
return nil, fmt.Errorf("日志未初始化,请先调用 logger.Init()")
|
||||
}
|
||||
return loggerInstance, nil
|
||||
}
|
||||
|
||||
// MustGetLogger 获取日志实例,如果未初始化则panic
|
||||
func MustGetLogger() *zap.Logger {
|
||||
logger, err := GetLogger()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return logger
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
47
pkg/logger/manager_test.go
Normal file
47
pkg/logger/manager_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestGetLogger_NotInitialized 测试未初始化时获取日志实例
|
||||
func TestGetLogger_NotInitialized(t *testing.T) {
|
||||
_, err := GetLogger()
|
||||
if err == nil {
|
||||
t.Error("未初始化时应该返回错误")
|
||||
}
|
||||
|
||||
expectedError := "日志未初始化,请先调用 logger.Init()"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("错误消息 = %q, want %q", err.Error(), expectedError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMustGetLogger_Panic 测试MustGetLogger在未初始化时panic
|
||||
func TestMustGetLogger_Panic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("MustGetLogger 应该在未初始化时panic")
|
||||
}
|
||||
}()
|
||||
|
||||
_ = MustGetLogger()
|
||||
}
|
||||
|
||||
// TestInit_Logger 测试日志初始化逻辑
|
||||
func TestInit_Logger(t *testing.T) {
|
||||
cfg := config.LogConfig{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
Output: "stdout",
|
||||
}
|
||||
|
||||
// 验证Init函数存在且可调用
|
||||
err := Init(cfg)
|
||||
if err != nil {
|
||||
// 初始化可能失败(例如缺少依赖),这是可以接受的
|
||||
t.Logf("Init() 返回错误(可能正常): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
50
pkg/redis/manager.go
Normal file
50
pkg/redis/manager.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
// clientInstance 全局Redis客户端实例
|
||||
clientInstance *Client
|
||||
// once 确保只初始化一次
|
||||
once sync.Once
|
||||
// initError 初始化错误
|
||||
initError error
|
||||
)
|
||||
|
||||
// Init 初始化Redis客户端(线程安全,只会执行一次)
|
||||
func Init(cfg config.RedisConfig, logger *zap.Logger) error {
|
||||
once.Do(func() {
|
||||
clientInstance, initError = New(cfg, logger)
|
||||
if initError != nil {
|
||||
return
|
||||
}
|
||||
})
|
||||
return initError
|
||||
}
|
||||
|
||||
// GetClient 获取Redis客户端实例(线程安全)
|
||||
func GetClient() (*Client, error) {
|
||||
if clientInstance == nil {
|
||||
return nil, fmt.Errorf("Redis客户端未初始化,请先调用 redis.Init()")
|
||||
}
|
||||
return clientInstance, nil
|
||||
}
|
||||
|
||||
// MustGetClient 获取Redis客户端实例,如果未初始化则panic
|
||||
func MustGetClient() *Client {
|
||||
client, err := GetClient()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
53
pkg/redis/manager_test.go
Normal file
53
pkg/redis/manager_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestGetClient_NotInitialized 测试未初始化时获取Redis客户端
|
||||
func TestGetClient_NotInitialized(t *testing.T) {
|
||||
_, err := GetClient()
|
||||
if err == nil {
|
||||
t.Error("未初始化时应该返回错误")
|
||||
}
|
||||
|
||||
expectedError := "Redis客户端未初始化,请先调用 redis.Init()"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("错误消息 = %q, want %q", err.Error(), expectedError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMustGetClient_Panic 测试MustGetClient在未初始化时panic
|
||||
func TestMustGetClient_Panic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("MustGetClient 应该在未初始化时panic")
|
||||
}
|
||||
}()
|
||||
|
||||
_ = MustGetClient()
|
||||
}
|
||||
|
||||
// TestInit_Redis 测试Redis初始化逻辑
|
||||
func TestInit_Redis(t *testing.T) {
|
||||
cfg := config.RedisConfig{
|
||||
Host: "localhost",
|
||||
Port: 6379,
|
||||
Password: "",
|
||||
Database: 0,
|
||||
PoolSize: 10,
|
||||
}
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
// 验证Init函数存在且可调用
|
||||
// 注意:实际连接可能失败,这是可以接受的
|
||||
err := Init(cfg, logger)
|
||||
if err != nil {
|
||||
t.Logf("Init() 返回错误(可能正常,如果Redis未运行): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
174
pkg/redis/redis.go
Normal file
174
pkg/redis/redis.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"carrotskin/pkg/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Client Redis客户端包装
|
||||
type Client struct {
|
||||
*redis.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// New 创建Redis客户端
|
||||
func New(cfg config.RedisConfig, logger *zap.Logger) (*Client, error) {
|
||||
// 创建Redis客户端
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
Password: cfg.Password,
|
||||
DB: cfg.Database,
|
||||
PoolSize: cfg.PoolSize,
|
||||
DialTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
})
|
||||
|
||||
// 测试连接
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("Redis连接失败: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("Redis连接成功",
|
||||
zap.String("host", cfg.Host),
|
||||
zap.Int("port", cfg.Port),
|
||||
zap.Int("database", cfg.Database),
|
||||
)
|
||||
|
||||
return &Client{
|
||||
Client: rdb,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close 关闭Redis连接
|
||||
func (c *Client) Close() error {
|
||||
c.logger.Info("正在关闭Redis连接")
|
||||
return c.Client.Close()
|
||||
}
|
||||
|
||||
// Set 设置键值对(带过期时间)
|
||||
func (c *Client) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
|
||||
return c.Client.Set(ctx, key, value, expiration).Err()
|
||||
}
|
||||
|
||||
// Get 获取键值
|
||||
func (c *Client) Get(ctx context.Context, key string) (string, error) {
|
||||
return c.Client.Get(ctx, key).Result()
|
||||
}
|
||||
|
||||
// Del 删除键
|
||||
func (c *Client) Del(ctx context.Context, keys ...string) error {
|
||||
return c.Client.Del(ctx, keys...).Err()
|
||||
}
|
||||
|
||||
// Exists 检查键是否存在
|
||||
func (c *Client) Exists(ctx context.Context, keys ...string) (int64, error) {
|
||||
return c.Client.Exists(ctx, keys...).Result()
|
||||
}
|
||||
|
||||
// Expire 设置键的过期时间
|
||||
func (c *Client) Expire(ctx context.Context, key string, expiration time.Duration) error {
|
||||
return c.Client.Expire(ctx, key, expiration).Err()
|
||||
}
|
||||
|
||||
// Incr 自增
|
||||
func (c *Client) Incr(ctx context.Context, key string) (int64, error) {
|
||||
return c.Client.Incr(ctx, key).Result()
|
||||
}
|
||||
|
||||
// Decr 自减
|
||||
func (c *Client) Decr(ctx context.Context, key string) (int64, error) {
|
||||
return c.Client.Decr(ctx, key).Result()
|
||||
}
|
||||
|
||||
// HSet 设置哈希字段
|
||||
func (c *Client) HSet(ctx context.Context, key string, values ...interface{}) error {
|
||||
return c.Client.HSet(ctx, key, values...).Err()
|
||||
}
|
||||
|
||||
// HGet 获取哈希字段
|
||||
func (c *Client) HGet(ctx context.Context, key, field string) (string, error) {
|
||||
return c.Client.HGet(ctx, key, field).Result()
|
||||
}
|
||||
|
||||
// HGetAll 获取所有哈希字段
|
||||
func (c *Client) HGetAll(ctx context.Context, key string) (map[string]string, error) {
|
||||
return c.Client.HGetAll(ctx, key).Result()
|
||||
}
|
||||
|
||||
// HDel 删除哈希字段
|
||||
func (c *Client) HDel(ctx context.Context, key string, fields ...string) error {
|
||||
return c.Client.HDel(ctx, key, fields...).Err()
|
||||
}
|
||||
|
||||
// SAdd 添加集合成员
|
||||
func (c *Client) SAdd(ctx context.Context, key string, members ...interface{}) error {
|
||||
return c.Client.SAdd(ctx, key, members...).Err()
|
||||
}
|
||||
|
||||
// SMembers 获取集合所有成员
|
||||
func (c *Client) SMembers(ctx context.Context, key string) ([]string, error) {
|
||||
return c.Client.SMembers(ctx, key).Result()
|
||||
}
|
||||
|
||||
// SRem 删除集合成员
|
||||
func (c *Client) SRem(ctx context.Context, key string, members ...interface{}) error {
|
||||
return c.Client.SRem(ctx, key, members...).Err()
|
||||
}
|
||||
|
||||
// SIsMember 检查是否是集合成员
|
||||
func (c *Client) SIsMember(ctx context.Context, key string, member interface{}) (bool, error) {
|
||||
return c.Client.SIsMember(ctx, key, member).Result()
|
||||
}
|
||||
|
||||
// ZAdd 添加有序集合成员
|
||||
func (c *Client) ZAdd(ctx context.Context, key string, members ...redis.Z) error {
|
||||
return c.Client.ZAdd(ctx, key, members...).Err()
|
||||
}
|
||||
|
||||
// ZRange 获取有序集合范围内的成员
|
||||
func (c *Client) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) {
|
||||
return c.Client.ZRange(ctx, key, start, stop).Result()
|
||||
}
|
||||
|
||||
// ZRem 删除有序集合成员
|
||||
func (c *Client) ZRem(ctx context.Context, key string, members ...interface{}) error {
|
||||
return c.Client.ZRem(ctx, key, members...).Err()
|
||||
}
|
||||
|
||||
// Pipeline 创建管道
|
||||
func (c *Client) Pipeline() redis.Pipeliner {
|
||||
return c.Client.Pipeline()
|
||||
}
|
||||
|
||||
// TxPipeline 创建事务管道
|
||||
func (c *Client) TxPipeline() redis.Pipeliner {
|
||||
return c.Client.TxPipeline()
|
||||
}
|
||||
|
||||
func (c *Client) Nil(err error) bool {
|
||||
return errors.Is(err, redis.Nil)
|
||||
}
|
||||
|
||||
// GetBytes 从Redis读取key对应的字节数据,统一处理错误
|
||||
func (c *Client) GetBytes(ctx context.Context, key string) ([]byte, error) {
|
||||
val, err := c.Client.Get(ctx, key).Bytes()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) { // 处理key不存在的情况(返回nil,无错误)
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err // 其他错误(如连接失败)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
48
pkg/storage/manager.go
Normal file
48
pkg/storage/manager.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
// clientInstance 全局存储客户端实例
|
||||
clientInstance *StorageClient
|
||||
// once 确保只初始化一次
|
||||
once sync.Once
|
||||
// initError 初始化错误
|
||||
initError error
|
||||
)
|
||||
|
||||
// Init 初始化存储客户端(线程安全,只会执行一次)
|
||||
func Init(cfg config.RustFSConfig) error {
|
||||
once.Do(func() {
|
||||
clientInstance, initError = NewStorage(cfg)
|
||||
if initError != nil {
|
||||
return
|
||||
}
|
||||
})
|
||||
return initError
|
||||
}
|
||||
|
||||
// GetClient 获取存储客户端实例(线程安全)
|
||||
func GetClient() (*StorageClient, error) {
|
||||
if clientInstance == nil {
|
||||
return nil, fmt.Errorf("存储客户端未初始化,请先调用 storage.Init()")
|
||||
}
|
||||
return clientInstance, nil
|
||||
}
|
||||
|
||||
// MustGetClient 获取存储客户端实例,如果未初始化则panic
|
||||
func MustGetClient() *StorageClient {
|
||||
client, err := GetClient()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
52
pkg/storage/manager_test.go
Normal file
52
pkg/storage/manager_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestGetClient_NotInitialized 测试未初始化时获取存储客户端
|
||||
func TestGetClient_NotInitialized(t *testing.T) {
|
||||
_, err := GetClient()
|
||||
if err == nil {
|
||||
t.Error("未初始化时应该返回错误")
|
||||
}
|
||||
|
||||
expectedError := "存储客户端未初始化,请先调用 storage.Init()"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("错误消息 = %q, want %q", err.Error(), expectedError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMustGetClient_Panic 测试MustGetClient在未初始化时panic
|
||||
func TestMustGetClient_Panic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("MustGetClient 应该在未初始化时panic")
|
||||
}
|
||||
}()
|
||||
|
||||
_ = MustGetClient()
|
||||
}
|
||||
|
||||
// TestInit_Storage 测试存储客户端初始化逻辑
|
||||
func TestInit_Storage(t *testing.T) {
|
||||
cfg := config.RustFSConfig{
|
||||
Endpoint: "http://localhost:9000",
|
||||
AccessKey: "minioadmin",
|
||||
SecretKey: "minioadmin",
|
||||
UseSSL: false,
|
||||
Buckets: map[string]string{
|
||||
"avatars": "avatars",
|
||||
"textures": "textures",
|
||||
},
|
||||
}
|
||||
|
||||
// 验证Init函数存在且可调用
|
||||
// 注意:实际连接可能失败,这是可以接受的
|
||||
err := Init(cfg)
|
||||
if err != nil {
|
||||
t.Logf("Init() 返回错误(可能正常,如果存储服务未运行): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
120
pkg/storage/minio.go
Normal file
120
pkg/storage/minio.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"carrotskin/pkg/config"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
)
|
||||
|
||||
// StorageClient S3兼容对象存储客户端包装 (支持RustFS、MinIO等)
|
||||
type StorageClient struct {
|
||||
client *minio.Client
|
||||
buckets map[string]string
|
||||
}
|
||||
|
||||
// NewStorage 创建新的对象存储客户端 (S3兼容,支持RustFS)
|
||||
func NewStorage(cfg config.RustFSConfig) (*StorageClient, error) {
|
||||
// 创建S3兼容客户端
|
||||
// minio-go SDK支持所有S3兼容的存储,包括RustFS
|
||||
// 不指定Region,让SDK自动检测
|
||||
client, err := minio.New(cfg.Endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(cfg.AccessKey, cfg.SecretKey, ""),
|
||||
Secure: cfg.UseSSL,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建对象存储客户端失败: %w", err)
|
||||
}
|
||||
|
||||
// 测试连接(如果AccessKey和SecretKey为空,跳过测试)
|
||||
if cfg.AccessKey != "" && cfg.SecretKey != "" {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err = client.ListBuckets(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("对象存储连接测试失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storageClient := &StorageClient{
|
||||
client: client,
|
||||
buckets: cfg.Buckets,
|
||||
}
|
||||
|
||||
return storageClient, nil
|
||||
}
|
||||
|
||||
// GetClient 获取底层S3客户端
|
||||
func (s *StorageClient) GetClient() *minio.Client {
|
||||
return s.client
|
||||
}
|
||||
|
||||
// GetBucket 获取存储桶名称
|
||||
func (s *StorageClient) GetBucket(name string) (string, error) {
|
||||
bucket, exists := s.buckets[name]
|
||||
if !exists {
|
||||
return "", fmt.Errorf("存储桶 %s 不存在", name)
|
||||
}
|
||||
return bucket, nil
|
||||
}
|
||||
|
||||
// GeneratePresignedURL 生成预签名上传URL (PUT方法)
|
||||
func (s *StorageClient) GeneratePresignedURL(ctx context.Context, bucketName, objectName string, expires time.Duration) (string, error) {
|
||||
url, err := s.client.PresignedPutObject(ctx, bucketName, objectName, expires)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("生成预签名URL失败: %w", err)
|
||||
}
|
||||
return url.String(), nil
|
||||
}
|
||||
|
||||
// PresignedPostPolicyResult 预签名POST策略结果
|
||||
type PresignedPostPolicyResult struct {
|
||||
PostURL string // POST的URL
|
||||
FormData map[string]string // 表单数据
|
||||
FileURL string // 文件的最终访问URL
|
||||
}
|
||||
|
||||
// GeneratePresignedPostURL 生成预签名POST URL (支持表单上传)
|
||||
// 注意:使用时必须确保file字段是表单的最后一个字段
|
||||
func (s *StorageClient) GeneratePresignedPostURL(ctx context.Context, bucketName, objectName string, minSize, maxSize int64, expires time.Duration, useSSL bool, endpoint string) (*PresignedPostPolicyResult, error) {
|
||||
// 创建上传策略
|
||||
policy := minio.NewPostPolicy()
|
||||
|
||||
// 设置策略的基本信息
|
||||
policy.SetBucket(bucketName)
|
||||
policy.SetKey(objectName)
|
||||
policy.SetExpires(time.Now().UTC().Add(expires))
|
||||
|
||||
// 设置文件大小限制
|
||||
if err := policy.SetContentLengthRange(minSize, maxSize); err != nil {
|
||||
return nil, fmt.Errorf("设置文件大小限制失败: %w", err)
|
||||
}
|
||||
|
||||
// 使用MinIO客户端和策略生成预签名的POST URL和表单数据
|
||||
postURL, formData, err := s.client.PresignedPostPolicy(ctx, policy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成预签名POST URL失败: %w", err)
|
||||
}
|
||||
|
||||
// 移除form_data中多余的bucket字段(MinIO Go SDK可能会添加这个字段,但会导致签名错误)
|
||||
// 注意:在Go中直接delete不存在的key是安全的
|
||||
delete(formData, "bucket")
|
||||
|
||||
// 构造文件的永久访问URL
|
||||
protocol := "http"
|
||||
if useSSL {
|
||||
protocol = "https"
|
||||
}
|
||||
fileURL := fmt.Sprintf("%s://%s/%s/%s", protocol, endpoint, bucketName, objectName)
|
||||
|
||||
return &PresignedPostPolicyResult{
|
||||
PostURL: postURL.String(),
|
||||
FormData: formData,
|
||||
FileURL: fileURL,
|
||||
}, nil
|
||||
}
|
||||
47
pkg/utils/format.go
Normal file
47
pkg/utils/format.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FormatUUID 将UUID格式化为带连字符的标准格式
|
||||
// 如果输入已经是标准格式,直接返回
|
||||
// 如果输入是32位十六进制字符串,添加连字符
|
||||
// 如果输入格式无效,返回错误
|
||||
func FormatUUID(uuid string) string {
|
||||
// 如果为空,直接返回
|
||||
if uuid == "" {
|
||||
return uuid
|
||||
}
|
||||
|
||||
// 如果已经是标准格式(8-4-4-4-12),直接返回
|
||||
if len(uuid) == 36 && uuid[8] == '-' && uuid[13] == '-' && uuid[18] == '-' && uuid[23] == '-' {
|
||||
return uuid
|
||||
}
|
||||
|
||||
// 如果是32位十六进制字符串,添加连字符
|
||||
if len(uuid) == 32 {
|
||||
// 预分配容量以提高性能
|
||||
var b strings.Builder
|
||||
b.Grow(36) // 最终长度为36(32个字符 + 4个连字符)
|
||||
|
||||
// 使用WriteString和WriteByte优化性能
|
||||
b.WriteString(uuid[0:8])
|
||||
b.WriteByte('-')
|
||||
b.WriteString(uuid[8:12])
|
||||
b.WriteByte('-')
|
||||
b.WriteString(uuid[12:16])
|
||||
b.WriteByte('-')
|
||||
b.WriteString(uuid[16:20])
|
||||
b.WriteByte('-')
|
||||
b.WriteString(uuid[20:32])
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// 如果长度不是32或36,说明格式无效,直接返回原值
|
||||
var logger *zap.Logger
|
||||
logger.Warn("[WARN] UUID格式无效: ", zap.String("uuid:", uuid))
|
||||
return uuid
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user