Initial backend repository commit.
Set up project files and add .gitignore to exclude local build/runtime artifacts. Made-with: Cursor
This commit is contained in:
16
.gitignore
vendored
Normal file
16
.gitignore
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Build artifacts
|
||||||
|
server
|
||||||
|
*.tar
|
||||||
|
|
||||||
|
# Runtime files
|
||||||
|
data/
|
||||||
|
logs/
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Local docker metadata
|
||||||
|
.last_docker_tag
|
||||||
|
.last_docker_tar
|
||||||
|
|
||||||
|
# Local environment files
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
28
Dockerfile
Normal file
28
Dockerfile
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# 使用Debian bookworm-slim作为生产镜像
|
||||||
|
FROM debian:bookworm-slim
|
||||||
|
|
||||||
|
# 安装必要的运行时依赖
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
ca-certificates \
|
||||||
|
tzdata \
|
||||||
|
libsqlite3-0 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& apt-get clean
|
||||||
|
|
||||||
|
# 设置工作目录
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# 从宿主机复制编译好的二进制文件
|
||||||
|
COPY server ./carrot_bbs
|
||||||
|
|
||||||
|
# 复制配置文件
|
||||||
|
COPY configs ./configs
|
||||||
|
|
||||||
|
# 给可执行文件添加执行权限
|
||||||
|
RUN chmod +x ./carrot_bbs
|
||||||
|
|
||||||
|
# 暴露端口
|
||||||
|
EXPOSE 8080
|
||||||
|
|
||||||
|
# 默认命令
|
||||||
|
CMD ["./carrot_bbs"]
|
||||||
172
configs/config.yaml
Normal file
172
configs/config.yaml
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
# 服务器配置
|
||||||
|
# 环境变量: APP_SERVER_HOST, APP_SERVER_PORT, APP_SERVER_MODE
|
||||||
|
server:
|
||||||
|
host: 0.0.0.0
|
||||||
|
port: 8080
|
||||||
|
mode: debug
|
||||||
|
|
||||||
|
# 数据库配置
|
||||||
|
# 环境变量:
|
||||||
|
# SQLite: APP_DATABASE_SQLITE_PATH
|
||||||
|
# Postgres: APP_DATABASE_POSTGRES_HOST, APP_DATABASE_POSTGRES_PORT, APP_DATABASE_POSTGRES_USER,
|
||||||
|
# APP_DATABASE_POSTGRES_PASSWORD, APP_DATABASE_POSTGRES_DBNAME
|
||||||
|
database:
|
||||||
|
type: sqlite # sqlite 或 postgres
|
||||||
|
sqlite:
|
||||||
|
path: ./data/carrot_bbs.db
|
||||||
|
postgres:
|
||||||
|
host: localhost
|
||||||
|
port: 5432
|
||||||
|
user: postgres
|
||||||
|
password: postgres
|
||||||
|
dbname: carrot_bbs
|
||||||
|
sslmode: disable
|
||||||
|
max_idle_conns: 10
|
||||||
|
max_open_conns: 100
|
||||||
|
log_level: warn
|
||||||
|
slow_threshold_ms: 200
|
||||||
|
ignore_record_not_found: true
|
||||||
|
parameterized_queries: true
|
||||||
|
|
||||||
|
# Redis配置
|
||||||
|
# 环境变量:
|
||||||
|
# 类型: APP_REDIS_TYPE (miniredis/redis)
|
||||||
|
# Redis: APP_REDIS_REDIS_HOST, APP_REDIS_REDIS_PORT, APP_REDIS_REDIS_PASSWORD, APP_REDIS_REDIS_DB
|
||||||
|
# Miniredis: APP_REDIS_MINIREDIS_HOST, APP_REDIS_MINIREDIS_PORT
|
||||||
|
redis:
|
||||||
|
type: miniredis # miniredis 或 redis
|
||||||
|
redis:
|
||||||
|
host: localhost
|
||||||
|
port: 6379
|
||||||
|
password: ""
|
||||||
|
db: 0
|
||||||
|
miniredis:
|
||||||
|
host: localhost
|
||||||
|
port: 6379
|
||||||
|
pool_size: 10
|
||||||
|
|
||||||
|
# 缓存配置
|
||||||
|
# 环境变量:
|
||||||
|
# APP_CACHE_ENABLED, APP_CACHE_KEY_PREFIX, APP_CACHE_DEFAULT_TTL, APP_CACHE_NULL_TTL
|
||||||
|
# APP_CACHE_JITTER_RATIO, APP_CACHE_DISABLE_FLUSHDB
|
||||||
|
# APP_CACHE_MODULES_POST_LIST_TTL, APP_CACHE_MODULES_CONVERSATION_TTL
|
||||||
|
# APP_CACHE_MODULES_UNREAD_COUNT_TTL, APP_CACHE_MODULES_GROUP_MEMBERS_TTL
|
||||||
|
cache:
|
||||||
|
enabled: true
|
||||||
|
key_prefix: ""
|
||||||
|
default_ttl: 30
|
||||||
|
null_ttl: 5
|
||||||
|
jitter_ratio: 0.1
|
||||||
|
disable_flushdb: true
|
||||||
|
modules:
|
||||||
|
post_list_ttl: 30
|
||||||
|
conversation_ttl: 60
|
||||||
|
unread_count_ttl: 30
|
||||||
|
group_members_ttl: 120
|
||||||
|
|
||||||
|
# S3对象存储配置
|
||||||
|
# 环境变量: APP_S3_ENDPOINT, APP_S3_ACCESS_KEY, APP_S3_SECRET_KEY, APP_S3_BUCKET, APP_S3_DOMAIN
|
||||||
|
s3:
|
||||||
|
endpoint: ""
|
||||||
|
access_key: ""
|
||||||
|
secret_key: ""
|
||||||
|
bucket: ""
|
||||||
|
use_ssl: true
|
||||||
|
region: us-east-1
|
||||||
|
domain: ""
|
||||||
|
# JWT配置
|
||||||
|
# 环境变量: APP_JWT_SECRET
|
||||||
|
jwt:
|
||||||
|
secret: your-jwt-secret-key-change-in-production
|
||||||
|
access_token_expire: 86400 # 24 hours in seconds
|
||||||
|
refresh_token_expire: 604800 # 7 days in seconds
|
||||||
|
|
||||||
|
log:
|
||||||
|
level: info
|
||||||
|
encoding: json
|
||||||
|
output_paths:
|
||||||
|
- stdout
|
||||||
|
- ./logs/app.log
|
||||||
|
|
||||||
|
rate_limit:
|
||||||
|
enabled: true
|
||||||
|
requests_per_minute: 60
|
||||||
|
|
||||||
|
upload:
|
||||||
|
max_file_size: 10485760 # 10MB
|
||||||
|
allowed_types:
|
||||||
|
- image/jpeg
|
||||||
|
- image/png
|
||||||
|
- image/gif
|
||||||
|
- image/webp
|
||||||
|
|
||||||
|
# 敏感词过滤配置
|
||||||
|
sensitive:
|
||||||
|
enabled: true
|
||||||
|
replace_str: "***"
|
||||||
|
min_match_len: 1
|
||||||
|
load_from_db: true
|
||||||
|
load_from_redis: false
|
||||||
|
redis_key_prefix: "sensitive_words"
|
||||||
|
|
||||||
|
# 内容审核服务配置
|
||||||
|
audit:
|
||||||
|
enabled: false # 暂时关闭第三方审核
|
||||||
|
# 审核服务提供商: local, aliyun, tencent, baidu
|
||||||
|
provider: "local"
|
||||||
|
auto_audit: true
|
||||||
|
timeout: 30
|
||||||
|
# 阿里云配置
|
||||||
|
aliyun_access_key: ""
|
||||||
|
aliyun_secret_key: ""
|
||||||
|
aliyun_region: "cn-shanghai"
|
||||||
|
# 腾讯云配置
|
||||||
|
tencent_secret_id: ""
|
||||||
|
tencent_secret_key: ""
|
||||||
|
# 百度云配置
|
||||||
|
baidu_api_key: ""
|
||||||
|
baidu_secret_key: ""
|
||||||
|
|
||||||
|
# Gorse推荐系统配置
|
||||||
|
# 环境变量: APP_GORSE_ADDRESS, APP_GORSE_API_KEY, APP_GORSE_DASHBOARD, APP_GORSE_IMPORT_PASSWORD
|
||||||
|
gorse:
|
||||||
|
enabled: false
|
||||||
|
address: "" # Gorse server地址
|
||||||
|
api_key: "" # API密钥
|
||||||
|
dashboard: "" # Gorse dashboard地址
|
||||||
|
import_password: "" # 导入数据密码
|
||||||
|
embedding_api_key: ""
|
||||||
|
embedding_url: "https://api.littlelan.cn/v1/embeddings"
|
||||||
|
embedding_model: "BAAI/bge-m3"
|
||||||
|
|
||||||
|
# OpenAI兼容接口配置(用于帖子审核,支持图文)
|
||||||
|
# 环境变量:
|
||||||
|
# APP_OPENAI_ENABLED, APP_OPENAI_BASE_URL, APP_OPENAI_API_KEY
|
||||||
|
# APP_OPENAI_MODERATION_MODEL, APP_OPENAI_MODERATION_MAX_IMAGES_PER_REQUEST
|
||||||
|
# APP_OPENAI_REQUEST_TIMEOUT, APP_OPENAI_STRICT_MODERATION
|
||||||
|
openai:
|
||||||
|
enabled: true
|
||||||
|
base_url: "https://api.littlelan.cn/"
|
||||||
|
api_key: ""
|
||||||
|
moderation_model: "qwen3.5-122b"
|
||||||
|
moderation_max_images_per_request: 1
|
||||||
|
request_timeout: 30
|
||||||
|
strict_moderation: false
|
||||||
|
|
||||||
|
# SMTP发信配置(gomail.v2)
|
||||||
|
# 环境变量:
|
||||||
|
# APP_EMAIL_ENABLED, APP_EMAIL_HOST, APP_EMAIL_PORT
|
||||||
|
# APP_EMAIL_USERNAME, APP_EMAIL_PASSWORD
|
||||||
|
# APP_EMAIL_FROM_ADDRESS, APP_EMAIL_FROM_NAME
|
||||||
|
# APP_EMAIL_USE_TLS, APP_EMAIL_INSECURE_SKIP_VERIFY, APP_EMAIL_TIMEOUT
|
||||||
|
email:
|
||||||
|
enabled: false
|
||||||
|
host: ""
|
||||||
|
port: 587
|
||||||
|
username: ""
|
||||||
|
password: ""
|
||||||
|
from_address: ""
|
||||||
|
from_name: "Carrot BBS"
|
||||||
|
use_tls: true
|
||||||
|
insecure_skip_verify: false
|
||||||
|
timeout: 15
|
||||||
78
docker-compose.yml
Normal file
78
docker-compose.yml
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
# 开发环境使用 SQLite 和 miniredis,不需要外部数据库服务
|
||||||
|
# 如需使用 PostgreSQL 和 Redis,切换 config.yaml 中的配置并取消注释以下服务
|
||||||
|
|
||||||
|
# PostgreSQL (生产环境使用)
|
||||||
|
# postgres:
|
||||||
|
# image: postgres:15-alpine
|
||||||
|
# container_name: carrot_bbs_postgres
|
||||||
|
# environment:
|
||||||
|
# POSTGRES_USER: postgres
|
||||||
|
# POSTGRES_PASSWORD: postgres
|
||||||
|
# POSTGRES_DB: carrot_bbs
|
||||||
|
# ports:
|
||||||
|
# - "5432:5432"
|
||||||
|
# volumes:
|
||||||
|
# - postgres_data:/var/lib/postgresql/data
|
||||||
|
# healthcheck:
|
||||||
|
# test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||||
|
# interval: 10s
|
||||||
|
# timeout: 5s
|
||||||
|
# retries: 5
|
||||||
|
|
||||||
|
# Redis (生产环境使用)
|
||||||
|
# redis:
|
||||||
|
# image: redis:7-alpine
|
||||||
|
# container_name: carrot_bbs_redis
|
||||||
|
# ports:
|
||||||
|
# - "6379:6379"
|
||||||
|
# volumes:
|
||||||
|
# - redis_data:/data
|
||||||
|
# healthcheck:
|
||||||
|
# test: ["CMD", "redis-cli", "ping"]
|
||||||
|
# interval: 10s
|
||||||
|
# timeout: 5s
|
||||||
|
# retries: 5
|
||||||
|
|
||||||
|
# MinIO (对象存储,可选)
|
||||||
|
minio:
|
||||||
|
image: minio/minio:latest
|
||||||
|
container_name: carrot_bbs_minio
|
||||||
|
environment:
|
||||||
|
MINIO_ROOT_USER: minioadmin
|
||||||
|
MINIO_ROOT_PASSWORD: minioadmin
|
||||||
|
ports:
|
||||||
|
- "9000:9000"
|
||||||
|
- "9001:9001"
|
||||||
|
volumes:
|
||||||
|
- minio_data:/data
|
||||||
|
command: server /data --console-address ":9001"
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 20s
|
||||||
|
retries: 3
|
||||||
|
|
||||||
|
app:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
container_name: carrot_bbs_app
|
||||||
|
ports:
|
||||||
|
- "8080:8080"
|
||||||
|
depends_on:
|
||||||
|
minio:
|
||||||
|
condition: service_healthy
|
||||||
|
environment:
|
||||||
|
- CONFIG_PATH=/app/configs/config.yaml
|
||||||
|
volumes:
|
||||||
|
- ./configs:/app/configs
|
||||||
|
- ./logs:/app/logs
|
||||||
|
- ./data:/app/data
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
# postgres_data:
|
||||||
|
# redis_data:
|
||||||
|
minio_data:
|
||||||
86
go.mod
Normal file
86
go.mod
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
module carrot_bbs
|
||||||
|
|
||||||
|
go 1.25
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/alicebob/miniredis/v2 v2.31.0
|
||||||
|
github.com/gin-gonic/gin v1.9.1
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.0
|
||||||
|
github.com/google/uuid v1.5.0
|
||||||
|
github.com/gorilla/websocket v1.5.3
|
||||||
|
github.com/gorse-io/gorse-go v0.5.0-alpha.3
|
||||||
|
github.com/minio/minio-go/v7 v7.0.66
|
||||||
|
github.com/redis/go-redis/v9 v9.3.1
|
||||||
|
github.com/spf13/viper v1.18.2
|
||||||
|
go.uber.org/zap v1.26.0
|
||||||
|
golang.org/x/crypto v0.17.0
|
||||||
|
golang.org/x/image v0.24.0
|
||||||
|
gorm.io/driver/postgres v1.5.4
|
||||||
|
gorm.io/driver/sqlite v1.5.4
|
||||||
|
gorm.io/gorm v1.25.5
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
|
||||||
|
github.com/bytedance/sonic v1.9.1 // indirect
|
||||||
|
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||||
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||||
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
|
github.com/felixge/httpsnoop v1.0.3 // indirect
|
||||||
|
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||||
|
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||||
|
github.com/go-logr/logr v1.2.3 // indirect
|
||||||
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
|
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||||
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
|
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||||
|
github.com/jackc/pgx/v5 v5.4.3 // indirect
|
||||||
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
|
github.com/klauspost/compress v1.17.4 // indirect
|
||||||
|
github.com/klauspost/cpuid/v2 v2.2.6 // indirect
|
||||||
|
github.com/leodido/go-urn v1.2.4 // indirect
|
||||||
|
github.com/magiconair/properties v1.8.7 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.17 // indirect
|
||||||
|
github.com/minio/md5-simd v1.1.2 // indirect
|
||||||
|
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||||
|
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
|
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
|
||||||
|
github.com/rs/xid v1.5.0 // indirect
|
||||||
|
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||||
|
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||||
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
|
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||||
|
github.com/spf13/afero v1.11.0 // indirect
|
||||||
|
github.com/spf13/cast v1.6.0 // indirect
|
||||||
|
github.com/spf13/pflag v1.0.5 // indirect
|
||||||
|
github.com/subosito/gotenv v1.6.0 // indirect
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
|
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||||
|
github.com/yuin/gopher-lua v1.1.0 // indirect
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4 // indirect
|
||||||
|
go.opentelemetry.io/otel v1.11.1 // indirect
|
||||||
|
go.opentelemetry.io/otel/metric v0.33.0 // indirect
|
||||||
|
go.opentelemetry.io/otel/trace v1.11.1 // indirect
|
||||||
|
go.uber.org/multierr v1.10.0 // indirect
|
||||||
|
golang.org/x/arch v0.3.0 // indirect
|
||||||
|
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
||||||
|
golang.org/x/net v0.19.0 // indirect
|
||||||
|
golang.org/x/sys v0.15.0 // indirect
|
||||||
|
golang.org/x/text v0.22.0 // indirect
|
||||||
|
google.golang.org/protobuf v1.31.0 // indirect
|
||||||
|
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
|
||||||
|
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df // indirect
|
||||||
|
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
)
|
||||||
216
go.sum
Normal file
216
go.sum
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
github.com/DmitriyVTitov/size v1.5.0/go.mod h1:le6rNI4CoLQV1b9gzp1+3d7hMAD/uu2QcJ+aYbNgiU0=
|
||||||
|
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
|
||||||
|
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
|
||||||
|
github.com/alicebob/miniredis/v2 v2.31.0 h1:ObEFUNlJwoIiyjxdrYF0QIDE7qXcLc7D3WpSH4c22PU=
|
||||||
|
github.com/alicebob/miniredis/v2 v2.31.0/go.mod h1:UB/T2Uztp7MlFSDakaX1sTXUv5CASoprx0wulRT6HBg=
|
||||||
|
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||||
|
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||||
|
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||||
|
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||||
|
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||||
|
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||||
|
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||||
|
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||||
|
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||||
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||||
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||||
|
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
|
||||||
|
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
|
||||||
|
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
|
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
|
||||||
|
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||||
|
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||||
|
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||||
|
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||||
|
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
||||||
|
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||||
|
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||||
|
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||||
|
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||||
|
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
|
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
|
||||||
|
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||||
|
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||||
|
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||||
|
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||||
|
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||||
|
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
|
||||||
|
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||||
|
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||||
|
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
|
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||||
|
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
|
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||||
|
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
|
||||||
|
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
|
github.com/gorse-io/gorse-go v0.5.0-alpha.3 h1:GR/OWzq016VyFyTTxgQWeayGahRVzB1cGFIW/AaShC4=
|
||||||
|
github.com/gorse-io/gorse-go v0.5.0-alpha.3/go.mod h1:ZxmVHzZPKm5pmEIlqaRDwK0rkfTRHlfziO033XZ+RW0=
|
||||||
|
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||||
|
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
||||||
|
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
|
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
|
||||||
|
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
|
||||||
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
|
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||||
|
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
|
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
|
||||||
|
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||||
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
|
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
||||||
|
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
||||||
|
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
||||||
|
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||||
|
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||||
|
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||||
|
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||||
|
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||||
|
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
|
||||||
|
github.com/minio/minio-go/v7 v7.0.66/go.mod h1:DHAgmyQEGdW3Cif0UooKOyrT3Vxs82zNdV6tkKhRtbs=
|
||||||
|
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
|
||||||
|
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
|
||||||
|
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||||
|
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||||
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/redis/go-redis/v9 v9.3.1 h1:KqdY8U+3X6z+iACvumCNxnoluToB+9Me+TvyFa21Mds=
|
||||||
|
github.com/redis/go-redis/v9 v9.3.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
|
||||||
|
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||||
|
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||||
|
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||||
|
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||||
|
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
|
||||||
|
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
|
||||||
|
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
|
||||||
|
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
|
||||||
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
|
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
||||||
|
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
||||||
|
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||||
|
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||||
|
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||||
|
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||||
|
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||||
|
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
|
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||||
|
github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
|
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
|
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||||
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
|
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||||
|
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
|
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||||
|
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||||
|
github.com/yuin/gopher-lua v1.1.0 h1:BojcDhfyDWgU2f2TOzYK/g5p2gxMrku8oupLDqlnSqE=
|
||||||
|
github.com/yuin/gopher-lua v1.1.0/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4 h1:aUEBEdCa6iamGzg6fuYxDA8ThxvOG240mAvWDU+XLio=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4/go.mod h1:l2MdsbKTocpPS5nQZscqTR9jd8u96VYZdcpF8Sye7mA=
|
||||||
|
go.opentelemetry.io/otel v1.11.1 h1:4WLLAmcfkmDk2ukNXJyq3/kiz/3UzCaYq6PskJsaou4=
|
||||||
|
go.opentelemetry.io/otel v1.11.1/go.mod h1:1nNhXBbWSD0nsL38H6btgnFN2k4i0sNLHNNMZMSbUGE=
|
||||||
|
go.opentelemetry.io/otel/metric v0.33.0 h1:xQAyl7uGEYvrLAiV/09iTJlp1pZnQ9Wl793qbVvED1E=
|
||||||
|
go.opentelemetry.io/otel/metric v0.33.0/go.mod h1:QlTYc+EnYNq/M2mNk1qDDMRLpqCOj2f/r5c7Fd5FYaI=
|
||||||
|
go.opentelemetry.io/otel/trace v1.11.1 h1:ofxdnzsNrGBYXbP7t7zpUK281+go5rF7dvdIZXF8gdQ=
|
||||||
|
go.opentelemetry.io/otel/trace v1.11.1/go.mod h1:f/Q9G7vzk5u91PhbmKbg1Qn0rzH1LJ4vbPHFGkTPtOk=
|
||||||
|
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
|
||||||
|
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
|
||||||
|
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||||
|
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||||
|
go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo=
|
||||||
|
go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
|
||||||
|
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
|
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||||
|
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
|
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
|
||||||
|
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
||||||
|
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
|
||||||
|
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
|
||||||
|
golang.org/x/image v0.24.0 h1:AN7zRgVsbvmTfNyqIbbOraYL8mSwcKncEj8ofjgzcMQ=
|
||||||
|
golang.org/x/image v0.24.0/go.mod h1:4b/ITuLfqYq1hqZcjofwctIhi7sZh2WaCjvsBNjjya8=
|
||||||
|
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
|
||||||
|
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
|
||||||
|
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
|
||||||
|
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
|
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||||
|
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||||
|
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
|
||||||
|
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||||
|
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk=
|
||||||
|
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
|
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE=
|
||||||
|
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw=
|
||||||
|
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||||
|
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
|
||||||
|
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
|
||||||
|
gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0=
|
||||||
|
gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
|
||||||
|
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
|
||||||
|
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||||
|
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||||
604
internal/cache/cache.go
vendored
Normal file
604
internal/cache/cache.go
vendored
Normal file
@@ -0,0 +1,604 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"math/rand"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
|
||||||
|
redisPkg "carrot_bbs/internal/pkg/redis"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cache 缓存接口
|
||||||
|
type Cache interface {
|
||||||
|
// Set 设置缓存值,支持TTL
|
||||||
|
Set(key string, value interface{}, ttl time.Duration)
|
||||||
|
// Get 获取缓存值
|
||||||
|
Get(key string) (interface{}, bool)
|
||||||
|
// Delete 删除缓存
|
||||||
|
Delete(key string)
|
||||||
|
// DeleteByPrefix 根据前缀删除缓存
|
||||||
|
DeleteByPrefix(prefix string)
|
||||||
|
// Clear 清空所有缓存
|
||||||
|
Clear()
|
||||||
|
// Exists 检查键是否存在
|
||||||
|
Exists(key string) bool
|
||||||
|
// Increment 增加计数器的值
|
||||||
|
Increment(key string) int64
|
||||||
|
// IncrementBy 增加指定值
|
||||||
|
IncrementBy(key string, value int64) int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// cacheItem 缓存项(用于内存缓存降级)
|
||||||
|
type cacheItem struct {
|
||||||
|
value interface{}
|
||||||
|
expiration int64 // 过期时间戳(纳秒)
|
||||||
|
}
|
||||||
|
|
||||||
|
const nullMarkerValue = "__carrot_cache_null__"
|
||||||
|
|
||||||
|
type cacheMetrics struct {
|
||||||
|
hit atomic.Int64
|
||||||
|
miss atomic.Int64
|
||||||
|
decodeError atomic.Int64
|
||||||
|
setError atomic.Int64
|
||||||
|
invalidate atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
var metrics cacheMetrics
|
||||||
|
var loadLocks sync.Map
|
||||||
|
|
||||||
|
type MetricsSnapshot struct {
|
||||||
|
Hit int64
|
||||||
|
Miss int64
|
||||||
|
DecodeError int64
|
||||||
|
SetError int64
|
||||||
|
Invalidate int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type Settings struct {
|
||||||
|
Enabled bool
|
||||||
|
KeyPrefix string
|
||||||
|
DefaultTTL time.Duration
|
||||||
|
NullTTL time.Duration
|
||||||
|
JitterRatio float64
|
||||||
|
PostListTTL time.Duration
|
||||||
|
ConversationTTL time.Duration
|
||||||
|
UnreadCountTTL time.Duration
|
||||||
|
GroupMembersTTL time.Duration
|
||||||
|
DisableFlushDB bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var settings = Settings{
|
||||||
|
Enabled: true,
|
||||||
|
DefaultTTL: 30 * time.Second,
|
||||||
|
NullTTL: 5 * time.Second,
|
||||||
|
JitterRatio: 0.1,
|
||||||
|
PostListTTL: 30 * time.Second,
|
||||||
|
ConversationTTL: 60 * time.Second,
|
||||||
|
UnreadCountTTL: 30 * time.Second,
|
||||||
|
GroupMembersTTL: 120 * time.Second,
|
||||||
|
DisableFlushDB: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
func Configure(s Settings) {
|
||||||
|
settings.Enabled = s.Enabled
|
||||||
|
if s.KeyPrefix != "" {
|
||||||
|
settings.KeyPrefix = s.KeyPrefix
|
||||||
|
}
|
||||||
|
if s.DefaultTTL > 0 {
|
||||||
|
settings.DefaultTTL = s.DefaultTTL
|
||||||
|
}
|
||||||
|
if s.NullTTL > 0 {
|
||||||
|
settings.NullTTL = s.NullTTL
|
||||||
|
}
|
||||||
|
if s.JitterRatio > 0 {
|
||||||
|
settings.JitterRatio = s.JitterRatio
|
||||||
|
}
|
||||||
|
if s.PostListTTL > 0 {
|
||||||
|
settings.PostListTTL = s.PostListTTL
|
||||||
|
}
|
||||||
|
if s.ConversationTTL > 0 {
|
||||||
|
settings.ConversationTTL = s.ConversationTTL
|
||||||
|
}
|
||||||
|
if s.UnreadCountTTL > 0 {
|
||||||
|
settings.UnreadCountTTL = s.UnreadCountTTL
|
||||||
|
}
|
||||||
|
if s.GroupMembersTTL > 0 {
|
||||||
|
settings.GroupMembersTTL = s.GroupMembersTTL
|
||||||
|
}
|
||||||
|
settings.DisableFlushDB = s.DisableFlushDB
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSettings() Settings {
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeKey(key string) string {
|
||||||
|
if settings.KeyPrefix == "" {
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
return settings.KeyPrefix + ":" + key
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizePrefix(prefix string) string {
|
||||||
|
if settings.KeyPrefix == "" {
|
||||||
|
return prefix
|
||||||
|
}
|
||||||
|
return settings.KeyPrefix + ":" + prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetMetricsSnapshot() MetricsSnapshot {
|
||||||
|
return MetricsSnapshot{
|
||||||
|
Hit: metrics.hit.Load(),
|
||||||
|
Miss: metrics.miss.Load(),
|
||||||
|
DecodeError: metrics.decodeError.Load(),
|
||||||
|
SetError: metrics.setError.Load(),
|
||||||
|
Invalidate: metrics.invalidate.Load(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isExpired 检查是否过期
|
||||||
|
func (item *cacheItem) isExpired() bool {
|
||||||
|
if item.expiration == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return time.Now().UnixNano() > item.expiration
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemoryCache 内存缓存实现(降级使用)
|
||||||
|
type MemoryCache struct {
|
||||||
|
items sync.Map
|
||||||
|
// cleanupInterval 清理过期缓存的间隔
|
||||||
|
cleanupInterval time.Duration
|
||||||
|
// stopCleanup 停止清理协程的通道
|
||||||
|
stopCleanup chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemoryCache 创建内存缓存
|
||||||
|
func NewMemoryCache() *MemoryCache {
|
||||||
|
c := &MemoryCache{
|
||||||
|
cleanupInterval: 1 * time.Minute,
|
||||||
|
stopCleanup: make(chan struct{}),
|
||||||
|
}
|
||||||
|
// 启动后台清理协程
|
||||||
|
go c.cleanup()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set 设置缓存值
|
||||||
|
func (c *MemoryCache) Set(key string, value interface{}, ttl time.Duration) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
var expiration int64
|
||||||
|
if ttl > 0 {
|
||||||
|
expiration = time.Now().Add(ttl).UnixNano()
|
||||||
|
}
|
||||||
|
c.items.Store(key, &cacheItem{
|
||||||
|
value: value,
|
||||||
|
expiration: expiration,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 获取缓存值
|
||||||
|
func (c *MemoryCache) Get(key string) (interface{}, bool) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
val, ok := c.items.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
item := val.(*cacheItem)
|
||||||
|
if item.isExpired() {
|
||||||
|
c.items.Delete(key)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return item.value, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除缓存
|
||||||
|
func (c *MemoryCache) Delete(key string) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
metrics.invalidate.Add(1)
|
||||||
|
c.items.Delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteByPrefix 根据前缀删除缓存
|
||||||
|
func (c *MemoryCache) DeleteByPrefix(prefix string) {
|
||||||
|
prefix = normalizePrefix(prefix)
|
||||||
|
c.items.Range(func(key, value interface{}) bool {
|
||||||
|
if keyStr, ok := key.(string); ok {
|
||||||
|
if strings.HasPrefix(keyStr, prefix) {
|
||||||
|
metrics.invalidate.Add(1)
|
||||||
|
c.items.Delete(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear 清空所有缓存
|
||||||
|
func (c *MemoryCache) Clear() {
|
||||||
|
c.items.Range(func(key, value interface{}) bool {
|
||||||
|
metrics.invalidate.Add(1)
|
||||||
|
c.items.Delete(key)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exists 检查键是否存在
|
||||||
|
func (c *MemoryCache) Exists(key string) bool {
|
||||||
|
_, ok := c.Get(key)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment 增加计数器的值
|
||||||
|
func (c *MemoryCache) Increment(key string) int64 {
|
||||||
|
return c.IncrementBy(key, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementBy 增加指定值
|
||||||
|
func (c *MemoryCache) IncrementBy(key string, value int64) int64 {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
for {
|
||||||
|
val, ok := c.items.Load(key)
|
||||||
|
if !ok {
|
||||||
|
// 键不存在,创建新值
|
||||||
|
c.items.Store(key, &cacheItem{
|
||||||
|
value: value,
|
||||||
|
expiration: 0,
|
||||||
|
})
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
item := val.(*cacheItem)
|
||||||
|
if item.isExpired() {
|
||||||
|
// 已过期,创建新值
|
||||||
|
c.items.Store(key, &cacheItem{
|
||||||
|
value: value,
|
||||||
|
expiration: 0,
|
||||||
|
})
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试更新
|
||||||
|
currentValue, ok := item.value.(int64)
|
||||||
|
if !ok {
|
||||||
|
// 类型不匹配,覆盖为新值
|
||||||
|
c.items.Store(key, &cacheItem{
|
||||||
|
value: value,
|
||||||
|
expiration: item.expiration,
|
||||||
|
})
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
newValue := currentValue + value
|
||||||
|
// 使用 CAS 操作确保并发安全
|
||||||
|
if c.items.CompareAndSwap(key, val, &cacheItem{
|
||||||
|
value: newValue,
|
||||||
|
expiration: item.expiration,
|
||||||
|
}) {
|
||||||
|
return newValue
|
||||||
|
}
|
||||||
|
// CAS 失败,重试
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup 定期清理过期缓存
|
||||||
|
func (c *MemoryCache) cleanup() {
|
||||||
|
ticker := time.NewTicker(c.cleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
c.cleanExpired()
|
||||||
|
case <-c.stopCleanup:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanExpired 清理过期缓存
|
||||||
|
func (c *MemoryCache) cleanExpired() {
|
||||||
|
count := 0
|
||||||
|
c.items.Range(func(key, value interface{}) bool {
|
||||||
|
item := value.(*cacheItem)
|
||||||
|
if item.isExpired() {
|
||||||
|
c.items.Delete(key)
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if count > 0 {
|
||||||
|
log.Printf("[Cache] Cleaned %d expired items", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop 停止缓存清理协程
|
||||||
|
func (c *MemoryCache) Stop() {
|
||||||
|
close(c.stopCleanup)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedisCache Redis缓存实现
|
||||||
|
type RedisCache struct {
|
||||||
|
client *redisPkg.Client
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRedisCache 创建Redis缓存
|
||||||
|
func NewRedisCache(client *redisPkg.Client) *RedisCache {
|
||||||
|
return &RedisCache{
|
||||||
|
client: client,
|
||||||
|
ctx: context.Background(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set 设置缓存值
|
||||||
|
func (c *RedisCache) Set(key string, value interface{}, ttl time.Duration) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
// 将值序列化为JSON
|
||||||
|
data, err := json.Marshal(value)
|
||||||
|
if err != nil {
|
||||||
|
metrics.setError.Add(1)
|
||||||
|
log.Printf("[RedisCache] Failed to marshal value for key %s: %v", key, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.client.Set(c.ctx, key, data, ttl); err != nil {
|
||||||
|
metrics.setError.Add(1)
|
||||||
|
log.Printf("[RedisCache] Failed to set key %s: %v", key, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 获取缓存值
|
||||||
|
func (c *RedisCache) Get(key string) (interface{}, bool) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
data, err := c.client.Get(c.ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
if err == redis.Nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
log.Printf("[RedisCache] Failed to get key %s: %v", key, err)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 返回原始字符串,由调用侧决定如何解码为目标类型
|
||||||
|
return data, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除缓存
|
||||||
|
func (c *RedisCache) Delete(key string) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
metrics.invalidate.Add(1)
|
||||||
|
if err := c.client.Del(c.ctx, key); err != nil {
|
||||||
|
log.Printf("[RedisCache] Failed to delete key %s: %v", key, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteByPrefix 根据前缀删除缓存
|
||||||
|
func (c *RedisCache) DeleteByPrefix(prefix string) {
|
||||||
|
prefix = normalizePrefix(prefix)
|
||||||
|
// 使用原生客户端执行SCAN命令
|
||||||
|
rdb := c.client.GetClient()
|
||||||
|
var cursor uint64
|
||||||
|
for {
|
||||||
|
keys, nextCursor, err := rdb.Scan(c.ctx, cursor, prefix+"*", 100).Result()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[RedisCache] Failed to scan keys with prefix %s: %v", prefix, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(keys) > 0 {
|
||||||
|
metrics.invalidate.Add(int64(len(keys)))
|
||||||
|
if err := c.client.Del(c.ctx, keys...); err != nil {
|
||||||
|
log.Printf("[RedisCache] Failed to delete keys with prefix %s: %v", prefix, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cursor = nextCursor
|
||||||
|
if cursor == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear 清空所有缓存
|
||||||
|
func (c *RedisCache) Clear() {
|
||||||
|
if settings.DisableFlushDB {
|
||||||
|
log.Printf("[RedisCache] Skip FlushDB because cache.disable_flushdb=true")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
metrics.invalidate.Add(1)
|
||||||
|
rdb := c.client.GetClient()
|
||||||
|
if err := rdb.FlushDB(c.ctx).Err(); err != nil {
|
||||||
|
log.Printf("[RedisCache] Failed to clear cache: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exists 检查键是否存在
|
||||||
|
func (c *RedisCache) Exists(key string) bool {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
n, err := c.client.Exists(c.ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[RedisCache] Failed to check existence of key %s: %v", key, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return n > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment 增加计数器的值
|
||||||
|
func (c *RedisCache) Increment(key string) int64 {
|
||||||
|
return c.IncrementBy(key, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementBy 增加指定值
|
||||||
|
func (c *RedisCache) IncrementBy(key string, value int64) int64 {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
rdb := c.client.GetClient()
|
||||||
|
result, err := rdb.IncrBy(c.ctx, key, value).Result()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[RedisCache] Failed to increment key %s: %v", key, err)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// 全局缓存实例
|
||||||
|
var globalCache Cache
|
||||||
|
var once sync.Once
|
||||||
|
|
||||||
|
// InitCache 初始化全局缓存实例(使用Redis)
|
||||||
|
func InitCache(redisClient *redisPkg.Client) {
|
||||||
|
once.Do(func() {
|
||||||
|
if redisClient != nil {
|
||||||
|
globalCache = NewRedisCache(redisClient)
|
||||||
|
log.Println("[Cache] Initialized Redis cache")
|
||||||
|
} else {
|
||||||
|
globalCache = NewMemoryCache()
|
||||||
|
log.Println("[Cache] Initialized Memory cache (Redis not available)")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCache 获取全局缓存实例
|
||||||
|
func GetCache() Cache {
|
||||||
|
if globalCache == nil {
|
||||||
|
// 如果未初始化,返回内存缓存作为降级
|
||||||
|
log.Println("[Cache] Warning: Cache not initialized, using Memory cache")
|
||||||
|
return NewMemoryCache()
|
||||||
|
}
|
||||||
|
return globalCache
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRedisClient 从缓存中获取Redis客户端(仅在Redis模式下有效)
|
||||||
|
func GetRedisClient() (*redisPkg.Client, error) {
|
||||||
|
if redisCache, ok := globalCache.(*RedisCache); ok {
|
||||||
|
return redisCache.client, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("cache is not using Redis backend")
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetWithJitter(c Cache, key string, value interface{}, ttl time.Duration, jitterRatio float64) {
|
||||||
|
if !settings.Enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set(key, value, ApplyTTLJitter(ttl, jitterRatio))
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetNull(c Cache, key string, ttl time.Duration) {
|
||||||
|
if !settings.Enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set(key, nullMarkerValue, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplyTTLJitter(ttl time.Duration, jitterRatio float64) time.Duration {
|
||||||
|
if ttl <= 0 || jitterRatio <= 0 {
|
||||||
|
return ttl
|
||||||
|
}
|
||||||
|
if jitterRatio > 1 {
|
||||||
|
jitterRatio = 1
|
||||||
|
}
|
||||||
|
maxJitter := int64(float64(ttl) * jitterRatio)
|
||||||
|
if maxJitter <= 0 {
|
||||||
|
return ttl
|
||||||
|
}
|
||||||
|
delta := rand.Int63n(maxJitter + 1)
|
||||||
|
return ttl + time.Duration(delta)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetTyped[T any](c Cache, key string) (T, bool) {
|
||||||
|
var zero T
|
||||||
|
if !settings.Enabled {
|
||||||
|
return zero, false
|
||||||
|
}
|
||||||
|
raw, ok := c.Get(key)
|
||||||
|
if !ok {
|
||||||
|
metrics.miss.Add(1)
|
||||||
|
return zero, false
|
||||||
|
}
|
||||||
|
if str, ok := raw.(string); ok && str == nullMarkerValue {
|
||||||
|
metrics.hit.Add(1)
|
||||||
|
return zero, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if typed, ok := raw.(T); ok {
|
||||||
|
metrics.hit.Add(1)
|
||||||
|
return typed, true
|
||||||
|
}
|
||||||
|
|
||||||
|
var out T
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case string:
|
||||||
|
if err := json.Unmarshal([]byte(v), &out); err != nil {
|
||||||
|
metrics.decodeError.Add(1)
|
||||||
|
return zero, false
|
||||||
|
}
|
||||||
|
metrics.hit.Add(1)
|
||||||
|
return out, true
|
||||||
|
case []byte:
|
||||||
|
if err := json.Unmarshal(v, &out); err != nil {
|
||||||
|
metrics.decodeError.Add(1)
|
||||||
|
return zero, false
|
||||||
|
}
|
||||||
|
metrics.hit.Add(1)
|
||||||
|
return out, true
|
||||||
|
default:
|
||||||
|
data, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
metrics.decodeError.Add(1)
|
||||||
|
return zero, false
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &out); err != nil {
|
||||||
|
metrics.decodeError.Add(1)
|
||||||
|
return zero, false
|
||||||
|
}
|
||||||
|
metrics.hit.Add(1)
|
||||||
|
return out, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetOrLoadTyped[T any](
|
||||||
|
c Cache,
|
||||||
|
key string,
|
||||||
|
ttl time.Duration,
|
||||||
|
jitterRatio float64,
|
||||||
|
nullTTL time.Duration,
|
||||||
|
loader func() (T, error),
|
||||||
|
) (T, error) {
|
||||||
|
if cached, ok := GetTyped[T](c, key); ok {
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
lockValue, _ := loadLocks.LoadOrStore(key, &sync.Mutex{})
|
||||||
|
lock := lockValue.(*sync.Mutex)
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
|
||||||
|
if cached, ok := GetTyped[T](c, key); ok {
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err := loader()
|
||||||
|
if err != nil {
|
||||||
|
var zero T
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
|
||||||
|
encoded, marshalErr := json.Marshal(loaded)
|
||||||
|
if marshalErr == nil && string(encoded) == "null" && nullTTL > 0 {
|
||||||
|
SetNull(c, key, nullTTL)
|
||||||
|
return loaded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
SetWithJitter(c, key, loaded, ttl, jitterRatio)
|
||||||
|
return loaded, nil
|
||||||
|
}
|
||||||
147
internal/cache/keys.go
vendored
Normal file
147
internal/cache/keys.go
vendored
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 缓存键前缀常量
|
||||||
|
const (
|
||||||
|
// 帖子相关
|
||||||
|
PrefixPostList = "posts:list"
|
||||||
|
PrefixPost = "posts:detail"
|
||||||
|
|
||||||
|
// 会话相关
|
||||||
|
PrefixConversationList = "conversations:list"
|
||||||
|
PrefixConversationDetail = "conversations:detail"
|
||||||
|
|
||||||
|
// 群组相关
|
||||||
|
PrefixGroupMembers = "groups:members"
|
||||||
|
PrefixGroupInfo = "groups:info"
|
||||||
|
|
||||||
|
// 未读数相关
|
||||||
|
PrefixUnreadSystem = "unread:system"
|
||||||
|
PrefixUnreadConversation = "unread:conversation"
|
||||||
|
PrefixUnreadDetail = "unread:detail"
|
||||||
|
|
||||||
|
// 用户相关
|
||||||
|
PrefixUserInfo = "users:info"
|
||||||
|
PrefixUserMe = "users:me"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostListKey 生成帖子列表缓存键
|
||||||
|
// postType: 帖子类型 (recommend, hot, follow, latest)
|
||||||
|
// page: 页码
|
||||||
|
// pageSize: 每页数量
|
||||||
|
// userID: 用户维度(仅在个性化列表如 follow 场景使用)
|
||||||
|
func PostListKey(postType string, userID string, page, pageSize int) string {
|
||||||
|
if userID == "" {
|
||||||
|
return fmt.Sprintf("%s:%s:%d:%d", PrefixPostList, postType, page, pageSize)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s:%s:%s:%d:%d", PrefixPostList, postType, userID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostDetailKey 生成帖子详情缓存键
|
||||||
|
func PostDetailKey(postID string) string {
|
||||||
|
return fmt.Sprintf("%s:%s", PrefixPost, postID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationListKey 生成会话列表缓存键
|
||||||
|
func ConversationListKey(userID string, page, pageSize int) string {
|
||||||
|
return fmt.Sprintf("%s:%s:%d:%d", PrefixConversationList, userID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationDetailKey 生成会话详情缓存键
|
||||||
|
func ConversationDetailKey(conversationID, userID string) string {
|
||||||
|
return fmt.Sprintf("%s:%s:%s", PrefixConversationDetail, conversationID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupMembersKey 生成群组成员缓存键
|
||||||
|
func GroupMembersKey(groupID string, page, pageSize int) string {
|
||||||
|
return fmt.Sprintf("%s:%s:page:%d:size:%d", PrefixGroupMembers, groupID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupMembersAllKey 生成群组全量成员ID列表缓存键
|
||||||
|
func GroupMembersAllKey(groupID string) string {
|
||||||
|
return fmt.Sprintf("%s:all:%s", PrefixGroupMembers, groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupInfoKey 生成群组信息缓存键
|
||||||
|
func GroupInfoKey(groupID string) string {
|
||||||
|
return fmt.Sprintf("%s:%s", PrefixGroupInfo, groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnreadSystemKey 生成系统消息未读数缓存键
|
||||||
|
func UnreadSystemKey(userID string) string {
|
||||||
|
return fmt.Sprintf("%s:%s", PrefixUnreadSystem, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnreadConversationKey 生成会话未读总数缓存键
|
||||||
|
func UnreadConversationKey(userID string) string {
|
||||||
|
return fmt.Sprintf("%s:%s", PrefixUnreadConversation, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnreadDetailKey 生成单个会话未读数缓存键
|
||||||
|
func UnreadDetailKey(userID, conversationID string) string {
|
||||||
|
return fmt.Sprintf("%s:%s:%s", PrefixUnreadDetail, userID, conversationID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserInfoKey 生成用户信息缓存键
|
||||||
|
func UserInfoKey(userID string) string {
|
||||||
|
return fmt.Sprintf("%s:%s", PrefixUserInfo, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserMeKey 生成当前用户信息缓存键
|
||||||
|
func UserMeKey(userID string) string {
|
||||||
|
return fmt.Sprintf("%s:%s", PrefixUserMe, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidatePostList 失效帖子列表缓存
|
||||||
|
func InvalidatePostList(cache Cache) {
|
||||||
|
cache.DeleteByPrefix(PrefixPostList)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidatePostDetail 失效帖子详情缓存
|
||||||
|
func InvalidatePostDetail(cache Cache, postID string) {
|
||||||
|
cache.Delete(PostDetailKey(postID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateConversationList 失效会话列表缓存
|
||||||
|
func InvalidateConversationList(cache Cache, userID string) {
|
||||||
|
cache.DeleteByPrefix(PrefixConversationList + ":" + userID + ":")
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateConversationDetail 失效会话详情缓存
|
||||||
|
func InvalidateConversationDetail(cache Cache, conversationID, userID string) {
|
||||||
|
cache.Delete(ConversationDetailKey(conversationID, userID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateGroupMembers 失效群组成员缓存
|
||||||
|
func InvalidateGroupMembers(cache Cache, groupID string) {
|
||||||
|
cache.DeleteByPrefix(PrefixGroupMembers + ":" + groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateGroupInfo 失效群组信息缓存
|
||||||
|
func InvalidateGroupInfo(cache Cache, groupID string) {
|
||||||
|
cache.Delete(GroupInfoKey(groupID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateUnreadSystem 失效系统消息未读数缓存
|
||||||
|
func InvalidateUnreadSystem(cache Cache, userID string) {
|
||||||
|
cache.Delete(UnreadSystemKey(userID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateUnreadConversation 失效会话未读数缓存
|
||||||
|
func InvalidateUnreadConversation(cache Cache, userID string) {
|
||||||
|
cache.Delete(UnreadConversationKey(userID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateUnreadDetail 失效单个会话未读数缓存
|
||||||
|
func InvalidateUnreadDetail(cache Cache, userID, conversationID string) {
|
||||||
|
cache.Delete(UnreadDetailKey(userID, conversationID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateUserInfo 失效用户信息缓存
|
||||||
|
func InvalidateUserInfo(cache Cache, userID string) {
|
||||||
|
cache.Delete(UserInfoKey(userID))
|
||||||
|
cache.Delete(UserMeKey(userID))
|
||||||
|
}
|
||||||
393
internal/config/config.go
Normal file
393
internal/config/config.go
Normal file
@@ -0,0 +1,393 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/minio/minio-go/v7"
|
||||||
|
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Server ServerConfig `mapstructure:"server"`
|
||||||
|
Database DatabaseConfig `mapstructure:"database"`
|
||||||
|
Redis RedisConfig `mapstructure:"redis"`
|
||||||
|
Cache CacheConfig `mapstructure:"cache"`
|
||||||
|
S3 S3Config `mapstructure:"s3"`
|
||||||
|
JWT JWTConfig `mapstructure:"jwt"`
|
||||||
|
Log LogConfig `mapstructure:"log"`
|
||||||
|
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||||
|
Upload UploadConfig `mapstructure:"upload"`
|
||||||
|
Gorse GorseConfig `mapstructure:"gorse"`
|
||||||
|
OpenAI OpenAIConfig `mapstructure:"openai"`
|
||||||
|
Email EmailConfig `mapstructure:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerConfig struct {
|
||||||
|
Host string `mapstructure:"host"`
|
||||||
|
Port int `mapstructure:"port"`
|
||||||
|
Mode string `mapstructure:"mode"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DatabaseConfig struct {
|
||||||
|
Type string `mapstructure:"type"`
|
||||||
|
SQLite SQLiteConfig `mapstructure:"sqlite"`
|
||||||
|
Postgres PostgresConfig `mapstructure:"postgres"`
|
||||||
|
MaxIdleConns int `mapstructure:"max_idle_conns"`
|
||||||
|
MaxOpenConns int `mapstructure:"max_open_conns"`
|
||||||
|
LogLevel string `mapstructure:"log_level"`
|
||||||
|
SlowThresholdMs int `mapstructure:"slow_threshold_ms"`
|
||||||
|
IgnoreRecordNotFound bool `mapstructure:"ignore_record_not_found"`
|
||||||
|
ParameterizedQueries bool `mapstructure:"parameterized_queries"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SQLiteConfig struct {
|
||||||
|
Path string `mapstructure:"path"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PostgresConfig struct {
|
||||||
|
Host string `mapstructure:"host"`
|
||||||
|
Port int `mapstructure:"port"`
|
||||||
|
User string `mapstructure:"user"`
|
||||||
|
Password string `mapstructure:"password"`
|
||||||
|
DBName string `mapstructure:"dbname"`
|
||||||
|
SSLMode string `mapstructure:"sslmode"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d PostgresConfig) DSN() string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||||
|
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
type RedisConfig struct {
|
||||||
|
Type string `mapstructure:"type"`
|
||||||
|
Redis RedisServerConfig `mapstructure:"redis"`
|
||||||
|
Miniredis MiniredisConfig `mapstructure:"miniredis"`
|
||||||
|
PoolSize int `mapstructure:"pool_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CacheConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
KeyPrefix string `mapstructure:"key_prefix"`
|
||||||
|
DefaultTTL int `mapstructure:"default_ttl"`
|
||||||
|
NullTTL int `mapstructure:"null_ttl"`
|
||||||
|
JitterRatio float64 `mapstructure:"jitter_ratio"`
|
||||||
|
DisableFlushDB bool `mapstructure:"disable_flushdb"`
|
||||||
|
Modules CacheModuleTTL `mapstructure:"modules"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CacheModuleTTL struct {
|
||||||
|
PostList int `mapstructure:"post_list_ttl"`
|
||||||
|
Conversation int `mapstructure:"conversation_ttl"`
|
||||||
|
UnreadCount int `mapstructure:"unread_count_ttl"`
|
||||||
|
GroupMembers int `mapstructure:"group_members_ttl"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RedisServerConfig struct {
|
||||||
|
Host string `mapstructure:"host"`
|
||||||
|
Port int `mapstructure:"port"`
|
||||||
|
Password string `mapstructure:"password"`
|
||||||
|
DB int `mapstructure:"db"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r RedisServerConfig) Addr() string {
|
||||||
|
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MiniredisConfig struct {
|
||||||
|
Host string `mapstructure:"host"`
|
||||||
|
Port int `mapstructure:"port"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type S3Config struct {
|
||||||
|
Endpoint string `mapstructure:"endpoint"`
|
||||||
|
AccessKey string `mapstructure:"access_key"`
|
||||||
|
SecretKey string `mapstructure:"secret_key"`
|
||||||
|
Bucket string `mapstructure:"bucket"`
|
||||||
|
UseSSL bool `mapstructure:"use_ssl"`
|
||||||
|
Region string `mapstructure:"region"`
|
||||||
|
Domain string `mapstructure:"domain"` // 自定义域名,如 s3.carrot.skin
|
||||||
|
}
|
||||||
|
|
||||||
|
type JWTConfig struct {
|
||||||
|
Secret string `mapstructure:"secret"`
|
||||||
|
AccessTokenExpire time.Duration `mapstructure:"access_token_expire"`
|
||||||
|
RefreshTokenExpire time.Duration `mapstructure:"refresh_token_expire"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LogConfig struct {
|
||||||
|
Level string `mapstructure:"level"`
|
||||||
|
Encoding string `mapstructure:"encoding"`
|
||||||
|
OutputPaths []string `mapstructure:"output_paths"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RateLimitConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
RequestsPerMinute int `mapstructure:"requests_per_minute"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UploadConfig struct {
|
||||||
|
MaxFileSize int64 `mapstructure:"max_file_size"`
|
||||||
|
AllowedTypes []string `mapstructure:"allowed_types"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GorseConfig struct {
|
||||||
|
Address string `mapstructure:"address"`
|
||||||
|
APIKey string `mapstructure:"api_key"`
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
Dashboard string `mapstructure:"dashboard"`
|
||||||
|
ImportPassword string `mapstructure:"import_password"`
|
||||||
|
EmbeddingAPIKey string `mapstructure:"embedding_api_key"`
|
||||||
|
EmbeddingURL string `mapstructure:"embedding_url"`
|
||||||
|
EmbeddingModel string `mapstructure:"embedding_model"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
BaseURL string `mapstructure:"base_url"`
|
||||||
|
APIKey string `mapstructure:"api_key"`
|
||||||
|
ModerationModel string `mapstructure:"moderation_model"`
|
||||||
|
ModerationMaxImagesPerRequest int `mapstructure:"moderation_max_images_per_request"`
|
||||||
|
RequestTimeout int `mapstructure:"request_timeout"`
|
||||||
|
StrictModeration bool `mapstructure:"strict_moderation"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmailConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
Host string `mapstructure:"host"`
|
||||||
|
Port int `mapstructure:"port"`
|
||||||
|
Username string `mapstructure:"username"`
|
||||||
|
Password string `mapstructure:"password"`
|
||||||
|
FromAddress string `mapstructure:"from_address"`
|
||||||
|
FromName string `mapstructure:"from_name"`
|
||||||
|
UseTLS bool `mapstructure:"use_tls"`
|
||||||
|
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"`
|
||||||
|
Timeout int `mapstructure:"timeout"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func Load(configPath string) (*Config, error) {
|
||||||
|
viper.SetConfigFile(configPath)
|
||||||
|
viper.SetConfigType("yaml")
|
||||||
|
|
||||||
|
// 启用环境变量支持
|
||||||
|
viper.SetEnvPrefix("APP")
|
||||||
|
viper.AutomaticEnv()
|
||||||
|
// 允许环境变量使用下划线或连字符
|
||||||
|
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_", "-", "_"))
|
||||||
|
|
||||||
|
// Set default values
|
||||||
|
viper.SetDefault("server.port", 8080)
|
||||||
|
viper.SetDefault("server.mode", "debug")
|
||||||
|
viper.SetDefault("server.host", "0.0.0.0")
|
||||||
|
viper.SetDefault("database.type", "sqlite")
|
||||||
|
viper.SetDefault("database.sqlite.path", "./data/carrot_bbs.db")
|
||||||
|
viper.SetDefault("database.max_idle_conns", 10)
|
||||||
|
viper.SetDefault("database.max_open_conns", 100)
|
||||||
|
viper.SetDefault("database.log_level", "warn")
|
||||||
|
viper.SetDefault("database.slow_threshold_ms", 200)
|
||||||
|
viper.SetDefault("database.ignore_record_not_found", true)
|
||||||
|
viper.SetDefault("database.parameterized_queries", true)
|
||||||
|
viper.SetDefault("redis.type", "miniredis")
|
||||||
|
viper.SetDefault("redis.redis.host", "localhost")
|
||||||
|
viper.SetDefault("redis.redis.port", 6379)
|
||||||
|
viper.SetDefault("redis.redis.password", "")
|
||||||
|
viper.SetDefault("redis.redis.db", 0)
|
||||||
|
viper.SetDefault("redis.miniredis.host", "localhost")
|
||||||
|
viper.SetDefault("redis.miniredis.port", 6379)
|
||||||
|
viper.SetDefault("redis.pool_size", 10)
|
||||||
|
viper.SetDefault("cache.enabled", true)
|
||||||
|
viper.SetDefault("cache.key_prefix", "")
|
||||||
|
viper.SetDefault("cache.default_ttl", 30)
|
||||||
|
viper.SetDefault("cache.null_ttl", 5)
|
||||||
|
viper.SetDefault("cache.jitter_ratio", 0.1)
|
||||||
|
viper.SetDefault("cache.disable_flushdb", true)
|
||||||
|
viper.SetDefault("cache.modules.post_list_ttl", 30)
|
||||||
|
viper.SetDefault("cache.modules.conversation_ttl", 60)
|
||||||
|
viper.SetDefault("cache.modules.unread_count_ttl", 30)
|
||||||
|
viper.SetDefault("cache.modules.group_members_ttl", 120)
|
||||||
|
viper.SetDefault("jwt.secret", "your-jwt-secret-key-change-in-production")
|
||||||
|
viper.SetDefault("jwt.access_token_expire", 86400)
|
||||||
|
viper.SetDefault("jwt.refresh_token_expire", 604800)
|
||||||
|
viper.SetDefault("log.level", "info")
|
||||||
|
viper.SetDefault("log.encoding", "json")
|
||||||
|
viper.SetDefault("log.output_paths", []string{"stdout", "./logs/app.log"})
|
||||||
|
viper.SetDefault("rate_limit.enabled", true)
|
||||||
|
viper.SetDefault("rate_limit.requests_per_minute", 60)
|
||||||
|
viper.SetDefault("upload.max_file_size", 10485760)
|
||||||
|
viper.SetDefault("upload.allowed_types", []string{"image/jpeg", "image/png", "image/gif", "image/webp"})
|
||||||
|
viper.SetDefault("s3.endpoint", "")
|
||||||
|
viper.SetDefault("s3.access_key", "")
|
||||||
|
viper.SetDefault("s3.secret_key", "")
|
||||||
|
viper.SetDefault("s3.bucket", "")
|
||||||
|
viper.SetDefault("s3.use_ssl", true)
|
||||||
|
viper.SetDefault("s3.region", "us-east-1")
|
||||||
|
viper.SetDefault("s3.domain", "")
|
||||||
|
viper.SetDefault("sensitive.enabled", true)
|
||||||
|
viper.SetDefault("sensitive.replace_str", "***")
|
||||||
|
viper.SetDefault("audit.enabled", false)
|
||||||
|
viper.SetDefault("audit.provider", "local")
|
||||||
|
viper.SetDefault("gorse.enabled", false)
|
||||||
|
viper.SetDefault("gorse.address", "http://localhost:8087")
|
||||||
|
viper.SetDefault("gorse.api_key", "")
|
||||||
|
viper.SetDefault("gorse.dashboard", "http://localhost:8088")
|
||||||
|
viper.SetDefault("gorse.import_password", "")
|
||||||
|
viper.SetDefault("gorse.embedding_api_key", "")
|
||||||
|
viper.SetDefault("gorse.embedding_url", "https://api.littlelan.cn/v1/embeddings")
|
||||||
|
viper.SetDefault("gorse.embedding_model", "BAAI/bge-m3")
|
||||||
|
viper.SetDefault("openai.enabled", true)
|
||||||
|
viper.SetDefault("openai.base_url", "https://api.littlelan.cn/")
|
||||||
|
viper.SetDefault("openai.api_key", "")
|
||||||
|
viper.SetDefault("openai.moderation_model", "qwen3.5-122b")
|
||||||
|
viper.SetDefault("openai.moderation_max_images_per_request", 1)
|
||||||
|
viper.SetDefault("openai.request_timeout", 30)
|
||||||
|
viper.SetDefault("openai.strict_moderation", false)
|
||||||
|
viper.SetDefault("email.enabled", false)
|
||||||
|
viper.SetDefault("email.host", "")
|
||||||
|
viper.SetDefault("email.port", 587)
|
||||||
|
viper.SetDefault("email.username", "")
|
||||||
|
viper.SetDefault("email.password", "")
|
||||||
|
viper.SetDefault("email.from_address", "")
|
||||||
|
viper.SetDefault("email.from_name", "Carrot BBS")
|
||||||
|
viper.SetDefault("email.use_tls", true)
|
||||||
|
viper.SetDefault("email.insecure_skip_verify", false)
|
||||||
|
viper.SetDefault("email.timeout", 15)
|
||||||
|
|
||||||
|
if err := viper.ReadInConfig(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg Config
|
||||||
|
if err := viper.Unmarshal(&cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert seconds to duration
|
||||||
|
cfg.JWT.AccessTokenExpire = time.Duration(viper.GetInt("jwt.access_token_expire")) * time.Second
|
||||||
|
cfg.JWT.RefreshTokenExpire = time.Duration(viper.GetInt("jwt.refresh_token_expire")) * time.Second
|
||||||
|
|
||||||
|
// 环境变量覆盖(显式处理敏感配置)
|
||||||
|
cfg.JWT.Secret = getEnvOrDefault("APP_JWT_SECRET", cfg.JWT.Secret)
|
||||||
|
cfg.Database.SQLite.Path = getEnvOrDefault("APP_DATABASE_SQLITE_PATH", cfg.Database.SQLite.Path)
|
||||||
|
cfg.Database.Postgres.Host = getEnvOrDefault("APP_DATABASE_POSTGRES_HOST", cfg.Database.Postgres.Host)
|
||||||
|
cfg.Database.Postgres.Port, _ = strconv.Atoi(getEnvOrDefault("APP_DATABASE_POSTGRES_PORT", fmt.Sprintf("%d", cfg.Database.Postgres.Port)))
|
||||||
|
cfg.Database.Postgres.User = getEnvOrDefault("APP_DATABASE_POSTGRES_USER", cfg.Database.Postgres.User)
|
||||||
|
cfg.Database.Postgres.Password = getEnvOrDefault("APP_DATABASE_POSTGRES_PASSWORD", cfg.Database.Postgres.Password)
|
||||||
|
cfg.Database.Postgres.DBName = getEnvOrDefault("APP_DATABASE_POSTGRES_DBNAME", cfg.Database.Postgres.DBName)
|
||||||
|
cfg.Database.LogLevel = getEnvOrDefault("APP_DATABASE_LOG_LEVEL", cfg.Database.LogLevel)
|
||||||
|
cfg.Database.SlowThresholdMs, _ = strconv.Atoi(getEnvOrDefault("APP_DATABASE_SLOW_THRESHOLD_MS", fmt.Sprintf("%d", cfg.Database.SlowThresholdMs)))
|
||||||
|
cfg.Database.IgnoreRecordNotFound, _ = strconv.ParseBool(getEnvOrDefault("APP_DATABASE_IGNORE_RECORD_NOT_FOUND", fmt.Sprintf("%t", cfg.Database.IgnoreRecordNotFound)))
|
||||||
|
cfg.Database.ParameterizedQueries, _ = strconv.ParseBool(getEnvOrDefault("APP_DATABASE_PARAMETERIZED_QUERIES", fmt.Sprintf("%t", cfg.Database.ParameterizedQueries)))
|
||||||
|
cfg.Redis.Redis.Host = getEnvOrDefault("APP_REDIS_REDIS_HOST", cfg.Redis.Redis.Host)
|
||||||
|
cfg.Redis.Redis.Port, _ = strconv.Atoi(getEnvOrDefault("APP_REDIS_REDIS_PORT", fmt.Sprintf("%d", cfg.Redis.Redis.Port)))
|
||||||
|
cfg.Redis.Redis.Password = getEnvOrDefault("APP_REDIS_REDIS_PASSWORD", cfg.Redis.Redis.Password)
|
||||||
|
cfg.Redis.Redis.DB, _ = strconv.Atoi(getEnvOrDefault("APP_REDIS_REDIS_DB", fmt.Sprintf("%d", cfg.Redis.Redis.DB)))
|
||||||
|
cfg.Redis.Miniredis.Host = getEnvOrDefault("APP_REDIS_MINIREDIS_HOST", cfg.Redis.Miniredis.Host)
|
||||||
|
cfg.Redis.Miniredis.Port, _ = strconv.Atoi(getEnvOrDefault("APP_REDIS_MINIREDIS_PORT", fmt.Sprintf("%d", cfg.Redis.Miniredis.Port)))
|
||||||
|
cfg.Redis.Type = getEnvOrDefault("APP_REDIS_TYPE", cfg.Redis.Type)
|
||||||
|
cfg.Cache.KeyPrefix = getEnvOrDefault("APP_CACHE_KEY_PREFIX", cfg.Cache.KeyPrefix)
|
||||||
|
cfg.Cache.Enabled, _ = strconv.ParseBool(getEnvOrDefault("APP_CACHE_ENABLED", fmt.Sprintf("%t", cfg.Cache.Enabled)))
|
||||||
|
cfg.Cache.DisableFlushDB, _ = strconv.ParseBool(getEnvOrDefault("APP_CACHE_DISABLE_FLUSHDB", fmt.Sprintf("%t", cfg.Cache.DisableFlushDB)))
|
||||||
|
cfg.Cache.DefaultTTL, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_DEFAULT_TTL", fmt.Sprintf("%d", cfg.Cache.DefaultTTL)))
|
||||||
|
cfg.Cache.NullTTL, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_NULL_TTL", fmt.Sprintf("%d", cfg.Cache.NullTTL)))
|
||||||
|
cfg.Cache.JitterRatio, _ = strconv.ParseFloat(getEnvOrDefault("APP_CACHE_JITTER_RATIO", fmt.Sprintf("%.2f", cfg.Cache.JitterRatio)), 64)
|
||||||
|
cfg.Cache.Modules.PostList, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_MODULES_POST_LIST_TTL", fmt.Sprintf("%d", cfg.Cache.Modules.PostList)))
|
||||||
|
cfg.Cache.Modules.Conversation, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_MODULES_CONVERSATION_TTL", fmt.Sprintf("%d", cfg.Cache.Modules.Conversation)))
|
||||||
|
cfg.Cache.Modules.UnreadCount, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_MODULES_UNREAD_COUNT_TTL", fmt.Sprintf("%d", cfg.Cache.Modules.UnreadCount)))
|
||||||
|
cfg.Cache.Modules.GroupMembers, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_MODULES_GROUP_MEMBERS_TTL", fmt.Sprintf("%d", cfg.Cache.Modules.GroupMembers)))
|
||||||
|
cfg.S3.Endpoint = getEnvOrDefault("APP_S3_ENDPOINT", cfg.S3.Endpoint)
|
||||||
|
cfg.S3.AccessKey = getEnvOrDefault("APP_S3_ACCESS_KEY", cfg.S3.AccessKey)
|
||||||
|
cfg.S3.SecretKey = getEnvOrDefault("APP_S3_SECRET_KEY", cfg.S3.SecretKey)
|
||||||
|
cfg.S3.Bucket = getEnvOrDefault("APP_S3_BUCKET", cfg.S3.Bucket)
|
||||||
|
cfg.S3.Domain = getEnvOrDefault("APP_S3_DOMAIN", cfg.S3.Domain)
|
||||||
|
cfg.Server.Host = getEnvOrDefault("APP_SERVER_HOST", cfg.Server.Host)
|
||||||
|
cfg.Server.Port, _ = strconv.Atoi(getEnvOrDefault("APP_SERVER_PORT", fmt.Sprintf("%d", cfg.Server.Port)))
|
||||||
|
cfg.Server.Mode = getEnvOrDefault("APP_SERVER_MODE", cfg.Server.Mode)
|
||||||
|
cfg.Gorse.Address = getEnvOrDefault("APP_GORSE_ADDRESS", cfg.Gorse.Address)
|
||||||
|
cfg.Gorse.APIKey = getEnvOrDefault("APP_GORSE_API_KEY", cfg.Gorse.APIKey)
|
||||||
|
cfg.Gorse.Dashboard = getEnvOrDefault("APP_GORSE_DASHBOARD", cfg.Gorse.Dashboard)
|
||||||
|
cfg.Gorse.ImportPassword = getEnvOrDefault("APP_GORSE_IMPORT_PASSWORD", cfg.Gorse.ImportPassword)
|
||||||
|
cfg.Gorse.EmbeddingAPIKey = getEnvOrDefault("APP_GORSE_EMBEDDING_API_KEY", cfg.Gorse.EmbeddingAPIKey)
|
||||||
|
cfg.Gorse.EmbeddingURL = getEnvOrDefault("APP_GORSE_EMBEDDING_URL", cfg.Gorse.EmbeddingURL)
|
||||||
|
cfg.Gorse.EmbeddingModel = getEnvOrDefault("APP_GORSE_EMBEDDING_MODEL", cfg.Gorse.EmbeddingModel)
|
||||||
|
cfg.OpenAI.BaseURL = getEnvOrDefault("APP_OPENAI_BASE_URL", cfg.OpenAI.BaseURL)
|
||||||
|
cfg.OpenAI.APIKey = getEnvOrDefault("APP_OPENAI_API_KEY", cfg.OpenAI.APIKey)
|
||||||
|
cfg.OpenAI.ModerationModel = getEnvOrDefault("APP_OPENAI_MODERATION_MODEL", cfg.OpenAI.ModerationModel)
|
||||||
|
cfg.OpenAI.ModerationMaxImagesPerRequest, _ = strconv.Atoi(getEnvOrDefault("APP_OPENAI_MODERATION_MAX_IMAGES_PER_REQUEST", fmt.Sprintf("%d", cfg.OpenAI.ModerationMaxImagesPerRequest)))
|
||||||
|
cfg.OpenAI.RequestTimeout, _ = strconv.Atoi(getEnvOrDefault("APP_OPENAI_REQUEST_TIMEOUT", fmt.Sprintf("%d", cfg.OpenAI.RequestTimeout)))
|
||||||
|
cfg.OpenAI.Enabled, _ = strconv.ParseBool(getEnvOrDefault("APP_OPENAI_ENABLED", fmt.Sprintf("%t", cfg.OpenAI.Enabled)))
|
||||||
|
cfg.OpenAI.StrictModeration, _ = strconv.ParseBool(getEnvOrDefault("APP_OPENAI_STRICT_MODERATION", fmt.Sprintf("%t", cfg.OpenAI.StrictModeration)))
|
||||||
|
cfg.Email.Enabled, _ = strconv.ParseBool(getEnvOrDefault("APP_EMAIL_ENABLED", fmt.Sprintf("%t", cfg.Email.Enabled)))
|
||||||
|
cfg.Email.Host = getEnvOrDefault("APP_EMAIL_HOST", cfg.Email.Host)
|
||||||
|
cfg.Email.Port, _ = strconv.Atoi(getEnvOrDefault("APP_EMAIL_PORT", fmt.Sprintf("%d", cfg.Email.Port)))
|
||||||
|
cfg.Email.Username = getEnvOrDefault("APP_EMAIL_USERNAME", cfg.Email.Username)
|
||||||
|
cfg.Email.Password = getEnvOrDefault("APP_EMAIL_PASSWORD", cfg.Email.Password)
|
||||||
|
cfg.Email.FromAddress = getEnvOrDefault("APP_EMAIL_FROM_ADDRESS", cfg.Email.FromAddress)
|
||||||
|
cfg.Email.FromName = getEnvOrDefault("APP_EMAIL_FROM_NAME", cfg.Email.FromName)
|
||||||
|
cfg.Email.UseTLS, _ = strconv.ParseBool(getEnvOrDefault("APP_EMAIL_USE_TLS", fmt.Sprintf("%t", cfg.Email.UseTLS)))
|
||||||
|
cfg.Email.InsecureSkipVerify, _ = strconv.ParseBool(getEnvOrDefault("APP_EMAIL_INSECURE_SKIP_VERIFY", fmt.Sprintf("%t", cfg.Email.InsecureSkipVerify)))
|
||||||
|
cfg.Email.Timeout, _ = strconv.Atoi(getEnvOrDefault("APP_EMAIL_TIMEOUT", fmt.Sprintf("%d", cfg.Email.Timeout)))
|
||||||
|
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEnvOrDefault 获取环境变量值,如果未设置则返回默认值
|
||||||
|
func getEnvOrDefault(key, defaultValue string) string {
|
||||||
|
if value := os.Getenv(key); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRedis 创建Redis客户端(真实Redis)
|
||||||
|
func NewRedis(cfg *RedisConfig) (*redis.Client, error) {
|
||||||
|
client := redis.NewClient(&redis.Options{
|
||||||
|
Addr: cfg.Redis.Addr(),
|
||||||
|
Password: cfg.Redis.Password,
|
||||||
|
DB: cfg.Redis.DB,
|
||||||
|
PoolSize: cfg.PoolSize,
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := client.Ping(ctx).Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to connect to redis: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewS3 创建S3客户端
|
||||||
|
func NewS3(cfg *S3Config) (*minio.Client, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
client, err := minio.New(cfg.Endpoint, &minio.Options{
|
||||||
|
Creds: credentials.NewStaticV4(cfg.AccessKey, cfg.SecretKey, ""),
|
||||||
|
Secure: cfg.UseSSL,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create S3 client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err := client.BucketExists(ctx, cfg.Bucket)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check bucket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
if err := client.MakeBucket(ctx, cfg.Bucket, minio.MakeBucketOptions{
|
||||||
|
Region: cfg.Region,
|
||||||
|
}); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create bucket: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
885
internal/dto/converter.go
Normal file
885
internal/dto/converter.go
Normal file
@@ -0,0 +1,885 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/pkg/utils"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==================== User 转换 ====================
|
||||||
|
|
||||||
|
// getAvatarOrDefault 获取头像URL,如果为空则返回在线头像生成服务的URL
|
||||||
|
func getAvatarOrDefault(user *model.User) string {
|
||||||
|
return utils.GetAvatarOrDefault(user.Username, user.Nickname, user.Avatar)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertUserToResponse 将User转换为UserResponse
|
||||||
|
func ConvertUserToResponse(user *model.User) *UserResponse {
|
||||||
|
if user == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &UserResponse{
|
||||||
|
ID: user.ID,
|
||||||
|
Username: user.Username,
|
||||||
|
Nickname: user.Nickname,
|
||||||
|
Email: user.Email,
|
||||||
|
Phone: user.Phone,
|
||||||
|
EmailVerified: user.EmailVerified,
|
||||||
|
Avatar: getAvatarOrDefault(user),
|
||||||
|
CoverURL: user.CoverURL,
|
||||||
|
Bio: user.Bio,
|
||||||
|
Website: user.Website,
|
||||||
|
Location: user.Location,
|
||||||
|
PostsCount: user.PostsCount,
|
||||||
|
FollowersCount: user.FollowersCount,
|
||||||
|
FollowingCount: user.FollowingCount,
|
||||||
|
CreatedAt: FormatTime(user.CreatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertUserToResponseWithFollowing 将User转换为UserResponse(包含关注状态)
|
||||||
|
func ConvertUserToResponseWithFollowing(user *model.User, isFollowing bool) *UserResponse {
|
||||||
|
if user == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &UserResponse{
|
||||||
|
ID: user.ID,
|
||||||
|
Username: user.Username,
|
||||||
|
Nickname: user.Nickname,
|
||||||
|
Email: user.Email,
|
||||||
|
Phone: user.Phone,
|
||||||
|
EmailVerified: user.EmailVerified,
|
||||||
|
Avatar: getAvatarOrDefault(user),
|
||||||
|
CoverURL: user.CoverURL,
|
||||||
|
Bio: user.Bio,
|
||||||
|
Website: user.Website,
|
||||||
|
Location: user.Location,
|
||||||
|
PostsCount: user.PostsCount,
|
||||||
|
FollowersCount: user.FollowersCount,
|
||||||
|
FollowingCount: user.FollowingCount,
|
||||||
|
IsFollowing: isFollowing,
|
||||||
|
IsFollowingMe: false, // 默认false,需要单独计算
|
||||||
|
CreatedAt: FormatTime(user.CreatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertUserToResponseWithPostsCount 将User转换为UserResponse(使用实时计算的帖子数量)
|
||||||
|
func ConvertUserToResponseWithPostsCount(user *model.User, postsCount int) *UserResponse {
|
||||||
|
if user == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &UserResponse{
|
||||||
|
ID: user.ID,
|
||||||
|
Username: user.Username,
|
||||||
|
Nickname: user.Nickname,
|
||||||
|
Email: user.Email,
|
||||||
|
Phone: user.Phone,
|
||||||
|
EmailVerified: user.EmailVerified,
|
||||||
|
Avatar: getAvatarOrDefault(user),
|
||||||
|
CoverURL: user.CoverURL,
|
||||||
|
Bio: user.Bio,
|
||||||
|
Website: user.Website,
|
||||||
|
Location: user.Location,
|
||||||
|
PostsCount: postsCount,
|
||||||
|
FollowersCount: user.FollowersCount,
|
||||||
|
FollowingCount: user.FollowingCount,
|
||||||
|
CreatedAt: FormatTime(user.CreatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertUserToResponseWithMutualFollow 将User转换为UserResponse(包含双向关注状态)
|
||||||
|
func ConvertUserToResponseWithMutualFollow(user *model.User, isFollowing, isFollowingMe bool) *UserResponse {
|
||||||
|
if user == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &UserResponse{
|
||||||
|
ID: user.ID,
|
||||||
|
Username: user.Username,
|
||||||
|
Nickname: user.Nickname,
|
||||||
|
Email: user.Email,
|
||||||
|
Phone: user.Phone,
|
||||||
|
EmailVerified: user.EmailVerified,
|
||||||
|
Avatar: getAvatarOrDefault(user),
|
||||||
|
CoverURL: user.CoverURL,
|
||||||
|
Bio: user.Bio,
|
||||||
|
Website: user.Website,
|
||||||
|
Location: user.Location,
|
||||||
|
PostsCount: user.PostsCount,
|
||||||
|
FollowersCount: user.FollowersCount,
|
||||||
|
FollowingCount: user.FollowingCount,
|
||||||
|
IsFollowing: isFollowing,
|
||||||
|
IsFollowingMe: isFollowingMe,
|
||||||
|
CreatedAt: FormatTime(user.CreatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertUserToResponseWithMutualFollowAndPostsCount 将User转换为UserResponse(包含双向关注状态和实时计算的帖子数量)
|
||||||
|
func ConvertUserToResponseWithMutualFollowAndPostsCount(user *model.User, isFollowing, isFollowingMe bool, postsCount int) *UserResponse {
|
||||||
|
if user == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &UserResponse{
|
||||||
|
ID: user.ID,
|
||||||
|
Username: user.Username,
|
||||||
|
Nickname: user.Nickname,
|
||||||
|
Email: user.Email,
|
||||||
|
Phone: user.Phone,
|
||||||
|
EmailVerified: user.EmailVerified,
|
||||||
|
Avatar: getAvatarOrDefault(user),
|
||||||
|
CoverURL: user.CoverURL,
|
||||||
|
Bio: user.Bio,
|
||||||
|
Website: user.Website,
|
||||||
|
Location: user.Location,
|
||||||
|
PostsCount: postsCount,
|
||||||
|
FollowersCount: user.FollowersCount,
|
||||||
|
FollowingCount: user.FollowingCount,
|
||||||
|
IsFollowing: isFollowing,
|
||||||
|
IsFollowingMe: isFollowingMe,
|
||||||
|
CreatedAt: FormatTime(user.CreatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertUserToDetailResponse 将User转换为UserDetailResponse
|
||||||
|
func ConvertUserToDetailResponse(user *model.User) *UserDetailResponse {
|
||||||
|
if user == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &UserDetailResponse{
|
||||||
|
ID: user.ID,
|
||||||
|
Username: user.Username,
|
||||||
|
Nickname: user.Nickname,
|
||||||
|
Email: user.Email,
|
||||||
|
EmailVerified: user.EmailVerified,
|
||||||
|
Avatar: getAvatarOrDefault(user),
|
||||||
|
CoverURL: user.CoverURL,
|
||||||
|
Bio: user.Bio,
|
||||||
|
Website: user.Website,
|
||||||
|
Location: user.Location,
|
||||||
|
PostsCount: user.PostsCount,
|
||||||
|
FollowersCount: user.FollowersCount,
|
||||||
|
FollowingCount: user.FollowingCount,
|
||||||
|
IsVerified: user.IsVerified,
|
||||||
|
CreatedAt: FormatTime(user.CreatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertUserToDetailResponseWithPostsCount 将User转换为UserDetailResponse(使用实时计算的帖子数量)
|
||||||
|
func ConvertUserToDetailResponseWithPostsCount(user *model.User, postsCount int) *UserDetailResponse {
|
||||||
|
if user == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &UserDetailResponse{
|
||||||
|
ID: user.ID,
|
||||||
|
Username: user.Username,
|
||||||
|
Nickname: user.Nickname,
|
||||||
|
Email: user.Email,
|
||||||
|
EmailVerified: user.EmailVerified,
|
||||||
|
Phone: user.Phone, // 仅当前用户自己可见
|
||||||
|
Avatar: getAvatarOrDefault(user),
|
||||||
|
CoverURL: user.CoverURL,
|
||||||
|
Bio: user.Bio,
|
||||||
|
Website: user.Website,
|
||||||
|
Location: user.Location,
|
||||||
|
PostsCount: postsCount,
|
||||||
|
FollowersCount: user.FollowersCount,
|
||||||
|
FollowingCount: user.FollowingCount,
|
||||||
|
IsVerified: user.IsVerified,
|
||||||
|
CreatedAt: FormatTime(user.CreatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertUsersToResponse 将User列表转换为响应列表
|
||||||
|
func ConvertUsersToResponse(users []*model.User) []*UserResponse {
|
||||||
|
result := make([]*UserResponse, 0, len(users))
|
||||||
|
for _, user := range users {
|
||||||
|
result = append(result, ConvertUserToResponse(user))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertUsersToResponseWithMutualFollow 将User列表转换为响应列表(包含双向关注状态)
|
||||||
|
// followingStatusMap: key是用户ID,value是[isFollowing, isFollowingMe]
|
||||||
|
func ConvertUsersToResponseWithMutualFollow(users []*model.User, followingStatusMap map[string][2]bool) []*UserResponse {
|
||||||
|
result := make([]*UserResponse, 0, len(users))
|
||||||
|
for _, user := range users {
|
||||||
|
status, ok := followingStatusMap[user.ID]
|
||||||
|
if ok {
|
||||||
|
result = append(result, ConvertUserToResponseWithMutualFollow(user, status[0], status[1]))
|
||||||
|
} else {
|
||||||
|
result = append(result, ConvertUserToResponse(user))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertUsersToResponseWithMutualFollowAndPostsCount 将User列表转换为响应列表(包含双向关注状态和实时计算的帖子数量)
|
||||||
|
// followingStatusMap: key是用户ID,value是[isFollowing, isFollowingMe]
|
||||||
|
// postsCountMap: key是用户ID,value是帖子数量
|
||||||
|
func ConvertUsersToResponseWithMutualFollowAndPostsCount(users []*model.User, followingStatusMap map[string][2]bool, postsCountMap map[string]int64) []*UserResponse {
|
||||||
|
result := make([]*UserResponse, 0, len(users))
|
||||||
|
for _, user := range users {
|
||||||
|
status, hasStatus := followingStatusMap[user.ID]
|
||||||
|
postsCount, hasPostsCount := postsCountMap[user.ID]
|
||||||
|
|
||||||
|
// 如果没有帖子数量,使用数据库中的值
|
||||||
|
if !hasPostsCount {
|
||||||
|
postsCount = int64(user.PostsCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasStatus {
|
||||||
|
result = append(result, ConvertUserToResponseWithMutualFollowAndPostsCount(user, status[0], status[1], int(postsCount)))
|
||||||
|
} else {
|
||||||
|
result = append(result, ConvertUserToResponseWithPostsCount(user, int(postsCount)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Post 转换 ====================
|
||||||
|
|
||||||
|
// ConvertPostImageToResponse 将PostImage转换为PostImageResponse
|
||||||
|
func ConvertPostImageToResponse(img *model.PostImage) PostImageResponse {
|
||||||
|
if img == nil {
|
||||||
|
return PostImageResponse{}
|
||||||
|
}
|
||||||
|
return PostImageResponse{
|
||||||
|
ID: img.ID,
|
||||||
|
URL: img.URL,
|
||||||
|
ThumbnailURL: img.ThumbnailURL,
|
||||||
|
Width: img.Width,
|
||||||
|
Height: img.Height,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertPostImagesToResponse 将PostImage列表转换为响应列表
|
||||||
|
func ConvertPostImagesToResponse(images []model.PostImage) []PostImageResponse {
|
||||||
|
result := make([]PostImageResponse, 0, len(images))
|
||||||
|
for i := range images {
|
||||||
|
result = append(result, ConvertPostImageToResponse(&images[i]))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertPostToResponse 将Post转换为PostResponse(列表用)
|
||||||
|
func ConvertPostToResponse(post *model.Post, isLiked, isFavorited bool) *PostResponse {
|
||||||
|
if post == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
images := make([]PostImageResponse, 0)
|
||||||
|
for _, img := range post.Images {
|
||||||
|
images = append(images, ConvertPostImageToResponse(&img))
|
||||||
|
}
|
||||||
|
|
||||||
|
var author *UserResponse
|
||||||
|
if post.User != nil {
|
||||||
|
author = ConvertUserToResponse(post.User)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PostResponse{
|
||||||
|
ID: post.ID,
|
||||||
|
UserID: post.UserID,
|
||||||
|
Title: post.Title,
|
||||||
|
Content: post.Content,
|
||||||
|
Images: images,
|
||||||
|
LikesCount: post.LikesCount,
|
||||||
|
CommentsCount: post.CommentsCount,
|
||||||
|
FavoritesCount: post.FavoritesCount,
|
||||||
|
SharesCount: post.SharesCount,
|
||||||
|
ViewsCount: post.ViewsCount,
|
||||||
|
IsPinned: post.IsPinned,
|
||||||
|
IsLocked: post.IsLocked,
|
||||||
|
IsVote: post.IsVote,
|
||||||
|
CreatedAt: FormatTime(post.CreatedAt),
|
||||||
|
Author: author,
|
||||||
|
IsLiked: isLiked,
|
||||||
|
IsFavorited: isFavorited,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertPostToDetailResponse 将Post转换为PostDetailResponse
|
||||||
|
func ConvertPostToDetailResponse(post *model.Post, isLiked, isFavorited bool) *PostDetailResponse {
|
||||||
|
if post == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
images := make([]PostImageResponse, 0)
|
||||||
|
for _, img := range post.Images {
|
||||||
|
images = append(images, ConvertPostImageToResponse(&img))
|
||||||
|
}
|
||||||
|
|
||||||
|
var author *UserResponse
|
||||||
|
if post.User != nil {
|
||||||
|
author = ConvertUserToResponse(post.User)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PostDetailResponse{
|
||||||
|
ID: post.ID,
|
||||||
|
UserID: post.UserID,
|
||||||
|
Title: post.Title,
|
||||||
|
Content: post.Content,
|
||||||
|
Images: images,
|
||||||
|
Status: string(post.Status),
|
||||||
|
LikesCount: post.LikesCount,
|
||||||
|
CommentsCount: post.CommentsCount,
|
||||||
|
FavoritesCount: post.FavoritesCount,
|
||||||
|
SharesCount: post.SharesCount,
|
||||||
|
ViewsCount: post.ViewsCount,
|
||||||
|
IsPinned: post.IsPinned,
|
||||||
|
IsLocked: post.IsLocked,
|
||||||
|
IsVote: post.IsVote,
|
||||||
|
CreatedAt: FormatTime(post.CreatedAt),
|
||||||
|
UpdatedAt: FormatTime(post.UpdatedAt),
|
||||||
|
Author: author,
|
||||||
|
IsLiked: isLiked,
|
||||||
|
IsFavorited: isFavorited,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertPostsToResponse 将Post列表转换为响应列表(每个帖子独立检查点赞/收藏状态)
|
||||||
|
func ConvertPostsToResponse(posts []*model.Post, isLikedMap, isFavoritedMap map[string]bool) []*PostResponse {
|
||||||
|
result := make([]*PostResponse, 0, len(posts))
|
||||||
|
for _, post := range posts {
|
||||||
|
isLiked := false
|
||||||
|
isFavorited := false
|
||||||
|
if isLikedMap != nil {
|
||||||
|
isLiked = isLikedMap[post.ID]
|
||||||
|
}
|
||||||
|
if isFavoritedMap != nil {
|
||||||
|
isFavorited = isFavoritedMap[post.ID]
|
||||||
|
}
|
||||||
|
result = append(result, ConvertPostToResponse(post, isLiked, isFavorited))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Comment 转换 ====================
|
||||||
|
|
||||||
|
// ConvertCommentToResponse 将Comment转换为CommentResponse
|
||||||
|
func ConvertCommentToResponse(comment *model.Comment, isLiked bool) *CommentResponse {
|
||||||
|
if comment == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var author *UserResponse
|
||||||
|
if comment.User != nil {
|
||||||
|
author = ConvertUserToResponse(comment.User)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换子回复(扁平化结构)
|
||||||
|
var replies []*CommentResponse
|
||||||
|
if len(comment.Replies) > 0 {
|
||||||
|
replies = make([]*CommentResponse, 0, len(comment.Replies))
|
||||||
|
for _, reply := range comment.Replies {
|
||||||
|
replies = append(replies, ConvertCommentToResponse(reply, false))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TargetID 就是 ParentID,前端根据这个 ID 找到被回复用户的昵称
|
||||||
|
var targetID *string
|
||||||
|
if comment.ParentID != nil && *comment.ParentID != "" {
|
||||||
|
targetID = comment.ParentID
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析图片JSON
|
||||||
|
var images []CommentImageResponse
|
||||||
|
if comment.Images != "" {
|
||||||
|
var urlList []string
|
||||||
|
if err := json.Unmarshal([]byte(comment.Images), &urlList); err == nil {
|
||||||
|
images = make([]CommentImageResponse, 0, len(urlList))
|
||||||
|
for _, url := range urlList {
|
||||||
|
images = append(images, CommentImageResponse{URL: url})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CommentResponse{
|
||||||
|
ID: comment.ID,
|
||||||
|
PostID: comment.PostID,
|
||||||
|
UserID: comment.UserID,
|
||||||
|
ParentID: comment.ParentID,
|
||||||
|
RootID: comment.RootID,
|
||||||
|
Content: comment.Content,
|
||||||
|
Images: images,
|
||||||
|
LikesCount: comment.LikesCount,
|
||||||
|
RepliesCount: comment.RepliesCount,
|
||||||
|
CreatedAt: FormatTime(comment.CreatedAt),
|
||||||
|
Author: author,
|
||||||
|
IsLiked: isLiked,
|
||||||
|
TargetID: targetID,
|
||||||
|
Replies: replies,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertCommentsToResponse 将Comment列表转换为响应列表
|
||||||
|
func ConvertCommentsToResponse(comments []*model.Comment, isLiked bool) []*CommentResponse {
|
||||||
|
result := make([]*CommentResponse, 0, len(comments))
|
||||||
|
for _, comment := range comments {
|
||||||
|
result = append(result, ConvertCommentToResponse(comment, isLiked))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLikedChecker 点赞状态检查器接口
|
||||||
|
type IsLikedChecker interface {
|
||||||
|
IsLiked(ctx context.Context, commentID, userID string) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertCommentToResponseWithUser 将Comment转换为CommentResponse(根据用户ID检查点赞状态)
|
||||||
|
func ConvertCommentToResponseWithUser(comment *model.Comment, userID string, checker IsLikedChecker) *CommentResponse {
|
||||||
|
if comment == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查当前用户是否点赞了该评论
|
||||||
|
isLiked := false
|
||||||
|
if userID != "" && checker != nil {
|
||||||
|
isLiked = checker.IsLiked(context.Background(), comment.ID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var author *UserResponse
|
||||||
|
if comment.User != nil {
|
||||||
|
author = ConvertUserToResponse(comment.User)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换子回复(扁平化结构),递归检查点赞状态
|
||||||
|
var replies []*CommentResponse
|
||||||
|
if len(comment.Replies) > 0 {
|
||||||
|
replies = make([]*CommentResponse, 0, len(comment.Replies))
|
||||||
|
for _, reply := range comment.Replies {
|
||||||
|
replies = append(replies, ConvertCommentToResponseWithUser(reply, userID, checker))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TargetID 就是 ParentID,前端根据这个 ID 找到被回复用户的昵称
|
||||||
|
var targetID *string
|
||||||
|
if comment.ParentID != nil && *comment.ParentID != "" {
|
||||||
|
targetID = comment.ParentID
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析图片JSON
|
||||||
|
var images []CommentImageResponse
|
||||||
|
if comment.Images != "" {
|
||||||
|
var urlList []string
|
||||||
|
if err := json.Unmarshal([]byte(comment.Images), &urlList); err == nil {
|
||||||
|
images = make([]CommentImageResponse, 0, len(urlList))
|
||||||
|
for _, url := range urlList {
|
||||||
|
images = append(images, CommentImageResponse{URL: url})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CommentResponse{
|
||||||
|
ID: comment.ID,
|
||||||
|
PostID: comment.PostID,
|
||||||
|
UserID: comment.UserID,
|
||||||
|
ParentID: comment.ParentID,
|
||||||
|
RootID: comment.RootID,
|
||||||
|
Content: comment.Content,
|
||||||
|
Images: images,
|
||||||
|
LikesCount: comment.LikesCount,
|
||||||
|
RepliesCount: comment.RepliesCount,
|
||||||
|
CreatedAt: FormatTime(comment.CreatedAt),
|
||||||
|
Author: author,
|
||||||
|
IsLiked: isLiked,
|
||||||
|
TargetID: targetID,
|
||||||
|
Replies: replies,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertCommentsToResponseWithUser 将Comment列表转换为响应列表(根据用户ID检查点赞状态)
|
||||||
|
func ConvertCommentsToResponseWithUser(comments []*model.Comment, userID string, checker IsLikedChecker) []*CommentResponse {
|
||||||
|
result := make([]*CommentResponse, 0, len(comments))
|
||||||
|
for _, comment := range comments {
|
||||||
|
result = append(result, ConvertCommentToResponseWithUser(comment, userID, checker))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Notification 转换 ====================
|
||||||
|
|
||||||
|
// ConvertNotificationToResponse 将Notification转换为NotificationResponse
|
||||||
|
func ConvertNotificationToResponse(notification *model.Notification) *NotificationResponse {
|
||||||
|
if notification == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &NotificationResponse{
|
||||||
|
ID: notification.ID,
|
||||||
|
UserID: notification.UserID,
|
||||||
|
Type: string(notification.Type),
|
||||||
|
Title: notification.Title,
|
||||||
|
Content: notification.Content,
|
||||||
|
Data: notification.Data,
|
||||||
|
IsRead: notification.IsRead,
|
||||||
|
CreatedAt: FormatTime(notification.CreatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertNotificationsToResponse 将Notification列表转换为响应列表
|
||||||
|
func ConvertNotificationsToResponse(notifications []*model.Notification) []*NotificationResponse {
|
||||||
|
result := make([]*NotificationResponse, 0, len(notifications))
|
||||||
|
for _, n := range notifications {
|
||||||
|
result = append(result, ConvertNotificationToResponse(n))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Message 转换 ====================
|
||||||
|
|
||||||
|
// ConvertMessageToResponse 将Message转换为MessageResponse
|
||||||
|
func ConvertMessageToResponse(message *model.Message) *MessageResponse {
|
||||||
|
if message == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 直接使用 segments,不需要解析
|
||||||
|
segments := make(model.MessageSegments, len(message.Segments))
|
||||||
|
for i, seg := range message.Segments {
|
||||||
|
segments[i] = model.MessageSegment{
|
||||||
|
Type: seg.Type,
|
||||||
|
Data: seg.Data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &MessageResponse{
|
||||||
|
ID: message.ID,
|
||||||
|
ConversationID: message.ConversationID,
|
||||||
|
SenderID: message.SenderID,
|
||||||
|
Seq: message.Seq,
|
||||||
|
Segments: segments,
|
||||||
|
ReplyToID: message.ReplyToID,
|
||||||
|
Status: string(message.Status),
|
||||||
|
Category: string(message.Category),
|
||||||
|
CreatedAt: FormatTime(message.CreatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertConversationToResponse 将Conversation转换为ConversationResponse
|
||||||
|
// participants: 会话参与者列表(用户信息)
|
||||||
|
// unreadCount: 当前用户的未读消息数
|
||||||
|
// lastMessage: 最后一条消息
|
||||||
|
func ConvertConversationToResponse(conv *model.Conversation, participants []*model.User, unreadCount int, lastMessage *model.Message, isPinned bool) *ConversationResponse {
|
||||||
|
if conv == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var participantResponses []*UserResponse
|
||||||
|
for _, p := range participants {
|
||||||
|
participantResponses = append(participantResponses, ConvertUserToResponse(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换群组信息
|
||||||
|
var groupResponse *GroupResponse
|
||||||
|
if conv.Group != nil {
|
||||||
|
groupResponse = GroupToResponse(conv.Group)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ConversationResponse{
|
||||||
|
ID: conv.ID,
|
||||||
|
Type: string(conv.Type),
|
||||||
|
IsPinned: isPinned,
|
||||||
|
Group: groupResponse,
|
||||||
|
LastSeq: conv.LastSeq,
|
||||||
|
LastMessage: ConvertMessageToResponse(lastMessage),
|
||||||
|
LastMessageAt: FormatTimePointer(conv.LastMsgTime),
|
||||||
|
UnreadCount: unreadCount,
|
||||||
|
Participants: participantResponses,
|
||||||
|
CreatedAt: FormatTime(conv.CreatedAt),
|
||||||
|
UpdatedAt: FormatTime(conv.UpdatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertConversationToDetailResponse 将Conversation转换为ConversationDetailResponse
|
||||||
|
func ConvertConversationToDetailResponse(conv *model.Conversation, participants []*model.User, unreadCount int64, lastMessage *model.Message, myLastReadSeq int64, otherLastReadSeq int64, isPinned bool) *ConversationDetailResponse {
|
||||||
|
if conv == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var participantResponses []*UserResponse
|
||||||
|
for _, p := range participants {
|
||||||
|
participantResponses = append(participantResponses, ConvertUserToResponse(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ConversationDetailResponse{
|
||||||
|
ID: conv.ID,
|
||||||
|
Type: string(conv.Type),
|
||||||
|
IsPinned: isPinned,
|
||||||
|
LastSeq: conv.LastSeq,
|
||||||
|
LastMessage: ConvertMessageToResponse(lastMessage),
|
||||||
|
LastMessageAt: FormatTimePointer(conv.LastMsgTime),
|
||||||
|
UnreadCount: unreadCount,
|
||||||
|
Participants: participantResponses,
|
||||||
|
MyLastReadSeq: myLastReadSeq,
|
||||||
|
OtherLastReadSeq: otherLastReadSeq,
|
||||||
|
CreatedAt: FormatTime(conv.CreatedAt),
|
||||||
|
UpdatedAt: FormatTime(conv.UpdatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertMessagesToResponse 将Message列表转换为响应列表
|
||||||
|
func ConvertMessagesToResponse(messages []*model.Message) []*MessageResponse {
|
||||||
|
result := make([]*MessageResponse, 0, len(messages))
|
||||||
|
for _, msg := range messages {
|
||||||
|
result = append(result, ConvertMessageToResponse(msg))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertConversationsToResponse 将Conversation列表转换为响应列表
|
||||||
|
func ConvertConversationsToResponse(convs []*model.Conversation) []*ConversationResponse {
|
||||||
|
result := make([]*ConversationResponse, 0, len(convs))
|
||||||
|
for _, conv := range convs {
|
||||||
|
result = append(result, ConvertConversationToResponse(conv, nil, 0, nil, false))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== PushRecord 转换 ====================
|
||||||
|
|
||||||
|
// PushRecordToResponse 将PushRecord转换为PushRecordResponse
|
||||||
|
func PushRecordToResponse(record *model.PushRecord) *PushRecordResponse {
|
||||||
|
if record == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
resp := &PushRecordResponse{
|
||||||
|
ID: record.ID,
|
||||||
|
MessageID: record.MessageID,
|
||||||
|
PushChannel: string(record.PushChannel),
|
||||||
|
PushStatus: string(record.PushStatus),
|
||||||
|
RetryCount: record.RetryCount,
|
||||||
|
CreatedAt: record.CreatedAt,
|
||||||
|
}
|
||||||
|
if record.PushedAt != nil {
|
||||||
|
resp.PushedAt = *record.PushedAt
|
||||||
|
}
|
||||||
|
if record.DeliveredAt != nil {
|
||||||
|
resp.DeliveredAt = *record.DeliveredAt
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushRecordsToResponse 将PushRecord列表转换为响应列表
|
||||||
|
func PushRecordsToResponse(records []*model.PushRecord) []*PushRecordResponse {
|
||||||
|
result := make([]*PushRecordResponse, 0, len(records))
|
||||||
|
for _, record := range records {
|
||||||
|
result = append(result, PushRecordToResponse(record))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== DeviceToken 转换 ====================
|
||||||
|
|
||||||
|
// DeviceTokenToResponse 将DeviceToken转换为DeviceTokenResponse
|
||||||
|
func DeviceTokenToResponse(token *model.DeviceToken) *DeviceTokenResponse {
|
||||||
|
if token == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
resp := &DeviceTokenResponse{
|
||||||
|
ID: token.ID,
|
||||||
|
DeviceID: token.DeviceID,
|
||||||
|
DeviceType: string(token.DeviceType),
|
||||||
|
IsActive: token.IsActive,
|
||||||
|
DeviceName: token.DeviceName,
|
||||||
|
CreatedAt: token.CreatedAt,
|
||||||
|
}
|
||||||
|
if token.LastUsedAt != nil {
|
||||||
|
resp.LastUsedAt = *token.LastUsedAt
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceTokensToResponse 将DeviceToken列表转换为响应列表
|
||||||
|
func DeviceTokensToResponse(tokens []*model.DeviceToken) []*DeviceTokenResponse {
|
||||||
|
result := make([]*DeviceTokenResponse, 0, len(tokens))
|
||||||
|
for _, token := range tokens {
|
||||||
|
result = append(result, DeviceTokenToResponse(token))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== SystemMessage 转换 ====================
|
||||||
|
|
||||||
|
// SystemMessageToResponse 将Message转换为SystemMessageResponse
|
||||||
|
func SystemMessageToResponse(msg *model.Message) *SystemMessageResponse {
|
||||||
|
if msg == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从 segments 中提取文本内容
|
||||||
|
content := ExtractTextContentFromModel(msg.Segments)
|
||||||
|
|
||||||
|
resp := &SystemMessageResponse{
|
||||||
|
ID: msg.ID,
|
||||||
|
SenderID: msg.SenderID,
|
||||||
|
ReceiverID: "", // 系统消息的接收者需要从上下文获取
|
||||||
|
Content: content,
|
||||||
|
Category: string(msg.Category),
|
||||||
|
SystemType: string(msg.SystemType),
|
||||||
|
CreatedAt: msg.CreatedAt,
|
||||||
|
}
|
||||||
|
if msg.ExtraData != nil {
|
||||||
|
resp.ExtraData = map[string]interface{}{
|
||||||
|
"actor_id": msg.ExtraData.ActorID,
|
||||||
|
"actor_name": msg.ExtraData.ActorName,
|
||||||
|
"avatar_url": msg.ExtraData.AvatarURL,
|
||||||
|
"target_id": msg.ExtraData.TargetID,
|
||||||
|
"target_type": msg.ExtraData.TargetType,
|
||||||
|
"action_url": msg.ExtraData.ActionURL,
|
||||||
|
"action_time": msg.ExtraData.ActionTime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// SystemMessagesToResponse 将Message列表转换为SystemMessageResponse列表
|
||||||
|
func SystemMessagesToResponse(messages []*model.Message) []*SystemMessageResponse {
|
||||||
|
result := make([]*SystemMessageResponse, 0, len(messages))
|
||||||
|
for _, msg := range messages {
|
||||||
|
result = append(result, SystemMessageToResponse(msg))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// SystemNotificationToResponse 将SystemNotification转换为SystemMessageResponse
|
||||||
|
func SystemNotificationToResponse(n *model.SystemNotification) *SystemMessageResponse {
|
||||||
|
if n == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
resp := &SystemMessageResponse{
|
||||||
|
ID: strconv.FormatInt(n.ID, 10),
|
||||||
|
SenderID: model.SystemSenderIDStr, // 系统发送者
|
||||||
|
ReceiverID: n.ReceiverID,
|
||||||
|
Content: n.Content,
|
||||||
|
Category: "notification",
|
||||||
|
SystemType: string(n.Type),
|
||||||
|
IsRead: n.IsRead,
|
||||||
|
CreatedAt: n.CreatedAt,
|
||||||
|
}
|
||||||
|
if n.ExtraData != nil {
|
||||||
|
resp.ExtraData = map[string]interface{}{
|
||||||
|
"actor_id": n.ExtraData.ActorID,
|
||||||
|
"actor_id_str": n.ExtraData.ActorIDStr,
|
||||||
|
"actor_name": n.ExtraData.ActorName,
|
||||||
|
"avatar_url": n.ExtraData.AvatarURL,
|
||||||
|
"target_id": n.ExtraData.TargetID,
|
||||||
|
"target_title": n.ExtraData.TargetTitle,
|
||||||
|
"target_type": n.ExtraData.TargetType,
|
||||||
|
"action_url": n.ExtraData.ActionURL,
|
||||||
|
"action_time": n.ExtraData.ActionTime,
|
||||||
|
"group_id": n.ExtraData.GroupID,
|
||||||
|
"group_name": n.ExtraData.GroupName,
|
||||||
|
"group_avatar": n.ExtraData.GroupAvatar,
|
||||||
|
"group_description": n.ExtraData.GroupDescription,
|
||||||
|
"flag": n.ExtraData.Flag,
|
||||||
|
"request_type": n.ExtraData.RequestType,
|
||||||
|
"request_status": n.ExtraData.RequestStatus,
|
||||||
|
"reason": n.ExtraData.Reason,
|
||||||
|
"target_user_id": n.ExtraData.TargetUserID,
|
||||||
|
"target_user_name": n.ExtraData.TargetUserName,
|
||||||
|
"target_user_avatar": n.ExtraData.TargetUserAvatar,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// SystemNotificationsToResponse 将SystemNotification列表转换为SystemMessageResponse列表
|
||||||
|
func SystemNotificationsToResponse(notifications []*model.SystemNotification) []*SystemMessageResponse {
|
||||||
|
result := make([]*SystemMessageResponse, 0, len(notifications))
|
||||||
|
for _, n := range notifications {
|
||||||
|
result = append(result, SystemNotificationToResponse(n))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Group 转换 ====================
|
||||||
|
|
||||||
|
// GroupToResponse 将Group转换为GroupResponse
|
||||||
|
func GroupToResponse(group *model.Group) *GroupResponse {
|
||||||
|
if group == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &GroupResponse{
|
||||||
|
ID: group.ID,
|
||||||
|
Name: group.Name,
|
||||||
|
Avatar: group.Avatar,
|
||||||
|
Description: group.Description,
|
||||||
|
OwnerID: group.OwnerID,
|
||||||
|
MemberCount: group.MemberCount,
|
||||||
|
MaxMembers: group.MaxMembers,
|
||||||
|
JoinType: int(group.JoinType),
|
||||||
|
MuteAll: group.MuteAll,
|
||||||
|
CreatedAt: FormatTime(group.CreatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupsToResponse 将Group列表转换为GroupResponse列表
|
||||||
|
func GroupsToResponse(groups []model.Group) []*GroupResponse {
|
||||||
|
result := make([]*GroupResponse, 0, len(groups))
|
||||||
|
for i := range groups {
|
||||||
|
result = append(result, GroupToResponse(&groups[i]))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupMemberToResponse 将GroupMember转换为GroupMemberResponse
|
||||||
|
func GroupMemberToResponse(member *model.GroupMember) *GroupMemberResponse {
|
||||||
|
if member == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &GroupMemberResponse{
|
||||||
|
ID: member.ID,
|
||||||
|
GroupID: member.GroupID,
|
||||||
|
UserID: member.UserID,
|
||||||
|
Role: member.Role,
|
||||||
|
Nickname: member.Nickname,
|
||||||
|
Muted: member.Muted,
|
||||||
|
JoinTime: FormatTime(member.JoinTime),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupMemberToResponseWithUser 将GroupMember转换为GroupMemberResponse(包含用户信息)
|
||||||
|
func GroupMemberToResponseWithUser(member *model.GroupMember, user *model.User) *GroupMemberResponse {
|
||||||
|
if member == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
resp := GroupMemberToResponse(member)
|
||||||
|
if user != nil {
|
||||||
|
resp.User = ConvertUserToResponse(user)
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupMembersToResponse 将GroupMember列表转换为GroupMemberResponse列表
|
||||||
|
func GroupMembersToResponse(members []model.GroupMember) []*GroupMemberResponse {
|
||||||
|
result := make([]*GroupMemberResponse, 0, len(members))
|
||||||
|
for i := range members {
|
||||||
|
result = append(result, GroupMemberToResponse(&members[i]))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupAnnouncementToResponse 将GroupAnnouncement转换为GroupAnnouncementResponse
|
||||||
|
func GroupAnnouncementToResponse(announcement *model.GroupAnnouncement) *GroupAnnouncementResponse {
|
||||||
|
if announcement == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &GroupAnnouncementResponse{
|
||||||
|
ID: announcement.ID,
|
||||||
|
GroupID: announcement.GroupID,
|
||||||
|
Content: announcement.Content,
|
||||||
|
AuthorID: announcement.AuthorID,
|
||||||
|
IsPinned: announcement.IsPinned,
|
||||||
|
CreatedAt: FormatTime(announcement.CreatedAt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupAnnouncementsToResponse 将GroupAnnouncement列表转换为GroupAnnouncementResponse列表
|
||||||
|
func GroupAnnouncementsToResponse(announcements []model.GroupAnnouncement) []*GroupAnnouncementResponse {
|
||||||
|
result := make([]*GroupAnnouncementResponse, 0, len(announcements))
|
||||||
|
for i := range announcements {
|
||||||
|
result = append(result, GroupAnnouncementToResponse(&announcements[i]))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
819
internal/dto/dto.go
Normal file
819
internal/dto/dto.go
Normal file
@@ -0,0 +1,819 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==================== User DTOs ====================
|
||||||
|
|
||||||
|
// UserResponse 用户信息响应
|
||||||
|
type UserResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Nickname string `json:"nickname"`
|
||||||
|
Email *string `json:"email,omitempty"`
|
||||||
|
Phone *string `json:"phone,omitempty"`
|
||||||
|
EmailVerified bool `json:"email_verified"`
|
||||||
|
Avatar string `json:"avatar"`
|
||||||
|
CoverURL string `json:"cover_url"` // 头图URL
|
||||||
|
Bio string `json:"bio"`
|
||||||
|
Website string `json:"website"`
|
||||||
|
Location string `json:"location"`
|
||||||
|
PostsCount int `json:"posts_count"`
|
||||||
|
FollowersCount int `json:"followers_count"`
|
||||||
|
FollowingCount int `json:"following_count"`
|
||||||
|
IsFollowing bool `json:"is_following"` // 当前用户是否关注了该用户
|
||||||
|
IsFollowingMe bool `json:"is_following_me"` // 该用户是否关注了当前用户
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserDetailResponse 用户详情响应
|
||||||
|
type UserDetailResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Nickname string `json:"nickname"`
|
||||||
|
Email *string `json:"email"`
|
||||||
|
EmailVerified bool `json:"email_verified"`
|
||||||
|
Phone *string `json:"phone,omitempty"` // 仅当前用户自己可见
|
||||||
|
Avatar string `json:"avatar"`
|
||||||
|
CoverURL string `json:"cover_url"` // 头图URL
|
||||||
|
Bio string `json:"bio"`
|
||||||
|
Website string `json:"website"`
|
||||||
|
Location string `json:"location"`
|
||||||
|
PostsCount int `json:"posts_count"`
|
||||||
|
FollowersCount int `json:"followers_count"`
|
||||||
|
FollowingCount int `json:"following_count"`
|
||||||
|
IsVerified bool `json:"is_verified"`
|
||||||
|
IsFollowing bool `json:"is_following"` // 当前用户是否关注了该用户
|
||||||
|
IsFollowingMe bool `json:"is_following_me"` // 该用户是否关注了当前用户
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Post DTOs ====================
|
||||||
|
|
||||||
|
// PostImageResponse 帖子图片响应
|
||||||
|
type PostImageResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
ThumbnailURL string `json:"thumbnail_url"`
|
||||||
|
Width int `json:"width"`
|
||||||
|
Height int `json:"height"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostResponse 帖子响应(列表用)
|
||||||
|
type PostResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Images []PostImageResponse `json:"images"`
|
||||||
|
LikesCount int `json:"likes_count"`
|
||||||
|
CommentsCount int `json:"comments_count"`
|
||||||
|
FavoritesCount int `json:"favorites_count"`
|
||||||
|
SharesCount int `json:"shares_count"`
|
||||||
|
ViewsCount int `json:"views_count"`
|
||||||
|
IsPinned bool `json:"is_pinned"`
|
||||||
|
IsLocked bool `json:"is_locked"`
|
||||||
|
IsVote bool `json:"is_vote"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
Author *UserResponse `json:"author"`
|
||||||
|
IsLiked bool `json:"is_liked"`
|
||||||
|
IsFavorited bool `json:"is_favorited"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostDetailResponse 帖子详情响应
|
||||||
|
type PostDetailResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Images []PostImageResponse `json:"images"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
LikesCount int `json:"likes_count"`
|
||||||
|
CommentsCount int `json:"comments_count"`
|
||||||
|
FavoritesCount int `json:"favorites_count"`
|
||||||
|
SharesCount int `json:"shares_count"`
|
||||||
|
ViewsCount int `json:"views_count"`
|
||||||
|
IsPinned bool `json:"is_pinned"`
|
||||||
|
IsLocked bool `json:"is_locked"`
|
||||||
|
IsVote bool `json:"is_vote"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
UpdatedAt string `json:"updated_at"`
|
||||||
|
Author *UserResponse `json:"author"`
|
||||||
|
IsLiked bool `json:"is_liked"`
|
||||||
|
IsFavorited bool `json:"is_favorited"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Comment DTOs ====================
|
||||||
|
|
||||||
|
// CommentImageResponse 评论图片响应
|
||||||
|
type CommentImageResponse struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CommentResponse 评论响应(扁平化结构,类似B站/抖音)
|
||||||
|
// 第一层级正常展示,第二三四五层级在第一层级的评论区扁平展示
|
||||||
|
type CommentResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
PostID string `json:"post_id"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
ParentID *string `json:"parent_id"`
|
||||||
|
RootID *string `json:"root_id"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Images []CommentImageResponse `json:"images"`
|
||||||
|
LikesCount int `json:"likes_count"`
|
||||||
|
RepliesCount int `json:"replies_count"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
Author *UserResponse `json:"author"`
|
||||||
|
IsLiked bool `json:"is_liked"`
|
||||||
|
TargetID *string `json:"target_id,omitempty"` // 被回复的评论ID,前端根据此ID找到被回复用户的昵称
|
||||||
|
Replies []*CommentResponse `json:"replies,omitempty"` // 子回复列表(扁平化,所有层级都在这里)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Notification DTOs ====================
|
||||||
|
|
||||||
|
// NotificationResponse 通知响应
|
||||||
|
type NotificationResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Data string `json:"data"`
|
||||||
|
IsRead bool `json:"is_read"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Message Segment DTOs ====================
|
||||||
|
|
||||||
|
// SegmentType Segment类型
|
||||||
|
type SegmentType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
SegmentTypeText SegmentType = "text"
|
||||||
|
SegmentTypeImage SegmentType = "image"
|
||||||
|
SegmentTypeVoice SegmentType = "voice"
|
||||||
|
SegmentTypeVideo SegmentType = "video"
|
||||||
|
SegmentTypeFile SegmentType = "file"
|
||||||
|
SegmentTypeAt SegmentType = "at"
|
||||||
|
SegmentTypeReply SegmentType = "reply"
|
||||||
|
SegmentTypeFace SegmentType = "face"
|
||||||
|
SegmentTypeLink SegmentType = "link"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TextSegmentData 文本数据
|
||||||
|
type TextSegmentData struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageSegmentData 图片数据
|
||||||
|
type ImageSegmentData struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Width int `json:"width,omitempty"`
|
||||||
|
Height int `json:"height,omitempty"`
|
||||||
|
ThumbnailURL string `json:"thumbnail_url,omitempty"`
|
||||||
|
FileSize int64 `json:"file_size,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VoiceSegmentData 语音数据
|
||||||
|
type VoiceSegmentData struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Duration int `json:"duration,omitempty"` // 秒
|
||||||
|
FileSize int64 `json:"file_size,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoSegmentData 视频数据
|
||||||
|
type VideoSegmentData struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Width int `json:"width,omitempty"`
|
||||||
|
Height int `json:"height,omitempty"`
|
||||||
|
Duration int `json:"duration,omitempty"` // 秒
|
||||||
|
ThumbnailURL string `json:"thumbnail_url,omitempty"`
|
||||||
|
FileSize int64 `json:"file_size,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileSegmentData 文件数据
|
||||||
|
type FileSegmentData struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Size int64 `json:"size,omitempty"`
|
||||||
|
MimeType string `json:"mime_type,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AtSegmentData @数据
|
||||||
|
type AtSegmentData struct {
|
||||||
|
UserID string `json:"user_id"` // "all" 表示@所有人
|
||||||
|
Nickname string `json:"nickname,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplySegmentData 回复数据
|
||||||
|
type ReplySegmentData struct {
|
||||||
|
ID string `json:"id"` // 被回复消息的ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// FaceSegmentData 表情数据
|
||||||
|
type FaceSegmentData struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
URL string `json:"url,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkSegmentData 链接数据
|
||||||
|
type LinkSegmentData struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Title string `json:"title,omitempty"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Image string `json:"image,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Message DTOs ====================
|
||||||
|
|
||||||
|
// MessageResponse 消息响应
|
||||||
|
type MessageResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
ConversationID string `json:"conversation_id"`
|
||||||
|
SenderID string `json:"sender_id"`
|
||||||
|
Seq int64 `json:"seq"`
|
||||||
|
Segments model.MessageSegments `json:"segments"` // 消息链(必须)
|
||||||
|
ReplyToID *string `json:"reply_to_id,omitempty"` // 被回复消息的ID(用于关联查找)
|
||||||
|
Status string `json:"status"`
|
||||||
|
Category string `json:"category,omitempty"` // 消息类别:chat, notification, announcement
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
Sender *UserResponse `json:"sender"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationResponse 会话响应
|
||||||
|
type ConversationResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
IsPinned bool `json:"is_pinned"`
|
||||||
|
Group *GroupResponse `json:"group,omitempty"`
|
||||||
|
LastSeq int64 `json:"last_seq"`
|
||||||
|
LastMessage *MessageResponse `json:"last_message"`
|
||||||
|
LastMessageAt string `json:"last_message_at"`
|
||||||
|
UnreadCount int `json:"unread_count"`
|
||||||
|
Participants []*UserResponse `json:"participants,omitempty"` // 私聊时使用
|
||||||
|
MemberCount int `json:"member_count,omitempty"` // 群聊时使用
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
UpdatedAt string `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationParticipantResponse 会话参与者响应
|
||||||
|
type ConversationParticipantResponse struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
LastReadSeq int64 `json:"last_read_seq"`
|
||||||
|
Muted bool `json:"muted"`
|
||||||
|
IsPinned bool `json:"is_pinned"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Auth DTOs ====================
|
||||||
|
|
||||||
|
// LoginResponse 登录响应
|
||||||
|
type LoginResponse struct {
|
||||||
|
User *UserResponse `json:"user"`
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterResponse 注册响应
|
||||||
|
type RegisterResponse struct {
|
||||||
|
User *UserResponse `json:"user"`
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokenResponse 刷新Token响应
|
||||||
|
type RefreshTokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Common DTOs ====================
|
||||||
|
|
||||||
|
// SuccessResponse 通用成功响应
|
||||||
|
type SuccessResponse struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AvailableResponse 可用性检查响应
|
||||||
|
type AvailableResponse struct {
|
||||||
|
Available bool `json:"available"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountResponse 数量响应
|
||||||
|
type CountResponse struct {
|
||||||
|
Count int `json:"count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// URLResponse URL响应
|
||||||
|
type URLResponse struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Chat Request DTOs ====================
|
||||||
|
|
||||||
|
// CreateConversationRequest 创建会话请求
|
||||||
|
type CreateConversationRequest struct {
|
||||||
|
UserID string `json:"user_id" binding:"required"` // 目标用户ID (UUID格式)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessageRequest 发送消息请求
|
||||||
|
type SendMessageRequest struct {
|
||||||
|
Segments model.MessageSegments `json:"segments" binding:"required"` // 消息链(必须)
|
||||||
|
ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID (string类型)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkReadRequest 标记已读请求
|
||||||
|
type MarkReadRequest struct {
|
||||||
|
LastReadSeq int64 `json:"last_read_seq" binding:"required"` // 已读到的seq位置
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetConversationPinnedRequest 设置会话置顶请求
|
||||||
|
type SetConversationPinnedRequest struct {
|
||||||
|
ConversationID string `json:"conversation_id" binding:"required"`
|
||||||
|
IsPinned bool `json:"is_pinned"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Chat Response DTOs ====================
|
||||||
|
|
||||||
|
// ConversationListResponse 会话列表响应
|
||||||
|
type ConversationListResponse struct {
|
||||||
|
Conversations []*ConversationResponse `json:"list"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationDetailResponse 会话详情响应
|
||||||
|
type ConversationDetailResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
IsPinned bool `json:"is_pinned"`
|
||||||
|
LastSeq int64 `json:"last_seq"`
|
||||||
|
LastMessage *MessageResponse `json:"last_message"`
|
||||||
|
LastMessageAt string `json:"last_message_at"`
|
||||||
|
UnreadCount int64 `json:"unread_count"`
|
||||||
|
Participants []*UserResponse `json:"participants"`
|
||||||
|
MyLastReadSeq int64 `json:"my_last_read_seq"` // 当前用户的已读位置
|
||||||
|
OtherLastReadSeq int64 `json:"other_last_read_seq"` // 对方用户的已读位置
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
UpdatedAt string `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnreadCountResponse 未读数响应
|
||||||
|
type UnreadCountResponse struct {
|
||||||
|
TotalUnreadCount int64 `json:"total_unread_count"` // 所有会话的未读总数
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationUnreadCountResponse 单个会话未读数响应
|
||||||
|
type ConversationUnreadCountResponse struct {
|
||||||
|
ConversationID string `json:"conversation_id"`
|
||||||
|
UnreadCount int64 `json:"unread_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageListResponse 消息列表响应
|
||||||
|
type MessageListResponse struct {
|
||||||
|
Messages []*MessageResponse `json:"messages"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageSyncResponse 消息同步响应(增量同步)
|
||||||
|
type MessageSyncResponse struct {
|
||||||
|
Messages []*MessageResponse `json:"messages"`
|
||||||
|
HasMore bool `json:"has_more"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 设备Token DTOs ====================
|
||||||
|
|
||||||
|
// RegisterDeviceRequest 注册设备请求
|
||||||
|
type RegisterDeviceRequest struct {
|
||||||
|
DeviceID string `json:"device_id" binding:"required"`
|
||||||
|
DeviceType string `json:"device_type" binding:"required,oneof=ios android web"`
|
||||||
|
PushToken string `json:"push_token"`
|
||||||
|
DeviceName string `json:"device_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceTokenResponse 设备Token响应
|
||||||
|
type DeviceTokenResponse struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
DeviceID string `json:"device_id"`
|
||||||
|
DeviceType string `json:"device_type"`
|
||||||
|
IsActive bool `json:"is_active"`
|
||||||
|
DeviceName string `json:"device_name"`
|
||||||
|
LastUsedAt time.Time `json:"last_used_at,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 推送记录 DTOs ====================
|
||||||
|
|
||||||
|
// PushRecordResponse 推送记录响应
|
||||||
|
type PushRecordResponse struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
MessageID string `json:"message_id"`
|
||||||
|
PushChannel string `json:"push_channel"`
|
||||||
|
PushStatus string `json:"push_status"`
|
||||||
|
RetryCount int `json:"retry_count"`
|
||||||
|
PushedAt time.Time `json:"pushed_at,omitempty"`
|
||||||
|
DeliveredAt time.Time `json:"delivered_at,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushRecordListResponse 推送记录列表响应
|
||||||
|
type PushRecordListResponse struct {
|
||||||
|
Records []*PushRecordResponse `json:"records"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 系统消息 DTOs ====================
|
||||||
|
|
||||||
|
// SystemMessageResponse 系统消息响应
|
||||||
|
type SystemMessageResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
SenderID string `json:"sender_id"`
|
||||||
|
ReceiverID string `json:"receiver_id"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Category string `json:"category"`
|
||||||
|
SystemType string `json:"system_type"`
|
||||||
|
ExtraData map[string]interface{} `json:"extra_data,omitempty"`
|
||||||
|
IsRead bool `json:"is_read"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SystemMessageListResponse 系统消息列表响应
|
||||||
|
type SystemMessageListResponse struct {
|
||||||
|
Messages []*SystemMessageResponse `json:"messages"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SystemUnreadCountResponse 系统消息未读数响应
|
||||||
|
type SystemUnreadCountResponse struct {
|
||||||
|
UnreadCount int64 `json:"unread_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 时间格式化 ====================
|
||||||
|
|
||||||
|
// FormatTime 格式化时间
|
||||||
|
func FormatTime(t time.Time) string {
|
||||||
|
if t.IsZero() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return t.Format("2006-01-02T15:04:05Z07:00")
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatTimePointer 格式化时间指针
|
||||||
|
func FormatTimePointer(t *time.Time) string {
|
||||||
|
if t == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return FormatTime(*t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Group DTOs ====================
|
||||||
|
|
||||||
|
// CreateGroupRequest 创建群组请求
|
||||||
|
type CreateGroupRequest struct {
|
||||||
|
Name string `json:"name" binding:"required,max=50"`
|
||||||
|
Description string `json:"description" binding:"max=500"`
|
||||||
|
MemberIDs []string `json:"member_ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateGroupRequest 更新群组请求
|
||||||
|
type UpdateGroupRequest struct {
|
||||||
|
Name string `json:"name" binding:"omitempty,max=50"`
|
||||||
|
Description string `json:"description" binding:"omitempty,max=500"`
|
||||||
|
Avatar string `json:"avatar" binding:"omitempty,url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// InviteMembersRequest 邀请成员请求
|
||||||
|
type InviteMembersRequest struct {
|
||||||
|
MemberIDs []string `json:"member_ids" binding:"required,min=1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TransferOwnerRequest 转让群主请求
|
||||||
|
type TransferOwnerRequest struct {
|
||||||
|
NewOwnerID string `json:"new_owner_id" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRoleRequest 设置角色请求
|
||||||
|
type SetRoleRequest struct {
|
||||||
|
Role string `json:"role" binding:"required,oneof=admin member"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNicknameRequest 设置昵称请求
|
||||||
|
type SetNicknameRequest struct {
|
||||||
|
Nickname string `json:"nickname" binding:"max=50"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MuteMemberRequest 禁言成员请求
|
||||||
|
type MuteMemberRequest struct {
|
||||||
|
Muted bool `json:"muted"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMuteAllRequest 设置全员禁言请求
|
||||||
|
type SetMuteAllRequest struct {
|
||||||
|
MuteAll bool `json:"mute_all"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetJoinTypeRequest 设置加群方式请求
|
||||||
|
type SetJoinTypeRequest struct {
|
||||||
|
JoinType int `json:"join_type" binding:"min=0,max=2"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAnnouncementRequest 创建群公告请求
|
||||||
|
type CreateAnnouncementRequest struct {
|
||||||
|
Content string `json:"content" binding:"required,max=2000"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupResponse 群组响应
|
||||||
|
type GroupResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Avatar string `json:"avatar"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
OwnerID string `json:"owner_id"`
|
||||||
|
MemberCount int `json:"member_count"`
|
||||||
|
MaxMembers int `json:"max_members"`
|
||||||
|
JoinType int `json:"join_type"`
|
||||||
|
MuteAll bool `json:"mute_all"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupMemberResponse 群成员响应
|
||||||
|
type GroupMemberResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
GroupID string `json:"group_id"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Nickname string `json:"nickname"`
|
||||||
|
Muted bool `json:"muted"`
|
||||||
|
JoinTime string `json:"join_time"`
|
||||||
|
User *UserResponse `json:"user,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupAnnouncementResponse 群公告响应
|
||||||
|
type GroupAnnouncementResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
GroupID string `json:"group_id"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
AuthorID string `json:"author_id"`
|
||||||
|
IsPinned bool `json:"is_pinned"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupListResponse 群组列表响应
|
||||||
|
type GroupListResponse struct {
|
||||||
|
List []*GroupResponse `json:"list"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupMemberListResponse 群成员列表响应
|
||||||
|
type GroupMemberListResponse struct {
|
||||||
|
List []*GroupMemberResponse `json:"list"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupAnnouncementListResponse 群公告列表响应
|
||||||
|
type GroupAnnouncementListResponse struct {
|
||||||
|
List []*GroupAnnouncementResponse `json:"list"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== WebSocket Event DTOs ====================
|
||||||
|
|
||||||
|
// WSEventResponse WebSocket事件响应结构体
|
||||||
|
// 用于后端推送消息给前端的标准格式
|
||||||
|
type WSEventResponse struct {
|
||||||
|
ID string `json:"id"` // 事件唯一ID (UUID)
|
||||||
|
Time int64 `json:"time"` // 时间戳 (毫秒)
|
||||||
|
Type string `json:"type"` // 事件类型 (message, notification, system等)
|
||||||
|
DetailType string `json:"detail_type"` // 详细类型 (private, group, like, comment等)
|
||||||
|
Seq string `json:"seq"` // 消息序列号
|
||||||
|
Segments model.MessageSegments `json:"segments"` // 消息段数组
|
||||||
|
SenderID string `json:"sender_id"` // 发送者用户ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== WebSocket Request DTOs ====================
|
||||||
|
|
||||||
|
// SendMessageParams 发送消息参数(用于 REST API)
|
||||||
|
type SendMessageParams struct {
|
||||||
|
DetailType string `json:"detail_type"` // 消息类型: private, group
|
||||||
|
ConversationID string `json:"conversation_id"` // 会话ID
|
||||||
|
Segments model.MessageSegments `json:"segments"` // 消息内容(消息段数组)
|
||||||
|
ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteMsgParams 撤回消息参数
|
||||||
|
type DeleteMsgParams struct {
|
||||||
|
MessageID string `json:"message_id"` // 消息ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Group Action Params ====================
|
||||||
|
|
||||||
|
// SetGroupKickParams 群组踢人参数
|
||||||
|
type SetGroupKickParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
UserID string `json:"user_id"` // 被踢用户ID
|
||||||
|
RejectAddRequest bool `json:"reject_add_request"` // 是否拒绝再次加群
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetGroupBanParams 群组单人禁言参数
|
||||||
|
type SetGroupBanParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
UserID string `json:"user_id"` // 被禁言用户ID
|
||||||
|
Duration int64 `json:"duration"` // 禁言时长(秒),0表示解除禁言
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetGroupWholeBanParams 群组全员禁言参数
|
||||||
|
type SetGroupWholeBanParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
Enable bool `json:"enable"` // 是否开启全员禁言
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetGroupAdminParams 群组设置管理员参数
|
||||||
|
type SetGroupAdminParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
UserID string `json:"user_id"` // 被设置的用户ID
|
||||||
|
Enable bool `json:"enable"` // 是否设置为管理员
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetGroupNameParams 设置群名参数
|
||||||
|
type SetGroupNameParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
GroupName string `json:"group_name"` // 新群名
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetGroupAvatarParams 设置群头像参数
|
||||||
|
type SetGroupAvatarParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
Avatar string `json:"avatar"` // 头像URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetGroupLeaveParams 退出群组参数
|
||||||
|
type SetGroupLeaveParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetGroupAddRequestParams 处理加群请求参数
|
||||||
|
type SetGroupAddRequestParams struct {
|
||||||
|
Flag string `json:"flag"` // 加群请求的flag标识
|
||||||
|
Approve bool `json:"approve"` // 是否同意
|
||||||
|
Reason string `json:"reason"` // 拒绝理由(当approve为false时)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversationListParams 获取会话列表参数
|
||||||
|
type GetConversationListParams struct {
|
||||||
|
Page int `json:"page"` // 页码
|
||||||
|
PageSize int `json:"page_size"` // 每页数量
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGroupInfoParams 获取群信息参数
|
||||||
|
type GetGroupInfoParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGroupMemberListParams 获取群成员列表参数
|
||||||
|
type GetGroupMemberListParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
Page int `json:"page"` // 页码
|
||||||
|
PageSize int `json:"page_size"` // 每页数量
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Conversation Action Params ====================
|
||||||
|
|
||||||
|
// CreateConversationParams 创建会话参数
|
||||||
|
type CreateConversationParams struct {
|
||||||
|
UserID string `json:"user_id"` // 目标用户ID(私聊)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkReadParams 标记已读参数
|
||||||
|
type MarkReadParams struct {
|
||||||
|
ConversationID string `json:"conversation_id"` // 会话ID
|
||||||
|
LastReadSeq int64 `json:"last_read_seq"` // 最后已读消息序号
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetConversationPinnedParams 设置会话置顶参数
|
||||||
|
type SetConversationPinnedParams struct {
|
||||||
|
ConversationID string `json:"conversation_id"` // 会话ID
|
||||||
|
IsPinned bool `json:"is_pinned"` // 是否置顶
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Group Action Params (Additional) ====================
|
||||||
|
|
||||||
|
// CreateGroupParams 创建群组参数
|
||||||
|
type CreateGroupParams struct {
|
||||||
|
Name string `json:"name"` // 群名
|
||||||
|
Description string `json:"description,omitempty"` // 群描述
|
||||||
|
MemberIDs []string `json:"member_ids,omitempty"` // 初始成员ID列表
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserGroupsParams 获取用户群组列表参数
|
||||||
|
type GetUserGroupsParams struct {
|
||||||
|
Page int `json:"page"` // 页码
|
||||||
|
PageSize int `json:"page_size"` // 每页数量
|
||||||
|
}
|
||||||
|
|
||||||
|
// TransferOwnerParams 转让群主参数
|
||||||
|
type TransferOwnerParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
NewOwnerID string `json:"new_owner_id"` // 新群主ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// InviteMembersParams 邀请成员参数
|
||||||
|
type InviteMembersParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
MemberIDs []string `json:"member_ids"` // 被邀请的用户ID列表
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinGroupParams 加入群组参数
|
||||||
|
type JoinGroupParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNicknameParams 设置群内昵称参数
|
||||||
|
type SetNicknameParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
Nickname string `json:"nickname"` // 群内昵称
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetJoinTypeParams 设置加群方式参数
|
||||||
|
type SetJoinTypeParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
JoinType int `json:"join_type"` // 加群方式:0-允许任何人加入,1-需要审批,2-不允许加入
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAnnouncementParams 创建群公告参数
|
||||||
|
type CreateAnnouncementParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
Content string `json:"content"` // 公告内容
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAnnouncementsParams 获取群公告列表参数
|
||||||
|
type GetAnnouncementsParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
Page int `json:"page"` // 页码
|
||||||
|
PageSize int `json:"page_size"` // 每页数量
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAnnouncementParams 删除群公告参数
|
||||||
|
type DeleteAnnouncementParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
AnnouncementID string `json:"announcement_id"` // 公告ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// DissolveGroupParams 解散群组参数
|
||||||
|
type DissolveGroupParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMyMemberInfoParams 获取我在群组中的成员信息参数
|
||||||
|
type GetMyMemberInfoParams struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Vote DTOs ====================
|
||||||
|
|
||||||
|
// CreateVotePostRequest 创建投票帖子请求
|
||||||
|
type CreateVotePostRequest struct {
|
||||||
|
Title string `json:"title" binding:"required,max=200"`
|
||||||
|
Content string `json:"content" binding:"max=2000"`
|
||||||
|
CommunityID string `json:"community_id"`
|
||||||
|
Images []string `json:"images"`
|
||||||
|
VoteOptions []string `json:"vote_options" binding:"required,min=2,max=10"` // 投票选项,至少2个最多10个
|
||||||
|
}
|
||||||
|
|
||||||
|
// VoteOptionDTO 投票选项DTO
|
||||||
|
type VoteOptionDTO struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
VotesCount int `json:"votes_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VoteResultDTO 投票结果DTO
|
||||||
|
type VoteResultDTO struct {
|
||||||
|
Options []VoteOptionDTO `json:"options"`
|
||||||
|
TotalVotes int `json:"total_votes"`
|
||||||
|
HasVoted bool `json:"has_voted"`
|
||||||
|
VotedOptionID string `json:"voted_option_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== WebSocket Response DTOs ====================
|
||||||
|
|
||||||
|
// WSResponse WebSocket响应结构体
|
||||||
|
type WSResponse struct {
|
||||||
|
Success bool `json:"success"` // 是否成功
|
||||||
|
Action string `json:"action"` // 响应原action
|
||||||
|
Data interface{} `json:"data,omitempty"` // 响应数据
|
||||||
|
Error string `json:"error,omitempty"` // 错误信息
|
||||||
|
}
|
||||||
362
internal/dto/segment.go
Normal file
362
internal/dto/segment.go
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseSegmentData 解析Segment数据到目标结构体
|
||||||
|
func ParseSegmentData(segment model.MessageSegment, target interface{}) error {
|
||||||
|
dataBytes, err := json.Marshal(segment.Data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return json.Unmarshal(dataBytes, target)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTextSegment 创建文本Segment
|
||||||
|
func NewTextSegment(content string) model.MessageSegment {
|
||||||
|
return model.MessageSegment{
|
||||||
|
Type: string(SegmentTypeText),
|
||||||
|
Data: map[string]interface{}{"text": content},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewImageSegment 创建图片Segment
|
||||||
|
func NewImageSegment(url string, width, height int, thumbnailURL string) model.MessageSegment {
|
||||||
|
return model.MessageSegment{
|
||||||
|
Type: string(SegmentTypeImage),
|
||||||
|
Data: map[string]interface{}{
|
||||||
|
"url": url,
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"thumbnail_url": thumbnailURL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewImageSegmentWithSize 创建带文件大小的图片Segment
|
||||||
|
func NewImageSegmentWithSize(url string, width, height int, thumbnailURL string, fileSize int64) model.MessageSegment {
|
||||||
|
return model.MessageSegment{
|
||||||
|
Type: string(SegmentTypeImage),
|
||||||
|
Data: map[string]interface{}{
|
||||||
|
"url": url,
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"thumbnail_url": thumbnailURL,
|
||||||
|
"file_size": fileSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewVoiceSegment 创建语音Segment
|
||||||
|
func NewVoiceSegment(url string, duration int, fileSize int64) model.MessageSegment {
|
||||||
|
return model.MessageSegment{
|
||||||
|
Type: string(SegmentTypeVoice),
|
||||||
|
Data: map[string]interface{}{
|
||||||
|
"url": url,
|
||||||
|
"duration": duration,
|
||||||
|
"file_size": fileSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewVideoSegment 创建视频Segment
|
||||||
|
func NewVideoSegment(url string, width, height, duration int, thumbnailURL string, fileSize int64) model.MessageSegment {
|
||||||
|
return model.MessageSegment{
|
||||||
|
Type: string(SegmentTypeVideo),
|
||||||
|
Data: map[string]interface{}{
|
||||||
|
"url": url,
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"duration": duration,
|
||||||
|
"thumbnail_url": thumbnailURL,
|
||||||
|
"file_size": fileSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFileSegment 创建文件Segment
|
||||||
|
func NewFileSegment(url, name string, size int64, mimeType string) model.MessageSegment {
|
||||||
|
return model.MessageSegment{
|
||||||
|
Type: string(SegmentTypeFile),
|
||||||
|
Data: map[string]interface{}{
|
||||||
|
"url": url,
|
||||||
|
"name": name,
|
||||||
|
"size": size,
|
||||||
|
"mime_type": mimeType,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAtSegment 创建@Segment,只存储 user_id,昵称由前端根据群成员列表实时解析
|
||||||
|
func NewAtSegment(userID string) model.MessageSegment {
|
||||||
|
return model.MessageSegment{
|
||||||
|
Type: string(SegmentTypeAt),
|
||||||
|
Data: map[string]interface{}{
|
||||||
|
"user_id": userID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAtAllSegment 创建@所有人Segment
|
||||||
|
func NewAtAllSegment() model.MessageSegment {
|
||||||
|
return model.MessageSegment{
|
||||||
|
Type: string(SegmentTypeAt),
|
||||||
|
Data: map[string]interface{}{
|
||||||
|
"user_id": "all",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewReplySegment 创建回复Segment
|
||||||
|
func NewReplySegment(messageID string) model.MessageSegment {
|
||||||
|
return model.MessageSegment{
|
||||||
|
Type: string(SegmentTypeReply),
|
||||||
|
Data: map[string]interface{}{"id": messageID},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFaceSegment 创建表情Segment
|
||||||
|
func NewFaceSegment(id int, name, url string) model.MessageSegment {
|
||||||
|
return model.MessageSegment{
|
||||||
|
Type: string(SegmentTypeFace),
|
||||||
|
Data: map[string]interface{}{
|
||||||
|
"id": id,
|
||||||
|
"name": name,
|
||||||
|
"url": url,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLinkSegment 创建链接Segment
|
||||||
|
func NewLinkSegment(url, title, description, image string) model.MessageSegment {
|
||||||
|
return model.MessageSegment{
|
||||||
|
Type: string(SegmentTypeLink),
|
||||||
|
Data: map[string]interface{}{
|
||||||
|
"url": url,
|
||||||
|
"title": title,
|
||||||
|
"description": description,
|
||||||
|
"image": image,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractTextContentFromJSON 从JSON格式的segments中提取纯文本内容
|
||||||
|
// 用于从数据库读取的 []byte 格式的 segments
|
||||||
|
// 已废弃:现在数据库直接存储 model.MessageSegments 类型
|
||||||
|
func ExtractTextContentFromJSON(segmentsJSON []byte) string {
|
||||||
|
if len(segmentsJSON) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var segments model.MessageSegments
|
||||||
|
if err := json.Unmarshal(segmentsJSON, &segments); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return ExtractTextContentFromModel(segments)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractTextContentFromModel 从 model.MessageSegments 中提取纯文本内容
|
||||||
|
func ExtractTextContentFromModel(segments model.MessageSegments) string {
|
||||||
|
var result string
|
||||||
|
for _, segment := range segments {
|
||||||
|
switch segment.Type {
|
||||||
|
case "text":
|
||||||
|
if text, ok := segment.Data["text"].(string); ok {
|
||||||
|
result += text
|
||||||
|
}
|
||||||
|
case "at":
|
||||||
|
userID, _ := segment.Data["user_id"].(string)
|
||||||
|
if userID == "all" {
|
||||||
|
result += "@所有人 "
|
||||||
|
} else if userID != "" {
|
||||||
|
// 昵称由前端实时解析,后端文本提取仅用于推送通知兜底
|
||||||
|
result += "@某人 "
|
||||||
|
}
|
||||||
|
case "image":
|
||||||
|
result += "[图片]"
|
||||||
|
case "voice":
|
||||||
|
result += "[语音]"
|
||||||
|
case "video":
|
||||||
|
result += "[视频]"
|
||||||
|
case "file":
|
||||||
|
if name, ok := segment.Data["name"].(string); ok && name != "" {
|
||||||
|
result += fmt.Sprintf("[文件:%s]", name)
|
||||||
|
} else {
|
||||||
|
result += "[文件]"
|
||||||
|
}
|
||||||
|
case "face":
|
||||||
|
if name, ok := segment.Data["name"].(string); ok && name != "" {
|
||||||
|
result += fmt.Sprintf("[%s]", name)
|
||||||
|
} else {
|
||||||
|
result += "[表情]"
|
||||||
|
}
|
||||||
|
case "link":
|
||||||
|
if title, ok := segment.Data["title"].(string); ok && title != "" {
|
||||||
|
result += fmt.Sprintf("[链接:%s]", title)
|
||||||
|
} else {
|
||||||
|
result += "[链接]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractTextContent 从消息链中提取纯文本内容
|
||||||
|
// 用于搜索、通知展示等场景
|
||||||
|
func ExtractTextContent(segments model.MessageSegments) string {
|
||||||
|
return ExtractTextContentFromModel(segments)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractMentionedUsers 从消息链中提取被@的用户ID列表
|
||||||
|
// 不包括 "all"(@所有人)
|
||||||
|
func ExtractMentionedUsers(segments model.MessageSegments) []string {
|
||||||
|
var userIDs []string
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, segment := range segments {
|
||||||
|
if segment.Type == string(SegmentTypeAt) {
|
||||||
|
userID, _ := segment.Data["user_id"].(string)
|
||||||
|
if userID != "all" && userID != "" && !seen[userID] {
|
||||||
|
userIDs = append(userIDs, userID)
|
||||||
|
seen[userID] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return userIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAtAll 检查消息是否@了所有人
|
||||||
|
func IsAtAll(segments model.MessageSegments) bool {
|
||||||
|
for _, segment := range segments {
|
||||||
|
if segment.Type == string(SegmentTypeAt) {
|
||||||
|
if userID, ok := segment.Data["user_id"].(string); ok && userID == "all" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetReplyMessageID 从消息链中获取被回复的消息ID
|
||||||
|
// 如果没有回复segment,返回空字符串
|
||||||
|
func GetReplyMessageID(segments model.MessageSegments) string {
|
||||||
|
for _, segment := range segments {
|
||||||
|
if segment.Type == string(SegmentTypeReply) {
|
||||||
|
if id, ok := segment.Data["id"].(string); ok {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildSegmentsFromContent 从旧版content构建segments
|
||||||
|
// 用于兼容旧版本消息
|
||||||
|
func BuildSegmentsFromContent(contentType, content string, mediaURL *string) model.MessageSegments {
|
||||||
|
var segments model.MessageSegments
|
||||||
|
|
||||||
|
switch contentType {
|
||||||
|
case "text":
|
||||||
|
segments = append(segments, NewTextSegment(content))
|
||||||
|
case "image":
|
||||||
|
if mediaURL != nil {
|
||||||
|
segments = append(segments, NewImageSegment(*mediaURL, 0, 0, ""))
|
||||||
|
}
|
||||||
|
case "voice":
|
||||||
|
if mediaURL != nil {
|
||||||
|
segments = append(segments, NewVoiceSegment(*mediaURL, 0, 0))
|
||||||
|
}
|
||||||
|
case "video":
|
||||||
|
if mediaURL != nil {
|
||||||
|
segments = append(segments, NewVideoSegment(*mediaURL, 0, 0, 0, "", 0))
|
||||||
|
}
|
||||||
|
case "file":
|
||||||
|
if mediaURL != nil {
|
||||||
|
segments = append(segments, NewFileSegment(*mediaURL, content, 0, ""))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// 默认当作文本处理
|
||||||
|
if content != "" {
|
||||||
|
segments = append(segments, NewTextSegment(content))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return segments
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasSegmentType 检查消息链中是否包含指定类型的segment
|
||||||
|
func HasSegmentType(segments model.MessageSegments, segmentType SegmentType) bool {
|
||||||
|
for _, segment := range segments {
|
||||||
|
if segment.Type == string(segmentType) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSegmentsByType 获取消息链中所有指定类型的segment
|
||||||
|
func GetSegmentsByType(segments model.MessageSegments, segmentType SegmentType) []model.MessageSegment {
|
||||||
|
var result []model.MessageSegment
|
||||||
|
for _, segment := range segments {
|
||||||
|
if segment.Type == string(segmentType) {
|
||||||
|
result = append(result, segment)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFirstImageURL 获取消息链中第一张图片的URL
|
||||||
|
// 如果没有图片,返回空字符串
|
||||||
|
func GetFirstImageURL(segments model.MessageSegments) string {
|
||||||
|
for _, segment := range segments {
|
||||||
|
if segment.Type == string(SegmentTypeImage) {
|
||||||
|
if url, ok := segment.Data["url"].(string); ok {
|
||||||
|
return url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFirstMediaURL 获取消息链中第一个媒体文件的URL(图片/视频/语音/文件)
|
||||||
|
// 用于兼容旧版本API
|
||||||
|
func GetFirstMediaURL(segments model.MessageSegments) string {
|
||||||
|
for _, segment := range segments {
|
||||||
|
switch segment.Type {
|
||||||
|
case string(SegmentTypeImage), string(SegmentTypeVideo), string(SegmentTypeVoice), string(SegmentTypeFile):
|
||||||
|
if url, ok := segment.Data["url"].(string); ok {
|
||||||
|
return url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetermineContentType 从消息链推断消息类型(用于兼容旧版本)
|
||||||
|
func DetermineContentType(segments model.MessageSegments) string {
|
||||||
|
if len(segments) == 0 {
|
||||||
|
return "text"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 优先检查媒体类型
|
||||||
|
for _, segment := range segments {
|
||||||
|
switch segment.Type {
|
||||||
|
case string(SegmentTypeImage):
|
||||||
|
return "image"
|
||||||
|
case string(SegmentTypeVideo):
|
||||||
|
return "video"
|
||||||
|
case string(SegmentTypeVoice):
|
||||||
|
return "voice"
|
||||||
|
case string(SegmentTypeFile):
|
||||||
|
return "file"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 默认返回text
|
||||||
|
return "text"
|
||||||
|
}
|
||||||
253
internal/handler/comment_handler.go
Normal file
253
internal/handler/comment_handler.go
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/dto"
|
||||||
|
"carrot_bbs/internal/pkg/response"
|
||||||
|
"carrot_bbs/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CommentHandler 评论处理器
|
||||||
|
type CommentHandler struct {
|
||||||
|
commentService *service.CommentService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCommentHandler 创建评论处理器
|
||||||
|
func NewCommentHandler(commentService *service.CommentService) *CommentHandler {
|
||||||
|
return &CommentHandler{
|
||||||
|
commentService: commentService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 创建评论
|
||||||
|
func (h *CommentHandler) Create(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateRequest struct {
|
||||||
|
PostID string `json:"post_id" binding:"required"`
|
||||||
|
Content string `json:"content"` // 内容可选,允许只发图片
|
||||||
|
ParentID *string `json:"parent_id"`
|
||||||
|
Images []string `json:"images"` // 图片URL列表
|
||||||
|
}
|
||||||
|
|
||||||
|
var req CreateRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证:评论必须有内容或图片
|
||||||
|
if req.Content == "" && len(req.Images) == 0 {
|
||||||
|
response.BadRequest(c, "评论内容或图片不能同时为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将图片列表转换为JSON字符串
|
||||||
|
var imagesJSON string
|
||||||
|
if len(req.Images) > 0 {
|
||||||
|
imagesBytes, _ := json.Marshal(req.Images)
|
||||||
|
imagesJSON = string(imagesBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
comment, err := h.commentService.Create(c.Request.Context(), req.PostID, userID, req.Content, req.ParentID, imagesJSON, req.Images)
|
||||||
|
if err != nil {
|
||||||
|
var moderationErr *service.CommentModerationRejectedError
|
||||||
|
if errors.As(err, &moderationErr) {
|
||||||
|
response.BadRequest(c, moderationErr.UserMessage())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.InternalServerError(c, "failed to create comment")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.ConvertCommentToResponse(comment, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID 获取单条评论详情
|
||||||
|
func (h *CommentHandler) GetByID(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
comment, err := h.commentService.GetByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "comment not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := dto.ConvertCommentToResponse(comment, h.commentService.IsLiked(c.Request.Context(), id, userID))
|
||||||
|
response.Success(c, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByPostID 获取帖子评论
|
||||||
|
func (h *CommentHandler) GetByPostID(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
postID := c.Param("id")
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
|
||||||
|
comments, total, err := h.commentService.GetByPostID(c.Request.Context(), postID, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get comments")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应结构,检查每个评论的点赞状态
|
||||||
|
commentResponses := dto.ConvertCommentsToResponseWithUser(comments, userID, h.commentService)
|
||||||
|
|
||||||
|
response.Paginated(c, commentResponses, total, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetReplies 获取回复
|
||||||
|
func (h *CommentHandler) GetReplies(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
parentID := c.Param("id")
|
||||||
|
|
||||||
|
comments, err := h.commentService.GetReplies(c.Request.Context(), parentID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get replies")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应结构,检查每个回复的点赞状态
|
||||||
|
commentResponses := dto.ConvertCommentsToResponseWithUser(comments, userID, h.commentService)
|
||||||
|
|
||||||
|
response.Success(c, commentResponses)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRepliesByRootID 根据根评论ID分页获取回复(扁平化)
|
||||||
|
func (h *CommentHandler) GetRepliesByRootID(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
rootID := c.Param("id")
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10"))
|
||||||
|
|
||||||
|
replies, total, err := h.commentService.GetRepliesByRootID(c.Request.Context(), rootID, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get replies")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应结构,检查每个回复的点赞状态
|
||||||
|
replyResponses := dto.ConvertCommentsToResponseWithUser(replies, userID, h.commentService)
|
||||||
|
|
||||||
|
response.Paginated(c, replyResponses, total, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 更新评论
|
||||||
|
func (h *CommentHandler) Update(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
comment, err := h.commentService.GetByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "comment not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if comment.UserID != userID {
|
||||||
|
response.Forbidden(c, "cannot update others' comment")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateRequest struct {
|
||||||
|
Content string `json:"content" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req UpdateRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
comment.Content = req.Content
|
||||||
|
|
||||||
|
err = h.commentService.Update(c.Request.Context(), comment)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to update comment")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.ConvertCommentToResponse(comment, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除评论
|
||||||
|
func (h *CommentHandler) Delete(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
comment, err := h.commentService.GetByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "comment not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if comment.UserID != userID {
|
||||||
|
response.Forbidden(c, "cannot delete others' comment")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.commentService.Delete(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to delete comment")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "comment deleted", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Like 点赞评论
|
||||||
|
func (h *CommentHandler) Like(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
err := h.commentService.Like(c.Request.Context(), id, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to like comment")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "liked", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unlike 取消点赞评论
|
||||||
|
func (h *CommentHandler) Unlike(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
err := h.commentService.Unlike(c.Request.Context(), id, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to unlike comment")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "unliked", nil)
|
||||||
|
}
|
||||||
234
internal/handler/gorse_handler.go
Normal file
234
internal/handler/gorse_handler.go
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/config"
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/pkg/gorse"
|
||||||
|
"carrot_bbs/internal/pkg/response"
|
||||||
|
|
||||||
|
gorseio "github.com/gorse-io/gorse-go"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GorseHandler Gorse推荐处理器
|
||||||
|
type GorseHandler struct {
|
||||||
|
importPassword string
|
||||||
|
gorseConfig config.GorseConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGorseHandler 创建Gorse处理器
|
||||||
|
func NewGorseHandler(cfg config.GorseConfig) *GorseHandler {
|
||||||
|
return &GorseHandler{
|
||||||
|
importPassword: cfg.ImportPassword,
|
||||||
|
gorseConfig: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImportRequest 导入请求
|
||||||
|
type ImportRequest struct {
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImportData 导入数据到Gorse
|
||||||
|
// POST /api/v1/gorse/import
|
||||||
|
func (h *GorseHandler) ImportData(c *gin.Context) {
|
||||||
|
// 验证密码
|
||||||
|
if h.importPassword == "" {
|
||||||
|
response.BadRequest(c, "Gorse import is disabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req ImportRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Password != h.importPassword {
|
||||||
|
response.Unauthorized(c, "invalid password")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
stats, err := h.importAllData(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[ERROR] gorse import failed: %v", err)
|
||||||
|
response.InternalServerError(c, "gorse import failed: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"message": "import completed",
|
||||||
|
"status": "done",
|
||||||
|
"stats": stats,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStatus 获取Gorse状态
|
||||||
|
// GET /api/v1/gorse/status
|
||||||
|
func (h *GorseHandler) GetStatus(c *gin.Context) {
|
||||||
|
// 返回Gorse连接状态和配置信息
|
||||||
|
hasPassword := h.importPassword != ""
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"enabled": h.gorseConfig.Enabled,
|
||||||
|
"has_password": hasPassword,
|
||||||
|
"address": h.gorseConfig.Address,
|
||||||
|
"api_key": strings.Repeat("*", 8), // 不返回实际APIKey
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *GorseHandler) importAllData(ctx context.Context) (gin.H, error) {
|
||||||
|
gorseClient := gorseio.NewGorseClient(h.gorseConfig.Address, h.gorseConfig.APIKey)
|
||||||
|
gorse.InitEmbeddingWithConfig(h.gorseConfig.EmbeddingAPIKey, h.gorseConfig.EmbeddingURL, h.gorseConfig.EmbeddingModel)
|
||||||
|
|
||||||
|
stats := gin.H{
|
||||||
|
"items": 0,
|
||||||
|
"users": 0,
|
||||||
|
"likes": 0,
|
||||||
|
"favorites": 0,
|
||||||
|
"comments": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 导入帖子
|
||||||
|
var posts []model.Post
|
||||||
|
if err := model.DB.Find(&posts).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, post := range posts {
|
||||||
|
embedding, err := gorse.GetEmbedding(strings.TrimSpace(post.Title + " " + post.Content))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[WARN] get embedding failed for post %s: %v", post.ID, err)
|
||||||
|
embedding = make([]float64, 1024)
|
||||||
|
}
|
||||||
|
_, err = gorseClient.InsertItem(ctx, gorseio.Item{
|
||||||
|
ItemId: post.ID,
|
||||||
|
IsHidden: post.DeletedAt.Valid,
|
||||||
|
Categories: buildPostCategories(&post),
|
||||||
|
Comment: post.Title,
|
||||||
|
Timestamp: post.CreatedAt.UTC().Truncate(time.Second),
|
||||||
|
Labels: map[string]any{
|
||||||
|
"embedding": embedding,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[WARN] import item failed (%s): %v", post.ID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stats["items"] = stats["items"].(int) + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// 导入用户
|
||||||
|
var users []model.User
|
||||||
|
if err := model.DB.Find(&users).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, user := range users {
|
||||||
|
_, err := gorseClient.InsertUser(ctx, gorseio.User{
|
||||||
|
UserId: user.ID,
|
||||||
|
Labels: map[string]any{
|
||||||
|
"posts_count": user.PostsCount,
|
||||||
|
"followers_count": user.FollowersCount,
|
||||||
|
"following_count": user.FollowingCount,
|
||||||
|
},
|
||||||
|
Comment: user.Nickname,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[WARN] import user failed (%s): %v", user.ID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stats["users"] = stats["users"].(int) + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// 导入点赞
|
||||||
|
var likes []model.PostLike
|
||||||
|
if err := model.DB.Find(&likes).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, like := range likes {
|
||||||
|
_, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{
|
||||||
|
FeedbackType: string(gorse.FeedbackTypeLike),
|
||||||
|
UserId: like.UserID,
|
||||||
|
ItemId: like.PostID,
|
||||||
|
Timestamp: like.CreatedAt.UTC().Truncate(time.Second),
|
||||||
|
}})
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[WARN] import like failed (%s/%s): %v", like.UserID, like.PostID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stats["likes"] = stats["likes"].(int) + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// 导入收藏
|
||||||
|
var favorites []model.Favorite
|
||||||
|
if err := model.DB.Find(&favorites).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, fav := range favorites {
|
||||||
|
_, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{
|
||||||
|
FeedbackType: string(gorse.FeedbackTypeStar),
|
||||||
|
UserId: fav.UserID,
|
||||||
|
ItemId: fav.PostID,
|
||||||
|
Timestamp: fav.CreatedAt.UTC().Truncate(time.Second),
|
||||||
|
}})
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[WARN] import favorite failed (%s/%s): %v", fav.UserID, fav.PostID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stats["favorites"] = stats["favorites"].(int) + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// 导入评论(按用户-帖子去重)
|
||||||
|
var comments []model.Comment
|
||||||
|
if err := model.DB.Where("status = ?", model.CommentStatusPublished).Find(&comments).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
for _, cm := range comments {
|
||||||
|
key := cm.UserID + ":" + cm.PostID
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
_, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{
|
||||||
|
FeedbackType: string(gorse.FeedbackTypeComment),
|
||||||
|
UserId: cm.UserID,
|
||||||
|
ItemId: cm.PostID,
|
||||||
|
Timestamp: cm.CreatedAt.UTC().Truncate(time.Second),
|
||||||
|
}})
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[WARN] import comment failed (%s/%s): %v", cm.UserID, cm.PostID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stats["comments"] = stats["comments"].(int) + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildPostCategories(post *model.Post) []string {
|
||||||
|
var categories []string
|
||||||
|
if post.ViewsCount > 1000 {
|
||||||
|
categories = append(categories, "hot_high")
|
||||||
|
} else if post.ViewsCount > 100 {
|
||||||
|
categories = append(categories, "hot_medium")
|
||||||
|
}
|
||||||
|
if post.LikesCount > 100 {
|
||||||
|
categories = append(categories, "likes_100+")
|
||||||
|
} else if post.LikesCount > 10 {
|
||||||
|
categories = append(categories, "likes_10+")
|
||||||
|
}
|
||||||
|
age := time.Since(post.CreatedAt)
|
||||||
|
if age < 24*time.Hour {
|
||||||
|
categories = append(categories, "today")
|
||||||
|
} else if age < 7*24*time.Hour {
|
||||||
|
categories = append(categories, "this_week")
|
||||||
|
}
|
||||||
|
return categories
|
||||||
|
}
|
||||||
1801
internal/handler/group_handler.go
Normal file
1801
internal/handler/group_handler.go
Normal file
File diff suppressed because it is too large
Load Diff
879
internal/handler/message_handler.go
Normal file
879
internal/handler/message_handler.go
Normal file
@@ -0,0 +1,879 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/dto"
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/pkg/response"
|
||||||
|
"carrot_bbs/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MessageHandler 消息处理器
|
||||||
|
type MessageHandler struct {
|
||||||
|
chatService service.ChatService
|
||||||
|
messageService *service.MessageService
|
||||||
|
userService *service.UserService
|
||||||
|
groupService service.GroupService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMessageHandler 创建消息处理器
|
||||||
|
func NewMessageHandler(chatService service.ChatService, messageService *service.MessageService, userService *service.UserService, groupService service.GroupService) *MessageHandler {
|
||||||
|
return &MessageHandler{
|
||||||
|
chatService: chatService,
|
||||||
|
messageService: messageService,
|
||||||
|
userService: userService,
|
||||||
|
groupService: groupService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversations 获取会话列表
|
||||||
|
// GET /api/conversations
|
||||||
|
func (h *MessageHandler) GetConversations(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
// 添加调试日志
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
|
||||||
|
convs, _, err := h.chatService.GetConversationList(c.Request.Context(), userID, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get conversations")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 过滤掉系统会话(系统通知现在使用独立的表)
|
||||||
|
filteredConvs := make([]*model.Conversation, 0)
|
||||||
|
for _, conv := range convs {
|
||||||
|
if conv.ID != model.SystemConversationID {
|
||||||
|
filteredConvs = append(filteredConvs, conv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应格式
|
||||||
|
result := make([]*dto.ConversationResponse, len(filteredConvs))
|
||||||
|
for i, conv := range filteredConvs {
|
||||||
|
// 获取未读数
|
||||||
|
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conv.ID, userID)
|
||||||
|
|
||||||
|
// 获取最后一条消息
|
||||||
|
var lastMessage *model.Message
|
||||||
|
messages, _, _ := h.chatService.GetMessages(c.Request.Context(), conv.ID, userID, 1, 1)
|
||||||
|
if len(messages) > 0 {
|
||||||
|
lastMessage = messages[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 群聊时返回member_count,私聊时返回participants
|
||||||
|
var resp *dto.ConversationResponse
|
||||||
|
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
|
||||||
|
isPinned := myParticipant != nil && myParticipant.IsPinned
|
||||||
|
if conv.Type == model.ConversationTypeGroup && conv.GroupID != nil && *conv.GroupID != "" {
|
||||||
|
// 群聊:实时计算群成员数量
|
||||||
|
memberCount, _ := h.groupService.GetMemberCount(*conv.GroupID)
|
||||||
|
// 创建响应并设置member_count
|
||||||
|
resp = dto.ConvertConversationToResponse(conv, nil, int(unreadCount), lastMessage, isPinned)
|
||||||
|
resp.MemberCount = memberCount
|
||||||
|
} else {
|
||||||
|
// 私聊:获取参与者信息
|
||||||
|
participants, _ := h.getConversationParticipants(c.Request.Context(), conv.ID, userID)
|
||||||
|
resp = dto.ConvertConversationToResponse(conv, participants, int(unreadCount), lastMessage, isPinned)
|
||||||
|
}
|
||||||
|
result[i] = resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新 total 为过滤后的数量
|
||||||
|
response.Paginated(c, result, int64(len(filteredConvs)), page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateConversation 创建私聊会话
|
||||||
|
// POST /api/conversations
|
||||||
|
func (h *MessageHandler) CreateConversation(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req dto.CreateConversationRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证目标用户是否存在
|
||||||
|
targetUser, err := h.userService.GetUserByID(c.Request.Context(), req.UserID)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "target user not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 不能和自己创建会话
|
||||||
|
if userID == req.UserID {
|
||||||
|
response.BadRequest(c, "cannot create conversation with yourself")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conv, err := h.chatService.GetOrCreateConversation(c.Request.Context(), userID, req.UserID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to create conversation")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取参与者信息
|
||||||
|
participants := []*model.User{targetUser}
|
||||||
|
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
|
||||||
|
isPinned := myParticipant != nil && myParticipant.IsPinned
|
||||||
|
|
||||||
|
response.Success(c, dto.ConvertConversationToResponse(conv, participants, 0, nil, isPinned))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversationByID 获取会话详情
|
||||||
|
// GET /api/conversations/:id
|
||||||
|
func (h *MessageHandler) GetConversationByID(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conversationIDStr := c.Param("id")
|
||||||
|
fmt.Printf("[DEBUG] GetConversationByID: conversationIDStr = %s\n", conversationIDStr)
|
||||||
|
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[DEBUG] GetConversationByID: failed to parse conversation ID: %v\n", err)
|
||||||
|
response.BadRequest(c, "invalid conversation id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("[DEBUG] GetConversationByID: conversationID = %s\n", conversationID)
|
||||||
|
|
||||||
|
conv, err := h.chatService.GetConversationByID(c.Request.Context(), conversationID, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取未读数
|
||||||
|
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
|
||||||
|
|
||||||
|
// 获取参与者信息
|
||||||
|
participants, _ := h.getConversationParticipants(c.Request.Context(), conversationID, userID)
|
||||||
|
|
||||||
|
// 获取当前用户的已读位置
|
||||||
|
myLastReadSeq := int64(0)
|
||||||
|
isPinned := false
|
||||||
|
allParticipants, _ := h.messageService.GetConversationParticipants(conversationID)
|
||||||
|
for _, p := range allParticipants {
|
||||||
|
if p.UserID == userID {
|
||||||
|
myLastReadSeq = p.LastReadSeq
|
||||||
|
isPinned = p.IsPinned
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取对方用户的已读位置
|
||||||
|
otherLastReadSeq := int64(0)
|
||||||
|
response.Success(c, dto.ConvertConversationToDetailResponse(conv, participants, unreadCount, nil, myLastReadSeq, otherLastReadSeq, isPinned))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessages 获取消息列表
|
||||||
|
// GET /api/conversations/:id/messages
|
||||||
|
func (h *MessageHandler) GetMessages(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conversationIDStr := c.Param("id")
|
||||||
|
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "invalid conversation id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否使用增量同步(after_seq参数)
|
||||||
|
afterSeqStr := c.Query("after_seq")
|
||||||
|
if afterSeqStr != "" {
|
||||||
|
// 增量同步模式
|
||||||
|
afterSeq, err := strconv.ParseInt(afterSeqStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "invalid after_seq")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||||
|
|
||||||
|
messages, err := h.chatService.GetMessagesAfterSeq(c.Request.Context(), conversationID, userID, afterSeq, limit)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应格式
|
||||||
|
result := dto.ConvertMessagesToResponse(messages)
|
||||||
|
|
||||||
|
response.Success(c, &dto.MessageSyncResponse{
|
||||||
|
Messages: result,
|
||||||
|
HasMore: len(messages) == limit,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否使用历史消息加载(before_seq参数)
|
||||||
|
beforeSeqStr := c.Query("before_seq")
|
||||||
|
if beforeSeqStr != "" {
|
||||||
|
// 加载更早的消息(下拉加载更多)
|
||||||
|
beforeSeq, err := strconv.ParseInt(beforeSeqStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "invalid before_seq")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||||
|
|
||||||
|
messages, err := h.chatService.GetMessagesBeforeSeq(c.Request.Context(), conversationID, userID, beforeSeq, limit)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应格式
|
||||||
|
result := dto.ConvertMessagesToResponse(messages)
|
||||||
|
|
||||||
|
response.Success(c, &dto.MessageSyncResponse{
|
||||||
|
Messages: result,
|
||||||
|
HasMore: len(messages) == limit,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 分页模式
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
|
||||||
|
messages, total, err := h.chatService.GetMessages(c.Request.Context(), conversationID, userID, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应格式
|
||||||
|
result := dto.ConvertMessagesToResponse(messages)
|
||||||
|
|
||||||
|
response.Paginated(c, result, total, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessage 发送消息
|
||||||
|
// POST /api/conversations/:id/messages
|
||||||
|
func (h *MessageHandler) SendMessage(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conversationIDStr := c.Param("id")
|
||||||
|
fmt.Printf("[DEBUG] SendMessage: conversationIDStr = %s\n", conversationIDStr)
|
||||||
|
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[DEBUG] SendMessage: failed to parse conversation ID: %v\n", err)
|
||||||
|
response.BadRequest(c, "invalid conversation id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("[DEBUG] SendMessage: conversationID = %s, userID = %s\n", conversationID, userID)
|
||||||
|
|
||||||
|
var req dto.SendMessageRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 直接使用 segments
|
||||||
|
msg, err := h.chatService.SendMessage(c.Request.Context(), userID, conversationID, req.Segments, req.ReplyToID)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.ConvertMessageToResponse(msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleSendMessage RESTful 风格的发送消息端点
|
||||||
|
// POST /api/v1/conversations/send_message
|
||||||
|
// 请求体格式: {"detail_type": "private", "conversation_id": "123445667", "segments": [{"type": "text", "data": {"text": "嗨~"}}]}
|
||||||
|
func (h *MessageHandler) HandleSendMessage(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var params dto.SendMessageParams
|
||||||
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证参数
|
||||||
|
if params.ConversationID == "" {
|
||||||
|
response.BadRequest(c, "conversation_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if params.DetailType == "" {
|
||||||
|
response.BadRequest(c, "detail_type is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if params.Segments == nil || len(params.Segments) == 0 {
|
||||||
|
response.BadRequest(c, "segments is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送消息
|
||||||
|
msg, err := h.chatService.SendMessage(c.Request.Context(), userID, params.ConversationID, params.Segments, params.ReplyToID)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建 WSEventResponse 格式响应
|
||||||
|
wsResponse := dto.WSEventResponse{
|
||||||
|
ID: msg.ID,
|
||||||
|
Time: msg.CreatedAt.UnixMilli(),
|
||||||
|
Type: "message",
|
||||||
|
DetailType: params.DetailType,
|
||||||
|
Seq: strconv.FormatInt(msg.Seq, 10),
|
||||||
|
Segments: params.Segments,
|
||||||
|
SenderID: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, wsResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleDeleteMsg 撤回消息
|
||||||
|
// POST /api/v1/messages/delete_msg
|
||||||
|
// 请求体格式: {"message_id": "xxx"}
|
||||||
|
func (h *MessageHandler) HandleDeleteMsg(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var params dto.DeleteMsgParams
|
||||||
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证参数
|
||||||
|
if params.MessageID == "" {
|
||||||
|
response.BadRequest(c, "message_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 撤回消息
|
||||||
|
err := h.chatService.RecallMessage(c.Request.Context(), params.MessageID, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "消息已撤回", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleGetConversationList 获取会话列表
|
||||||
|
// GET /api/v1/conversations/list
|
||||||
|
func (h *MessageHandler) HandleGetConversationList(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
|
||||||
|
convs, _, err := h.chatService.GetConversationList(c.Request.Context(), userID, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get conversations")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 过滤掉系统会话(系统通知现在使用独立的表)
|
||||||
|
filteredConvs := make([]*model.Conversation, 0)
|
||||||
|
for _, conv := range convs {
|
||||||
|
if conv.ID != model.SystemConversationID {
|
||||||
|
filteredConvs = append(filteredConvs, conv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应格式
|
||||||
|
result := make([]*dto.ConversationResponse, len(filteredConvs))
|
||||||
|
for i, conv := range filteredConvs {
|
||||||
|
// 获取未读数
|
||||||
|
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conv.ID, userID)
|
||||||
|
|
||||||
|
// 获取最后一条消息
|
||||||
|
var lastMessage *model.Message
|
||||||
|
messages, _, _ := h.chatService.GetMessages(c.Request.Context(), conv.ID, userID, 1, 1)
|
||||||
|
if len(messages) > 0 {
|
||||||
|
lastMessage = messages[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 群聊时返回member_count,私聊时返回participants
|
||||||
|
var resp *dto.ConversationResponse
|
||||||
|
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
|
||||||
|
isPinned := myParticipant != nil && myParticipant.IsPinned
|
||||||
|
if conv.Type == model.ConversationTypeGroup && conv.GroupID != nil && *conv.GroupID != "" {
|
||||||
|
// 群聊:实时计算群成员数量
|
||||||
|
memberCount, _ := h.groupService.GetMemberCount(*conv.GroupID)
|
||||||
|
// 创建响应并设置member_count
|
||||||
|
resp = dto.ConvertConversationToResponse(conv, nil, int(unreadCount), lastMessage, isPinned)
|
||||||
|
resp.MemberCount = memberCount
|
||||||
|
} else {
|
||||||
|
// 私聊:获取参与者信息
|
||||||
|
participants, _ := h.getConversationParticipants(c.Request.Context(), conv.ID, userID)
|
||||||
|
resp = dto.ConvertConversationToResponse(conv, participants, int(unreadCount), lastMessage, isPinned)
|
||||||
|
}
|
||||||
|
result[i] = resp
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Paginated(c, result, int64(len(filteredConvs)), page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleDeleteConversationForSelf 仅自己删除会话
|
||||||
|
// DELETE /api/v1/conversations/:id/self
|
||||||
|
func (h *MessageHandler) HandleDeleteConversationForSelf(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conversationID := c.Param("id")
|
||||||
|
if conversationID == "" {
|
||||||
|
response.BadRequest(c, "conversation id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.chatService.DeleteConversationForSelf(c.Request.Context(), conversationID, userID); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "conversation deleted for self", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAsRead 标记为已读
|
||||||
|
// POST /api/conversations/:id/read
|
||||||
|
func (h *MessageHandler) MarkAsRead(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conversationIDStr := c.Param("id")
|
||||||
|
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "invalid conversation id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req dto.MarkReadRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "last_read_seq is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.chatService.MarkAsRead(c.Request.Context(), conversationID, userID, req.LastReadSeq)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "marked as read", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnreadCount 获取未读消息总数
|
||||||
|
// GET /api/conversations/unread/count
|
||||||
|
func (h *MessageHandler) GetUnreadCount(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
// 添加调试日志
|
||||||
|
fmt.Printf("[DEBUG] GetUnreadCount: user_id from context = %q\n", userID)
|
||||||
|
|
||||||
|
if userID == "" {
|
||||||
|
fmt.Printf("[DEBUG] GetUnreadCount: user_id is empty, returning 401\n")
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := h.chatService.GetAllUnreadCount(c.Request.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get unread count")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, &dto.UnreadCountResponse{
|
||||||
|
TotalUnreadCount: count,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversationUnreadCount 获取单个会话的未读数
|
||||||
|
// GET /api/conversations/:id/unread/count
|
||||||
|
func (h *MessageHandler) GetConversationUnreadCount(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conversationIDStr := c.Param("id")
|
||||||
|
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "invalid conversation id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, &dto.ConversationUnreadCountResponse{
|
||||||
|
ConversationID: conversationID,
|
||||||
|
UnreadCount: count,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecallMessage 撤回消息
|
||||||
|
// POST /api/messages/:id/recall
|
||||||
|
func (h *MessageHandler) RecallMessage(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
messageIDStr := c.Param("id")
|
||||||
|
// 直接使用 string 类型的 messageID
|
||||||
|
err := h.chatService.RecallMessage(c.Request.Context(), messageIDStr, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "message recalled", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteMessage 删除消息
|
||||||
|
// DELETE /api/messages/:id
|
||||||
|
func (h *MessageHandler) DeleteMessage(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
messageIDStr := c.Param("id")
|
||||||
|
// 直接使用 string 类型的 messageID
|
||||||
|
err := h.chatService.DeleteMessage(c.Request.Context(), messageIDStr, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "message deleted", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数:验证内容类型
|
||||||
|
func isValidContentType(contentType model.ContentType) bool {
|
||||||
|
switch contentType {
|
||||||
|
case model.ContentTypeText, model.ContentTypeImage, model.ContentTypeVideo, model.ContentTypeAudio, model.ContentTypeFile:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数:获取会话参与者信息
|
||||||
|
func (h *MessageHandler) getConversationParticipants(ctx context.Context, conversationID string, currentUserID string) ([]*model.User, error) {
|
||||||
|
// 从repository获取参与者列表
|
||||||
|
participants, err := h.messageService.GetConversationParticipants(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取参与者用户信息
|
||||||
|
var users []*model.User
|
||||||
|
for _, p := range participants {
|
||||||
|
// 跳过当前用户
|
||||||
|
if p.UserID == currentUserID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
user, err := h.userService.GetUserByID(ctx, p.UserID)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
users = append(users, user)
|
||||||
|
}
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取当前用户在会话中的参与者信息
|
||||||
|
func (h *MessageHandler) getMyConversationParticipant(conversationID string, userID string) (*model.ConversationParticipant, error) {
|
||||||
|
participants, err := h.messageService.GetConversationParticipants(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, p := range participants {
|
||||||
|
if p.UserID == userID {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== RESTful Action 端点 ====================
|
||||||
|
|
||||||
|
// HandleCreateConversation 创建会话
|
||||||
|
// POST /api/v1/conversations/create
|
||||||
|
func (h *MessageHandler) HandleCreateConversation(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var params dto.CreateConversationParams
|
||||||
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证目标用户是否存在
|
||||||
|
targetUser, err := h.userService.GetUserByID(c.Request.Context(), params.UserID)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "target user not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 不能和自己创建会话
|
||||||
|
if userID == params.UserID {
|
||||||
|
response.BadRequest(c, "cannot create conversation with yourself")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conv, err := h.chatService.GetOrCreateConversation(c.Request.Context(), userID, params.UserID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to create conversation")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取参与者信息
|
||||||
|
participants := []*model.User{targetUser}
|
||||||
|
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
|
||||||
|
isPinned := myParticipant != nil && myParticipant.IsPinned
|
||||||
|
|
||||||
|
response.Success(c, dto.ConvertConversationToResponse(conv, participants, 0, nil, isPinned))
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleGetConversation 获取会话详情
|
||||||
|
// GET /api/v1/conversations/get?conversation_id=xxx
|
||||||
|
func (h *MessageHandler) HandleGetConversation(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conversationID := c.Query("conversation_id")
|
||||||
|
if conversationID == "" {
|
||||||
|
response.BadRequest(c, "conversation_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conv, err := h.chatService.GetConversationByID(c.Request.Context(), conversationID, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取未读数
|
||||||
|
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
|
||||||
|
|
||||||
|
// 获取参与者信息
|
||||||
|
participants, _ := h.getConversationParticipants(c.Request.Context(), conversationID, userID)
|
||||||
|
|
||||||
|
// 获取当前用户的已读位置
|
||||||
|
myLastReadSeq := int64(0)
|
||||||
|
isPinned := false
|
||||||
|
allParticipants, _ := h.messageService.GetConversationParticipants(conversationID)
|
||||||
|
for _, p := range allParticipants {
|
||||||
|
if p.UserID == userID {
|
||||||
|
myLastReadSeq = p.LastReadSeq
|
||||||
|
isPinned = p.IsPinned
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取对方用户的已读位置
|
||||||
|
otherLastReadSeq := int64(0)
|
||||||
|
response.Success(c, dto.ConvertConversationToDetailResponse(conv, participants, unreadCount, nil, myLastReadSeq, otherLastReadSeq, isPinned))
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleGetMessages 获取会话消息
|
||||||
|
// GET /api/v1/conversations/get_messages?conversation_id=xxx
|
||||||
|
func (h *MessageHandler) HandleGetMessages(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conversationID := c.Query("conversation_id")
|
||||||
|
if conversationID == "" {
|
||||||
|
response.BadRequest(c, "conversation_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否使用增量同步(after_seq参数)
|
||||||
|
afterSeqStr := c.Query("after_seq")
|
||||||
|
if afterSeqStr != "" {
|
||||||
|
// 增量同步模式
|
||||||
|
afterSeq, err := strconv.ParseInt(afterSeqStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "invalid after_seq")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||||
|
|
||||||
|
messages, err := h.chatService.GetMessagesAfterSeq(c.Request.Context(), conversationID, userID, afterSeq, limit)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应格式
|
||||||
|
result := dto.ConvertMessagesToResponse(messages)
|
||||||
|
|
||||||
|
response.Success(c, &dto.MessageSyncResponse{
|
||||||
|
Messages: result,
|
||||||
|
HasMore: len(messages) == limit,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否使用历史消息加载(before_seq参数)
|
||||||
|
beforeSeqStr := c.Query("before_seq")
|
||||||
|
if beforeSeqStr != "" {
|
||||||
|
// 加载更早的消息(下拉加载更多)
|
||||||
|
beforeSeq, err := strconv.ParseInt(beforeSeqStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "invalid before_seq")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||||
|
|
||||||
|
messages, err := h.chatService.GetMessagesBeforeSeq(c.Request.Context(), conversationID, userID, beforeSeq, limit)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应格式
|
||||||
|
result := dto.ConvertMessagesToResponse(messages)
|
||||||
|
|
||||||
|
response.Success(c, &dto.MessageSyncResponse{
|
||||||
|
Messages: result,
|
||||||
|
HasMore: len(messages) == limit,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 分页模式
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
|
||||||
|
messages, total, err := h.chatService.GetMessages(c.Request.Context(), conversationID, userID, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应格式
|
||||||
|
result := dto.ConvertMessagesToResponse(messages)
|
||||||
|
|
||||||
|
response.Paginated(c, result, total, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleMarkRead 标记已读
|
||||||
|
// POST /api/v1/conversations/mark_read
|
||||||
|
func (h *MessageHandler) HandleMarkRead(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var params dto.MarkReadParams
|
||||||
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if params.ConversationID == "" {
|
||||||
|
response.BadRequest(c, "conversation_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.chatService.MarkAsRead(c.Request.Context(), params.ConversationID, userID, params.LastReadSeq)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "marked as read", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleSetConversationPinned 设置会话置顶
|
||||||
|
// POST /api/v1/conversations/set_pinned
|
||||||
|
func (h *MessageHandler) HandleSetConversationPinned(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var params dto.SetConversationPinnedParams
|
||||||
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if params.ConversationID == "" {
|
||||||
|
response.BadRequest(c, "conversation_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.chatService.SetConversationPinned(c.Request.Context(), params.ConversationID, userID, params.IsPinned); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "conversation pinned status updated", gin.H{
|
||||||
|
"conversation_id": params.ConversationID,
|
||||||
|
"is_pinned": params.IsPinned,
|
||||||
|
})
|
||||||
|
}
|
||||||
132
internal/handler/notification_handler.go
Normal file
132
internal/handler/notification_handler.go
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/pkg/response"
|
||||||
|
"carrot_bbs/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NotificationHandler 通知处理器
|
||||||
|
type NotificationHandler struct {
|
||||||
|
notificationService *service.NotificationService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNotificationHandler 创建通知处理器
|
||||||
|
func NewNotificationHandler(notificationService *service.NotificationService) *NotificationHandler {
|
||||||
|
return &NotificationHandler{
|
||||||
|
notificationService: notificationService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNotifications 获取通知列表
|
||||||
|
func (h *NotificationHandler) GetNotifications(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
unreadOnly := c.Query("unread_only") == "true"
|
||||||
|
|
||||||
|
notifications, total, err := h.notificationService.GetByUserID(c.Request.Context(), userID, page, pageSize, unreadOnly)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get notifications")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Paginated(c, notifications, total, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAsRead 标记为已读
|
||||||
|
func (h *NotificationHandler) MarkAsRead(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
err := h.notificationService.MarkAsReadWithUserID(c.Request.Context(), id, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to mark as read")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "marked as read", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAllAsRead 标记所有为已读
|
||||||
|
func (h *NotificationHandler) MarkAllAsRead(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.notificationService.MarkAllAsRead(c.Request.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to mark all as read")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "all marked as read", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnreadCount 获取未读数量
|
||||||
|
func (h *NotificationHandler) GetUnreadCount(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := h.notificationService.GetUnreadCount(c.Request.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get unread count")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"count": count})
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteNotification 删除通知
|
||||||
|
func (h *NotificationHandler) DeleteNotification(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
err := h.notificationService.DeleteNotification(c.Request.Context(), id, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to delete notification")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAllNotifications 清空所有通知
|
||||||
|
func (h *NotificationHandler) ClearAllNotifications(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.notificationService.ClearAllNotifications(c.Request.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to clear notifications")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
511
internal/handler/post_handler.go
Normal file
511
internal/handler/post_handler.go
Normal file
@@ -0,0 +1,511 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/dto"
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/pkg/response"
|
||||||
|
"carrot_bbs/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostHandler 帖子处理器
|
||||||
|
type PostHandler struct {
|
||||||
|
postService *service.PostService
|
||||||
|
userService *service.UserService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPostHandler 创建帖子处理器
|
||||||
|
func NewPostHandler(postService *service.PostService, userService *service.UserService) *PostHandler {
|
||||||
|
return &PostHandler{
|
||||||
|
postService: postService,
|
||||||
|
userService: userService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 创建帖子
|
||||||
|
func (h *PostHandler) Create(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateRequest struct {
|
||||||
|
Title string `json:"title" binding:"required"`
|
||||||
|
Content string `json:"content" binding:"required"`
|
||||||
|
Images []string `json:"images"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req CreateRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
post, err := h.postService.Create(c.Request.Context(), userID, req.Title, req.Content, req.Images)
|
||||||
|
if err != nil {
|
||||||
|
var moderationErr *service.PostModerationRejectedError
|
||||||
|
if errors.As(err, &moderationErr) {
|
||||||
|
response.BadRequest(c, moderationErr.UserMessage())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.InternalServerError(c, "failed to create post")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.ConvertPostToResponse(post, false, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID 获取帖子(不增加浏览量)
|
||||||
|
func (h *PostHandler) GetByID(c *gin.Context) {
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "post not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 非作者不可查看未发布内容
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
if post.Status != model.PostStatusPublished && post.UserID != currentUserID {
|
||||||
|
response.NotFound(c, "post not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 注意:不再自动增加浏览量,浏览量通过 RecordView 端点单独记录
|
||||||
|
|
||||||
|
// 获取当前用户ID用于判断点赞和收藏状态
|
||||||
|
fmt.Printf("[DEBUG] GetByID - postID: %s, currentUserID: %s\n", id, currentUserID)
|
||||||
|
|
||||||
|
var isLiked, isFavorited bool
|
||||||
|
if currentUserID != "" {
|
||||||
|
isLiked = h.postService.IsLiked(c.Request.Context(), id, currentUserID)
|
||||||
|
isFavorited = h.postService.IsFavorited(c.Request.Context(), id, currentUserID)
|
||||||
|
fmt.Printf("[DEBUG] GetByID - isLiked: %v, isFavorited: %v\n", isLiked, isFavorited)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("[DEBUG] GetByID - user not logged in, isLiked: false, isFavorited: false\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有当前用户,检查与帖子作者的相互关注状态
|
||||||
|
var authorWithFollowStatus *dto.UserResponse
|
||||||
|
if currentUserID != "" && post.User != nil {
|
||||||
|
_, isFollowing, isFollowingMe, err := h.userService.GetUserByIDWithMutualFollowStatus(c.Request.Context(), post.UserID, currentUserID)
|
||||||
|
if err == nil {
|
||||||
|
authorWithFollowStatus = dto.ConvertUserToResponseWithMutualFollow(post.User, isFollowing, isFollowingMe)
|
||||||
|
} else {
|
||||||
|
// 如果出错,使用默认的author
|
||||||
|
authorWithFollowStatus = dto.ConvertUserToResponse(post.User)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建响应
|
||||||
|
responseData := &dto.PostResponse{
|
||||||
|
ID: post.ID,
|
||||||
|
UserID: post.UserID,
|
||||||
|
Title: post.Title,
|
||||||
|
Content: post.Content,
|
||||||
|
Images: dto.ConvertPostImagesToResponse(post.Images),
|
||||||
|
LikesCount: post.LikesCount,
|
||||||
|
CommentsCount: post.CommentsCount,
|
||||||
|
FavoritesCount: post.FavoritesCount,
|
||||||
|
SharesCount: post.SharesCount,
|
||||||
|
ViewsCount: post.ViewsCount,
|
||||||
|
IsPinned: post.IsPinned,
|
||||||
|
IsLocked: post.IsLocked,
|
||||||
|
IsVote: post.IsVote,
|
||||||
|
CreatedAt: dto.FormatTime(post.CreatedAt),
|
||||||
|
Author: authorWithFollowStatus,
|
||||||
|
IsLiked: isLiked,
|
||||||
|
IsFavorited: isFavorited,
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, responseData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordView 记录帖子浏览(增加浏览量)
|
||||||
|
func (h *PostHandler) RecordView(c *gin.Context) {
|
||||||
|
id := c.Param("id")
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
|
||||||
|
// 验证帖子存在
|
||||||
|
_, err := h.postService.GetByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "post not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 增加浏览量
|
||||||
|
if err := h.postService.IncrementViews(c.Request.Context(), id, userID); err != nil {
|
||||||
|
fmt.Printf("[DEBUG] Failed to increment views for post %s: %v\n", id, err)
|
||||||
|
response.InternalServerError(c, "failed to record view")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 获取帖子列表
|
||||||
|
func (h *PostHandler) List(c *gin.Context) {
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
userID := c.Query("user_id")
|
||||||
|
tab := c.Query("tab") // recommend, follow, hot, latest
|
||||||
|
|
||||||
|
// 获取当前用户ID
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
|
||||||
|
var posts []*model.Post
|
||||||
|
var total int64
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// 根据 tab 参数选择不同的获取方式
|
||||||
|
switch tab {
|
||||||
|
case "follow":
|
||||||
|
// 获取关注用户的帖子,需要登录
|
||||||
|
if currentUserID == "" {
|
||||||
|
response.Unauthorized(c, "请先登录")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
posts, total, err = h.postService.GetFollowingPosts(c.Request.Context(), currentUserID, page, pageSize)
|
||||||
|
case "hot":
|
||||||
|
// 获取热门帖子
|
||||||
|
posts, total, err = h.postService.GetHotPosts(c.Request.Context(), page, pageSize)
|
||||||
|
case "recommend":
|
||||||
|
// 推荐帖子(从Gorse获取个性化推荐)
|
||||||
|
posts, total, err = h.postService.GetRecommendedPosts(c.Request.Context(), currentUserID, page, pageSize)
|
||||||
|
case "latest":
|
||||||
|
// 最新帖子
|
||||||
|
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
|
||||||
|
default:
|
||||||
|
// 默认获取最新帖子
|
||||||
|
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get posts")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("[DEBUG] List - tab: %s, currentUserID: %s, posts count: %d\n", tab, currentUserID, len(posts))
|
||||||
|
|
||||||
|
isLikedMap := make(map[string]bool)
|
||||||
|
isFavoritedMap := make(map[string]bool)
|
||||||
|
if currentUserID != "" {
|
||||||
|
for _, post := range posts {
|
||||||
|
isLiked := h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
|
||||||
|
isFavorited := h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
|
||||||
|
isLikedMap[post.ID] = isLiked
|
||||||
|
isFavoritedMap[post.ID] = isFavorited
|
||||||
|
fmt.Printf("[DEBUG] List - postID: %s, isLiked: %v, isFavorited: %v\n", post.ID, isLiked, isFavorited)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fmt.Printf("[DEBUG] List - user not logged in\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应结构
|
||||||
|
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
|
||||||
|
|
||||||
|
response.Paginated(c, postResponses, total, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 更新帖子
|
||||||
|
func (h *PostHandler) Update(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "post not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if post.UserID != userID {
|
||||||
|
response.Forbidden(c, "cannot update others' post")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateRequest struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req UpdateRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Title != "" {
|
||||||
|
post.Title = req.Title
|
||||||
|
}
|
||||||
|
if req.Content != "" {
|
||||||
|
post.Content = req.Content
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.postService.Update(c.Request.Context(), post)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to update post")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
var isLiked, isFavorited bool
|
||||||
|
if currentUserID != "" {
|
||||||
|
isLiked = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
|
||||||
|
isFavorited = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除帖子
|
||||||
|
func (h *PostHandler) Delete(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := c.Param("id")
|
||||||
|
|
||||||
|
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "post not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if post.UserID != userID {
|
||||||
|
response.Forbidden(c, "cannot delete others' post")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.postService.Delete(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to delete post")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "post deleted", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Like 点赞帖子
|
||||||
|
func (h *PostHandler) Like(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := c.Param("id")
|
||||||
|
fmt.Printf("[DEBUG] Like - postID: %s, userID: %s\n", id, userID)
|
||||||
|
|
||||||
|
err := h.postService.Like(c.Request.Context(), id, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to like post")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取更新后的帖子状态
|
||||||
|
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get post")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
|
||||||
|
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
|
||||||
|
fmt.Printf("[DEBUG] Like - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
|
||||||
|
|
||||||
|
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unlike 取消点赞
|
||||||
|
func (h *PostHandler) Unlike(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := c.Param("id")
|
||||||
|
fmt.Printf("[DEBUG] Unlike - postID: %s, userID: %s\n", id, userID)
|
||||||
|
|
||||||
|
err := h.postService.Unlike(c.Request.Context(), id, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to unlike post")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取更新后的帖子状态
|
||||||
|
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get post")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
|
||||||
|
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
|
||||||
|
fmt.Printf("[DEBUG] Unlike - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
|
||||||
|
|
||||||
|
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Favorite 收藏帖子
|
||||||
|
func (h *PostHandler) Favorite(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := c.Param("id")
|
||||||
|
fmt.Printf("[DEBUG] Favorite - postID: %s, userID: %s\n", id, userID)
|
||||||
|
|
||||||
|
err := h.postService.Favorite(c.Request.Context(), id, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to favorite post")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取更新后的帖子状态
|
||||||
|
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get post")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
|
||||||
|
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
|
||||||
|
fmt.Printf("[DEBUG] Favorite - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
|
||||||
|
|
||||||
|
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unfavorite 取消收藏
|
||||||
|
func (h *PostHandler) Unfavorite(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := c.Param("id")
|
||||||
|
fmt.Printf("[DEBUG] Unfavorite - postID: %s, userID: %s\n", id, userID)
|
||||||
|
|
||||||
|
err := h.postService.Unfavorite(c.Request.Context(), id, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to unfavorite post")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取更新后的帖子状态
|
||||||
|
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get post")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
|
||||||
|
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
|
||||||
|
fmt.Printf("[DEBUG] Unfavorite - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
|
||||||
|
|
||||||
|
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserPosts 获取用户帖子列表
|
||||||
|
func (h *PostHandler) GetUserPosts(c *gin.Context) {
|
||||||
|
userID := c.Param("id")
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
|
||||||
|
posts, total, err := h.postService.GetUserPosts(c.Request.Context(), userID, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get user posts")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取当前用户ID用于判断点赞和收藏状态
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
isLikedMap := make(map[string]bool)
|
||||||
|
isFavoritedMap := make(map[string]bool)
|
||||||
|
if currentUserID != "" {
|
||||||
|
for _, post := range posts {
|
||||||
|
isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
|
||||||
|
isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应结构
|
||||||
|
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
|
||||||
|
|
||||||
|
response.Paginated(c, postResponses, total, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFavorites 获取收藏列表
|
||||||
|
func (h *PostHandler) GetFavorites(c *gin.Context) {
|
||||||
|
userID := c.Param("id")
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
|
||||||
|
posts, total, err := h.postService.GetFavorites(c.Request.Context(), userID, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get favorites")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取当前用户ID用于判断点赞和收藏状态
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
isLikedMap := make(map[string]bool)
|
||||||
|
isFavoritedMap := make(map[string]bool)
|
||||||
|
if currentUserID != "" {
|
||||||
|
for _, post := range posts {
|
||||||
|
isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
|
||||||
|
isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应结构
|
||||||
|
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
|
||||||
|
|
||||||
|
response.Paginated(c, postResponses, total, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search 搜索帖子
|
||||||
|
func (h *PostHandler) Search(c *gin.Context) {
|
||||||
|
keyword := c.Query("keyword")
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
|
||||||
|
posts, total, err := h.postService.Search(c.Request.Context(), keyword, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to search posts")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取当前用户ID用于判断点赞和收藏状态
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
isLikedMap := make(map[string]bool)
|
||||||
|
isFavoritedMap := make(map[string]bool)
|
||||||
|
if currentUserID != "" {
|
||||||
|
for _, post := range posts {
|
||||||
|
isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
|
||||||
|
isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应结构
|
||||||
|
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
|
||||||
|
|
||||||
|
response.Paginated(c, postResponses, total, page, pageSize)
|
||||||
|
}
|
||||||
157
internal/handler/push_handler.go
Normal file
157
internal/handler/push_handler.go
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/dto"
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/pkg/response"
|
||||||
|
"carrot_bbs/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PushHandler 推送处理器
|
||||||
|
type PushHandler struct {
|
||||||
|
pushService service.PushService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPushHandler 创建推送处理器
|
||||||
|
func NewPushHandler(pushService service.PushService) *PushHandler {
|
||||||
|
return &PushHandler{
|
||||||
|
pushService: pushService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterDevice 注册设备
|
||||||
|
// POST /api/v1/push/devices
|
||||||
|
func (h *PushHandler) RegisterDevice(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req dto.RegisterDeviceRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证设备类型
|
||||||
|
deviceType := model.DeviceType(req.DeviceType)
|
||||||
|
if !isValidDeviceType(deviceType) {
|
||||||
|
response.BadRequest(c, "invalid device type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.pushService.RegisterDevice(c.Request.Context(), userID, req.DeviceID, deviceType, req.PushToken)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to register device")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "device registered successfully", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterDevice 注销设备
|
||||||
|
// DELETE /api/v1/push/devices/:device_id
|
||||||
|
func (h *PushHandler) UnregisterDevice(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceID := c.Param("device_id")
|
||||||
|
if deviceID == "" {
|
||||||
|
response.BadRequest(c, "device_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.pushService.UnregisterDevice(c.Request.Context(), deviceID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to unregister device")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "device unregistered successfully", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMyDevices 获取当前用户的设备列表
|
||||||
|
// GET /api/v1/push/devices
|
||||||
|
func (h *PushHandler) GetMyDevices(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 这里需要从DeviceTokenRepository获取设备列表
|
||||||
|
// 由于PushService接口没有提供获取设备列表的方法,我们暂时返回空列表
|
||||||
|
// TODO: 在PushService接口中添加GetUserDevices方法
|
||||||
|
_ = userID // 避免未使用变量警告
|
||||||
|
|
||||||
|
response.Success(c, []*dto.DeviceTokenResponse{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPushRecords 获取推送记录
|
||||||
|
// GET /api/v1/push/records
|
||||||
|
func (h *PushHandler) GetPushRecords(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
records, err := h.pushService.GetPendingPushes(c.Request.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get push records")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, &dto.PushRecordListResponse{
|
||||||
|
Records: dto.PushRecordsToResponse(records),
|
||||||
|
Total: int64(len(records)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 辅助函数:验证设备类型
|
||||||
|
func isValidDeviceType(deviceType model.DeviceType) bool {
|
||||||
|
switch deviceType {
|
||||||
|
case model.DeviceTypeIOS, model.DeviceTypeAndroid, model.DeviceTypeWeb:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDeviceToken 更新设备推送Token
|
||||||
|
// PUT /api/v1/push/devices/:device_id/token
|
||||||
|
func (h *PushHandler) UpdateDeviceToken(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceID := c.Param("device_id")
|
||||||
|
if deviceID == "" {
|
||||||
|
response.BadRequest(c, "device_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req struct {
|
||||||
|
PushToken string `json:"push_token" binding:"required"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.pushService.UpdateDeviceToken(c.Request.Context(), deviceID, req.PushToken)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to update device token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "device token updated successfully", nil)
|
||||||
|
}
|
||||||
164
internal/handler/sticker_handler.go
Normal file
164
internal/handler/sticker_handler.go
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/pkg/response"
|
||||||
|
"carrot_bbs/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StickerHandler 自定义表情处理器
|
||||||
|
type StickerHandler struct {
|
||||||
|
stickerService service.StickerService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStickerHandler 创建自定义表情处理器
|
||||||
|
func NewStickerHandler(stickerService service.StickerService) *StickerHandler {
|
||||||
|
return &StickerHandler{
|
||||||
|
stickerService: stickerService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStickersRequest 获取表情列表请求
|
||||||
|
type GetStickersRequest struct {
|
||||||
|
UserID string `form:"user_id" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddStickerRequest 添加表情请求
|
||||||
|
type AddStickerRequest struct {
|
||||||
|
URL string `json:"url" binding:"required"`
|
||||||
|
Width int `json:"width"`
|
||||||
|
Height int `json:"height"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteStickerRequest 删除表情请求
|
||||||
|
type DeleteStickerRequest struct {
|
||||||
|
StickerID string `json:"sticker_id" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReorderStickersRequest 重新排序请求
|
||||||
|
type ReorderStickersRequest struct {
|
||||||
|
Orders map[string]int `json:"orders" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckStickerRequest 检查表情是否存在请求
|
||||||
|
type CheckStickerRequest struct {
|
||||||
|
URL string `form:"url" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStickers 获取用户的表情列表
|
||||||
|
func (h *StickerHandler) GetStickers(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stickers, err := h.stickerService.GetUserStickers(userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get stickers")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"stickers": stickers})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSticker 添加表情
|
||||||
|
func (h *StickerHandler) AddSticker(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req AddStickerRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sticker, err := h.stickerService.AddSticker(userID, req.URL, req.Width, req.Height)
|
||||||
|
if err != nil {
|
||||||
|
if err == service.ErrStickerAlreadyExists {
|
||||||
|
response.Error(c, http.StatusConflict, "sticker already exists")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err == service.ErrInvalidStickerURL {
|
||||||
|
response.BadRequest(c, "invalid sticker url, only http/https is allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.InternalServerError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"sticker": sticker})
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSticker 删除表情
|
||||||
|
func (h *StickerHandler) DeleteSticker(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req DeleteStickerRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.stickerService.DeleteSticker(userID, req.StickerID); err != nil {
|
||||||
|
response.InternalServerError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "sticker deleted successfully", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReorderStickers 重新排序表情
|
||||||
|
func (h *StickerHandler) ReorderStickers(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req ReorderStickersRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.stickerService.ReorderStickers(userID, req.Orders); err != nil {
|
||||||
|
response.InternalServerError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "stickers reordered successfully", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckStickerExists 检查表情是否存在
|
||||||
|
func (h *StickerHandler) CheckStickerExists(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
url := c.Query("url")
|
||||||
|
if url == "" {
|
||||||
|
response.BadRequest(c, "url is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err := h.stickerService.CheckExists(userID, url)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"exists": exists})
|
||||||
|
}
|
||||||
154
internal/handler/system_message_handler.go
Normal file
154
internal/handler/system_message_handler.go
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/cache"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/dto"
|
||||||
|
"carrot_bbs/internal/pkg/response"
|
||||||
|
"carrot_bbs/internal/repository"
|
||||||
|
"carrot_bbs/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SystemMessageHandler 系统消息处理器
|
||||||
|
type SystemMessageHandler struct {
|
||||||
|
systemMsgService service.SystemMessageService
|
||||||
|
notifyRepo *repository.SystemNotificationRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSystemMessageHandler 创建系统消息处理器
|
||||||
|
func NewSystemMessageHandler(
|
||||||
|
systemMsgService service.SystemMessageService,
|
||||||
|
notifyRepo *repository.SystemNotificationRepository,
|
||||||
|
) *SystemMessageHandler {
|
||||||
|
return &SystemMessageHandler{
|
||||||
|
systemMsgService: systemMsgService,
|
||||||
|
notifyRepo: notifyRepo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSystemMessages 获取系统消息列表
|
||||||
|
// GET /api/v1/messages/system
|
||||||
|
func (h *SystemMessageHandler) GetSystemMessages(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
|
||||||
|
// 获取当前用户的系统通知(从独立表中获取)
|
||||||
|
notifications, total, err := h.notifyRepo.GetByReceiverID(userID, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get system messages")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应格式
|
||||||
|
result := make([]*dto.SystemMessageResponse, 0)
|
||||||
|
for _, n := range notifications {
|
||||||
|
resp := dto.SystemNotificationToResponse(n)
|
||||||
|
result = append(result, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Paginated(c, result, total, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnreadCount 获取系统消息未读数
|
||||||
|
// GET /api/v1/messages/system/unread-count
|
||||||
|
func (h *SystemMessageHandler) GetUnreadCount(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取当前用户的未读通知数
|
||||||
|
unreadCount, err := h.notifyRepo.GetUnreadCount(userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get unread count")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, &dto.SystemUnreadCountResponse{
|
||||||
|
UnreadCount: unreadCount,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAsRead 标记系统消息为已读
|
||||||
|
// PUT /api/v1/messages/system/:id/read
|
||||||
|
func (h *SystemMessageHandler) MarkAsRead(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
notificationIDStr := c.Param("id")
|
||||||
|
notificationID, err := strconv.ParseInt(notificationIDStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "invalid notification id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 标记为已读
|
||||||
|
err = h.notifyRepo.MarkAsRead(notificationID, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to mark as read")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cache.InvalidateUnreadSystem(cache.GetCache(), userID)
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "marked as read", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAllAsRead 标记所有系统消息为已读
|
||||||
|
// PUT /api/v1/messages/system/read-all
|
||||||
|
func (h *SystemMessageHandler) MarkAllAsRead(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 标记当前用户所有通知为已读
|
||||||
|
err := h.notifyRepo.MarkAllAsRead(userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to mark all as read")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cache.InvalidateUnreadSystem(cache.GetCache(), userID)
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "all messages marked as read", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSystemMessage 删除系统消息
|
||||||
|
// DELETE /api/v1/messages/system/:id
|
||||||
|
func (h *SystemMessageHandler) DeleteSystemMessage(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
notificationIDStr := c.Param("id")
|
||||||
|
notificationID, err := strconv.ParseInt(notificationIDStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "invalid notification id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除通知
|
||||||
|
err = h.notifyRepo.Delete(notificationID, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to delete notification")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cache.InvalidateUnreadSystem(cache.GetCache(), userID)
|
||||||
|
|
||||||
|
response.SuccessWithMessage(c, "notification deleted", nil)
|
||||||
|
}
|
||||||
90
internal/handler/upload_handler.go
Normal file
90
internal/handler/upload_handler.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/pkg/response"
|
||||||
|
"carrot_bbs/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UploadHandler 上传处理器
|
||||||
|
type UploadHandler struct {
|
||||||
|
uploadService *service.UploadService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUploadHandler 创建上传处理器
|
||||||
|
func NewUploadHandler(uploadService *service.UploadService) *UploadHandler {
|
||||||
|
return &UploadHandler{
|
||||||
|
uploadService: uploadService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadImage 上传图片
|
||||||
|
func (h *UploadHandler) UploadImage(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := c.FormFile("image")
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "image file is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
url, err := h.uploadService.UploadImage(c.Request.Context(), file)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to upload image")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"url": url})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadAvatar 上传头像
|
||||||
|
func (h *UploadHandler) UploadAvatar(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := c.FormFile("image")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "avatar file is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
url, err := h.uploadService.UploadAvatar(c.Request.Context(), userID, file)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to upload avatar")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"url": url})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadCover 上传头图(个人主页封面)
|
||||||
|
func (h *UploadHandler) UploadCover(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := c.FormFile("image")
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "image file is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
url, err := h.uploadService.UploadCover(c.Request.Context(), userID, file)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to upload cover")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"url": url})
|
||||||
|
}
|
||||||
705
internal/handler/user_handler.go
Normal file
705
internal/handler/user_handler.go
Normal file
@@ -0,0 +1,705 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/dto"
|
||||||
|
"carrot_bbs/internal/pkg/response"
|
||||||
|
"carrot_bbs/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserHandler 用户处理器
|
||||||
|
type UserHandler struct {
|
||||||
|
userService *service.UserService
|
||||||
|
jwtService *service.JWTService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUserHandler 创建用户处理器
|
||||||
|
func NewUserHandler(userService *service.UserService) *UserHandler {
|
||||||
|
return &UserHandler{
|
||||||
|
userService: userService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register 用户注册
|
||||||
|
func (h *UserHandler) Register(c *gin.Context) {
|
||||||
|
type RegisterRequest struct {
|
||||||
|
Username string `json:"username" binding:"required"`
|
||||||
|
Email string `json:"email" binding:"required,email"`
|
||||||
|
Password string `json:"password" binding:"required,min=6"`
|
||||||
|
Nickname string `json:"nickname" binding:"required"`
|
||||||
|
Phone string `json:"phone"`
|
||||||
|
VerificationCode string `json:"verification_code" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req RegisterRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := h.userService.Register(c.Request.Context(), req.Username, req.Email, req.Password, req.Nickname, req.Phone, req.VerificationCode)
|
||||||
|
if err != nil {
|
||||||
|
if se, ok := err.(*service.ServiceError); ok {
|
||||||
|
response.Error(c, se.Code, se.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.InternalServerError(c, "failed to register")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成Token
|
||||||
|
accessToken, _ := h.jwtService.GenerateAccessToken(user.ID, user.Username)
|
||||||
|
refreshToken, _ := h.jwtService.GenerateRefreshToken(user.ID, user.Username)
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"user": dto.ConvertUserToResponse(user),
|
||||||
|
"token": accessToken,
|
||||||
|
"refresh_token": refreshToken,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login 用户登录
|
||||||
|
func (h *UserHandler) Login(c *gin.Context) {
|
||||||
|
type LoginRequest struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Account string `json:"account"`
|
||||||
|
Password string `json:"password" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req LoginRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
account := req.Account
|
||||||
|
if account == "" {
|
||||||
|
account = req.Username
|
||||||
|
}
|
||||||
|
if account == "" {
|
||||||
|
response.BadRequest(c, "username or account is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := h.userService.Login(c.Request.Context(), account, req.Password)
|
||||||
|
if err != nil {
|
||||||
|
if se, ok := err.(*service.ServiceError); ok {
|
||||||
|
response.Error(c, se.Code, se.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.InternalServerError(c, "failed to login")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成Token
|
||||||
|
accessToken, _ := h.jwtService.GenerateAccessToken(user.ID, user.Username)
|
||||||
|
refreshToken, _ := h.jwtService.GenerateRefreshToken(user.ID, user.Username)
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"user": dto.ConvertUserToResponse(user),
|
||||||
|
"token": accessToken,
|
||||||
|
"refresh_token": refreshToken,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendRegisterCode 发送注册验证码
|
||||||
|
func (h *UserHandler) SendRegisterCode(c *gin.Context) {
|
||||||
|
type SendCodeRequest struct {
|
||||||
|
Email string `json:"email" binding:"required,email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req SendCodeRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.userService.SendRegisterCode(c.Request.Context(), req.Email); err != nil {
|
||||||
|
if se, ok := err.(*service.ServiceError); ok {
|
||||||
|
response.Error(c, se.Code, se.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.InternalServerError(c, "failed to send register verification code")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendPasswordResetCode 发送找回密码验证码
|
||||||
|
func (h *UserHandler) SendPasswordResetCode(c *gin.Context) {
|
||||||
|
type SendCodeRequest struct {
|
||||||
|
Email string `json:"email" binding:"required,email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req SendCodeRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.userService.SendPasswordResetCode(c.Request.Context(), req.Email); err != nil {
|
||||||
|
if se, ok := err.(*service.ServiceError); ok {
|
||||||
|
response.Error(c, se.Code, se.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.InternalServerError(c, "failed to send reset verification code")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetPassword 找回密码并重置
|
||||||
|
func (h *UserHandler) ResetPassword(c *gin.Context) {
|
||||||
|
type ResetPasswordRequest struct {
|
||||||
|
Email string `json:"email" binding:"required,email"`
|
||||||
|
VerificationCode string `json:"verification_code" binding:"required"`
|
||||||
|
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req ResetPasswordRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.userService.ResetPasswordByEmail(c.Request.Context(), req.Email, req.VerificationCode, req.NewPassword); err != nil {
|
||||||
|
if se, ok := err.(*service.ServiceError); ok {
|
||||||
|
response.Error(c, se.Code, se.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.InternalServerError(c, "failed to reset password")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCurrentUser 获取当前用户
|
||||||
|
func (h *UserHandler) GetCurrentUser(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "user not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实时计算帖子数量
|
||||||
|
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
// 如果获取失败,使用数据库中的值
|
||||||
|
postsCount = int64(user.PostsCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.ConvertUserToDetailResponseWithPostsCount(user, int(postsCount)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByID 获取指定用户
|
||||||
|
func (h *UserHandler) GetUserByID(c *gin.Context) {
|
||||||
|
id := c.Param("id")
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
|
||||||
|
// 获取用户信息,包含双向关注状态
|
||||||
|
user, isFollowing, isFollowingMe, err := h.userService.GetUserByIDWithMutualFollowStatus(c.Request.Context(), id, currentUserID)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "user not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实时计算帖子数量
|
||||||
|
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), id)
|
||||||
|
if err != nil {
|
||||||
|
// 如果获取失败,使用数据库中的值
|
||||||
|
postsCount = int64(user.PostsCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为响应格式,包含双向关注状态和实时计算的帖子数量
|
||||||
|
userResponse := dto.ConvertUserToResponseWithMutualFollowAndPostsCount(user, isFollowing, isFollowingMe, int(postsCount))
|
||||||
|
|
||||||
|
response.Success(c, userResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUser 更新用户
|
||||||
|
func (h *UserHandler) UpdateUser(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateRequest struct {
|
||||||
|
Nickname string `json:"nickname"`
|
||||||
|
Bio string `json:"bio"`
|
||||||
|
Website string `json:"website"`
|
||||||
|
Location string `json:"location"`
|
||||||
|
Avatar string `json:"avatar"`
|
||||||
|
Phone *string `json:"phone"`
|
||||||
|
Email *string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req UpdateRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "user not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Nickname != "" {
|
||||||
|
user.Nickname = req.Nickname
|
||||||
|
}
|
||||||
|
if req.Bio != "" {
|
||||||
|
user.Bio = req.Bio
|
||||||
|
}
|
||||||
|
if req.Website != "" {
|
||||||
|
user.Website = req.Website
|
||||||
|
}
|
||||||
|
if req.Location != "" {
|
||||||
|
user.Location = req.Location
|
||||||
|
}
|
||||||
|
if req.Avatar != "" {
|
||||||
|
user.Avatar = req.Avatar
|
||||||
|
}
|
||||||
|
if req.Phone != nil {
|
||||||
|
user.Phone = req.Phone
|
||||||
|
}
|
||||||
|
if req.Email != nil {
|
||||||
|
if user.Email == nil || *user.Email != *req.Email {
|
||||||
|
user.EmailVerified = false
|
||||||
|
}
|
||||||
|
user.Email = req.Email
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.userService.UpdateUser(c.Request.Context(), user)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to update user")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 实时计算帖子数量
|
||||||
|
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), userID)
|
||||||
|
if err != nil {
|
||||||
|
// 如果获取失败,使用数据库中的值
|
||||||
|
postsCount = int64(user.PostsCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.ConvertUserToDetailResponseWithPostsCount(user, int(postsCount)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendEmailVerifyCode 发送当前用户邮箱验证码
|
||||||
|
func (h *UserHandler) SendEmailVerifyCode(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type SendCodeRequest struct {
|
||||||
|
Email string `json:"email" binding:"required,email"`
|
||||||
|
}
|
||||||
|
var req SendCodeRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.userService.SendCurrentUserEmailVerifyCode(c.Request.Context(), userID, req.Email); err != nil {
|
||||||
|
if se, ok := err.(*service.ServiceError); ok {
|
||||||
|
response.Error(c, se.Code, se.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.InternalServerError(c, "failed to send email verify code")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyEmail 验证当前用户邮箱
|
||||||
|
func (h *UserHandler) VerifyEmail(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type VerifyEmailRequest struct {
|
||||||
|
Email string `json:"email" binding:"required,email"`
|
||||||
|
VerificationCode string `json:"verification_code" binding:"required"`
|
||||||
|
}
|
||||||
|
var req VerifyEmailRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.userService.VerifyCurrentUserEmail(c.Request.Context(), userID, req.Email, req.VerificationCode); err != nil {
|
||||||
|
if se, ok := err.(*service.ServiceError); ok {
|
||||||
|
response.Error(c, se.Code, se.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.InternalServerError(c, "failed to verify email")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken 刷新Token
|
||||||
|
func (h *UserHandler) RefreshToken(c *gin.Context) {
|
||||||
|
type RefreshRequest struct {
|
||||||
|
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req RefreshRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 refresh token
|
||||||
|
claims, err := h.jwtService.ParseToken(req.RefreshToken)
|
||||||
|
if err != nil {
|
||||||
|
response.Unauthorized(c, "invalid refresh token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成新 token
|
||||||
|
accessToken, _ := h.jwtService.GenerateAccessToken(claims.UserID, claims.Username)
|
||||||
|
refreshToken, _ := h.jwtService.GenerateRefreshToken(claims.UserID, claims.Username)
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"token": accessToken,
|
||||||
|
"refresh_token": refreshToken,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetJWTService 设置JWT服务
|
||||||
|
func (h *UserHandler) SetJWTService(jwtSvc *service.JWTService) {
|
||||||
|
h.jwtService = jwtSvc
|
||||||
|
}
|
||||||
|
|
||||||
|
// FollowUser 关注用户
|
||||||
|
func (h *UserHandler) FollowUser(c *gin.Context) {
|
||||||
|
userID := c.Param("id")
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
|
||||||
|
if userID == currentUserID {
|
||||||
|
response.BadRequest(c, "cannot follow yourself")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.userService.FollowUser(c.Request.Context(), currentUserID, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to follow user")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnfollowUser 取消关注用户
|
||||||
|
func (h *UserHandler) UnfollowUser(c *gin.Context) {
|
||||||
|
userID := c.Param("id")
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
|
||||||
|
err := h.userService.UnfollowUser(c.Request.Context(), currentUserID, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to unfollow user")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlockUser 拉黑用户
|
||||||
|
func (h *UserHandler) BlockUser(c *gin.Context) {
|
||||||
|
targetUserID := c.Param("id")
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
|
||||||
|
if targetUserID == currentUserID {
|
||||||
|
response.BadRequest(c, "cannot block yourself")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.userService.BlockUser(c.Request.Context(), currentUserID, targetUserID)
|
||||||
|
if err != nil {
|
||||||
|
if se, ok := err.(*service.ServiceError); ok {
|
||||||
|
response.Error(c, se.Code, se.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.InternalServerError(c, "failed to block user")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnblockUser 取消拉黑
|
||||||
|
func (h *UserHandler) UnblockUser(c *gin.Context) {
|
||||||
|
targetUserID := c.Param("id")
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
|
||||||
|
if targetUserID == currentUserID {
|
||||||
|
response.BadRequest(c, "cannot unblock yourself")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.userService.UnblockUser(c.Request.Context(), currentUserID, targetUserID)
|
||||||
|
if err != nil {
|
||||||
|
if se, ok := err.(*service.ServiceError); ok {
|
||||||
|
response.Error(c, se.Code, se.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.InternalServerError(c, "failed to unblock user")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBlockedUsers 获取黑名单列表
|
||||||
|
func (h *UserHandler) GetBlockedUsers(c *gin.Context) {
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
if currentUserID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
if page <= 0 {
|
||||||
|
page = 1
|
||||||
|
}
|
||||||
|
if pageSize <= 0 {
|
||||||
|
pageSize = 20
|
||||||
|
}
|
||||||
|
|
||||||
|
users, total, err := h.userService.GetBlockedUsers(c.Request.Context(), currentUserID, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get blocked users")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userIDs := make([]string, len(users))
|
||||||
|
for i, u := range users {
|
||||||
|
userIDs[i] = u.ID
|
||||||
|
}
|
||||||
|
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||||
|
userResponses := dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
|
||||||
|
response.Paginated(c, userResponses, total, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBlockStatus 获取拉黑状态
|
||||||
|
func (h *UserHandler) GetBlockStatus(c *gin.Context) {
|
||||||
|
targetUserID := c.Param("id")
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
if currentUserID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if targetUserID == "" {
|
||||||
|
response.BadRequest(c, "target user id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
isBlocked, err := h.userService.IsBlocked(c.Request.Context(), currentUserID, targetUserID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get block status")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"is_blocked": isBlocked})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFollowingList 获取关注列表
|
||||||
|
func (h *UserHandler) GetFollowingList(c *gin.Context) {
|
||||||
|
userID := c.Param("id")
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
page := c.DefaultQuery("page", "1")
|
||||||
|
pageSize := c.DefaultQuery("page_size", "20")
|
||||||
|
|
||||||
|
users, err := h.userService.GetFollowingList(c.Request.Context(), userID, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get following list")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果已登录,获取双向关注状态和实时计算的帖子数量
|
||||||
|
var userResponses []*dto.UserResponse
|
||||||
|
if currentUserID != "" && len(users) > 0 {
|
||||||
|
userIDs := make([]string, len(users))
|
||||||
|
for i, u := range users {
|
||||||
|
userIDs[i] = u.ID
|
||||||
|
}
|
||||||
|
statusMap, _ := h.userService.GetMutualFollowStatus(c.Request.Context(), currentUserID, userIDs)
|
||||||
|
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||||
|
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, statusMap, postsCountMap)
|
||||||
|
} else if len(users) > 0 {
|
||||||
|
userIDs := make([]string, len(users))
|
||||||
|
for i, u := range users {
|
||||||
|
userIDs[i] = u.ID
|
||||||
|
}
|
||||||
|
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||||
|
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
|
||||||
|
} else {
|
||||||
|
userResponses = dto.ConvertUsersToResponse(users)
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"list": userResponses,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFollowersList 获取粉丝列表
|
||||||
|
func (h *UserHandler) GetFollowersList(c *gin.Context) {
|
||||||
|
userID := c.Param("id")
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
page := c.DefaultQuery("page", "1")
|
||||||
|
pageSize := c.DefaultQuery("page_size", "20")
|
||||||
|
|
||||||
|
fmt.Printf("[DEBUG] GetFollowersList: userID=%s, currentUserID=%s\n", userID, currentUserID)
|
||||||
|
|
||||||
|
users, err := h.userService.GetFollowersList(c.Request.Context(), userID, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to get followers list")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("[DEBUG] GetFollowersList: found %d users\n", len(users))
|
||||||
|
|
||||||
|
// 如果已登录,获取双向关注状态和实时计算的帖子数量
|
||||||
|
var userResponses []*dto.UserResponse
|
||||||
|
if currentUserID != "" && len(users) > 0 {
|
||||||
|
userIDs := make([]string, len(users))
|
||||||
|
for i, u := range users {
|
||||||
|
userIDs[i] = u.ID
|
||||||
|
}
|
||||||
|
fmt.Printf("[DEBUG] GetFollowersList: checking mutual follow status for userIDs=%v\n", userIDs)
|
||||||
|
statusMap, _ := h.userService.GetMutualFollowStatus(c.Request.Context(), currentUserID, userIDs)
|
||||||
|
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||||
|
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, statusMap, postsCountMap)
|
||||||
|
} else if len(users) > 0 {
|
||||||
|
userIDs := make([]string, len(users))
|
||||||
|
for i, u := range users {
|
||||||
|
userIDs[i] = u.ID
|
||||||
|
}
|
||||||
|
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||||
|
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
|
||||||
|
} else {
|
||||||
|
userResponses = dto.ConvertUsersToResponse(users)
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{
|
||||||
|
"list": userResponses,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckUsername 检查用户名是否可用
|
||||||
|
func (h *UserHandler) CheckUsername(c *gin.Context) {
|
||||||
|
username := c.Query("username")
|
||||||
|
if username == "" {
|
||||||
|
response.BadRequest(c, "username is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
available, err := h.userService.CheckUsernameAvailable(c.Request.Context(), username)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to check username")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"available": available})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChangePassword 修改密码
|
||||||
|
func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
|
||||||
|
type ChangePasswordRequest struct {
|
||||||
|
OldPassword string `json:"old_password" binding:"required"`
|
||||||
|
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||||
|
VerificationCode string `json:"verification_code" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var req ChangePasswordRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.userService.ChangePassword(c.Request.Context(), currentUserID, req.OldPassword, req.NewPassword, req.VerificationCode)
|
||||||
|
if err != nil {
|
||||||
|
if se, ok := err.(*service.ServiceError); ok {
|
||||||
|
response.Error(c, se.Code, se.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.InternalServerError(c, "failed to change password")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendChangePasswordCode 发送修改密码验证码
|
||||||
|
func (h *UserHandler) SendChangePasswordCode(c *gin.Context) {
|
||||||
|
currentUserID := c.GetString("user_id")
|
||||||
|
if currentUserID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.userService.SendChangePasswordCode(c.Request.Context(), currentUserID)
|
||||||
|
if err != nil {
|
||||||
|
if se, ok := err.(*service.ServiceError); ok {
|
||||||
|
response.Error(c, se.Code, se.Message)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.InternalServerError(c, "failed to send change password code")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search 搜索用户
|
||||||
|
func (h *UserHandler) Search(c *gin.Context) {
|
||||||
|
keyword := c.Query("keyword")
|
||||||
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||||
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||||
|
|
||||||
|
users, total, err := h.userService.Search(c.Request.Context(), keyword, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "failed to search users")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取实时计算的帖子数量
|
||||||
|
var userResponses []*dto.UserResponse
|
||||||
|
if len(users) > 0 {
|
||||||
|
userIDs := make([]string, len(users))
|
||||||
|
for i, u := range users {
|
||||||
|
userIDs[i] = u.ID
|
||||||
|
}
|
||||||
|
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||||
|
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
|
||||||
|
} else {
|
||||||
|
userResponses = dto.ConvertUsersToResponse(users)
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Paginated(c, userResponses, total, page, pageSize)
|
||||||
|
}
|
||||||
216
internal/handler/vote_handler.go
Normal file
216
internal/handler/vote_handler.go
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/dto"
|
||||||
|
"carrot_bbs/internal/pkg/response"
|
||||||
|
"carrot_bbs/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// VoteHandler 投票处理器
|
||||||
|
type VoteHandler struct {
|
||||||
|
voteService *service.VoteService
|
||||||
|
postService *service.PostService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewVoteHandler 创建投票处理器
|
||||||
|
func NewVoteHandler(voteService *service.VoteService, postService *service.PostService) *VoteHandler {
|
||||||
|
return &VoteHandler{
|
||||||
|
voteService: voteService,
|
||||||
|
postService: postService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateVotePost 创建投票帖子
|
||||||
|
// POST /api/v1/posts/vote
|
||||||
|
func (h *VoteHandler) CreateVotePost(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "请先登录")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req dto.CreateVotePostRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
post, err := h.voteService.CreateVotePost(c.Request.Context(), userID, &req)
|
||||||
|
if err != nil {
|
||||||
|
var moderationErr *service.PostModerationRejectedError
|
||||||
|
if errors.As(err, &moderationErr) {
|
||||||
|
response.BadRequest(c, moderationErr.UserMessage())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Error(c, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, post)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVoteResult 获取投票结果
|
||||||
|
// GET /api/v1/posts/:id/vote
|
||||||
|
func (h *VoteHandler) GetVoteResult(c *gin.Context) {
|
||||||
|
postID := c.Param("id")
|
||||||
|
if postID == "" {
|
||||||
|
response.BadRequest(c, "帖子ID不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证帖子存在
|
||||||
|
_, err := h.postService.GetByID(c.Request.Context(), postID)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "帖子不存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取当前用户ID(可选登录)
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
|
||||||
|
// 如果用户未登录,返回不带has_voted的结果
|
||||||
|
if userID == "" {
|
||||||
|
options, err := h.voteService.GetVoteOptions(postID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "获取投票选项失败")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算总票数
|
||||||
|
totalVotes := 0
|
||||||
|
for _, option := range options {
|
||||||
|
totalVotes += option.VotesCount
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &dto.VoteResultDTO{
|
||||||
|
Options: options,
|
||||||
|
TotalVotes: totalVotes,
|
||||||
|
HasVoted: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用户已登录,获取完整的投票结果
|
||||||
|
result, err := h.voteService.GetVoteResult(postID, userID)
|
||||||
|
if err != nil {
|
||||||
|
response.InternalServerError(c, "获取投票结果失败")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vote 投票
|
||||||
|
// POST /api/v1/posts/:id/vote
|
||||||
|
func (h *VoteHandler) Vote(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "请先登录")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
postID := c.Param("id")
|
||||||
|
if postID == "" {
|
||||||
|
response.BadRequest(c, "帖子ID不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证帖子存在
|
||||||
|
_, err := h.postService.GetByID(c.Request.Context(), postID)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "帖子不存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析请求体
|
||||||
|
var req struct {
|
||||||
|
OptionID string `json:"option_id" binding:"required"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.voteService.Vote(c.Request.Context(), postID, userID, req.OptionID); err != nil {
|
||||||
|
response.Error(c, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unvote 取消投票
|
||||||
|
// DELETE /api/v1/posts/:id/vote
|
||||||
|
func (h *VoteHandler) Unvote(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "请先登录")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
postID := c.Param("id")
|
||||||
|
if postID == "" {
|
||||||
|
response.BadRequest(c, "帖子ID不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证帖子存在
|
||||||
|
_, err := h.postService.GetByID(c.Request.Context(), postID)
|
||||||
|
if err != nil {
|
||||||
|
response.NotFound(c, "帖子不存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.voteService.Unvote(c.Request.Context(), postID, userID); err != nil {
|
||||||
|
response.Error(c, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateVoteOption 更新投票选项(仅作者)
|
||||||
|
// PUT /api/v1/vote-options/:id
|
||||||
|
func (h *VoteHandler) UpdateVoteOption(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "请先登录")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
optionID := c.Param("id")
|
||||||
|
if optionID == "" {
|
||||||
|
response.BadRequest(c, "选项ID不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析请求体
|
||||||
|
var req struct {
|
||||||
|
Content string `json:"content" binding:"required"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取帖子ID(从查询参数或请求体中获取)
|
||||||
|
postID := c.Query("post_id")
|
||||||
|
if postID == "" {
|
||||||
|
response.BadRequest(c, "帖子ID不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.voteService.UpdateVoteOption(c.Request.Context(), postID, optionID, userID, req.Content); err != nil {
|
||||||
|
response.Error(c, http.StatusForbidden, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"success": true})
|
||||||
|
}
|
||||||
866
internal/handler/websocket_handler.go
Normal file
866
internal/handler/websocket_handler.go
Normal file
@@ -0,0 +1,866 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/dto"
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
ws "carrot_bbs/internal/pkg/websocket"
|
||||||
|
"carrot_bbs/internal/repository"
|
||||||
|
"carrot_bbs/internal/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
var upgrader = websocket.Upgrader{
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
|
return true // 允许所有来源,生产环境应该限制
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocketHandler WebSocket处理器
|
||||||
|
type WebSocketHandler struct {
|
||||||
|
jwtService *service.JWTService
|
||||||
|
chatService service.ChatService
|
||||||
|
groupService service.GroupService
|
||||||
|
groupRepo repository.GroupRepository
|
||||||
|
userRepo *repository.UserRepository
|
||||||
|
wsManager *ws.WebSocketManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWebSocketHandler 创建WebSocket处理器
|
||||||
|
func NewWebSocketHandler(
|
||||||
|
jwtService *service.JWTService,
|
||||||
|
chatService service.ChatService,
|
||||||
|
groupService service.GroupService,
|
||||||
|
groupRepo repository.GroupRepository,
|
||||||
|
userRepo *repository.UserRepository,
|
||||||
|
wsManager *ws.WebSocketManager,
|
||||||
|
) *WebSocketHandler {
|
||||||
|
return &WebSocketHandler{
|
||||||
|
jwtService: jwtService,
|
||||||
|
chatService: chatService,
|
||||||
|
groupService: groupService,
|
||||||
|
groupRepo: groupRepo,
|
||||||
|
userRepo: userRepo,
|
||||||
|
wsManager: wsManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleWebSocket 处理WebSocket连接
|
||||||
|
func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) {
|
||||||
|
// 调试:打印请求头信息
|
||||||
|
log.Printf("[WebSocket] 收到请求: Method=%s, Path=%s", c.Request.Method, c.Request.URL.Path)
|
||||||
|
log.Printf("[WebSocket] 请求头: Connection=%s, Upgrade=%s",
|
||||||
|
c.GetHeader("Connection"),
|
||||||
|
c.GetHeader("Upgrade"))
|
||||||
|
log.Printf("[WebSocket] Sec-WebSocket-Key=%s, Sec-WebSocket-Version=%s",
|
||||||
|
c.GetHeader("Sec-WebSocket-Key"),
|
||||||
|
c.GetHeader("Sec-WebSocket-Version"))
|
||||||
|
|
||||||
|
// 从query参数获取token
|
||||||
|
token := c.Query("token")
|
||||||
|
if token == "" {
|
||||||
|
// 尝试从header获取
|
||||||
|
authHeader := c.GetHeader("Authorization")
|
||||||
|
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||||
|
token = strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing token"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证token
|
||||||
|
claims, err := h.jwtService.ParseToken(token)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Invalid token: %v", err)
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userID := claims.UserID
|
||||||
|
if userID == "" {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token claims"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 升级HTTP连接为WebSocket连接
|
||||||
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to upgrade connection: %v", err)
|
||||||
|
log.Printf("[WebSocket] 请求详情 - User-Agent: %s, Content-Type: %s",
|
||||||
|
c.GetHeader("User-Agent"),
|
||||||
|
c.GetHeader("Content-Type"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果用户已在线,先注销旧连接
|
||||||
|
if h.wsManager.IsUserOnline(userID) {
|
||||||
|
log.Printf("[DEBUG] 用户 %s 已有在线连接,复用该连接", userID)
|
||||||
|
} else {
|
||||||
|
log.Printf("[DEBUG] 用户 %s 当前不在线,创建新连接", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建客户端
|
||||||
|
client := &ws.Client{
|
||||||
|
ID: userID,
|
||||||
|
UserID: userID,
|
||||||
|
Conn: conn,
|
||||||
|
Send: make(chan []byte, 256),
|
||||||
|
Manager: h.wsManager,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 注册客户端
|
||||||
|
h.wsManager.Register(client)
|
||||||
|
|
||||||
|
// 启动读写协程
|
||||||
|
go client.WritePump()
|
||||||
|
go h.handleMessages(client)
|
||||||
|
|
||||||
|
log.Printf("[DEBUG] WebSocket连接建立: userID=%s, 当前在线=%v", userID, h.wsManager.IsUserOnline(userID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMessages 处理客户端消息
|
||||||
|
// 针对移动端优化:增加超时时间到 120 秒,配合客户端 55 秒心跳
|
||||||
|
func (h *WebSocketHandler) handleMessages(client *ws.Client) {
|
||||||
|
defer func() {
|
||||||
|
h.wsManager.Unregister(client)
|
||||||
|
client.Conn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
client.Conn.SetReadLimit(512 * 1024) // 512KB
|
||||||
|
client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒
|
||||||
|
client.Conn.SetPongHandler(func(string) error {
|
||||||
|
client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// 心跳定时器 - 服务端主动 ping 间隔保持 30 秒
|
||||||
|
pingTicker := time.NewTicker(30 * time.Second)
|
||||||
|
defer pingTicker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-pingTicker.C:
|
||||||
|
// 发送心跳
|
||||||
|
if err := client.SendPing(); err != nil {
|
||||||
|
log.Printf("Failed to send ping: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
_, message, err := client.Conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||||
|
log.Printf("WebSocket error: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var wsMsg ws.WSMessage
|
||||||
|
if err := json.Unmarshal(message, &wsMsg); err != nil {
|
||||||
|
log.Printf("Failed to unmarshal message: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
h.processMessage(client, &wsMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processMessage 处理消息
|
||||||
|
func (h *WebSocketHandler) processMessage(client *ws.Client, msg *ws.WSMessage) {
|
||||||
|
switch msg.Type {
|
||||||
|
case ws.MessageTypePing:
|
||||||
|
// 响应心跳
|
||||||
|
if err := client.SendPong(); err != nil {
|
||||||
|
log.Printf("Failed to send pong: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
case ws.MessageTypePong:
|
||||||
|
// 客户端响应心跳
|
||||||
|
|
||||||
|
case ws.MessageTypeMessage:
|
||||||
|
// 处理聊天消息
|
||||||
|
h.handleChatMessage(client, msg)
|
||||||
|
|
||||||
|
case ws.MessageTypeTyping:
|
||||||
|
// 处理正在输入状态
|
||||||
|
h.handleTyping(client, msg)
|
||||||
|
|
||||||
|
case ws.MessageTypeRead:
|
||||||
|
// 处理已读回执
|
||||||
|
h.handleReadReceipt(client, msg)
|
||||||
|
|
||||||
|
// 群组消息处理
|
||||||
|
case ws.MessageTypeGroupMessage:
|
||||||
|
// 处理群消息
|
||||||
|
h.handleGroupMessage(client, msg)
|
||||||
|
|
||||||
|
case ws.MessageTypeGroupTyping:
|
||||||
|
// 处理群输入状态
|
||||||
|
h.handleGroupTyping(client, msg)
|
||||||
|
|
||||||
|
case ws.MessageTypeGroupRead:
|
||||||
|
// 处理群消息已读
|
||||||
|
h.handleGroupReadReceipt(client, msg)
|
||||||
|
|
||||||
|
case ws.MessageTypeGroupRecall:
|
||||||
|
// 处理群消息撤回
|
||||||
|
h.handleGroupRecall(client, msg)
|
||||||
|
|
||||||
|
default:
|
||||||
|
log.Printf("Unknown message type: %s", msg.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleChatMessage 处理聊天消息
|
||||||
|
func (h *WebSocketHandler) handleChatMessage(client *ws.Client, msg *ws.WSMessage) {
|
||||||
|
data, ok := msg.Data.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
log.Printf("Invalid message data format")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[DEBUG handleChatMessage] 完整data: %+v", data)
|
||||||
|
|
||||||
|
conversationIDStr, _ := data["conversationId"].(string)
|
||||||
|
|
||||||
|
if conversationIDStr == "" {
|
||||||
|
log.Printf("Missing conversationId")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析会话ID
|
||||||
|
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Invalid conversation ID: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 segments
|
||||||
|
var segments model.MessageSegments
|
||||||
|
if data["segments"] != nil {
|
||||||
|
segmentsBytes, err := json.Marshal(data["segments"])
|
||||||
|
if err == nil {
|
||||||
|
json.Unmarshal(segmentsBytes, &segments)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从 segments 中提取回复消息ID
|
||||||
|
replyToID := dto.GetReplyMessageID(segments)
|
||||||
|
var replyToIDPtr *string
|
||||||
|
if replyToID != "" {
|
||||||
|
replyToIDPtr = &replyToID
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送消息 - 使用 segments
|
||||||
|
message, err := h.chatService.SendMessage(context.Background(), client.UserID, conversationID, segments, replyToIDPtr)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to send message: %v", err)
|
||||||
|
// 发送错误消息
|
||||||
|
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
||||||
|
"error": "Failed to send message",
|
||||||
|
})
|
||||||
|
if client.Send != nil {
|
||||||
|
msgBytes, _ := json.Marshal(errorMsg)
|
||||||
|
client.Send <- msgBytes
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送确认消息(使用 meta 事件格式,包含完整的消息内容)
|
||||||
|
metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{
|
||||||
|
"detail_type": ws.MetaDetailTypeAck,
|
||||||
|
"conversation_id": conversationID,
|
||||||
|
"id": message.ID,
|
||||||
|
"user_id": client.UserID,
|
||||||
|
"sender_id": client.UserID,
|
||||||
|
"seq": message.Seq,
|
||||||
|
"segments": message.Segments,
|
||||||
|
"created_at": message.CreatedAt.UnixMilli(),
|
||||||
|
})
|
||||||
|
if client.Send != nil {
|
||||||
|
msgBytes, _ := json.Marshal(metaAckMsg)
|
||||||
|
log.Printf("[DEBUG handleChatMessage] 私聊 ack 消息: %s", string(msgBytes))
|
||||||
|
log.Printf("[DEBUG handleChatMessage] message.Segments 类型: %T, 值: %+v", message.Segments, message.Segments)
|
||||||
|
client.Send <- msgBytes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleTyping 处理正在输入状态
|
||||||
|
func (h *WebSocketHandler) handleTyping(client *ws.Client, msg *ws.WSMessage) {
|
||||||
|
data, ok := msg.Data.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conversationIDStr, _ := data["conversationId"].(string)
|
||||||
|
if conversationIDStr == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 直接使用 string 类型的 userID
|
||||||
|
h.chatService.SendTyping(context.Background(), client.UserID, conversationID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleReadReceipt 处理已读回执
|
||||||
|
func (h *WebSocketHandler) handleReadReceipt(client *ws.Client, msg *ws.WSMessage) {
|
||||||
|
data, ok := msg.Data.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conversationIDStr, _ := data["conversationId"].(string)
|
||||||
|
if conversationIDStr == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取lastReadSeq
|
||||||
|
lastReadSeq, _ := data["lastReadSeq"].(float64)
|
||||||
|
if lastReadSeq == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 直接使用 string 类型的 userID 和 conversationID
|
||||||
|
if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil {
|
||||||
|
log.Printf("Failed to mark as read: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 群组消息处理 ====================
|
||||||
|
|
||||||
|
// handleGroupMessage 处理群消息
|
||||||
|
func (h *WebSocketHandler) handleGroupMessage(client *ws.Client, msg *ws.WSMessage) {
|
||||||
|
// 打印接收到的消息类型和数据,用于调试
|
||||||
|
log.Printf("[handleGroupMessage] Received message type: %s", msg.Type)
|
||||||
|
log.Printf("[handleGroupMessage] Message data: %+v", msg.Data)
|
||||||
|
|
||||||
|
data, ok := msg.Data.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
log.Printf("Invalid group message data format: data is not map[string]interface{}")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析群组ID(支持 camelCase 和 snake_case)
|
||||||
|
var groupIDFloat float64
|
||||||
|
groupID := "" // 使用 groupID 作为最终变量名
|
||||||
|
if val, ok := data["groupId"].(float64); ok {
|
||||||
|
groupIDFloat = val
|
||||||
|
groupID = strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
|
||||||
|
} else if val, ok := data["group_id"].(string); ok {
|
||||||
|
groupID = val
|
||||||
|
}
|
||||||
|
|
||||||
|
if groupID == "" {
|
||||||
|
log.Printf("Missing groupId in group message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析会话ID(支持 camelCase 和 snake_case)
|
||||||
|
var conversationID string
|
||||||
|
if val, ok := data["conversationId"].(string); ok {
|
||||||
|
conversationID = val
|
||||||
|
} else if val, ok := data["conversation_id"].(string); ok {
|
||||||
|
conversationID = val
|
||||||
|
}
|
||||||
|
if conversationID == "" {
|
||||||
|
log.Printf("Missing conversationId in group message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析 segments
|
||||||
|
var segments model.MessageSegments
|
||||||
|
if data["segments"] != nil {
|
||||||
|
segmentsBytes, err := json.Marshal(data["segments"])
|
||||||
|
if err == nil {
|
||||||
|
json.Unmarshal(segmentsBytes, &segments)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析@用户列表(支持 camelCase 和 snake_case)
|
||||||
|
var mentionUsers []uint64
|
||||||
|
var mentionUsersInterface []interface{}
|
||||||
|
if val, ok := data["mentionUsers"].([]interface{}); ok {
|
||||||
|
mentionUsersInterface = val
|
||||||
|
} else if val, ok := data["mention_users"].([]interface{}); ok {
|
||||||
|
mentionUsersInterface = val
|
||||||
|
}
|
||||||
|
if len(mentionUsersInterface) > 0 {
|
||||||
|
for _, uid := range mentionUsersInterface {
|
||||||
|
if uidFloat, ok := uid.(float64); ok {
|
||||||
|
mentionUsers = append(mentionUsers, uint64(uidFloat))
|
||||||
|
} else if uidStr, ok := uid.(string); ok {
|
||||||
|
// 处理字符串格式的用户ID
|
||||||
|
if uidInt, err := strconv.ParseUint(uidStr, 10, 64); err == nil {
|
||||||
|
mentionUsers = append(mentionUsers, uidInt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析@所有人(支持 camelCase 和 snake_case)
|
||||||
|
var mentionAll bool
|
||||||
|
if val, ok := data["mentionAll"].(bool); ok {
|
||||||
|
mentionAll = val
|
||||||
|
} else if val, ok := data["mention_all"].(bool); ok {
|
||||||
|
mentionAll = val
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查用户是否可以发送群消息(验证成员身份和禁言状态)
|
||||||
|
// client.UserID 已经是 string 格式的 UUID
|
||||||
|
if err := h.groupService.CanSendGroupMessage(client.UserID, groupID); err != nil {
|
||||||
|
log.Printf("User cannot send group message: %v", err)
|
||||||
|
// 发送错误消息
|
||||||
|
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
||||||
|
"error": "Cannot send group message",
|
||||||
|
"reason": err.Error(),
|
||||||
|
"type": "group_message_error",
|
||||||
|
"groupId": groupID,
|
||||||
|
})
|
||||||
|
if client.Send != nil {
|
||||||
|
msgBytes, _ := json.Marshal(errorMsg)
|
||||||
|
client.Send <- msgBytes
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查@所有人权限(只有群主和管理员可以@所有人)
|
||||||
|
if mentionAll {
|
||||||
|
if !h.groupService.IsGroupAdmin(client.UserID, groupID) {
|
||||||
|
log.Printf("User %s has no permission to mention all in group %s", client.UserID, groupID)
|
||||||
|
mentionAll = false // 取消@所有人标记
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建消息
|
||||||
|
message := &model.Message{
|
||||||
|
ConversationID: conversationID,
|
||||||
|
SenderID: client.UserID,
|
||||||
|
Segments: segments,
|
||||||
|
Status: model.MessageStatusNormal,
|
||||||
|
MentionAll: mentionAll,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 序列化mentionUsers为JSON
|
||||||
|
if len(mentionUsers) > 0 {
|
||||||
|
mentionUsersJSON, _ := json.Marshal(mentionUsers)
|
||||||
|
message.MentionUsers = string(mentionUsersJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存消息到数据库(只存库,不发私聊 WebSocket 帧,群消息通过 BroadcastGroupMessage 单独广播)
|
||||||
|
savedMessage, err := h.chatService.SaveMessage(context.Background(), client.UserID, conversationID, segments, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to save group message: %v", err)
|
||||||
|
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
||||||
|
"error": "Failed to save group message",
|
||||||
|
})
|
||||||
|
if client.Send != nil {
|
||||||
|
msgBytes, _ := json.Marshal(errorMsg)
|
||||||
|
client.Send <- msgBytes
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新消息的mention信息
|
||||||
|
if len(mentionUsers) > 0 || mentionAll {
|
||||||
|
message.ID = savedMessage.ID
|
||||||
|
message.Seq = savedMessage.Seq
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构造群消息响应
|
||||||
|
groupMsg := &ws.GroupMessage{
|
||||||
|
ID: savedMessage.ID,
|
||||||
|
ConversationID: conversationID,
|
||||||
|
GroupID: groupID,
|
||||||
|
SenderID: client.UserID,
|
||||||
|
Seq: savedMessage.Seq,
|
||||||
|
Segments: segments,
|
||||||
|
MentionUsers: mentionUsers,
|
||||||
|
MentionAll: mentionAll,
|
||||||
|
CreatedAt: savedMessage.CreatedAt.UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 广播消息给群组所有成员(排除发送者)
|
||||||
|
h.BroadcastGroupMessage(groupID, groupMsg, client.UserID)
|
||||||
|
|
||||||
|
// 发送确认消息给发送者(作为meta事件)
|
||||||
|
// 使用 meta 事件格式发送 ack
|
||||||
|
log.Printf("[DEBUG HandleGroupMessageSend] 准备发送 ack 消息, userID=%s, messageID=%s, seq=%d",
|
||||||
|
client.UserID, savedMessage.ID, savedMessage.Seq)
|
||||||
|
|
||||||
|
metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{
|
||||||
|
"detail_type": ws.MetaDetailTypeAck,
|
||||||
|
"conversation_id": conversationID,
|
||||||
|
"group_id": groupID,
|
||||||
|
"id": savedMessage.ID,
|
||||||
|
"user_id": client.UserID,
|
||||||
|
"sender_id": client.UserID,
|
||||||
|
"seq": savedMessage.Seq,
|
||||||
|
"segments": segments,
|
||||||
|
"created_at": savedMessage.CreatedAt.UnixMilli(),
|
||||||
|
})
|
||||||
|
if client.Send != nil {
|
||||||
|
msgBytes, _ := json.Marshal(metaAckMsg)
|
||||||
|
log.Printf("[DEBUG HandleGroupMessageSend] 发送 ack 消息到 channel, userID=%s, msg=%s",
|
||||||
|
client.UserID, string(msgBytes))
|
||||||
|
client.Send <- msgBytes
|
||||||
|
} else {
|
||||||
|
log.Printf("[ERROR HandleGroupMessageSend] client.Send 为 nil, userID=%s", client.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理@提及通知
|
||||||
|
if len(mentionUsers) > 0 || mentionAll {
|
||||||
|
// 提取文本正文(不含 @ 部分)
|
||||||
|
textContent := dto.ExtractTextContentFromModel(segments)
|
||||||
|
// 在通知内容前拼接被@的真实昵称,通过群成员列表查找
|
||||||
|
mentionContent := h.buildMentionContent(groupID, mentionUsers, mentionAll, textContent)
|
||||||
|
h.handleGroupMention(groupID, savedMessage.ID, client.UserID, mentionContent, mentionUsers, mentionAll)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGroupTyping 处理群输入状态
|
||||||
|
func (h *WebSocketHandler) handleGroupTyping(client *ws.Client, msg *ws.WSMessage) {
|
||||||
|
data, ok := msg.Data.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
groupIDFloat, _ := data["groupId"].(float64)
|
||||||
|
if groupIDFloat == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
|
||||||
|
|
||||||
|
isTyping, _ := data["isTyping"].(bool)
|
||||||
|
|
||||||
|
// 验证用户是否是群成员
|
||||||
|
// client.UserID 已经是 string 格式的 UUID
|
||||||
|
isMember, err := h.groupRepo.IsMember(groupID, client.UserID)
|
||||||
|
if err != nil || !isMember {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取用户信息
|
||||||
|
user, err := h.userRepo.GetByID(client.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构造输入状态消息
|
||||||
|
typingMsg := &ws.GroupTypingMessage{
|
||||||
|
GroupID: groupID,
|
||||||
|
UserID: client.UserID,
|
||||||
|
Username: user.Username,
|
||||||
|
IsTyping: isTyping,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 广播给群组其他成员
|
||||||
|
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg)
|
||||||
|
h.BroadcastGroupNoticeExclude(groupID, wsMsg, client.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGroupReadReceipt 处理群消息已读回执
|
||||||
|
func (h *WebSocketHandler) handleGroupReadReceipt(client *ws.Client, msg *ws.WSMessage) {
|
||||||
|
data, ok := msg.Data.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conversationID, _ := data["conversationId"].(string)
|
||||||
|
if conversationID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lastReadSeq, _ := data["lastReadSeq"].(float64)
|
||||||
|
if lastReadSeq == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 标记已读
|
||||||
|
if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil {
|
||||||
|
log.Printf("Failed to mark group message as read: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGroupRecall 处理群消息撤回
|
||||||
|
func (h *WebSocketHandler) handleGroupRecall(client *ws.Client, msg *ws.WSMessage) {
|
||||||
|
data, ok := msg.Data.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
messageID, _ := data["messageId"].(string)
|
||||||
|
if messageID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
groupIDFloat, _ := data["groupId"].(float64)
|
||||||
|
if groupIDFloat == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
|
||||||
|
|
||||||
|
// 撤回消息
|
||||||
|
if err := h.chatService.RecallMessage(context.Background(), messageID, client.UserID); err != nil {
|
||||||
|
log.Printf("Failed to recall group message: %v", err)
|
||||||
|
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
||||||
|
"error": "Failed to recall message",
|
||||||
|
})
|
||||||
|
if client.Send != nil {
|
||||||
|
msgBytes, _ := json.Marshal(errorMsg)
|
||||||
|
client.Send <- msgBytes
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 广播撤回通知给群组所有成员
|
||||||
|
recallNotice := ws.CreateWSMessage(ws.MessageTypeGroupRecall, map[string]interface{}{
|
||||||
|
"messageId": messageID,
|
||||||
|
"groupId": groupID,
|
||||||
|
"userId": client.UserID,
|
||||||
|
"timestamp": time.Now().UnixMilli(),
|
||||||
|
})
|
||||||
|
h.BroadcastGroupNotice(groupID, recallNotice)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGroupMention 处理群消息@提及通知
|
||||||
|
func (h *WebSocketHandler) handleGroupMention(groupID string, messageID, senderID, content string, mentionUsers []uint64, mentionAll bool) {
|
||||||
|
// 如果@所有人,获取所有群成员
|
||||||
|
if mentionAll {
|
||||||
|
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to get group members for mention all: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, member := range members {
|
||||||
|
// 不通知发送者自己
|
||||||
|
memberIDStr := member.UserID
|
||||||
|
if memberIDStr == senderID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送@提及通知
|
||||||
|
mentionMsg := &ws.GroupMentionMessage{
|
||||||
|
GroupID: groupID,
|
||||||
|
MessageID: messageID,
|
||||||
|
FromUserID: senderID,
|
||||||
|
Content: truncateContent(content, 50),
|
||||||
|
MentionAll: true,
|
||||||
|
CreatedAt: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg)
|
||||||
|
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理特定用户的@提及
|
||||||
|
for _, userID := range mentionUsers {
|
||||||
|
// userID 是 uint64,转换为 string
|
||||||
|
userIDStr := strconv.FormatUint(userID, 10)
|
||||||
|
if userIDStr == senderID {
|
||||||
|
continue // 不通知发送者自己
|
||||||
|
}
|
||||||
|
|
||||||
|
mentionMsg := &ws.GroupMentionMessage{
|
||||||
|
GroupID: groupID,
|
||||||
|
MessageID: messageID,
|
||||||
|
FromUserID: senderID,
|
||||||
|
Content: truncateContent(content, 50),
|
||||||
|
MentionAll: false,
|
||||||
|
CreatedAt: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg)
|
||||||
|
h.wsManager.SendToUser(userIDStr, wsMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildMentionContent 构建@提及通知的内容,通过群成员列表查找被@用户的真实昵称
|
||||||
|
func (h *WebSocketHandler) buildMentionContent(groupID string, mentionUsers []uint64, mentionAll bool, textBody string) string {
|
||||||
|
var prefix string
|
||||||
|
if mentionAll {
|
||||||
|
prefix = "@所有人 "
|
||||||
|
} else if len(mentionUsers) > 0 {
|
||||||
|
// 查询群成员列表,找到被@用户的昵称
|
||||||
|
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||||
|
if err == nil {
|
||||||
|
memberNickMap := make(map[string]string, len(members))
|
||||||
|
for _, m := range members {
|
||||||
|
displayName := m.Nickname
|
||||||
|
if displayName == "" {
|
||||||
|
displayName = m.UserID
|
||||||
|
}
|
||||||
|
memberNickMap[m.UserID] = displayName
|
||||||
|
}
|
||||||
|
for _, uid := range mentionUsers {
|
||||||
|
uidStr := strconv.FormatUint(uid, 10)
|
||||||
|
if name, ok := memberNickMap[uidStr]; ok {
|
||||||
|
prefix += "@" + name + " "
|
||||||
|
} else {
|
||||||
|
prefix += "@某人 "
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for range mentionUsers {
|
||||||
|
prefix += "@某人 "
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return prefix + textBody
|
||||||
|
}
|
||||||
|
|
||||||
|
// BroadcastGroupMessage 向群组所有成员广播消息
|
||||||
|
func (h *WebSocketHandler) BroadcastGroupMessage(groupID string, message *ws.GroupMessage, excludeUserID string) {
|
||||||
|
// 获取群组所有成员
|
||||||
|
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to get group members for broadcast: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建WebSocket消息
|
||||||
|
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMessage, message)
|
||||||
|
|
||||||
|
// 遍历成员,如果在线则发送消息
|
||||||
|
for _, member := range members {
|
||||||
|
memberIDStr := member.UserID
|
||||||
|
|
||||||
|
// 排除发送者
|
||||||
|
if memberIDStr == excludeUserID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送消息
|
||||||
|
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BroadcastGroupNotice 广播群组通知给所有成员
|
||||||
|
func (h *WebSocketHandler) BroadcastGroupNotice(groupID string, notice *ws.WSMessage) {
|
||||||
|
// 获取群组所有成员
|
||||||
|
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to get group members for notice broadcast: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 遍历成员,如果在线则发送通知
|
||||||
|
for _, member := range members {
|
||||||
|
memberIDStr := member.UserID
|
||||||
|
h.wsManager.SendToUser(memberIDStr, notice)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BroadcastGroupNoticeExclude 广播群组通知给所有成员(排除指定用户)
|
||||||
|
func (h *WebSocketHandler) BroadcastGroupNoticeExclude(groupID string, notice *ws.WSMessage, excludeUserID string) {
|
||||||
|
// 获取群组所有成员
|
||||||
|
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to get group members for notice broadcast: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 遍历成员,如果在线则发送通知
|
||||||
|
for _, member := range members {
|
||||||
|
memberIDStr := member.UserID
|
||||||
|
if memberIDStr == excludeUserID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h.wsManager.SendToUser(memberIDStr, notice)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendGroupMemberNotice 发送群成员变动通知
|
||||||
|
func (h *WebSocketHandler) SendGroupMemberNotice(noticeType string, groupID string, data *ws.GroupNoticeData) {
|
||||||
|
notice := &ws.GroupNoticeMessage{
|
||||||
|
NoticeType: noticeType,
|
||||||
|
GroupID: groupID,
|
||||||
|
Data: data,
|
||||||
|
Timestamp: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupNotice, notice)
|
||||||
|
h.BroadcastGroupNotice(groupID, wsMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncateContent 截断内容
|
||||||
|
func truncateContent(content string, maxLen int) string {
|
||||||
|
if len(content) <= maxLen {
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
return content[:maxLen] + "..."
|
||||||
|
}
|
||||||
|
|
||||||
|
// BroadcastGroupTyping 向群组所有成员广播输入状态
|
||||||
|
func (h *WebSocketHandler) BroadcastGroupTyping(groupID string, typingMsg *ws.GroupTypingMessage, excludeUserID string) {
|
||||||
|
// 获取群组所有成员
|
||||||
|
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to get group members for typing broadcast: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建WebSocket消息
|
||||||
|
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg)
|
||||||
|
|
||||||
|
// 遍历成员,如果在线则发送消息
|
||||||
|
for _, member := range members {
|
||||||
|
memberIDStr := member.UserID
|
||||||
|
|
||||||
|
// 排除指定用户
|
||||||
|
if memberIDStr == excludeUserID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送消息
|
||||||
|
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BroadcastGroupRead 向群组所有成员广播已读状态
|
||||||
|
func (h *WebSocketHandler) BroadcastGroupRead(groupID string, readMsg map[string]interface{}, excludeUserID string) {
|
||||||
|
// 获取群组所有成员
|
||||||
|
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to get group members for read broadcast: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建WebSocket消息
|
||||||
|
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupRead, readMsg)
|
||||||
|
|
||||||
|
// 遍历成员,如果在线则发送消息
|
||||||
|
for _, member := range members {
|
||||||
|
memberIDStr := member.UserID
|
||||||
|
|
||||||
|
// 排除指定用户
|
||||||
|
if memberIDStr == excludeUserID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送消息
|
||||||
|
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
95
internal/middleware/auth.go
Normal file
95
internal/middleware/auth.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/pkg/response"
|
||||||
|
"carrot_bbs/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Auth 认证中间件
|
||||||
|
func Auth(jwtService *service.JWTService) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
authHeader := c.GetHeader("Authorization")
|
||||||
|
fmt.Printf("[DEBUG] Auth middleware: Authorization header = %q\n", authHeader)
|
||||||
|
|
||||||
|
if authHeader == "" {
|
||||||
|
fmt.Printf("[DEBUG] Auth middleware: no Authorization header, returning 401\n")
|
||||||
|
response.Unauthorized(c, "authorization header is required")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取Token
|
||||||
|
parts := strings.SplitN(authHeader, " ", 2)
|
||||||
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||||
|
fmt.Printf("[DEBUG] Auth middleware: invalid Authorization header format\n")
|
||||||
|
response.Unauthorized(c, "invalid authorization header format")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token := parts[1]
|
||||||
|
fmt.Printf("[DEBUG] Auth middleware: token = %q\n", token[:min(20, len(token))]+"...")
|
||||||
|
|
||||||
|
// 验证Token
|
||||||
|
claims, err := jwtService.ParseToken(token)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[DEBUG] Auth middleware: failed to parse token: %v\n", err)
|
||||||
|
response.Unauthorized(c, "invalid token")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("[DEBUG] Auth middleware: parsed claims, user_id = %q, username = %q\n", claims.UserID, claims.Username)
|
||||||
|
|
||||||
|
// 将用户信息存入上下文
|
||||||
|
c.Set("user_id", claims.UserID)
|
||||||
|
c.Set("username", claims.Username)
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func min(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// OptionalAuth 可选认证中间件
|
||||||
|
func OptionalAuth(jwtService *service.JWTService) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
authHeader := c.GetHeader("Authorization")
|
||||||
|
if authHeader == "" {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取Token
|
||||||
|
parts := strings.SplitN(authHeader, " ", 2)
|
||||||
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token := parts[1]
|
||||||
|
|
||||||
|
// 验证Token
|
||||||
|
claims, err := jwtService.ParseToken(token)
|
||||||
|
if err != nil {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将用户信息存入上下文
|
||||||
|
c.Set("user_id", claims.UserID)
|
||||||
|
c.Set("username", claims.Username)
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
46
internal/middleware/cors.go
Normal file
46
internal/middleware/cors.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CORS CORS中间件
|
||||||
|
func CORS() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// 获取请求路径
|
||||||
|
path := c.Request.URL.Path
|
||||||
|
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||||
|
// 添加 WebSocket 升级所需的头
|
||||||
|
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization, Connection, Upgrade, Sec-WebSocket-Key, Sec-WebSocket-Version, Sec-WebSocket-Protocol, Sec-WebSocket-Extensions")
|
||||||
|
c.Header("Access-Control-Expose-Headers", "Content-Length, Connection, Upgrade")
|
||||||
|
c.Header("Access-Control-Allow-Credentials", "true")
|
||||||
|
|
||||||
|
// 处理 WebSocket 升级请求的预检
|
||||||
|
if c.Request.Method == "OPTIONS" {
|
||||||
|
log.Printf("[CORS] OPTIONS 预检请求: %s", path)
|
||||||
|
c.AbortWithStatus(204)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 针对 WebSocket 路径的特殊处理
|
||||||
|
if path == "/ws" {
|
||||||
|
connection := c.GetHeader("Connection")
|
||||||
|
upgrade := c.GetHeader("Upgrade")
|
||||||
|
log.Printf("[CORS] WebSocket 请求: Connection=%s, Upgrade=%s", connection, upgrade)
|
||||||
|
|
||||||
|
// 检查是否是有效的 WebSocket 升级请求
|
||||||
|
if strings.Contains(strings.ToLower(connection), "upgrade") && strings.ToLower(upgrade) == "websocket" {
|
||||||
|
log.Printf("[CORS] 有效的 WebSocket 升级请求")
|
||||||
|
} else {
|
||||||
|
log.Printf("[CORS] 警告: 不是有效的 WebSocket 升级请求!")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
49
internal/middleware/logger.go
Normal file
49
internal/middleware/logger.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Logger 日志中间件
|
||||||
|
func Logger(logger *zap.Logger) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
start := time.Now()
|
||||||
|
path := c.Request.URL.Path
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
|
||||||
|
latency := time.Since(start)
|
||||||
|
statusCode := c.Writer.Status()
|
||||||
|
|
||||||
|
logger.Info("request",
|
||||||
|
zap.String("method", c.Request.Method),
|
||||||
|
zap.String("path", path),
|
||||||
|
zap.Int("status", statusCode),
|
||||||
|
zap.Duration("latency", latency),
|
||||||
|
zap.String("ip", c.ClientIP()),
|
||||||
|
zap.String("user-agent", c.Request.UserAgent()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recovery 恢复中间件
|
||||||
|
func Recovery(logger *zap.Logger) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
logger.Error("panic recovered",
|
||||||
|
zap.Any("error", err),
|
||||||
|
)
|
||||||
|
c.JSON(500, gin.H{
|
||||||
|
"code": 500,
|
||||||
|
"message": "internal server error",
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
102
internal/middleware/ratelimit.go
Normal file
102
internal/middleware/ratelimit.go
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RateLimiter 限流器
|
||||||
|
type RateLimiter struct {
|
||||||
|
requests map[string][]time.Time
|
||||||
|
mu sync.Mutex
|
||||||
|
limit int
|
||||||
|
window time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRateLimiter 创建限流器
|
||||||
|
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
||||||
|
rl := &RateLimiter{
|
||||||
|
requests: make(map[string][]time.Time),
|
||||||
|
limit: limit,
|
||||||
|
window: window,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 定期清理过期的记录
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
time.Sleep(window)
|
||||||
|
rl.cleanup()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return rl
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup 清理过期的记录
|
||||||
|
func (rl *RateLimiter) cleanup() {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for key, times := range rl.requests {
|
||||||
|
var valid []time.Time
|
||||||
|
for _, t := range times {
|
||||||
|
if now.Sub(t) < rl.window {
|
||||||
|
valid = append(valid, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(valid) == 0 {
|
||||||
|
delete(rl.requests, key)
|
||||||
|
} else {
|
||||||
|
rl.requests[key] = valid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAllowed 检查是否允许请求
|
||||||
|
func (rl *RateLimiter) isAllowed(key string) bool {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
times := rl.requests[key]
|
||||||
|
|
||||||
|
// 过滤掉过期的
|
||||||
|
var valid []time.Time
|
||||||
|
for _, t := range times {
|
||||||
|
if now.Sub(t) < rl.window {
|
||||||
|
valid = append(valid, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(valid) >= rl.limit {
|
||||||
|
rl.requests[key] = valid
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
rl.requests[key] = append(valid, now)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimit 限流中间件
|
||||||
|
func RateLimit(requestsPerMinute int) gin.HandlerFunc {
|
||||||
|
limiter := NewRateLimiter(requestsPerMinute, time.Minute)
|
||||||
|
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
ip := c.ClientIP()
|
||||||
|
|
||||||
|
if !limiter.isAllowed(ip) {
|
||||||
|
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||||
|
"code": 429,
|
||||||
|
"message": "too many requests",
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
118
internal/model/audit_log.go
Normal file
118
internal/model/audit_log.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuditTargetType 审核对象类型
|
||||||
|
type AuditTargetType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
AuditTargetTypePost AuditTargetType = "post" // 帖子
|
||||||
|
AuditTargetTypeComment AuditTargetType = "comment" // 评论
|
||||||
|
AuditTargetTypeMessage AuditTargetType = "message" // 私信
|
||||||
|
AuditTargetTypeUser AuditTargetType = "user" // 用户资料
|
||||||
|
AuditTargetTypeImage AuditTargetType = "image" // 图片
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuditResult 审核结果
|
||||||
|
type AuditResult string
|
||||||
|
|
||||||
|
const (
|
||||||
|
AuditResultPass AuditResult = "pass" // 通过
|
||||||
|
AuditResultReview AuditResult = "review" // 需人工复审
|
||||||
|
AuditResultBlock AuditResult = "block" // 违规拦截
|
||||||
|
AuditResultUnknown AuditResult = "unknown" // 未知
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuditRiskLevel 风险等级
|
||||||
|
type AuditRiskLevel string
|
||||||
|
|
||||||
|
const (
|
||||||
|
AuditRiskLevelLow AuditRiskLevel = "low" // 低风险
|
||||||
|
AuditRiskLevelMedium AuditRiskLevel = "medium" // 中风险
|
||||||
|
AuditRiskLevelHigh AuditRiskLevel = "high" // 高风险
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuditSource 审核来源
|
||||||
|
type AuditSource string
|
||||||
|
|
||||||
|
const (
|
||||||
|
AuditSourceAuto AuditSource = "auto" // 自动审核
|
||||||
|
AuditSourceManual AuditSource = "manual" // 人工审核
|
||||||
|
AuditSourceCallback AuditSource = "callback" // 回调审核
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuditLog 审核日志实体
|
||||||
|
type AuditLog struct {
|
||||||
|
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||||
|
TargetType AuditTargetType `json:"target_type" gorm:"type:varchar(50);index"`
|
||||||
|
TargetID string `json:"target_id" gorm:"type:varchar(255);index"`
|
||||||
|
Content string `json:"content" gorm:"type:text"` // 待审核内容
|
||||||
|
ContentType string `json:"content_type" gorm:"type:varchar(50)"` // 内容类型: text, image
|
||||||
|
ContentURL string `json:"content_url" gorm:"type:text"` // 图片/文件URL
|
||||||
|
AuditType string `json:"audit_type" gorm:"type:varchar(50)"` // 审核类型: porn, violence, ad, political, fraud, gamble
|
||||||
|
Result AuditResult `json:"result" gorm:"type:varchar(50);index"`
|
||||||
|
RiskLevel AuditRiskLevel `json:"risk_level" gorm:"type:varchar(20)"`
|
||||||
|
Labels string `json:"labels" gorm:"type:text"` // JSON数组,标签列表
|
||||||
|
Suggestion string `json:"suggestion" gorm:"type:varchar(50)"` // pass, review, block
|
||||||
|
Detail string `json:"detail" gorm:"type:text"` // 详细说明
|
||||||
|
ThirdPartyID string `json:"third_party_id" gorm:"type:varchar(255)"` // 第三方审核服务返回的ID
|
||||||
|
Source AuditSource `json:"source" gorm:"type:varchar(20);default:auto"`
|
||||||
|
ReviewerID string `json:"reviewer_id" gorm:"type:varchar(255)"` // 审核人ID(人工审核时使用)
|
||||||
|
ReviewerName string `json:"reviewer_name" gorm:"type:varchar(100)"` // 审核人名称
|
||||||
|
ReviewTime *time.Time `json:"review_time" gorm:"index"` // 审核时间
|
||||||
|
UserID string `json:"user_id" gorm:"type:varchar(255);index"` // 内容发布者ID
|
||||||
|
UserIP string `json:"user_ip" gorm:"type:varchar(45)"` // 用户IP
|
||||||
|
Status string `json:"status" gorm:"type:varchar(20);default:pending"` // pending, completed, failed
|
||||||
|
RejectReason string `json:"reject_reason" gorm:"type:text"` // 拒绝原因(人工审核时使用)
|
||||||
|
ExtraData string `json:"extra_data" gorm:"type:text"` // 额外数据,JSON格式
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||||
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成UUID
|
||||||
|
func (al *AuditLog) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if al.ID == "" {
|
||||||
|
al.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (AuditLog) TableName() string {
|
||||||
|
return "audit_logs"
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditLogRequest 创建审核日志请求
|
||||||
|
type AuditLogRequest struct {
|
||||||
|
TargetType AuditTargetType `json:"target_type" validate:"required"`
|
||||||
|
TargetID string `json:"target_id" validate:"required"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
ContentType string `json:"content_type"`
|
||||||
|
ContentURL string `json:"content_url"`
|
||||||
|
AuditType string `json:"audit_type"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
UserIP string `json:"user_ip"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditLogListItem 审核日志列表项
|
||||||
|
type AuditLogListItem struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
TargetType AuditTargetType `json:"target_type"`
|
||||||
|
TargetID string `json:"target_id"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
ContentType string `json:"content_type"`
|
||||||
|
Result AuditResult `json:"result"`
|
||||||
|
RiskLevel AuditRiskLevel `json:"risk_level"`
|
||||||
|
Suggestion string `json:"suggestion"`
|
||||||
|
Source AuditSource `json:"source"`
|
||||||
|
ReviewerID string `json:"reviewer_id"`
|
||||||
|
ReviewTime *time.Time `json:"review_time"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
80
internal/model/comment.go
Normal file
80
internal/model/comment.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CommentStatus 评论状态
|
||||||
|
type CommentStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
CommentStatusDraft CommentStatus = "draft"
|
||||||
|
CommentStatusPending CommentStatus = "pending"
|
||||||
|
CommentStatusPublished CommentStatus = "published"
|
||||||
|
CommentStatusRejected CommentStatus = "rejected"
|
||||||
|
CommentStatusDeleted CommentStatus = "deleted"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Comment 评论实体
|
||||||
|
type Comment struct {
|
||||||
|
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||||
|
PostID string `json:"post_id" gorm:"type:varchar(36);not null;index:idx_comments_post_parent_status_created,priority:1"`
|
||||||
|
UserID string `json:"user_id" gorm:"type:varchar(36);index;not null"`
|
||||||
|
ParentID *string `json:"parent_id" gorm:"type:varchar(36);index:idx_comments_post_parent_status_created,priority:2"` // 父评论 ID(支持嵌套)
|
||||||
|
RootID *string `json:"root_id" gorm:"type:varchar(36);index:idx_comments_root_status_created,priority:1"` // 根评论 ID(用于高效查询)
|
||||||
|
Content string `json:"content" gorm:"type:text;not null"`
|
||||||
|
Images string `json:"images" gorm:"type:text"` // 图片URL列表,JSON数组格式
|
||||||
|
|
||||||
|
// 关联
|
||||||
|
User *User `json:"-" gorm:"foreignKey:UserID"`
|
||||||
|
Replies []*Comment `json:"-" gorm:"-"` // 子回复(手动加载,非 GORM 关联)
|
||||||
|
|
||||||
|
// 审核状态
|
||||||
|
Status CommentStatus `json:"status" gorm:"type:varchar(20);default:published;index:idx_comments_post_parent_status_created,priority:3;index:idx_comments_root_status_created,priority:2"`
|
||||||
|
|
||||||
|
// 统计
|
||||||
|
LikesCount int `json:"likes_count" gorm:"default:0"`
|
||||||
|
RepliesCount int `json:"replies_count" gorm:"default:0"`
|
||||||
|
|
||||||
|
// 软删除
|
||||||
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||||
|
|
||||||
|
// 时间戳
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_comments_post_parent_status_created,priority:4,sort:asc;index:idx_comments_root_status_created,priority:3,sort:asc"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成UUID
|
||||||
|
func (c *Comment) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if c.ID == "" {
|
||||||
|
c.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Comment) TableName() string {
|
||||||
|
return "comments"
|
||||||
|
}
|
||||||
|
|
||||||
|
// CommentLike 评论点赞
|
||||||
|
type CommentLike struct {
|
||||||
|
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||||
|
CommentID string `json:"comment_id" gorm:"type:varchar(36);index;not null"`
|
||||||
|
UserID string `json:"user_id" gorm:"type:varchar(36);index;not null"`
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成UUID
|
||||||
|
func (cl *CommentLike) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if cl.ID == "" {
|
||||||
|
cl.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (CommentLike) TableName() string {
|
||||||
|
return "comment_likes"
|
||||||
|
}
|
||||||
68
internal/model/conversation.go
Normal file
68
internal/model/conversation.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationType 会话类型
|
||||||
|
type ConversationType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ConversationTypePrivate ConversationType = "private" // 私聊
|
||||||
|
ConversationTypeGroup ConversationType = "group" // 群聊
|
||||||
|
ConversationTypeSystem ConversationType = "system" // 系统通知会话
|
||||||
|
)
|
||||||
|
|
||||||
|
// Conversation 会话实体
|
||||||
|
// 使用雪花算法ID(作为string存储)和seq机制实现消息排序和增量同步
|
||||||
|
type Conversation struct {
|
||||||
|
ID string `gorm:"primaryKey;size:20" json:"id"` // 雪花算法ID(转为string避免精度丢失)
|
||||||
|
Type ConversationType `gorm:"type:varchar(20);default:'private'" json:"type"` // 会话类型
|
||||||
|
GroupID *string `gorm:"index" json:"group_id,omitempty"` // 关联的群组ID(群聊时使用,string类型避免JS精度丢失);使用指针支持NULL值
|
||||||
|
LastSeq int64 `gorm:"default:0" json:"last_seq"` // 最后一条消息的seq
|
||||||
|
LastMsgTime *time.Time `json:"last_msg_time,omitempty"` // 最后消息时间
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||||
|
|
||||||
|
// 关联 - 使用 polymorphic 模式避免外键约束问题
|
||||||
|
Participants []ConversationParticipant `gorm:"foreignKey:ConversationID" json:"participants,omitempty"`
|
||||||
|
Group *Group `gorm:"foreignKey:GroupID;references:ID" json:"group,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成雪花算法ID
|
||||||
|
func (c *Conversation) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if c.ID == "" {
|
||||||
|
id, err := utils.GetSnowflake().GenerateID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.ID = strconv.FormatInt(id, 10)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Conversation) TableName() string {
|
||||||
|
return "conversations"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationParticipant 会话参与者
|
||||||
|
type ConversationParticipant struct {
|
||||||
|
ID uint `gorm:"primaryKey" json:"id"`
|
||||||
|
ConversationID string `gorm:"not null;size:20;uniqueIndex:idx_conversation_user,priority:1;index:idx_cp_conversation_user,priority:1" json:"conversation_id"` // 雪花算法ID(string类型)
|
||||||
|
UserID string `gorm:"column:user_id;type:varchar(50);not null;uniqueIndex:idx_conversation_user,priority:2;index:idx_cp_conversation_user,priority:2;index:idx_cp_user_hidden_pinned_updated,priority:1" json:"user_id"` // UUID格式,与JWT中user_id保持一致
|
||||||
|
LastReadSeq int64 `gorm:"default:0" json:"last_read_seq"` // 已读到的seq位置
|
||||||
|
Muted bool `gorm:"default:false" json:"muted"` // 是否免打扰
|
||||||
|
IsPinned bool `gorm:"default:false;index:idx_cp_user_hidden_pinned_updated,priority:3" json:"is_pinned"` // 是否置顶会话(用户维度)
|
||||||
|
HiddenAt *time.Time `gorm:"index:idx_cp_user_hidden_pinned_updated,priority:2" json:"hidden_at,omitempty"` // 仅自己删除会话时使用,收到新消息后自动恢复
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime;index:idx_cp_user_hidden_pinned_updated,priority:4,sort:desc"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ConversationParticipant) TableName() string {
|
||||||
|
return "conversation_participants"
|
||||||
|
}
|
||||||
94
internal/model/device_token.go
Normal file
94
internal/model/device_token.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeviceType 设备类型
|
||||||
|
type DeviceType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
DeviceTypeIOS DeviceType = "ios" // iOS设备
|
||||||
|
DeviceTypeAndroid DeviceType = "android" // Android设备
|
||||||
|
DeviceTypeWeb DeviceType = "web" // Web端
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeviceToken 设备Token实体
|
||||||
|
// 用于管理用户的多设备推送Token
|
||||||
|
type DeviceToken struct {
|
||||||
|
ID int64 `gorm:"primaryKey;autoIncrement:false" json:"id"` // 雪花算法ID
|
||||||
|
UserID string `gorm:"column:user_id;type:varchar(50);index;not null" json:"user_id"` // 用户ID (UUID格式)
|
||||||
|
DeviceID string `gorm:"type:varchar(100);not null" json:"device_id"` // 设备唯一标识
|
||||||
|
DeviceType DeviceType `gorm:"type:varchar(20);not null" json:"device_type"` // 设备类型
|
||||||
|
PushToken string `gorm:"type:varchar(256);not null" json:"push_token"` // 推送Token(FCM/APNs等)
|
||||||
|
IsActive bool `gorm:"default:true" json:"is_active"` // 是否活跃
|
||||||
|
DeviceName string `gorm:"type:varchar(100)" json:"device_name,omitempty"` // 设备名称(可选)
|
||||||
|
|
||||||
|
// 时间戳
|
||||||
|
LastUsedAt *time.Time `json:"last_used_at,omitempty"` // 最后使用时间
|
||||||
|
|
||||||
|
// 软删除
|
||||||
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||||
|
|
||||||
|
// 时间戳
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成雪花算法ID
|
||||||
|
func (d *DeviceToken) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if d.ID == 0 {
|
||||||
|
id, err := utils.GetSnowflake().GenerateID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
d.ID = id
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (DeviceToken) TableName() string {
|
||||||
|
return "device_tokens"
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateLastUsed 更新最后使用时间
|
||||||
|
func (d *DeviceToken) UpdateLastUsed() {
|
||||||
|
now := time.Now()
|
||||||
|
d.LastUsedAt = &now
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deactivate 停用设备
|
||||||
|
func (d *DeviceToken) Deactivate() {
|
||||||
|
d.IsActive = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Activate 激活设备
|
||||||
|
func (d *DeviceToken) Activate() {
|
||||||
|
d.IsActive = true
|
||||||
|
now := time.Now()
|
||||||
|
d.LastUsedAt = &now
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsIOS 判断是否为iOS设备
|
||||||
|
func (d *DeviceToken) IsIOS() bool {
|
||||||
|
return d.DeviceType == DeviceTypeIOS
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAndroid 判断是否为Android设备
|
||||||
|
func (d *DeviceToken) IsAndroid() bool {
|
||||||
|
return d.DeviceType == DeviceTypeAndroid
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsWeb 判断是否为Web端
|
||||||
|
func (d *DeviceToken) IsWeb() bool {
|
||||||
|
return d.DeviceType == DeviceTypeWeb
|
||||||
|
}
|
||||||
|
|
||||||
|
// SupportsMobilePush 判断是否支持手机推送
|
||||||
|
func (d *DeviceToken) SupportsMobilePush() bool {
|
||||||
|
return d.DeviceType == DeviceTypeIOS || d.DeviceType == DeviceTypeAndroid
|
||||||
|
}
|
||||||
28
internal/model/favorite.go
Normal file
28
internal/model/favorite.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Favorite 收藏
|
||||||
|
type Favorite struct {
|
||||||
|
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||||
|
PostID string `json:"post_id" gorm:"type:varchar(36);not null;index;uniqueIndex:idx_favorite_post_user,priority:1"`
|
||||||
|
UserID string `json:"user_id" gorm:"type:varchar(36);not null;index;uniqueIndex:idx_favorite_post_user,priority:2"`
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成UUID
|
||||||
|
func (f *Favorite) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if f.ID == "" {
|
||||||
|
f.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Favorite) TableName() string {
|
||||||
|
return "favorites"
|
||||||
|
}
|
||||||
28
internal/model/follow.go
Normal file
28
internal/model/follow.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Follow 关注关系
|
||||||
|
type Follow struct {
|
||||||
|
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||||
|
FollowerID string `json:"follower_id" gorm:"type:varchar(36);index;not null;uniqueIndex:idx_follower_following"` // 关注者
|
||||||
|
FollowingID string `json:"following_id" gorm:"type:varchar(36);index;not null;uniqueIndex:idx_follower_following"` // 被关注者
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成UUID
|
||||||
|
func (f *Follow) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if f.ID == "" {
|
||||||
|
f.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Follow) TableName() string {
|
||||||
|
return "follows"
|
||||||
|
}
|
||||||
57
internal/model/group.go
Normal file
57
internal/model/group.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/pkg/utils"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JoinType 群组加入类型
|
||||||
|
type JoinType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
JoinTypeAnyone JoinType = 0 // 允许任何人加入
|
||||||
|
JoinTypeApproval JoinType = 1 // 需要审批
|
||||||
|
JoinTypeForbidden JoinType = 2 // 不允许加入
|
||||||
|
)
|
||||||
|
|
||||||
|
// Group 群组模型
|
||||||
|
type Group struct {
|
||||||
|
ID string `gorm:"primaryKey;size:20" json:"id"`
|
||||||
|
Name string `gorm:"size:50;not null" json:"name"`
|
||||||
|
Avatar string `gorm:"size:512" json:"avatar"`
|
||||||
|
Description string `gorm:"size:500" json:"description"`
|
||||||
|
OwnerID string `gorm:"type:varchar(36);not null;index" json:"owner_id"`
|
||||||
|
MemberCount int `gorm:"default:0" json:"member_count"`
|
||||||
|
MaxMembers int `gorm:"default:500" json:"max_members"`
|
||||||
|
JoinType JoinType `gorm:"default:0" json:"join_type"` // 0:允许任何人加入 1:需要审批 2:不允许加入
|
||||||
|
MuteAll bool `gorm:"default:false" json:"mute_all"` // 全员禁言
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成雪花算法ID
|
||||||
|
func (g *Group) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if g.ID == "" {
|
||||||
|
id, err := utils.GetSnowflake().GenerateID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
g.ID = strconv.FormatInt(id, 10)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIDInt 获取数字类型的ID(用于比较)
|
||||||
|
func (g *Group) GetIDInt() uint64 {
|
||||||
|
id, _ := strconv.ParseUint(g.ID, 10, 64)
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName 指定表名
|
||||||
|
func (Group) TableName() string {
|
||||||
|
return "groups"
|
||||||
|
}
|
||||||
38
internal/model/group_announcement.go
Normal file
38
internal/model/group_announcement.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/pkg/utils"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GroupAnnouncement 群公告模型
|
||||||
|
type GroupAnnouncement struct {
|
||||||
|
ID string `gorm:"primaryKey;size:20" json:"id"`
|
||||||
|
GroupID string `gorm:"not null;index" json:"group_id"`
|
||||||
|
Content string `gorm:"type:text;not null" json:"content"`
|
||||||
|
AuthorID string `gorm:"type:varchar(36);not null" json:"author_id"`
|
||||||
|
IsPinned bool `gorm:"default:false" json:"is_pinned"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成雪花算法ID
|
||||||
|
func (ga *GroupAnnouncement) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if ga.ID == "" {
|
||||||
|
id, err := utils.GetSnowflake().GenerateID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ga.ID = strconv.FormatInt(id, 10)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName 指定表名
|
||||||
|
func (GroupAnnouncement) TableName() string {
|
||||||
|
return "group_announcements"
|
||||||
|
}
|
||||||
59
internal/model/group_join_request.go
Normal file
59
internal/model/group_join_request.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/pkg/utils"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GroupJoinRequestType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
GroupJoinRequestTypeInvite GroupJoinRequestType = "invite"
|
||||||
|
GroupJoinRequestTypeJoinApply GroupJoinRequestType = "join_apply"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GroupJoinRequestStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
GroupJoinRequestStatusPending GroupJoinRequestStatus = "pending"
|
||||||
|
GroupJoinRequestStatusAccepted GroupJoinRequestStatus = "accepted"
|
||||||
|
GroupJoinRequestStatusRejected GroupJoinRequestStatus = "rejected"
|
||||||
|
GroupJoinRequestStatusCancelled GroupJoinRequestStatus = "cancelled"
|
||||||
|
GroupJoinRequestStatusExpired GroupJoinRequestStatus = "expired"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GroupJoinRequest 统一保存邀请入群和主动加群申请
|
||||||
|
type GroupJoinRequest struct {
|
||||||
|
ID string `gorm:"primaryKey;size:20" json:"id"`
|
||||||
|
Flag string `gorm:"size:64;uniqueIndex;not null" json:"flag"`
|
||||||
|
GroupID string `gorm:"not null;index;index:idx_gjr_group_target_type_status_created,priority:1" json:"group_id"`
|
||||||
|
InitiatorID string `gorm:"type:varchar(36);not null;index" json:"initiator_id"`
|
||||||
|
TargetUserID string `gorm:"type:varchar(36);not null;index;index:idx_gjr_group_target_type_status_created,priority:2" json:"target_user_id"`
|
||||||
|
RequestType GroupJoinRequestType `gorm:"size:20;not null;index;index:idx_gjr_group_target_type_status_created,priority:3" json:"request_type"`
|
||||||
|
Status GroupJoinRequestStatus `gorm:"size:20;not null;index;index:idx_gjr_group_target_type_status_created,priority:4" json:"status"`
|
||||||
|
Reason string `gorm:"size:500" json:"reason"`
|
||||||
|
ReviewerID string `gorm:"type:varchar(36);index" json:"reviewer_id"`
|
||||||
|
ReviewedAt *time.Time `json:"reviewed_at"`
|
||||||
|
ExpireAt *time.Time `json:"expire_at"`
|
||||||
|
CreatedAt time.Time `gorm:"index:idx_gjr_group_target_type_status_created,priority:5,sort:desc" json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GroupJoinRequest) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if r.ID == "" {
|
||||||
|
id, err := utils.GetSnowflake().GenerateID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.ID = strconv.FormatInt(id, 10)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (GroupJoinRequest) TableName() string {
|
||||||
|
return "group_join_requests"
|
||||||
|
}
|
||||||
47
internal/model/group_member.go
Normal file
47
internal/model/group_member.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/pkg/utils"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GroupMember 群成员模型
|
||||||
|
type GroupMember struct {
|
||||||
|
ID string `gorm:"primaryKey;size:20" json:"id"`
|
||||||
|
GroupID string `gorm:"not null;uniqueIndex:idx_group_user" json:"group_id"`
|
||||||
|
UserID string `gorm:"type:varchar(36);not null;uniqueIndex:idx_group_user" json:"user_id"`
|
||||||
|
Role string `gorm:"size:20;default:'member'" json:"role"` // owner, admin, member
|
||||||
|
Nickname string `gorm:"size:50" json:"nickname"` // 群内昵称
|
||||||
|
Muted bool `gorm:"default:false" json:"muted"` // 是否被禁言
|
||||||
|
JoinTime time.Time `json:"join_time"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成雪花算法ID
|
||||||
|
func (gm *GroupMember) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if gm.ID == "" {
|
||||||
|
id, err := utils.GetSnowflake().GenerateID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
gm.ID = strconv.FormatInt(id, 10)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName 指定表名
|
||||||
|
func (GroupMember) TableName() string {
|
||||||
|
return "group_members"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Role 常量
|
||||||
|
const (
|
||||||
|
GroupRoleOwner = "owner"
|
||||||
|
GroupRoleAdmin = "admin"
|
||||||
|
GroupRoleMember = "member"
|
||||||
|
)
|
||||||
159
internal/model/init.go
Normal file
159
internal/model/init.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/driver/postgres"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DB 全局数据库连接
|
||||||
|
var DB *gorm.DB
|
||||||
|
|
||||||
|
// InitDB 初始化数据库连接
|
||||||
|
func InitDB(cfg *config.DatabaseConfig) error {
|
||||||
|
var err error
|
||||||
|
var db *gorm.DB
|
||||||
|
gormLogger := logger.New(
|
||||||
|
log.New(os.Stdout, "\r\n", log.LstdFlags),
|
||||||
|
logger.Config{
|
||||||
|
SlowThreshold: time.Duration(cfg.SlowThresholdMs) * time.Millisecond,
|
||||||
|
LogLevel: parseGormLogLevel(cfg.LogLevel),
|
||||||
|
IgnoreRecordNotFoundError: cfg.IgnoreRecordNotFound,
|
||||||
|
ParameterizedQueries: cfg.ParameterizedQueries,
|
||||||
|
Colorful: false,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
// 根据数据库类型选择驱动
|
||||||
|
switch cfg.Type {
|
||||||
|
case "sqlite":
|
||||||
|
db, err = gorm.Open(sqlite.Open(cfg.SQLite.Path), &gorm.Config{
|
||||||
|
Logger: gormLogger,
|
||||||
|
})
|
||||||
|
case "postgres", "postgresql":
|
||||||
|
dsn := cfg.Postgres.DSN()
|
||||||
|
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||||
|
Logger: gormLogger,
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
// 默认使用PostgreSQL
|
||||||
|
dsn := cfg.Postgres.DSN()
|
||||||
|
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||||
|
Logger: gormLogger,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB = db
|
||||||
|
|
||||||
|
// 配置连接池(SQLite不支持连接池配置,跳过)
|
||||||
|
if cfg.Type != "sqlite" {
|
||||||
|
sqlDB, err := DB.DB()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get database instance: %w", err)
|
||||||
|
}
|
||||||
|
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||||
|
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 自动迁移
|
||||||
|
if err := autoMigrate(DB); err != nil {
|
||||||
|
return fmt.Errorf("failed to auto migrate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Database connected (%s) and migrated successfully", cfg.Type)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseGormLogLevel(level string) logger.LogLevel {
|
||||||
|
switch level {
|
||||||
|
case "silent":
|
||||||
|
return logger.Silent
|
||||||
|
case "error":
|
||||||
|
return logger.Error
|
||||||
|
case "warn":
|
||||||
|
return logger.Warn
|
||||||
|
case "info":
|
||||||
|
return logger.Info
|
||||||
|
default:
|
||||||
|
return logger.Warn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// autoMigrate 自动迁移数据库表
|
||||||
|
func autoMigrate(db *gorm.DB) error {
|
||||||
|
err := db.AutoMigrate(
|
||||||
|
// 用户相关
|
||||||
|
&User{},
|
||||||
|
|
||||||
|
// 帖子相关
|
||||||
|
&Post{},
|
||||||
|
&PostImage{},
|
||||||
|
|
||||||
|
// 评论相关
|
||||||
|
&Comment{},
|
||||||
|
&CommentLike{},
|
||||||
|
|
||||||
|
// 消息相关(使用雪花算法ID和seq机制)
|
||||||
|
// 已读位置存储在 ConversationParticipant.LastReadSeq 中
|
||||||
|
&Conversation{},
|
||||||
|
&ConversationParticipant{},
|
||||||
|
&Message{},
|
||||||
|
|
||||||
|
// 系统通知(独立表,每个用户只能看到自己的通知)
|
||||||
|
&SystemNotification{},
|
||||||
|
|
||||||
|
// 通知
|
||||||
|
&Notification{},
|
||||||
|
|
||||||
|
// 推送中心相关
|
||||||
|
&PushRecord{}, // 推送记录
|
||||||
|
&DeviceToken{}, // 设备Token
|
||||||
|
|
||||||
|
// 社交
|
||||||
|
&Follow{},
|
||||||
|
&UserBlock{},
|
||||||
|
&PostLike{},
|
||||||
|
&Favorite{},
|
||||||
|
|
||||||
|
// 投票
|
||||||
|
&VoteOption{},
|
||||||
|
&UserVote{},
|
||||||
|
|
||||||
|
// 敏感词和审核
|
||||||
|
&SensitiveWord{},
|
||||||
|
&AuditLog{},
|
||||||
|
|
||||||
|
// 群组相关
|
||||||
|
&Group{},
|
||||||
|
&GroupMember{},
|
||||||
|
&GroupAnnouncement{},
|
||||||
|
&GroupJoinRequest{},
|
||||||
|
|
||||||
|
// 自定义表情
|
||||||
|
&UserSticker{},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseDB 关闭数据库连接
|
||||||
|
func CloseDB() {
|
||||||
|
if sqlDB, err := DB.DB(); err == nil {
|
||||||
|
sqlDB.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
28
internal/model/like.go
Normal file
28
internal/model/like.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostLike 帖子点赞
|
||||||
|
type PostLike struct {
|
||||||
|
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||||
|
PostID string `json:"post_id" gorm:"type:varchar(36);not null;index;uniqueIndex:idx_post_like_user,priority:1"`
|
||||||
|
UserID string `json:"user_id" gorm:"type:varchar(36);not null;index;uniqueIndex:idx_post_like_user,priority:2"`
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成UUID
|
||||||
|
func (pl *PostLike) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if pl.ID == "" {
|
||||||
|
pl.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (PostLike) TableName() string {
|
||||||
|
return "post_likes"
|
||||||
|
}
|
||||||
205
internal/model/message.go
Normal file
205
internal/model/message.go
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 系统消息相关常量
|
||||||
|
const (
|
||||||
|
// SystemSenderID 系统消息发送者ID (类似QQ 10000)
|
||||||
|
SystemSenderID int64 = 10000
|
||||||
|
|
||||||
|
// SystemSenderIDStr 系统消息发送者ID字符串版本
|
||||||
|
SystemSenderIDStr string = "10000"
|
||||||
|
|
||||||
|
// SystemConversationID 系统通知会话ID (string类型)
|
||||||
|
SystemConversationID string = "9999999999"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ContentType 消息内容类型
|
||||||
|
type ContentType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ContentTypeText ContentType = "text"
|
||||||
|
ContentTypeImage ContentType = "image"
|
||||||
|
ContentTypeVideo ContentType = "video"
|
||||||
|
ContentTypeAudio ContentType = "audio"
|
||||||
|
ContentTypeFile ContentType = "file"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MessageStatus 消息状态
|
||||||
|
type MessageStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MessageStatusNormal MessageStatus = "normal" // 正常
|
||||||
|
MessageStatusRecalled MessageStatus = "recalled" // 已撤回
|
||||||
|
MessageStatusDeleted MessageStatus = "deleted" // 已删除
|
||||||
|
)
|
||||||
|
|
||||||
|
// MessageCategory 消息类别
|
||||||
|
type MessageCategory string
|
||||||
|
|
||||||
|
const (
|
||||||
|
CategoryChat MessageCategory = "chat" // 普通聊天
|
||||||
|
CategoryNotification MessageCategory = "notification" // 通知类消息
|
||||||
|
CategoryAnnouncement MessageCategory = "announcement" // 系统公告
|
||||||
|
CategoryMarketing MessageCategory = "marketing" // 营销消息
|
||||||
|
)
|
||||||
|
|
||||||
|
// SystemMessageType 系统消息类型 (对应原NotificationType)
|
||||||
|
type SystemMessageType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// 互动通知
|
||||||
|
SystemTypeLikePost SystemMessageType = "like_post" // 点赞帖子
|
||||||
|
SystemTypeLikeComment SystemMessageType = "like_comment" // 点赞评论
|
||||||
|
SystemTypeComment SystemMessageType = "comment" // 评论
|
||||||
|
SystemTypeReply SystemMessageType = "reply" // 回复
|
||||||
|
SystemTypeFollow SystemMessageType = "follow" // 关注
|
||||||
|
SystemTypeMention SystemMessageType = "mention" // @提及
|
||||||
|
|
||||||
|
// 系统消息
|
||||||
|
SystemTypeSystem SystemMessageType = "system" // 系统通知
|
||||||
|
SystemTypeAnnounce SystemMessageType = "announce" // 系统公告
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExtraData 消息额外数据,用于存储系统消息的相关信息
|
||||||
|
type ExtraData struct {
|
||||||
|
// 操作者信息
|
||||||
|
ActorID int64 `json:"actor_id,omitempty"` // 操作者ID (数字格式,兼容旧数据)
|
||||||
|
ActorIDStr string `json:"actor_id_str,omitempty"` // 操作者ID (UUID字符串格式)
|
||||||
|
ActorName string `json:"actor_name,omitempty"` // 操作者名称
|
||||||
|
AvatarURL string `json:"avatar_url,omitempty"` // 操作者头像
|
||||||
|
|
||||||
|
// 目标信息
|
||||||
|
TargetID int64 `json:"target_id,omitempty"` // 目标ID(帖子ID、评论ID等)
|
||||||
|
TargetTitle string `json:"target_title,omitempty"` // 目标标题
|
||||||
|
TargetType string `json:"target_type,omitempty"` // 目标类型(post/comment等)
|
||||||
|
|
||||||
|
// 其他信息
|
||||||
|
ActionURL string `json:"action_url,omitempty"` // 跳转链接
|
||||||
|
ActionTime string `json:"action_time,omitempty"` // 操作时间
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value 实现driver.Valuer接口,用于数据库存储
|
||||||
|
func (e ExtraData) Value() (driver.Value, error) {
|
||||||
|
return json.Marshal(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan 实现sql.Scanner接口,用于数据库读取
|
||||||
|
func (e *ExtraData) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
bytes, ok := value.([]byte)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("type assertion to []byte failed")
|
||||||
|
}
|
||||||
|
return json.Unmarshal(bytes, e)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageSegmentData 单个消息段的数据
|
||||||
|
type MessageSegmentData map[string]interface{}
|
||||||
|
|
||||||
|
// MessageSegment 消息段
|
||||||
|
type MessageSegment struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Data MessageSegmentData `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageSegments 消息链类型
|
||||||
|
type MessageSegments []MessageSegment
|
||||||
|
|
||||||
|
// Value 实现driver.Valuer接口,用于数据库存储
|
||||||
|
func (s MessageSegments) Value() (driver.Value, error) {
|
||||||
|
return json.Marshal(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan 实现sql.Scanner接口,用于数据库读取
|
||||||
|
func (s *MessageSegments) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
*s = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
bytes, ok := value.([]byte)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("type assertion to []byte failed")
|
||||||
|
}
|
||||||
|
return json.Unmarshal(bytes, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message 消息实体
|
||||||
|
// 使用雪花算法ID(string类型)和seq机制实现消息排序和增量同步
|
||||||
|
type Message struct {
|
||||||
|
ID string `gorm:"primaryKey;size:20" json:"id"` // 雪花算法ID(string类型)
|
||||||
|
ConversationID string `gorm:"not null;size:20;index:idx_msg_conversation_seq,priority:1" json:"conversation_id"` // 会话ID(string类型)
|
||||||
|
SenderID string `gorm:"column:sender_id;type:varchar(50);index;not null" json:"sender_id"` // 发送者ID (UUID格式)
|
||||||
|
Seq int64 `gorm:"not null;index:idx_msg_conversation_seq,priority:2" json:"seq"` // 会话内序号,用于排序和增量同步
|
||||||
|
Segments MessageSegments `gorm:"type:json" json:"segments"` // 消息链(结构体数组)
|
||||||
|
ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID(string类型)
|
||||||
|
Status MessageStatus `gorm:"type:varchar(20);default:'normal'" json:"status"` // 消息状态
|
||||||
|
|
||||||
|
// 新增字段:消息分类和系统消息类型
|
||||||
|
Category MessageCategory `gorm:"type:varchar(20);default:'chat'" json:"category"` // 消息分类
|
||||||
|
SystemType SystemMessageType `gorm:"type:varchar(30)" json:"system_type,omitempty"` // 系统消息类型
|
||||||
|
ExtraData *ExtraData `gorm:"type:json" json:"extra_data,omitempty"` // 额外数据(JSON格式)
|
||||||
|
|
||||||
|
// @相关字段
|
||||||
|
MentionUsers string `gorm:"type:text" json:"mention_users"` // @的用户ID列表,JSON数组
|
||||||
|
MentionAll bool `gorm:"default:false" json:"mention_all"` // 是否@所有人
|
||||||
|
|
||||||
|
// 软删除
|
||||||
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||||
|
|
||||||
|
// 时间戳
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SenderIDStr 返回发送者ID字符串(保持兼容性)
|
||||||
|
func (m *Message) SenderIDStr() string {
|
||||||
|
return m.SenderID
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成雪花算法ID
|
||||||
|
func (m *Message) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if m.ID == "" {
|
||||||
|
id, err := utils.GetSnowflake().GenerateID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.ID = strconv.FormatInt(id, 10)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Message) TableName() string {
|
||||||
|
return "messages"
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSystemMessage 判断是否为系统消息
|
||||||
|
func (m *Message) IsSystemMessage() bool {
|
||||||
|
return m.SenderID == SystemSenderIDStr || m.Category == CategoryNotification || m.Category == CategoryAnnouncement
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsInteractionNotification 判断是否为互动通知
|
||||||
|
func (m *Message) IsInteractionNotification() bool {
|
||||||
|
if m.Category != CategoryNotification {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch m.SystemType {
|
||||||
|
case SystemTypeLikePost, SystemTypeLikeComment, SystemTypeComment,
|
||||||
|
SystemTypeReply, SystemTypeFollow, SystemTypeMention:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
19
internal/model/message_read.go
Normal file
19
internal/model/message_read.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MessageRead 消息已读状态
|
||||||
|
// 记录每个用户在每个会话中的已读位置
|
||||||
|
type MessageRead struct {
|
||||||
|
ID uint `gorm:"primaryKey" json:"id"`
|
||||||
|
ConversationID int64 `gorm:"uniqueIndex:idx_conversation_user;not null" json:"conversation_id"`
|
||||||
|
UserID uint `gorm:"uniqueIndex:idx_conversation_user;not null" json:"user_id"`
|
||||||
|
LastReadSeq int64 `gorm:"not null" json:"last_read_seq"` // 已读到的seq位置
|
||||||
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (MessageRead) TableName() string {
|
||||||
|
return "message_reads"
|
||||||
|
}
|
||||||
53
internal/model/notification.go
Normal file
53
internal/model/notification.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NotificationType 通知类型
|
||||||
|
type NotificationType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
NotificationTypeLikePost NotificationType = "like_post"
|
||||||
|
NotificationTypeLikeComment NotificationType = "like_comment"
|
||||||
|
NotificationTypeComment NotificationType = "comment"
|
||||||
|
NotificationTypeReply NotificationType = "reply"
|
||||||
|
NotificationTypeFollow NotificationType = "follow"
|
||||||
|
NotificationTypeMention NotificationType = "mention"
|
||||||
|
NotificationTypeSystem NotificationType = "system"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Notification 通知实体
|
||||||
|
type Notification struct {
|
||||||
|
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||||
|
UserID string `json:"user_id" gorm:"type:varchar(36);not null;index:idx_notifications_user_read_created,priority:1"` // 接收者
|
||||||
|
Type NotificationType `json:"type" gorm:"type:varchar(30);not null"`
|
||||||
|
Title string `json:"title" gorm:"type:varchar(200);not null"`
|
||||||
|
Content string `json:"content" gorm:"type:text"`
|
||||||
|
Data string `json:"data" gorm:"type:jsonb"` // 相关数据(JSON)
|
||||||
|
|
||||||
|
// 已读状态
|
||||||
|
IsRead bool `json:"is_read" gorm:"default:false;index:idx_notifications_user_read_created,priority:2"`
|
||||||
|
ReadAt *time.Time `json:"read_at" gorm:"type:timestamp"`
|
||||||
|
|
||||||
|
// 软删除
|
||||||
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||||
|
|
||||||
|
// 时间戳
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_notifications_user_read_created,priority:3,sort:desc"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成UUID
|
||||||
|
func (n *Notification) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if n.ID == "" {
|
||||||
|
n.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Notification) TableName() string {
|
||||||
|
return "notifications"
|
||||||
|
}
|
||||||
100
internal/model/post.go
Normal file
100
internal/model/post.go
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostStatus 帖子状态
|
||||||
|
type PostStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
PostStatusDraft PostStatus = "draft"
|
||||||
|
PostStatusPending PostStatus = "pending" // 待审核
|
||||||
|
PostStatusPublished PostStatus = "published"
|
||||||
|
PostStatusRejected PostStatus = "rejected"
|
||||||
|
PostStatusDeleted PostStatus = "deleted"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Post 帖子实体
|
||||||
|
type Post struct {
|
||||||
|
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||||
|
UserID string `json:"user_id" gorm:"type:varchar(36);index;index:idx_posts_user_status_created,priority:1;not null"`
|
||||||
|
CommunityID string `json:"community_id" gorm:"type:varchar(36);index"`
|
||||||
|
Title string `json:"title" gorm:"type:varchar(200);not null"`
|
||||||
|
Content string `json:"content" gorm:"type:text;not null"`
|
||||||
|
|
||||||
|
// 关联
|
||||||
|
// User 需要参与缓存序列化;否则列表命中缓存后会丢失作者信息,前端退化为“匿名用户”
|
||||||
|
User *User `json:"user,omitempty" gorm:"foreignKey:UserID"`
|
||||||
|
Images []PostImage `json:"images" gorm:"foreignKey:PostID"`
|
||||||
|
|
||||||
|
// 审核状态
|
||||||
|
Status PostStatus `json:"status" gorm:"type:varchar(20);default:published;index:idx_posts_status_created,priority:1;index:idx_posts_user_status_created,priority:2"`
|
||||||
|
ReviewedAt *time.Time `json:"reviewed_at" gorm:"type:timestamp"`
|
||||||
|
ReviewedBy string `json:"reviewed_by" gorm:"type:varchar(50)"`
|
||||||
|
RejectReason string `json:"reject_reason" gorm:"type:varchar(500)"`
|
||||||
|
|
||||||
|
// 统计
|
||||||
|
LikesCount int `json:"likes_count" gorm:"column:likes_count;default:0"`
|
||||||
|
CommentsCount int `json:"comments_count" gorm:"column:comments_count;default:0"`
|
||||||
|
FavoritesCount int `json:"favorites_count" gorm:"column:favorites_count;default:0"`
|
||||||
|
SharesCount int `json:"shares_count" gorm:"column:shares_count;default:0"`
|
||||||
|
ViewsCount int `json:"views_count" gorm:"column:views_count;default:0"`
|
||||||
|
HotScore float64 `json:"hot_score" gorm:"column:hot_score;default:0;index:idx_posts_hot_score_created,priority:1"`
|
||||||
|
|
||||||
|
// 置顶/锁定
|
||||||
|
IsPinned bool `json:"is_pinned" gorm:"default:false"`
|
||||||
|
IsLocked bool `json:"is_locked" gorm:"default:false"`
|
||||||
|
IsDeleted bool `json:"-" gorm:"default:false"`
|
||||||
|
|
||||||
|
// 投票
|
||||||
|
IsVote bool `json:"is_vote" gorm:"column:is_vote;default:false"`
|
||||||
|
|
||||||
|
// 软删除
|
||||||
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||||
|
|
||||||
|
// 时间戳
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_posts_status_created,priority:2,sort:desc;index:idx_posts_user_status_created,priority:3,sort:desc;index:idx_posts_hot_score_created,priority:2,sort:desc"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成UUID
|
||||||
|
func (p *Post) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if p.ID == "" {
|
||||||
|
p.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (Post) TableName() string {
|
||||||
|
return "posts"
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostImage 帖子图片
|
||||||
|
type PostImage struct {
|
||||||
|
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||||
|
PostID string `json:"post_id" gorm:"type:varchar(36);index;not null"`
|
||||||
|
URL string `json:"url" gorm:"type:text;not null"`
|
||||||
|
ThumbnailURL string `json:"thumbnail_url" gorm:"type:text"`
|
||||||
|
Width int `json:"width" gorm:"default:0"`
|
||||||
|
Height int `json:"height" gorm:"default:0"`
|
||||||
|
Size int64 `json:"size" gorm:"default:0"` // 文件大小(字节)
|
||||||
|
MimeType string `json:"mime_type" gorm:"type:varchar(50)"`
|
||||||
|
SortOrder int `json:"sort_order" gorm:"default:0"`
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成UUID
|
||||||
|
func (pi *PostImage) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if pi.ID == "" {
|
||||||
|
pi.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (PostImage) TableName() string {
|
||||||
|
return "post_images"
|
||||||
|
}
|
||||||
129
internal/model/push_record.go
Normal file
129
internal/model/push_record.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PushChannel 推送通道类型
|
||||||
|
type PushChannel string
|
||||||
|
|
||||||
|
const (
|
||||||
|
PushChannelWebSocket PushChannel = "websocket" // WebSocket推送
|
||||||
|
PushChannelFCM PushChannel = "fcm" // Firebase Cloud Messaging
|
||||||
|
PushChannelAPNs PushChannel = "apns" // Apple Push Notification service
|
||||||
|
PushChannelHuawei PushChannel = "huawei" // 华为推送
|
||||||
|
)
|
||||||
|
|
||||||
|
// PushStatus 推送状态
|
||||||
|
type PushStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
PushStatusPending PushStatus = "pending" // 待推送
|
||||||
|
PushStatusPushing PushStatus = "pushing" // 推送中
|
||||||
|
PushStatusPushed PushStatus = "pushed" // 已推送(成功发送到推送服务)
|
||||||
|
PushStatusDelivered PushStatus = "delivered" // 已送达(客户端确认)
|
||||||
|
PushStatusFailed PushStatus = "failed" // 推送失败
|
||||||
|
PushStatusExpired PushStatus = "expired" // 消息过期
|
||||||
|
)
|
||||||
|
|
||||||
|
// PushRecord 推送记录实体
|
||||||
|
// 用于跟踪消息的推送状态,支持多设备推送和重试机制
|
||||||
|
type PushRecord struct {
|
||||||
|
ID int64 `gorm:"primaryKey;autoIncrement:false" json:"id"` // 雪花算法ID
|
||||||
|
UserID string `gorm:"column:user_id;type:varchar(50);index;not null" json:"user_id"` // 目标用户ID (UUID格式)
|
||||||
|
MessageID string `gorm:"index;not null;size:20" json:"message_id"` // 关联的消息ID (string类型)
|
||||||
|
PushChannel PushChannel `gorm:"type:varchar(20);not null" json:"push_channel"` // 推送通道
|
||||||
|
PushStatus PushStatus `gorm:"type:varchar(20);not null;default:'pending'" json:"push_status"` // 推送状态
|
||||||
|
|
||||||
|
// 设备信息
|
||||||
|
DeviceToken string `gorm:"type:varchar(256)" json:"device_token,omitempty"` // 设备Token(用于手机推送)
|
||||||
|
DeviceType string `gorm:"type:varchar(20)" json:"device_type,omitempty"` // 设备类型 (ios/android/web)
|
||||||
|
|
||||||
|
// 重试机制
|
||||||
|
RetryCount int `gorm:"default:0" json:"retry_count"` // 重试次数
|
||||||
|
MaxRetry int `gorm:"default:3" json:"max_retry"` // 最大重试次数
|
||||||
|
|
||||||
|
// 时间戳
|
||||||
|
PushedAt *time.Time `json:"pushed_at,omitempty"` // 推送时间
|
||||||
|
DeliveredAt *time.Time `json:"delivered_at,omitempty"` // 送达时间
|
||||||
|
ExpiredAt *time.Time `gorm:"index" json:"expired_at,omitempty"` // 过期时间
|
||||||
|
|
||||||
|
// 错误信息
|
||||||
|
ErrorMessage string `gorm:"type:varchar(500)" json:"error_message,omitempty"` // 错误信息
|
||||||
|
|
||||||
|
// 软删除
|
||||||
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||||
|
|
||||||
|
// 时间戳
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成雪花算法ID
|
||||||
|
func (r *PushRecord) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if r.ID == 0 {
|
||||||
|
id, err := utils.GetSnowflake().GenerateID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.ID = id
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (PushRecord) TableName() string {
|
||||||
|
return "push_records"
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanRetry 判断是否可以重试
|
||||||
|
func (r *PushRecord) CanRetry() bool {
|
||||||
|
return r.RetryCount < r.MaxRetry && r.PushStatus != PushStatusDelivered && r.PushStatus != PushStatusExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired 判断是否已过期
|
||||||
|
func (r *PushRecord) IsExpired() bool {
|
||||||
|
if r.ExpiredAt == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return time.Now().After(*r.ExpiredAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkPushing 标记为推送中
|
||||||
|
func (r *PushRecord) MarkPushing() {
|
||||||
|
r.PushStatus = PushStatusPushing
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkPushed 标记为已推送
|
||||||
|
func (r *PushRecord) MarkPushed() {
|
||||||
|
now := time.Now()
|
||||||
|
r.PushStatus = PushStatusPushed
|
||||||
|
r.PushedAt = &now
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkDelivered 标记为已送达
|
||||||
|
func (r *PushRecord) MarkDelivered() {
|
||||||
|
now := time.Now()
|
||||||
|
r.PushStatus = PushStatusDelivered
|
||||||
|
r.DeliveredAt = &now
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkFailed 标记为推送失败
|
||||||
|
func (r *PushRecord) MarkFailed(errMsg string) {
|
||||||
|
r.PushStatus = PushStatusFailed
|
||||||
|
r.ErrorMessage = errMsg
|
||||||
|
r.RetryCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkExpired 标记为已过期
|
||||||
|
func (r *PushRecord) MarkExpired() {
|
||||||
|
r.PushStatus = PushStatusExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementRetry 增加重试次数
|
||||||
|
func (r *PushRecord) IncrementRetry() {
|
||||||
|
r.RetryCount++
|
||||||
|
}
|
||||||
77
internal/model/sensitive_word.go
Normal file
77
internal/model/sensitive_word.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SensitiveWordLevel 敏感词级别
|
||||||
|
type SensitiveWordLevel int
|
||||||
|
|
||||||
|
const (
|
||||||
|
SensitiveWordLevelLow SensitiveWordLevel = 1 // 低危
|
||||||
|
SensitiveWordLevelMedium SensitiveWordLevel = 2 // 中危
|
||||||
|
SensitiveWordLevelHigh SensitiveWordLevel = 3 // 高危
|
||||||
|
)
|
||||||
|
|
||||||
|
// SensitiveWordCategory 敏感词分类
|
||||||
|
type SensitiveWordCategory string
|
||||||
|
|
||||||
|
const (
|
||||||
|
SensitiveWordCategoryPolitical SensitiveWordCategory = "political" // 政治
|
||||||
|
SensitiveWordCategoryPorn SensitiveWordCategory = "porn" // 色情
|
||||||
|
SensitiveWordCategoryViolence SensitiveWordCategory = "violence" // 暴力
|
||||||
|
SensitiveWordCategoryAd SensitiveWordCategory = "ad" // 广告
|
||||||
|
SensitiveWordCategoryGamble SensitiveWordCategory = "gamble" // 赌博
|
||||||
|
SensitiveWordCategoryFraud SensitiveWordCategory = "fraud" // 诈骗
|
||||||
|
SensitiveWordCategoryOther SensitiveWordCategory = "other" // 其他
|
||||||
|
)
|
||||||
|
|
||||||
|
// SensitiveWord 敏感词实体
|
||||||
|
type SensitiveWord struct {
|
||||||
|
ID string `gorm:"type:varchar(36);primaryKey"`
|
||||||
|
Word string `gorm:"type:varchar(255);uniqueIndex;not null"`
|
||||||
|
Category SensitiveWordCategory `gorm:"type:varchar(50);index"`
|
||||||
|
Level SensitiveWordLevel `gorm:"type:int;default:1"`
|
||||||
|
IsActive bool `gorm:"default:true"`
|
||||||
|
CreatedBy string `gorm:"type:varchar(255)"`
|
||||||
|
UpdatedBy string `gorm:"type:varchar(255)"`
|
||||||
|
Remark string `gorm:"type:text"`
|
||||||
|
CreatedAt time.Time `gorm:"autoCreateTime"`
|
||||||
|
UpdatedAt time.Time `gorm:"autoUpdateTime"`
|
||||||
|
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成UUID
|
||||||
|
func (sw *SensitiveWord) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if sw.ID == "" {
|
||||||
|
sw.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (SensitiveWord) TableName() string {
|
||||||
|
return "sensitive_words"
|
||||||
|
}
|
||||||
|
|
||||||
|
// SensitiveWordRequest 创建/更新敏感词请求
|
||||||
|
type SensitiveWordRequest struct {
|
||||||
|
Word string `json:"word" validate:"required,min=1,max=255"`
|
||||||
|
Category SensitiveWordCategory `json:"category"`
|
||||||
|
Level SensitiveWordLevel `json:"level"`
|
||||||
|
Remark string `json:"remark"`
|
||||||
|
CreatedBy string `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SensitiveWordListItem 敏感词列表项(用于列表展示)
|
||||||
|
type SensitiveWordListItem struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Word string `json:"word"`
|
||||||
|
Category SensitiveWordCategory `json:"category"`
|
||||||
|
Level SensitiveWordLevel `json:"level"`
|
||||||
|
IsActive bool `json:"is_active"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
Remark string `json:"remark"`
|
||||||
|
}
|
||||||
33
internal/model/sticker.go
Normal file
33
internal/model/sticker.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserSticker 用户自定义表情
|
||||||
|
type UserSticker struct {
|
||||||
|
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||||
|
UserID string `json:"user_id" gorm:"type:varchar(36);not null;index:idx_user_stickers"`
|
||||||
|
URL string `json:"url" gorm:"type:text;not null"`
|
||||||
|
Width int `json:"width" gorm:"default:0"`
|
||||||
|
Height int `json:"height" gorm:"default:0"`
|
||||||
|
SortOrder int `json:"sort_order" gorm:"default:0;index:idx_user_stickers_sort"`
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName 表名
|
||||||
|
func (UserSticker) TableName() string {
|
||||||
|
return "user_stickers"
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成UUID
|
||||||
|
func (s *UserSticker) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if s.ID == "" {
|
||||||
|
s.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
127
internal/model/system_notification.go
Normal file
127
internal/model/system_notification.go
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SystemNotificationType 系统通知类型
|
||||||
|
type SystemNotificationType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// 互动通知
|
||||||
|
SysNotifyLikePost SystemNotificationType = "like_post" // 点赞帖子
|
||||||
|
SysNotifyLikeComment SystemNotificationType = "like_comment" // 点赞评论
|
||||||
|
SysNotifyComment SystemNotificationType = "comment" // 评论
|
||||||
|
SysNotifyReply SystemNotificationType = "reply" // 回复
|
||||||
|
SysNotifyFollow SystemNotificationType = "follow" // 关注
|
||||||
|
SysNotifyMention SystemNotificationType = "mention" // @提及
|
||||||
|
SysNotifyFavoritePost SystemNotificationType = "favorite_post" // 收藏帖子
|
||||||
|
SysNotifyLikeReply SystemNotificationType = "like_reply" // 点赞回复
|
||||||
|
|
||||||
|
// 系统消息
|
||||||
|
SysNotifySystem SystemNotificationType = "system" // 系统通知
|
||||||
|
SysNotifyAnnounce SystemNotificationType = "announce" // 系统公告
|
||||||
|
SysNotifyGroupInvite SystemNotificationType = "group_invite" // 群邀请
|
||||||
|
SysNotifyGroupJoinApply SystemNotificationType = "group_join_apply" // 加群申请待审批
|
||||||
|
SysNotifyGroupJoinApproved SystemNotificationType = "group_join_approved" // 加群申请通过
|
||||||
|
SysNotifyGroupJoinRejected SystemNotificationType = "group_join_rejected" // 加群申请拒绝
|
||||||
|
)
|
||||||
|
|
||||||
|
// SystemNotificationExtra 额外数据
|
||||||
|
type SystemNotificationExtra struct {
|
||||||
|
// 操作者信息
|
||||||
|
ActorID int64 `json:"actor_id,omitempty"`
|
||||||
|
ActorIDStr string `json:"actor_id_str,omitempty"`
|
||||||
|
ActorName string `json:"actor_name,omitempty"`
|
||||||
|
AvatarURL string `json:"avatar_url,omitempty"`
|
||||||
|
|
||||||
|
// 目标信息
|
||||||
|
TargetID string `json:"target_id,omitempty"` // 改为string类型以支持UUID
|
||||||
|
TargetTitle string `json:"target_title,omitempty"`
|
||||||
|
TargetType string `json:"target_type,omitempty"`
|
||||||
|
|
||||||
|
// 其他信息
|
||||||
|
ActionURL string `json:"action_url,omitempty"`
|
||||||
|
ActionTime string `json:"action_time,omitempty"`
|
||||||
|
|
||||||
|
// 群邀请/加群申请扩展字段
|
||||||
|
GroupID string `json:"group_id,omitempty"`
|
||||||
|
GroupName string `json:"group_name,omitempty"`
|
||||||
|
GroupAvatar string `json:"group_avatar,omitempty"`
|
||||||
|
GroupDescription string `json:"group_description,omitempty"`
|
||||||
|
Flag string `json:"flag,omitempty"`
|
||||||
|
RequestType string `json:"request_type,omitempty"`
|
||||||
|
RequestStatus string `json:"request_status,omitempty"`
|
||||||
|
Reason string `json:"reason,omitempty"`
|
||||||
|
TargetUserID string `json:"target_user_id,omitempty"`
|
||||||
|
TargetUserName string `json:"target_user_name,omitempty"`
|
||||||
|
TargetUserAvatar string `json:"target_user_avatar,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value 实现driver.Valuer接口
|
||||||
|
func (e SystemNotificationExtra) Value() (driver.Value, error) {
|
||||||
|
return json.Marshal(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan 实现sql.Scanner接口
|
||||||
|
func (e *SystemNotificationExtra) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
bytes, ok := value.([]byte)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("type assertion to []byte failed")
|
||||||
|
}
|
||||||
|
return json.Unmarshal(bytes, e)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SystemNotification 系统通知(独立表,与消息完全分离)
|
||||||
|
// 每个用户只能看到自己的系统通知
|
||||||
|
type SystemNotification struct {
|
||||||
|
ID int64 `gorm:"primaryKey;autoIncrement:false" json:"id"`
|
||||||
|
ReceiverID string `gorm:"column:receiver_id;type:varchar(50);not null;index:idx_sys_notifications_receiver_read_created,priority:1" json:"receiver_id"` // 接收者ID (UUID)
|
||||||
|
Type SystemNotificationType `gorm:"type:varchar(30);not null" json:"type"` // 通知类型
|
||||||
|
Title string `gorm:"type:varchar(200)" json:"title,omitempty"` // 标题
|
||||||
|
Content string `gorm:"type:text;not null" json:"content"` // 内容
|
||||||
|
ExtraData *SystemNotificationExtra `gorm:"type:json" json:"extra_data,omitempty"` // 额外数据
|
||||||
|
IsRead bool `gorm:"default:false;index:idx_sys_notifications_receiver_read_created,priority:2" json:"is_read"` // 是否已读
|
||||||
|
ReadAt *time.Time `json:"read_at,omitempty"` // 阅读时间
|
||||||
|
|
||||||
|
// 软删除
|
||||||
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||||
|
|
||||||
|
// 时间戳
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_sys_notifications_receiver_read_created,priority:3,sort:desc"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成雪花算法ID
|
||||||
|
func (n *SystemNotification) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if n.ID == 0 {
|
||||||
|
id, err := utils.GetSnowflake().GenerateID()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
n.ID = id
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName 指定表名
|
||||||
|
func (SystemNotification) TableName() string {
|
||||||
|
return "system_notifications"
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAsRead 标记为已读
|
||||||
|
func (n *SystemNotification) MarkAsRead() {
|
||||||
|
now := time.Now()
|
||||||
|
n.IsRead = true
|
||||||
|
n.ReadAt = &now
|
||||||
|
}
|
||||||
66
internal/model/user.go
Normal file
66
internal/model/user.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserStatus 用户状态
|
||||||
|
type UserStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
UserStatusActive UserStatus = "active"
|
||||||
|
UserStatusBanned UserStatus = "banned"
|
||||||
|
UserStatusInactive UserStatus = "inactive"
|
||||||
|
)
|
||||||
|
|
||||||
|
// User 用户实体
|
||||||
|
type User struct {
|
||||||
|
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||||
|
Username string `json:"username" gorm:"type:varchar(50);uniqueIndex;not null"`
|
||||||
|
Nickname string `json:"nickname" gorm:"type:varchar(100);not null"`
|
||||||
|
Email *string `json:"email" gorm:"type:varchar(255);uniqueIndex"`
|
||||||
|
Phone *string `json:"phone" gorm:"type:varchar(20);uniqueIndex"`
|
||||||
|
EmailVerified bool `json:"email_verified" gorm:"default:false"`
|
||||||
|
PasswordHash string `json:"-" gorm:"type:varchar(255);not null"`
|
||||||
|
Avatar string `json:"avatar" gorm:"type:text"`
|
||||||
|
CoverURL string `json:"cover_url" gorm:"type:text"` // 头图URL
|
||||||
|
Bio string `json:"bio" gorm:"type:text"`
|
||||||
|
Website string `json:"website" gorm:"type:varchar(255)"`
|
||||||
|
Location string `json:"location" gorm:"type:varchar(100)"`
|
||||||
|
|
||||||
|
// 实名认证信息(可选)
|
||||||
|
RealName string `json:"real_name" gorm:"type:varchar(100)"` // 真实姓名
|
||||||
|
IDCard string `json:"-" gorm:"type:varchar(18)"` // 身份证号(加密存储)
|
||||||
|
IsVerified bool `json:"is_verified" gorm:"default:false"` // 是否实名认证
|
||||||
|
VerifiedAt *time.Time `json:"verified_at" gorm:"type:timestamp"`
|
||||||
|
|
||||||
|
// 统计计数
|
||||||
|
PostsCount int `json:"posts_count" gorm:"default:0"`
|
||||||
|
FollowersCount int `json:"followers_count" gorm:"default:0"`
|
||||||
|
FollowingCount int `json:"following_count" gorm:"default:0"`
|
||||||
|
|
||||||
|
// 状态
|
||||||
|
Status UserStatus `json:"status" gorm:"type:varchar(20);default:active"`
|
||||||
|
LastLoginAt *time.Time `json:"last_login_at" gorm:"type:timestamp"`
|
||||||
|
LastLoginIP string `json:"last_login_ip" gorm:"type:varchar(45)"`
|
||||||
|
|
||||||
|
// 时间戳
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||||
|
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成UUID
|
||||||
|
func (u *User) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if u.ID == "" {
|
||||||
|
u.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (User) TableName() string {
|
||||||
|
return "users"
|
||||||
|
}
|
||||||
27
internal/model/user_block.go
Normal file
27
internal/model/user_block.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserBlock 用户拉黑关系
|
||||||
|
type UserBlock struct {
|
||||||
|
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||||
|
BlockerID string `json:"blocker_id" gorm:"type:varchar(36);index;not null;uniqueIndex:idx_blocker_blocked"` // 拉黑人
|
||||||
|
BlockedID string `json:"blocked_id" gorm:"type:varchar(36);index;not null;uniqueIndex:idx_blocker_blocked"` // 被拉黑人
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *UserBlock) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if b.ID == "" {
|
||||||
|
b.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UserBlock) TableName() string {
|
||||||
|
return "user_blocks"
|
||||||
|
}
|
||||||
52
internal/model/vote.go
Normal file
52
internal/model/vote.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// VoteOption 投票选项
|
||||||
|
type VoteOption struct {
|
||||||
|
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||||
|
PostID string `json:"post_id" gorm:"type:varchar(36);index:idx_vote_option_post_sort,priority:1;not null"`
|
||||||
|
Content string `json:"content" gorm:"type:varchar(200);not null"`
|
||||||
|
SortOrder int `json:"sort_order" gorm:"default:0;index:idx_vote_option_post_sort,priority:2"`
|
||||||
|
VotesCount int `json:"votes_count" gorm:"default:0"`
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成UUID
|
||||||
|
func (vo *VoteOption) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if vo.ID == "" {
|
||||||
|
vo.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (VoteOption) TableName() string {
|
||||||
|
return "vote_options"
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserVote 用户投票记录
|
||||||
|
type UserVote struct {
|
||||||
|
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||||
|
PostID string `json:"post_id" gorm:"type:varchar(36);index;uniqueIndex:idx_user_vote_post_user,priority:1;not null"`
|
||||||
|
UserID string `json:"user_id" gorm:"type:varchar(36);index;uniqueIndex:idx_user_vote_post_user,priority:2;not null"`
|
||||||
|
OptionID string `json:"option_id" gorm:"type:varchar(36);index;not null"`
|
||||||
|
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeCreate 创建前生成UUID
|
||||||
|
func (uv *UserVote) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if uv.ID == "" {
|
||||||
|
uv.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (UserVote) TableName() string {
|
||||||
|
return "user_votes"
|
||||||
|
}
|
||||||
115
internal/pkg/avatar/avatar.go
Normal file
115
internal/pkg/avatar/avatar.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package avatar
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 预定义一组好看的颜色
|
||||||
|
var colors = []string{
|
||||||
|
"#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4",
|
||||||
|
"#FFEAA7", "#DDA0DD", "#98D8C8", "#F7DC6F",
|
||||||
|
"#BB8FCE", "#85C1E9", "#F8B500", "#00CED1",
|
||||||
|
"#E74C3C", "#3498DB", "#2ECC71", "#9B59B6",
|
||||||
|
"#1ABC9C", "#F39C12", "#E67E22", "#16A085",
|
||||||
|
}
|
||||||
|
|
||||||
|
// SVG模板
|
||||||
|
const svgTemplate = `<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 100" width="%d" height="%d">
|
||||||
|
<rect width="100" height="100" fill="%s"/>
|
||||||
|
<text x="50" y="50" font-family="Arial, sans-serif" font-size="40" font-weight="bold" fill="#ffffff" text-anchor="middle" dominant-baseline="central">%s</text>
|
||||||
|
</svg>`
|
||||||
|
|
||||||
|
// GenerateSVGAvatar 根据用户名生成SVG头像
|
||||||
|
// username: 用户名
|
||||||
|
// size: 头像尺寸(像素)
|
||||||
|
func GenerateSVGAvatar(username string, size int) string {
|
||||||
|
initials := getInitials(username)
|
||||||
|
color := stringToColor(username)
|
||||||
|
return fmt.Sprintf(svgTemplate, size, size, color, initials)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateAvatarDataURI 生成Data URI格式的头像
|
||||||
|
// 可以直接在HTML img标签或CSS background-image中使用
|
||||||
|
func GenerateAvatarDataURI(username string, size int) string {
|
||||||
|
svg := GenerateSVGAvatar(username, size)
|
||||||
|
encoded := base64.StdEncoding.EncodeToString([]byte(svg))
|
||||||
|
return fmt.Sprintf("data:image/svg+xml;base64,%s", encoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getInitials 获取用户名首字母
|
||||||
|
// 中文取第一个字,英文取首字母(最多2个)
|
||||||
|
func getInitials(username string) string {
|
||||||
|
if username == "" {
|
||||||
|
return "?"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否是中文字符
|
||||||
|
firstRune, _ := utf8.DecodeRuneInString(username)
|
||||||
|
if isChinese(firstRune) {
|
||||||
|
// 中文直接返回第一个字符
|
||||||
|
return string(firstRune)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 英文处理:取前两个单词的首字母
|
||||||
|
// 例如: "John Doe" -> "JD", "john" -> "J"
|
||||||
|
result := []rune{}
|
||||||
|
for i, r := range username {
|
||||||
|
if i == 0 {
|
||||||
|
result = append(result, toUpper(r))
|
||||||
|
} else if r == ' ' || r == '_' || r == '-' {
|
||||||
|
// 找到下一个字符作为第二个首字母
|
||||||
|
nextIdx := i + 1
|
||||||
|
if nextIdx < len(username) {
|
||||||
|
nextRune, _ := utf8.DecodeRuneInString(username[nextIdx:])
|
||||||
|
if nextRune != utf8.RuneError && nextRune != ' ' {
|
||||||
|
result = append(result, toUpper(nextRune))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result) == 0 {
|
||||||
|
return "?"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 最多返回2个字符
|
||||||
|
if len(result) > 2 {
|
||||||
|
result = result[:2]
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isChinese 判断是否是中文字符
|
||||||
|
func isChinese(r rune) bool {
|
||||||
|
return r >= 0x4E00 && r <= 0x9FFF
|
||||||
|
}
|
||||||
|
|
||||||
|
// toUpper 将字母转换为大写
|
||||||
|
func toUpper(r rune) rune {
|
||||||
|
if r >= 'a' && r <= 'z' {
|
||||||
|
return r - 32
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// stringToColor 根据字符串生成颜色
|
||||||
|
// 使用简单的哈希算法确保同一用户名每次生成的颜色一致
|
||||||
|
func stringToColor(s string) string {
|
||||||
|
if s == "" {
|
||||||
|
return colors[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := 0
|
||||||
|
for _, r := range s {
|
||||||
|
hash = (hash*31 + int(r)) % len(colors)
|
||||||
|
}
|
||||||
|
if hash < 0 {
|
||||||
|
hash = -hash
|
||||||
|
}
|
||||||
|
|
||||||
|
return colors[hash%len(colors)]
|
||||||
|
}
|
||||||
118
internal/pkg/avatar/avatar_test.go
Normal file
118
internal/pkg/avatar/avatar_test.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package avatar
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetInitials(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
username string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"中文用户名", "张三", "张"},
|
||||||
|
{"英文用户名", "John", "J"},
|
||||||
|
{"英文全名", "John Doe", "JD"},
|
||||||
|
{"带下划线", "john_doe", "JD"},
|
||||||
|
{"带连字符", "john-doe", "JD"},
|
||||||
|
{"空字符串", "", "?"},
|
||||||
|
{"小写英文", "alice", "A"},
|
||||||
|
{"中文复合", "李小龙", "李"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := getInitials(tt.username)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("getInitials(%q) = %q, want %q", tt.username, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStringToColor(t *testing.T) {
|
||||||
|
// 测试同一用户名生成的颜色一致
|
||||||
|
color1 := stringToColor("张三")
|
||||||
|
color2 := stringToColor("张三")
|
||||||
|
if color1 != color2 {
|
||||||
|
t.Errorf("stringToColor should return consistent colors for the same input")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试不同用户名生成不同颜色(大概率)
|
||||||
|
color3 := stringToColor("李四")
|
||||||
|
if color1 == color3 {
|
||||||
|
t.Logf("Warning: different usernames generated the same color (possible but unlikely)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试空字符串
|
||||||
|
color4 := stringToColor("")
|
||||||
|
if color4 == "" {
|
||||||
|
t.Errorf("stringToColor should return a color for empty string")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证颜色格式
|
||||||
|
if !strings.HasPrefix(color4, "#") {
|
||||||
|
t.Errorf("stringToColor should return hex color format starting with #")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateSVGAvatar(t *testing.T) {
|
||||||
|
svg := GenerateSVGAvatar("张三", 100)
|
||||||
|
|
||||||
|
// 验证SVG结构
|
||||||
|
if !strings.Contains(svg, "<svg") {
|
||||||
|
t.Errorf("SVG should contain <svg tag")
|
||||||
|
}
|
||||||
|
if !strings.Contains(svg, "</svg>") {
|
||||||
|
t.Errorf("SVG should contain </svg> tag")
|
||||||
|
}
|
||||||
|
if !strings.Contains(svg, "width=\"100\"") {
|
||||||
|
t.Errorf("SVG should have width=100")
|
||||||
|
}
|
||||||
|
if !strings.Contains(svg, "height=\"100\"") {
|
||||||
|
t.Errorf("SVG should have height=100")
|
||||||
|
}
|
||||||
|
if !strings.Contains(svg, "张") {
|
||||||
|
t.Errorf("SVG should contain the initial character")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateAvatarDataURI(t *testing.T) {
|
||||||
|
dataURI := GenerateAvatarDataURI("张三", 100)
|
||||||
|
|
||||||
|
// 验证Data URI格式
|
||||||
|
if !strings.HasPrefix(dataURI, "data:image/svg+xml;base64,") {
|
||||||
|
t.Errorf("Data URI should start with data:image/svg+xml;base64,")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证base64部分不为空
|
||||||
|
parts := strings.Split(dataURI, ",")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
t.Errorf("Data URI should have two parts separated by comma")
|
||||||
|
}
|
||||||
|
if parts[1] == "" {
|
||||||
|
t.Errorf("Base64 part should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsChinese(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
r rune
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{'中', true},
|
||||||
|
{'文', true},
|
||||||
|
{'a', false},
|
||||||
|
{'Z', false},
|
||||||
|
{'0', false},
|
||||||
|
{'_', false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := isChinese(tt.r)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("isChinese(%q) = %v, want %v", tt.r, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
131
internal/pkg/email/client.go
Normal file
131
internal/pkg/email/client.go
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
package email
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gomail "gopkg.in/gomail.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Message 发信参数
|
||||||
|
type Message struct {
|
||||||
|
To []string
|
||||||
|
Cc []string
|
||||||
|
Bcc []string
|
||||||
|
ReplyTo []string
|
||||||
|
Subject string
|
||||||
|
TextBody string
|
||||||
|
HTMLBody string
|
||||||
|
Attachments []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Client interface {
|
||||||
|
IsEnabled() bool
|
||||||
|
Config() Config
|
||||||
|
Send(ctx context.Context, msg Message) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientImpl struct {
|
||||||
|
cfg Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(cfg Config) Client {
|
||||||
|
return &clientImpl{cfg: cfg}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientImpl) IsEnabled() bool {
|
||||||
|
return c.cfg.Enabled &&
|
||||||
|
strings.TrimSpace(c.cfg.Host) != "" &&
|
||||||
|
c.cfg.Port > 0 &&
|
||||||
|
strings.TrimSpace(c.cfg.FromAddress) != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientImpl) Config() Config {
|
||||||
|
return c.cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientImpl) Send(ctx context.Context, msg Message) error {
|
||||||
|
if !c.IsEnabled() {
|
||||||
|
return fmt.Errorf("email client is disabled or misconfigured")
|
||||||
|
}
|
||||||
|
if len(msg.To) == 0 {
|
||||||
|
return fmt.Errorf("email recipient is empty")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(msg.Subject) == "" {
|
||||||
|
return fmt.Errorf("email subject is empty")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(msg.TextBody) == "" && strings.TrimSpace(msg.HTMLBody) == "" {
|
||||||
|
return fmt.Errorf("email body is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
m := gomail.NewMessage()
|
||||||
|
m.SetAddressHeader("From", c.cfg.FromAddress, c.cfg.FromName)
|
||||||
|
m.SetHeader("To", msg.To...)
|
||||||
|
if len(msg.Cc) > 0 {
|
||||||
|
m.SetHeader("Cc", msg.Cc...)
|
||||||
|
}
|
||||||
|
if len(msg.Bcc) > 0 {
|
||||||
|
m.SetHeader("Bcc", msg.Bcc...)
|
||||||
|
}
|
||||||
|
if len(msg.ReplyTo) > 0 {
|
||||||
|
m.SetHeader("Reply-To", msg.ReplyTo...)
|
||||||
|
}
|
||||||
|
m.SetHeader("Subject", msg.Subject)
|
||||||
|
|
||||||
|
if strings.TrimSpace(msg.TextBody) != "" && strings.TrimSpace(msg.HTMLBody) != "" {
|
||||||
|
m.SetBody("text/plain", msg.TextBody)
|
||||||
|
m.AddAlternative("text/html", msg.HTMLBody)
|
||||||
|
} else if strings.TrimSpace(msg.HTMLBody) != "" {
|
||||||
|
m.SetBody("text/html", msg.HTMLBody)
|
||||||
|
} else {
|
||||||
|
m.SetBody("text/plain", msg.TextBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, attachment := range msg.Attachments {
|
||||||
|
if strings.TrimSpace(attachment) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m.Attach(attachment)
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := c.cfg.TimeoutSeconds
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = 15
|
||||||
|
}
|
||||||
|
dialer := gomail.NewDialer(c.cfg.Host, c.cfg.Port, c.cfg.Username, c.cfg.Password)
|
||||||
|
if c.cfg.UseTLS {
|
||||||
|
dialer.TLSConfig = &tls.Config{
|
||||||
|
ServerName: c.cfg.Host,
|
||||||
|
InsecureSkipVerify: c.cfg.InsecureSkipVerify,
|
||||||
|
}
|
||||||
|
// 465 端口通常要求直接 TLS(Implicit TLS)。
|
||||||
|
if c.cfg.Port == 465 {
|
||||||
|
dialer.SSL = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sendCtx := ctx
|
||||||
|
cancel := func() {}
|
||||||
|
if timeout > 0 {
|
||||||
|
sendCtx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||||
|
}
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- dialer.DialAndSend(m)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-sendCtx.Done():
|
||||||
|
return fmt.Errorf("send email canceled: %w", sendCtx.Err())
|
||||||
|
case err := <-done:
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("send email failed: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
33
internal/pkg/email/config.go
Normal file
33
internal/pkg/email/config.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package email
|
||||||
|
|
||||||
|
import "carrot_bbs/internal/config"
|
||||||
|
|
||||||
|
// Config SMTP 邮件配置(由应用配置转换)
|
||||||
|
type Config struct {
|
||||||
|
Enabled bool
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
FromAddress string
|
||||||
|
FromName string
|
||||||
|
UseTLS bool
|
||||||
|
InsecureSkipVerify bool
|
||||||
|
TimeoutSeconds int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigFromAppConfig 从应用配置转换
|
||||||
|
func ConfigFromAppConfig(cfg *config.EmailConfig) Config {
|
||||||
|
return Config{
|
||||||
|
Enabled: cfg.Enabled,
|
||||||
|
Host: cfg.Host,
|
||||||
|
Port: cfg.Port,
|
||||||
|
Username: cfg.Username,
|
||||||
|
Password: cfg.Password,
|
||||||
|
FromAddress: cfg.FromAddress,
|
||||||
|
FromName: cfg.FromName,
|
||||||
|
UseTLS: cfg.UseTLS,
|
||||||
|
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||||
|
TimeoutSeconds: cfg.Timeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
286
internal/pkg/gorse/client.go
Normal file
286
internal/pkg/gorse/client.go
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
package gorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gorseio "github.com/gorse-io/gorse-go"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FeedbackType 反馈类型
|
||||||
|
type FeedbackType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
FeedbackTypeLike FeedbackType = "like" // 点赞
|
||||||
|
FeedbackTypeStar FeedbackType = "star" // 收藏
|
||||||
|
FeedbackTypeComment FeedbackType = "comment" // 评论
|
||||||
|
FeedbackTypeRead FeedbackType = "read" // 浏览
|
||||||
|
)
|
||||||
|
|
||||||
|
// Score 非个性化推荐返回的评分项
|
||||||
|
type Score struct {
|
||||||
|
Id string `json:"Id"`
|
||||||
|
Score float64 `json:"Score"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client Gorse客户端接口
|
||||||
|
type Client interface {
|
||||||
|
// InsertFeedback 插入用户反馈
|
||||||
|
InsertFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error
|
||||||
|
// DeleteFeedback 删除用户反馈
|
||||||
|
DeleteFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error
|
||||||
|
// GetRecommend 获取个性化推荐列表
|
||||||
|
GetRecommend(ctx context.Context, userID string, n int, offset int) ([]string, error)
|
||||||
|
// GetNonPersonalized 获取非个性化推荐(通过名称)
|
||||||
|
GetNonPersonalized(ctx context.Context, name string, n int, offset int, userID string) ([]string, error)
|
||||||
|
// UpsertItem 插入或更新物品(无embedding)
|
||||||
|
UpsertItem(ctx context.Context, itemID string, categories []string, comment string) error
|
||||||
|
// UpsertItemWithEmbedding 插入或更新物品(带embedding)
|
||||||
|
UpsertItemWithEmbedding(ctx context.Context, itemID string, categories []string, comment string, textToEmbed string) error
|
||||||
|
// DeleteItem 删除物品
|
||||||
|
DeleteItem(ctx context.Context, itemID string) error
|
||||||
|
// UpsertUser 插入或更新用户
|
||||||
|
UpsertUser(ctx context.Context, userID string, labels map[string]any) error
|
||||||
|
// IsEnabled 检查是否启用
|
||||||
|
IsEnabled() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// client Gorse客户端实现
|
||||||
|
type client struct {
|
||||||
|
config Config
|
||||||
|
gorse *gorseio.GorseClient
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient 创建新的Gorse客户端
|
||||||
|
func NewClient(cfg Config) Client {
|
||||||
|
if !cfg.Enabled {
|
||||||
|
return &noopClient{}
|
||||||
|
}
|
||||||
|
|
||||||
|
gorse := gorseio.NewGorseClient(cfg.Address, cfg.APIKey)
|
||||||
|
return &client{
|
||||||
|
config: cfg,
|
||||||
|
gorse: gorse,
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEnabled 检查是否启用
|
||||||
|
func (c *client) IsEnabled() bool {
|
||||||
|
return c.config.Enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertFeedback 插入用户反馈
|
||||||
|
func (c *client) InsertFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error {
|
||||||
|
if !c.config.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := c.gorse.InsertFeedback(ctx, []gorseio.Feedback{
|
||||||
|
{
|
||||||
|
FeedbackType: string(feedbackType),
|
||||||
|
UserId: userID,
|
||||||
|
ItemId: itemID,
|
||||||
|
Timestamp: time.Now().UTC().Truncate(time.Second),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteFeedback 删除用户反馈
|
||||||
|
func (c *client) DeleteFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error {
|
||||||
|
if !c.config.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := c.gorse.DeleteFeedback(ctx, string(feedbackType), userID, itemID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRecommend 获取个性化推荐列表
|
||||||
|
func (c *client) GetRecommend(ctx context.Context, userID string, n int, offset int) ([]string, error) {
|
||||||
|
if !c.config.Enabled {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := c.gorse.GetRecommend(ctx, userID, "", n, offset)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNonPersonalized 获取非个性化推荐
|
||||||
|
// name: 推荐器名称,如 "most_liked_weekly"
|
||||||
|
// n: 返回数量
|
||||||
|
// offset: 偏移量
|
||||||
|
// userID: 可选,用于排除用户已读物品
|
||||||
|
func (c *client) GetNonPersonalized(ctx context.Context, name string, n int, offset int, userID string) ([]string, error) {
|
||||||
|
if !c.config.Enabled {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建URL
|
||||||
|
url := fmt.Sprintf("%s/api/non-personalized/%s?n=%d&offset=%d", c.config.Address, name, n, offset)
|
||||||
|
if userID != "" {
|
||||||
|
url += fmt.Sprintf("&user-id=%s", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建请求
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置API Key
|
||||||
|
if c.config.APIKey != "" {
|
||||||
|
req.Header.Set("X-API-Key", c.config.APIKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 读取响应
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return nil, fmt.Errorf("gorse api error: status=%d, body=%s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析响应
|
||||||
|
var scores []Score
|
||||||
|
if err := json.Unmarshal(body, &scores); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取ID
|
||||||
|
ids := make([]string, len(scores))
|
||||||
|
for i, score := range scores {
|
||||||
|
ids[i] = score.Id
|
||||||
|
}
|
||||||
|
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpsertItem 插入或更新物品
|
||||||
|
func (c *client) UpsertItem(ctx context.Context, itemID string, categories []string, comment string) error {
|
||||||
|
if !c.config.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := c.gorse.InsertItem(ctx, gorseio.Item{
|
||||||
|
ItemId: itemID,
|
||||||
|
IsHidden: false,
|
||||||
|
Categories: categories,
|
||||||
|
Comment: comment,
|
||||||
|
Timestamp: time.Now().UTC().Truncate(time.Second),
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpsertItemWithEmbedding 插入或更新物品(带embedding)
|
||||||
|
func (c *client) UpsertItemWithEmbedding(ctx context.Context, itemID string, categories []string, comment string, textToEmbed string) error {
|
||||||
|
if !c.config.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成embedding
|
||||||
|
var embedding []float64
|
||||||
|
if textToEmbed != "" {
|
||||||
|
var err error
|
||||||
|
embedding, err = GetEmbedding(textToEmbed)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[WARN] Failed to get embedding for item %s: %v, using zero vector", itemID, err)
|
||||||
|
embedding = make([]float64, 1024)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
embedding = make([]float64, 1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := c.gorse.InsertItem(ctx, gorseio.Item{
|
||||||
|
ItemId: itemID,
|
||||||
|
IsHidden: false,
|
||||||
|
Categories: categories,
|
||||||
|
Comment: comment,
|
||||||
|
Timestamp: time.Now().UTC().Truncate(time.Second),
|
||||||
|
Labels: map[string]any{
|
||||||
|
"embedding": embedding,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteItem 删除物品
|
||||||
|
func (c *client) DeleteItem(ctx context.Context, itemID string) error {
|
||||||
|
if !c.config.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := c.gorse.DeleteItem(ctx, itemID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpsertUser 插入或更新用户
|
||||||
|
func (c *client) UpsertUser(ctx context.Context, userID string, labels map[string]any) error {
|
||||||
|
if !c.config.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := c.gorse.InsertUser(ctx, gorseio.User{
|
||||||
|
UserId: userID,
|
||||||
|
Labels: labels,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// noopClient 空操作客户端(用于未启用推荐功能时)
|
||||||
|
type noopClient struct{}
|
||||||
|
|
||||||
|
func (c *noopClient) IsEnabled() bool { return false }
|
||||||
|
func (c *noopClient) InsertFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c *noopClient) DeleteFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c *noopClient) GetRecommend(ctx context.Context, userID string, n int, offset int) ([]string, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (c *noopClient) GetNonPersonalized(ctx context.Context, name string, n int, offset int, userID string) ([]string, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (c *noopClient) UpsertItem(ctx context.Context, itemID string, categories []string, comment string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c *noopClient) UpsertItemWithEmbedding(ctx context.Context, itemID string, categories []string, comment string, textToEmbed string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c *noopClient) DeleteItem(ctx context.Context, itemID string) error { return nil }
|
||||||
|
func (c *noopClient) UpsertUser(ctx context.Context, userID string, labels map[string]any) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保实现了接口
|
||||||
|
var _ Client = (*client)(nil)
|
||||||
|
var _ Client = (*noopClient)(nil)
|
||||||
|
|
||||||
|
// log 用于内部日志
|
||||||
|
func init() {
|
||||||
|
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||||
|
}
|
||||||
23
internal/pkg/gorse/config.go
Normal file
23
internal/pkg/gorse/config.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package gorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config Gorse客户端配置(从config.GorseConfig转换)
|
||||||
|
type Config struct {
|
||||||
|
Address string
|
||||||
|
APIKey string
|
||||||
|
Enabled bool
|
||||||
|
Dashboard string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigFromAppConfig 从应用配置创建Gorse配置
|
||||||
|
func ConfigFromAppConfig(cfg *config.GorseConfig) Config {
|
||||||
|
return Config{
|
||||||
|
Address: cfg.Address,
|
||||||
|
APIKey: cfg.APIKey,
|
||||||
|
Enabled: cfg.Enabled,
|
||||||
|
Dashboard: cfg.Dashboard,
|
||||||
|
}
|
||||||
|
}
|
||||||
106
internal/pkg/gorse/embedding.go
Normal file
106
internal/pkg/gorse/embedding.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
package gorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EmbeddingConfig embedding服务配置
|
||||||
|
type EmbeddingConfig struct {
|
||||||
|
APIKey string
|
||||||
|
URL string
|
||||||
|
Model string
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultEmbeddingConfig = EmbeddingConfig{
|
||||||
|
APIKey: "sk-ZPN5NMPSqEaOGCPfD2LqndZ5Wwmw3DC4CQgzgKhM35fI3RpD",
|
||||||
|
URL: "https://api.littlelan.cn/v1/embeddings",
|
||||||
|
Model: "BAAI/bge-m3",
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEmbeddingConfig 设置embedding配置
|
||||||
|
func SetEmbeddingConfig(apiKey, url, model string) {
|
||||||
|
if apiKey != "" {
|
||||||
|
defaultEmbeddingConfig.APIKey = apiKey
|
||||||
|
}
|
||||||
|
if url != "" {
|
||||||
|
defaultEmbeddingConfig.URL = url
|
||||||
|
}
|
||||||
|
if model != "" {
|
||||||
|
defaultEmbeddingConfig.Model = model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEmbedding 获取文本的embedding
|
||||||
|
func GetEmbedding(text string) ([]float64, error) {
|
||||||
|
type embeddingRequest struct {
|
||||||
|
Input string `json:"input"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type embeddingResponse struct {
|
||||||
|
Data []struct {
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody := embeddingRequest{
|
||||||
|
Input: text,
|
||||||
|
Model: defaultEmbeddingConfig.Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", defaultEmbeddingConfig.URL, bytes.NewReader(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+defaultEmbeddingConfig.APIKey)
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("embedding API error: status=%d, body=%s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result embeddingResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Data) == 0 {
|
||||||
|
return nil, fmt.Errorf("no embedding returned")
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.Data[0].Embedding, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitEmbeddingWithConfig 从应用配置初始化embedding
|
||||||
|
func InitEmbeddingWithConfig(apiKey, url, model string) {
|
||||||
|
if apiKey == "" {
|
||||||
|
log.Println("[WARN] Gorse embedding API key not set, using default")
|
||||||
|
}
|
||||||
|
defaultEmbeddingConfig.APIKey = apiKey
|
||||||
|
if url != "" {
|
||||||
|
defaultEmbeddingConfig.URL = url
|
||||||
|
}
|
||||||
|
if model != "" {
|
||||||
|
defaultEmbeddingConfig.Model = model
|
||||||
|
}
|
||||||
|
}
|
||||||
105
internal/pkg/jwt/jwt.go
Normal file
105
internal/pkg/jwt/jwt.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidToken = errors.New("invalid token")
|
||||||
|
ErrExpiredToken = errors.New("token has expired")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Claims JWT 声明
|
||||||
|
type Claims struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWT JWT工具
|
||||||
|
type JWT struct {
|
||||||
|
secretKey string
|
||||||
|
accessTokenExpire time.Duration
|
||||||
|
refreshTokenExpire time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// New 创建JWT实例
|
||||||
|
func New(secret string, accessExpire, refreshExpire time.Duration) *JWT {
|
||||||
|
return &JWT{
|
||||||
|
secretKey: secret,
|
||||||
|
accessTokenExpire: accessExpire,
|
||||||
|
refreshTokenExpire: refreshExpire,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateAccessToken 生成访问令牌
|
||||||
|
func (j *JWT) GenerateAccessToken(userID, username string) (string, error) {
|
||||||
|
now := time.Now()
|
||||||
|
claims := Claims{
|
||||||
|
UserID: userID,
|
||||||
|
Username: username,
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ExpiresAt: jwt.NewNumericDate(now.Add(j.accessTokenExpire)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
|
Issuer: "carrot_bbs",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
return token.SignedString([]byte(j.secretKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateRefreshToken 生成刷新令牌
|
||||||
|
func (j *JWT) GenerateRefreshToken(userID, username string) (string, error) {
|
||||||
|
now := time.Now()
|
||||||
|
claims := Claims{
|
||||||
|
UserID: userID,
|
||||||
|
Username: username,
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ExpiresAt: jwt.NewNumericDate(now.Add(j.refreshTokenExpire)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
|
Issuer: "carrot_bbs",
|
||||||
|
ID: "refresh",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
return token.SignedString([]byte(j.secretKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseToken 解析令牌
|
||||||
|
func (j *JWT) ParseToken(tokenString string) (*Claims, error) {
|
||||||
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
return []byte(j.secretKey), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, ErrInvalidToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateToken 验证令牌
|
||||||
|
func (j *JWT) ValidateToken(tokenString string) error {
|
||||||
|
claims, err := j.ParseToken(tokenString)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否是刷新令牌
|
||||||
|
if claims.ID == "refresh" {
|
||||||
|
return errors.New("cannot use refresh token as access token")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
438
internal/pkg/openai/client.go
Normal file
438
internal/pkg/openai/client.go
Normal file
@@ -0,0 +1,438 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
_ "image/gif"
|
||||||
|
"image/jpeg"
|
||||||
|
_ "image/png"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
xdraw "golang.org/x/image/draw"
|
||||||
|
)
|
||||||
|
|
||||||
|
const moderationSystemPrompt = "你是中文社区的内容审核助手,负责对“帖子标题、正文、配图”做联合审核。目标是平衡社区安全与正常交流:必须拦截高风险违规内容,但不要误伤正常玩梗、二创、吐槽和轻度调侃。请只输出指定JSON。\n\n审核流程:\n1) 先判断是否命中硬性违规;\n2) 再判断语境(玩笑/自嘲/朋友间互动/作品讨论);\n3) 做文图交叉判断(文本+图片合并理解);\n4) 给出 approved 与简短 reason。\n\n硬性违规(命中任一项必须 approved=false):\nA. 宣传对立与煽动撕裂:\n- 明确煽动群体对立、地域对立、性别对立、民族宗教对立,鼓动仇恨、排斥、报复。\nB. 严重人身攻击与网暴引导:\n- 持续性侮辱贬损、羞辱人格、号召围攻/骚扰/挂人/线下冲突。\nC. 开盒/人肉/隐私暴露:\n- 故意公开、拼接、索取他人可识别隐私信息(姓名+联系方式、身份证号、住址、学校单位、车牌、定位轨迹等);\n- 图片/截图中出现可识别隐私信息并伴随曝光意图,也按违规处理。\nD. 其他高危违规:\n- 违法犯罪、暴力威胁、极端仇恨、色情低俗、诈骗引流、恶意广告等。\n\n放行规则(以下通常 approved=true):\n- 正常玩梗、表情包、谐音梗、二次创作、无恶意的吐槽;\n- 非定向、轻度口语化吐槽(无明确攻击对象、无网暴号召、无隐私暴露);\n- 对社会事件/作品的理性讨论、观点争论(即使语气尖锐,但未煽动对立或人身攻击)。\n\n边界判定:\n- 若只是“梗文化表达”且不指向现实伤害,优先通过;\n- 若存在明确伤害意图(煽动、围攻、曝光隐私),必须拒绝;\n- 对模糊内容不因个别粗口直接拒绝,需结合对象、意图、号召性和可执行性综合判断。\n\nreason 要求:\n- approved=false 时:中文10-30字,说明核心违规点;\n- approved=true 时:reason 为空字符串。\n\n输出格式(严格):\n仅输出一行JSON对象,不要Markdown,不要额外解释:\n{\"approved\": true/false, \"reason\": \"...\"}"
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultMaxImagesPerModerationRequest = 1
|
||||||
|
maxModerationResultRetries = 3
|
||||||
|
maxChatCompletionRetries = 3
|
||||||
|
initialRetryBackoff = 500 * time.Millisecond
|
||||||
|
maxDownloadImageBytes = 10 * 1024 * 1024
|
||||||
|
maxModerationImageSide = 1280
|
||||||
|
compressedJPEGQuality = 72
|
||||||
|
maxCompressedPayloadBytes = 1536 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
type Client interface {
|
||||||
|
IsEnabled() bool
|
||||||
|
Config() Config
|
||||||
|
ModeratePost(ctx context.Context, title, content string, images []string) (bool, string, error)
|
||||||
|
ModerateComment(ctx context.Context, content string, images []string) (bool, string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientImpl struct {
|
||||||
|
cfg Config
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(cfg Config) Client {
|
||||||
|
timeout := cfg.RequestTimeoutSeconds
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = 30
|
||||||
|
}
|
||||||
|
return &clientImpl{
|
||||||
|
cfg: cfg,
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: time.Duration(timeout) * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientImpl) IsEnabled() bool {
|
||||||
|
return c.cfg.Enabled && c.cfg.APIKey != "" && c.cfg.BaseURL != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientImpl) Config() Config {
|
||||||
|
return c.cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientImpl) ModeratePost(ctx context.Context, title, content string, images []string) (bool, string, error) {
|
||||||
|
if !c.IsEnabled() {
|
||||||
|
return true, "", nil
|
||||||
|
}
|
||||||
|
return c.moderateContentInBatches(ctx, fmt.Sprintf("帖子标题:%s\n帖子内容:%s", title, content), images)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientImpl) ModerateComment(ctx context.Context, content string, images []string) (bool, string, error) {
|
||||||
|
if !c.IsEnabled() {
|
||||||
|
return true, "", nil
|
||||||
|
}
|
||||||
|
return c.moderateContentInBatches(ctx, fmt.Sprintf("评论内容:%s", content), images)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientImpl) moderateContentInBatches(ctx context.Context, contentPrompt string, images []string) (bool, string, error) {
|
||||||
|
cleanImages := normalizeImageURLs(images)
|
||||||
|
optimizedImages := c.optimizeImagesForModeration(ctx, cleanImages)
|
||||||
|
maxImagesPerRequest := c.maxImagesPerModerationRequest()
|
||||||
|
totalBatches := 1
|
||||||
|
if len(optimizedImages) > 0 {
|
||||||
|
totalBatches = (len(optimizedImages) + maxImagesPerRequest - 1) / maxImagesPerRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
// 图片超过单批上限时分批审核,任一批次拒绝即整体拒绝
|
||||||
|
for i := 0; i < totalBatches; i++ {
|
||||||
|
start := i * maxImagesPerRequest
|
||||||
|
end := start + maxImagesPerRequest
|
||||||
|
if end > len(optimizedImages) {
|
||||||
|
end = len(optimizedImages)
|
||||||
|
}
|
||||||
|
|
||||||
|
batchImages := []string{}
|
||||||
|
if len(optimizedImages) > 0 {
|
||||||
|
batchImages = optimizedImages[start:end]
|
||||||
|
}
|
||||||
|
|
||||||
|
approved, reason, err := c.moderateSingleBatch(ctx, contentPrompt, batchImages, i+1, totalBatches)
|
||||||
|
if err != nil {
|
||||||
|
return false, "", err
|
||||||
|
}
|
||||||
|
if !approved {
|
||||||
|
if strings.TrimSpace(reason) != "" && totalBatches > 1 {
|
||||||
|
reason = fmt.Sprintf("第%d/%d批图片未通过:%s", i+1, totalBatches, reason)
|
||||||
|
}
|
||||||
|
return false, reason, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientImpl) moderateSingleBatch(
|
||||||
|
ctx context.Context,
|
||||||
|
contentPrompt string,
|
||||||
|
images []string,
|
||||||
|
batchNo, totalBatches int,
|
||||||
|
) (bool, string, error) {
|
||||||
|
userPrompt := fmt.Sprintf(
|
||||||
|
"%s\n图片批次:%d/%d(本次仅提供当前批次图片)",
|
||||||
|
contentPrompt,
|
||||||
|
batchNo,
|
||||||
|
totalBatches,
|
||||||
|
)
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for attempt := 0; attempt < maxModerationResultRetries; attempt++ {
|
||||||
|
replyText, err := c.chatCompletion(ctx, c.cfg.ModerationModel, moderationSystemPrompt, userPrompt, images, 0.1, 220)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
} else {
|
||||||
|
parsed := struct {
|
||||||
|
Approved bool `json:"approved"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
}{}
|
||||||
|
if err := json.Unmarshal([]byte(extractJSONObject(replyText)), &parsed); err != nil {
|
||||||
|
lastErr = fmt.Errorf("failed to parse moderation result: %w", err)
|
||||||
|
} else {
|
||||||
|
return parsed.Approved, parsed.Reason, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if attempt == maxModerationResultRetries-1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if sleepErr := sleepWithBackoff(ctx, attempt); sleepErr != nil {
|
||||||
|
return false, "", sleepErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, "", fmt.Errorf(
|
||||||
|
"moderation batch %d/%d failed after %d attempts: %w",
|
||||||
|
batchNo,
|
||||||
|
totalBatches,
|
||||||
|
maxModerationResultRetries,
|
||||||
|
lastErr,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatCompletionsRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []chatMessage `json:"messages"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content interface{} `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type contentPart struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
ImageURL *imageURLPart `json:"image_url,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type imageURLPart struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatCompletionsResponse struct {
|
||||||
|
Choices []struct {
|
||||||
|
Message struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"message"`
|
||||||
|
} `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientImpl) chatCompletion(
|
||||||
|
ctx context.Context,
|
||||||
|
model string,
|
||||||
|
systemPrompt string,
|
||||||
|
userPrompt string,
|
||||||
|
images []string,
|
||||||
|
temperature float64,
|
||||||
|
maxTokens int,
|
||||||
|
) (string, error) {
|
||||||
|
if model == "" {
|
||||||
|
return "", fmt.Errorf("model is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanImages := normalizeImageURLs(images)
|
||||||
|
|
||||||
|
userParts := []contentPart{
|
||||||
|
{Type: "text", Text: userPrompt},
|
||||||
|
}
|
||||||
|
for _, image := range cleanImages {
|
||||||
|
userParts = append(userParts, contentPart{
|
||||||
|
Type: "image_url",
|
||||||
|
ImageURL: &imageURLPart{URL: image},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody := chatCompletionsRequest{
|
||||||
|
Model: model,
|
||||||
|
Messages: []chatMessage{
|
||||||
|
{Role: "system", Content: systemPrompt},
|
||||||
|
{Role: "user", Content: userParts},
|
||||||
|
},
|
||||||
|
Temperature: temperature,
|
||||||
|
MaxTokens: maxTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := strings.TrimRight(c.cfg.BaseURL, "/")
|
||||||
|
endpoint := baseURL + "/v1/chat/completions"
|
||||||
|
if strings.HasSuffix(baseURL, "/v1") {
|
||||||
|
endpoint = baseURL + "/chat/completions"
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for attempt := 0; attempt < maxChatCompletionRetries; attempt++ {
|
||||||
|
body, statusCode, err := c.doChatCompletionRequest(ctx, endpoint, data)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
} else if statusCode >= 400 {
|
||||||
|
lastErr = fmt.Errorf("openai error status=%d body=%s", statusCode, string(body))
|
||||||
|
if !isRetryableStatusCode(statusCode) {
|
||||||
|
return "", lastErr
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var parsed chatCompletionsResponse
|
||||||
|
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||||
|
return "", fmt.Errorf("decode response: %w", err)
|
||||||
|
}
|
||||||
|
if len(parsed.Choices) == 0 {
|
||||||
|
return "", fmt.Errorf("empty response choices")
|
||||||
|
}
|
||||||
|
return parsed.Choices[0].Message.Content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if attempt == maxChatCompletionRetries-1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if sleepErr := sleepWithBackoff(ctx, attempt); sleepErr != nil {
|
||||||
|
return "", sleepErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientImpl) doChatCompletionRequest(ctx context.Context, endpoint string, data []byte) ([]byte, int, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.cfg.APIKey)
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("request openai: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("read response: %w", err)
|
||||||
|
}
|
||||||
|
return body, resp.StatusCode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRetryableStatusCode(statusCode int) bool {
|
||||||
|
if statusCode == http.StatusTooManyRequests {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return statusCode >= 500 && statusCode <= 599
|
||||||
|
}
|
||||||
|
|
||||||
|
func sleepWithBackoff(ctx context.Context, attempt int) error {
|
||||||
|
delay := initialRetryBackoff * time.Duration(1<<attempt)
|
||||||
|
timer := time.NewTimer(delay)
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return fmt.Errorf("request cancelled: %w", ctx.Err())
|
||||||
|
case <-timer.C:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeImageURLs(images []string) []string {
|
||||||
|
clean := make([]string, 0, len(images))
|
||||||
|
for _, image := range images {
|
||||||
|
trimmed := strings.TrimSpace(image)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
clean = append(clean, trimmed)
|
||||||
|
}
|
||||||
|
return clean
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractJSONObject(raw string) string {
|
||||||
|
text := strings.TrimSpace(raw)
|
||||||
|
start := strings.Index(text, "{")
|
||||||
|
end := strings.LastIndex(text, "}")
|
||||||
|
if start >= 0 && end > start {
|
||||||
|
return text[start : end+1]
|
||||||
|
}
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientImpl) maxImagesPerModerationRequest() int {
|
||||||
|
// 审核固定单图请求,降低单次payload体积,减少超时风险。
|
||||||
|
if c.cfg.ModerationMaxImagesPerRequest <= 0 {
|
||||||
|
return defaultMaxImagesPerModerationRequest
|
||||||
|
}
|
||||||
|
if c.cfg.ModerationMaxImagesPerRequest > 1 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return c.cfg.ModerationMaxImagesPerRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientImpl) optimizeImagesForModeration(ctx context.Context, images []string) []string {
|
||||||
|
if len(images) == 0 {
|
||||||
|
return images
|
||||||
|
}
|
||||||
|
|
||||||
|
optimized := make([]string, 0, len(images))
|
||||||
|
for _, imageURL := range images {
|
||||||
|
optimized = append(optimized, c.tryCompressImageForModeration(ctx, imageURL))
|
||||||
|
}
|
||||||
|
return optimized
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *clientImpl) tryCompressImageForModeration(ctx context.Context, imageURL string) string {
|
||||||
|
parsed, err := url.Parse(imageURL)
|
||||||
|
if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") {
|
||||||
|
return imageURL
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, imageURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return imageURL
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return imageURL
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
return imageURL
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(strings.ToLower(resp.Header.Get("Content-Type")), "image/") {
|
||||||
|
return imageURL
|
||||||
|
}
|
||||||
|
|
||||||
|
originBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxDownloadImageBytes))
|
||||||
|
if err != nil || len(originBytes) == 0 {
|
||||||
|
return imageURL
|
||||||
|
}
|
||||||
|
|
||||||
|
srcImg, _, err := image.Decode(bytes.NewReader(originBytes))
|
||||||
|
if err != nil {
|
||||||
|
return imageURL
|
||||||
|
}
|
||||||
|
|
||||||
|
dstImg := resizeIfNeeded(srcImg, maxModerationImageSide)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := jpeg.Encode(&buf, dstImg, &jpeg.Options{Quality: compressedJPEGQuality}); err != nil {
|
||||||
|
return imageURL
|
||||||
|
}
|
||||||
|
|
||||||
|
compressed := buf.Bytes()
|
||||||
|
if len(compressed) == 0 || len(compressed) > maxCompressedPayloadBytes {
|
||||||
|
return imageURL
|
||||||
|
}
|
||||||
|
// 压缩效果不明显时直接使用原图URL,避免增大请求体。
|
||||||
|
if len(compressed) >= int(float64(len(originBytes))*0.95) {
|
||||||
|
return imageURL
|
||||||
|
}
|
||||||
|
|
||||||
|
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(compressed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func resizeIfNeeded(src image.Image, maxSide int) image.Image {
|
||||||
|
bounds := src.Bounds()
|
||||||
|
w := bounds.Dx()
|
||||||
|
h := bounds.Dy()
|
||||||
|
if w <= 0 || h <= 0 || max(w, h) <= maxSide {
|
||||||
|
return src
|
||||||
|
}
|
||||||
|
|
||||||
|
newW, newH := w, h
|
||||||
|
if w >= h {
|
||||||
|
newW = maxSide
|
||||||
|
newH = int(float64(h) * (float64(maxSide) / float64(w)))
|
||||||
|
} else {
|
||||||
|
newH = maxSide
|
||||||
|
newW = int(float64(w) * (float64(maxSide) / float64(h)))
|
||||||
|
}
|
||||||
|
if newW < 1 {
|
||||||
|
newW = 1
|
||||||
|
}
|
||||||
|
if newH < 1 {
|
||||||
|
newH = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
dst := image.NewRGBA(image.Rect(0, 0, newW, newH))
|
||||||
|
xdraw.CatmullRom.Scale(dst, dst.Bounds(), src, bounds, xdraw.Over, nil)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
27
internal/pkg/openai/config.go
Normal file
27
internal/pkg/openai/config.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import "carrot_bbs/internal/config"
|
||||||
|
|
||||||
|
// Config OpenAI 兼容接口配置
|
||||||
|
type Config struct {
|
||||||
|
Enabled bool
|
||||||
|
BaseURL string
|
||||||
|
APIKey string
|
||||||
|
ModerationModel string
|
||||||
|
ModerationMaxImagesPerRequest int
|
||||||
|
RequestTimeoutSeconds int
|
||||||
|
StrictModeration bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigFromAppConfig 从应用配置转换
|
||||||
|
func ConfigFromAppConfig(cfg *config.OpenAIConfig) Config {
|
||||||
|
return Config{
|
||||||
|
Enabled: cfg.Enabled,
|
||||||
|
BaseURL: cfg.BaseURL,
|
||||||
|
APIKey: cfg.APIKey,
|
||||||
|
ModerationModel: cfg.ModerationModel,
|
||||||
|
ModerationMaxImagesPerRequest: cfg.ModerationMaxImagesPerRequest,
|
||||||
|
RequestTimeoutSeconds: cfg.RequestTimeout,
|
||||||
|
StrictModeration: cfg.StrictModeration,
|
||||||
|
}
|
||||||
|
}
|
||||||
119
internal/pkg/redis/redis.go
Normal file
119
internal/pkg/redis/redis.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package redis
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/alicebob/miniredis/v2"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client Redis客户端
|
||||||
|
type Client struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
isMiniRedis bool
|
||||||
|
mr *miniredis.Miniredis
|
||||||
|
}
|
||||||
|
|
||||||
|
// New 创建Redis客户端
|
||||||
|
func New(cfg *config.RedisConfig) (*Client, error) {
|
||||||
|
switch cfg.Type {
|
||||||
|
case "miniredis":
|
||||||
|
// 启动内嵌Redis模拟
|
||||||
|
mr, err := miniredis.Run()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to start miniredis: %w", err)
|
||||||
|
}
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: mr.Addr(),
|
||||||
|
Password: "",
|
||||||
|
DB: 0,
|
||||||
|
})
|
||||||
|
return &Client{
|
||||||
|
rdb: rdb,
|
||||||
|
isMiniRedis: true,
|
||||||
|
mr: mr,
|
||||||
|
}, nil
|
||||||
|
case "redis":
|
||||||
|
// 使用真实Redis
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: cfg.Redis.Addr(),
|
||||||
|
Password: cfg.Redis.Password,
|
||||||
|
DB: cfg.Redis.DB,
|
||||||
|
PoolSize: cfg.PoolSize,
|
||||||
|
})
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to connect to redis: %w", err)
|
||||||
|
}
|
||||||
|
return &Client{rdb: rdb, isMiniRedis: false}, nil
|
||||||
|
default:
|
||||||
|
// 默认使用miniredis
|
||||||
|
mr, err := miniredis.Run()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to start miniredis: %w", err)
|
||||||
|
}
|
||||||
|
rdb := redis.NewClient(&redis.Options{
|
||||||
|
Addr: mr.Addr(),
|
||||||
|
})
|
||||||
|
return &Client{
|
||||||
|
rdb: rdb,
|
||||||
|
isMiniRedis: true,
|
||||||
|
mr: mr,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 获取值
|
||||||
|
func (c *Client) Get(ctx context.Context, key string) (string, error) {
|
||||||
|
return c.rdb.Get(ctx, key).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set 设置值
|
||||||
|
func (c *Client) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
|
||||||
|
return c.rdb.Set(ctx, key, value, expiration).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Del 删除键
|
||||||
|
func (c *Client) Del(ctx context.Context, keys ...string) error {
|
||||||
|
return c.rdb.Del(ctx, keys...).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exists 检查键是否存在
|
||||||
|
func (c *Client) Exists(ctx context.Context, keys ...string) (int64, error) {
|
||||||
|
return c.rdb.Exists(ctx, keys...).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Incr 递增
|
||||||
|
func (c *Client) Incr(ctx context.Context, key string) (int64, error) {
|
||||||
|
return c.rdb.Incr(ctx, key).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expire 设置过期时间
|
||||||
|
func (c *Client) Expire(ctx context.Context, key string, expiration time.Duration) (bool, error) {
|
||||||
|
return c.rdb.Expire(ctx, key, expiration).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClient 获取原生客户端
|
||||||
|
func (c *Client) GetClient() *redis.Client {
|
||||||
|
return c.rdb
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 关闭连接
|
||||||
|
func (c *Client) Close() error {
|
||||||
|
if err := c.rdb.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if c.mr != nil {
|
||||||
|
c.mr.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsMiniRedis 返回是否是miniredis
|
||||||
|
func (c *Client) IsMiniRedis() bool {
|
||||||
|
return c.isMiniRedis
|
||||||
|
}
|
||||||
117
internal/pkg/response/response.go
Normal file
117
internal/pkg/response/response.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
package response
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Response 统一响应结构
|
||||||
|
type Response struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data interface{} `json:"data,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseSnakeCase 统一响应结构(snake_case)
|
||||||
|
type ResponseSnakeCase struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data interface{} `json:"data,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success 成功响应
|
||||||
|
func Success(c *gin.Context, data interface{}) {
|
||||||
|
c.JSON(http.StatusOK, Response{
|
||||||
|
Code: 0,
|
||||||
|
Message: "success",
|
||||||
|
Data: data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SuccessWithMessage 成功响应带消息
|
||||||
|
func SuccessWithMessage(c *gin.Context, message string, data interface{}) {
|
||||||
|
c.JSON(http.StatusOK, Response{
|
||||||
|
Code: 0,
|
||||||
|
Message: message,
|
||||||
|
Data: data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error 错误响应
|
||||||
|
func Error(c *gin.Context, code int, message string) {
|
||||||
|
c.JSON(http.StatusBadRequest, Response{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorWithStatusCode 带状态码的错误响应
|
||||||
|
func ErrorWithStatusCode(c *gin.Context, statusCode int, code int, message string) {
|
||||||
|
c.JSON(statusCode, Response{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BadRequest 参数错误
|
||||||
|
func BadRequest(c *gin.Context, message string) {
|
||||||
|
ErrorWithStatusCode(c, http.StatusBadRequest, 400, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unauthorized 未授权
|
||||||
|
func Unauthorized(c *gin.Context, message string) {
|
||||||
|
if message == "" {
|
||||||
|
message = "unauthorized"
|
||||||
|
}
|
||||||
|
ErrorWithStatusCode(c, http.StatusUnauthorized, 401, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forbidden 禁止访问
|
||||||
|
func Forbidden(c *gin.Context, message string) {
|
||||||
|
if message == "" {
|
||||||
|
message = "forbidden"
|
||||||
|
}
|
||||||
|
ErrorWithStatusCode(c, http.StatusForbidden, 403, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotFound 资源不存在
|
||||||
|
func NotFound(c *gin.Context, message string) {
|
||||||
|
if message == "" {
|
||||||
|
message = "resource not found"
|
||||||
|
}
|
||||||
|
ErrorWithStatusCode(c, http.StatusNotFound, 404, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InternalServerError 服务器内部错误
|
||||||
|
func InternalServerError(c *gin.Context, message string) {
|
||||||
|
if message == "" {
|
||||||
|
message = "internal server error"
|
||||||
|
}
|
||||||
|
ErrorWithStatusCode(c, http.StatusInternalServerError, 500, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PaginatedResponse 分页响应
|
||||||
|
type PaginatedResponse struct {
|
||||||
|
List interface{} `json:"list"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
TotalPages int `json:"total_pages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Paginated 分页成功响应
|
||||||
|
func Paginated(c *gin.Context, list interface{}, total int64, page, pageSize int) {
|
||||||
|
totalPages := int(total) / pageSize
|
||||||
|
if int(total)%pageSize > 0 {
|
||||||
|
totalPages++
|
||||||
|
}
|
||||||
|
|
||||||
|
Success(c, PaginatedResponse{
|
||||||
|
List: list,
|
||||||
|
Total: total,
|
||||||
|
Page: page,
|
||||||
|
PageSize: pageSize,
|
||||||
|
TotalPages: totalPages,
|
||||||
|
})
|
||||||
|
}
|
||||||
119
internal/pkg/s3/s3.go
Normal file
119
internal/pkg/s3/s3.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package s3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/minio/minio-go/v7"
|
||||||
|
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client S3客户端
|
||||||
|
type Client struct {
|
||||||
|
client *minio.Client
|
||||||
|
bucket string
|
||||||
|
domain string
|
||||||
|
}
|
||||||
|
|
||||||
|
// New 创建S3客户端
|
||||||
|
func New(cfg *config.S3Config) (*Client, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
client, err := minio.New(cfg.Endpoint, &minio.Options{
|
||||||
|
Creds: credentials.NewStaticV4(cfg.AccessKey, cfg.SecretKey, ""),
|
||||||
|
Secure: cfg.UseSSL,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create S3 client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查bucket是否存在
|
||||||
|
exists, err := client.BucketExists(ctx, cfg.Bucket)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check bucket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
if err := client.MakeBucket(ctx, cfg.Bucket, minio.MakeBucketOptions{
|
||||||
|
Region: cfg.Region,
|
||||||
|
}); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create bucket: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有配置domain,则使用默认的endpoint
|
||||||
|
domain := cfg.Domain
|
||||||
|
if domain == "" {
|
||||||
|
domain = cfg.Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
client: client,
|
||||||
|
bucket: cfg.Bucket,
|
||||||
|
domain: domain,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upload 上传文件
|
||||||
|
func (c *Client) Upload(ctx context.Context, objectName string, filePath string, contentType string) (string, error) {
|
||||||
|
_, err := c.client.FPutObject(ctx, c.bucket, objectName, filePath, minio.PutObjectOptions{
|
||||||
|
ContentType: contentType,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to upload file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s/%s", c.bucket, objectName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadData 上传数据
|
||||||
|
func (c *Client) UploadData(ctx context.Context, objectName string, data []byte, contentType string) (string, error) {
|
||||||
|
_, err := c.client.PutObject(ctx, c.bucket, objectName, bytes.NewReader(data), int64(len(data)), minio.PutObjectOptions{
|
||||||
|
ContentType: contentType,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to upload data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 返回完整URL,包含bucket名称
|
||||||
|
scheme := "https"
|
||||||
|
if c.domain == c.bucket || c.domain == "" {
|
||||||
|
scheme = "http"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s://%s/%s/%s", scheme, c.domain, c.bucket, objectName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetURL 获取文件URL - 使用自定义域名
|
||||||
|
func (c *Client) GetURL(ctx context.Context, objectName string) (string, error) {
|
||||||
|
// 使用自定义域名构建URL,包含bucket名称
|
||||||
|
scheme := "https"
|
||||||
|
if c.domain == c.bucket || c.domain == "" {
|
||||||
|
scheme = "http"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s://%s/%s/%s", scheme, c.domain, c.bucket, objectName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPresignedURL 获取预签名URL(用于私有桶)
|
||||||
|
func (c *Client) GetPresignedURL(ctx context.Context, objectName string) (string, error) {
|
||||||
|
url, err := c.client.PresignedGetObject(ctx, c.bucket, objectName, time.Hour*24, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get presigned URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return url.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除文件
|
||||||
|
func (c *Client) Delete(ctx context.Context, objectName string) error {
|
||||||
|
return c.client.RemoveObject(ctx, c.bucket, objectName, minio.RemoveObjectOptions{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClient 获取原生客户端
|
||||||
|
func (c *Client) GetClient() *minio.Client {
|
||||||
|
return c.client
|
||||||
|
}
|
||||||
52
internal/pkg/utils/avatar.go
Normal file
52
internal/pkg/utils/avatar.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AvatarServiceBaseURL 默认头像服务基础URL (使用 UI Avatars API)
|
||||||
|
const AvatarServiceBaseURL = "https://ui-avatars.com/api"
|
||||||
|
|
||||||
|
// DefaultAvatarSize 默认头像尺寸
|
||||||
|
const DefaultAvatarSize = 100
|
||||||
|
|
||||||
|
// AvatarInfo 头像信息
|
||||||
|
type AvatarInfo struct {
|
||||||
|
Username string
|
||||||
|
Nickname string
|
||||||
|
Avatar string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAvatarOrDefault 获取头像URL,如果为空则返回在线头像生成服务的URL
|
||||||
|
// 优先使用已有的头像,否则使用昵称或用户名生成默认头像
|
||||||
|
func GetAvatarOrDefault(username, nickname, avatar string) string {
|
||||||
|
if avatar != "" {
|
||||||
|
return avatar
|
||||||
|
}
|
||||||
|
// 使用用户名生成默认头像URL(优先使用昵称)
|
||||||
|
displayName := nickname
|
||||||
|
if displayName == "" {
|
||||||
|
displayName = username
|
||||||
|
}
|
||||||
|
return GenerateDefaultAvatarURL(displayName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAvatarOrDefaultFromInfo 从 AvatarInfo 获取头像URL
|
||||||
|
func GetAvatarOrDefaultFromInfo(info AvatarInfo) string {
|
||||||
|
return GetAvatarOrDefault(info.Username, info.Nickname, info.Avatar)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateDefaultAvatarURL 生成默认头像URL
|
||||||
|
// 使用 UI Avatars API 生成基于用户名首字母的头像
|
||||||
|
func GenerateDefaultAvatarURL(name string) string {
|
||||||
|
if name == "" {
|
||||||
|
name = "?"
|
||||||
|
}
|
||||||
|
// 使用 UI Avatars API 生成头像
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("name", url.QueryEscape(name))
|
||||||
|
params.Set("size", "100")
|
||||||
|
params.Set("background", "0D8ABC") // 默认蓝色背景
|
||||||
|
params.Set("color", "ffffff") // 白色文字
|
||||||
|
return AvatarServiceBaseURL + "?" + params.Encode()
|
||||||
|
}
|
||||||
17
internal/pkg/utils/hash.go
Normal file
17
internal/pkg/utils/hash.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HashPassword 密码哈希
|
||||||
|
func HashPassword(password string) (string, error) {
|
||||||
|
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
return string(bytes), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckPasswordHash 验证密码
|
||||||
|
func CheckPasswordHash(password, hash string) bool {
|
||||||
|
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
261
internal/pkg/utils/snowflake.go
Normal file
261
internal/pkg/utils/snowflake.go
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 雪花算法常量定义
|
||||||
|
const (
|
||||||
|
// 64位ID结构:1位符号位 + 41位时间戳 + 10位机器ID + 12位序列号
|
||||||
|
|
||||||
|
// 机器ID占用的位数
|
||||||
|
nodeIDBits uint64 = 10
|
||||||
|
// 序列号占用的位数
|
||||||
|
sequenceBits uint64 = 12
|
||||||
|
|
||||||
|
// 机器ID的最大值 (0-1023)
|
||||||
|
maxNodeID int64 = -1 ^ (-1 << nodeIDBits)
|
||||||
|
// 序列号的最大值 (0-4095)
|
||||||
|
maxSequence int64 = -1 ^ (-1 << sequenceBits)
|
||||||
|
|
||||||
|
// 机器ID左移位数
|
||||||
|
nodeIDShift uint64 = sequenceBits
|
||||||
|
// 时间戳左移位数
|
||||||
|
timestampShift uint64 = sequenceBits + nodeIDBits
|
||||||
|
|
||||||
|
// 自定义纪元时间:2024-01-01 00:00:00 UTC
|
||||||
|
// 使用自定义纪元可以延长ID有效期约24年(从2024年开始)
|
||||||
|
customEpoch int64 = 1704067200000 // 2024-01-01 00:00:00 UTC 的毫秒时间戳
|
||||||
|
)
|
||||||
|
|
||||||
|
// 错误定义
|
||||||
|
var (
|
||||||
|
// ErrInvalidNodeID 机器ID无效
|
||||||
|
ErrInvalidNodeID = errors.New("node ID must be between 0 and 1023")
|
||||||
|
// ErrClockBackwards 时钟回拨
|
||||||
|
ErrClockBackwards = errors.New("clock moved backwards, refusing to generate ID")
|
||||||
|
)
|
||||||
|
|
||||||
|
// IDInfo 解析后的ID信息
|
||||||
|
type IDInfo struct {
|
||||||
|
Timestamp int64 // 生成ID时的时间戳(毫秒)
|
||||||
|
NodeID int64 // 机器ID
|
||||||
|
Sequence int64 // 序列号
|
||||||
|
}
|
||||||
|
|
||||||
|
// Snowflake 雪花算法ID生成器
|
||||||
|
type Snowflake struct {
|
||||||
|
mu sync.Mutex // 互斥锁,保证线程安全
|
||||||
|
nodeID int64 // 机器ID (0-1023)
|
||||||
|
sequence int64 // 当前序列号 (0-4095)
|
||||||
|
lastTimestamp int64 // 上次生成ID的时间戳
|
||||||
|
}
|
||||||
|
|
||||||
|
// 全局雪花算法实例
|
||||||
|
var (
|
||||||
|
globalSnowflake *Snowflake
|
||||||
|
globalSnowflakeOnce sync.Once
|
||||||
|
globalSnowflakeErr error
|
||||||
|
)
|
||||||
|
|
||||||
|
// InitSnowflake 初始化全局雪花算法实例
|
||||||
|
// nodeID: 机器ID,范围0-1023,可以通过环境变量 NODE_ID 配置
|
||||||
|
func InitSnowflake(nodeID int64) error {
|
||||||
|
globalSnowflake, globalSnowflakeErr = NewSnowflake(nodeID)
|
||||||
|
return globalSnowflakeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSnowflake 获取全局雪花算法实例
|
||||||
|
// 如果未初始化,会自动使用默认配置初始化
|
||||||
|
func GetSnowflake() *Snowflake {
|
||||||
|
globalSnowflakeOnce.Do(func() {
|
||||||
|
if globalSnowflake == nil {
|
||||||
|
globalSnowflake, globalSnowflakeErr = NewSnowflake(-1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return globalSnowflake
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSnowflake 创建雪花算法ID生成器实例
|
||||||
|
// nodeID: 机器ID,范围0-1023,可以通过环境变量 NODE_ID 配置
|
||||||
|
// 如果nodeID为-1,则尝试从环境变量 NODE_ID 读取
|
||||||
|
func NewSnowflake(nodeID int64) (*Snowflake, error) {
|
||||||
|
// 如果传入-1,尝试从环境变量读取
|
||||||
|
if nodeID == -1 {
|
||||||
|
nodeIDStr := os.Getenv("NODE_ID")
|
||||||
|
if nodeIDStr != "" {
|
||||||
|
// 解析环境变量
|
||||||
|
parsedID, err := parseInt(nodeIDStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrInvalidNodeID
|
||||||
|
}
|
||||||
|
nodeID = parsedID
|
||||||
|
} else {
|
||||||
|
// 默认使用0
|
||||||
|
nodeID = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证机器ID范围
|
||||||
|
if nodeID < 0 || nodeID > maxNodeID {
|
||||||
|
return nil, ErrInvalidNodeID
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Snowflake{
|
||||||
|
nodeID: nodeID,
|
||||||
|
sequence: 0,
|
||||||
|
lastTimestamp: 0,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseInt 辅助函数:解析整数
|
||||||
|
func parseInt(s string) (int64, error) {
|
||||||
|
var result int64
|
||||||
|
var negative bool
|
||||||
|
|
||||||
|
if len(s) == 0 {
|
||||||
|
return 0, errors.New("empty string")
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
if s[0] == '-' {
|
||||||
|
negative = true
|
||||||
|
i = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
for ; i < len(s); i++ {
|
||||||
|
if s[i] < '0' || s[i] > '9' {
|
||||||
|
return 0, errors.New("invalid character")
|
||||||
|
}
|
||||||
|
result = result*10 + int64(s[i]-'0')
|
||||||
|
}
|
||||||
|
|
||||||
|
if negative {
|
||||||
|
result = -result
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateID 生成唯一的雪花算法ID
|
||||||
|
// 返回值:生成的ID,以及可能的错误(如时钟回拨)
|
||||||
|
// 线程安全:使用互斥锁保证并发安全
|
||||||
|
func (s *Snowflake) GenerateID() (int64, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// 获取当前时间戳(毫秒)
|
||||||
|
now := currentTimestamp()
|
||||||
|
|
||||||
|
// 处理时钟回拨
|
||||||
|
if now < s.lastTimestamp {
|
||||||
|
return 0, ErrClockBackwards
|
||||||
|
}
|
||||||
|
|
||||||
|
// 同一毫秒内
|
||||||
|
if now == s.lastTimestamp {
|
||||||
|
// 序列号递增
|
||||||
|
s.sequence = (s.sequence + 1) & maxSequence
|
||||||
|
// 序列号溢出,等待下一毫秒
|
||||||
|
if s.sequence == 0 {
|
||||||
|
now = s.waitNextMillis(now)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 不同毫秒,序列号重置为0
|
||||||
|
s.sequence = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新上次生成时间
|
||||||
|
s.lastTimestamp = now
|
||||||
|
|
||||||
|
// 组装ID
|
||||||
|
// ID结构:时间戳部分 | 机器ID部分 | 序列号部分
|
||||||
|
id := ((now - customEpoch) << timestampShift) |
|
||||||
|
(s.nodeID << nodeIDShift) |
|
||||||
|
s.sequence
|
||||||
|
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitNextMillis 等待到下一毫秒
|
||||||
|
// 参数:当前时间戳
|
||||||
|
// 返回值:下一毫秒的时间戳
|
||||||
|
func (s *Snowflake) waitNextMillis(timestamp int64) int64 {
|
||||||
|
now := currentTimestamp()
|
||||||
|
for now <= timestamp {
|
||||||
|
now = currentTimestamp()
|
||||||
|
}
|
||||||
|
return now
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseID 解析雪花算法ID,提取其中的信息
|
||||||
|
// id: 要解析的雪花算法ID
|
||||||
|
// 返回值:包含时间戳、机器ID、序列号的结构体
|
||||||
|
func ParseID(id int64) *IDInfo {
|
||||||
|
// 提取序列号(低12位)
|
||||||
|
sequence := id & maxSequence
|
||||||
|
|
||||||
|
// 提取机器ID(中间10位)
|
||||||
|
nodeID := (id >> nodeIDShift) & maxNodeID
|
||||||
|
|
||||||
|
// 提取时间戳(高41位)
|
||||||
|
timestamp := (id >> timestampShift) + customEpoch
|
||||||
|
|
||||||
|
return &IDInfo{
|
||||||
|
Timestamp: timestamp,
|
||||||
|
NodeID: nodeID,
|
||||||
|
Sequence: sequence,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// currentTimestamp 获取当前时间戳(毫秒)
|
||||||
|
func currentTimestamp() int64 {
|
||||||
|
return time.Now().UnixNano() / 1e6
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNodeID 获取当前机器ID
|
||||||
|
func (s *Snowflake) GetNodeID() int64 {
|
||||||
|
return s.nodeID
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCustomEpoch 获取自定义纪元时间
|
||||||
|
func GetCustomEpoch() int64 {
|
||||||
|
return customEpoch
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDToTime 将雪花算法ID转换为生成时间
|
||||||
|
// 这是一个便捷方法,等价于 ParseID(id).Timestamp
|
||||||
|
func IDToTime(id int64) time.Time {
|
||||||
|
info := ParseID(id)
|
||||||
|
return time.Unix(0, info.Timestamp*1e6) // 毫秒转纳秒
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateID 验证ID是否为有效的雪花算法ID
|
||||||
|
// 检查时间戳是否在合理范围内
|
||||||
|
func ValidateID(id int64) bool {
|
||||||
|
if id <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
info := ParseID(id)
|
||||||
|
|
||||||
|
// 检查时间戳是否在合理范围内
|
||||||
|
// 不能早于纪元时间,不能晚于当前时间太多(允许1分钟的时钟偏差)
|
||||||
|
now := currentTimestamp()
|
||||||
|
if info.Timestamp < customEpoch || info.Timestamp > now+60000 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查机器ID和序列号是否在有效范围内
|
||||||
|
if info.NodeID < 0 || info.NodeID > maxNodeID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if info.Sequence < 0 || info.Sequence > maxSequence {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
46
internal/pkg/utils/validator.go
Normal file
46
internal/pkg/utils/validator.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidateEmail 验证邮箱
|
||||||
|
func ValidateEmail(email string) bool {
|
||||||
|
pattern := `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`
|
||||||
|
matched, _ := regexp.MatchString(pattern, email)
|
||||||
|
return matched
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateUsername 验证用户名
|
||||||
|
func ValidateUsername(username string) bool {
|
||||||
|
if len(username) < 3 || len(username) > 30 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
pattern := `^[a-zA-Z0-9_]+$`
|
||||||
|
matched, _ := regexp.MatchString(pattern, username)
|
||||||
|
return matched
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidatePassword 验证密码强度
|
||||||
|
func ValidatePassword(password string) bool {
|
||||||
|
if len(password) < 6 || len(password) > 50 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidatePhone 验证手机号
|
||||||
|
func ValidatePhone(phone string) bool {
|
||||||
|
pattern := `^1[3-9]\d{9}$`
|
||||||
|
matched, _ := regexp.MatchString(pattern, phone)
|
||||||
|
return matched
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeHTML 清理HTML
|
||||||
|
func SanitizeHTML(input string) string {
|
||||||
|
// 简单实现,实际使用建议用专门库
|
||||||
|
input = strings.ReplaceAll(input, "<", "<")
|
||||||
|
input = strings.ReplaceAll(input, ">", ">")
|
||||||
|
return input
|
||||||
|
}
|
||||||
440
internal/pkg/websocket/websocket.go
Normal file
440
internal/pkg/websocket/websocket.go
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"encoding/json"
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WebSocket消息类型常量
|
||||||
|
const (
|
||||||
|
MessageTypePing = "ping"
|
||||||
|
MessageTypePong = "pong"
|
||||||
|
MessageTypeMessage = "message"
|
||||||
|
MessageTypeTyping = "typing"
|
||||||
|
MessageTypeRead = "read"
|
||||||
|
MessageTypeAck = "ack"
|
||||||
|
MessageTypeError = "error"
|
||||||
|
MessageTypeRecall = "recall" // 撤回消息
|
||||||
|
MessageTypeSystem = "system" // 系统消息
|
||||||
|
MessageTypeNotification = "notification" // 通知消息
|
||||||
|
MessageTypeAnnouncement = "announcement" // 公告消息
|
||||||
|
|
||||||
|
// 群组相关消息类型
|
||||||
|
MessageTypeGroupMessage = "group_message" // 群消息
|
||||||
|
MessageTypeGroupTyping = "group_typing" // 群输入状态
|
||||||
|
MessageTypeGroupNotice = "group_notice" // 群组通知(成员变动等)
|
||||||
|
MessageTypeGroupMention = "group_mention" // @提及通知
|
||||||
|
MessageTypeGroupRead = "group_read" // 群消息已读
|
||||||
|
MessageTypeGroupRecall = "group_recall" // 群消息撤回
|
||||||
|
|
||||||
|
// Meta事件详细类型
|
||||||
|
MetaDetailTypeHeartbeat = "heartbeat"
|
||||||
|
MetaDetailTypeTyping = "typing"
|
||||||
|
MetaDetailTypeAck = "ack" // 消息发送确认
|
||||||
|
MetaDetailTypeRead = "read" // 已读回执
|
||||||
|
)
|
||||||
|
|
||||||
|
// WSMessage WebSocket消息结构
|
||||||
|
type WSMessage struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Data interface{} `json:"data"`
|
||||||
|
Timestamp int64 `json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatMessage 聊天消息结构
|
||||||
|
type ChatMessage struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
ConversationID string `json:"conversation_id"`
|
||||||
|
SenderID string `json:"sender_id"`
|
||||||
|
Seq int64 `json:"seq"`
|
||||||
|
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
|
||||||
|
ReplyToID *string `json:"reply_to_id,omitempty"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SystemMessage 系统消息结构
|
||||||
|
type SystemMessage struct {
|
||||||
|
ID string `json:"id"` // 消息ID
|
||||||
|
Type string `json:"type"` // 消息子类型(如:account_banned, post_approved等)
|
||||||
|
Title string `json:"title"` // 消息标题
|
||||||
|
Content string `json:"content"` // 消息内容
|
||||||
|
Data map[string]interface{} `json:"data"` // 额外数据
|
||||||
|
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotificationMessage 通知消息结构
|
||||||
|
type NotificationMessage struct {
|
||||||
|
ID string `json:"id"` // 通知ID
|
||||||
|
Type string `json:"type"` // 通知类型(like, comment, follow, mention等)
|
||||||
|
Title string `json:"title"` // 通知标题
|
||||||
|
Content string `json:"content"` // 通知内容
|
||||||
|
TriggerUser *NotificationUser `json:"trigger_user"` // 触发用户
|
||||||
|
ResourceType string `json:"resource_type"` // 资源类型(post, comment等)
|
||||||
|
ResourceID string `json:"resource_id"` // 资源ID
|
||||||
|
Extra map[string]interface{} `json:"extra"` // 额外数据
|
||||||
|
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotificationUser 通知中的用户信息
|
||||||
|
type NotificationUser struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Avatar string `json:"avatar"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AnnouncementMessage 公告消息结构
|
||||||
|
type AnnouncementMessage struct {
|
||||||
|
ID string `json:"id"` // 公告ID
|
||||||
|
Title string `json:"title"` // 公告标题
|
||||||
|
Content string `json:"content"` // 公告内容
|
||||||
|
Priority int `json:"priority"` // 优先级(1-10)
|
||||||
|
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupMessage 群消息结构
|
||||||
|
type GroupMessage struct {
|
||||||
|
ID string `json:"id"` // 消息ID
|
||||||
|
ConversationID string `json:"conversation_id"` // 会话ID(群聊会话)
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
SenderID string `json:"sender_id"` // 发送者ID
|
||||||
|
Seq int64 `json:"seq"` // 消息序号
|
||||||
|
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
|
||||||
|
ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID
|
||||||
|
MentionUsers []uint64 `json:"mention_users,omitempty"` // @的用户ID列表
|
||||||
|
MentionAll bool `json:"mention_all"` // 是否@所有人
|
||||||
|
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupTypingMessage 群输入状态消息
|
||||||
|
type GroupTypingMessage struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
UserID string `json:"user_id"` // 用户ID
|
||||||
|
Username string `json:"username"` // 用户名
|
||||||
|
IsTyping bool `json:"is_typing"` // 是否正在输入
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupNoticeMessage 群组通知消息
|
||||||
|
type GroupNoticeMessage struct {
|
||||||
|
NoticeType string `json:"notice_type"` // 通知类型:member_join, member_leave, member_removed, role_changed, muted, unmuted, group_dissolved
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
Data interface{} `json:"data"` // 通知数据
|
||||||
|
Timestamp int64 `json:"timestamp"` // 时间戳
|
||||||
|
MessageID string `json:"message_id,omitempty"` // 消息ID(如果通知保存为消息)
|
||||||
|
Seq int64 `json:"seq,omitempty"` // 消息序号(如果通知保存为消息)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupNoticeData 通知数据结构
|
||||||
|
type GroupNoticeData struct {
|
||||||
|
// 成员变动
|
||||||
|
UserID string `json:"user_id,omitempty"` // 变动的用户ID
|
||||||
|
Username string `json:"username,omitempty"` // 用户名
|
||||||
|
OperatorID string `json:"operator_id,omitempty"` // 操作者ID
|
||||||
|
OpName string `json:"op_name,omitempty"` // 操作者名称
|
||||||
|
NewRole string `json:"new_role,omitempty"` // 新角色
|
||||||
|
OldRole string `json:"old_role,omitempty"` // 旧角色
|
||||||
|
MemberCount int `json:"member_count,omitempty"` // 当前成员数
|
||||||
|
|
||||||
|
// 群设置变更
|
||||||
|
MuteAll bool `json:"mute_all,omitempty"` // 全员禁言状态
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupMentionMessage @提及通知消息
|
||||||
|
type GroupMentionMessage struct {
|
||||||
|
GroupID string `json:"group_id"` // 群组ID
|
||||||
|
MessageID string `json:"message_id"` // 消息ID
|
||||||
|
FromUserID string `json:"from_user_id"` // 发送者ID
|
||||||
|
FromName string `json:"from_name"` // 发送者名称
|
||||||
|
Content string `json:"content"` // 消息内容预览
|
||||||
|
MentionAll bool `json:"mention_all"` // 是否@所有人
|
||||||
|
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||||
|
}
|
||||||
|
|
||||||
|
// AckMessage 消息发送确认结构
|
||||||
|
type AckMessage struct {
|
||||||
|
ConversationID string `json:"conversation_id"` // 会话ID
|
||||||
|
GroupID string `json:"group_id,omitempty"` // 群组ID(群聊时)
|
||||||
|
ID string `json:"id"` // 消息ID
|
||||||
|
SenderID string `json:"sender_id"` // 发送者ID
|
||||||
|
Seq int64 `json:"seq"` // 消息序号
|
||||||
|
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
|
||||||
|
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client WebSocket客户端
|
||||||
|
type Client struct {
|
||||||
|
ID string
|
||||||
|
UserID string
|
||||||
|
Conn *websocket.Conn
|
||||||
|
Send chan []byte
|
||||||
|
Manager *WebSocketManager
|
||||||
|
IsClosed bool
|
||||||
|
Mu sync.Mutex
|
||||||
|
closeOnce sync.Once // 确保 Send channel 只关闭一次
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocketManager WebSocket连接管理器
|
||||||
|
type WebSocketManager struct {
|
||||||
|
clients map[string]*Client // userID -> Client
|
||||||
|
register chan *Client
|
||||||
|
unregister chan *Client
|
||||||
|
broadcast chan *BroadcastMessage
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// BroadcastMessage 广播消息
|
||||||
|
type BroadcastMessage struct {
|
||||||
|
Message *WSMessage
|
||||||
|
ExcludeUser string // 排除的用户ID,为空表示不排除
|
||||||
|
TargetUser string // 目标用户ID,为空表示广播给所有用户
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWebSocketManager 创建WebSocket管理器
|
||||||
|
func NewWebSocketManager() *WebSocketManager {
|
||||||
|
return &WebSocketManager{
|
||||||
|
clients: make(map[string]*Client),
|
||||||
|
register: make(chan *Client, 100),
|
||||||
|
unregister: make(chan *Client, 100),
|
||||||
|
broadcast: make(chan *BroadcastMessage, 100),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start 启动管理器
|
||||||
|
func (m *WebSocketManager) Start() {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case client := <-m.register:
|
||||||
|
m.mutex.Lock()
|
||||||
|
m.clients[client.UserID] = client
|
||||||
|
m.mutex.Unlock()
|
||||||
|
log.Printf("WebSocket client connected: userID=%s, 当前在线用户数=%d", client.UserID, len(m.clients))
|
||||||
|
|
||||||
|
case client := <-m.unregister:
|
||||||
|
m.mutex.Lock()
|
||||||
|
if _, ok := m.clients[client.UserID]; ok {
|
||||||
|
delete(m.clients, client.UserID)
|
||||||
|
// 使用 closeOnce 确保 channel 只关闭一次,避免 panic
|
||||||
|
client.closeOnce.Do(func() {
|
||||||
|
close(client.Send)
|
||||||
|
})
|
||||||
|
log.Printf("WebSocket client disconnected: userID=%s", client.UserID)
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
case broadcast := <-m.broadcast:
|
||||||
|
m.sendMessage(broadcast)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register 注册客户端
|
||||||
|
func (m *WebSocketManager) Register(client *Client) {
|
||||||
|
m.register <- client
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unregister 注销客户端
|
||||||
|
func (m *WebSocketManager) Unregister(client *Client) {
|
||||||
|
m.unregister <- client
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast 广播消息给所有用户
|
||||||
|
func (m *WebSocketManager) Broadcast(msg *WSMessage) {
|
||||||
|
m.broadcast <- &BroadcastMessage{
|
||||||
|
Message: msg,
|
||||||
|
TargetUser: "",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendToUser 发送消息给指定用户
|
||||||
|
func (m *WebSocketManager) SendToUser(userID string, msg *WSMessage) {
|
||||||
|
m.broadcast <- &BroadcastMessage{
|
||||||
|
Message: msg,
|
||||||
|
TargetUser: userID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendToUsers 发送消息给指定用户列表
|
||||||
|
func (m *WebSocketManager) SendToUsers(userIDs []string, msg *WSMessage) {
|
||||||
|
for _, userID := range userIDs {
|
||||||
|
m.SendToUser(userID, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClient 获取客户端
|
||||||
|
func (m *WebSocketManager) GetClient(userID string) (*Client, bool) {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
client, ok := m.clients[userID]
|
||||||
|
return client, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllClients 获取所有客户端
|
||||||
|
func (m *WebSocketManager) GetAllClients() map[string]*Client {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
return m.clients
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientCount 获取在线用户数量
|
||||||
|
func (m *WebSocketManager) GetClientCount() int {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
return len(m.clients)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsUserOnline 检查用户是否在线
|
||||||
|
func (m *WebSocketManager) IsUserOnline(userID string) bool {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
_, ok := m.clients[userID]
|
||||||
|
log.Printf("[DEBUG IsUserOnline] 检查用户 %s, 结果=%v, 当前在线用户=%v", userID, ok, m.clients)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendMessage 发送消息
|
||||||
|
func (m *WebSocketManager) sendMessage(broadcast *BroadcastMessage) {
|
||||||
|
msgBytes, err := json.Marshal(broadcast.Message)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to marshal message: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[DEBUG WebSocket] sendMessage: 目标用户=%s, 当前在线用户数=%d, 消息类型=%s",
|
||||||
|
broadcast.TargetUser, len(m.clients), broadcast.Message.Type)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
|
for userID, client := range m.clients {
|
||||||
|
// 如果指定了目标用户,只发送给目标用户
|
||||||
|
if broadcast.TargetUser != "" && userID != broadcast.TargetUser {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果指定了排除用户,跳过
|
||||||
|
if broadcast.ExcludeUser != "" && userID == broadcast.ExcludeUser {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case client.Send <- msgBytes:
|
||||||
|
log.Printf("[DEBUG WebSocket] 成功发送消息到用户 %s, 消息类型=%s", userID, broadcast.Message.Type)
|
||||||
|
default:
|
||||||
|
log.Printf("Failed to send message to user %s: channel full", userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendPing 发送心跳
|
||||||
|
func (c *Client) SendPing() error {
|
||||||
|
c.Mu.Lock()
|
||||||
|
defer c.Mu.Unlock()
|
||||||
|
if c.IsClosed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
msg := WSMessage{
|
||||||
|
Type: MessageTypePing,
|
||||||
|
Data: nil,
|
||||||
|
Timestamp: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
msgBytes, _ := json.Marshal(msg)
|
||||||
|
return c.Conn.WriteMessage(websocket.TextMessage, msgBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendPong 发送Pong响应
|
||||||
|
func (c *Client) SendPong() error {
|
||||||
|
c.Mu.Lock()
|
||||||
|
defer c.Mu.Unlock()
|
||||||
|
if c.IsClosed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
msg := WSMessage{
|
||||||
|
Type: MessageTypePong,
|
||||||
|
Data: nil,
|
||||||
|
Timestamp: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
msgBytes, _ := json.Marshal(msg)
|
||||||
|
return c.Conn.WriteMessage(websocket.TextMessage, msgBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WritePump 写入泵,将消息从Manager发送到客户端
|
||||||
|
func (c *Client) WritePump() {
|
||||||
|
defer func() {
|
||||||
|
c.Conn.Close()
|
||||||
|
c.Mu.Lock()
|
||||||
|
c.IsClosed = true
|
||||||
|
c.Mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
message, ok := <-c.Send
|
||||||
|
if !ok {
|
||||||
|
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Mu.Lock()
|
||||||
|
if c.IsClosed {
|
||||||
|
c.Mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := c.Conn.WriteMessage(websocket.TextMessage, message)
|
||||||
|
c.Mu.Unlock()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Write error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadPump 读取泵,从客户端读取消息
|
||||||
|
func (c *Client) ReadPump(handler func(msg *WSMessage)) {
|
||||||
|
defer func() {
|
||||||
|
c.Manager.Unregister(c)
|
||||||
|
c.Conn.Close()
|
||||||
|
c.Mu.Lock()
|
||||||
|
c.IsClosed = true
|
||||||
|
c.Mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
c.Conn.SetReadLimit(512 * 1024) // 512KB
|
||||||
|
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||||
|
c.Conn.SetPongHandler(func(string) error {
|
||||||
|
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
for {
|
||||||
|
_, message, err := c.Conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||||
|
log.Printf("WebSocket error: %v", err)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var wsMsg WSMessage
|
||||||
|
if err := json.Unmarshal(message, &wsMsg); err != nil {
|
||||||
|
log.Printf("Failed to unmarshal message: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
handler(&wsMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateWSMessage 创建WebSocket消息
|
||||||
|
func CreateWSMessage(msgType string, data interface{}) *WSMessage {
|
||||||
|
return &WSMessage{
|
||||||
|
Type: msgType,
|
||||||
|
Data: data,
|
||||||
|
Timestamp: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
}
|
||||||
296
internal/repository/comment_repo.go
Normal file
296
internal/repository/comment_repo.go
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CommentRepository 评论仓储
|
||||||
|
type CommentRepository struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCommentRepository 创建评论仓储
|
||||||
|
func NewCommentRepository(db *gorm.DB) *CommentRepository {
|
||||||
|
return &CommentRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 创建评论
|
||||||
|
func (r *CommentRepository) Create(comment *model.Comment) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 创建评论
|
||||||
|
err := tx.Create(comment).Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 增加帖子的评论数并同步热度分
|
||||||
|
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"comments_count": gorm.Expr("comments_count + 1"),
|
||||||
|
"hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"),
|
||||||
|
}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果是回复,增加父评论的回复数
|
||||||
|
if comment.ParentID != nil && *comment.ParentID != "" {
|
||||||
|
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
|
||||||
|
UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID 根据ID获取评论
|
||||||
|
func (r *CommentRepository) GetByID(id string) (*model.Comment, error) {
|
||||||
|
var comment model.Comment
|
||||||
|
err := r.db.Preload("User").First(&comment, "id = ?", id).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &comment, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 更新评论
|
||||||
|
func (r *CommentRepository) Update(comment *model.Comment) error {
|
||||||
|
return r.db.Save(comment).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateModerationStatus 更新评论审核状态
|
||||||
|
func (r *CommentRepository) UpdateModerationStatus(commentID string, status model.CommentStatus) error {
|
||||||
|
return r.db.Model(&model.Comment{}).
|
||||||
|
Where("id = ?", commentID).
|
||||||
|
Update("status", status).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除评论(软删除,同时清理关联数据)
|
||||||
|
func (r *CommentRepository) Delete(id string) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 先查询评论获取post_id和parent_id
|
||||||
|
var comment model.Comment
|
||||||
|
if err := tx.First(&comment, "id = ?", id).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除评论点赞记录
|
||||||
|
if err := tx.Where("comment_id = ?", id).Delete(&model.CommentLike{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除评论(软删除)
|
||||||
|
if err := tx.Delete(&model.Comment{}, "id = ?", id).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 减少帖子的评论数并同步热度分
|
||||||
|
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"comments_count": gorm.Expr("comments_count - 1"),
|
||||||
|
"hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"),
|
||||||
|
}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果是回复,减少父评论的回复数
|
||||||
|
if comment.ParentID != nil && *comment.ParentID != "" {
|
||||||
|
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
|
||||||
|
UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByPostID 获取帖子评论
|
||||||
|
func (r *CommentRepository) GetByPostID(postID string, page, pageSize int) ([]*model.Comment, int64, error) {
|
||||||
|
var comments []*model.Comment
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
r.db.Model(&model.Comment{}).Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := r.db.Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).
|
||||||
|
Preload("User").
|
||||||
|
Offset(offset).Limit(pageSize).
|
||||||
|
Order("created_at ASC").
|
||||||
|
Find(&comments).Error
|
||||||
|
|
||||||
|
return comments, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByPostIDWithReplies 获取帖子评论(包含回复,扁平化结构)
|
||||||
|
// 所有层级的回复都扁平化展示在顶级评论的 replies 中
|
||||||
|
func (r *CommentRepository) GetByPostIDWithReplies(postID string, page, pageSize, replyLimit int) ([]*model.Comment, int64, error) {
|
||||||
|
var comments []*model.Comment
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
r.db.Model(&model.Comment{}).Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := r.db.Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).
|
||||||
|
Preload("User").
|
||||||
|
Offset(offset).Limit(pageSize).
|
||||||
|
Order("created_at ASC").
|
||||||
|
Find(&comments).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(comments) == 0 {
|
||||||
|
return comments, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rootIDs := make([]string, 0, len(comments))
|
||||||
|
commentsByID := make(map[string]*model.Comment, len(comments))
|
||||||
|
for _, comment := range comments {
|
||||||
|
rootIDs = append(rootIDs, comment.ID)
|
||||||
|
commentsByID[comment.ID] = comment
|
||||||
|
}
|
||||||
|
|
||||||
|
// 批量加载所有回复,内存中按 root_id 分组并裁剪每个根评论的返回条数
|
||||||
|
var allReplies []*model.Comment
|
||||||
|
if err := r.db.Where("root_id IN ? AND status = ?", rootIDs, model.CommentStatusPublished).
|
||||||
|
Preload("User").
|
||||||
|
Order("created_at ASC").
|
||||||
|
Find(&allReplies).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
repliesByRoot := make(map[string][]*model.Comment, len(rootIDs))
|
||||||
|
for _, reply := range allReplies {
|
||||||
|
if reply.RootID == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rootID := *reply.RootID
|
||||||
|
if replyLimit <= 0 || len(repliesByRoot[rootID]) < replyLimit {
|
||||||
|
repliesByRoot[rootID] = append(repliesByRoot[rootID], reply)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type replyCountRow struct {
|
||||||
|
RootID string
|
||||||
|
Total int64
|
||||||
|
}
|
||||||
|
var replyCountRows []replyCountRow
|
||||||
|
if err := r.db.Model(&model.Comment{}).
|
||||||
|
Select("root_id, COUNT(*) AS total").
|
||||||
|
Where("root_id IN ? AND status = ?", rootIDs, model.CommentStatusPublished).
|
||||||
|
Group("root_id").
|
||||||
|
Scan(&replyCountRows).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
replyCountMap := make(map[string]int64, len(replyCountRows))
|
||||||
|
for _, row := range replyCountRows {
|
||||||
|
replyCountMap[row.RootID] = row.Total
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rootID := range rootIDs {
|
||||||
|
comment := commentsByID[rootID]
|
||||||
|
comment.Replies = repliesByRoot[rootID]
|
||||||
|
comment.RepliesCount = int(replyCountMap[rootID])
|
||||||
|
}
|
||||||
|
|
||||||
|
return comments, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadFlatReplies 加载评论的所有回复(扁平化,所有层级都在同一层)
|
||||||
|
func (r *CommentRepository) loadFlatReplies(rootComment *model.Comment, limit int) {
|
||||||
|
var allReplies []*model.Comment
|
||||||
|
|
||||||
|
// 查询所有以该评论为根评论的回复(不包括顶级评论本身)
|
||||||
|
r.db.Where("root_id = ? AND status = ?", rootComment.ID, model.CommentStatusPublished).
|
||||||
|
Preload("User").
|
||||||
|
Order("created_at ASC").
|
||||||
|
Limit(limit).
|
||||||
|
Find(&allReplies)
|
||||||
|
|
||||||
|
rootComment.Replies = allReplies
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRepliesByRootID 根据根评论ID分页获取回复(扁平化)
|
||||||
|
func (r *CommentRepository) GetRepliesByRootID(rootID string, page, pageSize int) ([]*model.Comment, int64, error) {
|
||||||
|
var replies []*model.Comment
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
// 统计总数
|
||||||
|
r.db.Model(&model.Comment{}).Where("root_id = ? AND status = ?", rootID, model.CommentStatusPublished).Count(&total)
|
||||||
|
|
||||||
|
// 分页查询
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := r.db.Where("root_id = ? AND status = ?", rootID, model.CommentStatusPublished).
|
||||||
|
Preload("User").
|
||||||
|
Order("created_at ASC").
|
||||||
|
Offset(offset).
|
||||||
|
Limit(pageSize).
|
||||||
|
Find(&replies).Error
|
||||||
|
|
||||||
|
return replies, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetReplies 获取回复
|
||||||
|
func (r *CommentRepository) GetReplies(parentID string) ([]*model.Comment, error) {
|
||||||
|
var comments []*model.Comment
|
||||||
|
err := r.db.Where("parent_id = ? AND status = ?", parentID, model.CommentStatusPublished).
|
||||||
|
Preload("User").
|
||||||
|
Order("created_at ASC").
|
||||||
|
Find(&comments).Error
|
||||||
|
return comments, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Like 点赞评论
|
||||||
|
func (r *CommentRepository) Like(commentID, userID string) error {
|
||||||
|
// 检查是否已经点赞
|
||||||
|
var existing model.CommentLike
|
||||||
|
err := r.db.Where("comment_id = ? AND user_id = ?", commentID, userID).First(&existing).Error
|
||||||
|
if err == nil {
|
||||||
|
// 已经点赞
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != gorm.ErrRecordNotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建点赞记录
|
||||||
|
like := &model.CommentLike{
|
||||||
|
CommentID: commentID,
|
||||||
|
UserID: userID,
|
||||||
|
}
|
||||||
|
err = r.db.Create(like).Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 增加评论点赞数
|
||||||
|
return r.db.Model(&model.Comment{}).Where("id = ?", commentID).
|
||||||
|
UpdateColumn("likes_count", gorm.Expr("likes_count + 1")).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unlike 取消点赞评论
|
||||||
|
func (r *CommentRepository) Unlike(commentID, userID string) error {
|
||||||
|
result := r.db.Where("comment_id = ? AND user_id = ?", commentID, userID).Delete(&model.CommentLike{})
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected > 0 {
|
||||||
|
// 减少评论点赞数
|
||||||
|
return r.db.Model(&model.Comment{}).Where("id = ?", commentID).
|
||||||
|
UpdateColumn("likes_count", gorm.Expr("likes_count - 1")).Error
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLiked 检查是否已点赞
|
||||||
|
func (r *CommentRepository) IsLiked(commentID, userID string) bool {
|
||||||
|
var count int64
|
||||||
|
r.db.Model(&model.CommentLike{}).Where("comment_id = ? AND user_id = ?", commentID, userID).Count(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
166
internal/repository/device_token_repo.go
Normal file
166
internal/repository/device_token_repo.go
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeviceTokenRepository 设备Token仓储
|
||||||
|
type DeviceTokenRepository struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDeviceTokenRepository 创建设备Token仓储
|
||||||
|
func NewDeviceTokenRepository(db *gorm.DB) *DeviceTokenRepository {
|
||||||
|
return &DeviceTokenRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 创建设备Token
|
||||||
|
func (r *DeviceTokenRepository) Create(token *model.DeviceToken) error {
|
||||||
|
return r.db.Create(token).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID 根据ID获取设备Token
|
||||||
|
func (r *DeviceTokenRepository) GetByID(id int64) (*model.DeviceToken, error) {
|
||||||
|
var token model.DeviceToken
|
||||||
|
err := r.db.First(&token, "id = ?", id).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 更新设备Token
|
||||||
|
func (r *DeviceTokenRepository) Update(token *model.DeviceToken) error {
|
||||||
|
return r.db.Save(token).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除设备Token(软删除)
|
||||||
|
func (r *DeviceTokenRepository) Delete(id int64) error {
|
||||||
|
return r.db.Delete(&model.DeviceToken{}, id).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByUserID 获取用户所有设备
|
||||||
|
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
|
func (r *DeviceTokenRepository) GetByUserID(userID string) ([]*model.DeviceToken, error) {
|
||||||
|
var tokens []*model.DeviceToken
|
||||||
|
err := r.db.Where("user_id = ?", userID).
|
||||||
|
Order("created_at DESC").
|
||||||
|
Find(&tokens).Error
|
||||||
|
return tokens, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveByUserID 获取用户活跃设备
|
||||||
|
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
|
func (r *DeviceTokenRepository) GetActiveByUserID(userID string) ([]*model.DeviceToken, error) {
|
||||||
|
var tokens []*model.DeviceToken
|
||||||
|
err := r.db.Where("user_id = ? AND is_active = ?", userID, true).
|
||||||
|
Order("last_used_at DESC").
|
||||||
|
Find(&tokens).Error
|
||||||
|
return tokens, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByDeviceID 根据设备ID获取设备Token
|
||||||
|
func (r *DeviceTokenRepository) GetByDeviceID(deviceID string) (*model.DeviceToken, error) {
|
||||||
|
var token model.DeviceToken
|
||||||
|
err := r.db.Where("device_id = ?", deviceID).First(&token).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByPushToken 根据推送Token获取设备信息
|
||||||
|
func (r *DeviceTokenRepository) GetByPushToken(pushToken string) (*model.DeviceToken, error) {
|
||||||
|
var token model.DeviceToken
|
||||||
|
err := r.db.Where("push_token = ?", pushToken).First(&token).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeactivateAllExcept 登出其他设备(停用除指定设备外的所有设备)
|
||||||
|
func (r *DeviceTokenRepository) DeactivateAllExcept(userID int64, deviceID string) error {
|
||||||
|
return r.db.Model(&model.DeviceToken{}).
|
||||||
|
Where("user_id = ? AND device_id != ?", userID, deviceID).
|
||||||
|
Update("is_active", false).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upsert 创建或更新设备Token
|
||||||
|
// 如果设备ID已存在,则更新Token和激活状态;否则创建新记录
|
||||||
|
func (r *DeviceTokenRepository) Upsert(token *model.DeviceToken) error {
|
||||||
|
var existing model.DeviceToken
|
||||||
|
err := r.db.Where("device_id = ?", token.DeviceID).First(&existing).Error
|
||||||
|
|
||||||
|
if err == gorm.ErrRecordNotFound {
|
||||||
|
// 创建新记录
|
||||||
|
return r.db.Create(token).Error
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新现有记录
|
||||||
|
return r.db.Model(&existing).Updates(map[string]interface{}{
|
||||||
|
"push_token": token.PushToken,
|
||||||
|
"is_active": true,
|
||||||
|
"device_name": token.DeviceName,
|
||||||
|
"last_used_at": time.Now(),
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateLastUsed 更新最后使用时间
|
||||||
|
func (r *DeviceTokenRepository) UpdateLastUsed(deviceID string) error {
|
||||||
|
return r.db.Model(&model.DeviceToken{}).
|
||||||
|
Where("device_id = ?", deviceID).
|
||||||
|
Update("last_used_at", time.Now()).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deactivate 停用设备
|
||||||
|
func (r *DeviceTokenRepository) Deactivate(deviceID string) error {
|
||||||
|
return r.db.Model(&model.DeviceToken{}).
|
||||||
|
Where("device_id = ?", deviceID).
|
||||||
|
Update("is_active", false).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Activate 激活设备
|
||||||
|
func (r *DeviceTokenRepository) Activate(deviceID string) error {
|
||||||
|
return r.db.Model(&model.DeviceToken{}).
|
||||||
|
Where("device_id = ?", deviceID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"is_active": true,
|
||||||
|
"last_used_at": time.Now(),
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteByUserID 删除用户所有设备Token
|
||||||
|
func (r *DeviceTokenRepository) DeleteByUserID(userID int64) error {
|
||||||
|
return r.db.Where("user_id = ?", userID).Delete(&model.DeviceToken{}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceCountByUserID 获取用户设备数量
|
||||||
|
func (r *DeviceTokenRepository) GetDeviceCountByUserID(userID int64) (int64, error) {
|
||||||
|
var count int64
|
||||||
|
err := r.db.Model(&model.DeviceToken{}).
|
||||||
|
Where("user_id = ?", userID).
|
||||||
|
Count(&count).Error
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveDeviceCountByUserID 获取用户活跃设备数量
|
||||||
|
func (r *DeviceTokenRepository) GetActiveDeviceCountByUserID(userID int64) (int64, error) {
|
||||||
|
var count int64
|
||||||
|
err := r.db.Model(&model.DeviceToken{}).
|
||||||
|
Where("user_id = ? AND is_active = ?", userID, true).
|
||||||
|
Count(&count).Error
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteInactiveDevices 删除长时间未使用的设备
|
||||||
|
func (r *DeviceTokenRepository) DeleteInactiveDevices(before time.Time) error {
|
||||||
|
return r.db.Where("is_active = ? AND last_used_at < ?", false, before).
|
||||||
|
Delete(&model.DeviceToken{}).Error
|
||||||
|
}
|
||||||
50
internal/repository/group_join_request_repo.go
Normal file
50
internal/repository/group_join_request_repo.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GroupJoinRequestRepository interface {
|
||||||
|
Create(req *model.GroupJoinRequest) error
|
||||||
|
GetByFlag(flag string) (*model.GroupJoinRequest, error)
|
||||||
|
Update(req *model.GroupJoinRequest) error
|
||||||
|
GetPendingByGroupAndTarget(groupID, targetUserID string, reqType model.GroupJoinRequestType) (*model.GroupJoinRequest, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type groupJoinRequestRepository struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGroupJoinRequestRepository(db *gorm.DB) GroupJoinRequestRepository {
|
||||||
|
return &groupJoinRequestRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *groupJoinRequestRepository) Create(req *model.GroupJoinRequest) error {
|
||||||
|
return r.db.Create(req).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *groupJoinRequestRepository) GetByFlag(flag string) (*model.GroupJoinRequest, error) {
|
||||||
|
var req model.GroupJoinRequest
|
||||||
|
if err := r.db.Where("flag = ?", flag).First(&req).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *groupJoinRequestRepository) Update(req *model.GroupJoinRequest) error {
|
||||||
|
return r.db.Save(req).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *groupJoinRequestRepository) GetPendingByGroupAndTarget(groupID, targetUserID string, reqType model.GroupJoinRequestType) (*model.GroupJoinRequest, error) {
|
||||||
|
var req model.GroupJoinRequest
|
||||||
|
err := r.db.Where("group_id = ? AND target_user_id = ? AND request_type = ? AND status = ?",
|
||||||
|
groupID, targetUserID, reqType, model.GroupJoinRequestStatusPending).
|
||||||
|
Order("created_at DESC").
|
||||||
|
First(&req).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &req, nil
|
||||||
|
}
|
||||||
242
internal/repository/group_repo.go
Normal file
242
internal/repository/group_repo.go
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GroupRepository 群组仓库接口
|
||||||
|
type GroupRepository interface {
|
||||||
|
// 群组操作
|
||||||
|
Create(group *model.Group) error
|
||||||
|
GetByID(id string) (*model.Group, error)
|
||||||
|
Update(group *model.Group) error
|
||||||
|
Delete(id string) error
|
||||||
|
GetByOwnerID(ownerID string, page, pageSize int) ([]model.Group, int64, error)
|
||||||
|
|
||||||
|
// 群成员操作
|
||||||
|
AddMember(member *model.GroupMember) error
|
||||||
|
GetMember(groupID string, userID string) (*model.GroupMember, error)
|
||||||
|
GetMembers(groupID string, page, pageSize int) ([]model.GroupMember, int64, error)
|
||||||
|
UpdateMember(member *model.GroupMember) error
|
||||||
|
RemoveMember(groupID string, userID string) error
|
||||||
|
GetMemberCount(groupID string) (int64, error)
|
||||||
|
IsMember(groupID string, userID string) (bool, error)
|
||||||
|
GetUserGroups(userID string, page, pageSize int) ([]model.Group, int64, error)
|
||||||
|
|
||||||
|
// 角色相关
|
||||||
|
GetMemberRole(groupID string, userID string) (string, error)
|
||||||
|
SetMemberRole(groupID string, userID string, role string) error
|
||||||
|
GetAdmins(groupID string) ([]model.GroupMember, error)
|
||||||
|
|
||||||
|
// 群公告操作
|
||||||
|
CreateAnnouncement(announcement *model.GroupAnnouncement) error
|
||||||
|
GetAnnouncements(groupID string, page, pageSize int) ([]model.GroupAnnouncement, int64, error)
|
||||||
|
GetAnnouncementByID(id string) (*model.GroupAnnouncement, error)
|
||||||
|
DeleteAnnouncement(id string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupRepository 群组仓库实现
|
||||||
|
type groupRepository struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGroupRepository 创建群组仓库
|
||||||
|
func NewGroupRepository(db *gorm.DB) GroupRepository {
|
||||||
|
return &groupRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 创建群组
|
||||||
|
func (r *groupRepository) Create(group *model.Group) error {
|
||||||
|
return r.db.Create(group).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID 根据ID获取群组
|
||||||
|
func (r *groupRepository) GetByID(id string) (*model.Group, error) {
|
||||||
|
var group model.Group
|
||||||
|
err := r.db.First(&group, "id = ?", id).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &group, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 更新群组
|
||||||
|
func (r *groupRepository) Update(group *model.Group) error {
|
||||||
|
return r.db.Save(group).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除群组
|
||||||
|
func (r *groupRepository) Delete(id string) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 删除群成员
|
||||||
|
if err := tx.Where("group_id = ?", id).Delete(&model.GroupMember{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// 删除群公告
|
||||||
|
if err := tx.Where("group_id = ?", id).Delete(&model.GroupAnnouncement{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// 删除群组
|
||||||
|
if err := tx.Delete(&model.Group{}, "id = ?", id).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByOwnerID 根据群主ID获取群组列表
|
||||||
|
func (r *groupRepository) GetByOwnerID(ownerID string, page, pageSize int) ([]model.Group, int64, error) {
|
||||||
|
var groups []model.Group
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
query := r.db.Model(&model.Group{}).Where("owner_id = ?", ownerID)
|
||||||
|
query.Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&groups).Error
|
||||||
|
return groups, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddMember 添加群成员
|
||||||
|
func (r *groupRepository) AddMember(member *model.GroupMember) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
if err := tx.Create(member).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// 更新群组成员数量
|
||||||
|
return tx.Model(&model.Group{}).Where("id = ?", member.GroupID).
|
||||||
|
Update("member_count", gorm.Expr("member_count + ?", 1)).Error
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMember 获取群成员
|
||||||
|
func (r *groupRepository) GetMember(groupID string, userID string) (*model.GroupMember, error) {
|
||||||
|
var member model.GroupMember
|
||||||
|
err := r.db.First(&member, "group_id = ? AND user_id = ?", groupID, userID).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &member, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMembers 获取群成员列表
|
||||||
|
func (r *groupRepository) GetMembers(groupID string, page, pageSize int) ([]model.GroupMember, int64, error) {
|
||||||
|
var members []model.GroupMember
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
query := r.db.Model(&model.GroupMember{}).Where("group_id = ?", groupID)
|
||||||
|
query.Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := query.Offset(offset).Limit(pageSize).Order("created_at ASC").Find(&members).Error
|
||||||
|
return members, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMember 更新群成员
|
||||||
|
func (r *groupRepository) UpdateMember(member *model.GroupMember) error {
|
||||||
|
return r.db.Save(member).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveMember 移除群成员
|
||||||
|
func (r *groupRepository) RemoveMember(groupID string, userID string) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 删除成员
|
||||||
|
if err := tx.Where("group_id = ? AND user_id = ?", groupID, userID).Delete(&model.GroupMember{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// 更新群组成员数量
|
||||||
|
return tx.Model(&model.Group{}).Where("id = ?", groupID).
|
||||||
|
Update("member_count", gorm.Expr("member_count - ?", 1)).Error
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMemberCount 获取群成员数量
|
||||||
|
func (r *groupRepository) GetMemberCount(groupID string) (int64, error) {
|
||||||
|
var count int64
|
||||||
|
err := r.db.Model(&model.GroupMember{}).Where("group_id = ?", groupID).Count(&count).Error
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsMember 检查是否是群成员
|
||||||
|
func (r *groupRepository) IsMember(groupID string, userID string) (bool, error) {
|
||||||
|
var count int64
|
||||||
|
err := r.db.Model(&model.GroupMember{}).Where("group_id = ? AND user_id = ?", groupID, userID).Count(&count).Error
|
||||||
|
return count > 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserGroups 获取用户加入的群组列表
|
||||||
|
func (r *groupRepository) GetUserGroups(userID string, page, pageSize int) ([]model.Group, int64, error) {
|
||||||
|
var groups []model.Group
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
// 通过群成员表查询用户加入的群组
|
||||||
|
subQuery := r.db.Model(&model.GroupMember{}).
|
||||||
|
Select("group_id").
|
||||||
|
Where("user_id = ?", userID)
|
||||||
|
|
||||||
|
query := r.db.Model(&model.Group{}).Where("id IN (?)", subQuery)
|
||||||
|
query.Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&groups).Error
|
||||||
|
return groups, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMemberRole 获取成员角色
|
||||||
|
func (r *groupRepository) GetMemberRole(groupID string, userID string) (string, error) {
|
||||||
|
member, err := r.GetMember(groupID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return member.Role, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMemberRole 设置成员角色
|
||||||
|
func (r *groupRepository) SetMemberRole(groupID string, userID string, role string) error {
|
||||||
|
return r.db.Model(&model.GroupMember{}).
|
||||||
|
Where("group_id = ? AND user_id = ?", groupID, userID).
|
||||||
|
Update("role", role).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAdmins 获取群管理员列表
|
||||||
|
func (r *groupRepository) GetAdmins(groupID string) ([]model.GroupMember, error) {
|
||||||
|
var admins []model.GroupMember
|
||||||
|
err := r.db.Where("group_id = ? AND role = ?", groupID, model.GroupRoleAdmin).Find(&admins).Error
|
||||||
|
return admins, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAnnouncement 创建群公告
|
||||||
|
func (r *groupRepository) CreateAnnouncement(announcement *model.GroupAnnouncement) error {
|
||||||
|
return r.db.Create(announcement).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAnnouncements 获取群公告列表
|
||||||
|
func (r *groupRepository) GetAnnouncements(groupID string, page, pageSize int) ([]model.GroupAnnouncement, int64, error) {
|
||||||
|
var announcements []model.GroupAnnouncement
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
query := r.db.Model(&model.GroupAnnouncement{}).Where("group_id = ?", groupID)
|
||||||
|
query.Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
// 置顶的排在前面,然后按时间倒序
|
||||||
|
err := query.Offset(offset).Limit(pageSize).Order("is_pinned DESC, created_at DESC").Find(&announcements).Error
|
||||||
|
return announcements, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAnnouncementByID 根据ID获取群公告
|
||||||
|
func (r *groupRepository) GetAnnouncementByID(id string) (*model.GroupAnnouncement, error) {
|
||||||
|
var announcement model.GroupAnnouncement
|
||||||
|
err := r.db.First(&announcement, "id = ?", id).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &announcement, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAnnouncement 删除群公告
|
||||||
|
func (r *groupRepository) DeleteAnnouncement(id string) error {
|
||||||
|
return r.db.Delete(&model.GroupAnnouncement{}, "id = ?", id).Error
|
||||||
|
}
|
||||||
543
internal/repository/message_repo.go
Normal file
543
internal/repository/message_repo.go
Normal file
@@ -0,0 +1,543 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MessageRepository 消息仓储
|
||||||
|
type MessageRepository struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMessageRepository 创建消息仓储
|
||||||
|
func NewMessageRepository(db *gorm.DB) *MessageRepository {
|
||||||
|
return &MessageRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateMessage 创建消息
|
||||||
|
func (r *MessageRepository) CreateMessage(msg *model.Message) error {
|
||||||
|
return r.db.Create(msg).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversation 获取会话
|
||||||
|
func (r *MessageRepository) GetConversation(id string) (*model.Conversation, error) {
|
||||||
|
var conv model.Conversation
|
||||||
|
err := r.db.Preload("Group").First(&conv, "id = ?", id).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &conv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrCreatePrivateConversation 获取或创建私聊会话
|
||||||
|
// 使用参与者关系表来管理会话
|
||||||
|
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
|
func (r *MessageRepository) GetOrCreatePrivateConversation(user1ID, user2ID string) (*model.Conversation, error) {
|
||||||
|
var conv model.Conversation
|
||||||
|
|
||||||
|
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: user1ID=%s, user2ID=%s\n", user1ID, user2ID)
|
||||||
|
|
||||||
|
// 查找两个用户共同参与的私聊会话
|
||||||
|
err := r.db.Table("conversations c").
|
||||||
|
Joins("INNER JOIN conversation_participants cp1 ON c.id = cp1.conversation_id AND cp1.user_id = ?", user1ID).
|
||||||
|
Joins("INNER JOIN conversation_participants cp2 ON c.id = cp2.conversation_id AND cp2.user_id = ?", user2ID).
|
||||||
|
Where("c.type = ?", model.ConversationTypePrivate).
|
||||||
|
First(&conv).Error
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
_ = r.db.Model(&model.ConversationParticipant{}).
|
||||||
|
Where("conversation_id = ? AND user_id IN ?", conv.ID, []string{user1ID, user2ID}).
|
||||||
|
Update("hidden_at", nil).Error
|
||||||
|
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: found existing conversation, ID=%s\n", conv.ID)
|
||||||
|
return &conv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != gorm.ErrRecordNotFound {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 没找到会话,创建新会话
|
||||||
|
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: no existing conversation found, creating new one\n")
|
||||||
|
conv = model.Conversation{
|
||||||
|
Type: model.ConversationTypePrivate,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用事务创建会话和参与者
|
||||||
|
err = r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
if err := tx.Create(&conv).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建参与者记录 - UserID 存储为 string (UUID)
|
||||||
|
participants := []model.ConversationParticipant{
|
||||||
|
{ConversationID: conv.ID, UserID: user1ID},
|
||||||
|
{ConversationID: conv.ID, UserID: user2ID},
|
||||||
|
}
|
||||||
|
if err := tx.Create(&participants).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: created new conversation, ID=%s\n", conv.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &conv, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversations 获取用户会话列表
|
||||||
|
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
|
func (r *MessageRepository) GetConversations(userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
|
||||||
|
var convs []*model.Conversation
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
// 获取总数
|
||||||
|
r.db.Model(&model.ConversationParticipant{}).
|
||||||
|
Where("user_id = ? AND hidden_at IS NULL", userID).
|
||||||
|
Count(&total)
|
||||||
|
|
||||||
|
if total == 0 {
|
||||||
|
return convs, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
// 查询会话列表并预加载关联数据:
|
||||||
|
// 当前用户维度先按置顶排序,再按更新时间排序
|
||||||
|
err := r.db.Model(&model.Conversation{}).
|
||||||
|
Joins("INNER JOIN conversation_participants cp ON conversations.id = cp.conversation_id").
|
||||||
|
Where("cp.user_id = ? AND cp.hidden_at IS NULL", userID).
|
||||||
|
Preload("Group").
|
||||||
|
Offset(offset).
|
||||||
|
Limit(pageSize).
|
||||||
|
Order("cp.is_pinned DESC").
|
||||||
|
Order("conversations.updated_at DESC").
|
||||||
|
Find(&convs).Error
|
||||||
|
|
||||||
|
return convs, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessages 获取会话消息
|
||||||
|
func (r *MessageRepository) GetMessages(conversationID string, page, pageSize int) ([]*model.Message, int64, error) {
|
||||||
|
var messages []*model.Message
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
r.db.Model(&model.Message{}).Where("conversation_id = ?", conversationID).Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := r.db.Where("conversation_id = ?", conversationID).
|
||||||
|
Offset(offset).
|
||||||
|
Limit(pageSize).
|
||||||
|
Order("seq DESC").
|
||||||
|
Find(&messages).Error
|
||||||
|
|
||||||
|
return messages, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessagesAfterSeq 获取指定seq之后的消息(用于增量同步)
|
||||||
|
func (r *MessageRepository) GetMessagesAfterSeq(conversationID string, afterSeq int64, limit int) ([]*model.Message, error) {
|
||||||
|
var messages []*model.Message
|
||||||
|
err := r.db.Where("conversation_id = ? AND seq > ?", conversationID, afterSeq).
|
||||||
|
Order("seq ASC").
|
||||||
|
Limit(limit).
|
||||||
|
Find(&messages).Error
|
||||||
|
return messages, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessagesBeforeSeq 获取指定seq之前的历史消息(用于下拉加载更多)
|
||||||
|
func (r *MessageRepository) GetMessagesBeforeSeq(conversationID string, beforeSeq int64, limit int) ([]*model.Message, error) {
|
||||||
|
var messages []*model.Message
|
||||||
|
fmt.Printf("[DEBUG] GetMessagesBeforeSeq: conversationID=%s, beforeSeq=%d, limit=%d\n", conversationID, beforeSeq, limit)
|
||||||
|
err := r.db.Where("conversation_id = ? AND seq < ?", conversationID, beforeSeq).
|
||||||
|
Order("seq DESC"). // 降序获取最新消息在前
|
||||||
|
Limit(limit).
|
||||||
|
Find(&messages).Error
|
||||||
|
fmt.Printf("[DEBUG] GetMessagesBeforeSeq: found %d messages, seq range: ", len(messages))
|
||||||
|
for i, m := range messages {
|
||||||
|
if i < 5 || i >= len(messages)-2 {
|
||||||
|
fmt.Printf("%d ", m.Seq)
|
||||||
|
} else if i == 5 {
|
||||||
|
fmt.Printf("... ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
// 反转回正序
|
||||||
|
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
|
||||||
|
messages[i], messages[j] = messages[j], messages[i]
|
||||||
|
}
|
||||||
|
return messages, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversationParticipants 获取会话参与者
|
||||||
|
func (r *MessageRepository) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) {
|
||||||
|
var participants []*model.ConversationParticipant
|
||||||
|
err := r.db.Where("conversation_id = ?", conversationID).Find(&participants).Error
|
||||||
|
return participants, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetParticipant 获取用户在会话中的参与者信息
|
||||||
|
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
|
func (r *MessageRepository) GetParticipant(conversationID string, userID string) (*model.ConversationParticipant, error) {
|
||||||
|
var participant model.ConversationParticipant
|
||||||
|
err := r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).First(&participant).Error
|
||||||
|
if err != nil {
|
||||||
|
// 如果找不到参与者,尝试添加(修复没有参与者记录的问题)
|
||||||
|
if err == gorm.ErrRecordNotFound {
|
||||||
|
// 检查会话是否存在
|
||||||
|
var conv model.Conversation
|
||||||
|
if err := r.db.First(&conv, conversationID).Error; err == nil {
|
||||||
|
// 会话存在,添加参与者
|
||||||
|
participant = model.ConversationParticipant{
|
||||||
|
ConversationID: conversationID,
|
||||||
|
UserID: userID,
|
||||||
|
}
|
||||||
|
if err := r.db.Create(&participant).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &participant, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &participant, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateLastReadSeq 更新已读位置
|
||||||
|
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
|
func (r *MessageRepository) UpdateLastReadSeq(conversationID string, userID string, lastReadSeq int64) error {
|
||||||
|
result := r.db.Model(&model.ConversationParticipant{}).
|
||||||
|
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
||||||
|
Update("last_read_seq", lastReadSeq)
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有更新任何记录,说明参与者记录不存在,需要插入
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
// 尝试插入新记录(跨数据库 upsert)
|
||||||
|
err := r.db.Clauses(clause.OnConflict{
|
||||||
|
Columns: []clause.Column{
|
||||||
|
{Name: "conversation_id"},
|
||||||
|
{Name: "user_id"},
|
||||||
|
},
|
||||||
|
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||||
|
"last_read_seq": lastReadSeq,
|
||||||
|
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||||
|
}),
|
||||||
|
}).Create(&model.ConversationParticipant{
|
||||||
|
ConversationID: conversationID,
|
||||||
|
UserID: userID,
|
||||||
|
LastReadSeq: lastReadSeq,
|
||||||
|
}).Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePinned 更新会话置顶状态(用户维度)
|
||||||
|
func (r *MessageRepository) UpdatePinned(conversationID string, userID string, isPinned bool) error {
|
||||||
|
result := r.db.Model(&model.ConversationParticipant{}).
|
||||||
|
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
||||||
|
Update("is_pinned", isPinned)
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return r.db.Clauses(clause.OnConflict{
|
||||||
|
Columns: []clause.Column{
|
||||||
|
{Name: "conversation_id"},
|
||||||
|
{Name: "user_id"},
|
||||||
|
},
|
||||||
|
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||||
|
"is_pinned": isPinned,
|
||||||
|
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||||
|
}),
|
||||||
|
}).Create(&model.ConversationParticipant{
|
||||||
|
ConversationID: conversationID,
|
||||||
|
UserID: userID,
|
||||||
|
IsPinned: isPinned,
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnreadCount 获取未读消息数
|
||||||
|
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
|
func (r *MessageRepository) GetUnreadCount(conversationID string, userID string) (int64, error) {
|
||||||
|
var participant model.ConversationParticipant
|
||||||
|
err := r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).First(&participant).Error
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
err = r.db.Model(&model.Message{}).
|
||||||
|
Where("conversation_id = ? AND sender_id != ? AND seq > ?", conversationID, userID, participant.LastReadSeq).
|
||||||
|
Count(&count).Error
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateConversationLastSeq 更新会话的最后消息seq和时间
|
||||||
|
func (r *MessageRepository) UpdateConversationLastSeq(conversationID string, seq int64) error {
|
||||||
|
return r.db.Model(&model.Conversation{}).
|
||||||
|
Where("id = ?", conversationID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"last_seq": seq,
|
||||||
|
"last_msg_time": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNextSeq 获取会话的下一个seq值
|
||||||
|
func (r *MessageRepository) GetNextSeq(conversationID string) (int64, error) {
|
||||||
|
var conv model.Conversation
|
||||||
|
err := r.db.Select("last_seq").First(&conv, conversationID).Error
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return conv.LastSeq + 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateMessageWithSeq 创建消息并更新seq(事务操作)
|
||||||
|
func (r *MessageRepository) CreateMessageWithSeq(msg *model.Message) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 获取当前seq并+1
|
||||||
|
var conv model.Conversation
|
||||||
|
if err := tx.Select("last_seq").First(&conv, msg.ConversationID).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
msg.Seq = conv.LastSeq + 1
|
||||||
|
|
||||||
|
// 创建消息
|
||||||
|
if err := tx.Create(msg).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新会话的last_seq
|
||||||
|
if err := tx.Model(&model.Conversation{}).
|
||||||
|
Where("id = ?", msg.ConversationID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"last_seq": msg.Seq,
|
||||||
|
"last_msg_time": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||||
|
}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 新消息到达后,自动恢复被“仅自己删除”的会话
|
||||||
|
if err := tx.Model(&model.ConversationParticipant{}).
|
||||||
|
Where("conversation_id = ?", msg.ConversationID).
|
||||||
|
Update("hidden_at", nil).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllUnreadCount 获取用户所有会话的未读消息总数
|
||||||
|
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
|
func (r *MessageRepository) GetAllUnreadCount(userID string) (int64, error) {
|
||||||
|
var totalUnread int64
|
||||||
|
err := r.db.Table("conversation_participants AS cp").
|
||||||
|
Joins("LEFT JOIN messages AS m ON m.conversation_id = cp.conversation_id AND m.sender_id <> ? AND m.seq > cp.last_read_seq AND m.deleted_at IS NULL", userID).
|
||||||
|
Where("cp.user_id = ?", userID).
|
||||||
|
Select("COALESCE(COUNT(m.id), 0)").
|
||||||
|
Scan(&totalUnread).Error
|
||||||
|
return totalUnread, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessageByID 根据ID获取消息
|
||||||
|
func (r *MessageRepository) GetMessageByID(messageID string) (*model.Message, error) {
|
||||||
|
var message model.Message
|
||||||
|
err := r.db.First(&message, "id = ?", messageID).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &message, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountMessagesBySenderInConversation 统计会话中某用户已发送消息数
|
||||||
|
func (r *MessageRepository) CountMessagesBySenderInConversation(conversationID, senderID string) (int64, error) {
|
||||||
|
var count int64
|
||||||
|
err := r.db.Model(&model.Message{}).
|
||||||
|
Where("conversation_id = ? AND sender_id = ?", conversationID, senderID).
|
||||||
|
Count(&count).Error
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMessageStatus 更新消息状态
|
||||||
|
func (r *MessageRepository) UpdateMessageStatus(messageID int64, status model.MessageStatus) error {
|
||||||
|
return r.db.Model(&model.Message{}).
|
||||||
|
Where("id = ?", messageID).
|
||||||
|
Update("status", status).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrCreateSystemParticipant 获取或创建用户在系统会话中的参与者记录
|
||||||
|
// 系统会话是虚拟会话,但需要参与者记录来跟踪已读状态
|
||||||
|
func (r *MessageRepository) GetOrCreateSystemParticipant(userID string) (*model.ConversationParticipant, error) {
|
||||||
|
var participant model.ConversationParticipant
|
||||||
|
err := r.db.Where("conversation_id = ? AND user_id = ?",
|
||||||
|
model.SystemConversationID, userID).First(&participant).Error
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
return &participant, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != gorm.ErrRecordNotFound {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 自动创建参与者记录
|
||||||
|
participant = model.ConversationParticipant{
|
||||||
|
ConversationID: model.SystemConversationID,
|
||||||
|
UserID: userID,
|
||||||
|
LastReadSeq: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.db.Create(&participant).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &participant, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSystemMessagesUnreadCount 获取系统消息未读数
|
||||||
|
func (r *MessageRepository) GetSystemMessagesUnreadCount(userID string) (int64, error) {
|
||||||
|
// 获取或创建参与者记录
|
||||||
|
participant, err := r.GetOrCreateSystemParticipant(userID)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算未读数:查询 seq > last_read_seq 的消息
|
||||||
|
var count int64
|
||||||
|
err = r.db.Model(&model.Message{}).
|
||||||
|
Where("conversation_id = ? AND seq > ?",
|
||||||
|
model.SystemConversationID, participant.LastReadSeq).
|
||||||
|
Count(&count).Error
|
||||||
|
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAllSystemMessagesAsRead 标记所有系统消息已读
|
||||||
|
func (r *MessageRepository) MarkAllSystemMessagesAsRead(userID string) error {
|
||||||
|
// 获取系统会话的最新 seq
|
||||||
|
var maxSeq int64
|
||||||
|
err := r.db.Model(&model.Message{}).
|
||||||
|
Where("conversation_id = ?", model.SystemConversationID).
|
||||||
|
Select("COALESCE(MAX(seq), 0)").
|
||||||
|
Scan(&maxSeq).Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用跨数据库 upsert 方式更新或创建参与者记录
|
||||||
|
return r.db.Clauses(clause.OnConflict{
|
||||||
|
Columns: []clause.Column{
|
||||||
|
{Name: "conversation_id"},
|
||||||
|
{Name: "user_id"},
|
||||||
|
},
|
||||||
|
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||||
|
"last_read_seq": maxSeq,
|
||||||
|
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||||
|
}),
|
||||||
|
}).Create(&model.ConversationParticipant{
|
||||||
|
ConversationID: model.SystemConversationID,
|
||||||
|
UserID: userID,
|
||||||
|
LastReadSeq: maxSeq,
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversationByGroupID 通过群组ID获取会话
|
||||||
|
func (r *MessageRepository) GetConversationByGroupID(groupID string) (*model.Conversation, error) {
|
||||||
|
var conv model.Conversation
|
||||||
|
err := r.db.Where("group_id = ?", groupID).First(&conv).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &conv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveParticipant 移除会话参与者
|
||||||
|
// 当用户退出群聊时,需要同时移除其在对应会话中的参与者记录
|
||||||
|
func (r *MessageRepository) RemoveParticipant(conversationID string, userID string) error {
|
||||||
|
return r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
||||||
|
Delete(&model.ConversationParticipant{}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddParticipant 添加会话参与者
|
||||||
|
// 当用户加入群聊时,需要同时将其添加到对应会话的参与者记录
|
||||||
|
func (r *MessageRepository) AddParticipant(conversationID string, userID string) error {
|
||||||
|
// 先检查是否已经是参与者
|
||||||
|
var count int64
|
||||||
|
err := r.db.Model(&model.ConversationParticipant{}).
|
||||||
|
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
||||||
|
Count(&count).Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果已经是参与者,直接返回
|
||||||
|
if count > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加参与者
|
||||||
|
participant := model.ConversationParticipant{
|
||||||
|
ConversationID: conversationID,
|
||||||
|
UserID: userID,
|
||||||
|
LastReadSeq: 0,
|
||||||
|
}
|
||||||
|
return r.db.Create(&participant).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteConversationByGroupID 删除群组对应的会话及其参与者
|
||||||
|
// 当解散群组时调用
|
||||||
|
func (r *MessageRepository) DeleteConversationByGroupID(groupID string) error {
|
||||||
|
// 获取群组对应的会话
|
||||||
|
conv, err := r.GetConversationByGroupID(groupID)
|
||||||
|
if err != nil {
|
||||||
|
// 如果会话不存在,直接返回
|
||||||
|
if err == gorm.ErrRecordNotFound {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 删除会话参与者
|
||||||
|
if err := tx.Where("conversation_id = ?", conv.ID).Delete(&model.ConversationParticipant{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// 删除会话中的消息
|
||||||
|
if err := tx.Where("conversation_id = ?", conv.ID).Delete(&model.Message{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// 删除会话
|
||||||
|
if err := tx.Delete(&model.Conversation{}, "id = ?", conv.ID).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HideConversationForUser 仅对当前用户隐藏会话(私聊删除)
|
||||||
|
func (r *MessageRepository) HideConversationForUser(conversationID, userID string) error {
|
||||||
|
now := time.Now()
|
||||||
|
return r.db.Model(&model.ConversationParticipant{}).
|
||||||
|
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
||||||
|
Update("hidden_at", &now).Error
|
||||||
|
}
|
||||||
78
internal/repository/notification_repo.go
Normal file
78
internal/repository/notification_repo.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NotificationRepository 通知仓储
|
||||||
|
type NotificationRepository struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNotificationRepository 创建通知仓储
|
||||||
|
func NewNotificationRepository(db *gorm.DB) *NotificationRepository {
|
||||||
|
return &NotificationRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 创建通知
|
||||||
|
func (r *NotificationRepository) Create(notification *model.Notification) error {
|
||||||
|
return r.db.Create(notification).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID 根据ID获取通知
|
||||||
|
func (r *NotificationRepository) GetByID(id string) (*model.Notification, error) {
|
||||||
|
var notification model.Notification
|
||||||
|
err := r.db.First(¬ification, "id = ?", id).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ¬ification, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByUserID 获取用户通知
|
||||||
|
func (r *NotificationRepository) GetByUserID(userID string, page, pageSize int, unreadOnly bool) ([]*model.Notification, int64, error) {
|
||||||
|
var notifications []*model.Notification
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
query := r.db.Model(&model.Notification{}).Where("user_id = ?", userID)
|
||||||
|
|
||||||
|
if unreadOnly {
|
||||||
|
query = query.Where("is_read = ?", false)
|
||||||
|
}
|
||||||
|
|
||||||
|
query.Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(¬ifications).Error
|
||||||
|
|
||||||
|
return notifications, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAsRead 标记为已读
|
||||||
|
func (r *NotificationRepository) MarkAsRead(id string) error {
|
||||||
|
return r.db.Model(&model.Notification{}).Where("id = ?", id).Update("is_read", true).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAllAsRead 标记所有为已读
|
||||||
|
func (r *NotificationRepository) MarkAllAsRead(userID string) error {
|
||||||
|
return r.db.Model(&model.Notification{}).Where("user_id = ?", userID).Update("is_read", true).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除通知
|
||||||
|
func (r *NotificationRepository) Delete(id string) error {
|
||||||
|
return r.db.Delete(&model.Notification{}, "id = ?", id).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnreadCount 获取未读数量
|
||||||
|
func (r *NotificationRepository) GetUnreadCount(userID string) (int64, error) {
|
||||||
|
var count int64
|
||||||
|
err := r.db.Model(&model.Notification{}).Where("user_id = ? AND is_read = ?", userID, false).Count(&count).Error
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAllByUserID 删除用户所有通知
|
||||||
|
func (r *NotificationRepository) DeleteAllByUserID(userID string) error {
|
||||||
|
return r.db.Where("user_id = ?", userID).Delete(&model.Notification{}).Error
|
||||||
|
}
|
||||||
360
internal/repository/post_repo.go
Normal file
360
internal/repository/post_repo.go
Normal file
@@ -0,0 +1,360 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostRepository 帖子仓储
|
||||||
|
type PostRepository struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPostRepository 创建帖子仓储
|
||||||
|
func NewPostRepository(db *gorm.DB) *PostRepository {
|
||||||
|
return &PostRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 创建帖子
|
||||||
|
func (r *PostRepository) Create(post *model.Post, images []string) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 创建帖子
|
||||||
|
if err := tx.Create(post).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建图片记录
|
||||||
|
for i, url := range images {
|
||||||
|
image := &model.PostImage{
|
||||||
|
PostID: post.ID,
|
||||||
|
URL: url,
|
||||||
|
SortOrder: i,
|
||||||
|
}
|
||||||
|
if err := tx.Create(image).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID 根据ID获取帖子
|
||||||
|
func (r *PostRepository) GetByID(id string) (*model.Post, error) {
|
||||||
|
var post model.Post
|
||||||
|
err := r.db.Preload("User").Preload("Images").First(&post, "id = ?", id).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &post, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 更新帖子
|
||||||
|
func (r *PostRepository) Update(post *model.Post) error {
|
||||||
|
return r.db.Save(post).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateModerationStatus 更新帖子审核状态
|
||||||
|
func (r *PostRepository) UpdateModerationStatus(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error {
|
||||||
|
updates := map[string]interface{}{
|
||||||
|
"status": status,
|
||||||
|
"reviewed_at": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||||
|
"reviewed_by": reviewedBy,
|
||||||
|
"reject_reason": rejectReason,
|
||||||
|
}
|
||||||
|
return r.db.Model(&model.Post{}).Where("id = ?", postID).Updates(updates).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除帖子(软删除,同时清理关联数据)
|
||||||
|
func (r *PostRepository) Delete(id string) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 删除帖子图片
|
||||||
|
if err := tx.Where("post_id = ?", id).Delete(&model.PostImage{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除帖子点赞记录
|
||||||
|
if err := tx.Where("post_id = ?", id).Delete(&model.PostLike{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除帖子收藏记录
|
||||||
|
if err := tx.Where("post_id = ?", id).Delete(&model.Favorite{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除评论点赞记录(子查询获取该帖子所有评论ID)
|
||||||
|
if err := tx.Where("comment_id IN (SELECT id FROM comments WHERE post_id = ?)", id).Delete(&model.CommentLike{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除帖子评论
|
||||||
|
if err := tx.Where("post_id = ?", id).Delete(&model.Comment{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 最后删除帖子本身(软删除)
|
||||||
|
return tx.Delete(&model.Post{}, "id = ?", id).Error
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 分页获取帖子列表
|
||||||
|
func (r *PostRepository) List(page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
||||||
|
var posts []*model.Post
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
query := r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished)
|
||||||
|
|
||||||
|
if userID != "" {
|
||||||
|
query = query.Where("user_id = ?", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
query.Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := query.Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
|
||||||
|
|
||||||
|
return posts, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserPosts 获取用户帖子
|
||||||
|
func (r *PostRepository) GetUserPosts(userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||||
|
var posts []*model.Post
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
r.db.Model(&model.Post{}).Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := r.db.Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
|
||||||
|
|
||||||
|
return posts, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFavorites 获取用户收藏
|
||||||
|
func (r *PostRepository) GetFavorites(userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||||
|
var posts []*model.Post
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
subQuery := r.db.Model(&model.Favorite{}).Where("user_id = ?", userID).Select("post_id")
|
||||||
|
r.db.Model(&model.Post{}).Where("id IN (?) AND status = ?", subQuery, model.PostStatusPublished).Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := r.db.Where("id IN (?) AND status = ?", subQuery, model.PostStatusPublished).Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
|
||||||
|
|
||||||
|
return posts, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Like 点赞帖子
|
||||||
|
func (r *PostRepository) Like(postID, userID string) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 检查是否已经点赞
|
||||||
|
var existing model.PostLike
|
||||||
|
err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&existing).Error
|
||||||
|
if err == nil {
|
||||||
|
// 已经点赞,直接返回
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != gorm.ErrRecordNotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建点赞记录
|
||||||
|
if err := tx.Create(&model.PostLike{
|
||||||
|
PostID: postID,
|
||||||
|
UserID: userID,
|
||||||
|
}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 增加帖子点赞数并同步热度分
|
||||||
|
return tx.Model(&model.Post{}).Where("id = ?", postID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"likes_count": gorm.Expr("likes_count + 1"),
|
||||||
|
"hot_score": gorm.Expr("(likes_count + 1) * 2 + comments_count * 3 + views_count * 0.1"),
|
||||||
|
}).Error
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unlike 取消点赞
|
||||||
|
func (r *PostRepository) Unlike(postID, userID string) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
result := tx.Where("post_id = ? AND user_id = ?", postID, userID).Delete(&model.PostLike{})
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected > 0 {
|
||||||
|
// 减少帖子点赞数并同步热度分
|
||||||
|
return tx.Model(&model.Post{}).Where("id = ?", postID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"likes_count": gorm.Expr("likes_count - 1"),
|
||||||
|
"hot_score": gorm.Expr("(likes_count - 1) * 2 + comments_count * 3 + views_count * 0.1"),
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLiked 检查是否点赞
|
||||||
|
func (r *PostRepository) IsLiked(postID, userID string) bool {
|
||||||
|
var count int64
|
||||||
|
r.db.Model(&model.PostLike{}).Where("post_id = ? AND user_id = ?", postID, userID).Count(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Favorite 收藏帖子
|
||||||
|
func (r *PostRepository) Favorite(postID, userID string) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 检查是否已经收藏
|
||||||
|
var existing model.Favorite
|
||||||
|
err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&existing).Error
|
||||||
|
if err == nil {
|
||||||
|
// 已经收藏,直接返回
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != gorm.ErrRecordNotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建收藏记录
|
||||||
|
if err := tx.Create(&model.Favorite{
|
||||||
|
PostID: postID,
|
||||||
|
UserID: userID,
|
||||||
|
}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 增加帖子收藏数
|
||||||
|
return tx.Model(&model.Post{}).Where("id = ?", postID).
|
||||||
|
UpdateColumn("favorites_count", gorm.Expr("favorites_count + 1")).Error
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unfavorite 取消收藏
|
||||||
|
func (r *PostRepository) Unfavorite(postID, userID string) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
result := tx.Where("post_id = ? AND user_id = ?", postID, userID).Delete(&model.Favorite{})
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
if result.RowsAffected > 0 {
|
||||||
|
// 减少帖子收藏数
|
||||||
|
return tx.Model(&model.Post{}).Where("id = ?", postID).
|
||||||
|
UpdateColumn("favorites_count", gorm.Expr("favorites_count - 1")).Error
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsFavorited 检查是否收藏
|
||||||
|
func (r *PostRepository) IsFavorited(postID, userID string) bool {
|
||||||
|
var count int64
|
||||||
|
r.db.Model(&model.Favorite{}).Where("post_id = ? AND user_id = ?", postID, userID).Count(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementViews 增加帖子观看量
|
||||||
|
func (r *PostRepository) IncrementViews(postID string) error {
|
||||||
|
return r.db.Model(&model.Post{}).Where("id = ?", postID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"views_count": gorm.Expr("views_count + 1"),
|
||||||
|
"hot_score": gorm.Expr("likes_count * 2 + comments_count * 3 + (views_count + 1) * 0.1"),
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search 搜索帖子
|
||||||
|
func (r *PostRepository) Search(keyword string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||||
|
var posts []*model.Post
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
query := r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished)
|
||||||
|
|
||||||
|
// 搜索标题和内容
|
||||||
|
if keyword != "" {
|
||||||
|
if r.db.Dialector.Name() == "postgres" {
|
||||||
|
// PostgreSQL 使用全文检索表达式,为 pg_trgm/GIN 索引升级预留路径
|
||||||
|
query = query.Where(
|
||||||
|
"to_tsvector('simple', COALESCE(title, '') || ' ' || COALESCE(content, '')) @@ plainto_tsquery('simple', ?)",
|
||||||
|
keyword,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
searchPattern := "%" + keyword + "%"
|
||||||
|
query = query.Where("title LIKE ? OR content LIKE ?", searchPattern, searchPattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
query.Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := query.Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
|
||||||
|
|
||||||
|
return posts, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFollowingPosts 获取关注用户的帖子
|
||||||
|
func (r *PostRepository) GetFollowingPosts(userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||||
|
var posts []*model.Post
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
// 子查询:获取当前用户关注的所有用户ID
|
||||||
|
subQuery := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Select("following_id")
|
||||||
|
|
||||||
|
// 统计总数
|
||||||
|
r.db.Model(&model.Post{}).Where("user_id IN (?) AND status = ?", subQuery, model.PostStatusPublished).Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := r.db.Where("user_id IN (?) AND status = ?", subQuery, model.PostStatusPublished).
|
||||||
|
Preload("User").Preload("Images").
|
||||||
|
Offset(offset).Limit(pageSize).
|
||||||
|
Order("created_at DESC").
|
||||||
|
Find(&posts).Error
|
||||||
|
|
||||||
|
return posts, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHotPosts 获取热门帖子(按点赞数和评论数排序)
|
||||||
|
func (r *PostRepository) GetHotPosts(page, pageSize int) ([]*model.Post, int64, error) {
|
||||||
|
var posts []*model.Post
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished).Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
// 热门排序使用预计算热度分,避免每次请求进行表达式排序计算
|
||||||
|
err := r.db.Where("status = ?", model.PostStatusPublished).Preload("User").Preload("Images").
|
||||||
|
Offset(offset).Limit(pageSize).
|
||||||
|
Order("hot_score DESC, created_at DESC").
|
||||||
|
Find(&posts).Error
|
||||||
|
|
||||||
|
return posts, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByIDs 根据ID列表获取帖子(保持传入顺序)
|
||||||
|
func (r *PostRepository) GetByIDs(ids []string) ([]*model.Post, error) {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return []*model.Post{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var posts []*model.Post
|
||||||
|
err := r.db.Preload("User").Preload("Images").
|
||||||
|
Where("id IN ? AND status = ?", ids, model.PostStatusPublished).
|
||||||
|
Find(&posts).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 按传入ID顺序排序
|
||||||
|
postMap := make(map[string]*model.Post)
|
||||||
|
for _, post := range posts {
|
||||||
|
postMap[post.ID] = post
|
||||||
|
}
|
||||||
|
|
||||||
|
ordered := make([]*model.Post, 0, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
if post, ok := postMap[id]; ok {
|
||||||
|
ordered = append(ordered, post)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ordered, nil
|
||||||
|
}
|
||||||
172
internal/repository/push_repo.go
Normal file
172
internal/repository/push_repo.go
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PushRecordRepository 推送记录仓储
|
||||||
|
type PushRecordRepository struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPushRecordRepository 创建推送记录仓储
|
||||||
|
func NewPushRecordRepository(db *gorm.DB) *PushRecordRepository {
|
||||||
|
return &PushRecordRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 创建推送记录
|
||||||
|
func (r *PushRecordRepository) Create(record *model.PushRecord) error {
|
||||||
|
return r.db.Create(record).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID 根据ID获取推送记录
|
||||||
|
func (r *PushRecordRepository) GetByID(id int64) (*model.PushRecord, error) {
|
||||||
|
var record model.PushRecord
|
||||||
|
err := r.db.First(&record, "id = ?", id).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 更新推送记录
|
||||||
|
func (r *PushRecordRepository) Update(record *model.PushRecord) error {
|
||||||
|
return r.db.Save(record).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPendingPushes 获取待推送记录
|
||||||
|
func (r *PushRecordRepository) GetPendingPushes(limit int) ([]*model.PushRecord, error) {
|
||||||
|
var records []*model.PushRecord
|
||||||
|
err := r.db.Where("push_status = ?", model.PushStatusPending).
|
||||||
|
Where("expired_at IS NULL OR expired_at > ?", time.Now()).
|
||||||
|
Order("created_at ASC").
|
||||||
|
Limit(limit).
|
||||||
|
Find(&records).Error
|
||||||
|
return records, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByUserID 根据用户ID获取推送记录
|
||||||
|
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
|
func (r *PushRecordRepository) GetByUserID(userID string, limit, offset int) ([]*model.PushRecord, error) {
|
||||||
|
var records []*model.PushRecord
|
||||||
|
err := r.db.Where("user_id = ?", userID).
|
||||||
|
Order("created_at DESC").
|
||||||
|
Offset(offset).
|
||||||
|
Limit(limit).
|
||||||
|
Find(&records).Error
|
||||||
|
return records, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByMessageID 根据消息ID获取推送记录
|
||||||
|
func (r *PushRecordRepository) GetByMessageID(messageID int64) ([]*model.PushRecord, error) {
|
||||||
|
var records []*model.PushRecord
|
||||||
|
err := r.db.Where("message_id = ?", messageID).
|
||||||
|
Order("created_at DESC").
|
||||||
|
Find(&records).Error
|
||||||
|
return records, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFailedPushesForRetry 获取失败待重试的推送
|
||||||
|
func (r *PushRecordRepository) GetFailedPushesForRetry(limit int) ([]*model.PushRecord, error) {
|
||||||
|
var records []*model.PushRecord
|
||||||
|
err := r.db.Where("push_status = ?", model.PushStatusFailed).
|
||||||
|
Where("retry_count < max_retry").
|
||||||
|
Where("expired_at IS NULL OR expired_at > ?", time.Now()).
|
||||||
|
Order("created_at ASC").
|
||||||
|
Limit(limit).
|
||||||
|
Find(&records).Error
|
||||||
|
return records, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchCreate 批量创建推送记录
|
||||||
|
func (r *PushRecordRepository) BatchCreate(records []*model.PushRecord) error {
|
||||||
|
if len(records) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return r.db.Create(&records).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchUpdateStatus 批量更新推送状态
|
||||||
|
func (r *PushRecordRepository) BatchUpdateStatus(ids []int64, status model.PushStatus) error {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
updates := map[string]interface{}{
|
||||||
|
"push_status": status,
|
||||||
|
}
|
||||||
|
if status == model.PushStatusPushed {
|
||||||
|
updates["pushed_at"] = time.Now()
|
||||||
|
}
|
||||||
|
return r.db.Model(&model.PushRecord{}).
|
||||||
|
Where("id IN ?", ids).
|
||||||
|
Updates(updates).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStatus 更新单条记录状态
|
||||||
|
func (r *PushRecordRepository) UpdateStatus(id int64, status model.PushStatus) error {
|
||||||
|
updates := map[string]interface{}{
|
||||||
|
"push_status": status,
|
||||||
|
}
|
||||||
|
if status == model.PushStatusPushed {
|
||||||
|
updates["pushed_at"] = time.Now()
|
||||||
|
}
|
||||||
|
return r.db.Model(&model.PushRecord{}).
|
||||||
|
Where("id = ?", id).
|
||||||
|
Updates(updates).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAsFailed 标记为失败
|
||||||
|
func (r *PushRecordRepository) MarkAsFailed(id int64, errMsg string) error {
|
||||||
|
return r.db.Model(&model.PushRecord{}).
|
||||||
|
Where("id = ?", id).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"push_status": model.PushStatusFailed,
|
||||||
|
"error_message": errMsg,
|
||||||
|
"retry_count": gorm.Expr("retry_count + 1"),
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAsDelivered 标记为已送达
|
||||||
|
func (r *PushRecordRepository) MarkAsDelivered(id int64) error {
|
||||||
|
return r.db.Model(&model.PushRecord{}).
|
||||||
|
Where("id = ?", id).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"push_status": model.PushStatusDelivered,
|
||||||
|
"delivered_at": time.Now(),
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteExpiredRecords 删除过期的推送记录(软删除)
|
||||||
|
func (r *PushRecordRepository) DeleteExpiredRecords() error {
|
||||||
|
return r.db.Where("expired_at IS NOT NULL AND expired_at < ?", time.Now()).
|
||||||
|
Delete(&model.PushRecord{}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStatsByUserID 获取用户推送统计
|
||||||
|
func (r *PushRecordRepository) GetStatsByUserID(userID int64) (map[model.PushStatus]int64, error) {
|
||||||
|
type statusCount struct {
|
||||||
|
Status model.PushStatus
|
||||||
|
Count int64
|
||||||
|
}
|
||||||
|
var results []statusCount
|
||||||
|
|
||||||
|
err := r.db.Model(&model.PushRecord{}).
|
||||||
|
Select("push_status as status, count(*) as count").
|
||||||
|
Where("user_id = ?", userID).
|
||||||
|
Group("push_status").
|
||||||
|
Scan(&results).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := make(map[model.PushStatus]int64)
|
||||||
|
for _, r := range results {
|
||||||
|
stats[r.Status] = r.Count
|
||||||
|
}
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
112
internal/repository/sticker_repo.go
Normal file
112
internal/repository/sticker_repo.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StickerRepository 自定义表情仓库接口
|
||||||
|
type StickerRepository interface {
|
||||||
|
// 获取用户的所有表情
|
||||||
|
GetByUserID(userID string) ([]model.UserSticker, error)
|
||||||
|
// 根据ID获取表情
|
||||||
|
GetByID(id string) (*model.UserSticker, error)
|
||||||
|
// 创建表情
|
||||||
|
Create(sticker *model.UserSticker) error
|
||||||
|
// 删除表情
|
||||||
|
Delete(id string) error
|
||||||
|
// 删除用户的所有表情
|
||||||
|
DeleteByUserID(userID string) error
|
||||||
|
// 检查表情是否存在
|
||||||
|
Exists(userID string, url string) (bool, error)
|
||||||
|
// 更新排序
|
||||||
|
UpdateSortOrder(id string, sortOrder int) error
|
||||||
|
// 批量更新排序
|
||||||
|
BatchUpdateSortOrder(userID string, orders map[string]int) error
|
||||||
|
// 获取用户表情数量
|
||||||
|
CountByUserID(userID string) (int64, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stickerRepository 自定义表情仓库实现
|
||||||
|
type stickerRepository struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStickerRepository 创建自定义表情仓库
|
||||||
|
func NewStickerRepository(db *gorm.DB) StickerRepository {
|
||||||
|
return &stickerRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByUserID 获取用户的所有表情
|
||||||
|
func (r *stickerRepository) GetByUserID(userID string) ([]model.UserSticker, error) {
|
||||||
|
var stickers []model.UserSticker
|
||||||
|
err := r.db.Where("user_id = ?", userID).
|
||||||
|
Order("sort_order ASC, created_at DESC").
|
||||||
|
Find(&stickers).Error
|
||||||
|
return stickers, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID 根据ID获取表情
|
||||||
|
func (r *stickerRepository) GetByID(id string) (*model.UserSticker, error) {
|
||||||
|
var sticker model.UserSticker
|
||||||
|
err := r.db.Where("id = ?", id).First(&sticker).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &sticker, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 创建表情
|
||||||
|
func (r *stickerRepository) Create(sticker *model.UserSticker) error {
|
||||||
|
return r.db.Create(sticker).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除表情
|
||||||
|
func (r *stickerRepository) Delete(id string) error {
|
||||||
|
return r.db.Where("id = ?", id).Delete(&model.UserSticker{}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteByUserID 删除用户的所有表情
|
||||||
|
func (r *stickerRepository) DeleteByUserID(userID string) error {
|
||||||
|
return r.db.Where("user_id = ?", userID).Delete(&model.UserSticker{}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exists 检查表情是否存在
|
||||||
|
func (r *stickerRepository) Exists(userID string, url string) (bool, error) {
|
||||||
|
var count int64
|
||||||
|
err := r.db.Model(&model.UserSticker{}).
|
||||||
|
Where("user_id = ? AND url = ?", userID, url).
|
||||||
|
Count(&count).Error
|
||||||
|
return count > 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateSortOrder 更新排序
|
||||||
|
func (r *stickerRepository) UpdateSortOrder(id string, sortOrder int) error {
|
||||||
|
return r.db.Model(&model.UserSticker{}).
|
||||||
|
Where("id = ?", id).
|
||||||
|
Update("sort_order", sortOrder).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchUpdateSortOrder 批量更新排序
|
||||||
|
func (r *stickerRepository) BatchUpdateSortOrder(userID string, orders map[string]int) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
for id, sortOrder := range orders {
|
||||||
|
if err := tx.Model(&model.UserSticker{}).
|
||||||
|
Where("id = ? AND user_id = ?", id, userID).
|
||||||
|
Update("sort_order", sortOrder).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountByUserID 获取用户表情数量
|
||||||
|
func (r *stickerRepository) CountByUserID(userID string) (int64, error) {
|
||||||
|
var count int64
|
||||||
|
err := r.db.Model(&model.UserSticker{}).
|
||||||
|
Where("user_id = ?", userID).
|
||||||
|
Count(&count).Error
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
114
internal/repository/system_notification_repo.go
Normal file
114
internal/repository/system_notification_repo.go
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SystemNotificationRepository 系统通知仓储
|
||||||
|
type SystemNotificationRepository struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSystemNotificationRepository 创建系统通知仓储
|
||||||
|
func NewSystemNotificationRepository(db *gorm.DB) *SystemNotificationRepository {
|
||||||
|
return &SystemNotificationRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 创建系统通知
|
||||||
|
func (r *SystemNotificationRepository) Create(notification *model.SystemNotification) error {
|
||||||
|
return r.db.Create(notification).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID 根据ID获取通知
|
||||||
|
func (r *SystemNotificationRepository) GetByID(id int64) (*model.SystemNotification, error) {
|
||||||
|
var notification model.SystemNotification
|
||||||
|
err := r.db.First(¬ification, "id = ?", id).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ¬ification, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByReceiverID 获取用户的通知列表
|
||||||
|
func (r *SystemNotificationRepository) GetByReceiverID(receiverID string, page, pageSize int) ([]*model.SystemNotification, int64, error) {
|
||||||
|
var notifications []*model.SystemNotification
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
query := r.db.Model(&model.SystemNotification{}).Where("receiver_id = ?", receiverID)
|
||||||
|
query.Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := query.Offset(offset).
|
||||||
|
Limit(pageSize).
|
||||||
|
Order("created_at DESC").
|
||||||
|
Find(¬ifications).Error
|
||||||
|
|
||||||
|
return notifications, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnreadByReceiverID 获取用户的未读通知列表
|
||||||
|
func (r *SystemNotificationRepository) GetUnreadByReceiverID(receiverID string, limit int) ([]*model.SystemNotification, error) {
|
||||||
|
var notifications []*model.SystemNotification
|
||||||
|
err := r.db.Where("receiver_id = ? AND is_read = ?", receiverID, false).
|
||||||
|
Order("created_at DESC").
|
||||||
|
Limit(limit).
|
||||||
|
Find(¬ifications).Error
|
||||||
|
return notifications, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnreadCount 获取用户未读通知数量
|
||||||
|
func (r *SystemNotificationRepository) GetUnreadCount(receiverID string) (int64, error) {
|
||||||
|
var count int64
|
||||||
|
err := r.db.Model(&model.SystemNotification{}).
|
||||||
|
Where("receiver_id = ? AND is_read = ?", receiverID, false).
|
||||||
|
Count(&count).Error
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAsRead 标记单条通知为已读
|
||||||
|
func (r *SystemNotificationRepository) MarkAsRead(id int64, receiverID string) error {
|
||||||
|
now := model.SystemNotification{}.UpdatedAt
|
||||||
|
return r.db.Model(&model.SystemNotification{}).
|
||||||
|
Where("id = ? AND receiver_id = ?", id, receiverID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"is_read": true,
|
||||||
|
"read_at": now,
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAllAsRead 标记用户所有通知为已读
|
||||||
|
func (r *SystemNotificationRepository) MarkAllAsRead(receiverID string) error {
|
||||||
|
now := model.SystemNotification{}.UpdatedAt
|
||||||
|
return r.db.Model(&model.SystemNotification{}).
|
||||||
|
Where("receiver_id = ? AND is_read = ?", receiverID, false).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"is_read": true,
|
||||||
|
"read_at": now,
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除通知(软删除)
|
||||||
|
func (r *SystemNotificationRepository) Delete(id int64, receiverID string) error {
|
||||||
|
return r.db.Where("id = ? AND receiver_id = ?", id, receiverID).
|
||||||
|
Delete(&model.SystemNotification{}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByType 获取用户指定类型的通知
|
||||||
|
func (r *SystemNotificationRepository) GetByType(receiverID string, notifyType model.SystemNotificationType, page, pageSize int) ([]*model.SystemNotification, int64, error) {
|
||||||
|
var notifications []*model.SystemNotification
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
query := r.db.Model(&model.SystemNotification{}).
|
||||||
|
Where("receiver_id = ? AND type = ?", receiverID, notifyType)
|
||||||
|
query.Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := query.Offset(offset).
|
||||||
|
Limit(pageSize).
|
||||||
|
Order("created_at DESC").
|
||||||
|
Find(¬ifications).Error
|
||||||
|
|
||||||
|
return notifications, total, err
|
||||||
|
}
|
||||||
404
internal/repository/user_repo.go
Normal file
404
internal/repository/user_repo.go
Normal file
@@ -0,0 +1,404 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserRepository 用户仓储
|
||||||
|
type UserRepository struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUserRepository 创建用户仓储
|
||||||
|
func NewUserRepository(db *gorm.DB) *UserRepository {
|
||||||
|
return &UserRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 创建用户
|
||||||
|
func (r *UserRepository) Create(user *model.User) error {
|
||||||
|
return r.db.Create(user).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID 根据ID获取用户
|
||||||
|
func (r *UserRepository) GetByID(id string) (*model.User, error) {
|
||||||
|
var user model.User
|
||||||
|
err := r.db.First(&user, "id = ?", id).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByUsername 根据用户名获取用户
|
||||||
|
func (r *UserRepository) GetByUsername(username string) (*model.User, error) {
|
||||||
|
var user model.User
|
||||||
|
err := r.db.First(&user, "username = ?", username).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByEmail 根据邮箱获取用户
|
||||||
|
func (r *UserRepository) GetByEmail(email string) (*model.User, error) {
|
||||||
|
var user model.User
|
||||||
|
err := r.db.First(&user, "email = ?", email).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByPhone 根据手机号获取用户
|
||||||
|
func (r *UserRepository) GetByPhone(phone string) (*model.User, error) {
|
||||||
|
var user model.User
|
||||||
|
err := r.db.First(&user, "phone = ?", phone).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 更新用户
|
||||||
|
func (r *UserRepository) Update(user *model.User) error {
|
||||||
|
return r.db.Save(user).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除用户
|
||||||
|
func (r *UserRepository) Delete(id string) error {
|
||||||
|
return r.db.Delete(&model.User{}, "id = ?", id).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 分页获取用户列表
|
||||||
|
func (r *UserRepository) List(page, pageSize int) ([]*model.User, int64, error) {
|
||||||
|
var users []*model.User
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
r.db.Model(&model.User{}).Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := r.db.Order("created_at DESC, id DESC").Offset(offset).Limit(pageSize).Find(&users).Error
|
||||||
|
|
||||||
|
return users, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFollowers 获取用户粉丝
|
||||||
|
func (r *UserRepository) GetFollowers(userID string, page, pageSize int) ([]*model.User, int64, error) {
|
||||||
|
var users []*model.User
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
subQuery := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Select("follower_id")
|
||||||
|
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := r.db.Where("id IN (?)", subQuery).
|
||||||
|
Order("created_at DESC, id DESC").
|
||||||
|
Offset(offset).Limit(pageSize).
|
||||||
|
Find(&users).Error
|
||||||
|
|
||||||
|
return users, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFollowing 获取用户关注
|
||||||
|
func (r *UserRepository) GetFollowing(userID string, page, pageSize int) ([]*model.User, int64, error) {
|
||||||
|
var users []*model.User
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
subQuery := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Select("following_id")
|
||||||
|
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := r.db.Where("id IN (?)", subQuery).
|
||||||
|
Order("created_at DESC, id DESC").
|
||||||
|
Offset(offset).Limit(pageSize).
|
||||||
|
Find(&users).Error
|
||||||
|
|
||||||
|
return users, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateFollow 创建关注关系
|
||||||
|
func (r *UserRepository) CreateFollow(follow *model.Follow) error {
|
||||||
|
return r.db.Create(follow).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteFollow 删除关注关系
|
||||||
|
func (r *UserRepository) DeleteFollow(followerID, followingID string) error {
|
||||||
|
return r.db.Where("follower_id = ? AND following_id = ?", followerID, followingID).Delete(&model.Follow{}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsFollowing 检查是否关注了某用户
|
||||||
|
func (r *UserRepository) IsFollowing(followerID, followingID string) (bool, error) {
|
||||||
|
var count int64
|
||||||
|
err := r.db.Model(&model.Follow{}).Where("follower_id = ? AND following_id = ?", followerID, followingID).Count(&count).Error
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return count > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementFollowersCount 增加用户粉丝数
|
||||||
|
func (r *UserRepository) IncrementFollowersCount(userID string) error {
|
||||||
|
return r.db.Model(&model.User{}).Where("id = ?", userID).
|
||||||
|
UpdateColumn("followers_count", gorm.Expr("followers_count + 1")).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecrementFollowersCount 减少用户粉丝数
|
||||||
|
func (r *UserRepository) DecrementFollowersCount(userID string) error {
|
||||||
|
return r.db.Model(&model.User{}).Where("id = ? AND followers_count > 0", userID).
|
||||||
|
UpdateColumn("followers_count", gorm.Expr("followers_count - 1")).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementFollowingCount 增加用户关注数
|
||||||
|
func (r *UserRepository) IncrementFollowingCount(userID string) error {
|
||||||
|
return r.db.Model(&model.User{}).Where("id = ?", userID).
|
||||||
|
UpdateColumn("following_count", gorm.Expr("following_count + 1")).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecrementFollowingCount 减少用户关注数
|
||||||
|
func (r *UserRepository) DecrementFollowingCount(userID string) error {
|
||||||
|
return r.db.Model(&model.User{}).Where("id = ? AND following_count > 0", userID).
|
||||||
|
UpdateColumn("following_count", gorm.Expr("following_count - 1")).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshFollowersCount 刷新用户粉丝数(通过实际计数)
|
||||||
|
func (r *UserRepository) RefreshFollowersCount(userID string) error {
|
||||||
|
var count int64
|
||||||
|
err := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Count(&count).Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return r.db.Model(&model.User{}).Where("id = ?", userID).
|
||||||
|
UpdateColumn("followers_count", count).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPostsCount 获取用户帖子数(实时计算)
|
||||||
|
func (r *UserRepository) GetPostsCount(userID string) (int64, error) {
|
||||||
|
var count int64
|
||||||
|
err := r.db.Model(&model.Post{}).Where("user_id = ?", userID).Count(&count).Error
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPostsCountBatch 批量获取用户帖子数(实时计算)
|
||||||
|
// 返回 map[userID]postsCount
|
||||||
|
func (r *UserRepository) GetPostsCountBatch(userIDs []string) (map[string]int64, error) {
|
||||||
|
result := make(map[string]int64)
|
||||||
|
if len(userIDs) == 0 {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化所有用户ID的计数为0
|
||||||
|
for _, userID := range userIDs {
|
||||||
|
result[userID] = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 GROUP BY 一次性查询所有用户的帖子数
|
||||||
|
type CountResult struct {
|
||||||
|
UserID string
|
||||||
|
Count int64
|
||||||
|
}
|
||||||
|
var counts []CountResult
|
||||||
|
err := r.db.Model(&model.Post{}).
|
||||||
|
Select("user_id, count(*) as count").
|
||||||
|
Where("user_id IN ?", userIDs).
|
||||||
|
Group("user_id").
|
||||||
|
Scan(&counts).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新查询结果
|
||||||
|
for _, c := range counts {
|
||||||
|
result[c.UserID] = c.Count
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshFollowingCount 刷新用户关注数(通过实际计数)
|
||||||
|
func (r *UserRepository) RefreshFollowingCount(userID string) error {
|
||||||
|
var count int64
|
||||||
|
err := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Count(&count).Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return r.db.Model(&model.User{}).Where("id = ?", userID).
|
||||||
|
UpdateColumn("following_count", count).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsBlocked 检查拉黑关系是否存在(blocker -> blocked)
|
||||||
|
func (r *UserRepository) IsBlocked(blockerID, blockedID string) (bool, error) {
|
||||||
|
var count int64
|
||||||
|
err := r.db.Model(&model.UserBlock{}).
|
||||||
|
Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID).
|
||||||
|
Count(&count).Error
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return count > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsBlockedEitherDirection 检查是否任一方向存在拉黑
|
||||||
|
func (r *UserRepository) IsBlockedEitherDirection(userA, userB string) (bool, error) {
|
||||||
|
var count int64
|
||||||
|
err := r.db.Model(&model.UserBlock{}).
|
||||||
|
Where("(blocker_id = ? AND blocked_id = ?) OR (blocker_id = ? AND blocked_id = ?)",
|
||||||
|
userA, userB, userB, userA).
|
||||||
|
Count(&count).Error
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return count > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlockUserAndCleanupRelations 拉黑用户并清理双向关注关系(事务)
|
||||||
|
func (r *UserRepository) BlockUserAndCleanupRelations(blockerID, blockedID string) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
block := &model.UserBlock{
|
||||||
|
BlockerID: blockerID,
|
||||||
|
BlockedID: blockedID,
|
||||||
|
}
|
||||||
|
if err := tx.Clauses(clause.OnConflict{
|
||||||
|
Columns: []clause.Column{{Name: "blocker_id"}, {Name: "blocked_id"}},
|
||||||
|
DoNothing: true,
|
||||||
|
}).Create(block).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Where("follower_id = ? AND following_id = ?", blockerID, blockedID).
|
||||||
|
Delete(&model.Follow{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := tx.Where("follower_id = ? AND following_id = ?", blockedID, blockerID).
|
||||||
|
Delete(&model.Follow{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, uid := range []string{blockerID, blockedID} {
|
||||||
|
var followersCount int64
|
||||||
|
if err := tx.Model(&model.Follow{}).Where("following_id = ?", uid).Count(&followersCount).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := tx.Model(&model.User{}).Where("id = ?", uid).
|
||||||
|
UpdateColumn("followers_count", followersCount).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var followingCount int64
|
||||||
|
if err := tx.Model(&model.Follow{}).Where("follower_id = ?", uid).Count(&followingCount).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := tx.Model(&model.User{}).Where("id = ?", uid).
|
||||||
|
UpdateColumn("following_count", followingCount).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnblockUser 取消拉黑
|
||||||
|
func (r *UserRepository) UnblockUser(blockerID, blockedID string) error {
|
||||||
|
return r.db.Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID).
|
||||||
|
Delete(&model.UserBlock{}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBlockedUsers 获取用户黑名单列表
|
||||||
|
func (r *UserRepository) GetBlockedUsers(blockerID string, page, pageSize int) ([]*model.User, int64, error) {
|
||||||
|
var users []*model.User
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
subQuery := r.db.Model(&model.UserBlock{}).Where("blocker_id = ?", blockerID).Select("blocked_id")
|
||||||
|
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := r.db.Where("id IN (?)", subQuery).
|
||||||
|
Order("created_at DESC, id DESC").
|
||||||
|
Offset(offset).
|
||||||
|
Limit(pageSize).
|
||||||
|
Find(&users).Error
|
||||||
|
|
||||||
|
return users, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search 搜索用户
|
||||||
|
func (r *UserRepository) Search(keyword string, page, pageSize int) ([]*model.User, int64, error) {
|
||||||
|
var users []*model.User
|
||||||
|
var total int64
|
||||||
|
|
||||||
|
query := r.db.Model(&model.User{})
|
||||||
|
|
||||||
|
// 搜索用户名、昵称、简介
|
||||||
|
if keyword != "" {
|
||||||
|
if r.db.Dialector.Name() == "postgres" {
|
||||||
|
query = query.Where(
|
||||||
|
"to_tsvector('simple', COALESCE(username, '') || ' ' || COALESCE(nickname, '') || ' ' || COALESCE(bio, '')) @@ plainto_tsquery('simple', ?)",
|
||||||
|
keyword,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
searchPattern := "%" + keyword + "%"
|
||||||
|
query = query.Where("username LIKE ? OR nickname LIKE ? OR bio LIKE ?", searchPattern, searchPattern, searchPattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
query.Count(&total)
|
||||||
|
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&users).Error
|
||||||
|
|
||||||
|
return users, total, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMutualFollowStatus 批量获取双向关注状态
|
||||||
|
// 返回 map[userID][isFollowing, isFollowingMe]
|
||||||
|
func (r *UserRepository) GetMutualFollowStatus(currentUserID string, targetUserIDs []string) (map[string][2]bool, error) {
|
||||||
|
result := make(map[string][2]bool)
|
||||||
|
|
||||||
|
if len(targetUserIDs) == 0 {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("[DEBUG] GetMutualFollowStatus: currentUserID=%s, targetUserIDs=%v\n", currentUserID, targetUserIDs)
|
||||||
|
|
||||||
|
// 初始化所有目标用户为未关注状态
|
||||||
|
for _, userID := range targetUserIDs {
|
||||||
|
result[userID] = [2]bool{false, false}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查询当前用户关注了哪些目标用户 (isFollowing)
|
||||||
|
var followingIDs []string
|
||||||
|
err := r.db.Model(&model.Follow{}).
|
||||||
|
Where("follower_id = ? AND following_id IN ?", currentUserID, targetUserIDs).
|
||||||
|
Pluck("following_id", &followingIDs).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fmt.Printf("[DEBUG] GetMutualFollowStatus: currentUser follows these targets: %v\n", followingIDs)
|
||||||
|
for _, id := range followingIDs {
|
||||||
|
status := result[id]
|
||||||
|
status[0] = true
|
||||||
|
result[id] = status
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查询哪些目标用户关注了当前用户 (isFollowingMe)
|
||||||
|
var followerIDs []string
|
||||||
|
err = r.db.Model(&model.Follow{}).
|
||||||
|
Where("follower_id IN ? AND following_id = ?", targetUserIDs, currentUserID).
|
||||||
|
Pluck("follower_id", &followerIDs).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fmt.Printf("[DEBUG] GetMutualFollowStatus: these targets follow currentUser: %v\n", followerIDs)
|
||||||
|
for _, id := range followerIDs {
|
||||||
|
status := result[id]
|
||||||
|
status[1] = true
|
||||||
|
result[id] = status
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("[DEBUG] GetMutualFollowStatus: final result=%v\n", result)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
141
internal/repository/vote_repo.go
Normal file
141
internal/repository/vote_repo.go
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// VoteRepository 投票仓储
|
||||||
|
type VoteRepository struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewVoteRepository 创建投票仓储
|
||||||
|
func NewVoteRepository(db *gorm.DB) *VoteRepository {
|
||||||
|
return &VoteRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateOptions 批量创建投票选项
|
||||||
|
func (r *VoteRepository) CreateOptions(postID string, options []string) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
for i, content := range options {
|
||||||
|
option := &model.VoteOption{
|
||||||
|
PostID: postID,
|
||||||
|
Content: content,
|
||||||
|
SortOrder: i,
|
||||||
|
}
|
||||||
|
if err := tx.Create(option).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOptionsByPostID 获取帖子的所有投票选项
|
||||||
|
func (r *VoteRepository) GetOptionsByPostID(postID string) ([]model.VoteOption, error) {
|
||||||
|
var options []model.VoteOption
|
||||||
|
err := r.db.Where("post_id = ?", postID).Order("sort_order ASC").Find(&options).Error
|
||||||
|
return options, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vote 用户投票
|
||||||
|
func (r *VoteRepository) Vote(postID, userID, optionID string) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 检查用户是否已投票
|
||||||
|
var existing model.UserVote
|
||||||
|
err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&existing).Error
|
||||||
|
if err == nil {
|
||||||
|
// 已经投票,返回错误
|
||||||
|
return errors.New("user already voted")
|
||||||
|
}
|
||||||
|
if err != gorm.ErrRecordNotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证选项是否属于该帖子
|
||||||
|
var option model.VoteOption
|
||||||
|
if err := tx.Where("id = ? AND post_id = ?", optionID, postID).First(&option).Error; err != nil {
|
||||||
|
if err == gorm.ErrRecordNotFound {
|
||||||
|
return errors.New("invalid option")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建投票记录
|
||||||
|
if err := tx.Create(&model.UserVote{
|
||||||
|
PostID: postID,
|
||||||
|
UserID: userID,
|
||||||
|
OptionID: optionID,
|
||||||
|
}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 原子增加选项投票数
|
||||||
|
return tx.Model(&model.VoteOption{}).Where("id = ?", optionID).
|
||||||
|
UpdateColumn("votes_count", gorm.Expr("votes_count + 1")).Error
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unvote 取消投票
|
||||||
|
func (r *VoteRepository) Unvote(postID, userID string) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 获取用户的投票记录
|
||||||
|
var userVote model.UserVote
|
||||||
|
err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&userVote).Error
|
||||||
|
if err != nil {
|
||||||
|
if err == gorm.ErrRecordNotFound {
|
||||||
|
return nil // 没有投票记录,直接返回
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除投票记录
|
||||||
|
result := tx.Where("post_id = ? AND user_id = ?", postID, userID).Delete(&model.UserVote{})
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected > 0 {
|
||||||
|
// 原子减少选项投票数
|
||||||
|
return tx.Model(&model.VoteOption{}).Where("id = ?", userVote.OptionID).
|
||||||
|
UpdateColumn("votes_count", gorm.Expr("votes_count - 1")).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserVote 获取用户在指定帖子的投票
|
||||||
|
func (r *VoteRepository) GetUserVote(postID, userID string) (*model.UserVote, error) {
|
||||||
|
var userVote model.UserVote
|
||||||
|
err := r.db.Where("post_id = ? AND user_id = ?", postID, userID).First(&userVote).Error
|
||||||
|
if err != nil {
|
||||||
|
if err == gorm.ErrRecordNotFound {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &userVote, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateOption 更新选项内容
|
||||||
|
func (r *VoteRepository) UpdateOption(optionID, content string) error {
|
||||||
|
return r.db.Model(&model.VoteOption{}).Where("id = ?", optionID).
|
||||||
|
Update("content", content).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteOptionsByPostID 删除帖子的所有投票选项
|
||||||
|
func (r *VoteRepository) DeleteOptionsByPostID(postID string) error {
|
||||||
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
// 删除关联的用户投票记录
|
||||||
|
if err := tx.Where("post_id = ?", postID).Delete(&model.UserVote{}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除投票选项
|
||||||
|
return tx.Where("post_id = ?", postID).Delete(&model.VoteOption{}).Error
|
||||||
|
})
|
||||||
|
}
|
||||||
334
internal/router/router.go
Normal file
334
internal/router/router.go
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/handler"
|
||||||
|
"carrot_bbs/internal/middleware"
|
||||||
|
"carrot_bbs/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Router 路由配置
|
||||||
|
type Router struct {
|
||||||
|
engine *gin.Engine
|
||||||
|
userHandler *handler.UserHandler
|
||||||
|
postHandler *handler.PostHandler
|
||||||
|
commentHandler *handler.CommentHandler
|
||||||
|
messageHandler *handler.MessageHandler
|
||||||
|
notificationHandler *handler.NotificationHandler
|
||||||
|
uploadHandler *handler.UploadHandler
|
||||||
|
wsHandler *handler.WebSocketHandler
|
||||||
|
pushHandler *handler.PushHandler
|
||||||
|
systemMessageHandler *handler.SystemMessageHandler
|
||||||
|
groupHandler *handler.GroupHandler
|
||||||
|
stickerHandler *handler.StickerHandler
|
||||||
|
gorseHandler *handler.GorseHandler
|
||||||
|
voteHandler *handler.VoteHandler
|
||||||
|
jwtService *service.JWTService
|
||||||
|
}
|
||||||
|
|
||||||
|
// New 创建路由
|
||||||
|
func New(
|
||||||
|
userHandler *handler.UserHandler,
|
||||||
|
postHandler *handler.PostHandler,
|
||||||
|
commentHandler *handler.CommentHandler,
|
||||||
|
messageHandler *handler.MessageHandler,
|
||||||
|
notificationHandler *handler.NotificationHandler,
|
||||||
|
uploadHandler *handler.UploadHandler,
|
||||||
|
jwtService *service.JWTService,
|
||||||
|
wsHandler *handler.WebSocketHandler,
|
||||||
|
pushHandler *handler.PushHandler,
|
||||||
|
systemMessageHandler *handler.SystemMessageHandler,
|
||||||
|
groupHandler *handler.GroupHandler,
|
||||||
|
stickerHandler *handler.StickerHandler,
|
||||||
|
gorseHandler *handler.GorseHandler,
|
||||||
|
voteHandler *handler.VoteHandler,
|
||||||
|
) *Router {
|
||||||
|
// 设置JWT服务
|
||||||
|
userHandler.SetJWTService(jwtService)
|
||||||
|
|
||||||
|
r := &Router{
|
||||||
|
engine: gin.Default(),
|
||||||
|
userHandler: userHandler,
|
||||||
|
postHandler: postHandler,
|
||||||
|
commentHandler: commentHandler,
|
||||||
|
messageHandler: messageHandler,
|
||||||
|
notificationHandler: notificationHandler,
|
||||||
|
uploadHandler: uploadHandler,
|
||||||
|
wsHandler: wsHandler,
|
||||||
|
pushHandler: pushHandler,
|
||||||
|
systemMessageHandler: systemMessageHandler,
|
||||||
|
groupHandler: groupHandler,
|
||||||
|
stickerHandler: stickerHandler,
|
||||||
|
gorseHandler: gorseHandler,
|
||||||
|
voteHandler: voteHandler,
|
||||||
|
jwtService: jwtService,
|
||||||
|
}
|
||||||
|
|
||||||
|
r.setupRoutes()
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupRoutes 设置路由
|
||||||
|
func (r *Router) setupRoutes() {
|
||||||
|
// 中间件
|
||||||
|
r.engine.Use(middleware.CORS())
|
||||||
|
|
||||||
|
// 健康检查
|
||||||
|
r.engine.GET("/health", func(c *gin.Context) {
|
||||||
|
c.JSON(200, gin.H{"status": "ok"})
|
||||||
|
})
|
||||||
|
|
||||||
|
// WebSocket 路由
|
||||||
|
if r.wsHandler != nil {
|
||||||
|
r.engine.GET("/ws", r.wsHandler.HandleWebSocket)
|
||||||
|
}
|
||||||
|
|
||||||
|
// API v1
|
||||||
|
v1 := r.engine.Group("/api/v1")
|
||||||
|
{
|
||||||
|
// 认证路由(公开)
|
||||||
|
auth := v1.Group("/auth")
|
||||||
|
{
|
||||||
|
auth.POST("/register", r.userHandler.Register)
|
||||||
|
auth.POST("/register/send-code", r.userHandler.SendRegisterCode)
|
||||||
|
auth.POST("/login", r.userHandler.Login)
|
||||||
|
auth.POST("/password/send-code", r.userHandler.SendPasswordResetCode)
|
||||||
|
auth.POST("/password/reset", r.userHandler.ResetPassword)
|
||||||
|
auth.POST("/refresh", r.userHandler.RefreshToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 需要认证的路由
|
||||||
|
authMiddleware := middleware.Auth(r.jwtService)
|
||||||
|
|
||||||
|
// 用户路由
|
||||||
|
users := v1.Group("/users")
|
||||||
|
{
|
||||||
|
// 当前用户
|
||||||
|
users.GET("/me", authMiddleware, r.userHandler.GetCurrentUser)
|
||||||
|
users.PUT("/me", authMiddleware, r.userHandler.UpdateUser)
|
||||||
|
users.POST("/me/email/send-code", authMiddleware, r.userHandler.SendEmailVerifyCode)
|
||||||
|
users.POST("/me/email/verify", authMiddleware, r.userHandler.VerifyEmail)
|
||||||
|
users.POST("/me/avatar", authMiddleware, r.uploadHandler.UploadAvatar)
|
||||||
|
users.POST("/me/cover", authMiddleware, r.uploadHandler.UploadCover)
|
||||||
|
users.POST("/change-password/send-code", authMiddleware, r.userHandler.SendChangePasswordCode)
|
||||||
|
users.POST("/change-password", authMiddleware, r.userHandler.ChangePassword)
|
||||||
|
|
||||||
|
// 搜索用户
|
||||||
|
users.GET("/search", middleware.OptionalAuth(r.jwtService), r.userHandler.Search)
|
||||||
|
|
||||||
|
// 其他用户
|
||||||
|
users.GET("/:id", middleware.OptionalAuth(r.jwtService), r.userHandler.GetUserByID)
|
||||||
|
users.POST("/:id/follow", authMiddleware, r.userHandler.FollowUser)
|
||||||
|
users.DELETE("/:id/follow", authMiddleware, r.userHandler.UnfollowUser)
|
||||||
|
users.POST("/:id/block", authMiddleware, r.userHandler.BlockUser)
|
||||||
|
users.DELETE("/:id/block", authMiddleware, r.userHandler.UnblockUser)
|
||||||
|
users.GET("/:id/block-status", authMiddleware, r.userHandler.GetBlockStatus)
|
||||||
|
users.GET("/blocks", authMiddleware, r.userHandler.GetBlockedUsers)
|
||||||
|
users.GET("/:id/following", authMiddleware, r.userHandler.GetFollowingList)
|
||||||
|
users.GET("/:id/followers", authMiddleware, r.userHandler.GetFollowersList)
|
||||||
|
|
||||||
|
// 用户帖子 - 使用 OptionalAuth 获取当前用户点赞/收藏状态
|
||||||
|
users.GET("/:id/posts", middleware.OptionalAuth(r.jwtService), r.postHandler.GetUserPosts)
|
||||||
|
users.GET("/:id/favorites", middleware.OptionalAuth(r.jwtService), r.postHandler.GetFavorites)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 认证路由(公开)
|
||||||
|
authPublic := v1.Group("/auth")
|
||||||
|
{
|
||||||
|
authPublic.GET("/check-username", r.userHandler.CheckUsername)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 帖子路由
|
||||||
|
posts := v1.Group("/posts")
|
||||||
|
{
|
||||||
|
// 使用 OptionalAuth 中间件来获取用户登录状态
|
||||||
|
posts.GET("", middleware.OptionalAuth(r.jwtService), r.postHandler.List)
|
||||||
|
posts.GET("/search", middleware.OptionalAuth(r.jwtService), r.postHandler.Search)
|
||||||
|
posts.GET("/:id", middleware.OptionalAuth(r.jwtService), r.postHandler.GetByID)
|
||||||
|
posts.POST("", authMiddleware, r.postHandler.Create)
|
||||||
|
posts.PUT("/:id", authMiddleware, r.postHandler.Update)
|
||||||
|
posts.DELETE("/:id", authMiddleware, r.postHandler.Delete)
|
||||||
|
|
||||||
|
// 浏览量记录(可选认证,允许游客浏览)
|
||||||
|
posts.POST("/:id/view", middleware.OptionalAuth(r.jwtService), r.postHandler.RecordView)
|
||||||
|
|
||||||
|
// 点赞
|
||||||
|
posts.POST("/:id/like", authMiddleware, r.postHandler.Like)
|
||||||
|
posts.DELETE("/:id/like", authMiddleware, r.postHandler.Unlike)
|
||||||
|
|
||||||
|
// 收藏
|
||||||
|
posts.POST("/:id/favorite", authMiddleware, r.postHandler.Favorite)
|
||||||
|
posts.DELETE("/:id/favorite", authMiddleware, r.postHandler.Unfavorite)
|
||||||
|
|
||||||
|
// 投票相关路由
|
||||||
|
posts.POST("/vote", authMiddleware, r.voteHandler.CreateVotePost) // 创建投票帖子
|
||||||
|
posts.GET("/:id/vote", middleware.OptionalAuth(r.jwtService), r.voteHandler.GetVoteResult) // 获取投票结果
|
||||||
|
posts.POST("/:id/vote", authMiddleware, r.voteHandler.Vote) // 投票
|
||||||
|
posts.DELETE("/:id/vote", authMiddleware, r.voteHandler.Unvote) // 取消投票
|
||||||
|
}
|
||||||
|
|
||||||
|
// 投票选项路由
|
||||||
|
voteOptions := v1.Group("/vote-options")
|
||||||
|
voteOptions.Use(authMiddleware)
|
||||||
|
{
|
||||||
|
voteOptions.PUT("/:id", r.voteHandler.UpdateVoteOption) // 更新选项
|
||||||
|
}
|
||||||
|
|
||||||
|
// 评论路由
|
||||||
|
comments := v1.Group("/comments")
|
||||||
|
{
|
||||||
|
comments.GET("/post/:id", middleware.OptionalAuth(r.jwtService), r.commentHandler.GetByPostID)
|
||||||
|
comments.GET("/:id", middleware.OptionalAuth(r.jwtService), r.commentHandler.GetByID)
|
||||||
|
comments.POST("", authMiddleware, r.commentHandler.Create)
|
||||||
|
comments.PUT("/:id", authMiddleware, r.commentHandler.Update)
|
||||||
|
comments.DELETE("/:id", authMiddleware, r.commentHandler.Delete)
|
||||||
|
comments.GET("/:id/replies", middleware.OptionalAuth(r.jwtService), r.commentHandler.GetReplies)
|
||||||
|
comments.GET("/:id/replies/flat", middleware.OptionalAuth(r.jwtService), r.commentHandler.GetRepliesByRootID) // 扁平化分页获取回复
|
||||||
|
// 评论点赞
|
||||||
|
comments.POST("/:id/like", authMiddleware, r.commentHandler.Like)
|
||||||
|
comments.DELETE("/:id/like", authMiddleware, r.commentHandler.Unlike)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 会话路由(新版 RESTful action 风格)
|
||||||
|
conversations := v1.Group("/conversations")
|
||||||
|
conversations.Use(authMiddleware)
|
||||||
|
{
|
||||||
|
// 获取会话列表
|
||||||
|
conversations.GET("/list", r.messageHandler.HandleGetConversationList)
|
||||||
|
// 创建会话
|
||||||
|
conversations.POST("/create", r.messageHandler.HandleCreateConversation)
|
||||||
|
// 获取会话详情
|
||||||
|
conversations.GET("/get", r.messageHandler.HandleGetConversation)
|
||||||
|
// 获取会话消息
|
||||||
|
conversations.GET("/get_messages", r.messageHandler.HandleGetMessages)
|
||||||
|
// 发送消息
|
||||||
|
conversations.POST("/send_message", r.messageHandler.HandleSendMessage)
|
||||||
|
// 标记已读
|
||||||
|
conversations.POST("/mark_read", r.messageHandler.HandleMarkRead)
|
||||||
|
// 会话置顶
|
||||||
|
conversations.POST("/set_pinned", r.messageHandler.HandleSetConversationPinned)
|
||||||
|
// 获取未读消息总数
|
||||||
|
conversations.GET("/unread/count", r.messageHandler.GetUnreadCount)
|
||||||
|
// 仅自己删除会话
|
||||||
|
conversations.DELETE("/:id/self", r.messageHandler.HandleDeleteConversationForSelf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 消息操作路由
|
||||||
|
messages := v1.Group("/messages")
|
||||||
|
messages.Use(authMiddleware)
|
||||||
|
{
|
||||||
|
// 撤回/删除消息(统一接口)
|
||||||
|
messages.POST("/delete_msg", r.messageHandler.HandleDeleteMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 通知路由
|
||||||
|
notifications := v1.Group("/notifications")
|
||||||
|
{
|
||||||
|
notifications.GET("", authMiddleware, r.notificationHandler.GetNotifications)
|
||||||
|
notifications.POST("/:id/read", authMiddleware, r.notificationHandler.MarkAsRead)
|
||||||
|
notifications.POST("/read-all", authMiddleware, r.notificationHandler.MarkAllAsRead)
|
||||||
|
notifications.GET("/unread-count", authMiddleware, r.notificationHandler.GetUnreadCount)
|
||||||
|
notifications.DELETE("/:id", authMiddleware, r.notificationHandler.DeleteNotification)
|
||||||
|
notifications.DELETE("", authMiddleware, r.notificationHandler.ClearAllNotifications)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 上传路由
|
||||||
|
uploads := v1.Group("/uploads")
|
||||||
|
{
|
||||||
|
uploads.POST("/images", authMiddleware, r.uploadHandler.UploadImage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 推送相关路由
|
||||||
|
if r.pushHandler != nil {
|
||||||
|
pushGroup := v1.Group("/push")
|
||||||
|
pushGroup.Use(authMiddleware)
|
||||||
|
{
|
||||||
|
pushGroup.POST("/devices", r.pushHandler.RegisterDevice)
|
||||||
|
pushGroup.GET("/devices", r.pushHandler.GetMyDevices)
|
||||||
|
pushGroup.DELETE("/devices/:device_id", r.pushHandler.UnregisterDevice)
|
||||||
|
pushGroup.PUT("/devices/:device_id/token", r.pushHandler.UpdateDeviceToken)
|
||||||
|
pushGroup.GET("/records", r.pushHandler.GetPushRecords)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 系统消息路由
|
||||||
|
if r.systemMessageHandler != nil {
|
||||||
|
msgGroup := v1.Group("/messages")
|
||||||
|
msgGroup.Use(authMiddleware)
|
||||||
|
{
|
||||||
|
msgGroup.GET("/system", r.systemMessageHandler.GetSystemMessages)
|
||||||
|
msgGroup.GET("/system/unread-count", r.systemMessageHandler.GetUnreadCount)
|
||||||
|
msgGroup.PUT("/system/:id/read", r.systemMessageHandler.MarkAsRead)
|
||||||
|
msgGroup.PUT("/system/read-all", r.systemMessageHandler.MarkAllAsRead)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 群组路由(新版 RESTful action 风格)
|
||||||
|
if r.groupHandler != nil {
|
||||||
|
groups := v1.Group("/groups")
|
||||||
|
groups.Use(authMiddleware)
|
||||||
|
{
|
||||||
|
// 群组管理
|
||||||
|
groups.POST("/create", r.groupHandler.HandleCreateGroup)
|
||||||
|
groups.GET("/list", r.groupHandler.HandleGetUserGroups)
|
||||||
|
groups.GET("/get", r.groupHandler.HandleGetGroupInfo)
|
||||||
|
groups.GET("/get_my_info", r.groupHandler.HandleGetMyMemberInfo)
|
||||||
|
groups.POST("/dissolve", r.groupHandler.HandleDissolveGroup)
|
||||||
|
groups.POST("/transfer", r.groupHandler.HandleTransferOwner)
|
||||||
|
|
||||||
|
// 成员管理
|
||||||
|
groups.POST("/invite_members", r.groupHandler.HandleInviteMembers)
|
||||||
|
groups.POST("/join", r.groupHandler.HandleJoinGroup)
|
||||||
|
groups.POST("/respond_invite", r.groupHandler.HandleRespondInvite)
|
||||||
|
groups.POST("/set_group_leave", r.groupHandler.HandleSetGroupLeave)
|
||||||
|
groups.GET("/get_members", r.groupHandler.HandleGetGroupMemberList)
|
||||||
|
groups.POST("/set_group_kick", r.groupHandler.HandleSetGroupKick)
|
||||||
|
groups.POST("/set_group_admin", r.groupHandler.HandleSetGroupAdmin)
|
||||||
|
groups.POST("/set_nickname", r.groupHandler.HandleSetNickname)
|
||||||
|
groups.POST("/set_group_ban", r.groupHandler.HandleSetGroupBan)
|
||||||
|
|
||||||
|
// 群设置
|
||||||
|
groups.POST("/set_group_whole_ban", r.groupHandler.HandleSetGroupWholeBan)
|
||||||
|
groups.POST("/set_join_type", r.groupHandler.HandleSetJoinType)
|
||||||
|
groups.POST("/set_group_name", r.groupHandler.HandleSetGroupName)
|
||||||
|
groups.POST("/set_group_avatar", r.groupHandler.HandleSetGroupAvatar)
|
||||||
|
|
||||||
|
// 群公告
|
||||||
|
groups.POST("/create_announcement", r.groupHandler.HandleCreateAnnouncement)
|
||||||
|
groups.GET("/get_announcements", r.groupHandler.HandleGetAnnouncements)
|
||||||
|
groups.POST("/delete_announcement", r.groupHandler.HandleDeleteAnnouncement)
|
||||||
|
|
||||||
|
// 加群请求处理(预留)
|
||||||
|
groups.POST("/set_group_add_request", r.groupHandler.HandleSetGroupAddRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 自定义表情路由
|
||||||
|
if r.stickerHandler != nil {
|
||||||
|
stickers := v1.Group("/stickers")
|
||||||
|
stickers.Use(authMiddleware)
|
||||||
|
{
|
||||||
|
stickers.GET("", r.stickerHandler.GetStickers)
|
||||||
|
stickers.POST("", r.stickerHandler.AddSticker)
|
||||||
|
stickers.DELETE("", r.stickerHandler.DeleteSticker)
|
||||||
|
stickers.POST("/reorder", r.stickerHandler.ReorderStickers)
|
||||||
|
stickers.GET("/check", r.stickerHandler.CheckStickerExists)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gorse 管理路由
|
||||||
|
if r.gorseHandler != nil {
|
||||||
|
gorseGroup := v1.Group("/gorse")
|
||||||
|
{
|
||||||
|
gorseGroup.GET("/status", r.gorseHandler.GetStatus)
|
||||||
|
gorseGroup.POST("/import", r.gorseHandler.ImportData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Engine 获取引擎
|
||||||
|
func (r *Router) Engine() *gin.Engine {
|
||||||
|
return r.engine
|
||||||
|
}
|
||||||
759
internal/service/audit_service.go
Normal file
759
internal/service/audit_service.go
Normal file
@@ -0,0 +1,759 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==================== 内容审核服务接口和实现 ====================
|
||||||
|
|
||||||
|
// AuditServiceProvider 内容审核服务提供商接口
|
||||||
|
type AuditServiceProvider interface {
|
||||||
|
// AuditText 审核文本
|
||||||
|
AuditText(ctx context.Context, text string, scene string) (*AuditResult, error)
|
||||||
|
// AuditImage 审核图片
|
||||||
|
AuditImage(ctx context.Context, imageURL string) (*AuditResult, error)
|
||||||
|
// GetName 获取提供商名称
|
||||||
|
GetName() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditResult 审核结果
|
||||||
|
type AuditResult struct {
|
||||||
|
Pass bool `json:"pass"` // 是否通过
|
||||||
|
Risk string `json:"risk"` // 风险等级: low, medium, high
|
||||||
|
Labels []string `json:"labels"` // 标签列表
|
||||||
|
Suggest string `json:"suggest"` // 建议: pass, review, block
|
||||||
|
Detail string `json:"detail"` // 详细说明
|
||||||
|
Provider string `json:"provider"` // 服务提供商
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditService 内容审核服务接口
|
||||||
|
type AuditService interface {
|
||||||
|
// AuditText 审核文本
|
||||||
|
AuditText(ctx context.Context, text string, auditType string) (*AuditResult, error)
|
||||||
|
// AuditImage 审核图片
|
||||||
|
AuditImage(ctx context.Context, imageURL string) (*AuditResult, error)
|
||||||
|
// GetAuditResult 获取审核结果
|
||||||
|
GetAuditResult(ctx context.Context, auditID string) (*AuditResult, error)
|
||||||
|
// SetProvider 设置审核服务提供商
|
||||||
|
SetProvider(provider AuditServiceProvider)
|
||||||
|
// GetProvider 获取当前审核服务提供商
|
||||||
|
GetProvider() AuditServiceProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
// auditServiceImpl 内容审核服务实现
|
||||||
|
type auditServiceImpl struct {
|
||||||
|
db *gorm.DB
|
||||||
|
provider AuditServiceProvider
|
||||||
|
config *AuditConfig
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditConfig 内容审核服务配置
|
||||||
|
type AuditConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled" yaml:"enabled"`
|
||||||
|
// 审核服务提供商: local, aliyun, tencent, baidu
|
||||||
|
Provider string `mapstructure:"provider" yaml:"provider"`
|
||||||
|
// 阿里云配置
|
||||||
|
AliyunAccessKey string `mapstructure:"aliyun_access_key" yaml:"aliyun_access_key"`
|
||||||
|
AliyunSecretKey string `mapstructure:"aliyun_secret_key" yaml:"aliyun_secret_key"`
|
||||||
|
AliyunRegion string `mapstructure:"aliyun_region" yaml:"aliyun_region"`
|
||||||
|
// 腾讯云配置
|
||||||
|
TencentSecretID string `mapstructure:"tencent_secret_id" yaml:"tencent_secret_id"`
|
||||||
|
TencentSecretKey string `mapstructure:"tencent_secret_key" yaml:"tencent_secret_key"`
|
||||||
|
// 百度云配置
|
||||||
|
BaiduAPIKey string `mapstructure:"baidu_api_key" yaml:"baidu_api_key"`
|
||||||
|
BaiduSecretKey string `mapstructure:"baidu_secret_key" yaml:"baidu_secret_key"`
|
||||||
|
// 是否自动审核
|
||||||
|
AutoAudit bool `mapstructure:"auto_audit" yaml:"auto_audit"`
|
||||||
|
// 审核超时时间(秒)
|
||||||
|
Timeout int `mapstructure:"timeout" yaml:"timeout"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuditService 创建内容审核服务
|
||||||
|
func NewAuditService(db *gorm.DB, config *AuditConfig) AuditService {
|
||||||
|
s := &auditServiceImpl{
|
||||||
|
db: db,
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据配置初始化提供商
|
||||||
|
if config.Enabled {
|
||||||
|
provider := s.initProvider(config.Provider)
|
||||||
|
s.provider = provider
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// initProvider 根据配置初始化审核服务提供商
|
||||||
|
func (s *auditServiceImpl) initProvider(providerType string) AuditServiceProvider {
|
||||||
|
switch strings.ToLower(providerType) {
|
||||||
|
case "aliyun":
|
||||||
|
return NewAliyunAuditProvider(s.config.AliyunAccessKey, s.config.AliyunSecretKey, s.config.AliyunRegion)
|
||||||
|
case "tencent":
|
||||||
|
return NewTencentAuditProvider(s.config.TencentSecretID, s.config.TencentSecretKey)
|
||||||
|
case "baidu":
|
||||||
|
return NewBaiduAuditProvider(s.config.BaiduAPIKey, s.config.BaiduSecretKey)
|
||||||
|
case "local":
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
// 默认使用本地审核服务
|
||||||
|
return NewLocalAuditProvider()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditText 审核文本
|
||||||
|
func (s *auditServiceImpl) AuditText(ctx context.Context, text string, auditType string) (*AuditResult, error) {
|
||||||
|
if !s.config.Enabled {
|
||||||
|
// 如果审核服务未启用,直接返回通过
|
||||||
|
return &AuditResult{
|
||||||
|
Pass: true,
|
||||||
|
Risk: "low",
|
||||||
|
Suggest: "pass",
|
||||||
|
Detail: "Audit service disabled",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if text == "" {
|
||||||
|
return &AuditResult{
|
||||||
|
Pass: true,
|
||||||
|
Risk: "low",
|
||||||
|
Suggest: "pass",
|
||||||
|
Detail: "Empty text",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var result *AuditResult
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// 使用提供商审核
|
||||||
|
if s.provider != nil {
|
||||||
|
result, err = s.provider.AuditText(ctx, text, auditType)
|
||||||
|
} else {
|
||||||
|
// 如果没有设置提供商,使用本地审核
|
||||||
|
localProvider := NewLocalAuditProvider()
|
||||||
|
result, err = localProvider.AuditText(ctx, text, auditType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Audit text error: %v", err)
|
||||||
|
return &AuditResult{
|
||||||
|
Pass: false,
|
||||||
|
Risk: "high",
|
||||||
|
Suggest: "review",
|
||||||
|
Detail: fmt.Sprintf("Audit error: %v", err),
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录审核日志
|
||||||
|
go s.saveAuditLog(ctx, "text", "", text, auditType, result)
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditImage 审核图片
|
||||||
|
func (s *auditServiceImpl) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
|
||||||
|
if !s.config.Enabled {
|
||||||
|
return &AuditResult{
|
||||||
|
Pass: true,
|
||||||
|
Risk: "low",
|
||||||
|
Suggest: "pass",
|
||||||
|
Detail: "Audit service disabled",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if imageURL == "" {
|
||||||
|
return &AuditResult{
|
||||||
|
Pass: true,
|
||||||
|
Risk: "low",
|
||||||
|
Suggest: "pass",
|
||||||
|
Detail: "Empty image URL",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var result *AuditResult
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// 使用提供商审核
|
||||||
|
if s.provider != nil {
|
||||||
|
result, err = s.provider.AuditImage(ctx, imageURL)
|
||||||
|
} else {
|
||||||
|
// 如果没有设置提供商,使用本地审核
|
||||||
|
localProvider := NewLocalAuditProvider()
|
||||||
|
result, err = localProvider.AuditImage(ctx, imageURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Audit image error: %v", err)
|
||||||
|
return &AuditResult{
|
||||||
|
Pass: false,
|
||||||
|
Risk: "high",
|
||||||
|
Suggest: "review",
|
||||||
|
Detail: fmt.Sprintf("Audit error: %v", err),
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录审核日志
|
||||||
|
go s.saveAuditLog(ctx, "image", "", "", "image", result)
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuditResult 获取审核结果
|
||||||
|
func (s *auditServiceImpl) GetAuditResult(ctx context.Context, auditID string) (*AuditResult, error) {
|
||||||
|
if s.db == nil || auditID == "" {
|
||||||
|
return nil, fmt.Errorf("invalid audit ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
var auditLog model.AuditLog
|
||||||
|
if err := s.db.Where("id = ?", auditID).First(&auditLog).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &AuditResult{
|
||||||
|
Pass: auditLog.Result == model.AuditResultPass,
|
||||||
|
Risk: string(auditLog.RiskLevel),
|
||||||
|
Suggest: auditLog.Suggestion,
|
||||||
|
Detail: auditLog.Detail,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析标签
|
||||||
|
if auditLog.Labels != "" {
|
||||||
|
json.Unmarshal([]byte(auditLog.Labels), &result.Labels)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetProvider 设置审核服务提供商
|
||||||
|
func (s *auditServiceImpl) SetProvider(provider AuditServiceProvider) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.provider = provider
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProvider 获取当前审核服务提供商
|
||||||
|
func (s *auditServiceImpl) GetProvider() AuditServiceProvider {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.provider
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveAuditLog 保存审核日志
|
||||||
|
func (s *auditServiceImpl) saveAuditLog(ctx context.Context, contentType, content, imageURL, auditType string, result *AuditResult) {
|
||||||
|
if s.db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
auditLog := model.AuditLog{
|
||||||
|
ContentType: contentType,
|
||||||
|
Content: content,
|
||||||
|
ContentURL: imageURL,
|
||||||
|
AuditType: auditType,
|
||||||
|
Labels: strings.Join(result.Labels, ","),
|
||||||
|
Suggestion: result.Suggest,
|
||||||
|
Detail: result.Detail,
|
||||||
|
Source: model.AuditSourceAuto,
|
||||||
|
Status: "completed",
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Pass {
|
||||||
|
auditLog.Result = model.AuditResultPass
|
||||||
|
} else if result.Suggest == "review" {
|
||||||
|
auditLog.Result = model.AuditResultReview
|
||||||
|
} else {
|
||||||
|
auditLog.Result = model.AuditResultBlock
|
||||||
|
}
|
||||||
|
|
||||||
|
switch result.Risk {
|
||||||
|
case "low":
|
||||||
|
auditLog.RiskLevel = model.AuditRiskLevelLow
|
||||||
|
case "medium":
|
||||||
|
auditLog.RiskLevel = model.AuditRiskLevelMedium
|
||||||
|
case "high":
|
||||||
|
auditLog.RiskLevel = model.AuditRiskLevelHigh
|
||||||
|
default:
|
||||||
|
auditLog.RiskLevel = model.AuditRiskLevelLow
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.Create(&auditLog).Error; err != nil {
|
||||||
|
log.Printf("Failed to save audit log: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 本地审核服务提供商 ====================
|
||||||
|
|
||||||
|
// localAuditProvider 本地审核服务提供商
|
||||||
|
type localAuditProvider struct {
|
||||||
|
// 可以注入敏感词服务进行本地审核
|
||||||
|
sensitiveService SensitiveService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLocalAuditProvider 创建本地审核服务提供商
|
||||||
|
func NewLocalAuditProvider() AuditServiceProvider {
|
||||||
|
return &localAuditProvider{
|
||||||
|
sensitiveService: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetName 获取提供商名称
|
||||||
|
func (p *localAuditProvider) GetName() string {
|
||||||
|
return "local"
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditText 审核文本
|
||||||
|
func (p *localAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
|
||||||
|
// 本地审核逻辑
|
||||||
|
// 1. 敏感词检查
|
||||||
|
// 2. 规则匹配
|
||||||
|
// 3. 简单的关键词检测
|
||||||
|
|
||||||
|
result := &AuditResult{
|
||||||
|
Pass: true,
|
||||||
|
Risk: "low",
|
||||||
|
Suggest: "pass",
|
||||||
|
Labels: []string{},
|
||||||
|
Provider: "local",
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有敏感词服务,使用它进行检测
|
||||||
|
if p.sensitiveService != nil {
|
||||||
|
hasSensitive, words := p.sensitiveService.Check(ctx, text)
|
||||||
|
if hasSensitive {
|
||||||
|
result.Pass = false
|
||||||
|
result.Risk = "high"
|
||||||
|
result.Suggest = "block"
|
||||||
|
result.Detail = fmt.Sprintf("包含敏感词: %s", strings.Join(words, ","))
|
||||||
|
result.Labels = append(result.Labels, "sensitive")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 简单的关键词检测规则
|
||||||
|
// 实际项目中应该从数据库加载
|
||||||
|
suspiciousPatterns := []string{
|
||||||
|
"诈骗",
|
||||||
|
"钓鱼",
|
||||||
|
"木马",
|
||||||
|
"病毒",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pattern := range suspiciousPatterns {
|
||||||
|
if strings.Contains(text, pattern) {
|
||||||
|
result.Pass = false
|
||||||
|
result.Risk = "high"
|
||||||
|
result.Suggest = "block"
|
||||||
|
result.Labels = append(result.Labels, "suspicious")
|
||||||
|
if result.Detail == "" {
|
||||||
|
result.Detail = fmt.Sprintf("包含可疑内容: %s", pattern)
|
||||||
|
} else {
|
||||||
|
result.Detail += fmt.Sprintf(", %s", pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditImage 审核图片
|
||||||
|
func (p *localAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
|
||||||
|
// 本地图片审核逻辑
|
||||||
|
// 1. 图片URL合法性检查
|
||||||
|
// 2. 图片格式检查
|
||||||
|
// 3. 可以扩展接入本地图片识别服务
|
||||||
|
|
||||||
|
result := &AuditResult{
|
||||||
|
Pass: true,
|
||||||
|
Risk: "low",
|
||||||
|
Suggest: "pass",
|
||||||
|
Labels: []string{},
|
||||||
|
Provider: "local",
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查URL是否为空
|
||||||
|
if imageURL == "" {
|
||||||
|
result.Detail = "Empty image URL"
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否为支持的图片URL格式
|
||||||
|
validPrefixes := []string{"http://", "https://", "s3://", "oss://", "cos://"}
|
||||||
|
isValid := false
|
||||||
|
for _, prefix := range validPrefixes {
|
||||||
|
if strings.HasPrefix(strings.ToLower(imageURL), prefix) {
|
||||||
|
isValid = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValid {
|
||||||
|
result.Pass = false
|
||||||
|
result.Risk = "medium"
|
||||||
|
result.Suggest = "review"
|
||||||
|
result.Detail = "Invalid image URL format"
|
||||||
|
result.Labels = append(result.Labels, "invalid_url")
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSensitiveService 设置敏感词服务
|
||||||
|
func (p *localAuditProvider) SetSensitiveService(ss SensitiveService) {
|
||||||
|
p.sensitiveService = ss
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 阿里云审核服务提供商 ====================
|
||||||
|
|
||||||
|
// aliyunAuditProvider 阿里云审核服务提供商
|
||||||
|
type aliyunAuditProvider struct {
|
||||||
|
accessKey string
|
||||||
|
secretKey string
|
||||||
|
region string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAliyunAuditProvider 创建阿里云审核服务提供商
|
||||||
|
func NewAliyunAuditProvider(accessKey, secretKey, region string) AuditServiceProvider {
|
||||||
|
return &aliyunAuditProvider{
|
||||||
|
accessKey: accessKey,
|
||||||
|
secretKey: secretKey,
|
||||||
|
region: region,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetName 获取提供商名称
|
||||||
|
func (p *aliyunAuditProvider) GetName() string {
|
||||||
|
return "aliyun"
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditText 审核文本
|
||||||
|
func (p *aliyunAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
|
||||||
|
// 阿里云内容安全API调用
|
||||||
|
// 实际项目中需要实现阿里云SDK调用
|
||||||
|
// 这里预留接口
|
||||||
|
|
||||||
|
result := &AuditResult{
|
||||||
|
Pass: true,
|
||||||
|
Risk: "low",
|
||||||
|
Suggest: "pass",
|
||||||
|
Labels: []string{},
|
||||||
|
Provider: "aliyun",
|
||||||
|
Detail: "Aliyun audit not implemented, using pass",
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: 实现阿里云内容安全API调用
|
||||||
|
// 具体参考: https://help.aliyun.com/document_detail/28417.html
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditImage 审核图片
|
||||||
|
func (p *aliyunAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
|
||||||
|
result := &AuditResult{
|
||||||
|
Pass: true,
|
||||||
|
Risk: "low",
|
||||||
|
Suggest: "pass",
|
||||||
|
Labels: []string{},
|
||||||
|
Provider: "aliyun",
|
||||||
|
Detail: "Aliyun image audit not implemented, using pass",
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: 实现阿里云图片审核API调用
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 腾讯云审核服务提供商 ====================
|
||||||
|
|
||||||
|
// tencentAuditProvider 腾讯云审核服务提供商
|
||||||
|
type tencentAuditProvider struct {
|
||||||
|
secretID string
|
||||||
|
secretKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTencentAuditProvider 创建腾讯云审核服务提供商
|
||||||
|
func NewTencentAuditProvider(secretID, secretKey string) AuditServiceProvider {
|
||||||
|
return &tencentAuditProvider{
|
||||||
|
secretID: secretID,
|
||||||
|
secretKey: secretKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetName 获取提供商名称
|
||||||
|
func (p *tencentAuditProvider) GetName() string {
|
||||||
|
return "tencent"
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditText 审核文本
|
||||||
|
func (p *tencentAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
|
||||||
|
result := &AuditResult{
|
||||||
|
Pass: true,
|
||||||
|
Risk: "low",
|
||||||
|
Suggest: "pass",
|
||||||
|
Labels: []string{},
|
||||||
|
Provider: "tencent",
|
||||||
|
Detail: "Tencent audit not implemented, using pass",
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: 实现腾讯云内容审核API调用
|
||||||
|
// 具体参考: https://cloud.tencent.com/document/product/1124/64508
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditImage 审核图片
|
||||||
|
func (p *tencentAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
|
||||||
|
result := &AuditResult{
|
||||||
|
Pass: true,
|
||||||
|
Risk: "low",
|
||||||
|
Suggest: "pass",
|
||||||
|
Labels: []string{},
|
||||||
|
Provider: "tencent",
|
||||||
|
Detail: "Tencent image audit not implemented, using pass",
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: 实现腾讯云图片审核API调用
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 百度云审核服务提供商 ====================
|
||||||
|
|
||||||
|
// baiduAuditProvider 百度云审核服务提供商
|
||||||
|
type baiduAuditProvider struct {
|
||||||
|
apiKey string
|
||||||
|
secretKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBaiduAuditProvider 创建百度云审核服务提供商
|
||||||
|
func NewBaiduAuditProvider(apiKey, secretKey string) AuditServiceProvider {
|
||||||
|
return &baiduAuditProvider{
|
||||||
|
apiKey: apiKey,
|
||||||
|
secretKey: secretKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetName 获取提供商名称
|
||||||
|
func (p *baiduAuditProvider) GetName() string {
|
||||||
|
return "baidu"
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditText 审核文本
|
||||||
|
func (p *baiduAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
|
||||||
|
result := &AuditResult{
|
||||||
|
Pass: true,
|
||||||
|
Risk: "low",
|
||||||
|
Suggest: "pass",
|
||||||
|
Labels: []string{},
|
||||||
|
Provider: "baidu",
|
||||||
|
Detail: "Baidu audit not implemented, using pass",
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: 实现百度云内容审核API调用
|
||||||
|
// 具体参考: https://cloud.baidu.com/doc/ANTISPAM/s/Jjw0r1iF6
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditImage 审核图片
|
||||||
|
func (p *baiduAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
|
||||||
|
result := &AuditResult{
|
||||||
|
Pass: true,
|
||||||
|
Risk: "low",
|
||||||
|
Suggest: "pass",
|
||||||
|
Labels: []string{},
|
||||||
|
Provider: "baidu",
|
||||||
|
Detail: "Baidu image audit not implemented, using pass",
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: 实现百度云图片审核API调用
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 审核结果回调处理 ====================
|
||||||
|
|
||||||
|
// AuditCallback 审核回调处理
|
||||||
|
type AuditCallback struct {
|
||||||
|
service AuditService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuditCallback 创建审核回调处理
|
||||||
|
func NewAuditCallback(service AuditService) *AuditCallback {
|
||||||
|
return &AuditCallback{
|
||||||
|
service: service,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleTextCallback 处理文本审核回调
|
||||||
|
func (c *AuditCallback) HandleTextCallback(ctx context.Context, auditID string, result *AuditResult) error {
|
||||||
|
if c.service == nil || auditID == "" || result == nil {
|
||||||
|
return fmt.Errorf("invalid parameters")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Processing text audit callback: auditID=%s, result=%+v", auditID, result)
|
||||||
|
|
||||||
|
// 根据审核结果执行相应操作
|
||||||
|
// 例如: 更新帖子状态、发送通知等
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleImageCallback 处理图片审核回调
|
||||||
|
func (c *AuditCallback) HandleImageCallback(ctx context.Context, auditID string, result *AuditResult) error {
|
||||||
|
if c.service == nil || auditID == "" || result == nil {
|
||||||
|
return fmt.Errorf("invalid parameters")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Processing image audit callback: auditID=%s, result=%+v", auditID, result)
|
||||||
|
|
||||||
|
// 根据审核结果执行相应操作
|
||||||
|
// 例如: 更新图片状态、删除违规图片等
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 辅助函数 ====================
|
||||||
|
|
||||||
|
// IsContentSafe 判断内容是否安全
|
||||||
|
func IsContentSafe(result *AuditResult) bool {
|
||||||
|
if result == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return result.Pass && result.Suggest != "block"
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeedReview 判断内容是否需要人工复审
|
||||||
|
func NeedReview(result *AuditResult) bool {
|
||||||
|
if result == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return result.Suggest == "review"
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRiskLevel 获取风险等级
|
||||||
|
func GetRiskLevel(result *AuditResult) string {
|
||||||
|
if result == nil {
|
||||||
|
return "low"
|
||||||
|
}
|
||||||
|
return result.Risk
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatAuditResult 格式化审核结果为字符串
|
||||||
|
func FormatAuditResult(result *AuditResult) string {
|
||||||
|
if result == nil {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(result)
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseAuditResult 从字符串解析审核结果
|
||||||
|
func ParseAuditResult(data string) (*AuditResult, error) {
|
||||||
|
if data == "" {
|
||||||
|
return nil, fmt.Errorf("empty data")
|
||||||
|
}
|
||||||
|
var result AuditResult
|
||||||
|
if err := json.Unmarshal([]byte(data), &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 审核日志查询 ====================
|
||||||
|
|
||||||
|
// GetAuditLogs 获取审核日志列表
|
||||||
|
func GetAuditLogs(db *gorm.DB, targetType string, targetID string, result string, page, pageSize int) ([]model.AuditLog, int64, error) {
|
||||||
|
query := db.Model(&model.AuditLog{})
|
||||||
|
|
||||||
|
if targetType != "" {
|
||||||
|
query = query.Where("target_type = ?", targetType)
|
||||||
|
}
|
||||||
|
if targetID != "" {
|
||||||
|
query = query.Where("target_id = ?", targetID)
|
||||||
|
}
|
||||||
|
if result != "" {
|
||||||
|
query = query.Where("result = ?", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
var total int64
|
||||||
|
if err := query.Count(&total).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var logs []model.AuditLog
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&logs).Error; err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return logs, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 定时任务 ====================
|
||||||
|
|
||||||
|
// AuditScheduler 审核调度器
|
||||||
|
type AuditScheduler struct {
|
||||||
|
db *gorm.DB
|
||||||
|
service AuditService
|
||||||
|
interval time.Duration
|
||||||
|
stopCh chan bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuditScheduler 创建审核调度器
|
||||||
|
func NewAuditScheduler(db *gorm.DB, service AuditService, interval time.Duration) *AuditScheduler {
|
||||||
|
return &AuditScheduler{
|
||||||
|
db: db,
|
||||||
|
service: service,
|
||||||
|
interval: interval,
|
||||||
|
stopCh: make(chan bool),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start 启动调度器
|
||||||
|
func (s *AuditScheduler) Start() {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(s.interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
s.processPendingAudits()
|
||||||
|
case <-s.stopCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop 停止调度器
|
||||||
|
func (s *AuditScheduler) Stop() {
|
||||||
|
s.stopCh <- true
|
||||||
|
}
|
||||||
|
|
||||||
|
// processPendingAudits 处理待审核内容
|
||||||
|
func (s *AuditScheduler) processPendingAudits() {
|
||||||
|
// 查询待审核的内容
|
||||||
|
// 1. 查询审核状态为 pending 的记录
|
||||||
|
// 2. 调用审核服务
|
||||||
|
// 3. 更新审核状态
|
||||||
|
|
||||||
|
// 示例逻辑,实际需要根据业务需求实现
|
||||||
|
log.Println("Processing pending audits...")
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupOldLogs 清理旧的审核日志
|
||||||
|
func CleanupOldLogs(db *gorm.DB, days int) error {
|
||||||
|
// 清理指定天数之前的审核日志
|
||||||
|
cutoffTime := time.Now().AddDate(0, 0, -days)
|
||||||
|
return db.Where("created_at < ? AND result = ?", cutoffTime, model.AuditResultPass).Delete(&model.AuditLog{}).Error
|
||||||
|
}
|
||||||
622
internal/service/chat_service.go
Normal file
622
internal/service/chat_service.go
Normal file
@@ -0,0 +1,622 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/pkg/websocket"
|
||||||
|
"carrot_bbs/internal/repository"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 撤回消息的时间限制(2分钟)
|
||||||
|
const RecallMessageTimeout = 2 * time.Minute
|
||||||
|
|
||||||
|
// ChatService 聊天服务接口
|
||||||
|
type ChatService interface {
|
||||||
|
// 会话管理
|
||||||
|
GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error)
|
||||||
|
GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error)
|
||||||
|
GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error)
|
||||||
|
DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error
|
||||||
|
SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error
|
||||||
|
|
||||||
|
// 消息操作
|
||||||
|
SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
|
||||||
|
GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error)
|
||||||
|
GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error)
|
||||||
|
GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error)
|
||||||
|
|
||||||
|
// 已读管理
|
||||||
|
MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error
|
||||||
|
GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error)
|
||||||
|
GetAllUnreadCount(ctx context.Context, userID string) (int64, error)
|
||||||
|
|
||||||
|
// 消息扩展功能
|
||||||
|
RecallMessage(ctx context.Context, messageID string, userID string) error
|
||||||
|
DeleteMessage(ctx context.Context, messageID string, userID string) error
|
||||||
|
|
||||||
|
// WebSocket相关
|
||||||
|
SendTyping(ctx context.Context, senderID string, conversationID string)
|
||||||
|
BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string)
|
||||||
|
|
||||||
|
// 系统消息推送
|
||||||
|
IsUserOnline(userID string) bool
|
||||||
|
PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error
|
||||||
|
PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error
|
||||||
|
PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error
|
||||||
|
|
||||||
|
// 仅保存消息到数据库,不发送 WebSocket 推送(供群聊等自行推送的场景使用)
|
||||||
|
SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// chatServiceImpl 聊天服务实现
|
||||||
|
type chatServiceImpl struct {
|
||||||
|
db *gorm.DB
|
||||||
|
repo *repository.MessageRepository
|
||||||
|
userRepo *repository.UserRepository
|
||||||
|
sensitive SensitiveService
|
||||||
|
wsManager *websocket.WebSocketManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewChatService 创建聊天服务
|
||||||
|
func NewChatService(
|
||||||
|
db *gorm.DB,
|
||||||
|
repo *repository.MessageRepository,
|
||||||
|
userRepo *repository.UserRepository,
|
||||||
|
sensitive SensitiveService,
|
||||||
|
wsManager *websocket.WebSocketManager,
|
||||||
|
) ChatService {
|
||||||
|
return &chatServiceImpl{
|
||||||
|
db: db,
|
||||||
|
repo: repo,
|
||||||
|
userRepo: userRepo,
|
||||||
|
sensitive: sensitive,
|
||||||
|
wsManager: wsManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrCreateConversation 获取或创建私聊会话
|
||||||
|
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
|
||||||
|
return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversationList 获取用户的会话列表
|
||||||
|
func (s *chatServiceImpl) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
|
||||||
|
return s.repo.GetConversations(userID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversationByID 获取会话详情
|
||||||
|
func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) {
|
||||||
|
// 验证用户是否是会话参与者
|
||||||
|
participant, err := s.repo.GetParticipant(conversationID, userID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, errors.New("conversation not found or no permission")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取会话信息
|
||||||
|
conv, err := s.repo.GetConversation(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 填充用户的已读位置信息
|
||||||
|
_ = participant // 可以用于返回已读位置等信息
|
||||||
|
|
||||||
|
return conv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteConversationForSelf 仅自己删除会话
|
||||||
|
func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error {
|
||||||
|
participant, err := s.repo.GetParticipant(conversationID, userID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return errors.New("conversation not found or no permission")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to get participant: %w", err)
|
||||||
|
}
|
||||||
|
if participant.ConversationID == "" {
|
||||||
|
return errors.New("conversation not found or no permission")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.repo.HideConversationForUser(conversationID, userID); err != nil {
|
||||||
|
return fmt.Errorf("failed to hide conversation: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetConversationPinned 设置会话置顶(用户维度)
|
||||||
|
func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error {
|
||||||
|
participant, err := s.repo.GetParticipant(conversationID, userID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return errors.New("conversation not found or no permission")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to get participant: %w", err)
|
||||||
|
}
|
||||||
|
if participant.ConversationID == "" {
|
||||||
|
return errors.New("conversation not found or no permission")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.repo.UpdatePinned(conversationID, userID, isPinned); err != nil {
|
||||||
|
return fmt.Errorf("failed to update pinned status: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessage 发送消息(使用 segments)
|
||||||
|
func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
|
||||||
|
// 首先验证会话是否存在
|
||||||
|
conv, err := s.repo.GetConversation(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, errors.New("会话不存在,请重新创建会话")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 拉黑限制:仅拦截“被拉黑方 -> 拉黑人”方向
|
||||||
|
if conv.Type == model.ConversationTypePrivate && s.userRepo != nil {
|
||||||
|
participants, pErr := s.repo.GetConversationParticipants(conversationID)
|
||||||
|
if pErr != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get participants: %w", pErr)
|
||||||
|
}
|
||||||
|
var sentCount *int64
|
||||||
|
for _, p := range participants {
|
||||||
|
if p.UserID == senderID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
blocked, bErr := s.userRepo.IsBlocked(p.UserID, senderID)
|
||||||
|
if bErr != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check block status: %w", bErr)
|
||||||
|
}
|
||||||
|
if blocked {
|
||||||
|
return nil, ErrUserBlocked
|
||||||
|
}
|
||||||
|
|
||||||
|
// 陌生人限制:对方未回关前,只允许发送一条文本消息,且禁止发送图片
|
||||||
|
isFollowedBack, fErr := s.userRepo.IsFollowing(p.UserID, senderID)
|
||||||
|
if fErr != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check follow status: %w", fErr)
|
||||||
|
}
|
||||||
|
if !isFollowedBack {
|
||||||
|
if containsImageSegment(segments) {
|
||||||
|
return nil, errors.New("对方未关注你,暂不支持发送图片")
|
||||||
|
}
|
||||||
|
if sentCount == nil {
|
||||||
|
c, cErr := s.repo.CountMessagesBySenderInConversation(conversationID, senderID)
|
||||||
|
if cErr != nil {
|
||||||
|
return nil, fmt.Errorf("failed to count sender messages: %w", cErr)
|
||||||
|
}
|
||||||
|
sentCount = &c
|
||||||
|
}
|
||||||
|
if *sentCount >= 1 {
|
||||||
|
return nil, errors.New("对方未关注你前,仅允许发送一条消息")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证用户是否是会话参与者
|
||||||
|
participant, err := s.repo.GetParticipant(conversationID, senderID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, errors.New("您不是该会话的参与者")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建消息
|
||||||
|
message := &model.Message{
|
||||||
|
ConversationID: conversationID,
|
||||||
|
SenderID: senderID, // 直接使用string类型的UUID
|
||||||
|
Segments: segments,
|
||||||
|
ReplyToID: replyToID,
|
||||||
|
Status: model.MessageStatusNormal,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用事务创建消息并更新seq
|
||||||
|
if err := s.repo.CreateMessageWithSeq(message); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to save message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送消息给接收者
|
||||||
|
log.Printf("[DEBUG SendMessage] 私聊消息 segments 类型: %T, 值: %+v", message.Segments, message.Segments)
|
||||||
|
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, websocket.ChatMessage{
|
||||||
|
ID: message.ID,
|
||||||
|
ConversationID: message.ConversationID,
|
||||||
|
SenderID: senderID,
|
||||||
|
Segments: message.Segments,
|
||||||
|
Seq: message.Seq,
|
||||||
|
CreatedAt: message.CreatedAt.UnixMilli(),
|
||||||
|
})
|
||||||
|
|
||||||
|
// 获取会话中的其他参与者
|
||||||
|
participants, err := s.repo.GetConversationParticipants(conversationID)
|
||||||
|
if err == nil {
|
||||||
|
for _, p := range participants {
|
||||||
|
// 不发给自己
|
||||||
|
if p.UserID == senderID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 如果接收者在线,发送实时消息
|
||||||
|
if s.wsManager != nil {
|
||||||
|
isOnline := s.wsManager.IsUserOnline(p.UserID)
|
||||||
|
log.Printf("[DEBUG SendMessage] 接收者 UserID=%s, 在线状态=%v", p.UserID, isOnline)
|
||||||
|
if isOnline {
|
||||||
|
log.Printf("[DEBUG SendMessage] 发送WebSocket消息给 UserID=%s, 消息类型=%s", p.UserID, wsMsg.Type)
|
||||||
|
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Printf("[DEBUG SendMessage] 获取参与者失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = participant // 避免未使用变量警告
|
||||||
|
|
||||||
|
return message, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsImageSegment(segments model.MessageSegments) bool {
|
||||||
|
for _, seg := range segments {
|
||||||
|
if seg.Type == string(model.ContentTypeImage) || seg.Type == "image" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessages 获取消息历史(分页)
|
||||||
|
func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) {
|
||||||
|
// 验证用户是否是会话参与者
|
||||||
|
_, err := s.repo.GetParticipant(conversationID, userID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, 0, errors.New("conversation not found or no permission")
|
||||||
|
}
|
||||||
|
return nil, 0, fmt.Errorf("failed to get participant: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.repo.GetMessages(conversationID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessagesAfterSeq 获取指定seq之后的消息(用于增量同步)
|
||||||
|
func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) {
|
||||||
|
// 验证用户是否是会话参与者
|
||||||
|
_, err := s.repo.GetParticipant(conversationID, userID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, errors.New("conversation not found or no permission")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 100
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.repo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessagesBeforeSeq 获取指定seq之前的历史消息(用于下拉加载更多)
|
||||||
|
func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) {
|
||||||
|
// 验证用户是否是会话参与者
|
||||||
|
_, err := s.repo.GetParticipant(conversationID, userID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, errors.New("conversation not found or no permission")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 20
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.repo.GetMessagesBeforeSeq(conversationID, beforeSeq, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAsRead 标记已读
|
||||||
|
func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error {
|
||||||
|
// 验证用户是否是会话参与者
|
||||||
|
_, err := s.repo.GetParticipant(conversationID, userID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return errors.New("conversation not found or no permission")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to get participant: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新参与者的已读位置
|
||||||
|
err = s.repo.UpdateLastReadSeq(conversationID, userID, seq)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to update last read seq: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送已读回执(作为 meta 事件)
|
||||||
|
if s.wsManager != nil {
|
||||||
|
wsMsg := websocket.CreateWSMessage("meta", map[string]interface{}{
|
||||||
|
"detail_type": websocket.MetaDetailTypeRead,
|
||||||
|
"conversation_id": conversationID,
|
||||||
|
"seq": seq,
|
||||||
|
"user_id": userID,
|
||||||
|
})
|
||||||
|
|
||||||
|
// 获取会话中的所有参与者
|
||||||
|
participants, err := s.repo.GetConversationParticipants(conversationID)
|
||||||
|
if err == nil {
|
||||||
|
// 推送给会话中的所有参与者(包括自己)
|
||||||
|
for _, p := range participants {
|
||||||
|
if s.wsManager.IsUserOnline(p.UserID) {
|
||||||
|
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnreadCount 获取指定会话的未读消息数
|
||||||
|
func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
|
||||||
|
// 验证用户是否是会话参与者
|
||||||
|
_, err := s.repo.GetParticipant(conversationID, userID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return 0, errors.New("conversation not found or no permission")
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("failed to get participant: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.repo.GetUnreadCount(conversationID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllUnreadCount 获取所有会话的未读消息总数
|
||||||
|
func (s *chatServiceImpl) GetAllUnreadCount(ctx context.Context, userID string) (int64, error) {
|
||||||
|
return s.repo.GetAllUnreadCount(userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecallMessage 撤回消息(2分钟内)
|
||||||
|
func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, userID string) error {
|
||||||
|
// 获取消息
|
||||||
|
var message model.Message
|
||||||
|
err := s.db.First(&message, "id = ?", messageID).Error
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return errors.New("message not found")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to get message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证是否是消息发送者
|
||||||
|
if message.SenderIDStr() != userID {
|
||||||
|
return errors.New("can only recall your own messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证消息是否已被撤回
|
||||||
|
if message.Status == model.MessageStatusRecalled {
|
||||||
|
return errors.New("message already recalled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证是否在2分钟内
|
||||||
|
if time.Since(message.CreatedAt) > RecallMessageTimeout {
|
||||||
|
return errors.New("message recall timeout (2 minutes)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新消息状态为已撤回
|
||||||
|
err = s.db.Model(&message).Update("status", model.MessageStatusRecalled).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to recall message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送撤回通知
|
||||||
|
if s.wsManager != nil {
|
||||||
|
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeRecall, map[string]interface{}{
|
||||||
|
"messageId": messageID,
|
||||||
|
"conversationId": message.ConversationID,
|
||||||
|
"senderId": userID,
|
||||||
|
})
|
||||||
|
|
||||||
|
// 通知会话中的所有参与者
|
||||||
|
participants, err := s.repo.GetConversationParticipants(message.ConversationID)
|
||||||
|
if err == nil {
|
||||||
|
for _, p := range participants {
|
||||||
|
if s.wsManager.IsUserOnline(p.UserID) {
|
||||||
|
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteMessage 删除消息(仅对自己可见)
|
||||||
|
func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, userID string) error {
|
||||||
|
// 获取消息
|
||||||
|
var message model.Message
|
||||||
|
err := s.db.First(&message, "id = ?", messageID).Error
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return errors.New("message not found")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to get message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证用户是否是会话参与者
|
||||||
|
_, err = s.repo.GetParticipant(message.ConversationID, userID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return errors.New("no permission to delete this message")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to get participant: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对于删除消息,我们使用软删除,但需要确保只对当前用户隐藏
|
||||||
|
// 这里简化处理:只有发送者可以删除自己的消息
|
||||||
|
if message.SenderIDStr() != userID {
|
||||||
|
return errors.New("can only delete your own messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新消息状态为已删除
|
||||||
|
err = s.db.Model(&message).Update("status", model.MessageStatusDeleted).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendTyping 发送正在输入状态
|
||||||
|
func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) {
|
||||||
|
if s.wsManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证用户是否是会话参与者
|
||||||
|
_, err := s.repo.GetParticipant(conversationID, senderID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取会话中的其他参与者
|
||||||
|
participants, err := s.repo.GetConversationParticipants(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range participants {
|
||||||
|
if p.UserID == senderID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 发送正在输入状态
|
||||||
|
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeTyping, map[string]string{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"senderId": senderID,
|
||||||
|
})
|
||||||
|
|
||||||
|
if s.wsManager.IsUserOnline(p.UserID) {
|
||||||
|
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BroadcastMessage 广播消息给用户
|
||||||
|
func (s *chatServiceImpl) BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string) {
|
||||||
|
if s.wsManager != nil {
|
||||||
|
s.wsManager.SendToUser(targetUser, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsUserOnline 检查用户是否在线
|
||||||
|
func (s *chatServiceImpl) IsUserOnline(userID string) bool {
|
||||||
|
if s.wsManager == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.wsManager.IsUserOnline(userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushSystemMessage 推送系统消息给指定用户
|
||||||
|
func (s *chatServiceImpl) PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error {
|
||||||
|
if s.wsManager == nil {
|
||||||
|
return errors.New("websocket manager not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.wsManager.IsUserOnline(userID) {
|
||||||
|
return errors.New("user is offline")
|
||||||
|
}
|
||||||
|
|
||||||
|
sysMsg := &websocket.SystemMessage{
|
||||||
|
ID: "", // 由调用方生成
|
||||||
|
Type: msgType,
|
||||||
|
Title: title,
|
||||||
|
Content: content,
|
||||||
|
Data: data,
|
||||||
|
CreatedAt: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
|
||||||
|
s.wsManager.SendToUser(userID, wsMsg)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushNotificationMessage 推送通知消息给指定用户
|
||||||
|
func (s *chatServiceImpl) PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error {
|
||||||
|
if s.wsManager == nil {
|
||||||
|
return errors.New("websocket manager not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.wsManager.IsUserOnline(userID) {
|
||||||
|
return errors.New("user is offline")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保时间戳已设置
|
||||||
|
if notification.CreatedAt == 0 {
|
||||||
|
notification.CreatedAt = time.Now().UnixMilli()
|
||||||
|
}
|
||||||
|
|
||||||
|
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
|
||||||
|
s.wsManager.SendToUser(userID, wsMsg)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushAnnouncementMessage 广播公告消息给所有在线用户
|
||||||
|
func (s *chatServiceImpl) PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error {
|
||||||
|
if s.wsManager == nil {
|
||||||
|
return errors.New("websocket manager not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保时间戳已设置
|
||||||
|
if announcement.CreatedAt == 0 {
|
||||||
|
announcement.CreatedAt = time.Now().UnixMilli()
|
||||||
|
}
|
||||||
|
|
||||||
|
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
|
||||||
|
s.wsManager.Broadcast(wsMsg)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveMessage 仅保存消息到数据库,不发送 WebSocket 推送
|
||||||
|
// 适用于群聊等由调用方自行负责推送的场景
|
||||||
|
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
|
||||||
|
// 验证会话是否存在
|
||||||
|
_, err := s.repo.GetConversation(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, errors.New("会话不存在,请重新创建会话")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证用户是否是会话参与者
|
||||||
|
_, err = s.repo.GetParticipant(conversationID, senderID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, errors.New("您不是该会话的参与者")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
message := &model.Message{
|
||||||
|
ConversationID: conversationID,
|
||||||
|
SenderID: senderID,
|
||||||
|
Segments: segments,
|
||||||
|
ReplyToID: replyToID,
|
||||||
|
Status: model.MessageStatusNormal,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.repo.CreateMessageWithSeq(message); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to save message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return message, nil
|
||||||
|
}
|
||||||
273
internal/service/comment_service.go
Normal file
273
internal/service/comment_service.go
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/pkg/gorse"
|
||||||
|
"carrot_bbs/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CommentService 评论服务
|
||||||
|
type CommentService struct {
|
||||||
|
commentRepo *repository.CommentRepository
|
||||||
|
postRepo *repository.PostRepository
|
||||||
|
systemMessageService SystemMessageService
|
||||||
|
gorseClient gorse.Client
|
||||||
|
postAIService *PostAIService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCommentService 创建评论服务
|
||||||
|
func NewCommentService(commentRepo *repository.CommentRepository, postRepo *repository.PostRepository, systemMessageService SystemMessageService, gorseClient gorse.Client, postAIService *PostAIService) *CommentService {
|
||||||
|
return &CommentService{
|
||||||
|
commentRepo: commentRepo,
|
||||||
|
postRepo: postRepo,
|
||||||
|
systemMessageService: systemMessageService,
|
||||||
|
gorseClient: gorseClient,
|
||||||
|
postAIService: postAIService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 创建评论
|
||||||
|
func (s *CommentService) Create(ctx context.Context, postID, userID, content string, parentID *string, images string, imageURLs []string) (*model.Comment, error) {
|
||||||
|
if s.postAIService != nil {
|
||||||
|
// 采用异步审核,前端先立即返回
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取帖子信息用于发送通知
|
||||||
|
post, err := s.postRepo.GetByID(postID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
comment := &model.Comment{
|
||||||
|
PostID: postID,
|
||||||
|
UserID: userID,
|
||||||
|
Content: content,
|
||||||
|
ParentID: parentID,
|
||||||
|
Images: images,
|
||||||
|
Status: model.CommentStatusPending,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有父评论,设置根评论ID
|
||||||
|
var parentUserID string
|
||||||
|
if parentID != nil {
|
||||||
|
parent, err := s.commentRepo.GetByID(*parentID)
|
||||||
|
if err == nil && parent != nil {
|
||||||
|
if parent.RootID != nil {
|
||||||
|
comment.RootID = parent.RootID
|
||||||
|
} else {
|
||||||
|
comment.RootID = parentID
|
||||||
|
}
|
||||||
|
parentUserID = parent.UserID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.commentRepo.Create(comment)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重新查询以获取关联的 User
|
||||||
|
comment, err = s.commentRepo.GetByID(comment.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.reviewCommentAsync(comment.ID, userID, postID, content, imageURLs, parentID, parentUserID, post.UserID)
|
||||||
|
|
||||||
|
return comment, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CommentService) reviewCommentAsync(
|
||||||
|
commentID, userID, postID, content string,
|
||||||
|
imageURLs []string,
|
||||||
|
parentID *string,
|
||||||
|
parentUserID string,
|
||||||
|
postOwnerID string,
|
||||||
|
) {
|
||||||
|
// 未启用AI时,直接通过审核并发送后续通知
|
||||||
|
if s.postAIService == nil || !s.postAIService.IsEnabled() {
|
||||||
|
if err := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); err != nil {
|
||||||
|
log.Printf("[WARN] Failed to publish comment without AI moderation: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.postAIService.ModerateComment(context.Background(), content, imageURLs)
|
||||||
|
if err != nil {
|
||||||
|
var rejectedErr *CommentModerationRejectedError
|
||||||
|
if errors.As(err, &rejectedErr) {
|
||||||
|
if delErr := s.commentRepo.Delete(commentID); delErr != nil {
|
||||||
|
log.Printf("[WARN] Failed to delete rejected comment %s: %v", commentID, delErr)
|
||||||
|
}
|
||||||
|
s.notifyCommentModerationRejected(userID, rejectedErr.Reason)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 审核服务异常时降级放行,避免评论长期pending
|
||||||
|
if updateErr := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); updateErr != nil {
|
||||||
|
log.Printf("[WARN] Failed to publish comment %s after moderation error: %v", commentID, updateErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("[WARN] Comment moderation failed, fallback publish comment=%s err=%v", commentID, err)
|
||||||
|
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if updateErr := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); updateErr != nil {
|
||||||
|
log.Printf("[WARN] Failed to publish comment %s: %v", commentID, updateErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CommentService) afterCommentPublished(userID, postID, commentID string, parentID *string, parentUserID, postOwnerID string) {
|
||||||
|
// 发送系统消息通知
|
||||||
|
if s.systemMessageService != nil {
|
||||||
|
go func() {
|
||||||
|
if parentID != nil && parentUserID != "" {
|
||||||
|
// 回复评论,通知被回复的人
|
||||||
|
if parentUserID != userID {
|
||||||
|
notifyErr := s.systemMessageService.SendReplyNotification(context.Background(), parentUserID, userID, postID, *parentID, commentID)
|
||||||
|
if notifyErr != nil {
|
||||||
|
fmt.Printf("[DEBUG] Error sending reply notification: %v\n", notifyErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 评论帖子,通知帖子作者
|
||||||
|
if postOwnerID != userID {
|
||||||
|
notifyErr := s.systemMessageService.SendCommentNotification(context.Background(), postOwnerID, userID, postID, commentID)
|
||||||
|
if notifyErr != nil {
|
||||||
|
fmt.Printf("[DEBUG] Error sending comment notification: %v\n", notifyErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 推送评论行为到Gorse(异步)
|
||||||
|
go func() {
|
||||||
|
if s.gorseClient.IsEnabled() {
|
||||||
|
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeComment, userID, postID); err != nil {
|
||||||
|
log.Printf("[WARN] Failed to insert comment feedback to Gorse: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CommentService) notifyCommentModerationRejected(userID, reason string) {
|
||||||
|
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
content := "您发布的评论未通过AI审核,请修改后重试。"
|
||||||
|
if strings.TrimSpace(reason) != "" {
|
||||||
|
content = fmt.Sprintf("您发布的评论未通过AI审核,原因:%s。请修改后重试。", reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := s.systemMessageService.SendSystemAnnouncement(
|
||||||
|
context.Background(),
|
||||||
|
[]string{userID},
|
||||||
|
"评论审核未通过",
|
||||||
|
content,
|
||||||
|
); err != nil {
|
||||||
|
log.Printf("[WARN] Failed to send comment moderation reject notification: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID 根据ID获取评论
|
||||||
|
func (s *CommentService) GetByID(ctx context.Context, id string) (*model.Comment, error) {
|
||||||
|
return s.commentRepo.GetByID(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByPostID 获取帖子评论
|
||||||
|
func (s *CommentService) GetByPostID(ctx context.Context, postID string, page, pageSize int) ([]*model.Comment, int64, error) {
|
||||||
|
// 使用带回复的查询,默认加载前3条回复
|
||||||
|
return s.commentRepo.GetByPostIDWithReplies(postID, page, pageSize, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRepliesByRootID 根据根评论ID分页获取回复
|
||||||
|
func (s *CommentService) GetRepliesByRootID(ctx context.Context, rootID string, page, pageSize int) ([]*model.Comment, int64, error) {
|
||||||
|
return s.commentRepo.GetRepliesByRootID(rootID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetReplies 获取回复
|
||||||
|
func (s *CommentService) GetReplies(ctx context.Context, parentID string) ([]*model.Comment, error) {
|
||||||
|
return s.commentRepo.GetReplies(parentID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 更新评论
|
||||||
|
func (s *CommentService) Update(ctx context.Context, comment *model.Comment) error {
|
||||||
|
return s.commentRepo.Update(comment)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除评论
|
||||||
|
func (s *CommentService) Delete(ctx context.Context, id string) error {
|
||||||
|
return s.commentRepo.Delete(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Like 点赞评论
|
||||||
|
func (s *CommentService) Like(ctx context.Context, commentID, userID string) error {
|
||||||
|
// 获取评论信息用于发送通知
|
||||||
|
comment, err := s.commentRepo.GetByID(commentID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.commentRepo.Like(commentID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送评论/回复点赞通知(只有不是给自己点赞时才发送)
|
||||||
|
if s.systemMessageService != nil && comment.UserID != userID {
|
||||||
|
go func() {
|
||||||
|
var notifyErr error
|
||||||
|
if comment.ParentID != nil {
|
||||||
|
notifyErr = s.systemMessageService.SendLikeReplyNotification(
|
||||||
|
context.Background(),
|
||||||
|
comment.UserID,
|
||||||
|
userID,
|
||||||
|
comment.PostID,
|
||||||
|
commentID,
|
||||||
|
comment.Content,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
notifyErr = s.systemMessageService.SendLikeCommentNotification(
|
||||||
|
context.Background(),
|
||||||
|
comment.UserID,
|
||||||
|
userID,
|
||||||
|
comment.PostID,
|
||||||
|
commentID,
|
||||||
|
comment.Content,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if notifyErr != nil {
|
||||||
|
fmt.Printf("[DEBUG] Error sending like notification: %v\n", notifyErr)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("[DEBUG] Like notification sent successfully\n")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unlike 取消点赞评论
|
||||||
|
func (s *CommentService) Unlike(ctx context.Context, commentID, userID string) error {
|
||||||
|
return s.commentRepo.Unlike(commentID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLiked 检查是否已点赞
|
||||||
|
func (s *CommentService) IsLiked(ctx context.Context, commentID, userID string) bool {
|
||||||
|
return s.commentRepo.IsLiked(commentID, userID)
|
||||||
|
}
|
||||||
234
internal/service/email_code_service.go
Normal file
234
internal/service/email_code_service.go
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/cache"
|
||||||
|
"carrot_bbs/internal/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
verifyCodeTTL = 10 * time.Minute
|
||||||
|
verifyCodeRateLimitTTL = 60 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CodePurposeRegister = "register"
|
||||||
|
CodePurposePasswordReset = "password_reset"
|
||||||
|
CodePurposeEmailVerify = "email_verify"
|
||||||
|
CodePurposeChangePassword = "change_password"
|
||||||
|
)
|
||||||
|
|
||||||
|
type verificationCodePayload struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Purpose string `json:"purpose"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
ExpiresAt int64 `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmailCodeService interface {
|
||||||
|
SendCode(ctx context.Context, purpose, email string) error
|
||||||
|
VerifyCode(purpose, email, code string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type emailCodeServiceImpl struct {
|
||||||
|
emailService EmailService
|
||||||
|
cache cache.Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEmailCodeService(emailService EmailService, cacheBackend cache.Cache) EmailCodeService {
|
||||||
|
if cacheBackend == nil {
|
||||||
|
cacheBackend = cache.GetCache()
|
||||||
|
}
|
||||||
|
return &emailCodeServiceImpl{
|
||||||
|
emailService: emailService,
|
||||||
|
cache: cacheBackend,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func verificationCodeCacheKey(purpose, email string) string {
|
||||||
|
return fmt.Sprintf("auth:verify_code:%s:%s", purpose, strings.ToLower(strings.TrimSpace(email)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func verificationCodeRateLimitKey(purpose, email string) string {
|
||||||
|
return fmt.Sprintf("auth:verify_code_rate_limit:%s:%s", purpose, strings.ToLower(strings.TrimSpace(email)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateNumericCode(length int) (string, error) {
|
||||||
|
if length <= 0 {
|
||||||
|
return "", fmt.Errorf("invalid code length")
|
||||||
|
}
|
||||||
|
max := big.NewInt(10)
|
||||||
|
result := make([]byte, length)
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
n, err := rand.Int(rand.Reader, max)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
result[i] = byte('0' + n.Int64())
|
||||||
|
}
|
||||||
|
return string(result), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *emailCodeServiceImpl) SendCode(ctx context.Context, purpose, email string) error {
|
||||||
|
if strings.TrimSpace(email) == "" || !utils.ValidateEmail(email) {
|
||||||
|
return ErrInvalidEmail
|
||||||
|
}
|
||||||
|
if s.emailService == nil || !s.emailService.IsEnabled() {
|
||||||
|
return ErrEmailServiceUnavailable
|
||||||
|
}
|
||||||
|
if s.cache == nil {
|
||||||
|
return ErrVerificationCodeUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
rateLimitKey := verificationCodeRateLimitKey(purpose, email)
|
||||||
|
if s.cache.Exists(rateLimitKey) {
|
||||||
|
return ErrVerificationCodeTooFrequent
|
||||||
|
}
|
||||||
|
|
||||||
|
code, err := generateNumericCode(6)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate verification code failed: %w", err)
|
||||||
|
}
|
||||||
|
payload := verificationCodePayload{
|
||||||
|
Code: code,
|
||||||
|
Purpose: purpose,
|
||||||
|
Email: strings.ToLower(strings.TrimSpace(email)),
|
||||||
|
ExpiresAt: time.Now().Add(verifyCodeTTL).Unix(),
|
||||||
|
}
|
||||||
|
cacheKey := verificationCodeCacheKey(purpose, email)
|
||||||
|
s.cache.Set(cacheKey, payload, verifyCodeTTL)
|
||||||
|
s.cache.Set(rateLimitKey, "1", verifyCodeRateLimitTTL)
|
||||||
|
|
||||||
|
subject, sceneText := verificationEmailMeta(purpose)
|
||||||
|
textBody := fmt.Sprintf("【%s】验证码:%s\n有效期:10分钟\n请勿将验证码泄露给他人。", sceneText, code)
|
||||||
|
htmlBody := buildVerificationEmailHTML(sceneText, code)
|
||||||
|
if err := s.emailService.Send(ctx, SendEmailRequest{
|
||||||
|
To: []string{email},
|
||||||
|
Subject: subject,
|
||||||
|
TextBody: textBody,
|
||||||
|
HTMLBody: htmlBody,
|
||||||
|
}); err != nil {
|
||||||
|
s.cache.Delete(cacheKey)
|
||||||
|
return fmt.Errorf("send verification email failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *emailCodeServiceImpl) VerifyCode(purpose, email, code string) error {
|
||||||
|
if strings.TrimSpace(email) == "" || strings.TrimSpace(code) == "" {
|
||||||
|
return ErrVerificationCodeInvalid
|
||||||
|
}
|
||||||
|
if s.cache == nil {
|
||||||
|
return ErrVerificationCodeUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := verificationCodeCacheKey(purpose, email)
|
||||||
|
raw, ok := s.cache.Get(cacheKey)
|
||||||
|
if !ok {
|
||||||
|
return ErrVerificationCodeExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload verificationCodePayload
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case string:
|
||||||
|
if err := json.Unmarshal([]byte(v), &payload); err != nil {
|
||||||
|
return ErrVerificationCodeInvalid
|
||||||
|
}
|
||||||
|
case []byte:
|
||||||
|
if err := json.Unmarshal(v, &payload); err != nil {
|
||||||
|
return ErrVerificationCodeInvalid
|
||||||
|
}
|
||||||
|
case verificationCodePayload:
|
||||||
|
payload = v
|
||||||
|
default:
|
||||||
|
data, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return ErrVerificationCodeInvalid
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &payload); err != nil {
|
||||||
|
return ErrVerificationCodeInvalid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if payload.Purpose != purpose || payload.Email != strings.ToLower(strings.TrimSpace(email)) {
|
||||||
|
return ErrVerificationCodeInvalid
|
||||||
|
}
|
||||||
|
if payload.ExpiresAt > 0 && time.Now().Unix() > payload.ExpiresAt {
|
||||||
|
s.cache.Delete(cacheKey)
|
||||||
|
return ErrVerificationCodeExpired
|
||||||
|
}
|
||||||
|
if payload.Code != strings.TrimSpace(code) {
|
||||||
|
return ErrVerificationCodeInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
s.cache.Delete(cacheKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verificationEmailMeta(purpose string) (subject string, sceneText string) {
|
||||||
|
switch purpose {
|
||||||
|
case CodePurposeRegister:
|
||||||
|
return "Carrot BBS 注册验证码", "注册账号"
|
||||||
|
case CodePurposePasswordReset:
|
||||||
|
return "Carrot BBS 找回密码验证码", "找回密码"
|
||||||
|
case CodePurposeEmailVerify:
|
||||||
|
return "Carrot BBS 邮箱验证验证码", "验证邮箱"
|
||||||
|
case CodePurposeChangePassword:
|
||||||
|
return "Carrot BBS 修改密码验证码", "修改密码"
|
||||||
|
default:
|
||||||
|
return "Carrot BBS 验证码", "身份验证"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildVerificationEmailHTML(sceneText, code string) string {
|
||||||
|
return fmt.Sprintf(`<!doctype html>
|
||||||
|
<html lang="zh-CN">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<title>Carrot BBS 验证码</title>
|
||||||
|
</head>
|
||||||
|
<body style="margin:0;padding:0;background:#f4f6fb;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,'PingFang SC','Microsoft YaHei',sans-serif;color:#1f2937;">
|
||||||
|
<table role="presentation" width="100%%" cellspacing="0" cellpadding="0" style="background:#f4f6fb;padding:24px 12px;">
|
||||||
|
<tr>
|
||||||
|
<td align="center">
|
||||||
|
<table role="presentation" width="100%%" cellspacing="0" cellpadding="0" style="max-width:560px;background:#ffffff;border-radius:14px;overflow:hidden;box-shadow:0 8px 30px rgba(15,23,42,0.08);">
|
||||||
|
<tr>
|
||||||
|
<td style="background:linear-gradient(135deg,#ff6b35,#ff8f66);padding:24px 28px;color:#ffffff;">
|
||||||
|
<div style="font-size:22px;font-weight:700;line-height:1.2;">Carrot BBS</div>
|
||||||
|
<div style="margin-top:6px;font-size:14px;opacity:0.95;">%s 验证</div>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td style="padding:28px;">
|
||||||
|
<p style="margin:0 0 14px;font-size:15px;line-height:1.75;">你好,</p>
|
||||||
|
<p style="margin:0 0 20px;font-size:15px;line-height:1.75;">你正在进行 <strong>%s</strong> 操作,请使用下方验证码完成验证:</p>
|
||||||
|
<div style="margin:0 auto 18px;max-width:320px;border:1px dashed #ff8f66;background:#fff8f4;border-radius:12px;padding:14px 12px;text-align:center;">
|
||||||
|
<div style="font-size:13px;color:#9a3412;letter-spacing:0.5px;">验证码(10分钟内有效)</div>
|
||||||
|
<div style="margin-top:8px;font-size:34px;line-height:1;font-weight:800;letter-spacing:8px;color:#ea580c;">%s</div>
|
||||||
|
</div>
|
||||||
|
<p style="margin:0 0 8px;font-size:13px;color:#6b7280;line-height:1.7;">如果不是你本人操作,请忽略此邮件,并及时检查账号安全。</p>
|
||||||
|
<p style="margin:0;font-size:13px;color:#6b7280;line-height:1.7;">请勿向任何人透露验证码,平台不会以任何理由索取验证码。</p>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td style="padding:14px 28px;background:#f8fafc;border-top:1px solid #e5e7eb;color:#94a3b8;font-size:12px;line-height:1.7;">
|
||||||
|
此邮件由系统自动发送,请勿直接回复。<br/>
|
||||||
|
© Carrot BBS
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
</body>
|
||||||
|
</html>`, sceneText, sceneText, code)
|
||||||
|
}
|
||||||
82
internal/service/email_service.go
Normal file
82
internal/service/email_service.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
emailpkg "carrot_bbs/internal/pkg/email"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SendEmailRequest 发信请求
|
||||||
|
type SendEmailRequest struct {
|
||||||
|
To []string
|
||||||
|
Cc []string
|
||||||
|
Bcc []string
|
||||||
|
ReplyTo []string
|
||||||
|
Subject string
|
||||||
|
TextBody string
|
||||||
|
HTMLBody string
|
||||||
|
Attachments []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmailService interface {
|
||||||
|
IsEnabled() bool
|
||||||
|
Send(ctx context.Context, req SendEmailRequest) error
|
||||||
|
SendText(ctx context.Context, to []string, subject, body string) error
|
||||||
|
SendHTML(ctx context.Context, to []string, subject, html string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type emailServiceImpl struct {
|
||||||
|
client emailpkg.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEmailService(client emailpkg.Client) EmailService {
|
||||||
|
return &emailServiceImpl{client: client}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *emailServiceImpl) IsEnabled() bool {
|
||||||
|
return s.client != nil && s.client.IsEnabled()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *emailServiceImpl) Send(ctx context.Context, req SendEmailRequest) error {
|
||||||
|
if s.client == nil {
|
||||||
|
return fmt.Errorf("email client is nil")
|
||||||
|
}
|
||||||
|
if !s.client.IsEnabled() {
|
||||||
|
return fmt.Errorf("email service is disabled")
|
||||||
|
}
|
||||||
|
if len(req.To) == 0 {
|
||||||
|
return fmt.Errorf("email recipient is empty")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.Subject) == "" {
|
||||||
|
return fmt.Errorf("email subject is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.client.Send(ctx, emailpkg.Message{
|
||||||
|
To: req.To,
|
||||||
|
Cc: req.Cc,
|
||||||
|
Bcc: req.Bcc,
|
||||||
|
ReplyTo: req.ReplyTo,
|
||||||
|
Subject: req.Subject,
|
||||||
|
TextBody: req.TextBody,
|
||||||
|
HTMLBody: req.HTMLBody,
|
||||||
|
Attachments: req.Attachments,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *emailServiceImpl) SendText(ctx context.Context, to []string, subject, body string) error {
|
||||||
|
return s.Send(ctx, SendEmailRequest{
|
||||||
|
To: to,
|
||||||
|
Subject: subject,
|
||||||
|
TextBody: body,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *emailServiceImpl) SendHTML(ctx context.Context, to []string, subject, html string) error {
|
||||||
|
return s.Send(ctx, SendEmailRequest{
|
||||||
|
To: to,
|
||||||
|
Subject: subject,
|
||||||
|
HTMLBody: html,
|
||||||
|
})
|
||||||
|
}
|
||||||
1491
internal/service/group_service.go
Normal file
1491
internal/service/group_service.go
Normal file
File diff suppressed because it is too large
Load Diff
38
internal/service/jwt_service.go
Normal file
38
internal/service/jwt_service.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/pkg/jwt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JWTService JWT服务
|
||||||
|
type JWTService struct {
|
||||||
|
jwt *jwt.JWT
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJWTService 创建JWT服务
|
||||||
|
func NewJWTService(secret string, accessExpire, refreshExpire int64) *JWTService {
|
||||||
|
return &JWTService{
|
||||||
|
jwt: jwt.New(secret, time.Duration(accessExpire)*time.Second, time.Duration(refreshExpire)*time.Second),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateAccessToken 生成访问令牌
|
||||||
|
func (s *JWTService) GenerateAccessToken(userID, username string) (string, error) {
|
||||||
|
return s.jwt.GenerateAccessToken(userID, username)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateRefreshToken 生成刷新令牌
|
||||||
|
func (s *JWTService) GenerateRefreshToken(userID, username string) (string, error) {
|
||||||
|
return s.jwt.GenerateRefreshToken(userID, username)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseToken 解析令牌
|
||||||
|
func (s *JWTService) ParseToken(tokenString string) (*jwt.Claims, error) {
|
||||||
|
return s.jwt.ParseToken(tokenString)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateToken 验证令牌
|
||||||
|
func (s *JWTService) ValidateToken(tokenString string) error {
|
||||||
|
return s.jwt.ValidateToken(tokenString)
|
||||||
|
}
|
||||||
215
internal/service/message_service.go
Normal file
215
internal/service/message_service.go
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/cache"
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 缓存TTL常量
|
||||||
|
const (
|
||||||
|
ConversationListTTL = 60 * time.Second // 会话列表缓存60秒
|
||||||
|
ConversationDetailTTL = 60 * time.Second // 会话详情缓存60秒
|
||||||
|
UnreadCountTTL = 30 * time.Second // 未读数缓存30秒
|
||||||
|
ConversationNullTTL = 5 * time.Second
|
||||||
|
UnreadNullTTL = 5 * time.Second
|
||||||
|
CacheJitterRatio = 0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
// MessageService 消息服务
|
||||||
|
type MessageService struct {
|
||||||
|
messageRepo *repository.MessageRepository
|
||||||
|
cache cache.Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMessageService 创建消息服务
|
||||||
|
func NewMessageService(messageRepo *repository.MessageRepository) *MessageService {
|
||||||
|
return &MessageService{
|
||||||
|
messageRepo: messageRepo,
|
||||||
|
cache: cache.GetCache(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationListResult 会话列表缓存结果
|
||||||
|
type ConversationListResult struct {
|
||||||
|
Conversations []*model.Conversation
|
||||||
|
Total int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessage 发送消息(使用 segments)
|
||||||
|
// senderID 和 receiverID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
|
func (s *MessageService) SendMessage(ctx context.Context, senderID, receiverID string, segments model.MessageSegments) (*model.Message, error) {
|
||||||
|
// 获取或创建会话
|
||||||
|
conv, err := s.messageRepo.GetOrCreatePrivateConversation(senderID, receiverID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := &model.Message{
|
||||||
|
ConversationID: conv.ID,
|
||||||
|
SenderID: senderID,
|
||||||
|
Segments: segments,
|
||||||
|
Status: model.MessageStatusNormal,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用事务创建消息并更新seq
|
||||||
|
err = s.messageRepo.CreateMessageWithSeq(msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效会话列表缓存(发送者和接收者)
|
||||||
|
cache.InvalidateConversationList(s.cache, senderID)
|
||||||
|
cache.InvalidateConversationList(s.cache, receiverID)
|
||||||
|
|
||||||
|
// 失效未读数缓存
|
||||||
|
cache.InvalidateUnreadConversation(s.cache, receiverID)
|
||||||
|
cache.InvalidateUnreadDetail(s.cache, receiverID, conv.ID)
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversations 获取会话列表(带缓存)
|
||||||
|
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
|
func (s *MessageService) GetConversations(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
|
||||||
|
cacheSettings := cache.GetSettings()
|
||||||
|
conversationTTL := cacheSettings.ConversationTTL
|
||||||
|
if conversationTTL <= 0 {
|
||||||
|
conversationTTL = ConversationListTTL
|
||||||
|
}
|
||||||
|
nullTTL := cacheSettings.NullTTL
|
||||||
|
if nullTTL <= 0 {
|
||||||
|
nullTTL = ConversationNullTTL
|
||||||
|
}
|
||||||
|
jitter := cacheSettings.JitterRatio
|
||||||
|
if jitter <= 0 {
|
||||||
|
jitter = CacheJitterRatio
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成缓存键
|
||||||
|
cacheKey := cache.ConversationListKey(userID, page, pageSize)
|
||||||
|
result, err := cache.GetOrLoadTyped[*ConversationListResult](
|
||||||
|
s.cache,
|
||||||
|
cacheKey,
|
||||||
|
conversationTTL,
|
||||||
|
jitter,
|
||||||
|
nullTTL,
|
||||||
|
func() (*ConversationListResult, error) {
|
||||||
|
conversations, total, err := s.messageRepo.GetConversations(userID, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &ConversationListResult{
|
||||||
|
Conversations: conversations,
|
||||||
|
Total: total,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
if result == nil {
|
||||||
|
return []*model.Conversation{}, 0, nil
|
||||||
|
}
|
||||||
|
return result.Conversations, result.Total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessages 获取消息列表
|
||||||
|
func (s *MessageService) GetMessages(ctx context.Context, conversationID string, page, pageSize int) ([]*model.Message, int64, error) {
|
||||||
|
return s.messageRepo.GetMessages(conversationID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessagesAfterSeq 获取指定seq之后的消息(增量同步)
|
||||||
|
func (s *MessageService) GetMessagesAfterSeq(ctx context.Context, conversationID string, afterSeq int64, limit int) ([]*model.Message, error) {
|
||||||
|
return s.messageRepo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAsRead 标记为已读
|
||||||
|
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
|
func (s *MessageService) MarkAsRead(ctx context.Context, conversationID string, userID string, lastReadSeq int64) error {
|
||||||
|
err := s.messageRepo.UpdateLastReadSeq(conversationID, userID, lastReadSeq)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效未读数缓存
|
||||||
|
cache.InvalidateUnreadConversation(s.cache, userID)
|
||||||
|
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
|
||||||
|
|
||||||
|
// 失效会话列表缓存
|
||||||
|
cache.InvalidateConversationList(s.cache, userID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnreadCount 获取未读消息数(带缓存)
|
||||||
|
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
|
func (s *MessageService) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
|
||||||
|
cacheSettings := cache.GetSettings()
|
||||||
|
unreadTTL := cacheSettings.UnreadCountTTL
|
||||||
|
if unreadTTL <= 0 {
|
||||||
|
unreadTTL = UnreadCountTTL
|
||||||
|
}
|
||||||
|
nullTTL := cacheSettings.NullTTL
|
||||||
|
if nullTTL <= 0 {
|
||||||
|
nullTTL = UnreadNullTTL
|
||||||
|
}
|
||||||
|
jitter := cacheSettings.JitterRatio
|
||||||
|
if jitter <= 0 {
|
||||||
|
jitter = CacheJitterRatio
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成缓存键
|
||||||
|
cacheKey := cache.UnreadDetailKey(userID, conversationID)
|
||||||
|
|
||||||
|
return cache.GetOrLoadTyped[int64](
|
||||||
|
s.cache,
|
||||||
|
cacheKey,
|
||||||
|
unreadTTL,
|
||||||
|
jitter,
|
||||||
|
nullTTL,
|
||||||
|
func() (int64, error) {
|
||||||
|
return s.messageRepo.GetUnreadCount(conversationID, userID)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrCreateConversation 获取或创建私聊会话
|
||||||
|
// user1ID 和 user2ID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
|
func (s *MessageService) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
|
||||||
|
conv, err := s.messageRepo.GetOrCreatePrivateConversation(user1ID, user2ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效会话列表缓存
|
||||||
|
cache.InvalidateConversationList(s.cache, user1ID)
|
||||||
|
cache.InvalidateConversationList(s.cache, user2ID)
|
||||||
|
|
||||||
|
return conv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversationParticipants 获取会话参与者列表
|
||||||
|
func (s *MessageService) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) {
|
||||||
|
return s.messageRepo.GetConversationParticipants(conversationID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConversationID 辅助函数:直接返回字符串ID(已经是string类型)
|
||||||
|
func ParseConversationID(idStr string) (string, error) {
|
||||||
|
return idStr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateUserConversationCache 失效用户会话相关缓存(供外部调用)
|
||||||
|
func (s *MessageService) InvalidateUserConversationCache(userID string) {
|
||||||
|
cache.InvalidateConversationList(s.cache, userID)
|
||||||
|
cache.InvalidateUnreadConversation(s.cache, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateUserUnreadCache 失效用户未读数缓存(供外部调用)
|
||||||
|
func (s *MessageService) InvalidateUserUnreadCache(userID, conversationID string) {
|
||||||
|
cache.InvalidateUnreadConversation(s.cache, userID)
|
||||||
|
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
|
||||||
|
}
|
||||||
169
internal/service/notification_service.go
Normal file
169
internal/service/notification_service.go
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/cache"
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 缓存TTL常量
|
||||||
|
const (
|
||||||
|
NotificationUnreadCountTTL = 30 * time.Second // 通知未读数缓存30秒
|
||||||
|
NotificationNullTTL = 5 * time.Second
|
||||||
|
NotificationCacheJitter = 0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
// NotificationService 通知服务
|
||||||
|
type NotificationService struct {
|
||||||
|
notificationRepo *repository.NotificationRepository
|
||||||
|
cache cache.Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNotificationService 创建通知服务
|
||||||
|
func NewNotificationService(notificationRepo *repository.NotificationRepository) *NotificationService {
|
||||||
|
return &NotificationService{
|
||||||
|
notificationRepo: notificationRepo,
|
||||||
|
cache: cache.GetCache(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 创建通知
|
||||||
|
func (s *NotificationService) Create(ctx context.Context, userID string, notificationType model.NotificationType, title, content string) (*model.Notification, error) {
|
||||||
|
notification := &model.Notification{
|
||||||
|
UserID: userID,
|
||||||
|
Type: notificationType,
|
||||||
|
Title: title,
|
||||||
|
Content: content,
|
||||||
|
IsRead: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.notificationRepo.Create(notification)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效未读数缓存
|
||||||
|
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||||
|
|
||||||
|
return notification, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByUserID 获取用户通知
|
||||||
|
func (s *NotificationService) GetByUserID(ctx context.Context, userID string, page, pageSize int, unreadOnly bool) ([]*model.Notification, int64, error) {
|
||||||
|
return s.notificationRepo.GetByUserID(userID, page, pageSize, unreadOnly)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAsRead 标记为已读
|
||||||
|
func (s *NotificationService) MarkAsRead(ctx context.Context, id string) error {
|
||||||
|
err := s.notificationRepo.MarkAsRead(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 注意:这里无法获取userID,所以不在缓存中失效
|
||||||
|
// 调用方应该使用MarkAsReadWithUserID方法
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAsReadWithUserID 标记为已读(带用户ID,用于缓存失效)
|
||||||
|
func (s *NotificationService) MarkAsReadWithUserID(ctx context.Context, id, userID string) error {
|
||||||
|
err := s.notificationRepo.MarkAsRead(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效未读数缓存
|
||||||
|
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkAllAsRead 标记所有为已读
|
||||||
|
func (s *NotificationService) MarkAllAsRead(ctx context.Context, userID string) error {
|
||||||
|
err := s.notificationRepo.MarkAllAsRead(userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效未读数缓存
|
||||||
|
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除通知
|
||||||
|
func (s *NotificationService) Delete(ctx context.Context, id string) error {
|
||||||
|
return s.notificationRepo.Delete(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnreadCount 获取未读数量(带缓存)
|
||||||
|
func (s *NotificationService) GetUnreadCount(ctx context.Context, userID string) (int64, error) {
|
||||||
|
cacheSettings := cache.GetSettings()
|
||||||
|
unreadTTL := cacheSettings.UnreadCountTTL
|
||||||
|
if unreadTTL <= 0 {
|
||||||
|
unreadTTL = NotificationUnreadCountTTL
|
||||||
|
}
|
||||||
|
nullTTL := cacheSettings.NullTTL
|
||||||
|
if nullTTL <= 0 {
|
||||||
|
nullTTL = NotificationNullTTL
|
||||||
|
}
|
||||||
|
jitter := cacheSettings.JitterRatio
|
||||||
|
if jitter <= 0 {
|
||||||
|
jitter = NotificationCacheJitter
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成缓存键
|
||||||
|
cacheKey := cache.UnreadSystemKey(userID)
|
||||||
|
return cache.GetOrLoadTyped[int64](
|
||||||
|
s.cache,
|
||||||
|
cacheKey,
|
||||||
|
unreadTTL,
|
||||||
|
jitter,
|
||||||
|
nullTTL,
|
||||||
|
func() (int64, error) {
|
||||||
|
return s.notificationRepo.GetUnreadCount(userID)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteNotification 删除通知(带用户验证)
|
||||||
|
func (s *NotificationService) DeleteNotification(ctx context.Context, id, userID string) error {
|
||||||
|
// 先检查通知是否属于该用户
|
||||||
|
notification, err := s.notificationRepo.GetByID(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if notification.UserID != userID {
|
||||||
|
return ErrUnauthorizedNotification
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.notificationRepo.Delete(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效未读数缓存
|
||||||
|
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearAllNotifications 清空所有通知
|
||||||
|
func (s *NotificationService) ClearAllNotifications(ctx context.Context, userID string) error {
|
||||||
|
err := s.notificationRepo.DeleteAllByUserID(userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效未读数缓存
|
||||||
|
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 错误定义
|
||||||
|
var ErrUnauthorizedNotification = &ServiceError{Code: 403, Message: "unauthorized to delete this notification"}
|
||||||
103
internal/service/post_ai_service.go
Normal file
103
internal/service/post_ai_service.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/pkg/openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostModerationRejectedError 帖子审核拒绝错误
|
||||||
|
type PostModerationRejectedError struct {
|
||||||
|
Reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *PostModerationRejectedError) Error() string {
|
||||||
|
if strings.TrimSpace(e.Reason) == "" {
|
||||||
|
return "post rejected by moderation"
|
||||||
|
}
|
||||||
|
return "post rejected by moderation: " + e.Reason
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserMessage 返回给前端的用户可读文案
|
||||||
|
func (e *PostModerationRejectedError) UserMessage() string {
|
||||||
|
if strings.TrimSpace(e.Reason) == "" {
|
||||||
|
return "内容未通过审核,请修改后重试"
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(e.Reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CommentModerationRejectedError 评论审核拒绝错误
|
||||||
|
type CommentModerationRejectedError struct {
|
||||||
|
Reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *CommentModerationRejectedError) Error() string {
|
||||||
|
if strings.TrimSpace(e.Reason) == "" {
|
||||||
|
return "comment rejected by moderation"
|
||||||
|
}
|
||||||
|
return "comment rejected by moderation: " + e.Reason
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserMessage 返回给前端的用户可读文案
|
||||||
|
func (e *CommentModerationRejectedError) UserMessage() string {
|
||||||
|
if strings.TrimSpace(e.Reason) == "" {
|
||||||
|
return "评论未通过审核,请修改后重试"
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(e.Reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
type PostAIService struct {
|
||||||
|
openAIClient openai.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostAIService(openAIClient openai.Client) *PostAIService {
|
||||||
|
return &PostAIService{
|
||||||
|
openAIClient: openAIClient,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PostAIService) IsEnabled() bool {
|
||||||
|
return s != nil && s.openAIClient != nil && s.openAIClient.IsEnabled()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModeratePost 审核帖子内容,返回 nil 表示通过
|
||||||
|
func (s *PostAIService) ModeratePost(ctx context.Context, title, content string, images []string) error {
|
||||||
|
if !s.IsEnabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
approved, reason, err := s.openAIClient.ModeratePost(ctx, title, content, images)
|
||||||
|
if err != nil {
|
||||||
|
if s.openAIClient.Config().StrictModeration {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Printf("[WARN] AI moderation failed, fallback allow: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !approved {
|
||||||
|
return &PostModerationRejectedError{Reason: reason}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModerateComment 审核评论内容,返回 nil 表示通过
|
||||||
|
func (s *PostAIService) ModerateComment(ctx context.Context, content string, images []string) error {
|
||||||
|
if !s.IsEnabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
approved, reason, err := s.openAIClient.ModerateComment(ctx, content, images)
|
||||||
|
if err != nil {
|
||||||
|
if s.openAIClient.Config().StrictModeration {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Printf("[WARN] AI comment moderation failed, fallback allow: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !approved {
|
||||||
|
return &CommentModerationRejectedError{Reason: reason}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
593
internal/service/post_service.go
Normal file
593
internal/service/post_service.go
Normal file
@@ -0,0 +1,593 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/cache"
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/pkg/gorse"
|
||||||
|
"carrot_bbs/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 缓存TTL常量
|
||||||
|
const (
|
||||||
|
PostListTTL = 30 * time.Second // 帖子列表缓存30秒
|
||||||
|
PostListNullTTL = 5 * time.Second
|
||||||
|
PostListJitterRatio = 0.15
|
||||||
|
anonymousViewUserID = "_anon_view"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostService 帖子服务
|
||||||
|
type PostService struct {
|
||||||
|
postRepo *repository.PostRepository
|
||||||
|
systemMessageService SystemMessageService
|
||||||
|
cache cache.Cache
|
||||||
|
gorseClient gorse.Client
|
||||||
|
postAIService *PostAIService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPostService 创建帖子服务
|
||||||
|
func NewPostService(postRepo *repository.PostRepository, systemMessageService SystemMessageService, gorseClient gorse.Client, postAIService *PostAIService) *PostService {
|
||||||
|
return &PostService{
|
||||||
|
postRepo: postRepo,
|
||||||
|
systemMessageService: systemMessageService,
|
||||||
|
cache: cache.GetCache(),
|
||||||
|
gorseClient: gorseClient,
|
||||||
|
postAIService: postAIService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostListResult 帖子列表缓存结果
|
||||||
|
type PostListResult struct {
|
||||||
|
Posts []*model.Post
|
||||||
|
Total int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 创建帖子
|
||||||
|
func (s *PostService) Create(ctx context.Context, userID, title, content string, images []string) (*model.Post, error) {
|
||||||
|
post := &model.Post{
|
||||||
|
UserID: userID,
|
||||||
|
Title: title,
|
||||||
|
Content: content,
|
||||||
|
Status: model.PostStatusPending,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.postRepo.Create(post, images)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效帖子列表缓存
|
||||||
|
cache.InvalidatePostList(s.cache)
|
||||||
|
|
||||||
|
// 同步到Gorse推荐系统(异步)
|
||||||
|
go s.reviewPostAsync(post.ID, userID, title, content, images)
|
||||||
|
|
||||||
|
// 重新查询以获取关联的 User 和 Images
|
||||||
|
return s.postRepo.GetByID(post.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PostService) reviewPostAsync(postID, userID, title, content string, images []string) {
|
||||||
|
// 未启用AI时,直接发布
|
||||||
|
if s.postAIService == nil || !s.postAIService.IsEnabled() {
|
||||||
|
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil {
|
||||||
|
log.Printf("[WARN] Failed to publish post without AI moderation: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.postAIService.ModeratePost(context.Background(), title, content, images)
|
||||||
|
if err != nil {
|
||||||
|
var rejectedErr *PostModerationRejectedError
|
||||||
|
if errors.As(err, &rejectedErr) {
|
||||||
|
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
|
||||||
|
log.Printf("[WARN] Failed to reject post %s: %v", postID, updateErr)
|
||||||
|
}
|
||||||
|
s.notifyModerationRejected(userID, rejectedErr.Reason)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 规则审核不可用时,降级为发布,避免长时间pending
|
||||||
|
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
|
||||||
|
log.Printf("[WARN] Failed to publish post %s after moderation error: %v", postID, updateErr)
|
||||||
|
}
|
||||||
|
log.Printf("[WARN] Post moderation failed, fallback publish post=%s err=%v", postID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "ai"); err != nil {
|
||||||
|
log.Printf("[WARN] Failed to publish post %s: %v", postID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.gorseClient.IsEnabled() {
|
||||||
|
post, getErr := s.postRepo.GetByID(postID)
|
||||||
|
if getErr != nil {
|
||||||
|
log.Printf("[WARN] Failed to load published post for gorse sync: %v", getErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
categories := s.buildPostCategories(post)
|
||||||
|
comment := post.Title
|
||||||
|
textToEmbed := post.Title + " " + post.Content
|
||||||
|
if upsertErr := s.gorseClient.UpsertItemWithEmbedding(context.Background(), post.ID, categories, comment, textToEmbed); upsertErr != nil {
|
||||||
|
log.Printf("[WARN] Failed to upsert item to Gorse: %v", upsertErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PostService) notifyModerationRejected(userID, reason string) {
|
||||||
|
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
content := "您发布的帖子未通过AI审核,请修改后重试。"
|
||||||
|
if strings.TrimSpace(reason) != "" {
|
||||||
|
content = fmt.Sprintf("您发布的帖子未通过AI审核,原因:%s。请修改后重试。", reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := s.systemMessageService.SendSystemAnnouncement(
|
||||||
|
context.Background(),
|
||||||
|
[]string{userID},
|
||||||
|
"帖子审核未通过",
|
||||||
|
content,
|
||||||
|
); err != nil {
|
||||||
|
log.Printf("[WARN] Failed to send moderation reject notification: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetByID 根据ID获取帖子
|
||||||
|
func (s *PostService) GetByID(ctx context.Context, id string) (*model.Post, error) {
|
||||||
|
return s.postRepo.GetByID(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 更新帖子
|
||||||
|
func (s *PostService) Update(ctx context.Context, post *model.Post) error {
|
||||||
|
err := s.postRepo.Update(post)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效帖子详情缓存和列表缓存
|
||||||
|
cache.InvalidatePostDetail(s.cache, post.ID)
|
||||||
|
cache.InvalidatePostList(s.cache)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除帖子
|
||||||
|
func (s *PostService) Delete(ctx context.Context, id string) error {
|
||||||
|
err := s.postRepo.Delete(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效帖子详情缓存和列表缓存
|
||||||
|
cache.InvalidatePostDetail(s.cache, id)
|
||||||
|
cache.InvalidatePostList(s.cache)
|
||||||
|
|
||||||
|
// 从Gorse中删除帖子(异步)
|
||||||
|
go func() {
|
||||||
|
if s.gorseClient.IsEnabled() {
|
||||||
|
if err := s.gorseClient.DeleteItem(context.Background(), id); err != nil {
|
||||||
|
log.Printf("[WARN] Failed to delete item from Gorse: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 获取帖子列表(带缓存)
|
||||||
|
func (s *PostService) List(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
||||||
|
cacheSettings := cache.GetSettings()
|
||||||
|
postListTTL := cacheSettings.PostListTTL
|
||||||
|
if postListTTL <= 0 {
|
||||||
|
postListTTL = PostListTTL
|
||||||
|
}
|
||||||
|
nullTTL := cacheSettings.NullTTL
|
||||||
|
if nullTTL <= 0 {
|
||||||
|
nullTTL = PostListNullTTL
|
||||||
|
}
|
||||||
|
jitter := cacheSettings.JitterRatio
|
||||||
|
if jitter <= 0 {
|
||||||
|
jitter = PostListJitterRatio
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成缓存键(包含 userID 维度,避免过滤查询与全量查询互相污染)
|
||||||
|
cacheKey := cache.PostListKey("latest", userID, page, pageSize)
|
||||||
|
|
||||||
|
result, err := cache.GetOrLoadTyped[*PostListResult](
|
||||||
|
s.cache,
|
||||||
|
cacheKey,
|
||||||
|
postListTTL,
|
||||||
|
jitter,
|
||||||
|
nullTTL,
|
||||||
|
func() (*PostListResult, error) {
|
||||||
|
posts, total, err := s.postRepo.List(page, pageSize, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &PostListResult{Posts: posts, Total: total}, nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
if result == nil {
|
||||||
|
return []*model.Post{}, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 兼容历史脏缓存:旧缓存序列化会丢失 Post.User,导致前端显示“匿名用户”
|
||||||
|
// 这里检测并回源重建一次缓存,避免在 TTL 内持续返回缺失作者的数据。
|
||||||
|
missingAuthor := false
|
||||||
|
for _, post := range result.Posts {
|
||||||
|
if post != nil && post.UserID != "" && post.User == nil {
|
||||||
|
missingAuthor = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if missingAuthor {
|
||||||
|
posts, total, loadErr := s.postRepo.List(page, pageSize, userID)
|
||||||
|
if loadErr != nil {
|
||||||
|
return nil, 0, loadErr
|
||||||
|
}
|
||||||
|
result = &PostListResult{Posts: posts, Total: total}
|
||||||
|
cache.SetWithJitter(s.cache, cacheKey, result, postListTTL, jitter)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.Posts, result.Total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLatestPosts 获取最新帖子(语义化别名)
|
||||||
|
func (s *PostService) GetLatestPosts(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
||||||
|
return s.List(ctx, page, pageSize, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserPosts 获取用户帖子
|
||||||
|
func (s *PostService) GetUserPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||||
|
return s.postRepo.GetUserPosts(userID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Like 点赞
|
||||||
|
func (s *PostService) Like(ctx context.Context, postID, userID string) error {
|
||||||
|
// 获取帖子信息用于发送通知
|
||||||
|
post, err := s.postRepo.GetByID(postID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.postRepo.Like(postID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效帖子详情缓存
|
||||||
|
cache.InvalidatePostDetail(s.cache, postID)
|
||||||
|
|
||||||
|
// 发送点赞通知(不给自己发通知)
|
||||||
|
if s.systemMessageService != nil && post.UserID != userID {
|
||||||
|
go func() {
|
||||||
|
notifyErr := s.systemMessageService.SendLikeNotification(context.Background(), post.UserID, userID, postID)
|
||||||
|
if notifyErr != nil {
|
||||||
|
fmt.Printf("[DEBUG] Error sending like notification: %v\n", notifyErr)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("[DEBUG] Like notification sent successfully\n")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 推送点赞行为到Gorse(异步)
|
||||||
|
go func() {
|
||||||
|
if s.gorseClient.IsEnabled() {
|
||||||
|
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeLike, userID, postID); err != nil {
|
||||||
|
log.Printf("[WARN] Failed to insert like feedback to Gorse: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unlike 取消点赞
|
||||||
|
func (s *PostService) Unlike(ctx context.Context, postID, userID string) error {
|
||||||
|
err := s.postRepo.Unlike(postID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效帖子详情缓存
|
||||||
|
cache.InvalidatePostDetail(s.cache, postID)
|
||||||
|
|
||||||
|
// 删除Gorse中的点赞反馈(异步)
|
||||||
|
go func() {
|
||||||
|
if s.gorseClient.IsEnabled() {
|
||||||
|
if err := s.gorseClient.DeleteFeedback(context.Background(), gorse.FeedbackTypeLike, userID, postID); err != nil {
|
||||||
|
log.Printf("[WARN] Failed to delete like feedback from Gorse: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLiked 检查是否点赞
|
||||||
|
func (s *PostService) IsLiked(ctx context.Context, postID, userID string) bool {
|
||||||
|
return s.postRepo.IsLiked(postID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Favorite 收藏
|
||||||
|
func (s *PostService) Favorite(ctx context.Context, postID, userID string) error {
|
||||||
|
// 获取帖子信息用于发送通知
|
||||||
|
post, err := s.postRepo.GetByID(postID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.postRepo.Favorite(postID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效帖子详情缓存
|
||||||
|
cache.InvalidatePostDetail(s.cache, postID)
|
||||||
|
|
||||||
|
// 发送收藏通知(不给自己发通知)
|
||||||
|
if s.systemMessageService != nil && post.UserID != userID {
|
||||||
|
go func() {
|
||||||
|
notifyErr := s.systemMessageService.SendFavoriteNotification(context.Background(), post.UserID, userID, postID)
|
||||||
|
if notifyErr != nil {
|
||||||
|
fmt.Printf("[DEBUG] Error sending favorite notification: %v\n", notifyErr)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("[DEBUG] Favorite notification sent successfully\n")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 推送收藏行为到Gorse(异步)
|
||||||
|
go func() {
|
||||||
|
if s.gorseClient.IsEnabled() {
|
||||||
|
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeStar, userID, postID); err != nil {
|
||||||
|
log.Printf("[WARN] Failed to insert favorite feedback to Gorse: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unfavorite 取消收藏
|
||||||
|
func (s *PostService) Unfavorite(ctx context.Context, postID, userID string) error {
|
||||||
|
err := s.postRepo.Unfavorite(postID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效帖子详情缓存
|
||||||
|
cache.InvalidatePostDetail(s.cache, postID)
|
||||||
|
|
||||||
|
// 删除Gorse中的收藏反馈(异步)
|
||||||
|
go func() {
|
||||||
|
if s.gorseClient.IsEnabled() {
|
||||||
|
if err := s.gorseClient.DeleteFeedback(context.Background(), gorse.FeedbackTypeStar, userID, postID); err != nil {
|
||||||
|
log.Printf("[WARN] Failed to delete favorite feedback from Gorse: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsFavorited 检查是否收藏
|
||||||
|
func (s *PostService) IsFavorited(ctx context.Context, postID, userID string) bool {
|
||||||
|
return s.postRepo.IsFavorited(postID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementViews 增加帖子观看量并同步到Gorse
|
||||||
|
func (s *PostService) IncrementViews(ctx context.Context, postID, userID string) error {
|
||||||
|
if err := s.postRepo.IncrementViews(postID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 同步浏览行为到Gorse(异步)
|
||||||
|
go func() {
|
||||||
|
if !s.gorseClient.IsEnabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
feedbackUserID := userID
|
||||||
|
if feedbackUserID == "" {
|
||||||
|
feedbackUserID = anonymousViewUserID
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeRead, feedbackUserID, postID); err != nil {
|
||||||
|
log.Printf("[WARN] Failed to insert read feedback to Gorse: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFavorites 获取收藏列表
|
||||||
|
func (s *PostService) GetFavorites(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||||
|
return s.postRepo.GetFavorites(userID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search 搜索帖子
|
||||||
|
func (s *PostService) Search(ctx context.Context, keyword string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||||
|
return s.postRepo.Search(keyword, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFollowingPosts 获取关注用户的帖子(带缓存)
|
||||||
|
func (s *PostService) GetFollowingPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||||
|
cacheSettings := cache.GetSettings()
|
||||||
|
postListTTL := cacheSettings.PostListTTL
|
||||||
|
if postListTTL <= 0 {
|
||||||
|
postListTTL = PostListTTL
|
||||||
|
}
|
||||||
|
nullTTL := cacheSettings.NullTTL
|
||||||
|
if nullTTL <= 0 {
|
||||||
|
nullTTL = PostListNullTTL
|
||||||
|
}
|
||||||
|
jitter := cacheSettings.JitterRatio
|
||||||
|
if jitter <= 0 {
|
||||||
|
jitter = PostListJitterRatio
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成缓存键
|
||||||
|
cacheKey := cache.PostListKey("follow", userID, page, pageSize)
|
||||||
|
|
||||||
|
result, err := cache.GetOrLoadTyped[*PostListResult](
|
||||||
|
s.cache,
|
||||||
|
cacheKey,
|
||||||
|
postListTTL,
|
||||||
|
jitter,
|
||||||
|
nullTTL,
|
||||||
|
func() (*PostListResult, error) {
|
||||||
|
posts, total, err := s.postRepo.GetFollowingPosts(userID, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &PostListResult{Posts: posts, Total: total}, nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
if result == nil {
|
||||||
|
return []*model.Post{}, 0, nil
|
||||||
|
}
|
||||||
|
return result.Posts, result.Total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHotPosts 获取热门帖子(使用Gorse非个性化推荐)
|
||||||
|
func (s *PostService) GetHotPosts(ctx context.Context, page, pageSize int) ([]*model.Post, int64, error) {
|
||||||
|
// 如果Gorse启用,使用自定义的非个性化推荐器
|
||||||
|
if s.gorseClient.IsEnabled() {
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
// 使用 most_liked_weekly 推荐器获取周热门
|
||||||
|
// 多取1条用于判断是否还有下一页
|
||||||
|
itemIDs, err := s.gorseClient.GetNonPersonalized(ctx, "most_liked_weekly", pageSize+1, offset, "")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[WARN] Gorse GetNonPersonalized failed: %v, fallback to database", err)
|
||||||
|
return s.getHotPostsFromDB(ctx, page, pageSize)
|
||||||
|
}
|
||||||
|
if len(itemIDs) > 0 {
|
||||||
|
hasNext := len(itemIDs) > pageSize
|
||||||
|
if hasNext {
|
||||||
|
itemIDs = itemIDs[:pageSize]
|
||||||
|
}
|
||||||
|
posts, err := s.postRepo.GetByIDs(itemIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
// 近似 total:当 hasNext 为 true 时,按分页窗口估算,避免因脏数据/缺失数据导致总页数被低估
|
||||||
|
estimatedTotal := int64(offset + len(posts))
|
||||||
|
if hasNext {
|
||||||
|
estimatedTotal = int64(offset + pageSize + 1)
|
||||||
|
}
|
||||||
|
return posts, estimatedTotal, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 降级:从数据库获取
|
||||||
|
return s.getHotPostsFromDB(ctx, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getHotPostsFromDB 从数据库获取热门帖子(降级路径)
|
||||||
|
func (s *PostService) getHotPostsFromDB(ctx context.Context, page, pageSize int) ([]*model.Post, int64, error) {
|
||||||
|
// 直接查询数据库,不再使用本地缓存(Gorse失败降级时使用)
|
||||||
|
posts, total, err := s.postRepo.GetHotPosts(page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return posts, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRecommendedPosts 获取推荐帖子
|
||||||
|
func (s *PostService) GetRecommendedPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||||
|
// 如果Gorse未启用或用户未登录,降级为热门帖子
|
||||||
|
if !s.gorseClient.IsEnabled() || userID == "" {
|
||||||
|
return s.GetHotPosts(ctx, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算偏移量
|
||||||
|
offset := (page - 1) * pageSize
|
||||||
|
|
||||||
|
// 从Gorse获取推荐列表
|
||||||
|
// 多取1条用于判断是否还有下一页
|
||||||
|
itemIDs, err := s.gorseClient.GetRecommend(ctx, userID, pageSize+1, offset)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[WARN] Gorse recommendation failed: %v, fallback to hot posts", err)
|
||||||
|
return s.GetHotPosts(ctx, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有推荐结果,降级为热门帖子
|
||||||
|
if len(itemIDs) == 0 {
|
||||||
|
return s.GetHotPosts(ctx, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasNext := len(itemIDs) > pageSize
|
||||||
|
if hasNext {
|
||||||
|
itemIDs = itemIDs[:pageSize]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据ID列表查询帖子详情
|
||||||
|
posts, err := s.postRepo.GetByIDs(itemIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 近似 total:当 hasNext 为 true 时,按分页窗口估算,避免因脏数据/缺失数据导致总页数被低估
|
||||||
|
estimatedTotal := int64(offset + len(posts))
|
||||||
|
if hasNext {
|
||||||
|
estimatedTotal = int64(offset + pageSize + 1)
|
||||||
|
}
|
||||||
|
return posts, estimatedTotal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildPostCategories 构建帖子的类别标签
|
||||||
|
func (s *PostService) buildPostCategories(post *model.Post) []string {
|
||||||
|
var categories []string
|
||||||
|
|
||||||
|
// 热度标签
|
||||||
|
if post.ViewsCount > 1000 {
|
||||||
|
categories = append(categories, "hot_high")
|
||||||
|
} else if post.ViewsCount > 100 {
|
||||||
|
categories = append(categories, "hot_medium")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 点赞标签
|
||||||
|
if post.LikesCount > 100 {
|
||||||
|
categories = append(categories, "likes_100+")
|
||||||
|
} else if post.LikesCount > 50 {
|
||||||
|
categories = append(categories, "likes_50+")
|
||||||
|
} else if post.LikesCount > 10 {
|
||||||
|
categories = append(categories, "likes_10+")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 评论标签
|
||||||
|
if post.CommentsCount > 50 {
|
||||||
|
categories = append(categories, "comments_50+")
|
||||||
|
} else if post.CommentsCount > 10 {
|
||||||
|
categories = append(categories, "comments_10+")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 时间标签
|
||||||
|
age := time.Since(post.CreatedAt)
|
||||||
|
if age < 24*time.Hour {
|
||||||
|
categories = append(categories, "today")
|
||||||
|
} else if age < 7*24*time.Hour {
|
||||||
|
categories = append(categories, "this_week")
|
||||||
|
} else if age < 30*24*time.Hour {
|
||||||
|
categories = append(categories, "this_month")
|
||||||
|
}
|
||||||
|
|
||||||
|
return categories
|
||||||
|
}
|
||||||
575
internal/service/push_service.go
Normal file
575
internal/service/push_service.go
Normal file
@@ -0,0 +1,575 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/dto"
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/pkg/websocket"
|
||||||
|
"carrot_bbs/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 推送相关常量
|
||||||
|
const (
|
||||||
|
// DefaultPushTimeout 默认推送超时时间
|
||||||
|
DefaultPushTimeout = 30 * time.Second
|
||||||
|
// MaxRetryCount 最大重试次数
|
||||||
|
MaxRetryCount = 3
|
||||||
|
// DefaultExpiredTime 默认消息过期时间(24小时)
|
||||||
|
DefaultExpiredTime = 24 * time.Hour
|
||||||
|
// PushQueueSize 推送队列大小
|
||||||
|
PushQueueSize = 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
// PushPriority 推送优先级
|
||||||
|
type PushPriority int
|
||||||
|
|
||||||
|
const (
|
||||||
|
PriorityLow PushPriority = 1 // 低优先级(营销消息等)
|
||||||
|
PriorityNormal PushPriority = 5 // 普通优先级(系统通知)
|
||||||
|
PriorityHigh PushPriority = 8 // 高优先级(聊天消息)
|
||||||
|
PriorityCritical PushPriority = 10 // 最高优先级(重要系统通知)
|
||||||
|
)
|
||||||
|
|
||||||
|
// PushService 推送服务接口
|
||||||
|
type PushService interface {
|
||||||
|
// 推送核心方法
|
||||||
|
PushMessage(ctx context.Context, userID string, message *model.Message) error
|
||||||
|
PushToUser(ctx context.Context, userID string, message *model.Message, priority int) error
|
||||||
|
|
||||||
|
// 系统消息推送
|
||||||
|
PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error
|
||||||
|
PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error
|
||||||
|
PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error
|
||||||
|
|
||||||
|
// 系统通知推送(新接口,使用独立的 SystemNotification 模型)
|
||||||
|
PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error
|
||||||
|
|
||||||
|
// 设备管理
|
||||||
|
RegisterDevice(ctx context.Context, userID string, deviceID string, deviceType model.DeviceType, pushToken string) error
|
||||||
|
UnregisterDevice(ctx context.Context, deviceID string) error
|
||||||
|
UpdateDeviceToken(ctx context.Context, deviceID string, newPushToken string) error
|
||||||
|
|
||||||
|
// 推送记录管理
|
||||||
|
CreatePushRecord(ctx context.Context, userID string, messageID string, channel model.PushChannel) (*model.PushRecord, error)
|
||||||
|
GetPendingPushes(ctx context.Context, userID string) ([]*model.PushRecord, error)
|
||||||
|
|
||||||
|
// 后台任务
|
||||||
|
StartPushWorker(ctx context.Context)
|
||||||
|
StopPushWorker()
|
||||||
|
}
|
||||||
|
|
||||||
|
// pushServiceImpl 推送服务实现
|
||||||
|
type pushServiceImpl struct {
|
||||||
|
pushRepo *repository.PushRecordRepository
|
||||||
|
deviceRepo *repository.DeviceTokenRepository
|
||||||
|
messageRepo *repository.MessageRepository
|
||||||
|
wsManager *websocket.WebSocketManager
|
||||||
|
|
||||||
|
// 推送队列
|
||||||
|
pushQueue chan *pushTask
|
||||||
|
stopChan chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pushTask 推送任务
|
||||||
|
type pushTask struct {
|
||||||
|
userID string
|
||||||
|
message *model.Message
|
||||||
|
priority int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPushService 创建推送服务
|
||||||
|
func NewPushService(
|
||||||
|
pushRepo *repository.PushRecordRepository,
|
||||||
|
deviceRepo *repository.DeviceTokenRepository,
|
||||||
|
messageRepo *repository.MessageRepository,
|
||||||
|
wsManager *websocket.WebSocketManager,
|
||||||
|
) PushService {
|
||||||
|
return &pushServiceImpl{
|
||||||
|
pushRepo: pushRepo,
|
||||||
|
deviceRepo: deviceRepo,
|
||||||
|
messageRepo: messageRepo,
|
||||||
|
wsManager: wsManager,
|
||||||
|
pushQueue: make(chan *pushTask, PushQueueSize),
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushMessage 推送消息给用户
|
||||||
|
func (s *pushServiceImpl) PushMessage(ctx context.Context, userID string, message *model.Message) error {
|
||||||
|
return s.PushToUser(ctx, userID, message, int(PriorityNormal))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushToUser 带优先级的推送
|
||||||
|
func (s *pushServiceImpl) PushToUser(ctx context.Context, userID string, message *model.Message, priority int) error {
|
||||||
|
// 首先尝试WebSocket推送(实时推送)
|
||||||
|
if s.pushViaWebSocket(ctx, userID, message) {
|
||||||
|
// WebSocket推送成功,记录推送状态
|
||||||
|
record, err := s.CreatePushRecord(ctx, userID, message.ID, model.PushChannelWebSocket)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create push record: %w", err)
|
||||||
|
}
|
||||||
|
record.MarkPushed()
|
||||||
|
if err := s.pushRepo.Update(record); err != nil {
|
||||||
|
return fmt.Errorf("failed to update push record: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocket推送失败,加入推送队列等待移动端推送
|
||||||
|
select {
|
||||||
|
case s.pushQueue <- &pushTask{
|
||||||
|
userID: userID,
|
||||||
|
message: message,
|
||||||
|
priority: priority,
|
||||||
|
}:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
// 队列已满,直接创建待推送记录
|
||||||
|
_, err := s.CreatePushRecord(ctx, userID, message.ID, model.PushChannelFCM)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create pending push record: %w", err)
|
||||||
|
}
|
||||||
|
return errors.New("push queue is full, message queued for later delivery")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pushViaWebSocket 通过WebSocket推送消息
|
||||||
|
// 返回true表示推送成功,false表示用户不在线
|
||||||
|
func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, message *model.Message) bool {
|
||||||
|
if s.wsManager == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.wsManager.IsUserOnline(userID) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 判断是否为系统消息/通知消息
|
||||||
|
if message.IsSystemMessage() || message.Category == model.CategoryNotification {
|
||||||
|
// 使用 NotificationMessage 格式推送系统通知
|
||||||
|
// 从 segments 中提取文本内容
|
||||||
|
content := dto.ExtractTextContentFromModel(message.Segments)
|
||||||
|
|
||||||
|
notification := &websocket.NotificationMessage{
|
||||||
|
ID: fmt.Sprintf("%s", message.ID),
|
||||||
|
Type: string(message.SystemType),
|
||||||
|
Content: content,
|
||||||
|
Extra: make(map[string]interface{}),
|
||||||
|
CreatedAt: message.CreatedAt.UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 填充额外数据
|
||||||
|
if message.ExtraData != nil {
|
||||||
|
notification.Extra["actor_id"] = message.ExtraData.ActorID
|
||||||
|
notification.Extra["actor_name"] = message.ExtraData.ActorName
|
||||||
|
notification.Extra["avatar_url"] = message.ExtraData.AvatarURL
|
||||||
|
notification.Extra["target_id"] = message.ExtraData.TargetID
|
||||||
|
notification.Extra["target_type"] = message.ExtraData.TargetType
|
||||||
|
notification.Extra["action_url"] = message.ExtraData.ActionURL
|
||||||
|
notification.Extra["action_time"] = message.ExtraData.ActionTime
|
||||||
|
|
||||||
|
// 设置触发用户信息
|
||||||
|
if message.ExtraData.ActorID > 0 {
|
||||||
|
notification.TriggerUser = &websocket.NotificationUser{
|
||||||
|
ID: fmt.Sprintf("%d", message.ExtraData.ActorID),
|
||||||
|
Username: message.ExtraData.ActorName,
|
||||||
|
Avatar: message.ExtraData.AvatarURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
|
||||||
|
s.wsManager.SendToUser(userID, wsMsg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建普通聊天消息的 WebSocket 消息 - 使用新的 WSEventResponse 格式
|
||||||
|
// 获取会话类型 (private/group)
|
||||||
|
detailType := "private"
|
||||||
|
if message.ConversationID != "" {
|
||||||
|
// 从会话中获取类型,需要查询数据库或从消息中判断
|
||||||
|
// 这里暂时默认为 private,group 类型需要额外逻辑
|
||||||
|
}
|
||||||
|
|
||||||
|
// 直接使用 message.Segments
|
||||||
|
segments := message.Segments
|
||||||
|
|
||||||
|
event := &dto.WSEventResponse{
|
||||||
|
ID: fmt.Sprintf("%s", message.ID),
|
||||||
|
Time: message.CreatedAt.UnixMilli(),
|
||||||
|
Type: "message",
|
||||||
|
DetailType: detailType,
|
||||||
|
Seq: fmt.Sprintf("%d", message.Seq),
|
||||||
|
Segments: segments,
|
||||||
|
SenderID: message.SenderID,
|
||||||
|
}
|
||||||
|
|
||||||
|
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, event)
|
||||||
|
s.wsManager.SendToUser(userID, wsMsg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// pushViaFCM 通过FCM推送(预留接口)
|
||||||
|
func (s *pushServiceImpl) pushViaFCM(ctx context.Context, deviceToken *model.DeviceToken, message *model.Message) error {
|
||||||
|
// TODO: 实现FCM推送
|
||||||
|
// 1. 构建FCM消息
|
||||||
|
// 2. 调用Firebase Admin SDK发送消息
|
||||||
|
// 3. 处理发送结果
|
||||||
|
return errors.New("FCM push not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// pushViaAPNs 通过APNs推送(预留接口)
|
||||||
|
func (s *pushServiceImpl) pushViaAPNs(ctx context.Context, deviceToken *model.DeviceToken, message *model.Message) error {
|
||||||
|
// TODO: 实现APNs推送
|
||||||
|
// 1. 构建APNs消息
|
||||||
|
// 2. 调用APNs SDK发送消息
|
||||||
|
// 3. 处理发送结果
|
||||||
|
return errors.New("APNs push not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterDevice 注册设备
|
||||||
|
func (s *pushServiceImpl) RegisterDevice(ctx context.Context, userID string, deviceID string, deviceType model.DeviceType, pushToken string) error {
|
||||||
|
deviceToken := &model.DeviceToken{
|
||||||
|
UserID: userID,
|
||||||
|
DeviceID: deviceID,
|
||||||
|
DeviceType: deviceType,
|
||||||
|
PushToken: pushToken,
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
deviceToken.UpdateLastUsed()
|
||||||
|
|
||||||
|
return s.deviceRepo.Upsert(deviceToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterDevice 注销设备
|
||||||
|
func (s *pushServiceImpl) UnregisterDevice(ctx context.Context, deviceID string) error {
|
||||||
|
return s.deviceRepo.Deactivate(deviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDeviceToken 更新设备Token
|
||||||
|
func (s *pushServiceImpl) UpdateDeviceToken(ctx context.Context, deviceID string, newPushToken string) error {
|
||||||
|
deviceToken, err := s.deviceRepo.GetByDeviceID(deviceID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("device not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceToken.PushToken = newPushToken
|
||||||
|
deviceToken.Activate()
|
||||||
|
|
||||||
|
return s.deviceRepo.Update(deviceToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatePushRecord 创建推送记录
|
||||||
|
func (s *pushServiceImpl) CreatePushRecord(ctx context.Context, userID string, messageID string, channel model.PushChannel) (*model.PushRecord, error) {
|
||||||
|
expiredAt := time.Now().Add(DefaultExpiredTime)
|
||||||
|
record := &model.PushRecord{
|
||||||
|
UserID: userID,
|
||||||
|
MessageID: messageID,
|
||||||
|
PushChannel: channel,
|
||||||
|
PushStatus: model.PushStatusPending,
|
||||||
|
MaxRetry: MaxRetryCount,
|
||||||
|
ExpiredAt: &expiredAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.pushRepo.Create(record); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create push record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPendingPushes 获取待推送记录
|
||||||
|
func (s *pushServiceImpl) GetPendingPushes(ctx context.Context, userID string) ([]*model.PushRecord, error) {
|
||||||
|
return s.pushRepo.GetByUserID(userID, 100, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartPushWorker 启动推送工作协程
|
||||||
|
func (s *pushServiceImpl) StartPushWorker(ctx context.Context) {
|
||||||
|
go s.processPushQueue()
|
||||||
|
go s.retryFailedPushes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopPushWorker 停止推送工作协程
|
||||||
|
func (s *pushServiceImpl) StopPushWorker() {
|
||||||
|
close(s.stopChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processPushQueue 处理推送队列
|
||||||
|
func (s *pushServiceImpl) processPushQueue() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.stopChan:
|
||||||
|
return
|
||||||
|
case task := <-s.pushQueue:
|
||||||
|
s.processPushTask(task)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processPushTask 处理单个推送任务
|
||||||
|
func (s *pushServiceImpl) processPushTask(task *pushTask) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), DefaultPushTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// 获取用户活跃设备
|
||||||
|
devices, err := s.deviceRepo.GetActiveByUserID(task.userID)
|
||||||
|
if err != nil || len(devices) == 0 {
|
||||||
|
// 没有可用设备,创建待推送记录
|
||||||
|
s.CreatePushRecord(ctx, task.userID, task.message.ID, model.PushChannelFCM)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对每个设备创建推送记录并尝试推送
|
||||||
|
for _, device := range devices {
|
||||||
|
record, err := s.CreatePushRecord(ctx, task.userID, task.message.ID, s.getChannelForDevice(device))
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var pushErr error
|
||||||
|
switch {
|
||||||
|
case device.IsIOS():
|
||||||
|
pushErr = s.pushViaAPNs(ctx, device, task.message)
|
||||||
|
case device.IsAndroid():
|
||||||
|
pushErr = s.pushViaFCM(ctx, device, task.message)
|
||||||
|
default:
|
||||||
|
// Web设备只支持WebSocket
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if pushErr != nil {
|
||||||
|
record.MarkFailed(pushErr.Error())
|
||||||
|
} else {
|
||||||
|
record.MarkPushed()
|
||||||
|
}
|
||||||
|
s.pushRepo.Update(record)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getChannelForDevice 根据设备类型获取推送通道
|
||||||
|
func (s *pushServiceImpl) getChannelForDevice(device *model.DeviceToken) model.PushChannel {
|
||||||
|
switch device.DeviceType {
|
||||||
|
case model.DeviceTypeIOS:
|
||||||
|
return model.PushChannelAPNs
|
||||||
|
case model.DeviceTypeAndroid:
|
||||||
|
return model.PushChannelFCM
|
||||||
|
default:
|
||||||
|
return model.PushChannelWebSocket
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// retryFailedPushes 重试失败的推送
|
||||||
|
func (s *pushServiceImpl) retryFailedPushes() {
|
||||||
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.stopChan:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
s.doRetry()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// doRetry 执行重试
|
||||||
|
func (s *pushServiceImpl) doRetry() {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// 获取失败待重试的推送
|
||||||
|
records, err := s.pushRepo.GetFailedPushesForRetry(100)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, record := range records {
|
||||||
|
// 检查是否过期
|
||||||
|
if record.IsExpired() {
|
||||||
|
record.MarkExpired()
|
||||||
|
s.pushRepo.Update(record)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取消息
|
||||||
|
message, err := s.messageRepo.GetMessageByID(record.MessageID)
|
||||||
|
if err != nil {
|
||||||
|
record.MarkFailed("message not found")
|
||||||
|
s.pushRepo.Update(record)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试WebSocket推送
|
||||||
|
if s.pushViaWebSocket(ctx, record.UserID, message) {
|
||||||
|
record.MarkDelivered()
|
||||||
|
s.pushRepo.Update(record)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取设备并尝试移动端推送
|
||||||
|
if record.DeviceToken != "" {
|
||||||
|
device, err := s.deviceRepo.GetByPushToken(record.DeviceToken)
|
||||||
|
if err != nil {
|
||||||
|
record.MarkFailed("device not found")
|
||||||
|
s.pushRepo.Update(record)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var pushErr error
|
||||||
|
switch {
|
||||||
|
case device.IsIOS():
|
||||||
|
pushErr = s.pushViaAPNs(ctx, device, message)
|
||||||
|
case device.IsAndroid():
|
||||||
|
pushErr = s.pushViaFCM(ctx, device, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pushErr != nil {
|
||||||
|
record.MarkFailed(pushErr.Error())
|
||||||
|
} else {
|
||||||
|
record.MarkPushed()
|
||||||
|
}
|
||||||
|
s.pushRepo.Update(record)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushSystemMessage 推送系统消息
|
||||||
|
func (s *pushServiceImpl) PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error {
|
||||||
|
// 首先尝试WebSocket推送
|
||||||
|
if s.pushSystemViaWebSocket(ctx, userID, msgType, title, content, data) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用户不在线,创建待推送记录(移动端上线后可通过其他方式获取)
|
||||||
|
// 系统消息通常不需要离线推送,客户端上线后会主动拉取
|
||||||
|
return errors.New("user is offline, system message will be available on next sync")
|
||||||
|
}
|
||||||
|
|
||||||
|
// pushSystemViaWebSocket 通过WebSocket推送系统消息
|
||||||
|
func (s *pushServiceImpl) pushSystemViaWebSocket(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) bool {
|
||||||
|
if s.wsManager == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.wsManager.IsUserOnline(userID) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
sysMsg := &websocket.SystemMessage{
|
||||||
|
Type: msgType,
|
||||||
|
Title: title,
|
||||||
|
Content: content,
|
||||||
|
Data: data,
|
||||||
|
CreatedAt: time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
|
||||||
|
s.wsManager.SendToUser(userID, wsMsg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushNotification 推送通知消息
|
||||||
|
func (s *pushServiceImpl) PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error {
|
||||||
|
// 首先尝试WebSocket推送
|
||||||
|
if s.pushNotificationViaWebSocket(ctx, userID, notification) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用户不在线,创建待推送记录
|
||||||
|
// 通知消息可以等用户上线后拉取
|
||||||
|
return errors.New("user is offline, notification will be available on next sync")
|
||||||
|
}
|
||||||
|
|
||||||
|
// pushNotificationViaWebSocket 通过WebSocket推送通知消息
|
||||||
|
func (s *pushServiceImpl) pushNotificationViaWebSocket(ctx context.Context, userID string, notification *websocket.NotificationMessage) bool {
|
||||||
|
if s.wsManager == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.wsManager.IsUserOnline(userID) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if notification.CreatedAt == 0 {
|
||||||
|
notification.CreatedAt = time.Now().UnixMilli()
|
||||||
|
}
|
||||||
|
|
||||||
|
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
|
||||||
|
s.wsManager.SendToUser(userID, wsMsg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushAnnouncement 广播公告消息
|
||||||
|
func (s *pushServiceImpl) PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error {
|
||||||
|
if s.wsManager == nil {
|
||||||
|
return errors.New("websocket manager not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if announcement.CreatedAt == 0 {
|
||||||
|
announcement.CreatedAt = time.Now().UnixMilli()
|
||||||
|
}
|
||||||
|
|
||||||
|
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
|
||||||
|
s.wsManager.Broadcast(wsMsg)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushSystemNotification 推送系统通知(使用独立的 SystemNotification 模型)
|
||||||
|
func (s *pushServiceImpl) PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error {
|
||||||
|
// 首先尝试WebSocket推送
|
||||||
|
if s.pushSystemNotificationViaWebSocket(ctx, userID, notification) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用户不在线,系统通知已存储在数据库中,用户上线后会主动拉取
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// pushSystemNotificationViaWebSocket 通过WebSocket推送系统通知
|
||||||
|
func (s *pushServiceImpl) pushSystemNotificationViaWebSocket(ctx context.Context, userID string, notification *model.SystemNotification) bool {
|
||||||
|
if s.wsManager == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.wsManager.IsUserOnline(userID) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建 WebSocket 通知消息
|
||||||
|
wsNotification := &websocket.NotificationMessage{
|
||||||
|
ID: fmt.Sprintf("%d", notification.ID),
|
||||||
|
Type: string(notification.Type),
|
||||||
|
Title: notification.Title,
|
||||||
|
Content: notification.Content,
|
||||||
|
Extra: make(map[string]interface{}),
|
||||||
|
CreatedAt: notification.CreatedAt.UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 填充额外数据
|
||||||
|
if notification.ExtraData != nil {
|
||||||
|
wsNotification.Extra["actor_id_str"] = notification.ExtraData.ActorIDStr
|
||||||
|
wsNotification.Extra["actor_name"] = notification.ExtraData.ActorName
|
||||||
|
wsNotification.Extra["avatar_url"] = notification.ExtraData.AvatarURL
|
||||||
|
wsNotification.Extra["target_id"] = notification.ExtraData.TargetID
|
||||||
|
wsNotification.Extra["target_type"] = notification.ExtraData.TargetType
|
||||||
|
wsNotification.Extra["action_url"] = notification.ExtraData.ActionURL
|
||||||
|
wsNotification.Extra["action_time"] = notification.ExtraData.ActionTime
|
||||||
|
|
||||||
|
// 设置触发用户信息
|
||||||
|
if notification.ExtraData.ActorIDStr != "" {
|
||||||
|
wsNotification.TriggerUser = &websocket.NotificationUser{
|
||||||
|
ID: notification.ExtraData.ActorIDStr,
|
||||||
|
Username: notification.ExtraData.ActorName,
|
||||||
|
Avatar: notification.ExtraData.AvatarURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, wsNotification)
|
||||||
|
s.wsManager.SendToUser(userID, wsMsg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
559
internal/service/sensitive_service.go
Normal file
559
internal/service/sensitive_service.go
Normal file
@@ -0,0 +1,559 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
redisclient "carrot_bbs/internal/pkg/redis"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==================== DFA 敏感词过滤实现 ====================
|
||||||
|
|
||||||
|
// SensitiveNode 敏感词树节点
|
||||||
|
type SensitiveNode struct {
|
||||||
|
// 子节点映射
|
||||||
|
Children map[rune]*SensitiveNode
|
||||||
|
// 是否为敏感词结尾
|
||||||
|
IsEnd bool
|
||||||
|
// 敏感词信息(仅在 IsEnd 为 true 时有效)
|
||||||
|
Word string
|
||||||
|
Level model.SensitiveWordLevel
|
||||||
|
Category model.SensitiveWordCategory
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSensitiveNode 创建新的敏感词节点
|
||||||
|
func NewSensitiveNode() *SensitiveNode {
|
||||||
|
return &SensitiveNode{
|
||||||
|
Children: make(map[rune]*SensitiveNode),
|
||||||
|
IsEnd: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SensitiveWordTree 敏感词树
|
||||||
|
type SensitiveWordTree struct {
|
||||||
|
root *SensitiveNode
|
||||||
|
wordCount int
|
||||||
|
mu sync.RWMutex
|
||||||
|
lastReload time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSensitiveWordTree 创建新的敏感词树
|
||||||
|
func NewSensitiveWordTree() *SensitiveWordTree {
|
||||||
|
return &SensitiveWordTree{
|
||||||
|
root: NewSensitiveNode(),
|
||||||
|
wordCount: 0,
|
||||||
|
lastReload: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddWord 添加敏感词到树中
|
||||||
|
func (t *SensitiveWordTree) AddWord(word string, level model.SensitiveWordLevel, category model.SensitiveWordCategory) {
|
||||||
|
if word == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
node := t.root
|
||||||
|
// 转换为小写进行匹配(不区分大小写)
|
||||||
|
lowerWord := strings.ToLower(word)
|
||||||
|
runes := []rune(lowerWord)
|
||||||
|
|
||||||
|
for _, r := range runes {
|
||||||
|
child, exists := node.Children[r]
|
||||||
|
if !exists {
|
||||||
|
child = NewSensitiveNode()
|
||||||
|
node.Children[r] = child
|
||||||
|
}
|
||||||
|
node = child
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果不是已存在的敏感词,则计数+1
|
||||||
|
if !node.IsEnd {
|
||||||
|
t.wordCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
node.IsEnd = true
|
||||||
|
node.Word = word
|
||||||
|
node.Level = level
|
||||||
|
node.Category = category
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveWord 从树中移除敏感词
|
||||||
|
func (t *SensitiveWordTree) RemoveWord(word string) {
|
||||||
|
if word == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
lowerWord := strings.ToLower(word)
|
||||||
|
runes := []rune(lowerWord)
|
||||||
|
|
||||||
|
// 查找节点
|
||||||
|
node := t.root
|
||||||
|
for _, r := range runes {
|
||||||
|
child, exists := node.Children[r]
|
||||||
|
if !exists {
|
||||||
|
return // 敏感词不存在
|
||||||
|
}
|
||||||
|
node = child
|
||||||
|
}
|
||||||
|
|
||||||
|
if node.IsEnd {
|
||||||
|
node.IsEnd = false
|
||||||
|
node.Word = ""
|
||||||
|
t.wordCount--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check 检查文本是否包含敏感词,返回是否包含及敏感词列表
|
||||||
|
func (t *SensitiveWordTree) Check(text string) (bool, []string) {
|
||||||
|
if text == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mu.RLock()
|
||||||
|
defer t.mu.RUnlock()
|
||||||
|
|
||||||
|
var foundWords []string
|
||||||
|
runes := []rune(strings.ToLower(text))
|
||||||
|
length := len(runes)
|
||||||
|
|
||||||
|
// 用于标记已找到的敏感词位置,避免重复计算
|
||||||
|
marked := make([]bool, length)
|
||||||
|
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
// 从当前位置开始搜索
|
||||||
|
node := t.root
|
||||||
|
matchEnd := -1
|
||||||
|
matchWord := ""
|
||||||
|
|
||||||
|
for j := i; j < length; j++ {
|
||||||
|
child, exists := node.Children[runes[j]]
|
||||||
|
if !exists {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
node = child
|
||||||
|
|
||||||
|
if node.IsEnd {
|
||||||
|
matchEnd = j
|
||||||
|
matchWord = node.Word
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 标记找到的敏感词位置
|
||||||
|
if matchEnd >= 0 && !marked[i] {
|
||||||
|
for k := i; k <= matchEnd; k++ {
|
||||||
|
marked[k] = true
|
||||||
|
}
|
||||||
|
foundWords = append(foundWords, matchWord)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(foundWords) > 0, foundWords
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace 替换文本中的敏感词
|
||||||
|
func (t *SensitiveWordTree) Replace(text string, repl string) string {
|
||||||
|
if text == "" {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mu.RLock()
|
||||||
|
defer t.mu.RUnlock()
|
||||||
|
|
||||||
|
runes := []rune(text)
|
||||||
|
length := len(runes)
|
||||||
|
result := make([]rune, 0, length)
|
||||||
|
|
||||||
|
// 用于标记已替换的位置
|
||||||
|
marked := make([]bool, length)
|
||||||
|
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
if marked[i] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从当前位置开始搜索
|
||||||
|
node := t.root
|
||||||
|
matchEnd := -1
|
||||||
|
|
||||||
|
for j := i; j < length; j++ {
|
||||||
|
child, exists := node.Children[runes[j]]
|
||||||
|
if !exists {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
node = child
|
||||||
|
|
||||||
|
if node.IsEnd {
|
||||||
|
matchEnd = j
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if matchEnd >= 0 {
|
||||||
|
// 标记已替换的位置
|
||||||
|
for k := i; k <= matchEnd; k++ {
|
||||||
|
marked[k] = true
|
||||||
|
}
|
||||||
|
// 追加替换符
|
||||||
|
replRunes := []rune(repl)
|
||||||
|
result = append(result, replRunes...)
|
||||||
|
// 跳过已匹配的字符
|
||||||
|
i = matchEnd
|
||||||
|
} else {
|
||||||
|
// 追加原字符
|
||||||
|
result = append(result, runes[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WordCount 获取敏感词数量
|
||||||
|
func (t *SensitiveWordTree) WordCount() int {
|
||||||
|
t.mu.RLock()
|
||||||
|
defer t.mu.RUnlock()
|
||||||
|
return t.wordCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 敏感词服务实现 ====================
|
||||||
|
|
||||||
|
// SensitiveService 敏感词服务接口
|
||||||
|
type SensitiveService interface {
|
||||||
|
// Check 检查文本是否包含敏感词
|
||||||
|
Check(ctx context.Context, text string) (bool, []string)
|
||||||
|
// Replace 替换敏感词
|
||||||
|
Replace(ctx context.Context, text string, repl string) string
|
||||||
|
// AddWord 添加敏感词
|
||||||
|
AddWord(ctx context.Context, word string, category string, level int) error
|
||||||
|
// RemoveWord 移除敏感词
|
||||||
|
RemoveWord(ctx context.Context, word string) error
|
||||||
|
// Reload 重新加载敏感词库
|
||||||
|
Reload(ctx context.Context) error
|
||||||
|
// GetWordCount 获取敏感词数量
|
||||||
|
GetWordCount(ctx context.Context) int
|
||||||
|
}
|
||||||
|
|
||||||
|
// sensitiveServiceImpl 敏感词服务实现
|
||||||
|
type sensitiveServiceImpl struct {
|
||||||
|
tree *SensitiveWordTree
|
||||||
|
db *gorm.DB
|
||||||
|
redis *redisclient.Client
|
||||||
|
config *SensitiveConfig
|
||||||
|
mu sync.RWMutex
|
||||||
|
replaceStr string
|
||||||
|
}
|
||||||
|
|
||||||
|
// SensitiveConfig 敏感词服务配置
|
||||||
|
type SensitiveConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled" yaml:"enabled"`
|
||||||
|
ReplaceStr string `mapstructure:"replace_str" yaml:"replace_str"`
|
||||||
|
// 最小匹配长度
|
||||||
|
MinMatchLen int `mapstructure:"min_match_len" yaml:"min_match_len"`
|
||||||
|
// 是否从数据库加载
|
||||||
|
LoadFromDB bool `mapstructure:"load_from_db" yaml:"load_from_db"`
|
||||||
|
// 是否从Redis加载
|
||||||
|
LoadFromRedis bool `mapstructure:"load_from_redis" yaml:"load_from_redis"`
|
||||||
|
// Redis键前缀
|
||||||
|
RedisKeyPrefix string `mapstructure:"redis_key_prefix" yaml:"redis_key_prefix"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSensitiveService 创建敏感词服务
|
||||||
|
func NewSensitiveService(db *gorm.DB, redisClient *redisclient.Client, config *SensitiveConfig) SensitiveService {
|
||||||
|
s := &sensitiveServiceImpl{
|
||||||
|
tree: NewSensitiveWordTree(),
|
||||||
|
db: db,
|
||||||
|
redis: redisClient,
|
||||||
|
config: config,
|
||||||
|
replaceStr: config.ReplaceStr,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果未设置替换符,默认使用 ***
|
||||||
|
if s.replaceStr == "" {
|
||||||
|
s.replaceStr = "***"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化加载敏感词
|
||||||
|
if config.LoadFromDB {
|
||||||
|
if err := s.loadFromDB(context.Background()); err != nil {
|
||||||
|
log.Printf("Failed to load sensitive words from database: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.LoadFromRedis && redisClient != nil {
|
||||||
|
if err := s.loadFromRedis(context.Background()); err != nil {
|
||||||
|
log.Printf("Failed to load sensitive words from redis: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check 检查文本是否包含敏感词
|
||||||
|
func (s *sensitiveServiceImpl) Check(ctx context.Context, text string) (bool, []string) {
|
||||||
|
if !s.config.Enabled {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if text == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return s.tree.Check(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace 替换敏感词
|
||||||
|
func (s *sensitiveServiceImpl) Replace(ctx context.Context, text string, repl string) string {
|
||||||
|
if !s.config.Enabled {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
if text == "" {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果未指定替换符,使用默认替换符
|
||||||
|
if repl == "" {
|
||||||
|
repl = s.replaceStr
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.tree.Replace(text, repl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddWord 添加敏感词
|
||||||
|
func (s *sensitiveServiceImpl) AddWord(ctx context.Context, word string, category string, level int) error {
|
||||||
|
if word == "" {
|
||||||
|
return fmt.Errorf("word cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为敏感词级别
|
||||||
|
wordLevel := model.SensitiveWordLevel(level)
|
||||||
|
if wordLevel < 1 || wordLevel > 3 {
|
||||||
|
wordLevel = model.SensitiveWordLevelLow
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换为敏感词分类
|
||||||
|
wordCategory := model.SensitiveWordCategory(category)
|
||||||
|
if wordCategory == "" {
|
||||||
|
wordCategory = model.SensitiveWordCategoryOther
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加到树
|
||||||
|
s.tree.AddWord(word, wordLevel, wordCategory)
|
||||||
|
|
||||||
|
// 持久化到数据库
|
||||||
|
if s.db != nil {
|
||||||
|
sensitiveWord := model.SensitiveWord{
|
||||||
|
Word: word,
|
||||||
|
Category: wordCategory,
|
||||||
|
Level: wordLevel,
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 upsert 逻辑
|
||||||
|
var existing model.SensitiveWord
|
||||||
|
result := s.db.Where("word = ?", word).First(&existing)
|
||||||
|
if result.Error == gorm.ErrRecordNotFound {
|
||||||
|
if err := s.db.Create(&sensitiveWord).Error; err != nil {
|
||||||
|
log.Printf("Failed to save sensitive word to database: %v", err)
|
||||||
|
}
|
||||||
|
} else if result.Error == nil {
|
||||||
|
// 更新已存在的记录
|
||||||
|
existing.Category = wordCategory
|
||||||
|
existing.Level = wordLevel
|
||||||
|
existing.IsActive = true
|
||||||
|
if err := s.db.Save(&existing).Error; err != nil {
|
||||||
|
log.Printf("Failed to update sensitive word in database: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 同步到 Redis
|
||||||
|
if s.redis != nil && s.config.RedisKeyPrefix != "" {
|
||||||
|
key := fmt.Sprintf("%s:%s", s.config.RedisKeyPrefix, word)
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"word": word,
|
||||||
|
"category": category,
|
||||||
|
"level": level,
|
||||||
|
}
|
||||||
|
jsonData, _ := json.Marshal(data)
|
||||||
|
s.redis.Set(ctx, key, jsonData, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveWord 移除敏感词
|
||||||
|
func (s *sensitiveServiceImpl) RemoveWord(ctx context.Context, word string) error {
|
||||||
|
if word == "" {
|
||||||
|
return fmt.Errorf("word cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从树中移除
|
||||||
|
s.tree.RemoveWord(word)
|
||||||
|
|
||||||
|
// 从数据库中标记为不活跃
|
||||||
|
if s.db != nil {
|
||||||
|
result := s.db.Model(&model.SensitiveWord{}).Where("word = ?", word).Update("is_active", false)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.Printf("Failed to deactivate sensitive word in database: %v", result.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从 Redis 中删除
|
||||||
|
if s.redis != nil && s.config.RedisKeyPrefix != "" {
|
||||||
|
key := fmt.Sprintf("%s:%s", s.config.RedisKeyPrefix, word)
|
||||||
|
s.redis.Del(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload 重新加载敏感词库
|
||||||
|
func (s *sensitiveServiceImpl) Reload(ctx context.Context) error {
|
||||||
|
// 清空现有树
|
||||||
|
s.tree = NewSensitiveWordTree()
|
||||||
|
|
||||||
|
// 从数据库加载
|
||||||
|
if s.config.LoadFromDB {
|
||||||
|
if err := s.loadFromDB(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to load from database: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从 Redis 加载
|
||||||
|
if s.config.LoadFromRedis && s.redis != nil {
|
||||||
|
if err := s.loadFromRedis(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to load from redis: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWordCount 获取敏感词数量
|
||||||
|
func (s *sensitiveServiceImpl) GetWordCount(ctx context.Context) int {
|
||||||
|
return s.tree.WordCount()
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadFromDB 从数据库加载敏感词
|
||||||
|
func (s *sensitiveServiceImpl) loadFromDB(ctx context.Context) error {
|
||||||
|
if s.db == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var words []model.SensitiveWord
|
||||||
|
if err := s.db.Where("is_active = ?", true).Find(&words).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, word := range words {
|
||||||
|
s.tree.AddWord(word.Word, word.Level, word.Category)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Loaded %d sensitive words from database", len(words))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadFromRedis 从 Redis 加载敏感词
|
||||||
|
func (s *sensitiveServiceImpl) loadFromRedis(ctx context.Context) error {
|
||||||
|
if s.redis == nil || s.config.RedisKeyPrefix == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 SCAN 命令代替 KEYS,避免阻塞
|
||||||
|
pattern := fmt.Sprintf("%s:*", s.config.RedisKeyPrefix)
|
||||||
|
var cursor uint64
|
||||||
|
for {
|
||||||
|
keys, nextCursor, err := s.redis.GetClient().Scan(ctx, cursor, pattern, 100).Result()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range keys {
|
||||||
|
data, err := s.redis.Get(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var wordData map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(data), &wordData); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
word, _ := wordData["word"].(string)
|
||||||
|
category, _ := wordData["category"].(string)
|
||||||
|
level, _ := wordData["level"].(float64)
|
||||||
|
|
||||||
|
if word != "" {
|
||||||
|
s.tree.AddWord(word, model.SensitiveWordLevel(int(level)), model.SensitiveWordCategory(category))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cursor = nextCursor
|
||||||
|
if cursor == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 辅助函数 ====================
|
||||||
|
|
||||||
|
// ContainsSensitiveWord 快速检查文本是否包含敏感词
|
||||||
|
func ContainsSensitiveWord(text string, tree *SensitiveWordTree) bool {
|
||||||
|
if tree == nil || text == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
hasSensitive, _ := tree.Check(text)
|
||||||
|
return hasSensitive
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterSensitiveWords 过滤敏感词并返回替换后的文本
|
||||||
|
func FilterSensitiveWords(text string, tree *SensitiveWordTree, repl string) string {
|
||||||
|
if tree == nil || text == "" {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
if repl == "" {
|
||||||
|
repl = "***"
|
||||||
|
}
|
||||||
|
return tree.Replace(text, repl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateTextLength 验证文本长度是否合法
|
||||||
|
func ValidateTextLength(text string, minLen, maxLen int) bool {
|
||||||
|
length := utf8.RuneCountInString(text)
|
||||||
|
return length >= minLen && length <= maxLen
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeText 清理文本,移除多余空白字符
|
||||||
|
func SanitizeText(text string) string {
|
||||||
|
// 替换多个连续空白字符为单个空格
|
||||||
|
spaceReg := regexp.MustCompile(`\s+`)
|
||||||
|
text = spaceReg.ReplaceAllString(text, " ")
|
||||||
|
// 去除首尾空白
|
||||||
|
return strings.TrimSpace(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 默认敏感词列表 ====================
|
||||||
|
|
||||||
|
// DefaultSensitiveWords 返回默认敏感词列表(示例)
|
||||||
|
func DefaultSensitiveWords() map[string]struct{} {
|
||||||
|
return map[string]struct{}{
|
||||||
|
// 示例敏感词,实际需要从数据库或配置加载
|
||||||
|
"测试敏感词1": {},
|
||||||
|
"测试敏感词2": {},
|
||||||
|
"测试敏感词3": {},
|
||||||
|
}
|
||||||
|
}
|
||||||
139
internal/service/sticker_service.go
Normal file
139
internal/service/sticker_service.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/repository"
|
||||||
|
"errors"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrStickerAlreadyExists = errors.New("sticker already exists")
|
||||||
|
ErrInvalidStickerURL = errors.New("invalid sticker url")
|
||||||
|
)
|
||||||
|
|
||||||
|
// StickerService 自定义表情服务接口
|
||||||
|
type StickerService interface {
|
||||||
|
// 获取用户的所有表情
|
||||||
|
GetUserStickers(userID string) ([]model.UserSticker, error)
|
||||||
|
// 添加表情
|
||||||
|
AddSticker(userID string, url string, width, height int) (*model.UserSticker, error)
|
||||||
|
// 删除表情
|
||||||
|
DeleteSticker(userID string, stickerID string) error
|
||||||
|
// 检查表情是否已存在
|
||||||
|
CheckExists(userID string, url string) (bool, error)
|
||||||
|
// 重新排序
|
||||||
|
ReorderStickers(userID string, orders map[string]int) error
|
||||||
|
// 获取用户表情数量
|
||||||
|
GetStickerCount(userID string) (int64, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stickerService 自定义表情服务实现
|
||||||
|
type stickerService struct {
|
||||||
|
stickerRepo repository.StickerRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStickerService 创建自定义表情服务
|
||||||
|
func NewStickerService(stickerRepo repository.StickerRepository) StickerService {
|
||||||
|
return &stickerService{
|
||||||
|
stickerRepo: stickerRepo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserStickers 获取用户的所有表情
|
||||||
|
func (s *stickerService) GetUserStickers(userID string) ([]model.UserSticker, error) {
|
||||||
|
stickers, err := s.stickerRepo.GetByUserID(userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 兼容历史脏数据:过滤本地文件 URI,避免客户端加载 file:// 报错
|
||||||
|
filtered := make([]model.UserSticker, 0, len(stickers))
|
||||||
|
for _, sticker := range stickers {
|
||||||
|
if isValidStickerURL(sticker.URL) {
|
||||||
|
filtered = append(filtered, sticker)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSticker 添加表情
|
||||||
|
func (s *stickerService) AddSticker(userID string, url string, width, height int) (*model.UserSticker, error) {
|
||||||
|
if !isValidStickerURL(url) {
|
||||||
|
return nil, ErrInvalidStickerURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否已存在
|
||||||
|
exists, err := s.stickerRepo.Exists(userID, url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
return nil, ErrStickerAlreadyExists
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取当前数量用于设置排序
|
||||||
|
count, err := s.stickerRepo.CountByUserID(userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sticker := &model.UserSticker{
|
||||||
|
UserID: userID,
|
||||||
|
URL: url,
|
||||||
|
Width: width,
|
||||||
|
Height: height,
|
||||||
|
SortOrder: int(count), // 新表情添加到末尾
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.stickerRepo.Create(sticker); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return sticker, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidStickerURL(raw string) bool {
|
||||||
|
trimmed := strings.TrimSpace(raw)
|
||||||
|
if trimmed == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := url.Parse(trimmed)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
scheme := strings.ToLower(parsed.Scheme)
|
||||||
|
return scheme == "http" || scheme == "https"
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSticker 删除表情
|
||||||
|
func (s *stickerService) DeleteSticker(userID string, stickerID string) error {
|
||||||
|
// 先检查表情是否属于该用户
|
||||||
|
sticker, err := s.stickerRepo.GetByID(stickerID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if sticker.UserID != userID {
|
||||||
|
return errors.New("sticker not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.stickerRepo.Delete(stickerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckExists 检查表情是否已存在
|
||||||
|
func (s *stickerService) CheckExists(userID string, url string) (bool, error) {
|
||||||
|
return s.stickerRepo.Exists(userID, url)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReorderStickers 重新排序
|
||||||
|
func (s *stickerService) ReorderStickers(userID string, orders map[string]int) error {
|
||||||
|
return s.stickerRepo.BatchUpdateSortOrder(userID, orders)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStickerCount 获取用户表情数量
|
||||||
|
func (s *stickerService) GetStickerCount(userID string) (int64, error) {
|
||||||
|
return s.stickerRepo.CountByUserID(userID)
|
||||||
|
}
|
||||||
462
internal/service/system_message_service.go
Normal file
462
internal/service/system_message_service.go
Normal file
@@ -0,0 +1,462 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/cache"
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/pkg/utils"
|
||||||
|
"carrot_bbs/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SystemMessageService 系统消息服务接口
|
||||||
|
type SystemMessageService interface {
|
||||||
|
// 发送互动通知
|
||||||
|
SendLikeNotification(ctx context.Context, userID string, operatorID string, postID string) error
|
||||||
|
SendCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string) error
|
||||||
|
SendReplyNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, replyID string) error
|
||||||
|
SendFollowNotification(ctx context.Context, userID string, operatorID string) error
|
||||||
|
SendMentionNotification(ctx context.Context, userID string, operatorID string, postID string) error
|
||||||
|
SendFavoriteNotification(ctx context.Context, userID string, operatorID string, postID string) error
|
||||||
|
SendLikeCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, commentContent string) error
|
||||||
|
SendLikeReplyNotification(ctx context.Context, userID string, operatorID string, postID string, replyID string, replyContent string) error
|
||||||
|
|
||||||
|
// 发送系统公告
|
||||||
|
SendSystemAnnouncement(ctx context.Context, userIDs []string, title string, content string) error
|
||||||
|
SendBroadcastAnnouncement(ctx context.Context, title string, content string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type systemMessageServiceImpl struct {
|
||||||
|
notifyRepo *repository.SystemNotificationRepository
|
||||||
|
pushService PushService
|
||||||
|
userRepo *repository.UserRepository
|
||||||
|
postRepo *repository.PostRepository
|
||||||
|
cache cache.Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSystemMessageService 创建系统消息服务
|
||||||
|
func NewSystemMessageService(
|
||||||
|
notifyRepo *repository.SystemNotificationRepository,
|
||||||
|
pushService PushService,
|
||||||
|
userRepo *repository.UserRepository,
|
||||||
|
postRepo *repository.PostRepository,
|
||||||
|
) SystemMessageService {
|
||||||
|
return &systemMessageServiceImpl{
|
||||||
|
notifyRepo: notifyRepo,
|
||||||
|
pushService: pushService,
|
||||||
|
userRepo: userRepo,
|
||||||
|
postRepo: postRepo,
|
||||||
|
cache: cache.GetCache(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendLikeNotification 发送点赞通知
|
||||||
|
func (s *systemMessageServiceImpl) SendLikeNotification(ctx context.Context, userID string, operatorID string, postID string) error {
|
||||||
|
// 获取操作者信息
|
||||||
|
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取帖子标题
|
||||||
|
postTitle, err := s.getPostTitle(postID)
|
||||||
|
if err != nil {
|
||||||
|
postTitle = "您的帖子"
|
||||||
|
}
|
||||||
|
|
||||||
|
extraData := &model.SystemNotificationExtra{
|
||||||
|
ActorIDStr: operatorID,
|
||||||
|
ActorName: actorName,
|
||||||
|
AvatarURL: avatarURL,
|
||||||
|
TargetID: postID,
|
||||||
|
TargetTitle: postTitle,
|
||||||
|
TargetType: "post",
|
||||||
|
ActionURL: fmt.Sprintf("/posts/%s", postID),
|
||||||
|
ActionTime: time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
content := fmt.Sprintf("%s 赞了「%s」", actorName, postTitle)
|
||||||
|
|
||||||
|
// 创建通知
|
||||||
|
notification, err := s.createNotification(ctx, userID, model.SysNotifyLikePost, content, extraData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create like notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 推送通知
|
||||||
|
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendCommentNotification 发送评论通知
|
||||||
|
func (s *systemMessageServiceImpl) SendCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string) error {
|
||||||
|
// 获取操作者信息
|
||||||
|
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取帖子标题
|
||||||
|
postTitle, err := s.getPostTitle(postID)
|
||||||
|
if err != nil {
|
||||||
|
postTitle = "您的帖子"
|
||||||
|
}
|
||||||
|
|
||||||
|
extraData := &model.SystemNotificationExtra{
|
||||||
|
ActorIDStr: operatorID,
|
||||||
|
ActorName: actorName,
|
||||||
|
AvatarURL: avatarURL,
|
||||||
|
TargetID: postID,
|
||||||
|
TargetTitle: postTitle,
|
||||||
|
TargetType: "comment",
|
||||||
|
ActionURL: fmt.Sprintf("/posts/%s?comment=%s", postID, commentID),
|
||||||
|
ActionTime: time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
content := fmt.Sprintf("%s 评论了「%s」", actorName, postTitle)
|
||||||
|
|
||||||
|
// 创建通知
|
||||||
|
notification, err := s.createNotification(ctx, userID, model.SysNotifyComment, content, extraData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create comment notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 推送通知
|
||||||
|
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendReplyNotification 发送回复通知
|
||||||
|
func (s *systemMessageServiceImpl) SendReplyNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, replyID string) error {
|
||||||
|
// 获取操作者信息
|
||||||
|
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取帖子标题
|
||||||
|
postTitle, err := s.getPostTitle(postID)
|
||||||
|
if err != nil {
|
||||||
|
postTitle = "您的帖子"
|
||||||
|
}
|
||||||
|
|
||||||
|
extraData := &model.SystemNotificationExtra{
|
||||||
|
ActorIDStr: operatorID,
|
||||||
|
ActorName: actorName,
|
||||||
|
AvatarURL: avatarURL,
|
||||||
|
TargetID: replyID,
|
||||||
|
TargetTitle: postTitle,
|
||||||
|
TargetType: "reply",
|
||||||
|
ActionURL: fmt.Sprintf("/posts/%s?comment=%s&reply=%s", postID, commentID, replyID),
|
||||||
|
ActionTime: time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
content := fmt.Sprintf("%s 回复了您在「%s」的评论", actorName, postTitle)
|
||||||
|
|
||||||
|
// 创建通知
|
||||||
|
notification, err := s.createNotification(ctx, userID, model.SysNotifyReply, content, extraData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create reply notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 推送通知
|
||||||
|
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendFollowNotification 发送关注通知
|
||||||
|
func (s *systemMessageServiceImpl) SendFollowNotification(ctx context.Context, userID string, operatorID string) error {
|
||||||
|
fmt.Printf("[DEBUG] SendFollowNotification: userID=%s, operatorID=%s\n", userID, operatorID)
|
||||||
|
|
||||||
|
// 获取操作者信息
|
||||||
|
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[DEBUG] SendFollowNotification: getActorInfo error: %v\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Printf("[DEBUG] SendFollowNotification: actorName=%s, avatarURL=%s\n", actorName, avatarURL)
|
||||||
|
|
||||||
|
extraData := &model.SystemNotificationExtra{
|
||||||
|
ActorIDStr: operatorID,
|
||||||
|
ActorName: actorName,
|
||||||
|
AvatarURL: avatarURL,
|
||||||
|
TargetID: "",
|
||||||
|
TargetType: "user",
|
||||||
|
ActionURL: fmt.Sprintf("/users/%s", operatorID),
|
||||||
|
ActionTime: time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
content := fmt.Sprintf("%s 关注了你", actorName)
|
||||||
|
|
||||||
|
// 创建通知
|
||||||
|
notification, err := s.createNotification(ctx, userID, model.SysNotifyFollow, content, extraData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create follow notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("[DEBUG] SendFollowNotification: notification created, ID=%d, Content=%s\n", notification.ID, notification.Content)
|
||||||
|
|
||||||
|
// 推送通知
|
||||||
|
pushErr := s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||||
|
if pushErr != nil {
|
||||||
|
fmt.Printf("[DEBUG] SendFollowNotification: PushSystemNotification error: %v\n", pushErr)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("[DEBUG] SendFollowNotification: PushSystemNotification success\n")
|
||||||
|
}
|
||||||
|
return pushErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendFavoriteNotification 发送收藏通知
|
||||||
|
func (s *systemMessageServiceImpl) SendFavoriteNotification(ctx context.Context, userID string, operatorID string, postID string) error {
|
||||||
|
// 获取操作者信息
|
||||||
|
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取帖子标题
|
||||||
|
postTitle, err := s.getPostTitle(postID)
|
||||||
|
if err != nil {
|
||||||
|
postTitle = "您的帖子"
|
||||||
|
}
|
||||||
|
|
||||||
|
extraData := &model.SystemNotificationExtra{
|
||||||
|
ActorIDStr: operatorID,
|
||||||
|
ActorName: actorName,
|
||||||
|
AvatarURL: avatarURL,
|
||||||
|
TargetID: postID,
|
||||||
|
TargetTitle: postTitle,
|
||||||
|
TargetType: "post",
|
||||||
|
ActionURL: fmt.Sprintf("/posts/%s", postID),
|
||||||
|
ActionTime: time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
content := fmt.Sprintf("%s 收藏了「%s」", actorName, postTitle)
|
||||||
|
|
||||||
|
// 创建通知
|
||||||
|
notification, err := s.createNotification(ctx, userID, model.SysNotifyFavoritePost, content, extraData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create favorite notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 推送通知
|
||||||
|
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendLikeCommentNotification 发送评论点赞通知
|
||||||
|
func (s *systemMessageServiceImpl) SendLikeCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, commentContent string) error {
|
||||||
|
// 获取操作者信息
|
||||||
|
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 截取评论内容预览(最多50字)
|
||||||
|
preview := commentContent
|
||||||
|
runes := []rune(preview)
|
||||||
|
if len(runes) > 50 {
|
||||||
|
preview = string(runes[:50]) + "..."
|
||||||
|
}
|
||||||
|
|
||||||
|
extraData := &model.SystemNotificationExtra{
|
||||||
|
ActorIDStr: operatorID,
|
||||||
|
ActorName: actorName,
|
||||||
|
AvatarURL: avatarURL,
|
||||||
|
TargetID: postID,
|
||||||
|
TargetTitle: preview,
|
||||||
|
TargetType: "comment",
|
||||||
|
ActionURL: fmt.Sprintf("/posts/%s?comment=%s", postID, commentID),
|
||||||
|
ActionTime: time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
content := fmt.Sprintf("%s 赞了您的评论", actorName)
|
||||||
|
|
||||||
|
// 创建通知
|
||||||
|
notification, err := s.createNotification(ctx, userID, model.SysNotifyLikeComment, content, extraData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create like comment notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 推送通知
|
||||||
|
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendLikeReplyNotification 发送回复点赞通知
|
||||||
|
func (s *systemMessageServiceImpl) SendLikeReplyNotification(ctx context.Context, userID string, operatorID string, postID string, replyID string, replyContent string) error {
|
||||||
|
// 获取操作者信息
|
||||||
|
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 截取回复内容预览(最多50字)
|
||||||
|
preview := replyContent
|
||||||
|
runes := []rune(preview)
|
||||||
|
if len(runes) > 50 {
|
||||||
|
preview = string(runes[:50]) + "..."
|
||||||
|
}
|
||||||
|
|
||||||
|
extraData := &model.SystemNotificationExtra{
|
||||||
|
ActorIDStr: operatorID,
|
||||||
|
ActorName: actorName,
|
||||||
|
AvatarURL: avatarURL,
|
||||||
|
TargetID: postID,
|
||||||
|
TargetTitle: preview,
|
||||||
|
TargetType: "reply",
|
||||||
|
ActionURL: fmt.Sprintf("/posts/%s?reply=%s", postID, replyID),
|
||||||
|
ActionTime: time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
content := fmt.Sprintf("%s 赞了您的回复", actorName)
|
||||||
|
|
||||||
|
// 创建通知
|
||||||
|
notification, err := s.createNotification(ctx, userID, model.SysNotifyLikeReply, content, extraData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create like reply notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 推送通知
|
||||||
|
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMentionNotification 发送@提及通知
|
||||||
|
func (s *systemMessageServiceImpl) SendMentionNotification(ctx context.Context, userID string, operatorID string, postID string) error {
|
||||||
|
// 获取操作者信息
|
||||||
|
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取帖子标题
|
||||||
|
postTitle, err := s.getPostTitle(postID)
|
||||||
|
if err != nil {
|
||||||
|
postTitle = "您的帖子"
|
||||||
|
}
|
||||||
|
|
||||||
|
extraData := &model.SystemNotificationExtra{
|
||||||
|
ActorIDStr: operatorID,
|
||||||
|
ActorName: actorName,
|
||||||
|
AvatarURL: avatarURL,
|
||||||
|
TargetID: postID,
|
||||||
|
TargetTitle: postTitle,
|
||||||
|
TargetType: "post",
|
||||||
|
ActionURL: fmt.Sprintf("/posts/%s", postID),
|
||||||
|
ActionTime: time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
content := fmt.Sprintf("%s 在「%s」中提到了你", actorName, postTitle)
|
||||||
|
|
||||||
|
// 创建通知
|
||||||
|
notification, err := s.createNotification(ctx, userID, model.SysNotifyMention, content, extraData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create mention notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 推送通知
|
||||||
|
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendSystemAnnouncement 发送系统公告给指定用户
|
||||||
|
func (s *systemMessageServiceImpl) SendSystemAnnouncement(ctx context.Context, userIDs []string, title string, content string) error {
|
||||||
|
for _, userID := range userIDs {
|
||||||
|
extraData := &model.SystemNotificationExtra{
|
||||||
|
TargetType: "announcement",
|
||||||
|
ActionTime: time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
notification, err := s.createNotification(ctx, userID, model.SysNotifyAnnounce, fmt.Sprintf("【%s】%s", title, content), extraData)
|
||||||
|
if err != nil {
|
||||||
|
continue // 单个失败不影响其他用户
|
||||||
|
}
|
||||||
|
|
||||||
|
// 推送通知(使用高优先级)
|
||||||
|
if err := s.pushService.PushSystemNotification(ctx, userID, notification); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendBroadcastAnnouncement 发送广播公告给所有在线用户
|
||||||
|
func (s *systemMessageServiceImpl) SendBroadcastAnnouncement(ctx context.Context, title string, content string) error {
|
||||||
|
// TODO: 实现广播公告
|
||||||
|
// 1. 获取所有在线用户
|
||||||
|
// 2. 批量发送公告
|
||||||
|
// 3. 对于离线用户,存储为待推送记录
|
||||||
|
return fmt.Errorf("broadcast announcement not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// createNotification 创建系统通知(存储到独立表)
|
||||||
|
func (s *systemMessageServiceImpl) createNotification(ctx context.Context, userID string, notifyType model.SystemNotificationType, content string, extraData *model.SystemNotificationExtra) (*model.SystemNotification, error) {
|
||||||
|
fmt.Printf("[DEBUG] createNotification: userID=%s, notifyType=%s\n", userID, notifyType)
|
||||||
|
|
||||||
|
// 生成雪花算法ID
|
||||||
|
id, err := utils.GetSnowflake().GenerateID()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[DEBUG] createNotification: failed to generate ID: %v\n", err)
|
||||||
|
return nil, fmt.Errorf("failed to generate notification ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
notification := &model.SystemNotification{
|
||||||
|
ID: id,
|
||||||
|
ReceiverID: userID,
|
||||||
|
Type: notifyType,
|
||||||
|
Content: content,
|
||||||
|
ExtraData: extraData,
|
||||||
|
IsRead: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("[DEBUG] createNotification: notification created with ID=%d, ReceiverID=%s\n", id, userID)
|
||||||
|
|
||||||
|
// 保存通知到数据库
|
||||||
|
if err := s.notifyRepo.Create(notification); err != nil {
|
||||||
|
fmt.Printf("[DEBUG] createNotification: failed to save notification: %v\n", err)
|
||||||
|
return nil, fmt.Errorf("failed to save notification: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效系统消息未读数缓存
|
||||||
|
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||||
|
|
||||||
|
fmt.Printf("[DEBUG] createNotification: notification saved successfully, ID=%d\n", notification.ID)
|
||||||
|
return notification, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getActorInfo 获取操作者信息
|
||||||
|
func (s *systemMessageServiceImpl) getActorInfo(ctx context.Context, operatorID string) (string, string, error) {
|
||||||
|
// 从用户仓储获取用户信息
|
||||||
|
if s.userRepo != nil {
|
||||||
|
user, err := s.userRepo.GetByID(operatorID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[DEBUG] getActorInfo: failed to get user %s: %v\n", operatorID, err)
|
||||||
|
return "用户", utils.GenerateDefaultAvatarURL("用户"), nil // 返回默认值,不阻断流程
|
||||||
|
}
|
||||||
|
avatar := utils.GetAvatarOrDefault(user.Username, user.Nickname, user.Avatar)
|
||||||
|
return user.Nickname, avatar, nil
|
||||||
|
}
|
||||||
|
// 如果没有用户仓储,返回默认值
|
||||||
|
return "用户", utils.GenerateDefaultAvatarURL("用户"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPostTitle 获取帖子标题
|
||||||
|
func (s *systemMessageServiceImpl) getPostTitle(postID string) (string, error) {
|
||||||
|
if s.postRepo == nil {
|
||||||
|
if len(postID) >= 8 {
|
||||||
|
return fmt.Sprintf("帖子#%s", postID[:8]), nil
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("帖子#%s", postID), nil
|
||||||
|
}
|
||||||
|
post, err := s.postRepo.GetByID(postID)
|
||||||
|
if err != nil {
|
||||||
|
if len(postID) >= 8 {
|
||||||
|
return fmt.Sprintf("帖子#%s", postID[:8]), nil
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("帖子#%s", postID), nil
|
||||||
|
}
|
||||||
|
if post.Title != "" {
|
||||||
|
return post.Title, nil
|
||||||
|
}
|
||||||
|
// 如果没有标题,返回内容前20个字符
|
||||||
|
if len(post.Content) > 20 {
|
||||||
|
return post.Content[:20] + "...", nil
|
||||||
|
}
|
||||||
|
return post.Content, nil
|
||||||
|
}
|
||||||
273
internal/service/upload_service.go
Normal file
273
internal/service/upload_service.go
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"image/jpeg"
|
||||||
|
"image/png"
|
||||||
|
"io"
|
||||||
|
"mime"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/pkg/s3"
|
||||||
|
_ "golang.org/x/image/bmp"
|
||||||
|
_ "golang.org/x/image/tiff"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UploadService 上传服务
|
||||||
|
type UploadService struct {
|
||||||
|
s3Client *s3.Client
|
||||||
|
userService *UserService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUploadService 创建上传服务
|
||||||
|
func NewUploadService(s3Client *s3.Client, userService *UserService) *UploadService {
|
||||||
|
return &UploadService{
|
||||||
|
s3Client: s3Client,
|
||||||
|
userService: userService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadImage 上传图片
|
||||||
|
func (s *UploadService) UploadImage(ctx context.Context, file *multipart.FileHeader) (string, error) {
|
||||||
|
processedData, contentType, ext, err := prepareImageForUpload(file)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 压缩后再计算哈希,确保同一压缩结果映射同一对象名
|
||||||
|
hash := sha256.Sum256(processedData)
|
||||||
|
hashStr := fmt.Sprintf("%x", hash)
|
||||||
|
|
||||||
|
objectName := fmt.Sprintf("images/%s%s", hashStr, ext)
|
||||||
|
|
||||||
|
url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to upload to S3: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return url, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getExtFromContentType 根据Content-Type获取文件扩展名
|
||||||
|
func getExtFromContentType(contentType string) string {
|
||||||
|
baseType, _, err := mime.ParseMediaType(contentType)
|
||||||
|
if err == nil && baseType != "" {
|
||||||
|
contentType = baseType
|
||||||
|
}
|
||||||
|
|
||||||
|
switch contentType {
|
||||||
|
case "image/jpg", "image/jpeg":
|
||||||
|
return ".jpg"
|
||||||
|
case "image/png":
|
||||||
|
return ".png"
|
||||||
|
case "image/gif":
|
||||||
|
return ".gif"
|
||||||
|
case "image/webp":
|
||||||
|
return ".webp"
|
||||||
|
case "image/bmp", "image/x-ms-bmp":
|
||||||
|
return ".bmp"
|
||||||
|
case "image/tiff":
|
||||||
|
return ".tiff"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadAvatar 上传头像
|
||||||
|
func (s *UploadService) UploadAvatar(ctx context.Context, userID string, file *multipart.FileHeader) (string, error) {
|
||||||
|
processedData, contentType, ext, err := prepareImageForUpload(file)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 压缩后再计算哈希
|
||||||
|
hash := sha256.Sum256(processedData)
|
||||||
|
hashStr := fmt.Sprintf("%x", hash)
|
||||||
|
|
||||||
|
objectName := fmt.Sprintf("avatars/%s%s", hashStr, ext)
|
||||||
|
|
||||||
|
url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to upload to S3: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新用户头像
|
||||||
|
if s.userService != nil {
|
||||||
|
user, err := s.userService.GetUserByID(ctx, userID)
|
||||||
|
if err == nil && user != nil {
|
||||||
|
user.Avatar = url
|
||||||
|
err = s.userService.UpdateUser(ctx, user)
|
||||||
|
if err != nil {
|
||||||
|
// 更新失败不影响上传结果,只记录日志
|
||||||
|
fmt.Printf("[UploadAvatar] failed to update user avatar: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return url, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UploadCover 上传头图(个人主页封面)
|
||||||
|
func (s *UploadService) UploadCover(ctx context.Context, userID string, file *multipart.FileHeader) (string, error) {
|
||||||
|
processedData, contentType, ext, err := prepareImageForUpload(file)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 压缩后再计算哈希
|
||||||
|
hash := sha256.Sum256(processedData)
|
||||||
|
hashStr := fmt.Sprintf("%x", hash)
|
||||||
|
|
||||||
|
objectName := fmt.Sprintf("covers/%s%s", hashStr, ext)
|
||||||
|
|
||||||
|
url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to upload to S3: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新用户头图
|
||||||
|
if s.userService != nil {
|
||||||
|
user, err := s.userService.GetUserByID(ctx, userID)
|
||||||
|
if err == nil && user != nil {
|
||||||
|
user.CoverURL = url
|
||||||
|
err = s.userService.UpdateUser(ctx, user)
|
||||||
|
if err != nil {
|
||||||
|
// 更新失败不影响上传结果,只记录日志
|
||||||
|
fmt.Printf("[UploadCover] failed to update user cover: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return url, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetURL 获取文件URL
|
||||||
|
func (s *UploadService) GetURL(ctx context.Context, objectName string) (string, error) {
|
||||||
|
return s.s3Client.GetURL(ctx, objectName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除文件
|
||||||
|
func (s *UploadService) Delete(ctx context.Context, objectName string) error {
|
||||||
|
return s.s3Client.Delete(ctx, objectName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareImageForUpload(file *multipart.FileHeader) ([]byte, string, string, error) {
|
||||||
|
f, err := file.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", fmt.Errorf("failed to open file: %w", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
originalData, err := io.ReadAll(f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", fmt.Errorf("failed to read file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 优先从文件字节探测真实类型,避免前端压缩/转码后 header 与实际格式不一致
|
||||||
|
detectedType := normalizeImageContentType(http.DetectContentType(originalData))
|
||||||
|
headerType := normalizeImageContentType(file.Header.Get("Content-Type"))
|
||||||
|
contentType := detectedType
|
||||||
|
if contentType == "" || contentType == "application/octet-stream" {
|
||||||
|
contentType = headerType
|
||||||
|
}
|
||||||
|
|
||||||
|
compressedData, compressedType, err := compressImageData(originalData, contentType)
|
||||||
|
if err != nil {
|
||||||
|
// 压缩失败时回退到原图,保证上传可用性
|
||||||
|
compressedData = originalData
|
||||||
|
compressedType = contentType
|
||||||
|
}
|
||||||
|
|
||||||
|
if compressedType == "" {
|
||||||
|
compressedType = contentType
|
||||||
|
}
|
||||||
|
if compressedType == "" {
|
||||||
|
compressedType = http.DetectContentType(compressedData)
|
||||||
|
}
|
||||||
|
|
||||||
|
ext := getExtFromContentType(compressedType)
|
||||||
|
if ext == "" {
|
||||||
|
ext = strings.ToLower(filepath.Ext(file.Filename))
|
||||||
|
}
|
||||||
|
if ext == "" {
|
||||||
|
// 最终兜底,避免对象名无扩展名导致 URL 语义不明确
|
||||||
|
ext = ".jpg"
|
||||||
|
}
|
||||||
|
|
||||||
|
return compressedData, compressedType, ext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func compressImageData(data []byte, contentType string) ([]byte, string, error) {
|
||||||
|
contentType = normalizeImageContentType(contentType)
|
||||||
|
|
||||||
|
// GIF/WebP 等格式先保留原图,避免动画和透明通道丢失
|
||||||
|
if contentType == "image/gif" || contentType == "image/webp" {
|
||||||
|
return data, contentType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if contentType != "image/jpeg" &&
|
||||||
|
contentType != "image/png" &&
|
||||||
|
contentType != "image/bmp" &&
|
||||||
|
contentType != "image/x-ms-bmp" &&
|
||||||
|
contentType != "image/tiff" {
|
||||||
|
return data, contentType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
img, _, err := image.Decode(bytes.NewReader(data))
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", fmt.Errorf("failed to decode image: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
switch contentType {
|
||||||
|
case "image/png":
|
||||||
|
encoder := png.Encoder{CompressionLevel: png.BestCompression}
|
||||||
|
if err := encoder.Encode(&buf, img); err != nil {
|
||||||
|
return nil, "", fmt.Errorf("failed to encode png: %w", err)
|
||||||
|
}
|
||||||
|
return buf.Bytes(), "image/png", nil
|
||||||
|
default:
|
||||||
|
// BMP/TIFF 等无损大图统一压缩为 JPEG,控制体积
|
||||||
|
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 82}); err != nil {
|
||||||
|
return nil, "", fmt.Errorf("failed to encode jpeg: %w", err)
|
||||||
|
}
|
||||||
|
return buf.Bytes(), "image/jpeg", nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeImageContentType(contentType string) string {
|
||||||
|
if contentType == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
baseType, _, err := mime.ParseMediaType(contentType)
|
||||||
|
if err == nil && baseType != "" {
|
||||||
|
contentType = baseType
|
||||||
|
}
|
||||||
|
|
||||||
|
switch strings.ToLower(contentType) {
|
||||||
|
case "image/jpg":
|
||||||
|
return "image/jpeg"
|
||||||
|
case "image/jpeg":
|
||||||
|
return "image/jpeg"
|
||||||
|
case "image/png":
|
||||||
|
return "image/png"
|
||||||
|
case "image/gif":
|
||||||
|
return "image/gif"
|
||||||
|
case "image/webp":
|
||||||
|
return "image/webp"
|
||||||
|
case "image/bmp", "image/x-ms-bmp":
|
||||||
|
return "image/bmp"
|
||||||
|
case "image/tiff":
|
||||||
|
return "image/tiff"
|
||||||
|
default:
|
||||||
|
return contentType
|
||||||
|
}
|
||||||
|
}
|
||||||
592
internal/service/user_service.go
Normal file
592
internal/service/user_service.go
Normal file
@@ -0,0 +1,592 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/cache"
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/pkg/utils"
|
||||||
|
"carrot_bbs/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserService 用户服务
|
||||||
|
type UserService struct {
|
||||||
|
userRepo *repository.UserRepository
|
||||||
|
systemMessageService SystemMessageService
|
||||||
|
emailCodeService EmailCodeService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUserService 创建用户服务
|
||||||
|
func NewUserService(
|
||||||
|
userRepo *repository.UserRepository,
|
||||||
|
systemMessageService SystemMessageService,
|
||||||
|
emailService EmailService,
|
||||||
|
cacheBackend cache.Cache,
|
||||||
|
) *UserService {
|
||||||
|
return &UserService{
|
||||||
|
userRepo: userRepo,
|
||||||
|
systemMessageService: systemMessageService,
|
||||||
|
emailCodeService: NewEmailCodeService(emailService, cacheBackend),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendRegisterCode 发送注册验证码
|
||||||
|
func (s *UserService) SendRegisterCode(ctx context.Context, email string) error {
|
||||||
|
user, err := s.userRepo.GetByEmail(email)
|
||||||
|
if err == nil && user != nil {
|
||||||
|
return ErrEmailExists
|
||||||
|
}
|
||||||
|
return s.emailCodeService.SendCode(ctx, CodePurposeRegister, email)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendPasswordResetCode 发送找回密码验证码
|
||||||
|
func (s *UserService) SendPasswordResetCode(ctx context.Context, email string) error {
|
||||||
|
user, err := s.userRepo.GetByEmail(email)
|
||||||
|
if err != nil || user == nil {
|
||||||
|
return ErrUserNotFound
|
||||||
|
}
|
||||||
|
return s.emailCodeService.SendCode(ctx, CodePurposePasswordReset, email)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendCurrentUserEmailVerifyCode 发送当前用户邮箱验证验证码
|
||||||
|
func (s *UserService) SendCurrentUserEmailVerifyCode(ctx context.Context, userID, email string) error {
|
||||||
|
user, err := s.userRepo.GetByID(userID)
|
||||||
|
if err != nil || user == nil {
|
||||||
|
return ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
targetEmail := strings.TrimSpace(email)
|
||||||
|
if targetEmail == "" && user.Email != nil {
|
||||||
|
targetEmail = strings.TrimSpace(*user.Email)
|
||||||
|
}
|
||||||
|
if targetEmail == "" || !utils.ValidateEmail(targetEmail) {
|
||||||
|
return ErrInvalidEmail
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.EmailVerified && user.Email != nil && strings.EqualFold(strings.TrimSpace(*user.Email), targetEmail) {
|
||||||
|
return ErrEmailAlreadyVerified
|
||||||
|
}
|
||||||
|
|
||||||
|
existingUser, queryErr := s.userRepo.GetByEmail(targetEmail)
|
||||||
|
if queryErr == nil && existingUser != nil && existingUser.ID != userID {
|
||||||
|
return ErrEmailExists
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.emailCodeService.SendCode(ctx, CodePurposeEmailVerify, targetEmail)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyCurrentUserEmail 验证当前用户邮箱
|
||||||
|
func (s *UserService) VerifyCurrentUserEmail(ctx context.Context, userID, email, verificationCode string) error {
|
||||||
|
user, err := s.userRepo.GetByID(userID)
|
||||||
|
if err != nil || user == nil {
|
||||||
|
return ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
targetEmail := strings.TrimSpace(email)
|
||||||
|
if targetEmail == "" && user.Email != nil {
|
||||||
|
targetEmail = strings.TrimSpace(*user.Email)
|
||||||
|
}
|
||||||
|
if targetEmail == "" || !utils.ValidateEmail(targetEmail) {
|
||||||
|
return ErrInvalidEmail
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.emailCodeService.VerifyCode(CodePurposeEmailVerify, targetEmail, verificationCode); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
existingUser, queryErr := s.userRepo.GetByEmail(targetEmail)
|
||||||
|
if queryErr == nil && existingUser != nil && existingUser.ID != userID {
|
||||||
|
return ErrEmailExists
|
||||||
|
}
|
||||||
|
|
||||||
|
user.Email = &targetEmail
|
||||||
|
user.EmailVerified = true
|
||||||
|
return s.userRepo.Update(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendChangePasswordCode 发送修改密码验证码
|
||||||
|
func (s *UserService) SendChangePasswordCode(ctx context.Context, userID string) error {
|
||||||
|
user, err := s.userRepo.GetByID(userID)
|
||||||
|
if err != nil || user == nil {
|
||||||
|
return ErrUserNotFound
|
||||||
|
}
|
||||||
|
if user.Email == nil || strings.TrimSpace(*user.Email) == "" {
|
||||||
|
return ErrEmailNotBound
|
||||||
|
}
|
||||||
|
return s.emailCodeService.SendCode(ctx, CodePurposeChangePassword, *user.Email)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register 用户注册
|
||||||
|
func (s *UserService) Register(ctx context.Context, username, email, password, nickname, phone, verificationCode string) (*model.User, error) {
|
||||||
|
// 验证用户名
|
||||||
|
if !utils.ValidateUsername(username) {
|
||||||
|
return nil, ErrInvalidUsername
|
||||||
|
}
|
||||||
|
|
||||||
|
// 注册必须提供邮箱并完成验证码校验
|
||||||
|
if email == "" || !utils.ValidateEmail(email) {
|
||||||
|
return nil, ErrInvalidEmail
|
||||||
|
}
|
||||||
|
if err := s.emailCodeService.VerifyCode(CodePurposeRegister, email, verificationCode); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证密码
|
||||||
|
if !utils.ValidatePassword(password) {
|
||||||
|
return nil, ErrWeakPassword
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证手机号(如果提供)
|
||||||
|
if phone != "" && !utils.ValidatePhone(phone) {
|
||||||
|
return nil, ErrInvalidPhone
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查用户名是否已存在
|
||||||
|
existingUser, err := s.userRepo.GetByUsername(username)
|
||||||
|
if err == nil && existingUser != nil {
|
||||||
|
return nil, ErrUsernameExists
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查邮箱是否已存在(如果提供)
|
||||||
|
if email != "" {
|
||||||
|
existingUser, err = s.userRepo.GetByEmail(email)
|
||||||
|
if err == nil && existingUser != nil {
|
||||||
|
return nil, ErrEmailExists
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查手机号是否已存在(如果提供)
|
||||||
|
if phone != "" {
|
||||||
|
existingUser, err = s.userRepo.GetByPhone(phone)
|
||||||
|
if err == nil && existingUser != nil {
|
||||||
|
return nil, ErrPhoneExists
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 密码哈希
|
||||||
|
hashedPassword, err := utils.HashPassword(password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建用户
|
||||||
|
user := &model.User{
|
||||||
|
Username: username,
|
||||||
|
Nickname: nickname,
|
||||||
|
EmailVerified: true,
|
||||||
|
PasswordHash: hashedPassword,
|
||||||
|
Status: model.UserStatusActive,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果提供了邮箱,设置指针值
|
||||||
|
if email != "" {
|
||||||
|
user.Email = &email
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果提供了手机号,设置指针值
|
||||||
|
if phone != "" {
|
||||||
|
user.Phone = &phone
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.userRepo.Create(user)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login 用户登录
|
||||||
|
func (s *UserService) Login(ctx context.Context, account, password string) (*model.User, error) {
|
||||||
|
account = strings.TrimSpace(account)
|
||||||
|
var (
|
||||||
|
user *model.User
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if utils.ValidateEmail(account) {
|
||||||
|
user, err = s.userRepo.GetByEmail(account)
|
||||||
|
} else if utils.ValidatePhone(account) {
|
||||||
|
user, err = s.userRepo.GetByPhone(account)
|
||||||
|
} else {
|
||||||
|
user, err = s.userRepo.GetByUsername(account)
|
||||||
|
}
|
||||||
|
if err != nil || user == nil {
|
||||||
|
return nil, ErrInvalidCredentials
|
||||||
|
}
|
||||||
|
|
||||||
|
if !utils.CheckPasswordHash(password, user.PasswordHash) {
|
||||||
|
return nil, ErrInvalidCredentials
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Status != model.UserStatusActive {
|
||||||
|
return nil, ErrUserBanned
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByID 根据ID获取用户
|
||||||
|
func (s *UserService) GetUserByID(ctx context.Context, id string) (*model.User, error) {
|
||||||
|
return s.userRepo.GetByID(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserPostCount 获取用户帖子数(实时计算)
|
||||||
|
func (s *UserService) GetUserPostCount(ctx context.Context, userID string) (int64, error) {
|
||||||
|
return s.userRepo.GetPostsCount(userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserPostCountBatch 批量获取用户帖子数(实时计算)
|
||||||
|
func (s *UserService) GetUserPostCountBatch(ctx context.Context, userIDs []string) (map[string]int64, error) {
|
||||||
|
return s.userRepo.GetPostsCountBatch(userIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByIDWithFollowingStatus 根据ID获取用户(包含当前用户是否关注的状态)
|
||||||
|
func (s *UserService) GetUserByIDWithFollowingStatus(ctx context.Context, userID, currentUserID string) (*model.User, bool, error) {
|
||||||
|
user, err := s.userRepo.GetByID(userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果查询的是当前用户自己,不需要检查关注状态
|
||||||
|
if userID == currentUserID {
|
||||||
|
return user, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
isFollowing, err := s.userRepo.IsFollowing(currentUserID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return user, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, isFollowing, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByIDWithMutualFollowStatus 根据ID获取用户(包含双向关注状态)
|
||||||
|
func (s *UserService) GetUserByIDWithMutualFollowStatus(ctx context.Context, userID, currentUserID string) (*model.User, bool, bool, error) {
|
||||||
|
user, err := s.userRepo.GetByID(userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果查询的是当前用户自己,不需要检查关注状态
|
||||||
|
if userID == currentUserID {
|
||||||
|
return user, false, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 当前用户是否关注了该用户
|
||||||
|
isFollowing, err := s.userRepo.IsFollowing(currentUserID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return user, false, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 该用户是否关注了当前用户
|
||||||
|
isFollowingMe, err := s.userRepo.IsFollowing(userID, currentUserID)
|
||||||
|
if err != nil {
|
||||||
|
return user, isFollowing, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, isFollowing, isFollowingMe, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUser 更新用户
|
||||||
|
func (s *UserService) UpdateUser(ctx context.Context, user *model.User) error {
|
||||||
|
return s.userRepo.Update(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFollowers 获取粉丝
|
||||||
|
func (s *UserService) GetFollowers(ctx context.Context, userID string, page, pageSize int) ([]*model.User, int64, error) {
|
||||||
|
return s.userRepo.GetFollowers(userID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFollowing 获取关注
|
||||||
|
func (s *UserService) GetFollowing(ctx context.Context, userID string, page, pageSize int) ([]*model.User, int64, error) {
|
||||||
|
return s.userRepo.GetFollowing(userID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FollowUser 关注用户
|
||||||
|
func (s *UserService) FollowUser(ctx context.Context, followerID, followeeID string) error {
|
||||||
|
fmt.Printf("[DEBUG] FollowUser called: followerID=%s, followeeID=%s\n", followerID, followeeID)
|
||||||
|
|
||||||
|
blocked, err := s.userRepo.IsBlockedEitherDirection(followerID, followeeID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if blocked {
|
||||||
|
return ErrUserBlocked
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否已经关注
|
||||||
|
isFollowing, err := s.userRepo.IsFollowing(followerID, followeeID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[DEBUG] Error checking existing follow: %v\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if isFollowing {
|
||||||
|
fmt.Printf("[DEBUG] Already following, skip creation\n")
|
||||||
|
return nil // 已经关注,直接返回成功
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建关注关系
|
||||||
|
follow := &model.Follow{
|
||||||
|
FollowerID: followerID,
|
||||||
|
FollowingID: followeeID,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.userRepo.CreateFollow(follow)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[DEBUG] CreateFollow error: %v\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("[DEBUG] Follow record created successfully\n")
|
||||||
|
|
||||||
|
// 刷新关注者的关注数(通过实际计数,更可靠)
|
||||||
|
err = s.userRepo.RefreshFollowingCount(followerID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[DEBUG] Error refreshing following count: %v\n", err)
|
||||||
|
// 不回滚,计数可以通过其他方式修复
|
||||||
|
}
|
||||||
|
|
||||||
|
// 刷新被关注者的粉丝数(通过实际计数,更可靠)
|
||||||
|
err = s.userRepo.RefreshFollowersCount(followeeID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[DEBUG] Error refreshing followers count: %v\n", err)
|
||||||
|
// 不回滚,计数可以通过其他方式修复
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送关注通知给被关注者
|
||||||
|
if s.systemMessageService != nil {
|
||||||
|
// 异步发送通知,不阻塞主流程
|
||||||
|
go func() {
|
||||||
|
notifyErr := s.systemMessageService.SendFollowNotification(context.Background(), followeeID, followerID)
|
||||||
|
if notifyErr != nil {
|
||||||
|
fmt.Printf("[DEBUG] Error sending follow notification: %v\n", notifyErr)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("[DEBUG] Follow notification sent successfully to %s\n", followeeID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("[DEBUG] FollowUser completed: followerID=%s, followeeID=%s\n", followerID, followeeID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnfollowUser 取消关注用户
|
||||||
|
func (s *UserService) UnfollowUser(ctx context.Context, followerID, followeeID string) error {
|
||||||
|
fmt.Printf("[DEBUG] UnfollowUser called: followerID=%s, followeeID=%s\n", followerID, followeeID)
|
||||||
|
|
||||||
|
// 检查是否已经关注
|
||||||
|
isFollowing, err := s.userRepo.IsFollowing(followerID, followeeID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[DEBUG] Error checking existing follow: %v\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !isFollowing {
|
||||||
|
fmt.Printf("[DEBUG] Not following, skip deletion\n")
|
||||||
|
return nil // 没有关注,直接返回成功
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除关注关系
|
||||||
|
err = s.userRepo.DeleteFollow(followerID, followeeID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[DEBUG] DeleteFollow error: %v\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("[DEBUG] Follow record deleted successfully\n")
|
||||||
|
|
||||||
|
// 刷新关注者的关注数(通过实际计数,更可靠)
|
||||||
|
err = s.userRepo.RefreshFollowingCount(followerID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[DEBUG] Error refreshing following count: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 刷新被关注者的粉丝数(通过实际计数,更可靠)
|
||||||
|
err = s.userRepo.RefreshFollowersCount(followeeID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("[DEBUG] Error refreshing followers count: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("[DEBUG] UnfollowUser completed: followerID=%s, followeeID=%s\n", followerID, followeeID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlockUser 拉黑用户,并自动清理双向关注/粉丝关系
|
||||||
|
func (s *UserService) BlockUser(ctx context.Context, blockerID, blockedID string) error {
|
||||||
|
if blockerID == blockedID {
|
||||||
|
return ErrInvalidOperation
|
||||||
|
}
|
||||||
|
return s.userRepo.BlockUserAndCleanupRelations(blockerID, blockedID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnblockUser 取消拉黑
|
||||||
|
func (s *UserService) UnblockUser(ctx context.Context, blockerID, blockedID string) error {
|
||||||
|
if blockerID == blockedID {
|
||||||
|
return ErrInvalidOperation
|
||||||
|
}
|
||||||
|
return s.userRepo.UnblockUser(blockerID, blockedID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBlockedUsers 获取黑名单列表
|
||||||
|
func (s *UserService) GetBlockedUsers(ctx context.Context, blockerID string, page, pageSize int) ([]*model.User, int64, error) {
|
||||||
|
return s.userRepo.GetBlockedUsers(blockerID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsBlocked 检查当前用户是否已拉黑目标用户
|
||||||
|
func (s *UserService) IsBlocked(ctx context.Context, blockerID, blockedID string) (bool, error) {
|
||||||
|
return s.userRepo.IsBlocked(blockerID, blockedID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFollowingList 获取关注列表(字符串参数版本)
|
||||||
|
func (s *UserService) GetFollowingList(ctx context.Context, userID, page, pageSize string) ([]*model.User, error) {
|
||||||
|
// 转换字符串参数为整数
|
||||||
|
pageInt := 1
|
||||||
|
pageSizeInt := 20
|
||||||
|
if page != "" {
|
||||||
|
_, err := fmt.Sscanf(page, "%d", &pageInt)
|
||||||
|
if err != nil {
|
||||||
|
pageInt = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if pageSize != "" {
|
||||||
|
_, err := fmt.Sscanf(pageSize, "%d", &pageSizeInt)
|
||||||
|
if err != nil {
|
||||||
|
pageSizeInt = 20
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
users, _, err := s.userRepo.GetFollowing(userID, pageInt, pageSizeInt)
|
||||||
|
return users, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFollowersList 获取粉丝列表(字符串参数版本)
|
||||||
|
func (s *UserService) GetFollowersList(ctx context.Context, userID, page, pageSize string) ([]*model.User, error) {
|
||||||
|
// 转换字符串参数为整数
|
||||||
|
pageInt := 1
|
||||||
|
pageSizeInt := 20
|
||||||
|
if page != "" {
|
||||||
|
_, err := fmt.Sscanf(page, "%d", &pageInt)
|
||||||
|
if err != nil {
|
||||||
|
pageInt = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if pageSize != "" {
|
||||||
|
_, err := fmt.Sscanf(pageSize, "%d", &pageSizeInt)
|
||||||
|
if err != nil {
|
||||||
|
pageSizeInt = 20
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
users, _, err := s.userRepo.GetFollowers(userID, pageInt, pageSizeInt)
|
||||||
|
return users, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMutualFollowStatus 批量获取双向关注状态
|
||||||
|
func (s *UserService) GetMutualFollowStatus(ctx context.Context, currentUserID string, targetUserIDs []string) (map[string][2]bool, error) {
|
||||||
|
return s.userRepo.GetMutualFollowStatus(currentUserID, targetUserIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckUsernameAvailable 检查用户名是否可用
|
||||||
|
func (s *UserService) CheckUsernameAvailable(ctx context.Context, username string) (bool, error) {
|
||||||
|
user, err := s.userRepo.GetByUsername(username)
|
||||||
|
if err != nil {
|
||||||
|
return true, nil // 用户不存在,可用
|
||||||
|
}
|
||||||
|
return user == nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChangePassword 修改密码
|
||||||
|
func (s *UserService) ChangePassword(ctx context.Context, userID, oldPassword, newPassword, verificationCode string) error {
|
||||||
|
// 获取用户
|
||||||
|
user, err := s.userRepo.GetByID(userID)
|
||||||
|
if err != nil {
|
||||||
|
return ErrUserNotFound
|
||||||
|
}
|
||||||
|
if user.Email == nil || strings.TrimSpace(*user.Email) == "" {
|
||||||
|
return ErrEmailNotBound
|
||||||
|
}
|
||||||
|
if err := s.emailCodeService.VerifyCode(CodePurposeChangePassword, *user.Email, verificationCode); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证旧密码
|
||||||
|
if !utils.CheckPasswordHash(oldPassword, user.PasswordHash) {
|
||||||
|
return ErrInvalidCredentials
|
||||||
|
}
|
||||||
|
|
||||||
|
// 哈希新密码
|
||||||
|
hashedPassword, err := utils.HashPassword(newPassword)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新密码
|
||||||
|
user.PasswordHash = hashedPassword
|
||||||
|
return s.userRepo.Update(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetPasswordByEmail 通过邮箱重置密码
|
||||||
|
func (s *UserService) ResetPasswordByEmail(ctx context.Context, email, verificationCode, newPassword string) error {
|
||||||
|
email = strings.TrimSpace(email)
|
||||||
|
if !utils.ValidateEmail(email) {
|
||||||
|
return ErrInvalidEmail
|
||||||
|
}
|
||||||
|
if !utils.ValidatePassword(newPassword) {
|
||||||
|
return ErrWeakPassword
|
||||||
|
}
|
||||||
|
if err := s.emailCodeService.VerifyCode(CodePurposePasswordReset, email, verificationCode); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := s.userRepo.GetByEmail(email)
|
||||||
|
if err != nil || user == nil {
|
||||||
|
return ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
hashedPassword, err := utils.HashPassword(newPassword)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
user.PasswordHash = hashedPassword
|
||||||
|
return s.userRepo.Update(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search 搜索用户
|
||||||
|
func (s *UserService) Search(ctx context.Context, keyword string, page, pageSize int) ([]*model.User, int64, error) {
|
||||||
|
return s.userRepo.Search(keyword, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 错误定义
|
||||||
|
var (
|
||||||
|
ErrInvalidUsername = &ServiceError{Code: 400, Message: "invalid username"}
|
||||||
|
ErrInvalidEmail = &ServiceError{Code: 400, Message: "invalid email"}
|
||||||
|
ErrInvalidPhone = &ServiceError{Code: 400, Message: "invalid phone number"}
|
||||||
|
ErrWeakPassword = &ServiceError{Code: 400, Message: "password too weak"}
|
||||||
|
ErrUsernameExists = &ServiceError{Code: 400, Message: "username already exists"}
|
||||||
|
ErrEmailExists = &ServiceError{Code: 400, Message: "email already exists"}
|
||||||
|
ErrPhoneExists = &ServiceError{Code: 400, Message: "phone number already exists"}
|
||||||
|
ErrUserNotFound = &ServiceError{Code: 404, Message: "user not found"}
|
||||||
|
ErrUserBanned = &ServiceError{Code: 403, Message: "user is banned"}
|
||||||
|
ErrUserBlocked = &ServiceError{Code: 403, Message: "blocked relationship exists"}
|
||||||
|
ErrInvalidOperation = &ServiceError{Code: 400, Message: "invalid operation"}
|
||||||
|
ErrEmailServiceUnavailable = &ServiceError{Code: 503, Message: "email service unavailable"}
|
||||||
|
ErrVerificationCodeTooFrequent = &ServiceError{Code: 429, Message: "verification code sent too frequently"}
|
||||||
|
ErrVerificationCodeInvalid = &ServiceError{Code: 400, Message: "invalid verification code"}
|
||||||
|
ErrVerificationCodeExpired = &ServiceError{Code: 400, Message: "verification code expired"}
|
||||||
|
ErrVerificationCodeUnavailable = &ServiceError{Code: 500, Message: "verification code storage unavailable"}
|
||||||
|
ErrEmailAlreadyVerified = &ServiceError{Code: 400, Message: "email already verified"}
|
||||||
|
ErrEmailNotBound = &ServiceError{Code: 400, Message: "email not bound"}
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServiceError 服务错误
|
||||||
|
type ServiceError struct {
|
||||||
|
Code int
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ServiceError) Error() string {
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrInvalidCredentials = &ServiceError{Code: 401, Message: "invalid username or password"}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user