Initial backend repository commit.
Set up project files and add .gitignore to exclude local build/runtime artifacts. Made-with: Cursor
This commit is contained in:
16
.gitignore
vendored
Normal file
16
.gitignore
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
# Build artifacts
|
||||
server
|
||||
*.tar
|
||||
|
||||
# Runtime files
|
||||
data/
|
||||
logs/
|
||||
*.log
|
||||
|
||||
# Local docker metadata
|
||||
.last_docker_tag
|
||||
.last_docker_tar
|
||||
|
||||
# Local environment files
|
||||
.env
|
||||
.env.*
|
||||
28
Dockerfile
Normal file
28
Dockerfile
Normal file
@@ -0,0 +1,28 @@
|
||||
# 使用Debian bookworm-slim作为生产镜像
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
# 安装必要的运行时依赖
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
tzdata \
|
||||
libsqlite3-0 \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& apt-get clean
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 从宿主机复制编译好的二进制文件
|
||||
COPY server ./carrot_bbs
|
||||
|
||||
# 复制配置文件
|
||||
COPY configs ./configs
|
||||
|
||||
# 给可执行文件添加执行权限
|
||||
RUN chmod +x ./carrot_bbs
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8080
|
||||
|
||||
# 默认命令
|
||||
CMD ["./carrot_bbs"]
|
||||
172
configs/config.yaml
Normal file
172
configs/config.yaml
Normal file
@@ -0,0 +1,172 @@
|
||||
# 服务器配置
|
||||
# 环境变量: APP_SERVER_HOST, APP_SERVER_PORT, APP_SERVER_MODE
|
||||
server:
|
||||
host: 0.0.0.0
|
||||
port: 8080
|
||||
mode: debug
|
||||
|
||||
# 数据库配置
|
||||
# 环境变量:
|
||||
# SQLite: APP_DATABASE_SQLITE_PATH
|
||||
# Postgres: APP_DATABASE_POSTGRES_HOST, APP_DATABASE_POSTGRES_PORT, APP_DATABASE_POSTGRES_USER,
|
||||
# APP_DATABASE_POSTGRES_PASSWORD, APP_DATABASE_POSTGRES_DBNAME
|
||||
database:
|
||||
type: sqlite # sqlite 或 postgres
|
||||
sqlite:
|
||||
path: ./data/carrot_bbs.db
|
||||
postgres:
|
||||
host: localhost
|
||||
port: 5432
|
||||
user: postgres
|
||||
password: postgres
|
||||
dbname: carrot_bbs
|
||||
sslmode: disable
|
||||
max_idle_conns: 10
|
||||
max_open_conns: 100
|
||||
log_level: warn
|
||||
slow_threshold_ms: 200
|
||||
ignore_record_not_found: true
|
||||
parameterized_queries: true
|
||||
|
||||
# Redis配置
|
||||
# 环境变量:
|
||||
# 类型: APP_REDIS_TYPE (miniredis/redis)
|
||||
# Redis: APP_REDIS_REDIS_HOST, APP_REDIS_REDIS_PORT, APP_REDIS_REDIS_PASSWORD, APP_REDIS_REDIS_DB
|
||||
# Miniredis: APP_REDIS_MINIREDIS_HOST, APP_REDIS_MINIREDIS_PORT
|
||||
redis:
|
||||
type: miniredis # miniredis 或 redis
|
||||
redis:
|
||||
host: localhost
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
miniredis:
|
||||
host: localhost
|
||||
port: 6379
|
||||
pool_size: 10
|
||||
|
||||
# 缓存配置
|
||||
# 环境变量:
|
||||
# APP_CACHE_ENABLED, APP_CACHE_KEY_PREFIX, APP_CACHE_DEFAULT_TTL, APP_CACHE_NULL_TTL
|
||||
# APP_CACHE_JITTER_RATIO, APP_CACHE_DISABLE_FLUSHDB
|
||||
# APP_CACHE_MODULES_POST_LIST_TTL, APP_CACHE_MODULES_CONVERSATION_TTL
|
||||
# APP_CACHE_MODULES_UNREAD_COUNT_TTL, APP_CACHE_MODULES_GROUP_MEMBERS_TTL
|
||||
cache:
|
||||
enabled: true
|
||||
key_prefix: ""
|
||||
default_ttl: 30
|
||||
null_ttl: 5
|
||||
jitter_ratio: 0.1
|
||||
disable_flushdb: true
|
||||
modules:
|
||||
post_list_ttl: 30
|
||||
conversation_ttl: 60
|
||||
unread_count_ttl: 30
|
||||
group_members_ttl: 120
|
||||
|
||||
# S3对象存储配置
|
||||
# 环境变量: APP_S3_ENDPOINT, APP_S3_ACCESS_KEY, APP_S3_SECRET_KEY, APP_S3_BUCKET, APP_S3_DOMAIN
|
||||
s3:
|
||||
endpoint: ""
|
||||
access_key: ""
|
||||
secret_key: ""
|
||||
bucket: ""
|
||||
use_ssl: true
|
||||
region: us-east-1
|
||||
domain: ""
|
||||
# JWT配置
|
||||
# 环境变量: APP_JWT_SECRET
|
||||
jwt:
|
||||
secret: your-jwt-secret-key-change-in-production
|
||||
access_token_expire: 86400 # 24 hours in seconds
|
||||
refresh_token_expire: 604800 # 7 days in seconds
|
||||
|
||||
log:
|
||||
level: info
|
||||
encoding: json
|
||||
output_paths:
|
||||
- stdout
|
||||
- ./logs/app.log
|
||||
|
||||
rate_limit:
|
||||
enabled: true
|
||||
requests_per_minute: 60
|
||||
|
||||
upload:
|
||||
max_file_size: 10485760 # 10MB
|
||||
allowed_types:
|
||||
- image/jpeg
|
||||
- image/png
|
||||
- image/gif
|
||||
- image/webp
|
||||
|
||||
# 敏感词过滤配置
|
||||
sensitive:
|
||||
enabled: true
|
||||
replace_str: "***"
|
||||
min_match_len: 1
|
||||
load_from_db: true
|
||||
load_from_redis: false
|
||||
redis_key_prefix: "sensitive_words"
|
||||
|
||||
# 内容审核服务配置
|
||||
audit:
|
||||
enabled: false # 暂时关闭第三方审核
|
||||
# 审核服务提供商: local, aliyun, tencent, baidu
|
||||
provider: "local"
|
||||
auto_audit: true
|
||||
timeout: 30
|
||||
# 阿里云配置
|
||||
aliyun_access_key: ""
|
||||
aliyun_secret_key: ""
|
||||
aliyun_region: "cn-shanghai"
|
||||
# 腾讯云配置
|
||||
tencent_secret_id: ""
|
||||
tencent_secret_key: ""
|
||||
# 百度云配置
|
||||
baidu_api_key: ""
|
||||
baidu_secret_key: ""
|
||||
|
||||
# Gorse推荐系统配置
|
||||
# 环境变量: APP_GORSE_ADDRESS, APP_GORSE_API_KEY, APP_GORSE_DASHBOARD, APP_GORSE_IMPORT_PASSWORD
|
||||
gorse:
|
||||
enabled: false
|
||||
address: "" # Gorse server地址
|
||||
api_key: "" # API密钥
|
||||
dashboard: "" # Gorse dashboard地址
|
||||
import_password: "" # 导入数据密码
|
||||
embedding_api_key: ""
|
||||
embedding_url: "https://api.littlelan.cn/v1/embeddings"
|
||||
embedding_model: "BAAI/bge-m3"
|
||||
|
||||
# OpenAI兼容接口配置(用于帖子审核,支持图文)
|
||||
# 环境变量:
|
||||
# APP_OPENAI_ENABLED, APP_OPENAI_BASE_URL, APP_OPENAI_API_KEY
|
||||
# APP_OPENAI_MODERATION_MODEL, APP_OPENAI_MODERATION_MAX_IMAGES_PER_REQUEST
|
||||
# APP_OPENAI_REQUEST_TIMEOUT, APP_OPENAI_STRICT_MODERATION
|
||||
openai:
|
||||
enabled: true
|
||||
base_url: "https://api.littlelan.cn/"
|
||||
api_key: ""
|
||||
moderation_model: "qwen3.5-122b"
|
||||
moderation_max_images_per_request: 1
|
||||
request_timeout: 30
|
||||
strict_moderation: false
|
||||
|
||||
# SMTP发信配置(gomail.v2)
|
||||
# 环境变量:
|
||||
# APP_EMAIL_ENABLED, APP_EMAIL_HOST, APP_EMAIL_PORT
|
||||
# APP_EMAIL_USERNAME, APP_EMAIL_PASSWORD
|
||||
# APP_EMAIL_FROM_ADDRESS, APP_EMAIL_FROM_NAME
|
||||
# APP_EMAIL_USE_TLS, APP_EMAIL_INSECURE_SKIP_VERIFY, APP_EMAIL_TIMEOUT
|
||||
email:
|
||||
enabled: false
|
||||
host: ""
|
||||
port: 587
|
||||
username: ""
|
||||
password: ""
|
||||
from_address: ""
|
||||
from_name: "Carrot BBS"
|
||||
use_tls: true
|
||||
insecure_skip_verify: false
|
||||
timeout: 15
|
||||
78
docker-compose.yml
Normal file
78
docker-compose.yml
Normal file
@@ -0,0 +1,78 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
# 开发环境使用 SQLite 和 miniredis,不需要外部数据库服务
|
||||
# 如需使用 PostgreSQL 和 Redis,切换 config.yaml 中的配置并取消注释以下服务
|
||||
|
||||
# PostgreSQL (生产环境使用)
|
||||
# postgres:
|
||||
# image: postgres:15-alpine
|
||||
# container_name: carrot_bbs_postgres
|
||||
# environment:
|
||||
# POSTGRES_USER: postgres
|
||||
# POSTGRES_PASSWORD: postgres
|
||||
# POSTGRES_DB: carrot_bbs
|
||||
# ports:
|
||||
# - "5432:5432"
|
||||
# volumes:
|
||||
# - postgres_data:/var/lib/postgresql/data
|
||||
# healthcheck:
|
||||
# test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||
# interval: 10s
|
||||
# timeout: 5s
|
||||
# retries: 5
|
||||
|
||||
# Redis (生产环境使用)
|
||||
# redis:
|
||||
# image: redis:7-alpine
|
||||
# container_name: carrot_bbs_redis
|
||||
# ports:
|
||||
# - "6379:6379"
|
||||
# volumes:
|
||||
# - redis_data:/data
|
||||
# healthcheck:
|
||||
# test: ["CMD", "redis-cli", "ping"]
|
||||
# interval: 10s
|
||||
# timeout: 5s
|
||||
# retries: 5
|
||||
|
||||
# MinIO (对象存储,可选)
|
||||
minio:
|
||||
image: minio/minio:latest
|
||||
container_name: carrot_bbs_minio
|
||||
environment:
|
||||
MINIO_ROOT_USER: minioadmin
|
||||
MINIO_ROOT_PASSWORD: minioadmin
|
||||
ports:
|
||||
- "9000:9000"
|
||||
- "9001:9001"
|
||||
volumes:
|
||||
- minio_data:/data
|
||||
command: server /data --console-address ":9001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
app:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
container_name: carrot_bbs_app
|
||||
ports:
|
||||
- "8080:8080"
|
||||
depends_on:
|
||||
minio:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
- CONFIG_PATH=/app/configs/config.yaml
|
||||
volumes:
|
||||
- ./configs:/app/configs
|
||||
- ./logs:/app/logs
|
||||
- ./data:/app/data
|
||||
|
||||
volumes:
|
||||
# postgres_data:
|
||||
# redis_data:
|
||||
minio_data:
|
||||
86
go.mod
Normal file
86
go.mod
Normal file
@@ -0,0 +1,86 @@
|
||||
module carrot_bbs
|
||||
|
||||
go 1.25
|
||||
|
||||
require (
|
||||
github.com/alicebob/miniredis/v2 v2.31.0
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0
|
||||
github.com/google/uuid v1.5.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/gorse-io/gorse-go v0.5.0-alpha.3
|
||||
github.com/minio/minio-go/v7 v7.0.66
|
||||
github.com/redis/go-redis/v9 v9.3.1
|
||||
github.com/spf13/viper v1.18.2
|
||||
go.uber.org/zap v1.26.0
|
||||
golang.org/x/crypto v0.17.0
|
||||
golang.org/x/image v0.24.0
|
||||
gorm.io/driver/postgres v1.5.4
|
||||
gorm.io/driver/sqlite v1.5.4
|
||||
gorm.io/gorm v1.25.5
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.3 // indirect
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-logr/logr v1.2.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.4.3 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.17.4 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.6 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.17 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
github.com/spf13/afero v1.11.0 // indirect
|
||||
github.com/spf13/cast v1.6.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
github.com/yuin/gopher-lua v1.1.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4 // indirect
|
||||
go.opentelemetry.io/otel v1.11.1 // indirect
|
||||
go.opentelemetry.io/otel/metric v0.33.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.11.1 // indirect
|
||||
go.uber.org/multierr v1.10.0 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
||||
golang.org/x/net v0.19.0 // indirect
|
||||
golang.org/x/sys v0.15.0 // indirect
|
||||
golang.org/x/text v0.22.0 // indirect
|
||||
google.golang.org/protobuf v1.31.0 // indirect
|
||||
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
|
||||
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
216
go.sum
Normal file
216
go.sum
Normal file
@@ -0,0 +1,216 @@
|
||||
github.com/DmitriyVTitov/size v1.5.0/go.mod h1:le6rNI4CoLQV1b9gzp1+3d7hMAD/uu2QcJ+aYbNgiU0=
|
||||
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
|
||||
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
|
||||
github.com/alicebob/miniredis/v2 v2.31.0 h1:ObEFUNlJwoIiyjxdrYF0QIDE7qXcLc7D3WpSH4c22PU=
|
||||
github.com/alicebob/miniredis/v2 v2.31.0/go.mod h1:UB/T2Uztp7MlFSDakaX1sTXUv5CASoprx0wulRT6HBg=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
|
||||
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
|
||||
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
|
||||
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
|
||||
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
|
||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
|
||||
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/gorse-io/gorse-go v0.5.0-alpha.3 h1:GR/OWzq016VyFyTTxgQWeayGahRVzB1cGFIW/AaShC4=
|
||||
github.com/gorse-io/gorse-go v0.5.0-alpha.3/go.mod h1:ZxmVHzZPKm5pmEIlqaRDwK0rkfTRHlfziO033XZ+RW0=
|
||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
|
||||
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
|
||||
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc=
|
||||
github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
||||
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
|
||||
github.com/minio/minio-go/v7 v7.0.66/go.mod h1:DHAgmyQEGdW3Cif0UooKOyrT3Vxs82zNdV6tkKhRtbs=
|
||||
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
|
||||
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
|
||||
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.3.1 h1:KqdY8U+3X6z+iACvumCNxnoluToB+9Me+TvyFa21Mds=
|
||||
github.com/redis/go-redis/v9 v9.3.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
|
||||
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
|
||||
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
||||
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
||||
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||
github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/yuin/gopher-lua v1.1.0 h1:BojcDhfyDWgU2f2TOzYK/g5p2gxMrku8oupLDqlnSqE=
|
||||
github.com/yuin/gopher-lua v1.1.0/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4 h1:aUEBEdCa6iamGzg6fuYxDA8ThxvOG240mAvWDU+XLio=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4/go.mod h1:l2MdsbKTocpPS5nQZscqTR9jd8u96VYZdcpF8Sye7mA=
|
||||
go.opentelemetry.io/otel v1.11.1 h1:4WLLAmcfkmDk2ukNXJyq3/kiz/3UzCaYq6PskJsaou4=
|
||||
go.opentelemetry.io/otel v1.11.1/go.mod h1:1nNhXBbWSD0nsL38H6btgnFN2k4i0sNLHNNMZMSbUGE=
|
||||
go.opentelemetry.io/otel/metric v0.33.0 h1:xQAyl7uGEYvrLAiV/09iTJlp1pZnQ9Wl793qbVvED1E=
|
||||
go.opentelemetry.io/otel/metric v0.33.0/go.mod h1:QlTYc+EnYNq/M2mNk1qDDMRLpqCOj2f/r5c7Fd5FYaI=
|
||||
go.opentelemetry.io/otel/trace v1.11.1 h1:ofxdnzsNrGBYXbP7t7zpUK281+go5rF7dvdIZXF8gdQ=
|
||||
go.opentelemetry.io/otel/trace v1.11.1/go.mod h1:f/Q9G7vzk5u91PhbmKbg1Qn0rzH1LJ4vbPHFGkTPtOk=
|
||||
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
|
||||
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
|
||||
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo=
|
||||
go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
|
||||
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
|
||||
golang.org/x/image v0.24.0 h1:AN7zRgVsbvmTfNyqIbbOraYL8mSwcKncEj8ofjgzcMQ=
|
||||
golang.org/x/image v0.24.0/go.mod h1:4b/ITuLfqYq1hqZcjofwctIhi7sZh2WaCjvsBNjjya8=
|
||||
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
|
||||
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
|
||||
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
|
||||
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
|
||||
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk=
|
||||
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE=
|
||||
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
|
||||
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
|
||||
gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0=
|
||||
gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
|
||||
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
|
||||
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
604
internal/cache/cache.go
vendored
Normal file
604
internal/cache/cache.go
vendored
Normal file
@@ -0,0 +1,604 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
redisPkg "carrot_bbs/internal/pkg/redis"
|
||||
)
|
||||
|
||||
// Cache 缓存接口
|
||||
type Cache interface {
|
||||
// Set 设置缓存值,支持TTL
|
||||
Set(key string, value interface{}, ttl time.Duration)
|
||||
// Get 获取缓存值
|
||||
Get(key string) (interface{}, bool)
|
||||
// Delete 删除缓存
|
||||
Delete(key string)
|
||||
// DeleteByPrefix 根据前缀删除缓存
|
||||
DeleteByPrefix(prefix string)
|
||||
// Clear 清空所有缓存
|
||||
Clear()
|
||||
// Exists 检查键是否存在
|
||||
Exists(key string) bool
|
||||
// Increment 增加计数器的值
|
||||
Increment(key string) int64
|
||||
// IncrementBy 增加指定值
|
||||
IncrementBy(key string, value int64) int64
|
||||
}
|
||||
|
||||
// cacheItem 缓存项(用于内存缓存降级)
|
||||
type cacheItem struct {
|
||||
value interface{}
|
||||
expiration int64 // 过期时间戳(纳秒)
|
||||
}
|
||||
|
||||
const nullMarkerValue = "__carrot_cache_null__"
|
||||
|
||||
type cacheMetrics struct {
|
||||
hit atomic.Int64
|
||||
miss atomic.Int64
|
||||
decodeError atomic.Int64
|
||||
setError atomic.Int64
|
||||
invalidate atomic.Int64
|
||||
}
|
||||
|
||||
var metrics cacheMetrics
|
||||
var loadLocks sync.Map
|
||||
|
||||
type MetricsSnapshot struct {
|
||||
Hit int64
|
||||
Miss int64
|
||||
DecodeError int64
|
||||
SetError int64
|
||||
Invalidate int64
|
||||
}
|
||||
|
||||
type Settings struct {
|
||||
Enabled bool
|
||||
KeyPrefix string
|
||||
DefaultTTL time.Duration
|
||||
NullTTL time.Duration
|
||||
JitterRatio float64
|
||||
PostListTTL time.Duration
|
||||
ConversationTTL time.Duration
|
||||
UnreadCountTTL time.Duration
|
||||
GroupMembersTTL time.Duration
|
||||
DisableFlushDB bool
|
||||
}
|
||||
|
||||
var settings = Settings{
|
||||
Enabled: true,
|
||||
DefaultTTL: 30 * time.Second,
|
||||
NullTTL: 5 * time.Second,
|
||||
JitterRatio: 0.1,
|
||||
PostListTTL: 30 * time.Second,
|
||||
ConversationTTL: 60 * time.Second,
|
||||
UnreadCountTTL: 30 * time.Second,
|
||||
GroupMembersTTL: 120 * time.Second,
|
||||
DisableFlushDB: true,
|
||||
}
|
||||
|
||||
func Configure(s Settings) {
|
||||
settings.Enabled = s.Enabled
|
||||
if s.KeyPrefix != "" {
|
||||
settings.KeyPrefix = s.KeyPrefix
|
||||
}
|
||||
if s.DefaultTTL > 0 {
|
||||
settings.DefaultTTL = s.DefaultTTL
|
||||
}
|
||||
if s.NullTTL > 0 {
|
||||
settings.NullTTL = s.NullTTL
|
||||
}
|
||||
if s.JitterRatio > 0 {
|
||||
settings.JitterRatio = s.JitterRatio
|
||||
}
|
||||
if s.PostListTTL > 0 {
|
||||
settings.PostListTTL = s.PostListTTL
|
||||
}
|
||||
if s.ConversationTTL > 0 {
|
||||
settings.ConversationTTL = s.ConversationTTL
|
||||
}
|
||||
if s.UnreadCountTTL > 0 {
|
||||
settings.UnreadCountTTL = s.UnreadCountTTL
|
||||
}
|
||||
if s.GroupMembersTTL > 0 {
|
||||
settings.GroupMembersTTL = s.GroupMembersTTL
|
||||
}
|
||||
settings.DisableFlushDB = s.DisableFlushDB
|
||||
}
|
||||
|
||||
func GetSettings() Settings {
|
||||
return settings
|
||||
}
|
||||
|
||||
func normalizeKey(key string) string {
|
||||
if settings.KeyPrefix == "" {
|
||||
return key
|
||||
}
|
||||
return settings.KeyPrefix + ":" + key
|
||||
}
|
||||
|
||||
func normalizePrefix(prefix string) string {
|
||||
if settings.KeyPrefix == "" {
|
||||
return prefix
|
||||
}
|
||||
return settings.KeyPrefix + ":" + prefix
|
||||
}
|
||||
|
||||
func GetMetricsSnapshot() MetricsSnapshot {
|
||||
return MetricsSnapshot{
|
||||
Hit: metrics.hit.Load(),
|
||||
Miss: metrics.miss.Load(),
|
||||
DecodeError: metrics.decodeError.Load(),
|
||||
SetError: metrics.setError.Load(),
|
||||
Invalidate: metrics.invalidate.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
// isExpired 检查是否过期
|
||||
func (item *cacheItem) isExpired() bool {
|
||||
if item.expiration == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().UnixNano() > item.expiration
|
||||
}
|
||||
|
||||
// MemoryCache 内存缓存实现(降级使用)
|
||||
type MemoryCache struct {
|
||||
items sync.Map
|
||||
// cleanupInterval 清理过期缓存的间隔
|
||||
cleanupInterval time.Duration
|
||||
// stopCleanup 停止清理协程的通道
|
||||
stopCleanup chan struct{}
|
||||
}
|
||||
|
||||
// NewMemoryCache 创建内存缓存
|
||||
func NewMemoryCache() *MemoryCache {
|
||||
c := &MemoryCache{
|
||||
cleanupInterval: 1 * time.Minute,
|
||||
stopCleanup: make(chan struct{}),
|
||||
}
|
||||
// 启动后台清理协程
|
||||
go c.cleanup()
|
||||
return c
|
||||
}
|
||||
|
||||
// Set 设置缓存值
|
||||
func (c *MemoryCache) Set(key string, value interface{}, ttl time.Duration) {
|
||||
key = normalizeKey(key)
|
||||
var expiration int64
|
||||
if ttl > 0 {
|
||||
expiration = time.Now().Add(ttl).UnixNano()
|
||||
}
|
||||
c.items.Store(key, &cacheItem{
|
||||
value: value,
|
||||
expiration: expiration,
|
||||
})
|
||||
}
|
||||
|
||||
// Get 获取缓存值
|
||||
func (c *MemoryCache) Get(key string) (interface{}, bool) {
|
||||
key = normalizeKey(key)
|
||||
val, ok := c.items.Load(key)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
item := val.(*cacheItem)
|
||||
if item.isExpired() {
|
||||
c.items.Delete(key)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return item.value, true
|
||||
}
|
||||
|
||||
// Delete 删除缓存
|
||||
func (c *MemoryCache) Delete(key string) {
|
||||
key = normalizeKey(key)
|
||||
metrics.invalidate.Add(1)
|
||||
c.items.Delete(key)
|
||||
}
|
||||
|
||||
// DeleteByPrefix 根据前缀删除缓存
|
||||
func (c *MemoryCache) DeleteByPrefix(prefix string) {
|
||||
prefix = normalizePrefix(prefix)
|
||||
c.items.Range(func(key, value interface{}) bool {
|
||||
if keyStr, ok := key.(string); ok {
|
||||
if strings.HasPrefix(keyStr, prefix) {
|
||||
metrics.invalidate.Add(1)
|
||||
c.items.Delete(key)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Clear 清空所有缓存
|
||||
func (c *MemoryCache) Clear() {
|
||||
c.items.Range(func(key, value interface{}) bool {
|
||||
metrics.invalidate.Add(1)
|
||||
c.items.Delete(key)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Exists 检查键是否存在
|
||||
func (c *MemoryCache) Exists(key string) bool {
|
||||
_, ok := c.Get(key)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Increment 增加计数器的值
|
||||
func (c *MemoryCache) Increment(key string) int64 {
|
||||
return c.IncrementBy(key, 1)
|
||||
}
|
||||
|
||||
// IncrementBy 增加指定值
|
||||
func (c *MemoryCache) IncrementBy(key string, value int64) int64 {
|
||||
key = normalizeKey(key)
|
||||
for {
|
||||
val, ok := c.items.Load(key)
|
||||
if !ok {
|
||||
// 键不存在,创建新值
|
||||
c.items.Store(key, &cacheItem{
|
||||
value: value,
|
||||
expiration: 0,
|
||||
})
|
||||
return value
|
||||
}
|
||||
|
||||
item := val.(*cacheItem)
|
||||
if item.isExpired() {
|
||||
// 已过期,创建新值
|
||||
c.items.Store(key, &cacheItem{
|
||||
value: value,
|
||||
expiration: 0,
|
||||
})
|
||||
return value
|
||||
}
|
||||
|
||||
// 尝试更新
|
||||
currentValue, ok := item.value.(int64)
|
||||
if !ok {
|
||||
// 类型不匹配,覆盖为新值
|
||||
c.items.Store(key, &cacheItem{
|
||||
value: value,
|
||||
expiration: item.expiration,
|
||||
})
|
||||
return value
|
||||
}
|
||||
|
||||
newValue := currentValue + value
|
||||
// 使用 CAS 操作确保并发安全
|
||||
if c.items.CompareAndSwap(key, val, &cacheItem{
|
||||
value: newValue,
|
||||
expiration: item.expiration,
|
||||
}) {
|
||||
return newValue
|
||||
}
|
||||
// CAS 失败,重试
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup 定期清理过期缓存
|
||||
func (c *MemoryCache) cleanup() {
|
||||
ticker := time.NewTicker(c.cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
c.cleanExpired()
|
||||
case <-c.stopCleanup:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanExpired 清理过期缓存
|
||||
func (c *MemoryCache) cleanExpired() {
|
||||
count := 0
|
||||
c.items.Range(func(key, value interface{}) bool {
|
||||
item := value.(*cacheItem)
|
||||
if item.isExpired() {
|
||||
c.items.Delete(key)
|
||||
count++
|
||||
}
|
||||
return true
|
||||
})
|
||||
if count > 0 {
|
||||
log.Printf("[Cache] Cleaned %d expired items", count)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止缓存清理协程
|
||||
func (c *MemoryCache) Stop() {
|
||||
close(c.stopCleanup)
|
||||
}
|
||||
|
||||
// RedisCache Redis缓存实现
|
||||
type RedisCache struct {
|
||||
client *redisPkg.Client
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewRedisCache 创建Redis缓存
|
||||
func NewRedisCache(client *redisPkg.Client) *RedisCache {
|
||||
return &RedisCache{
|
||||
client: client,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
// Set 设置缓存值
|
||||
func (c *RedisCache) Set(key string, value interface{}, ttl time.Duration) {
|
||||
key = normalizeKey(key)
|
||||
// 将值序列化为JSON
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
metrics.setError.Add(1)
|
||||
log.Printf("[RedisCache] Failed to marshal value for key %s: %v", key, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.client.Set(c.ctx, key, data, ttl); err != nil {
|
||||
metrics.setError.Add(1)
|
||||
log.Printf("[RedisCache] Failed to set key %s: %v", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get 获取缓存值
|
||||
func (c *RedisCache) Get(key string) (interface{}, bool) {
|
||||
key = normalizeKey(key)
|
||||
data, err := c.client.Get(c.ctx, key)
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, false
|
||||
}
|
||||
log.Printf("[RedisCache] Failed to get key %s: %v", key, err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 返回原始字符串,由调用侧决定如何解码为目标类型
|
||||
return data, true
|
||||
}
|
||||
|
||||
// Delete 删除缓存
|
||||
func (c *RedisCache) Delete(key string) {
|
||||
key = normalizeKey(key)
|
||||
metrics.invalidate.Add(1)
|
||||
if err := c.client.Del(c.ctx, key); err != nil {
|
||||
log.Printf("[RedisCache] Failed to delete key %s: %v", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteByPrefix 根据前缀删除缓存
|
||||
func (c *RedisCache) DeleteByPrefix(prefix string) {
|
||||
prefix = normalizePrefix(prefix)
|
||||
// 使用原生客户端执行SCAN命令
|
||||
rdb := c.client.GetClient()
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := rdb.Scan(c.ctx, cursor, prefix+"*", 100).Result()
|
||||
if err != nil {
|
||||
log.Printf("[RedisCache] Failed to scan keys with prefix %s: %v", prefix, err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(keys) > 0 {
|
||||
metrics.invalidate.Add(int64(len(keys)))
|
||||
if err := c.client.Del(c.ctx, keys...); err != nil {
|
||||
log.Printf("[RedisCache] Failed to delete keys with prefix %s: %v", prefix, err)
|
||||
}
|
||||
}
|
||||
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear 清空所有缓存
|
||||
func (c *RedisCache) Clear() {
|
||||
if settings.DisableFlushDB {
|
||||
log.Printf("[RedisCache] Skip FlushDB because cache.disable_flushdb=true")
|
||||
return
|
||||
}
|
||||
metrics.invalidate.Add(1)
|
||||
rdb := c.client.GetClient()
|
||||
if err := rdb.FlushDB(c.ctx).Err(); err != nil {
|
||||
log.Printf("[RedisCache] Failed to clear cache: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Exists 检查键是否存在
|
||||
func (c *RedisCache) Exists(key string) bool {
|
||||
key = normalizeKey(key)
|
||||
n, err := c.client.Exists(c.ctx, key)
|
||||
if err != nil {
|
||||
log.Printf("[RedisCache] Failed to check existence of key %s: %v", key, err)
|
||||
return false
|
||||
}
|
||||
return n > 0
|
||||
}
|
||||
|
||||
// Increment 增加计数器的值
|
||||
func (c *RedisCache) Increment(key string) int64 {
|
||||
return c.IncrementBy(key, 1)
|
||||
}
|
||||
|
||||
// IncrementBy 增加指定值
|
||||
func (c *RedisCache) IncrementBy(key string, value int64) int64 {
|
||||
key = normalizeKey(key)
|
||||
rdb := c.client.GetClient()
|
||||
result, err := rdb.IncrBy(c.ctx, key, value).Result()
|
||||
if err != nil {
|
||||
log.Printf("[RedisCache] Failed to increment key %s: %v", key, err)
|
||||
return 0
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// 全局缓存实例
|
||||
var globalCache Cache
|
||||
var once sync.Once
|
||||
|
||||
// InitCache 初始化全局缓存实例(使用Redis)
|
||||
func InitCache(redisClient *redisPkg.Client) {
|
||||
once.Do(func() {
|
||||
if redisClient != nil {
|
||||
globalCache = NewRedisCache(redisClient)
|
||||
log.Println("[Cache] Initialized Redis cache")
|
||||
} else {
|
||||
globalCache = NewMemoryCache()
|
||||
log.Println("[Cache] Initialized Memory cache (Redis not available)")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// GetCache 获取全局缓存实例
|
||||
func GetCache() Cache {
|
||||
if globalCache == nil {
|
||||
// 如果未初始化,返回内存缓存作为降级
|
||||
log.Println("[Cache] Warning: Cache not initialized, using Memory cache")
|
||||
return NewMemoryCache()
|
||||
}
|
||||
return globalCache
|
||||
}
|
||||
|
||||
// GetRedisClient 从缓存中获取Redis客户端(仅在Redis模式下有效)
|
||||
func GetRedisClient() (*redisPkg.Client, error) {
|
||||
if redisCache, ok := globalCache.(*RedisCache); ok {
|
||||
return redisCache.client, nil
|
||||
}
|
||||
return nil, fmt.Errorf("cache is not using Redis backend")
|
||||
}
|
||||
|
||||
func SetWithJitter(c Cache, key string, value interface{}, ttl time.Duration, jitterRatio float64) {
|
||||
if !settings.Enabled {
|
||||
return
|
||||
}
|
||||
c.Set(key, value, ApplyTTLJitter(ttl, jitterRatio))
|
||||
}
|
||||
|
||||
func SetNull(c Cache, key string, ttl time.Duration) {
|
||||
if !settings.Enabled {
|
||||
return
|
||||
}
|
||||
c.Set(key, nullMarkerValue, ttl)
|
||||
}
|
||||
|
||||
func ApplyTTLJitter(ttl time.Duration, jitterRatio float64) time.Duration {
|
||||
if ttl <= 0 || jitterRatio <= 0 {
|
||||
return ttl
|
||||
}
|
||||
if jitterRatio > 1 {
|
||||
jitterRatio = 1
|
||||
}
|
||||
maxJitter := int64(float64(ttl) * jitterRatio)
|
||||
if maxJitter <= 0 {
|
||||
return ttl
|
||||
}
|
||||
delta := rand.Int63n(maxJitter + 1)
|
||||
return ttl + time.Duration(delta)
|
||||
}
|
||||
|
||||
func GetTyped[T any](c Cache, key string) (T, bool) {
|
||||
var zero T
|
||||
if !settings.Enabled {
|
||||
return zero, false
|
||||
}
|
||||
raw, ok := c.Get(key)
|
||||
if !ok {
|
||||
metrics.miss.Add(1)
|
||||
return zero, false
|
||||
}
|
||||
if str, ok := raw.(string); ok && str == nullMarkerValue {
|
||||
metrics.hit.Add(1)
|
||||
return zero, false
|
||||
}
|
||||
|
||||
if typed, ok := raw.(T); ok {
|
||||
metrics.hit.Add(1)
|
||||
return typed, true
|
||||
}
|
||||
|
||||
var out T
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
if err := json.Unmarshal([]byte(v), &out); err != nil {
|
||||
metrics.decodeError.Add(1)
|
||||
return zero, false
|
||||
}
|
||||
metrics.hit.Add(1)
|
||||
return out, true
|
||||
case []byte:
|
||||
if err := json.Unmarshal(v, &out); err != nil {
|
||||
metrics.decodeError.Add(1)
|
||||
return zero, false
|
||||
}
|
||||
metrics.hit.Add(1)
|
||||
return out, true
|
||||
default:
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
metrics.decodeError.Add(1)
|
||||
return zero, false
|
||||
}
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
metrics.decodeError.Add(1)
|
||||
return zero, false
|
||||
}
|
||||
metrics.hit.Add(1)
|
||||
return out, true
|
||||
}
|
||||
}
|
||||
|
||||
func GetOrLoadTyped[T any](
|
||||
c Cache,
|
||||
key string,
|
||||
ttl time.Duration,
|
||||
jitterRatio float64,
|
||||
nullTTL time.Duration,
|
||||
loader func() (T, error),
|
||||
) (T, error) {
|
||||
if cached, ok := GetTyped[T](c, key); ok {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
lockValue, _ := loadLocks.LoadOrStore(key, &sync.Mutex{})
|
||||
lock := lockValue.(*sync.Mutex)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
if cached, ok := GetTyped[T](c, key); ok {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
loaded, err := loader()
|
||||
if err != nil {
|
||||
var zero T
|
||||
return zero, err
|
||||
}
|
||||
|
||||
encoded, marshalErr := json.Marshal(loaded)
|
||||
if marshalErr == nil && string(encoded) == "null" && nullTTL > 0 {
|
||||
SetNull(c, key, nullTTL)
|
||||
return loaded, nil
|
||||
}
|
||||
|
||||
SetWithJitter(c, key, loaded, ttl, jitterRatio)
|
||||
return loaded, nil
|
||||
}
|
||||
147
internal/cache/keys.go
vendored
Normal file
147
internal/cache/keys.go
vendored
Normal file
@@ -0,0 +1,147 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// 缓存键前缀常量
|
||||
const (
|
||||
// 帖子相关
|
||||
PrefixPostList = "posts:list"
|
||||
PrefixPost = "posts:detail"
|
||||
|
||||
// 会话相关
|
||||
PrefixConversationList = "conversations:list"
|
||||
PrefixConversationDetail = "conversations:detail"
|
||||
|
||||
// 群组相关
|
||||
PrefixGroupMembers = "groups:members"
|
||||
PrefixGroupInfo = "groups:info"
|
||||
|
||||
// 未读数相关
|
||||
PrefixUnreadSystem = "unread:system"
|
||||
PrefixUnreadConversation = "unread:conversation"
|
||||
PrefixUnreadDetail = "unread:detail"
|
||||
|
||||
// 用户相关
|
||||
PrefixUserInfo = "users:info"
|
||||
PrefixUserMe = "users:me"
|
||||
)
|
||||
|
||||
// PostListKey 生成帖子列表缓存键
|
||||
// postType: 帖子类型 (recommend, hot, follow, latest)
|
||||
// page: 页码
|
||||
// pageSize: 每页数量
|
||||
// userID: 用户维度(仅在个性化列表如 follow 场景使用)
|
||||
func PostListKey(postType string, userID string, page, pageSize int) string {
|
||||
if userID == "" {
|
||||
return fmt.Sprintf("%s:%s:%d:%d", PrefixPostList, postType, page, pageSize)
|
||||
}
|
||||
return fmt.Sprintf("%s:%s:%s:%d:%d", PrefixPostList, postType, userID, page, pageSize)
|
||||
}
|
||||
|
||||
// PostDetailKey 生成帖子详情缓存键
|
||||
func PostDetailKey(postID string) string {
|
||||
return fmt.Sprintf("%s:%s", PrefixPost, postID)
|
||||
}
|
||||
|
||||
// ConversationListKey 生成会话列表缓存键
|
||||
func ConversationListKey(userID string, page, pageSize int) string {
|
||||
return fmt.Sprintf("%s:%s:%d:%d", PrefixConversationList, userID, page, pageSize)
|
||||
}
|
||||
|
||||
// ConversationDetailKey 生成会话详情缓存键
|
||||
func ConversationDetailKey(conversationID, userID string) string {
|
||||
return fmt.Sprintf("%s:%s:%s", PrefixConversationDetail, conversationID, userID)
|
||||
}
|
||||
|
||||
// GroupMembersKey 生成群组成员缓存键
|
||||
func GroupMembersKey(groupID string, page, pageSize int) string {
|
||||
return fmt.Sprintf("%s:%s:page:%d:size:%d", PrefixGroupMembers, groupID, page, pageSize)
|
||||
}
|
||||
|
||||
// GroupMembersAllKey 生成群组全量成员ID列表缓存键
|
||||
func GroupMembersAllKey(groupID string) string {
|
||||
return fmt.Sprintf("%s:all:%s", PrefixGroupMembers, groupID)
|
||||
}
|
||||
|
||||
// GroupInfoKey 生成群组信息缓存键
|
||||
func GroupInfoKey(groupID string) string {
|
||||
return fmt.Sprintf("%s:%s", PrefixGroupInfo, groupID)
|
||||
}
|
||||
|
||||
// UnreadSystemKey 生成系统消息未读数缓存键
|
||||
func UnreadSystemKey(userID string) string {
|
||||
return fmt.Sprintf("%s:%s", PrefixUnreadSystem, userID)
|
||||
}
|
||||
|
||||
// UnreadConversationKey 生成会话未读总数缓存键
|
||||
func UnreadConversationKey(userID string) string {
|
||||
return fmt.Sprintf("%s:%s", PrefixUnreadConversation, userID)
|
||||
}
|
||||
|
||||
// UnreadDetailKey 生成单个会话未读数缓存键
|
||||
func UnreadDetailKey(userID, conversationID string) string {
|
||||
return fmt.Sprintf("%s:%s:%s", PrefixUnreadDetail, userID, conversationID)
|
||||
}
|
||||
|
||||
// UserInfoKey 生成用户信息缓存键
|
||||
func UserInfoKey(userID string) string {
|
||||
return fmt.Sprintf("%s:%s", PrefixUserInfo, userID)
|
||||
}
|
||||
|
||||
// UserMeKey 生成当前用户信息缓存键
|
||||
func UserMeKey(userID string) string {
|
||||
return fmt.Sprintf("%s:%s", PrefixUserMe, userID)
|
||||
}
|
||||
|
||||
// InvalidatePostList 失效帖子列表缓存
|
||||
func InvalidatePostList(cache Cache) {
|
||||
cache.DeleteByPrefix(PrefixPostList)
|
||||
}
|
||||
|
||||
// InvalidatePostDetail 失效帖子详情缓存
|
||||
func InvalidatePostDetail(cache Cache, postID string) {
|
||||
cache.Delete(PostDetailKey(postID))
|
||||
}
|
||||
|
||||
// InvalidateConversationList 失效会话列表缓存
|
||||
func InvalidateConversationList(cache Cache, userID string) {
|
||||
cache.DeleteByPrefix(PrefixConversationList + ":" + userID + ":")
|
||||
}
|
||||
|
||||
// InvalidateConversationDetail 失效会话详情缓存
|
||||
func InvalidateConversationDetail(cache Cache, conversationID, userID string) {
|
||||
cache.Delete(ConversationDetailKey(conversationID, userID))
|
||||
}
|
||||
|
||||
// InvalidateGroupMembers 失效群组成员缓存
|
||||
func InvalidateGroupMembers(cache Cache, groupID string) {
|
||||
cache.DeleteByPrefix(PrefixGroupMembers + ":" + groupID)
|
||||
}
|
||||
|
||||
// InvalidateGroupInfo 失效群组信息缓存
|
||||
func InvalidateGroupInfo(cache Cache, groupID string) {
|
||||
cache.Delete(GroupInfoKey(groupID))
|
||||
}
|
||||
|
||||
// InvalidateUnreadSystem 失效系统消息未读数缓存
|
||||
func InvalidateUnreadSystem(cache Cache, userID string) {
|
||||
cache.Delete(UnreadSystemKey(userID))
|
||||
}
|
||||
|
||||
// InvalidateUnreadConversation 失效会话未读数缓存
|
||||
func InvalidateUnreadConversation(cache Cache, userID string) {
|
||||
cache.Delete(UnreadConversationKey(userID))
|
||||
}
|
||||
|
||||
// InvalidateUnreadDetail 失效单个会话未读数缓存
|
||||
func InvalidateUnreadDetail(cache Cache, userID, conversationID string) {
|
||||
cache.Delete(UnreadDetailKey(userID, conversationID))
|
||||
}
|
||||
|
||||
// InvalidateUserInfo 失效用户信息缓存
|
||||
func InvalidateUserInfo(cache Cache, userID string) {
|
||||
cache.Delete(UserInfoKey(userID))
|
||||
cache.Delete(UserMeKey(userID))
|
||||
}
|
||||
393
internal/config/config.go
Normal file
393
internal/config/config.go
Normal file
@@ -0,0 +1,393 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
Cache CacheConfig `mapstructure:"cache"`
|
||||
S3 S3Config `mapstructure:"s3"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
Log LogConfig `mapstructure:"log"`
|
||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||
Upload UploadConfig `mapstructure:"upload"`
|
||||
Gorse GorseConfig `mapstructure:"gorse"`
|
||||
OpenAI OpenAIConfig `mapstructure:"openai"`
|
||||
Email EmailConfig `mapstructure:"email"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Type string `mapstructure:"type"`
|
||||
SQLite SQLiteConfig `mapstructure:"sqlite"`
|
||||
Postgres PostgresConfig `mapstructure:"postgres"`
|
||||
MaxIdleConns int `mapstructure:"max_idle_conns"`
|
||||
MaxOpenConns int `mapstructure:"max_open_conns"`
|
||||
LogLevel string `mapstructure:"log_level"`
|
||||
SlowThresholdMs int `mapstructure:"slow_threshold_ms"`
|
||||
IgnoreRecordNotFound bool `mapstructure:"ignore_record_not_found"`
|
||||
ParameterizedQueries bool `mapstructure:"parameterized_queries"`
|
||||
}
|
||||
|
||||
type SQLiteConfig struct {
|
||||
Path string `mapstructure:"path"`
|
||||
}
|
||||
|
||||
type PostgresConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
User string `mapstructure:"user"`
|
||||
Password string `mapstructure:"password"`
|
||||
DBName string `mapstructure:"dbname"`
|
||||
SSLMode string `mapstructure:"sslmode"`
|
||||
}
|
||||
|
||||
func (d PostgresConfig) DSN() string {
|
||||
return fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode,
|
||||
)
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
Type string `mapstructure:"type"`
|
||||
Redis RedisServerConfig `mapstructure:"redis"`
|
||||
Miniredis MiniredisConfig `mapstructure:"miniredis"`
|
||||
PoolSize int `mapstructure:"pool_size"`
|
||||
}
|
||||
|
||||
type CacheConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
KeyPrefix string `mapstructure:"key_prefix"`
|
||||
DefaultTTL int `mapstructure:"default_ttl"`
|
||||
NullTTL int `mapstructure:"null_ttl"`
|
||||
JitterRatio float64 `mapstructure:"jitter_ratio"`
|
||||
DisableFlushDB bool `mapstructure:"disable_flushdb"`
|
||||
Modules CacheModuleTTL `mapstructure:"modules"`
|
||||
}
|
||||
|
||||
type CacheModuleTTL struct {
|
||||
PostList int `mapstructure:"post_list_ttl"`
|
||||
Conversation int `mapstructure:"conversation_ttl"`
|
||||
UnreadCount int `mapstructure:"unread_count_ttl"`
|
||||
GroupMembers int `mapstructure:"group_members_ttl"`
|
||||
}
|
||||
|
||||
type RedisServerConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Password string `mapstructure:"password"`
|
||||
DB int `mapstructure:"db"`
|
||||
}
|
||||
|
||||
func (r RedisServerConfig) Addr() string {
|
||||
return fmt.Sprintf("%s:%d", r.Host, r.Port)
|
||||
}
|
||||
|
||||
type MiniredisConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
}
|
||||
|
||||
type S3Config struct {
|
||||
Endpoint string `mapstructure:"endpoint"`
|
||||
AccessKey string `mapstructure:"access_key"`
|
||||
SecretKey string `mapstructure:"secret_key"`
|
||||
Bucket string `mapstructure:"bucket"`
|
||||
UseSSL bool `mapstructure:"use_ssl"`
|
||||
Region string `mapstructure:"region"`
|
||||
Domain string `mapstructure:"domain"` // 自定义域名,如 s3.carrot.skin
|
||||
}
|
||||
|
||||
type JWTConfig struct {
|
||||
Secret string `mapstructure:"secret"`
|
||||
AccessTokenExpire time.Duration `mapstructure:"access_token_expire"`
|
||||
RefreshTokenExpire time.Duration `mapstructure:"refresh_token_expire"`
|
||||
}
|
||||
|
||||
type LogConfig struct {
|
||||
Level string `mapstructure:"level"`
|
||||
Encoding string `mapstructure:"encoding"`
|
||||
OutputPaths []string `mapstructure:"output_paths"`
|
||||
}
|
||||
|
||||
type RateLimitConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
RequestsPerMinute int `mapstructure:"requests_per_minute"`
|
||||
}
|
||||
|
||||
type UploadConfig struct {
|
||||
MaxFileSize int64 `mapstructure:"max_file_size"`
|
||||
AllowedTypes []string `mapstructure:"allowed_types"`
|
||||
}
|
||||
|
||||
type GorseConfig struct {
|
||||
Address string `mapstructure:"address"`
|
||||
APIKey string `mapstructure:"api_key"`
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
Dashboard string `mapstructure:"dashboard"`
|
||||
ImportPassword string `mapstructure:"import_password"`
|
||||
EmbeddingAPIKey string `mapstructure:"embedding_api_key"`
|
||||
EmbeddingURL string `mapstructure:"embedding_url"`
|
||||
EmbeddingModel string `mapstructure:"embedding_model"`
|
||||
}
|
||||
|
||||
type OpenAIConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
APIKey string `mapstructure:"api_key"`
|
||||
ModerationModel string `mapstructure:"moderation_model"`
|
||||
ModerationMaxImagesPerRequest int `mapstructure:"moderation_max_images_per_request"`
|
||||
RequestTimeout int `mapstructure:"request_timeout"`
|
||||
StrictModeration bool `mapstructure:"strict_moderation"`
|
||||
}
|
||||
|
||||
type EmailConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
FromAddress string `mapstructure:"from_address"`
|
||||
FromName string `mapstructure:"from_name"`
|
||||
UseTLS bool `mapstructure:"use_tls"`
|
||||
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"`
|
||||
Timeout int `mapstructure:"timeout"`
|
||||
}
|
||||
|
||||
func Load(configPath string) (*Config, error) {
|
||||
viper.SetConfigFile(configPath)
|
||||
viper.SetConfigType("yaml")
|
||||
|
||||
// 启用环境变量支持
|
||||
viper.SetEnvPrefix("APP")
|
||||
viper.AutomaticEnv()
|
||||
// 允许环境变量使用下划线或连字符
|
||||
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_", "-", "_"))
|
||||
|
||||
// Set default values
|
||||
viper.SetDefault("server.port", 8080)
|
||||
viper.SetDefault("server.mode", "debug")
|
||||
viper.SetDefault("server.host", "0.0.0.0")
|
||||
viper.SetDefault("database.type", "sqlite")
|
||||
viper.SetDefault("database.sqlite.path", "./data/carrot_bbs.db")
|
||||
viper.SetDefault("database.max_idle_conns", 10)
|
||||
viper.SetDefault("database.max_open_conns", 100)
|
||||
viper.SetDefault("database.log_level", "warn")
|
||||
viper.SetDefault("database.slow_threshold_ms", 200)
|
||||
viper.SetDefault("database.ignore_record_not_found", true)
|
||||
viper.SetDefault("database.parameterized_queries", true)
|
||||
viper.SetDefault("redis.type", "miniredis")
|
||||
viper.SetDefault("redis.redis.host", "localhost")
|
||||
viper.SetDefault("redis.redis.port", 6379)
|
||||
viper.SetDefault("redis.redis.password", "")
|
||||
viper.SetDefault("redis.redis.db", 0)
|
||||
viper.SetDefault("redis.miniredis.host", "localhost")
|
||||
viper.SetDefault("redis.miniredis.port", 6379)
|
||||
viper.SetDefault("redis.pool_size", 10)
|
||||
viper.SetDefault("cache.enabled", true)
|
||||
viper.SetDefault("cache.key_prefix", "")
|
||||
viper.SetDefault("cache.default_ttl", 30)
|
||||
viper.SetDefault("cache.null_ttl", 5)
|
||||
viper.SetDefault("cache.jitter_ratio", 0.1)
|
||||
viper.SetDefault("cache.disable_flushdb", true)
|
||||
viper.SetDefault("cache.modules.post_list_ttl", 30)
|
||||
viper.SetDefault("cache.modules.conversation_ttl", 60)
|
||||
viper.SetDefault("cache.modules.unread_count_ttl", 30)
|
||||
viper.SetDefault("cache.modules.group_members_ttl", 120)
|
||||
viper.SetDefault("jwt.secret", "your-jwt-secret-key-change-in-production")
|
||||
viper.SetDefault("jwt.access_token_expire", 86400)
|
||||
viper.SetDefault("jwt.refresh_token_expire", 604800)
|
||||
viper.SetDefault("log.level", "info")
|
||||
viper.SetDefault("log.encoding", "json")
|
||||
viper.SetDefault("log.output_paths", []string{"stdout", "./logs/app.log"})
|
||||
viper.SetDefault("rate_limit.enabled", true)
|
||||
viper.SetDefault("rate_limit.requests_per_minute", 60)
|
||||
viper.SetDefault("upload.max_file_size", 10485760)
|
||||
viper.SetDefault("upload.allowed_types", []string{"image/jpeg", "image/png", "image/gif", "image/webp"})
|
||||
viper.SetDefault("s3.endpoint", "")
|
||||
viper.SetDefault("s3.access_key", "")
|
||||
viper.SetDefault("s3.secret_key", "")
|
||||
viper.SetDefault("s3.bucket", "")
|
||||
viper.SetDefault("s3.use_ssl", true)
|
||||
viper.SetDefault("s3.region", "us-east-1")
|
||||
viper.SetDefault("s3.domain", "")
|
||||
viper.SetDefault("sensitive.enabled", true)
|
||||
viper.SetDefault("sensitive.replace_str", "***")
|
||||
viper.SetDefault("audit.enabled", false)
|
||||
viper.SetDefault("audit.provider", "local")
|
||||
viper.SetDefault("gorse.enabled", false)
|
||||
viper.SetDefault("gorse.address", "http://localhost:8087")
|
||||
viper.SetDefault("gorse.api_key", "")
|
||||
viper.SetDefault("gorse.dashboard", "http://localhost:8088")
|
||||
viper.SetDefault("gorse.import_password", "")
|
||||
viper.SetDefault("gorse.embedding_api_key", "")
|
||||
viper.SetDefault("gorse.embedding_url", "https://api.littlelan.cn/v1/embeddings")
|
||||
viper.SetDefault("gorse.embedding_model", "BAAI/bge-m3")
|
||||
viper.SetDefault("openai.enabled", true)
|
||||
viper.SetDefault("openai.base_url", "https://api.littlelan.cn/")
|
||||
viper.SetDefault("openai.api_key", "")
|
||||
viper.SetDefault("openai.moderation_model", "qwen3.5-122b")
|
||||
viper.SetDefault("openai.moderation_max_images_per_request", 1)
|
||||
viper.SetDefault("openai.request_timeout", 30)
|
||||
viper.SetDefault("openai.strict_moderation", false)
|
||||
viper.SetDefault("email.enabled", false)
|
||||
viper.SetDefault("email.host", "")
|
||||
viper.SetDefault("email.port", 587)
|
||||
viper.SetDefault("email.username", "")
|
||||
viper.SetDefault("email.password", "")
|
||||
viper.SetDefault("email.from_address", "")
|
||||
viper.SetDefault("email.from_name", "Carrot BBS")
|
||||
viper.SetDefault("email.use_tls", true)
|
||||
viper.SetDefault("email.insecure_skip_verify", false)
|
||||
viper.SetDefault("email.timeout", 15)
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
return nil, fmt.Errorf("failed to read config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := viper.Unmarshal(&cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
// Convert seconds to duration
|
||||
cfg.JWT.AccessTokenExpire = time.Duration(viper.GetInt("jwt.access_token_expire")) * time.Second
|
||||
cfg.JWT.RefreshTokenExpire = time.Duration(viper.GetInt("jwt.refresh_token_expire")) * time.Second
|
||||
|
||||
// 环境变量覆盖(显式处理敏感配置)
|
||||
cfg.JWT.Secret = getEnvOrDefault("APP_JWT_SECRET", cfg.JWT.Secret)
|
||||
cfg.Database.SQLite.Path = getEnvOrDefault("APP_DATABASE_SQLITE_PATH", cfg.Database.SQLite.Path)
|
||||
cfg.Database.Postgres.Host = getEnvOrDefault("APP_DATABASE_POSTGRES_HOST", cfg.Database.Postgres.Host)
|
||||
cfg.Database.Postgres.Port, _ = strconv.Atoi(getEnvOrDefault("APP_DATABASE_POSTGRES_PORT", fmt.Sprintf("%d", cfg.Database.Postgres.Port)))
|
||||
cfg.Database.Postgres.User = getEnvOrDefault("APP_DATABASE_POSTGRES_USER", cfg.Database.Postgres.User)
|
||||
cfg.Database.Postgres.Password = getEnvOrDefault("APP_DATABASE_POSTGRES_PASSWORD", cfg.Database.Postgres.Password)
|
||||
cfg.Database.Postgres.DBName = getEnvOrDefault("APP_DATABASE_POSTGRES_DBNAME", cfg.Database.Postgres.DBName)
|
||||
cfg.Database.LogLevel = getEnvOrDefault("APP_DATABASE_LOG_LEVEL", cfg.Database.LogLevel)
|
||||
cfg.Database.SlowThresholdMs, _ = strconv.Atoi(getEnvOrDefault("APP_DATABASE_SLOW_THRESHOLD_MS", fmt.Sprintf("%d", cfg.Database.SlowThresholdMs)))
|
||||
cfg.Database.IgnoreRecordNotFound, _ = strconv.ParseBool(getEnvOrDefault("APP_DATABASE_IGNORE_RECORD_NOT_FOUND", fmt.Sprintf("%t", cfg.Database.IgnoreRecordNotFound)))
|
||||
cfg.Database.ParameterizedQueries, _ = strconv.ParseBool(getEnvOrDefault("APP_DATABASE_PARAMETERIZED_QUERIES", fmt.Sprintf("%t", cfg.Database.ParameterizedQueries)))
|
||||
cfg.Redis.Redis.Host = getEnvOrDefault("APP_REDIS_REDIS_HOST", cfg.Redis.Redis.Host)
|
||||
cfg.Redis.Redis.Port, _ = strconv.Atoi(getEnvOrDefault("APP_REDIS_REDIS_PORT", fmt.Sprintf("%d", cfg.Redis.Redis.Port)))
|
||||
cfg.Redis.Redis.Password = getEnvOrDefault("APP_REDIS_REDIS_PASSWORD", cfg.Redis.Redis.Password)
|
||||
cfg.Redis.Redis.DB, _ = strconv.Atoi(getEnvOrDefault("APP_REDIS_REDIS_DB", fmt.Sprintf("%d", cfg.Redis.Redis.DB)))
|
||||
cfg.Redis.Miniredis.Host = getEnvOrDefault("APP_REDIS_MINIREDIS_HOST", cfg.Redis.Miniredis.Host)
|
||||
cfg.Redis.Miniredis.Port, _ = strconv.Atoi(getEnvOrDefault("APP_REDIS_MINIREDIS_PORT", fmt.Sprintf("%d", cfg.Redis.Miniredis.Port)))
|
||||
cfg.Redis.Type = getEnvOrDefault("APP_REDIS_TYPE", cfg.Redis.Type)
|
||||
cfg.Cache.KeyPrefix = getEnvOrDefault("APP_CACHE_KEY_PREFIX", cfg.Cache.KeyPrefix)
|
||||
cfg.Cache.Enabled, _ = strconv.ParseBool(getEnvOrDefault("APP_CACHE_ENABLED", fmt.Sprintf("%t", cfg.Cache.Enabled)))
|
||||
cfg.Cache.DisableFlushDB, _ = strconv.ParseBool(getEnvOrDefault("APP_CACHE_DISABLE_FLUSHDB", fmt.Sprintf("%t", cfg.Cache.DisableFlushDB)))
|
||||
cfg.Cache.DefaultTTL, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_DEFAULT_TTL", fmt.Sprintf("%d", cfg.Cache.DefaultTTL)))
|
||||
cfg.Cache.NullTTL, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_NULL_TTL", fmt.Sprintf("%d", cfg.Cache.NullTTL)))
|
||||
cfg.Cache.JitterRatio, _ = strconv.ParseFloat(getEnvOrDefault("APP_CACHE_JITTER_RATIO", fmt.Sprintf("%.2f", cfg.Cache.JitterRatio)), 64)
|
||||
cfg.Cache.Modules.PostList, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_MODULES_POST_LIST_TTL", fmt.Sprintf("%d", cfg.Cache.Modules.PostList)))
|
||||
cfg.Cache.Modules.Conversation, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_MODULES_CONVERSATION_TTL", fmt.Sprintf("%d", cfg.Cache.Modules.Conversation)))
|
||||
cfg.Cache.Modules.UnreadCount, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_MODULES_UNREAD_COUNT_TTL", fmt.Sprintf("%d", cfg.Cache.Modules.UnreadCount)))
|
||||
cfg.Cache.Modules.GroupMembers, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_MODULES_GROUP_MEMBERS_TTL", fmt.Sprintf("%d", cfg.Cache.Modules.GroupMembers)))
|
||||
cfg.S3.Endpoint = getEnvOrDefault("APP_S3_ENDPOINT", cfg.S3.Endpoint)
|
||||
cfg.S3.AccessKey = getEnvOrDefault("APP_S3_ACCESS_KEY", cfg.S3.AccessKey)
|
||||
cfg.S3.SecretKey = getEnvOrDefault("APP_S3_SECRET_KEY", cfg.S3.SecretKey)
|
||||
cfg.S3.Bucket = getEnvOrDefault("APP_S3_BUCKET", cfg.S3.Bucket)
|
||||
cfg.S3.Domain = getEnvOrDefault("APP_S3_DOMAIN", cfg.S3.Domain)
|
||||
cfg.Server.Host = getEnvOrDefault("APP_SERVER_HOST", cfg.Server.Host)
|
||||
cfg.Server.Port, _ = strconv.Atoi(getEnvOrDefault("APP_SERVER_PORT", fmt.Sprintf("%d", cfg.Server.Port)))
|
||||
cfg.Server.Mode = getEnvOrDefault("APP_SERVER_MODE", cfg.Server.Mode)
|
||||
cfg.Gorse.Address = getEnvOrDefault("APP_GORSE_ADDRESS", cfg.Gorse.Address)
|
||||
cfg.Gorse.APIKey = getEnvOrDefault("APP_GORSE_API_KEY", cfg.Gorse.APIKey)
|
||||
cfg.Gorse.Dashboard = getEnvOrDefault("APP_GORSE_DASHBOARD", cfg.Gorse.Dashboard)
|
||||
cfg.Gorse.ImportPassword = getEnvOrDefault("APP_GORSE_IMPORT_PASSWORD", cfg.Gorse.ImportPassword)
|
||||
cfg.Gorse.EmbeddingAPIKey = getEnvOrDefault("APP_GORSE_EMBEDDING_API_KEY", cfg.Gorse.EmbeddingAPIKey)
|
||||
cfg.Gorse.EmbeddingURL = getEnvOrDefault("APP_GORSE_EMBEDDING_URL", cfg.Gorse.EmbeddingURL)
|
||||
cfg.Gorse.EmbeddingModel = getEnvOrDefault("APP_GORSE_EMBEDDING_MODEL", cfg.Gorse.EmbeddingModel)
|
||||
cfg.OpenAI.BaseURL = getEnvOrDefault("APP_OPENAI_BASE_URL", cfg.OpenAI.BaseURL)
|
||||
cfg.OpenAI.APIKey = getEnvOrDefault("APP_OPENAI_API_KEY", cfg.OpenAI.APIKey)
|
||||
cfg.OpenAI.ModerationModel = getEnvOrDefault("APP_OPENAI_MODERATION_MODEL", cfg.OpenAI.ModerationModel)
|
||||
cfg.OpenAI.ModerationMaxImagesPerRequest, _ = strconv.Atoi(getEnvOrDefault("APP_OPENAI_MODERATION_MAX_IMAGES_PER_REQUEST", fmt.Sprintf("%d", cfg.OpenAI.ModerationMaxImagesPerRequest)))
|
||||
cfg.OpenAI.RequestTimeout, _ = strconv.Atoi(getEnvOrDefault("APP_OPENAI_REQUEST_TIMEOUT", fmt.Sprintf("%d", cfg.OpenAI.RequestTimeout)))
|
||||
cfg.OpenAI.Enabled, _ = strconv.ParseBool(getEnvOrDefault("APP_OPENAI_ENABLED", fmt.Sprintf("%t", cfg.OpenAI.Enabled)))
|
||||
cfg.OpenAI.StrictModeration, _ = strconv.ParseBool(getEnvOrDefault("APP_OPENAI_STRICT_MODERATION", fmt.Sprintf("%t", cfg.OpenAI.StrictModeration)))
|
||||
cfg.Email.Enabled, _ = strconv.ParseBool(getEnvOrDefault("APP_EMAIL_ENABLED", fmt.Sprintf("%t", cfg.Email.Enabled)))
|
||||
cfg.Email.Host = getEnvOrDefault("APP_EMAIL_HOST", cfg.Email.Host)
|
||||
cfg.Email.Port, _ = strconv.Atoi(getEnvOrDefault("APP_EMAIL_PORT", fmt.Sprintf("%d", cfg.Email.Port)))
|
||||
cfg.Email.Username = getEnvOrDefault("APP_EMAIL_USERNAME", cfg.Email.Username)
|
||||
cfg.Email.Password = getEnvOrDefault("APP_EMAIL_PASSWORD", cfg.Email.Password)
|
||||
cfg.Email.FromAddress = getEnvOrDefault("APP_EMAIL_FROM_ADDRESS", cfg.Email.FromAddress)
|
||||
cfg.Email.FromName = getEnvOrDefault("APP_EMAIL_FROM_NAME", cfg.Email.FromName)
|
||||
cfg.Email.UseTLS, _ = strconv.ParseBool(getEnvOrDefault("APP_EMAIL_USE_TLS", fmt.Sprintf("%t", cfg.Email.UseTLS)))
|
||||
cfg.Email.InsecureSkipVerify, _ = strconv.ParseBool(getEnvOrDefault("APP_EMAIL_INSECURE_SKIP_VERIFY", fmt.Sprintf("%t", cfg.Email.InsecureSkipVerify)))
|
||||
cfg.Email.Timeout, _ = strconv.Atoi(getEnvOrDefault("APP_EMAIL_TIMEOUT", fmt.Sprintf("%d", cfg.Email.Timeout)))
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// getEnvOrDefault 获取环境变量值,如果未设置则返回默认值
|
||||
func getEnvOrDefault(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// NewRedis 创建Redis客户端(真实Redis)
|
||||
func NewRedis(cfg *RedisConfig) (*redis.Client, error) {
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Redis.Addr(),
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
PoolSize: cfg.PoolSize,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to redis: %w", err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// NewS3 创建S3客户端
|
||||
func NewS3(cfg *S3Config) (*minio.Client, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
client, err := minio.New(cfg.Endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(cfg.AccessKey, cfg.SecretKey, ""),
|
||||
Secure: cfg.UseSSL,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create S3 client: %w", err)
|
||||
}
|
||||
|
||||
exists, err := client.BucketExists(ctx, cfg.Bucket)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check bucket: %w", err)
|
||||
}
|
||||
|
||||
if !exists {
|
||||
if err := client.MakeBucket(ctx, cfg.Bucket, minio.MakeBucketOptions{
|
||||
Region: cfg.Region,
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("failed to create bucket: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
885
internal/dto/converter.go
Normal file
885
internal/dto/converter.go
Normal file
@@ -0,0 +1,885 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// ==================== User 转换 ====================
|
||||
|
||||
// getAvatarOrDefault 获取头像URL,如果为空则返回在线头像生成服务的URL
|
||||
func getAvatarOrDefault(user *model.User) string {
|
||||
return utils.GetAvatarOrDefault(user.Username, user.Nickname, user.Avatar)
|
||||
}
|
||||
|
||||
// ConvertUserToResponse 将User转换为UserResponse
|
||||
func ConvertUserToResponse(user *model.User) *UserResponse {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserResponse{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Nickname: user.Nickname,
|
||||
Email: user.Email,
|
||||
Phone: user.Phone,
|
||||
EmailVerified: user.EmailVerified,
|
||||
Avatar: getAvatarOrDefault(user),
|
||||
CoverURL: user.CoverURL,
|
||||
Bio: user.Bio,
|
||||
Website: user.Website,
|
||||
Location: user.Location,
|
||||
PostsCount: user.PostsCount,
|
||||
FollowersCount: user.FollowersCount,
|
||||
FollowingCount: user.FollowingCount,
|
||||
CreatedAt: FormatTime(user.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertUserToResponseWithFollowing 将User转换为UserResponse(包含关注状态)
|
||||
func ConvertUserToResponseWithFollowing(user *model.User, isFollowing bool) *UserResponse {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserResponse{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Nickname: user.Nickname,
|
||||
Email: user.Email,
|
||||
Phone: user.Phone,
|
||||
EmailVerified: user.EmailVerified,
|
||||
Avatar: getAvatarOrDefault(user),
|
||||
CoverURL: user.CoverURL,
|
||||
Bio: user.Bio,
|
||||
Website: user.Website,
|
||||
Location: user.Location,
|
||||
PostsCount: user.PostsCount,
|
||||
FollowersCount: user.FollowersCount,
|
||||
FollowingCount: user.FollowingCount,
|
||||
IsFollowing: isFollowing,
|
||||
IsFollowingMe: false, // 默认false,需要单独计算
|
||||
CreatedAt: FormatTime(user.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertUserToResponseWithPostsCount 将User转换为UserResponse(使用实时计算的帖子数量)
|
||||
func ConvertUserToResponseWithPostsCount(user *model.User, postsCount int) *UserResponse {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserResponse{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Nickname: user.Nickname,
|
||||
Email: user.Email,
|
||||
Phone: user.Phone,
|
||||
EmailVerified: user.EmailVerified,
|
||||
Avatar: getAvatarOrDefault(user),
|
||||
CoverURL: user.CoverURL,
|
||||
Bio: user.Bio,
|
||||
Website: user.Website,
|
||||
Location: user.Location,
|
||||
PostsCount: postsCount,
|
||||
FollowersCount: user.FollowersCount,
|
||||
FollowingCount: user.FollowingCount,
|
||||
CreatedAt: FormatTime(user.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertUserToResponseWithMutualFollow 将User转换为UserResponse(包含双向关注状态)
|
||||
func ConvertUserToResponseWithMutualFollow(user *model.User, isFollowing, isFollowingMe bool) *UserResponse {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserResponse{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Nickname: user.Nickname,
|
||||
Email: user.Email,
|
||||
Phone: user.Phone,
|
||||
EmailVerified: user.EmailVerified,
|
||||
Avatar: getAvatarOrDefault(user),
|
||||
CoverURL: user.CoverURL,
|
||||
Bio: user.Bio,
|
||||
Website: user.Website,
|
||||
Location: user.Location,
|
||||
PostsCount: user.PostsCount,
|
||||
FollowersCount: user.FollowersCount,
|
||||
FollowingCount: user.FollowingCount,
|
||||
IsFollowing: isFollowing,
|
||||
IsFollowingMe: isFollowingMe,
|
||||
CreatedAt: FormatTime(user.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertUserToResponseWithMutualFollowAndPostsCount 将User转换为UserResponse(包含双向关注状态和实时计算的帖子数量)
|
||||
func ConvertUserToResponseWithMutualFollowAndPostsCount(user *model.User, isFollowing, isFollowingMe bool, postsCount int) *UserResponse {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserResponse{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Nickname: user.Nickname,
|
||||
Email: user.Email,
|
||||
Phone: user.Phone,
|
||||
EmailVerified: user.EmailVerified,
|
||||
Avatar: getAvatarOrDefault(user),
|
||||
CoverURL: user.CoverURL,
|
||||
Bio: user.Bio,
|
||||
Website: user.Website,
|
||||
Location: user.Location,
|
||||
PostsCount: postsCount,
|
||||
FollowersCount: user.FollowersCount,
|
||||
FollowingCount: user.FollowingCount,
|
||||
IsFollowing: isFollowing,
|
||||
IsFollowingMe: isFollowingMe,
|
||||
CreatedAt: FormatTime(user.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertUserToDetailResponse 将User转换为UserDetailResponse
|
||||
func ConvertUserToDetailResponse(user *model.User) *UserDetailResponse {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserDetailResponse{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Nickname: user.Nickname,
|
||||
Email: user.Email,
|
||||
EmailVerified: user.EmailVerified,
|
||||
Avatar: getAvatarOrDefault(user),
|
||||
CoverURL: user.CoverURL,
|
||||
Bio: user.Bio,
|
||||
Website: user.Website,
|
||||
Location: user.Location,
|
||||
PostsCount: user.PostsCount,
|
||||
FollowersCount: user.FollowersCount,
|
||||
FollowingCount: user.FollowingCount,
|
||||
IsVerified: user.IsVerified,
|
||||
CreatedAt: FormatTime(user.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertUserToDetailResponseWithPostsCount 将User转换为UserDetailResponse(使用实时计算的帖子数量)
|
||||
func ConvertUserToDetailResponseWithPostsCount(user *model.User, postsCount int) *UserDetailResponse {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserDetailResponse{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Nickname: user.Nickname,
|
||||
Email: user.Email,
|
||||
EmailVerified: user.EmailVerified,
|
||||
Phone: user.Phone, // 仅当前用户自己可见
|
||||
Avatar: getAvatarOrDefault(user),
|
||||
CoverURL: user.CoverURL,
|
||||
Bio: user.Bio,
|
||||
Website: user.Website,
|
||||
Location: user.Location,
|
||||
PostsCount: postsCount,
|
||||
FollowersCount: user.FollowersCount,
|
||||
FollowingCount: user.FollowingCount,
|
||||
IsVerified: user.IsVerified,
|
||||
CreatedAt: FormatTime(user.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertUsersToResponse 将User列表转换为响应列表
|
||||
func ConvertUsersToResponse(users []*model.User) []*UserResponse {
|
||||
result := make([]*UserResponse, 0, len(users))
|
||||
for _, user := range users {
|
||||
result = append(result, ConvertUserToResponse(user))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ConvertUsersToResponseWithMutualFollow 将User列表转换为响应列表(包含双向关注状态)
|
||||
// followingStatusMap: key是用户ID,value是[isFollowing, isFollowingMe]
|
||||
func ConvertUsersToResponseWithMutualFollow(users []*model.User, followingStatusMap map[string][2]bool) []*UserResponse {
|
||||
result := make([]*UserResponse, 0, len(users))
|
||||
for _, user := range users {
|
||||
status, ok := followingStatusMap[user.ID]
|
||||
if ok {
|
||||
result = append(result, ConvertUserToResponseWithMutualFollow(user, status[0], status[1]))
|
||||
} else {
|
||||
result = append(result, ConvertUserToResponse(user))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ConvertUsersToResponseWithMutualFollowAndPostsCount 将User列表转换为响应列表(包含双向关注状态和实时计算的帖子数量)
|
||||
// followingStatusMap: key是用户ID,value是[isFollowing, isFollowingMe]
|
||||
// postsCountMap: key是用户ID,value是帖子数量
|
||||
func ConvertUsersToResponseWithMutualFollowAndPostsCount(users []*model.User, followingStatusMap map[string][2]bool, postsCountMap map[string]int64) []*UserResponse {
|
||||
result := make([]*UserResponse, 0, len(users))
|
||||
for _, user := range users {
|
||||
status, hasStatus := followingStatusMap[user.ID]
|
||||
postsCount, hasPostsCount := postsCountMap[user.ID]
|
||||
|
||||
// 如果没有帖子数量,使用数据库中的值
|
||||
if !hasPostsCount {
|
||||
postsCount = int64(user.PostsCount)
|
||||
}
|
||||
|
||||
if hasStatus {
|
||||
result = append(result, ConvertUserToResponseWithMutualFollowAndPostsCount(user, status[0], status[1], int(postsCount)))
|
||||
} else {
|
||||
result = append(result, ConvertUserToResponseWithPostsCount(user, int(postsCount)))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ==================== Post 转换 ====================
|
||||
|
||||
// ConvertPostImageToResponse 将PostImage转换为PostImageResponse
|
||||
func ConvertPostImageToResponse(img *model.PostImage) PostImageResponse {
|
||||
if img == nil {
|
||||
return PostImageResponse{}
|
||||
}
|
||||
return PostImageResponse{
|
||||
ID: img.ID,
|
||||
URL: img.URL,
|
||||
ThumbnailURL: img.ThumbnailURL,
|
||||
Width: img.Width,
|
||||
Height: img.Height,
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertPostImagesToResponse 将PostImage列表转换为响应列表
|
||||
func ConvertPostImagesToResponse(images []model.PostImage) []PostImageResponse {
|
||||
result := make([]PostImageResponse, 0, len(images))
|
||||
for i := range images {
|
||||
result = append(result, ConvertPostImageToResponse(&images[i]))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ConvertPostToResponse 将Post转换为PostResponse(列表用)
|
||||
func ConvertPostToResponse(post *model.Post, isLiked, isFavorited bool) *PostResponse {
|
||||
if post == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
images := make([]PostImageResponse, 0)
|
||||
for _, img := range post.Images {
|
||||
images = append(images, ConvertPostImageToResponse(&img))
|
||||
}
|
||||
|
||||
var author *UserResponse
|
||||
if post.User != nil {
|
||||
author = ConvertUserToResponse(post.User)
|
||||
}
|
||||
|
||||
return &PostResponse{
|
||||
ID: post.ID,
|
||||
UserID: post.UserID,
|
||||
Title: post.Title,
|
||||
Content: post.Content,
|
||||
Images: images,
|
||||
LikesCount: post.LikesCount,
|
||||
CommentsCount: post.CommentsCount,
|
||||
FavoritesCount: post.FavoritesCount,
|
||||
SharesCount: post.SharesCount,
|
||||
ViewsCount: post.ViewsCount,
|
||||
IsPinned: post.IsPinned,
|
||||
IsLocked: post.IsLocked,
|
||||
IsVote: post.IsVote,
|
||||
CreatedAt: FormatTime(post.CreatedAt),
|
||||
Author: author,
|
||||
IsLiked: isLiked,
|
||||
IsFavorited: isFavorited,
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertPostToDetailResponse 将Post转换为PostDetailResponse
|
||||
func ConvertPostToDetailResponse(post *model.Post, isLiked, isFavorited bool) *PostDetailResponse {
|
||||
if post == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
images := make([]PostImageResponse, 0)
|
||||
for _, img := range post.Images {
|
||||
images = append(images, ConvertPostImageToResponse(&img))
|
||||
}
|
||||
|
||||
var author *UserResponse
|
||||
if post.User != nil {
|
||||
author = ConvertUserToResponse(post.User)
|
||||
}
|
||||
|
||||
return &PostDetailResponse{
|
||||
ID: post.ID,
|
||||
UserID: post.UserID,
|
||||
Title: post.Title,
|
||||
Content: post.Content,
|
||||
Images: images,
|
||||
Status: string(post.Status),
|
||||
LikesCount: post.LikesCount,
|
||||
CommentsCount: post.CommentsCount,
|
||||
FavoritesCount: post.FavoritesCount,
|
||||
SharesCount: post.SharesCount,
|
||||
ViewsCount: post.ViewsCount,
|
||||
IsPinned: post.IsPinned,
|
||||
IsLocked: post.IsLocked,
|
||||
IsVote: post.IsVote,
|
||||
CreatedAt: FormatTime(post.CreatedAt),
|
||||
UpdatedAt: FormatTime(post.UpdatedAt),
|
||||
Author: author,
|
||||
IsLiked: isLiked,
|
||||
IsFavorited: isFavorited,
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertPostsToResponse 将Post列表转换为响应列表(每个帖子独立检查点赞/收藏状态)
|
||||
func ConvertPostsToResponse(posts []*model.Post, isLikedMap, isFavoritedMap map[string]bool) []*PostResponse {
|
||||
result := make([]*PostResponse, 0, len(posts))
|
||||
for _, post := range posts {
|
||||
isLiked := false
|
||||
isFavorited := false
|
||||
if isLikedMap != nil {
|
||||
isLiked = isLikedMap[post.ID]
|
||||
}
|
||||
if isFavoritedMap != nil {
|
||||
isFavorited = isFavoritedMap[post.ID]
|
||||
}
|
||||
result = append(result, ConvertPostToResponse(post, isLiked, isFavorited))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ==================== Comment 转换 ====================
|
||||
|
||||
// ConvertCommentToResponse 将Comment转换为CommentResponse
|
||||
func ConvertCommentToResponse(comment *model.Comment, isLiked bool) *CommentResponse {
|
||||
if comment == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var author *UserResponse
|
||||
if comment.User != nil {
|
||||
author = ConvertUserToResponse(comment.User)
|
||||
}
|
||||
|
||||
// 转换子回复(扁平化结构)
|
||||
var replies []*CommentResponse
|
||||
if len(comment.Replies) > 0 {
|
||||
replies = make([]*CommentResponse, 0, len(comment.Replies))
|
||||
for _, reply := range comment.Replies {
|
||||
replies = append(replies, ConvertCommentToResponse(reply, false))
|
||||
}
|
||||
}
|
||||
|
||||
// TargetID 就是 ParentID,前端根据这个 ID 找到被回复用户的昵称
|
||||
var targetID *string
|
||||
if comment.ParentID != nil && *comment.ParentID != "" {
|
||||
targetID = comment.ParentID
|
||||
}
|
||||
|
||||
// 解析图片JSON
|
||||
var images []CommentImageResponse
|
||||
if comment.Images != "" {
|
||||
var urlList []string
|
||||
if err := json.Unmarshal([]byte(comment.Images), &urlList); err == nil {
|
||||
images = make([]CommentImageResponse, 0, len(urlList))
|
||||
for _, url := range urlList {
|
||||
images = append(images, CommentImageResponse{URL: url})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &CommentResponse{
|
||||
ID: comment.ID,
|
||||
PostID: comment.PostID,
|
||||
UserID: comment.UserID,
|
||||
ParentID: comment.ParentID,
|
||||
RootID: comment.RootID,
|
||||
Content: comment.Content,
|
||||
Images: images,
|
||||
LikesCount: comment.LikesCount,
|
||||
RepliesCount: comment.RepliesCount,
|
||||
CreatedAt: FormatTime(comment.CreatedAt),
|
||||
Author: author,
|
||||
IsLiked: isLiked,
|
||||
TargetID: targetID,
|
||||
Replies: replies,
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertCommentsToResponse 将Comment列表转换为响应列表
|
||||
func ConvertCommentsToResponse(comments []*model.Comment, isLiked bool) []*CommentResponse {
|
||||
result := make([]*CommentResponse, 0, len(comments))
|
||||
for _, comment := range comments {
|
||||
result = append(result, ConvertCommentToResponse(comment, isLiked))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// IsLikedChecker 点赞状态检查器接口
|
||||
type IsLikedChecker interface {
|
||||
IsLiked(ctx context.Context, commentID, userID string) bool
|
||||
}
|
||||
|
||||
// ConvertCommentToResponseWithUser 将Comment转换为CommentResponse(根据用户ID检查点赞状态)
|
||||
func ConvertCommentToResponseWithUser(comment *model.Comment, userID string, checker IsLikedChecker) *CommentResponse {
|
||||
if comment == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查当前用户是否点赞了该评论
|
||||
isLiked := false
|
||||
if userID != "" && checker != nil {
|
||||
isLiked = checker.IsLiked(context.Background(), comment.ID, userID)
|
||||
}
|
||||
|
||||
var author *UserResponse
|
||||
if comment.User != nil {
|
||||
author = ConvertUserToResponse(comment.User)
|
||||
}
|
||||
|
||||
// 转换子回复(扁平化结构),递归检查点赞状态
|
||||
var replies []*CommentResponse
|
||||
if len(comment.Replies) > 0 {
|
||||
replies = make([]*CommentResponse, 0, len(comment.Replies))
|
||||
for _, reply := range comment.Replies {
|
||||
replies = append(replies, ConvertCommentToResponseWithUser(reply, userID, checker))
|
||||
}
|
||||
}
|
||||
|
||||
// TargetID 就是 ParentID,前端根据这个 ID 找到被回复用户的昵称
|
||||
var targetID *string
|
||||
if comment.ParentID != nil && *comment.ParentID != "" {
|
||||
targetID = comment.ParentID
|
||||
}
|
||||
|
||||
// 解析图片JSON
|
||||
var images []CommentImageResponse
|
||||
if comment.Images != "" {
|
||||
var urlList []string
|
||||
if err := json.Unmarshal([]byte(comment.Images), &urlList); err == nil {
|
||||
images = make([]CommentImageResponse, 0, len(urlList))
|
||||
for _, url := range urlList {
|
||||
images = append(images, CommentImageResponse{URL: url})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &CommentResponse{
|
||||
ID: comment.ID,
|
||||
PostID: comment.PostID,
|
||||
UserID: comment.UserID,
|
||||
ParentID: comment.ParentID,
|
||||
RootID: comment.RootID,
|
||||
Content: comment.Content,
|
||||
Images: images,
|
||||
LikesCount: comment.LikesCount,
|
||||
RepliesCount: comment.RepliesCount,
|
||||
CreatedAt: FormatTime(comment.CreatedAt),
|
||||
Author: author,
|
||||
IsLiked: isLiked,
|
||||
TargetID: targetID,
|
||||
Replies: replies,
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertCommentsToResponseWithUser 将Comment列表转换为响应列表(根据用户ID检查点赞状态)
|
||||
func ConvertCommentsToResponseWithUser(comments []*model.Comment, userID string, checker IsLikedChecker) []*CommentResponse {
|
||||
result := make([]*CommentResponse, 0, len(comments))
|
||||
for _, comment := range comments {
|
||||
result = append(result, ConvertCommentToResponseWithUser(comment, userID, checker))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ==================== Notification 转换 ====================
|
||||
|
||||
// ConvertNotificationToResponse 将Notification转换为NotificationResponse
|
||||
func ConvertNotificationToResponse(notification *model.Notification) *NotificationResponse {
|
||||
if notification == nil {
|
||||
return nil
|
||||
}
|
||||
return &NotificationResponse{
|
||||
ID: notification.ID,
|
||||
UserID: notification.UserID,
|
||||
Type: string(notification.Type),
|
||||
Title: notification.Title,
|
||||
Content: notification.Content,
|
||||
Data: notification.Data,
|
||||
IsRead: notification.IsRead,
|
||||
CreatedAt: FormatTime(notification.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertNotificationsToResponse 将Notification列表转换为响应列表
|
||||
func ConvertNotificationsToResponse(notifications []*model.Notification) []*NotificationResponse {
|
||||
result := make([]*NotificationResponse, 0, len(notifications))
|
||||
for _, n := range notifications {
|
||||
result = append(result, ConvertNotificationToResponse(n))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ==================== Message 转换 ====================
|
||||
|
||||
// ConvertMessageToResponse 将Message转换为MessageResponse
|
||||
func ConvertMessageToResponse(message *model.Message) *MessageResponse {
|
||||
if message == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 直接使用 segments,不需要解析
|
||||
segments := make(model.MessageSegments, len(message.Segments))
|
||||
for i, seg := range message.Segments {
|
||||
segments[i] = model.MessageSegment{
|
||||
Type: seg.Type,
|
||||
Data: seg.Data,
|
||||
}
|
||||
}
|
||||
|
||||
return &MessageResponse{
|
||||
ID: message.ID,
|
||||
ConversationID: message.ConversationID,
|
||||
SenderID: message.SenderID,
|
||||
Seq: message.Seq,
|
||||
Segments: segments,
|
||||
ReplyToID: message.ReplyToID,
|
||||
Status: string(message.Status),
|
||||
Category: string(message.Category),
|
||||
CreatedAt: FormatTime(message.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertConversationToResponse 将Conversation转换为ConversationResponse
|
||||
// participants: 会话参与者列表(用户信息)
|
||||
// unreadCount: 当前用户的未读消息数
|
||||
// lastMessage: 最后一条消息
|
||||
func ConvertConversationToResponse(conv *model.Conversation, participants []*model.User, unreadCount int, lastMessage *model.Message, isPinned bool) *ConversationResponse {
|
||||
if conv == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var participantResponses []*UserResponse
|
||||
for _, p := range participants {
|
||||
participantResponses = append(participantResponses, ConvertUserToResponse(p))
|
||||
}
|
||||
|
||||
// 转换群组信息
|
||||
var groupResponse *GroupResponse
|
||||
if conv.Group != nil {
|
||||
groupResponse = GroupToResponse(conv.Group)
|
||||
}
|
||||
|
||||
return &ConversationResponse{
|
||||
ID: conv.ID,
|
||||
Type: string(conv.Type),
|
||||
IsPinned: isPinned,
|
||||
Group: groupResponse,
|
||||
LastSeq: conv.LastSeq,
|
||||
LastMessage: ConvertMessageToResponse(lastMessage),
|
||||
LastMessageAt: FormatTimePointer(conv.LastMsgTime),
|
||||
UnreadCount: unreadCount,
|
||||
Participants: participantResponses,
|
||||
CreatedAt: FormatTime(conv.CreatedAt),
|
||||
UpdatedAt: FormatTime(conv.UpdatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertConversationToDetailResponse 将Conversation转换为ConversationDetailResponse
|
||||
func ConvertConversationToDetailResponse(conv *model.Conversation, participants []*model.User, unreadCount int64, lastMessage *model.Message, myLastReadSeq int64, otherLastReadSeq int64, isPinned bool) *ConversationDetailResponse {
|
||||
if conv == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var participantResponses []*UserResponse
|
||||
for _, p := range participants {
|
||||
participantResponses = append(participantResponses, ConvertUserToResponse(p))
|
||||
}
|
||||
|
||||
return &ConversationDetailResponse{
|
||||
ID: conv.ID,
|
||||
Type: string(conv.Type),
|
||||
IsPinned: isPinned,
|
||||
LastSeq: conv.LastSeq,
|
||||
LastMessage: ConvertMessageToResponse(lastMessage),
|
||||
LastMessageAt: FormatTimePointer(conv.LastMsgTime),
|
||||
UnreadCount: unreadCount,
|
||||
Participants: participantResponses,
|
||||
MyLastReadSeq: myLastReadSeq,
|
||||
OtherLastReadSeq: otherLastReadSeq,
|
||||
CreatedAt: FormatTime(conv.CreatedAt),
|
||||
UpdatedAt: FormatTime(conv.UpdatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertMessagesToResponse 将Message列表转换为响应列表
|
||||
func ConvertMessagesToResponse(messages []*model.Message) []*MessageResponse {
|
||||
result := make([]*MessageResponse, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
result = append(result, ConvertMessageToResponse(msg))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ConvertConversationsToResponse 将Conversation列表转换为响应列表
|
||||
func ConvertConversationsToResponse(convs []*model.Conversation) []*ConversationResponse {
|
||||
result := make([]*ConversationResponse, 0, len(convs))
|
||||
for _, conv := range convs {
|
||||
result = append(result, ConvertConversationToResponse(conv, nil, 0, nil, false))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ==================== PushRecord 转换 ====================
|
||||
|
||||
// PushRecordToResponse 将PushRecord转换为PushRecordResponse
|
||||
func PushRecordToResponse(record *model.PushRecord) *PushRecordResponse {
|
||||
if record == nil {
|
||||
return nil
|
||||
}
|
||||
resp := &PushRecordResponse{
|
||||
ID: record.ID,
|
||||
MessageID: record.MessageID,
|
||||
PushChannel: string(record.PushChannel),
|
||||
PushStatus: string(record.PushStatus),
|
||||
RetryCount: record.RetryCount,
|
||||
CreatedAt: record.CreatedAt,
|
||||
}
|
||||
if record.PushedAt != nil {
|
||||
resp.PushedAt = *record.PushedAt
|
||||
}
|
||||
if record.DeliveredAt != nil {
|
||||
resp.DeliveredAt = *record.DeliveredAt
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// PushRecordsToResponse 将PushRecord列表转换为响应列表
|
||||
func PushRecordsToResponse(records []*model.PushRecord) []*PushRecordResponse {
|
||||
result := make([]*PushRecordResponse, 0, len(records))
|
||||
for _, record := range records {
|
||||
result = append(result, PushRecordToResponse(record))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ==================== DeviceToken 转换 ====================
|
||||
|
||||
// DeviceTokenToResponse 将DeviceToken转换为DeviceTokenResponse
|
||||
func DeviceTokenToResponse(token *model.DeviceToken) *DeviceTokenResponse {
|
||||
if token == nil {
|
||||
return nil
|
||||
}
|
||||
resp := &DeviceTokenResponse{
|
||||
ID: token.ID,
|
||||
DeviceID: token.DeviceID,
|
||||
DeviceType: string(token.DeviceType),
|
||||
IsActive: token.IsActive,
|
||||
DeviceName: token.DeviceName,
|
||||
CreatedAt: token.CreatedAt,
|
||||
}
|
||||
if token.LastUsedAt != nil {
|
||||
resp.LastUsedAt = *token.LastUsedAt
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// DeviceTokensToResponse 将DeviceToken列表转换为响应列表
|
||||
func DeviceTokensToResponse(tokens []*model.DeviceToken) []*DeviceTokenResponse {
|
||||
result := make([]*DeviceTokenResponse, 0, len(tokens))
|
||||
for _, token := range tokens {
|
||||
result = append(result, DeviceTokenToResponse(token))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ==================== SystemMessage 转换 ====================
|
||||
|
||||
// SystemMessageToResponse 将Message转换为SystemMessageResponse
|
||||
func SystemMessageToResponse(msg *model.Message) *SystemMessageResponse {
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 从 segments 中提取文本内容
|
||||
content := ExtractTextContentFromModel(msg.Segments)
|
||||
|
||||
resp := &SystemMessageResponse{
|
||||
ID: msg.ID,
|
||||
SenderID: msg.SenderID,
|
||||
ReceiverID: "", // 系统消息的接收者需要从上下文获取
|
||||
Content: content,
|
||||
Category: string(msg.Category),
|
||||
SystemType: string(msg.SystemType),
|
||||
CreatedAt: msg.CreatedAt,
|
||||
}
|
||||
if msg.ExtraData != nil {
|
||||
resp.ExtraData = map[string]interface{}{
|
||||
"actor_id": msg.ExtraData.ActorID,
|
||||
"actor_name": msg.ExtraData.ActorName,
|
||||
"avatar_url": msg.ExtraData.AvatarURL,
|
||||
"target_id": msg.ExtraData.TargetID,
|
||||
"target_type": msg.ExtraData.TargetType,
|
||||
"action_url": msg.ExtraData.ActionURL,
|
||||
"action_time": msg.ExtraData.ActionTime,
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// SystemMessagesToResponse 将Message列表转换为SystemMessageResponse列表
|
||||
func SystemMessagesToResponse(messages []*model.Message) []*SystemMessageResponse {
|
||||
result := make([]*SystemMessageResponse, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
result = append(result, SystemMessageToResponse(msg))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// SystemNotificationToResponse 将SystemNotification转换为SystemMessageResponse
|
||||
func SystemNotificationToResponse(n *model.SystemNotification) *SystemMessageResponse {
|
||||
if n == nil {
|
||||
return nil
|
||||
}
|
||||
resp := &SystemMessageResponse{
|
||||
ID: strconv.FormatInt(n.ID, 10),
|
||||
SenderID: model.SystemSenderIDStr, // 系统发送者
|
||||
ReceiverID: n.ReceiverID,
|
||||
Content: n.Content,
|
||||
Category: "notification",
|
||||
SystemType: string(n.Type),
|
||||
IsRead: n.IsRead,
|
||||
CreatedAt: n.CreatedAt,
|
||||
}
|
||||
if n.ExtraData != nil {
|
||||
resp.ExtraData = map[string]interface{}{
|
||||
"actor_id": n.ExtraData.ActorID,
|
||||
"actor_id_str": n.ExtraData.ActorIDStr,
|
||||
"actor_name": n.ExtraData.ActorName,
|
||||
"avatar_url": n.ExtraData.AvatarURL,
|
||||
"target_id": n.ExtraData.TargetID,
|
||||
"target_title": n.ExtraData.TargetTitle,
|
||||
"target_type": n.ExtraData.TargetType,
|
||||
"action_url": n.ExtraData.ActionURL,
|
||||
"action_time": n.ExtraData.ActionTime,
|
||||
"group_id": n.ExtraData.GroupID,
|
||||
"group_name": n.ExtraData.GroupName,
|
||||
"group_avatar": n.ExtraData.GroupAvatar,
|
||||
"group_description": n.ExtraData.GroupDescription,
|
||||
"flag": n.ExtraData.Flag,
|
||||
"request_type": n.ExtraData.RequestType,
|
||||
"request_status": n.ExtraData.RequestStatus,
|
||||
"reason": n.ExtraData.Reason,
|
||||
"target_user_id": n.ExtraData.TargetUserID,
|
||||
"target_user_name": n.ExtraData.TargetUserName,
|
||||
"target_user_avatar": n.ExtraData.TargetUserAvatar,
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// SystemNotificationsToResponse 将SystemNotification列表转换为SystemMessageResponse列表
|
||||
func SystemNotificationsToResponse(notifications []*model.SystemNotification) []*SystemMessageResponse {
|
||||
result := make([]*SystemMessageResponse, 0, len(notifications))
|
||||
for _, n := range notifications {
|
||||
result = append(result, SystemNotificationToResponse(n))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ==================== Group 转换 ====================
|
||||
|
||||
// GroupToResponse 将Group转换为GroupResponse
|
||||
func GroupToResponse(group *model.Group) *GroupResponse {
|
||||
if group == nil {
|
||||
return nil
|
||||
}
|
||||
return &GroupResponse{
|
||||
ID: group.ID,
|
||||
Name: group.Name,
|
||||
Avatar: group.Avatar,
|
||||
Description: group.Description,
|
||||
OwnerID: group.OwnerID,
|
||||
MemberCount: group.MemberCount,
|
||||
MaxMembers: group.MaxMembers,
|
||||
JoinType: int(group.JoinType),
|
||||
MuteAll: group.MuteAll,
|
||||
CreatedAt: FormatTime(group.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// GroupsToResponse 将Group列表转换为GroupResponse列表
|
||||
func GroupsToResponse(groups []model.Group) []*GroupResponse {
|
||||
result := make([]*GroupResponse, 0, len(groups))
|
||||
for i := range groups {
|
||||
result = append(result, GroupToResponse(&groups[i]))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GroupMemberToResponse 将GroupMember转换为GroupMemberResponse
|
||||
func GroupMemberToResponse(member *model.GroupMember) *GroupMemberResponse {
|
||||
if member == nil {
|
||||
return nil
|
||||
}
|
||||
return &GroupMemberResponse{
|
||||
ID: member.ID,
|
||||
GroupID: member.GroupID,
|
||||
UserID: member.UserID,
|
||||
Role: member.Role,
|
||||
Nickname: member.Nickname,
|
||||
Muted: member.Muted,
|
||||
JoinTime: FormatTime(member.JoinTime),
|
||||
}
|
||||
}
|
||||
|
||||
// GroupMemberToResponseWithUser 将GroupMember转换为GroupMemberResponse(包含用户信息)
|
||||
func GroupMemberToResponseWithUser(member *model.GroupMember, user *model.User) *GroupMemberResponse {
|
||||
if member == nil {
|
||||
return nil
|
||||
}
|
||||
resp := GroupMemberToResponse(member)
|
||||
if user != nil {
|
||||
resp.User = ConvertUserToResponse(user)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// GroupMembersToResponse 将GroupMember列表转换为GroupMemberResponse列表
|
||||
func GroupMembersToResponse(members []model.GroupMember) []*GroupMemberResponse {
|
||||
result := make([]*GroupMemberResponse, 0, len(members))
|
||||
for i := range members {
|
||||
result = append(result, GroupMemberToResponse(&members[i]))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GroupAnnouncementToResponse 将GroupAnnouncement转换为GroupAnnouncementResponse
|
||||
func GroupAnnouncementToResponse(announcement *model.GroupAnnouncement) *GroupAnnouncementResponse {
|
||||
if announcement == nil {
|
||||
return nil
|
||||
}
|
||||
return &GroupAnnouncementResponse{
|
||||
ID: announcement.ID,
|
||||
GroupID: announcement.GroupID,
|
||||
Content: announcement.Content,
|
||||
AuthorID: announcement.AuthorID,
|
||||
IsPinned: announcement.IsPinned,
|
||||
CreatedAt: FormatTime(announcement.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// GroupAnnouncementsToResponse 将GroupAnnouncement列表转换为GroupAnnouncementResponse列表
|
||||
func GroupAnnouncementsToResponse(announcements []model.GroupAnnouncement) []*GroupAnnouncementResponse {
|
||||
result := make([]*GroupAnnouncementResponse, 0, len(announcements))
|
||||
for i := range announcements {
|
||||
result = append(result, GroupAnnouncementToResponse(&announcements[i]))
|
||||
}
|
||||
return result
|
||||
}
|
||||
819
internal/dto/dto.go
Normal file
819
internal/dto/dto.go
Normal file
@@ -0,0 +1,819 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ==================== User DTOs ====================
|
||||
|
||||
// UserResponse 用户信息响应
|
||||
type UserResponse struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Nickname string `json:"nickname"`
|
||||
Email *string `json:"email,omitempty"`
|
||||
Phone *string `json:"phone,omitempty"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Avatar string `json:"avatar"`
|
||||
CoverURL string `json:"cover_url"` // 头图URL
|
||||
Bio string `json:"bio"`
|
||||
Website string `json:"website"`
|
||||
Location string `json:"location"`
|
||||
PostsCount int `json:"posts_count"`
|
||||
FollowersCount int `json:"followers_count"`
|
||||
FollowingCount int `json:"following_count"`
|
||||
IsFollowing bool `json:"is_following"` // 当前用户是否关注了该用户
|
||||
IsFollowingMe bool `json:"is_following_me"` // 该用户是否关注了当前用户
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// UserDetailResponse 用户详情响应
|
||||
type UserDetailResponse struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Nickname string `json:"nickname"`
|
||||
Email *string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Phone *string `json:"phone,omitempty"` // 仅当前用户自己可见
|
||||
Avatar string `json:"avatar"`
|
||||
CoverURL string `json:"cover_url"` // 头图URL
|
||||
Bio string `json:"bio"`
|
||||
Website string `json:"website"`
|
||||
Location string `json:"location"`
|
||||
PostsCount int `json:"posts_count"`
|
||||
FollowersCount int `json:"followers_count"`
|
||||
FollowingCount int `json:"following_count"`
|
||||
IsVerified bool `json:"is_verified"`
|
||||
IsFollowing bool `json:"is_following"` // 当前用户是否关注了该用户
|
||||
IsFollowingMe bool `json:"is_following_me"` // 该用户是否关注了当前用户
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// ==================== Post DTOs ====================
|
||||
|
||||
// PostImageResponse 帖子图片响应
|
||||
type PostImageResponse struct {
|
||||
ID string `json:"id"`
|
||||
URL string `json:"url"`
|
||||
ThumbnailURL string `json:"thumbnail_url"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
}
|
||||
|
||||
// PostResponse 帖子响应(列表用)
|
||||
type PostResponse struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Images []PostImageResponse `json:"images"`
|
||||
LikesCount int `json:"likes_count"`
|
||||
CommentsCount int `json:"comments_count"`
|
||||
FavoritesCount int `json:"favorites_count"`
|
||||
SharesCount int `json:"shares_count"`
|
||||
ViewsCount int `json:"views_count"`
|
||||
IsPinned bool `json:"is_pinned"`
|
||||
IsLocked bool `json:"is_locked"`
|
||||
IsVote bool `json:"is_vote"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Author *UserResponse `json:"author"`
|
||||
IsLiked bool `json:"is_liked"`
|
||||
IsFavorited bool `json:"is_favorited"`
|
||||
}
|
||||
|
||||
// PostDetailResponse 帖子详情响应
|
||||
type PostDetailResponse struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Images []PostImageResponse `json:"images"`
|
||||
Status string `json:"status"`
|
||||
LikesCount int `json:"likes_count"`
|
||||
CommentsCount int `json:"comments_count"`
|
||||
FavoritesCount int `json:"favorites_count"`
|
||||
SharesCount int `json:"shares_count"`
|
||||
ViewsCount int `json:"views_count"`
|
||||
IsPinned bool `json:"is_pinned"`
|
||||
IsLocked bool `json:"is_locked"`
|
||||
IsVote bool `json:"is_vote"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
Author *UserResponse `json:"author"`
|
||||
IsLiked bool `json:"is_liked"`
|
||||
IsFavorited bool `json:"is_favorited"`
|
||||
}
|
||||
|
||||
// ==================== Comment DTOs ====================
|
||||
|
||||
// CommentImageResponse 评论图片响应
|
||||
type CommentImageResponse struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// CommentResponse 评论响应(扁平化结构,类似B站/抖音)
|
||||
// 第一层级正常展示,第二三四五层级在第一层级的评论区扁平展示
|
||||
type CommentResponse struct {
|
||||
ID string `json:"id"`
|
||||
PostID string `json:"post_id"`
|
||||
UserID string `json:"user_id"`
|
||||
ParentID *string `json:"parent_id"`
|
||||
RootID *string `json:"root_id"`
|
||||
Content string `json:"content"`
|
||||
Images []CommentImageResponse `json:"images"`
|
||||
LikesCount int `json:"likes_count"`
|
||||
RepliesCount int `json:"replies_count"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Author *UserResponse `json:"author"`
|
||||
IsLiked bool `json:"is_liked"`
|
||||
TargetID *string `json:"target_id,omitempty"` // 被回复的评论ID,前端根据此ID找到被回复用户的昵称
|
||||
Replies []*CommentResponse `json:"replies,omitempty"` // 子回复列表(扁平化,所有层级都在这里)
|
||||
}
|
||||
|
||||
// ==================== Notification DTOs ====================
|
||||
|
||||
// NotificationResponse 通知响应
|
||||
type NotificationResponse struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Type string `json:"type"`
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Data string `json:"data"`
|
||||
IsRead bool `json:"is_read"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// ==================== Message Segment DTOs ====================
|
||||
|
||||
// SegmentType Segment类型
|
||||
type SegmentType string
|
||||
|
||||
const (
|
||||
SegmentTypeText SegmentType = "text"
|
||||
SegmentTypeImage SegmentType = "image"
|
||||
SegmentTypeVoice SegmentType = "voice"
|
||||
SegmentTypeVideo SegmentType = "video"
|
||||
SegmentTypeFile SegmentType = "file"
|
||||
SegmentTypeAt SegmentType = "at"
|
||||
SegmentTypeReply SegmentType = "reply"
|
||||
SegmentTypeFace SegmentType = "face"
|
||||
SegmentTypeLink SegmentType = "link"
|
||||
)
|
||||
|
||||
// TextSegmentData 文本数据
|
||||
type TextSegmentData struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// ImageSegmentData 图片数据
|
||||
type ImageSegmentData struct {
|
||||
URL string `json:"url"`
|
||||
Width int `json:"width,omitempty"`
|
||||
Height int `json:"height,omitempty"`
|
||||
ThumbnailURL string `json:"thumbnail_url,omitempty"`
|
||||
FileSize int64 `json:"file_size,omitempty"`
|
||||
}
|
||||
|
||||
// VoiceSegmentData 语音数据
|
||||
type VoiceSegmentData struct {
|
||||
URL string `json:"url"`
|
||||
Duration int `json:"duration,omitempty"` // 秒
|
||||
FileSize int64 `json:"file_size,omitempty"`
|
||||
}
|
||||
|
||||
// VideoSegmentData 视频数据
|
||||
type VideoSegmentData struct {
|
||||
URL string `json:"url"`
|
||||
Width int `json:"width,omitempty"`
|
||||
Height int `json:"height,omitempty"`
|
||||
Duration int `json:"duration,omitempty"` // 秒
|
||||
ThumbnailURL string `json:"thumbnail_url,omitempty"`
|
||||
FileSize int64 `json:"file_size,omitempty"`
|
||||
}
|
||||
|
||||
// FileSegmentData 文件数据
|
||||
type FileSegmentData struct {
|
||||
URL string `json:"url"`
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size,omitempty"`
|
||||
MimeType string `json:"mime_type,omitempty"`
|
||||
}
|
||||
|
||||
// AtSegmentData @数据
|
||||
type AtSegmentData struct {
|
||||
UserID string `json:"user_id"` // "all" 表示@所有人
|
||||
Nickname string `json:"nickname,omitempty"`
|
||||
}
|
||||
|
||||
// ReplySegmentData 回复数据
|
||||
type ReplySegmentData struct {
|
||||
ID string `json:"id"` // 被回复消息的ID
|
||||
}
|
||||
|
||||
// FaceSegmentData 表情数据
|
||||
type FaceSegmentData struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
// LinkSegmentData 链接数据
|
||||
type LinkSegmentData struct {
|
||||
URL string `json:"url"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
}
|
||||
|
||||
// ==================== Message DTOs ====================
|
||||
|
||||
// MessageResponse 消息响应
|
||||
type MessageResponse struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
SenderID string `json:"sender_id"`
|
||||
Seq int64 `json:"seq"`
|
||||
Segments model.MessageSegments `json:"segments"` // 消息链(必须)
|
||||
ReplyToID *string `json:"reply_to_id,omitempty"` // 被回复消息的ID(用于关联查找)
|
||||
Status string `json:"status"`
|
||||
Category string `json:"category,omitempty"` // 消息类别:chat, notification, announcement
|
||||
CreatedAt string `json:"created_at"`
|
||||
Sender *UserResponse `json:"sender"`
|
||||
}
|
||||
|
||||
// ConversationResponse 会话响应
|
||||
type ConversationResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
IsPinned bool `json:"is_pinned"`
|
||||
Group *GroupResponse `json:"group,omitempty"`
|
||||
LastSeq int64 `json:"last_seq"`
|
||||
LastMessage *MessageResponse `json:"last_message"`
|
||||
LastMessageAt string `json:"last_message_at"`
|
||||
UnreadCount int `json:"unread_count"`
|
||||
Participants []*UserResponse `json:"participants,omitempty"` // 私聊时使用
|
||||
MemberCount int `json:"member_count,omitempty"` // 群聊时使用
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ConversationParticipantResponse 会话参与者响应
|
||||
type ConversationParticipantResponse struct {
|
||||
UserID string `json:"user_id"`
|
||||
LastReadSeq int64 `json:"last_read_seq"`
|
||||
Muted bool `json:"muted"`
|
||||
IsPinned bool `json:"is_pinned"`
|
||||
}
|
||||
|
||||
// ==================== Auth DTOs ====================
|
||||
|
||||
// LoginResponse 登录响应
|
||||
type LoginResponse struct {
|
||||
User *UserResponse `json:"user"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
// RegisterResponse 注册响应
|
||||
type RegisterResponse struct {
|
||||
User *UserResponse `json:"user"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
// RefreshTokenResponse 刷新Token响应
|
||||
type RefreshTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
// ==================== Common DTOs ====================
|
||||
|
||||
// SuccessResponse 通用成功响应
|
||||
type SuccessResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// AvailableResponse 可用性检查响应
|
||||
type AvailableResponse struct {
|
||||
Available bool `json:"available"`
|
||||
}
|
||||
|
||||
// CountResponse 数量响应
|
||||
type CountResponse struct {
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
// URLResponse URL响应
|
||||
type URLResponse struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// ==================== Chat Request DTOs ====================
|
||||
|
||||
// CreateConversationRequest 创建会话请求
|
||||
type CreateConversationRequest struct {
|
||||
UserID string `json:"user_id" binding:"required"` // 目标用户ID (UUID格式)
|
||||
}
|
||||
|
||||
// SendMessageRequest 发送消息请求
|
||||
type SendMessageRequest struct {
|
||||
Segments model.MessageSegments `json:"segments" binding:"required"` // 消息链(必须)
|
||||
ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID (string类型)
|
||||
}
|
||||
|
||||
// MarkReadRequest 标记已读请求
|
||||
type MarkReadRequest struct {
|
||||
LastReadSeq int64 `json:"last_read_seq" binding:"required"` // 已读到的seq位置
|
||||
}
|
||||
|
||||
// SetConversationPinnedRequest 设置会话置顶请求
|
||||
type SetConversationPinnedRequest struct {
|
||||
ConversationID string `json:"conversation_id" binding:"required"`
|
||||
IsPinned bool `json:"is_pinned"`
|
||||
}
|
||||
|
||||
// ==================== Chat Response DTOs ====================
|
||||
|
||||
// ConversationListResponse 会话列表响应
|
||||
type ConversationListResponse struct {
|
||||
Conversations []*ConversationResponse `json:"list"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
// ConversationDetailResponse 会话详情响应
|
||||
type ConversationDetailResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
IsPinned bool `json:"is_pinned"`
|
||||
LastSeq int64 `json:"last_seq"`
|
||||
LastMessage *MessageResponse `json:"last_message"`
|
||||
LastMessageAt string `json:"last_message_at"`
|
||||
UnreadCount int64 `json:"unread_count"`
|
||||
Participants []*UserResponse `json:"participants"`
|
||||
MyLastReadSeq int64 `json:"my_last_read_seq"` // 当前用户的已读位置
|
||||
OtherLastReadSeq int64 `json:"other_last_read_seq"` // 对方用户的已读位置
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// UnreadCountResponse 未读数响应
|
||||
type UnreadCountResponse struct {
|
||||
TotalUnreadCount int64 `json:"total_unread_count"` // 所有会话的未读总数
|
||||
}
|
||||
|
||||
// ConversationUnreadCountResponse 单个会话未读数响应
|
||||
type ConversationUnreadCountResponse struct {
|
||||
ConversationID string `json:"conversation_id"`
|
||||
UnreadCount int64 `json:"unread_count"`
|
||||
}
|
||||
|
||||
// MessageListResponse 消息列表响应
|
||||
type MessageListResponse struct {
|
||||
Messages []*MessageResponse `json:"messages"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
// MessageSyncResponse 消息同步响应(增量同步)
|
||||
type MessageSyncResponse struct {
|
||||
Messages []*MessageResponse `json:"messages"`
|
||||
HasMore bool `json:"has_more"`
|
||||
}
|
||||
|
||||
// ==================== 设备Token DTOs ====================
|
||||
|
||||
// RegisterDeviceRequest 注册设备请求
|
||||
type RegisterDeviceRequest struct {
|
||||
DeviceID string `json:"device_id" binding:"required"`
|
||||
DeviceType string `json:"device_type" binding:"required,oneof=ios android web"`
|
||||
PushToken string `json:"push_token"`
|
||||
DeviceName string `json:"device_name"`
|
||||
}
|
||||
|
||||
// DeviceTokenResponse 设备Token响应
|
||||
type DeviceTokenResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
DeviceID string `json:"device_id"`
|
||||
DeviceType string `json:"device_type"`
|
||||
IsActive bool `json:"is_active"`
|
||||
DeviceName string `json:"device_name"`
|
||||
LastUsedAt time.Time `json:"last_used_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// ==================== 推送记录 DTOs ====================
|
||||
|
||||
// PushRecordResponse 推送记录响应
|
||||
type PushRecordResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
MessageID string `json:"message_id"`
|
||||
PushChannel string `json:"push_channel"`
|
||||
PushStatus string `json:"push_status"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
PushedAt time.Time `json:"pushed_at,omitempty"`
|
||||
DeliveredAt time.Time `json:"delivered_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// PushRecordListResponse 推送记录列表响应
|
||||
type PushRecordListResponse struct {
|
||||
Records []*PushRecordResponse `json:"records"`
|
||||
Total int64 `json:"total"`
|
||||
}
|
||||
|
||||
// ==================== 系统消息 DTOs ====================
|
||||
|
||||
// SystemMessageResponse 系统消息响应
|
||||
type SystemMessageResponse struct {
|
||||
ID string `json:"id"`
|
||||
SenderID string `json:"sender_id"`
|
||||
ReceiverID string `json:"receiver_id"`
|
||||
Content string `json:"content"`
|
||||
Category string `json:"category"`
|
||||
SystemType string `json:"system_type"`
|
||||
ExtraData map[string]interface{} `json:"extra_data,omitempty"`
|
||||
IsRead bool `json:"is_read"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SystemMessageListResponse 系统消息列表响应
|
||||
type SystemMessageListResponse struct {
|
||||
Messages []*SystemMessageResponse `json:"messages"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
// SystemUnreadCountResponse 系统消息未读数响应
|
||||
type SystemUnreadCountResponse struct {
|
||||
UnreadCount int64 `json:"unread_count"`
|
||||
}
|
||||
|
||||
// ==================== 时间格式化 ====================
|
||||
|
||||
// FormatTime 格式化时间
|
||||
func FormatTime(t time.Time) string {
|
||||
if t.IsZero() {
|
||||
return ""
|
||||
}
|
||||
return t.Format("2006-01-02T15:04:05Z07:00")
|
||||
}
|
||||
|
||||
// FormatTimePointer 格式化时间指针
|
||||
func FormatTimePointer(t *time.Time) string {
|
||||
if t == nil {
|
||||
return ""
|
||||
}
|
||||
return FormatTime(*t)
|
||||
}
|
||||
|
||||
// ==================== Group DTOs ====================
|
||||
|
||||
// CreateGroupRequest 创建群组请求
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name" binding:"required,max=50"`
|
||||
Description string `json:"description" binding:"max=500"`
|
||||
MemberIDs []string `json:"member_ids"`
|
||||
}
|
||||
|
||||
// UpdateGroupRequest 更新群组请求
|
||||
type UpdateGroupRequest struct {
|
||||
Name string `json:"name" binding:"omitempty,max=50"`
|
||||
Description string `json:"description" binding:"omitempty,max=500"`
|
||||
Avatar string `json:"avatar" binding:"omitempty,url"`
|
||||
}
|
||||
|
||||
// InviteMembersRequest 邀请成员请求
|
||||
type InviteMembersRequest struct {
|
||||
MemberIDs []string `json:"member_ids" binding:"required,min=1"`
|
||||
}
|
||||
|
||||
// TransferOwnerRequest 转让群主请求
|
||||
type TransferOwnerRequest struct {
|
||||
NewOwnerID string `json:"new_owner_id" binding:"required"`
|
||||
}
|
||||
|
||||
// SetRoleRequest 设置角色请求
|
||||
type SetRoleRequest struct {
|
||||
Role string `json:"role" binding:"required,oneof=admin member"`
|
||||
}
|
||||
|
||||
// SetNicknameRequest 设置昵称请求
|
||||
type SetNicknameRequest struct {
|
||||
Nickname string `json:"nickname" binding:"max=50"`
|
||||
}
|
||||
|
||||
// MuteMemberRequest 禁言成员请求
|
||||
type MuteMemberRequest struct {
|
||||
Muted bool `json:"muted"`
|
||||
}
|
||||
|
||||
// SetMuteAllRequest 设置全员禁言请求
|
||||
type SetMuteAllRequest struct {
|
||||
MuteAll bool `json:"mute_all"`
|
||||
}
|
||||
|
||||
// SetJoinTypeRequest 设置加群方式请求
|
||||
type SetJoinTypeRequest struct {
|
||||
JoinType int `json:"join_type" binding:"min=0,max=2"`
|
||||
}
|
||||
|
||||
// CreateAnnouncementRequest 创建群公告请求
|
||||
type CreateAnnouncementRequest struct {
|
||||
Content string `json:"content" binding:"required,max=2000"`
|
||||
}
|
||||
|
||||
// GroupResponse 群组响应
|
||||
type GroupResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Avatar string `json:"avatar"`
|
||||
Description string `json:"description"`
|
||||
OwnerID string `json:"owner_id"`
|
||||
MemberCount int `json:"member_count"`
|
||||
MaxMembers int `json:"max_members"`
|
||||
JoinType int `json:"join_type"`
|
||||
MuteAll bool `json:"mute_all"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// GroupMemberResponse 群成员响应
|
||||
type GroupMemberResponse struct {
|
||||
ID string `json:"id"`
|
||||
GroupID string `json:"group_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Role string `json:"role"`
|
||||
Nickname string `json:"nickname"`
|
||||
Muted bool `json:"muted"`
|
||||
JoinTime string `json:"join_time"`
|
||||
User *UserResponse `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// GroupAnnouncementResponse 群公告响应
|
||||
type GroupAnnouncementResponse struct {
|
||||
ID string `json:"id"`
|
||||
GroupID string `json:"group_id"`
|
||||
Content string `json:"content"`
|
||||
AuthorID string `json:"author_id"`
|
||||
IsPinned bool `json:"is_pinned"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// GroupListResponse 群组列表响应
|
||||
type GroupListResponse struct {
|
||||
List []*GroupResponse `json:"list"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
// GroupMemberListResponse 群成员列表响应
|
||||
type GroupMemberListResponse struct {
|
||||
List []*GroupMemberResponse `json:"list"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
// GroupAnnouncementListResponse 群公告列表响应
|
||||
type GroupAnnouncementListResponse struct {
|
||||
List []*GroupAnnouncementResponse `json:"list"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
// ==================== WebSocket Event DTOs ====================
|
||||
|
||||
// WSEventResponse WebSocket事件响应结构体
|
||||
// 用于后端推送消息给前端的标准格式
|
||||
type WSEventResponse struct {
|
||||
ID string `json:"id"` // 事件唯一ID (UUID)
|
||||
Time int64 `json:"time"` // 时间戳 (毫秒)
|
||||
Type string `json:"type"` // 事件类型 (message, notification, system等)
|
||||
DetailType string `json:"detail_type"` // 详细类型 (private, group, like, comment等)
|
||||
Seq string `json:"seq"` // 消息序列号
|
||||
Segments model.MessageSegments `json:"segments"` // 消息段数组
|
||||
SenderID string `json:"sender_id"` // 发送者用户ID
|
||||
}
|
||||
|
||||
// ==================== WebSocket Request DTOs ====================
|
||||
|
||||
// SendMessageParams 发送消息参数(用于 REST API)
|
||||
type SendMessageParams struct {
|
||||
DetailType string `json:"detail_type"` // 消息类型: private, group
|
||||
ConversationID string `json:"conversation_id"` // 会话ID
|
||||
Segments model.MessageSegments `json:"segments"` // 消息内容(消息段数组)
|
||||
ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID
|
||||
}
|
||||
|
||||
// DeleteMsgParams 撤回消息参数
|
||||
type DeleteMsgParams struct {
|
||||
MessageID string `json:"message_id"` // 消息ID
|
||||
}
|
||||
|
||||
// ==================== Group Action Params ====================
|
||||
|
||||
// SetGroupKickParams 群组踢人参数
|
||||
type SetGroupKickParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
UserID string `json:"user_id"` // 被踢用户ID
|
||||
RejectAddRequest bool `json:"reject_add_request"` // 是否拒绝再次加群
|
||||
}
|
||||
|
||||
// SetGroupBanParams 群组单人禁言参数
|
||||
type SetGroupBanParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
UserID string `json:"user_id"` // 被禁言用户ID
|
||||
Duration int64 `json:"duration"` // 禁言时长(秒),0表示解除禁言
|
||||
}
|
||||
|
||||
// SetGroupWholeBanParams 群组全员禁言参数
|
||||
type SetGroupWholeBanParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
Enable bool `json:"enable"` // 是否开启全员禁言
|
||||
}
|
||||
|
||||
// SetGroupAdminParams 群组设置管理员参数
|
||||
type SetGroupAdminParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
UserID string `json:"user_id"` // 被设置的用户ID
|
||||
Enable bool `json:"enable"` // 是否设置为管理员
|
||||
}
|
||||
|
||||
// SetGroupNameParams 设置群名参数
|
||||
type SetGroupNameParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
GroupName string `json:"group_name"` // 新群名
|
||||
}
|
||||
|
||||
// SetGroupAvatarParams 设置群头像参数
|
||||
type SetGroupAvatarParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
Avatar string `json:"avatar"` // 头像URL
|
||||
}
|
||||
|
||||
// SetGroupLeaveParams 退出群组参数
|
||||
type SetGroupLeaveParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
}
|
||||
|
||||
// SetGroupAddRequestParams 处理加群请求参数
|
||||
type SetGroupAddRequestParams struct {
|
||||
Flag string `json:"flag"` // 加群请求的flag标识
|
||||
Approve bool `json:"approve"` // 是否同意
|
||||
Reason string `json:"reason"` // 拒绝理由(当approve为false时)
|
||||
}
|
||||
|
||||
// GetConversationListParams 获取会话列表参数
|
||||
type GetConversationListParams struct {
|
||||
Page int `json:"page"` // 页码
|
||||
PageSize int `json:"page_size"` // 每页数量
|
||||
}
|
||||
|
||||
// GetGroupInfoParams 获取群信息参数
|
||||
type GetGroupInfoParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
}
|
||||
|
||||
// GetGroupMemberListParams 获取群成员列表参数
|
||||
type GetGroupMemberListParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
Page int `json:"page"` // 页码
|
||||
PageSize int `json:"page_size"` // 每页数量
|
||||
}
|
||||
|
||||
// ==================== Conversation Action Params ====================
|
||||
|
||||
// CreateConversationParams 创建会话参数
|
||||
type CreateConversationParams struct {
|
||||
UserID string `json:"user_id"` // 目标用户ID(私聊)
|
||||
}
|
||||
|
||||
// MarkReadParams 标记已读参数
|
||||
type MarkReadParams struct {
|
||||
ConversationID string `json:"conversation_id"` // 会话ID
|
||||
LastReadSeq int64 `json:"last_read_seq"` // 最后已读消息序号
|
||||
}
|
||||
|
||||
// SetConversationPinnedParams 设置会话置顶参数
|
||||
type SetConversationPinnedParams struct {
|
||||
ConversationID string `json:"conversation_id"` // 会话ID
|
||||
IsPinned bool `json:"is_pinned"` // 是否置顶
|
||||
}
|
||||
|
||||
// ==================== Group Action Params (Additional) ====================
|
||||
|
||||
// CreateGroupParams 创建群组参数
|
||||
type CreateGroupParams struct {
|
||||
Name string `json:"name"` // 群名
|
||||
Description string `json:"description,omitempty"` // 群描述
|
||||
MemberIDs []string `json:"member_ids,omitempty"` // 初始成员ID列表
|
||||
}
|
||||
|
||||
// GetUserGroupsParams 获取用户群组列表参数
|
||||
type GetUserGroupsParams struct {
|
||||
Page int `json:"page"` // 页码
|
||||
PageSize int `json:"page_size"` // 每页数量
|
||||
}
|
||||
|
||||
// TransferOwnerParams 转让群主参数
|
||||
type TransferOwnerParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
NewOwnerID string `json:"new_owner_id"` // 新群主ID
|
||||
}
|
||||
|
||||
// InviteMembersParams 邀请成员参数
|
||||
type InviteMembersParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
MemberIDs []string `json:"member_ids"` // 被邀请的用户ID列表
|
||||
}
|
||||
|
||||
// JoinGroupParams 加入群组参数
|
||||
type JoinGroupParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
}
|
||||
|
||||
// SetNicknameParams 设置群内昵称参数
|
||||
type SetNicknameParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
Nickname string `json:"nickname"` // 群内昵称
|
||||
}
|
||||
|
||||
// SetJoinTypeParams 设置加群方式参数
|
||||
type SetJoinTypeParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
JoinType int `json:"join_type"` // 加群方式:0-允许任何人加入,1-需要审批,2-不允许加入
|
||||
}
|
||||
|
||||
// CreateAnnouncementParams 创建群公告参数
|
||||
type CreateAnnouncementParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
Content string `json:"content"` // 公告内容
|
||||
}
|
||||
|
||||
// GetAnnouncementsParams 获取群公告列表参数
|
||||
type GetAnnouncementsParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
Page int `json:"page"` // 页码
|
||||
PageSize int `json:"page_size"` // 每页数量
|
||||
}
|
||||
|
||||
// DeleteAnnouncementParams 删除群公告参数
|
||||
type DeleteAnnouncementParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
AnnouncementID string `json:"announcement_id"` // 公告ID
|
||||
}
|
||||
|
||||
// DissolveGroupParams 解散群组参数
|
||||
type DissolveGroupParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
}
|
||||
|
||||
// GetMyMemberInfoParams 获取我在群组中的成员信息参数
|
||||
type GetMyMemberInfoParams struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
}
|
||||
|
||||
// ==================== Vote DTOs ====================
|
||||
|
||||
// CreateVotePostRequest 创建投票帖子请求
|
||||
type CreateVotePostRequest struct {
|
||||
Title string `json:"title" binding:"required,max=200"`
|
||||
Content string `json:"content" binding:"max=2000"`
|
||||
CommunityID string `json:"community_id"`
|
||||
Images []string `json:"images"`
|
||||
VoteOptions []string `json:"vote_options" binding:"required,min=2,max=10"` // 投票选项,至少2个最多10个
|
||||
}
|
||||
|
||||
// VoteOptionDTO 投票选项DTO
|
||||
type VoteOptionDTO struct {
|
||||
ID string `json:"id"`
|
||||
Content string `json:"content"`
|
||||
VotesCount int `json:"votes_count"`
|
||||
}
|
||||
|
||||
// VoteResultDTO 投票结果DTO
|
||||
type VoteResultDTO struct {
|
||||
Options []VoteOptionDTO `json:"options"`
|
||||
TotalVotes int `json:"total_votes"`
|
||||
HasVoted bool `json:"has_voted"`
|
||||
VotedOptionID string `json:"voted_option_id,omitempty"`
|
||||
}
|
||||
|
||||
// ==================== WebSocket Response DTOs ====================
|
||||
|
||||
// WSResponse WebSocket响应结构体
|
||||
type WSResponse struct {
|
||||
Success bool `json:"success"` // 是否成功
|
||||
Action string `json:"action"` // 响应原action
|
||||
Data interface{} `json:"data,omitempty"` // 响应数据
|
||||
Error string `json:"error,omitempty"` // 错误信息
|
||||
}
|
||||
362
internal/dto/segment.go
Normal file
362
internal/dto/segment.go
Normal file
@@ -0,0 +1,362 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"carrot_bbs/internal/model"
|
||||
)
|
||||
|
||||
// ParseSegmentData 解析Segment数据到目标结构体
|
||||
func ParseSegmentData(segment model.MessageSegment, target interface{}) error {
|
||||
dataBytes, err := json.Marshal(segment.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(dataBytes, target)
|
||||
}
|
||||
|
||||
// NewTextSegment 创建文本Segment
|
||||
func NewTextSegment(content string) model.MessageSegment {
|
||||
return model.MessageSegment{
|
||||
Type: string(SegmentTypeText),
|
||||
Data: map[string]interface{}{"text": content},
|
||||
}
|
||||
}
|
||||
|
||||
// NewImageSegment 创建图片Segment
|
||||
func NewImageSegment(url string, width, height int, thumbnailURL string) model.MessageSegment {
|
||||
return model.MessageSegment{
|
||||
Type: string(SegmentTypeImage),
|
||||
Data: map[string]interface{}{
|
||||
"url": url,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"thumbnail_url": thumbnailURL,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewImageSegmentWithSize 创建带文件大小的图片Segment
|
||||
func NewImageSegmentWithSize(url string, width, height int, thumbnailURL string, fileSize int64) model.MessageSegment {
|
||||
return model.MessageSegment{
|
||||
Type: string(SegmentTypeImage),
|
||||
Data: map[string]interface{}{
|
||||
"url": url,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"thumbnail_url": thumbnailURL,
|
||||
"file_size": fileSize,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewVoiceSegment 创建语音Segment
|
||||
func NewVoiceSegment(url string, duration int, fileSize int64) model.MessageSegment {
|
||||
return model.MessageSegment{
|
||||
Type: string(SegmentTypeVoice),
|
||||
Data: map[string]interface{}{
|
||||
"url": url,
|
||||
"duration": duration,
|
||||
"file_size": fileSize,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewVideoSegment 创建视频Segment
|
||||
func NewVideoSegment(url string, width, height, duration int, thumbnailURL string, fileSize int64) model.MessageSegment {
|
||||
return model.MessageSegment{
|
||||
Type: string(SegmentTypeVideo),
|
||||
Data: map[string]interface{}{
|
||||
"url": url,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"duration": duration,
|
||||
"thumbnail_url": thumbnailURL,
|
||||
"file_size": fileSize,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewFileSegment 创建文件Segment
|
||||
func NewFileSegment(url, name string, size int64, mimeType string) model.MessageSegment {
|
||||
return model.MessageSegment{
|
||||
Type: string(SegmentTypeFile),
|
||||
Data: map[string]interface{}{
|
||||
"url": url,
|
||||
"name": name,
|
||||
"size": size,
|
||||
"mime_type": mimeType,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewAtSegment 创建@Segment,只存储 user_id,昵称由前端根据群成员列表实时解析
|
||||
func NewAtSegment(userID string) model.MessageSegment {
|
||||
return model.MessageSegment{
|
||||
Type: string(SegmentTypeAt),
|
||||
Data: map[string]interface{}{
|
||||
"user_id": userID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewAtAllSegment 创建@所有人Segment
|
||||
func NewAtAllSegment() model.MessageSegment {
|
||||
return model.MessageSegment{
|
||||
Type: string(SegmentTypeAt),
|
||||
Data: map[string]interface{}{
|
||||
"user_id": "all",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewReplySegment 创建回复Segment
|
||||
func NewReplySegment(messageID string) model.MessageSegment {
|
||||
return model.MessageSegment{
|
||||
Type: string(SegmentTypeReply),
|
||||
Data: map[string]interface{}{"id": messageID},
|
||||
}
|
||||
}
|
||||
|
||||
// NewFaceSegment 创建表情Segment
|
||||
func NewFaceSegment(id int, name, url string) model.MessageSegment {
|
||||
return model.MessageSegment{
|
||||
Type: string(SegmentTypeFace),
|
||||
Data: map[string]interface{}{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"url": url,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewLinkSegment 创建链接Segment
|
||||
func NewLinkSegment(url, title, description, image string) model.MessageSegment {
|
||||
return model.MessageSegment{
|
||||
Type: string(SegmentTypeLink),
|
||||
Data: map[string]interface{}{
|
||||
"url": url,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"image": image,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractTextContentFromJSON 从JSON格式的segments中提取纯文本内容
|
||||
// 用于从数据库读取的 []byte 格式的 segments
|
||||
// 已废弃:现在数据库直接存储 model.MessageSegments 类型
|
||||
func ExtractTextContentFromJSON(segmentsJSON []byte) string {
|
||||
if len(segmentsJSON) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var segments model.MessageSegments
|
||||
if err := json.Unmarshal(segmentsJSON, &segments); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return ExtractTextContentFromModel(segments)
|
||||
}
|
||||
|
||||
// ExtractTextContentFromModel 从 model.MessageSegments 中提取纯文本内容
|
||||
func ExtractTextContentFromModel(segments model.MessageSegments) string {
|
||||
var result string
|
||||
for _, segment := range segments {
|
||||
switch segment.Type {
|
||||
case "text":
|
||||
if text, ok := segment.Data["text"].(string); ok {
|
||||
result += text
|
||||
}
|
||||
case "at":
|
||||
userID, _ := segment.Data["user_id"].(string)
|
||||
if userID == "all" {
|
||||
result += "@所有人 "
|
||||
} else if userID != "" {
|
||||
// 昵称由前端实时解析,后端文本提取仅用于推送通知兜底
|
||||
result += "@某人 "
|
||||
}
|
||||
case "image":
|
||||
result += "[图片]"
|
||||
case "voice":
|
||||
result += "[语音]"
|
||||
case "video":
|
||||
result += "[视频]"
|
||||
case "file":
|
||||
if name, ok := segment.Data["name"].(string); ok && name != "" {
|
||||
result += fmt.Sprintf("[文件:%s]", name)
|
||||
} else {
|
||||
result += "[文件]"
|
||||
}
|
||||
case "face":
|
||||
if name, ok := segment.Data["name"].(string); ok && name != "" {
|
||||
result += fmt.Sprintf("[%s]", name)
|
||||
} else {
|
||||
result += "[表情]"
|
||||
}
|
||||
case "link":
|
||||
if title, ok := segment.Data["title"].(string); ok && title != "" {
|
||||
result += fmt.Sprintf("[链接:%s]", title)
|
||||
} else {
|
||||
result += "[链接]"
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ExtractTextContent 从消息链中提取纯文本内容
|
||||
// 用于搜索、通知展示等场景
|
||||
func ExtractTextContent(segments model.MessageSegments) string {
|
||||
return ExtractTextContentFromModel(segments)
|
||||
}
|
||||
|
||||
// ExtractMentionedUsers 从消息链中提取被@的用户ID列表
|
||||
// 不包括 "all"(@所有人)
|
||||
func ExtractMentionedUsers(segments model.MessageSegments) []string {
|
||||
var userIDs []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, segment := range segments {
|
||||
if segment.Type == string(SegmentTypeAt) {
|
||||
userID, _ := segment.Data["user_id"].(string)
|
||||
if userID != "all" && userID != "" && !seen[userID] {
|
||||
userIDs = append(userIDs, userID)
|
||||
seen[userID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return userIDs
|
||||
}
|
||||
|
||||
// IsAtAll 检查消息是否@了所有人
|
||||
func IsAtAll(segments model.MessageSegments) bool {
|
||||
for _, segment := range segments {
|
||||
if segment.Type == string(SegmentTypeAt) {
|
||||
if userID, ok := segment.Data["user_id"].(string); ok && userID == "all" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetReplyMessageID 从消息链中获取被回复的消息ID
|
||||
// 如果没有回复segment,返回空字符串
|
||||
func GetReplyMessageID(segments model.MessageSegments) string {
|
||||
for _, segment := range segments {
|
||||
if segment.Type == string(SegmentTypeReply) {
|
||||
if id, ok := segment.Data["id"].(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// BuildSegmentsFromContent 从旧版content构建segments
|
||||
// 用于兼容旧版本消息
|
||||
func BuildSegmentsFromContent(contentType, content string, mediaURL *string) model.MessageSegments {
|
||||
var segments model.MessageSegments
|
||||
|
||||
switch contentType {
|
||||
case "text":
|
||||
segments = append(segments, NewTextSegment(content))
|
||||
case "image":
|
||||
if mediaURL != nil {
|
||||
segments = append(segments, NewImageSegment(*mediaURL, 0, 0, ""))
|
||||
}
|
||||
case "voice":
|
||||
if mediaURL != nil {
|
||||
segments = append(segments, NewVoiceSegment(*mediaURL, 0, 0))
|
||||
}
|
||||
case "video":
|
||||
if mediaURL != nil {
|
||||
segments = append(segments, NewVideoSegment(*mediaURL, 0, 0, 0, "", 0))
|
||||
}
|
||||
case "file":
|
||||
if mediaURL != nil {
|
||||
segments = append(segments, NewFileSegment(*mediaURL, content, 0, ""))
|
||||
}
|
||||
default:
|
||||
// 默认当作文本处理
|
||||
if content != "" {
|
||||
segments = append(segments, NewTextSegment(content))
|
||||
}
|
||||
}
|
||||
|
||||
return segments
|
||||
}
|
||||
|
||||
// HasSegmentType 检查消息链中是否包含指定类型的segment
|
||||
func HasSegmentType(segments model.MessageSegments, segmentType SegmentType) bool {
|
||||
for _, segment := range segments {
|
||||
if segment.Type == string(segmentType) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetSegmentsByType 获取消息链中所有指定类型的segment
|
||||
func GetSegmentsByType(segments model.MessageSegments, segmentType SegmentType) []model.MessageSegment {
|
||||
var result []model.MessageSegment
|
||||
for _, segment := range segments {
|
||||
if segment.Type == string(segmentType) {
|
||||
result = append(result, segment)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetFirstImageURL 获取消息链中第一张图片的URL
|
||||
// 如果没有图片,返回空字符串
|
||||
func GetFirstImageURL(segments model.MessageSegments) string {
|
||||
for _, segment := range segments {
|
||||
if segment.Type == string(SegmentTypeImage) {
|
||||
if url, ok := segment.Data["url"].(string); ok {
|
||||
return url
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetFirstMediaURL 获取消息链中第一个媒体文件的URL(图片/视频/语音/文件)
|
||||
// 用于兼容旧版本API
|
||||
func GetFirstMediaURL(segments model.MessageSegments) string {
|
||||
for _, segment := range segments {
|
||||
switch segment.Type {
|
||||
case string(SegmentTypeImage), string(SegmentTypeVideo), string(SegmentTypeVoice), string(SegmentTypeFile):
|
||||
if url, ok := segment.Data["url"].(string); ok {
|
||||
return url
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// DetermineContentType 从消息链推断消息类型(用于兼容旧版本)
|
||||
func DetermineContentType(segments model.MessageSegments) string {
|
||||
if len(segments) == 0 {
|
||||
return "text"
|
||||
}
|
||||
|
||||
// 优先检查媒体类型
|
||||
for _, segment := range segments {
|
||||
switch segment.Type {
|
||||
case string(SegmentTypeImage):
|
||||
return "image"
|
||||
case string(SegmentTypeVideo):
|
||||
return "video"
|
||||
case string(SegmentTypeVoice):
|
||||
return "voice"
|
||||
case string(SegmentTypeFile):
|
||||
return "file"
|
||||
}
|
||||
}
|
||||
|
||||
// 默认返回text
|
||||
return "text"
|
||||
}
|
||||
253
internal/handler/comment_handler.go
Normal file
253
internal/handler/comment_handler.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// CommentHandler 评论处理器
|
||||
type CommentHandler struct {
|
||||
commentService *service.CommentService
|
||||
}
|
||||
|
||||
// NewCommentHandler 创建评论处理器
|
||||
func NewCommentHandler(commentService *service.CommentService) *CommentHandler {
|
||||
return &CommentHandler{
|
||||
commentService: commentService,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建评论
|
||||
func (h *CommentHandler) Create(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
type CreateRequest struct {
|
||||
PostID string `json:"post_id" binding:"required"`
|
||||
Content string `json:"content"` // 内容可选,允许只发图片
|
||||
ParentID *string `json:"parent_id"`
|
||||
Images []string `json:"images"` // 图片URL列表
|
||||
}
|
||||
|
||||
var req CreateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证:评论必须有内容或图片
|
||||
if req.Content == "" && len(req.Images) == 0 {
|
||||
response.BadRequest(c, "评论内容或图片不能同时为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 将图片列表转换为JSON字符串
|
||||
var imagesJSON string
|
||||
if len(req.Images) > 0 {
|
||||
imagesBytes, _ := json.Marshal(req.Images)
|
||||
imagesJSON = string(imagesBytes)
|
||||
}
|
||||
|
||||
comment, err := h.commentService.Create(c.Request.Context(), req.PostID, userID, req.Content, req.ParentID, imagesJSON, req.Images)
|
||||
if err != nil {
|
||||
var moderationErr *service.CommentModerationRejectedError
|
||||
if errors.As(err, &moderationErr) {
|
||||
response.BadRequest(c, moderationErr.UserMessage())
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to create comment")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ConvertCommentToResponse(comment, false))
|
||||
}
|
||||
|
||||
// GetByID 获取单条评论详情
|
||||
func (h *CommentHandler) GetByID(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
id := c.Param("id")
|
||||
|
||||
comment, err := h.commentService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "comment not found")
|
||||
return
|
||||
}
|
||||
|
||||
resp := dto.ConvertCommentToResponse(comment, h.commentService.IsLiked(c.Request.Context(), id, userID))
|
||||
response.Success(c, resp)
|
||||
}
|
||||
|
||||
// GetByPostID 获取帖子评论
|
||||
func (h *CommentHandler) GetByPostID(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
postID := c.Param("id")
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
comments, total, err := h.commentService.GetByPostID(c.Request.Context(), postID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get comments")
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应结构,检查每个评论的点赞状态
|
||||
commentResponses := dto.ConvertCommentsToResponseWithUser(comments, userID, h.commentService)
|
||||
|
||||
response.Paginated(c, commentResponses, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetReplies 获取回复
|
||||
func (h *CommentHandler) GetReplies(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
parentID := c.Param("id")
|
||||
|
||||
comments, err := h.commentService.GetReplies(c.Request.Context(), parentID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get replies")
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应结构,检查每个回复的点赞状态
|
||||
commentResponses := dto.ConvertCommentsToResponseWithUser(comments, userID, h.commentService)
|
||||
|
||||
response.Success(c, commentResponses)
|
||||
}
|
||||
|
||||
// GetRepliesByRootID 根据根评论ID分页获取回复(扁平化)
|
||||
func (h *CommentHandler) GetRepliesByRootID(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
rootID := c.Param("id")
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10"))
|
||||
|
||||
replies, total, err := h.commentService.GetRepliesByRootID(c.Request.Context(), rootID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get replies")
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应结构,检查每个回复的点赞状态
|
||||
replyResponses := dto.ConvertCommentsToResponseWithUser(replies, userID, h.commentService)
|
||||
|
||||
response.Paginated(c, replyResponses, total, page, pageSize)
|
||||
}
|
||||
|
||||
// Update 更新评论
|
||||
func (h *CommentHandler) Update(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
comment, err := h.commentService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "comment not found")
|
||||
return
|
||||
}
|
||||
|
||||
if comment.UserID != userID {
|
||||
response.Forbidden(c, "cannot update others' comment")
|
||||
return
|
||||
}
|
||||
|
||||
type UpdateRequest struct {
|
||||
Content string `json:"content" binding:"required"`
|
||||
}
|
||||
|
||||
var req UpdateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
comment.Content = req.Content
|
||||
|
||||
err = h.commentService.Update(c.Request.Context(), comment)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to update comment")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ConvertCommentToResponse(comment, false))
|
||||
}
|
||||
|
||||
// Delete 删除评论
|
||||
func (h *CommentHandler) Delete(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
comment, err := h.commentService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "comment not found")
|
||||
return
|
||||
}
|
||||
|
||||
if comment.UserID != userID {
|
||||
response.Forbidden(c, "cannot delete others' comment")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.commentService.Delete(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to delete comment")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "comment deleted", nil)
|
||||
}
|
||||
|
||||
// Like 点赞评论
|
||||
func (h *CommentHandler) Like(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
err := h.commentService.Like(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to like comment")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "liked", nil)
|
||||
}
|
||||
|
||||
// Unlike 取消点赞评论
|
||||
func (h *CommentHandler) Unlike(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
err := h.commentService.Unlike(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to unlike comment")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "unliked", nil)
|
||||
}
|
||||
234
internal/handler/gorse_handler.go
Normal file
234
internal/handler/gorse_handler.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/config"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/gorse"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
|
||||
gorseio "github.com/gorse-io/gorse-go"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GorseHandler Gorse推荐处理器
|
||||
type GorseHandler struct {
|
||||
importPassword string
|
||||
gorseConfig config.GorseConfig
|
||||
}
|
||||
|
||||
// NewGorseHandler 创建Gorse处理器
|
||||
func NewGorseHandler(cfg config.GorseConfig) *GorseHandler {
|
||||
return &GorseHandler{
|
||||
importPassword: cfg.ImportPassword,
|
||||
gorseConfig: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// ImportRequest 导入请求
|
||||
type ImportRequest struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// ImportData 导入数据到Gorse
|
||||
// POST /api/v1/gorse/import
|
||||
func (h *GorseHandler) ImportData(c *gin.Context) {
|
||||
// 验证密码
|
||||
if h.importPassword == "" {
|
||||
response.BadRequest(c, "Gorse import is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
var req ImportRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Password != h.importPassword {
|
||||
response.Unauthorized(c, "invalid password")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
stats, err := h.importAllData(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] gorse import failed: %v", err)
|
||||
response.InternalServerError(c, "gorse import failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"message": "import completed",
|
||||
"status": "done",
|
||||
"stats": stats,
|
||||
})
|
||||
}
|
||||
|
||||
// GetStatus 获取Gorse状态
|
||||
// GET /api/v1/gorse/status
|
||||
func (h *GorseHandler) GetStatus(c *gin.Context) {
|
||||
// 返回Gorse连接状态和配置信息
|
||||
hasPassword := h.importPassword != ""
|
||||
response.Success(c, gin.H{
|
||||
"enabled": h.gorseConfig.Enabled,
|
||||
"has_password": hasPassword,
|
||||
"address": h.gorseConfig.Address,
|
||||
"api_key": strings.Repeat("*", 8), // 不返回实际APIKey
|
||||
})
|
||||
}
|
||||
|
||||
func (h *GorseHandler) importAllData(ctx context.Context) (gin.H, error) {
|
||||
gorseClient := gorseio.NewGorseClient(h.gorseConfig.Address, h.gorseConfig.APIKey)
|
||||
gorse.InitEmbeddingWithConfig(h.gorseConfig.EmbeddingAPIKey, h.gorseConfig.EmbeddingURL, h.gorseConfig.EmbeddingModel)
|
||||
|
||||
stats := gin.H{
|
||||
"items": 0,
|
||||
"users": 0,
|
||||
"likes": 0,
|
||||
"favorites": 0,
|
||||
"comments": 0,
|
||||
}
|
||||
|
||||
// 导入帖子
|
||||
var posts []model.Post
|
||||
if err := model.DB.Find(&posts).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, post := range posts {
|
||||
embedding, err := gorse.GetEmbedding(strings.TrimSpace(post.Title + " " + post.Content))
|
||||
if err != nil {
|
||||
log.Printf("[WARN] get embedding failed for post %s: %v", post.ID, err)
|
||||
embedding = make([]float64, 1024)
|
||||
}
|
||||
_, err = gorseClient.InsertItem(ctx, gorseio.Item{
|
||||
ItemId: post.ID,
|
||||
IsHidden: post.DeletedAt.Valid,
|
||||
Categories: buildPostCategories(&post),
|
||||
Comment: post.Title,
|
||||
Timestamp: post.CreatedAt.UTC().Truncate(time.Second),
|
||||
Labels: map[string]any{
|
||||
"embedding": embedding,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("[WARN] import item failed (%s): %v", post.ID, err)
|
||||
continue
|
||||
}
|
||||
stats["items"] = stats["items"].(int) + 1
|
||||
}
|
||||
|
||||
// 导入用户
|
||||
var users []model.User
|
||||
if err := model.DB.Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, user := range users {
|
||||
_, err := gorseClient.InsertUser(ctx, gorseio.User{
|
||||
UserId: user.ID,
|
||||
Labels: map[string]any{
|
||||
"posts_count": user.PostsCount,
|
||||
"followers_count": user.FollowersCount,
|
||||
"following_count": user.FollowingCount,
|
||||
},
|
||||
Comment: user.Nickname,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("[WARN] import user failed (%s): %v", user.ID, err)
|
||||
continue
|
||||
}
|
||||
stats["users"] = stats["users"].(int) + 1
|
||||
}
|
||||
|
||||
// 导入点赞
|
||||
var likes []model.PostLike
|
||||
if err := model.DB.Find(&likes).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, like := range likes {
|
||||
_, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{
|
||||
FeedbackType: string(gorse.FeedbackTypeLike),
|
||||
UserId: like.UserID,
|
||||
ItemId: like.PostID,
|
||||
Timestamp: like.CreatedAt.UTC().Truncate(time.Second),
|
||||
}})
|
||||
if err != nil {
|
||||
log.Printf("[WARN] import like failed (%s/%s): %v", like.UserID, like.PostID, err)
|
||||
continue
|
||||
}
|
||||
stats["likes"] = stats["likes"].(int) + 1
|
||||
}
|
||||
|
||||
// 导入收藏
|
||||
var favorites []model.Favorite
|
||||
if err := model.DB.Find(&favorites).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, fav := range favorites {
|
||||
_, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{
|
||||
FeedbackType: string(gorse.FeedbackTypeStar),
|
||||
UserId: fav.UserID,
|
||||
ItemId: fav.PostID,
|
||||
Timestamp: fav.CreatedAt.UTC().Truncate(time.Second),
|
||||
}})
|
||||
if err != nil {
|
||||
log.Printf("[WARN] import favorite failed (%s/%s): %v", fav.UserID, fav.PostID, err)
|
||||
continue
|
||||
}
|
||||
stats["favorites"] = stats["favorites"].(int) + 1
|
||||
}
|
||||
|
||||
// 导入评论(按用户-帖子去重)
|
||||
var comments []model.Comment
|
||||
if err := model.DB.Where("status = ?", model.CommentStatusPublished).Find(&comments).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
seen := make(map[string]struct{})
|
||||
for _, cm := range comments {
|
||||
key := cm.UserID + ":" + cm.PostID
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
_, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{
|
||||
FeedbackType: string(gorse.FeedbackTypeComment),
|
||||
UserId: cm.UserID,
|
||||
ItemId: cm.PostID,
|
||||
Timestamp: cm.CreatedAt.UTC().Truncate(time.Second),
|
||||
}})
|
||||
if err != nil {
|
||||
log.Printf("[WARN] import comment failed (%s/%s): %v", cm.UserID, cm.PostID, err)
|
||||
continue
|
||||
}
|
||||
stats["comments"] = stats["comments"].(int) + 1
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func buildPostCategories(post *model.Post) []string {
|
||||
var categories []string
|
||||
if post.ViewsCount > 1000 {
|
||||
categories = append(categories, "hot_high")
|
||||
} else if post.ViewsCount > 100 {
|
||||
categories = append(categories, "hot_medium")
|
||||
}
|
||||
if post.LikesCount > 100 {
|
||||
categories = append(categories, "likes_100+")
|
||||
} else if post.LikesCount > 10 {
|
||||
categories = append(categories, "likes_10+")
|
||||
}
|
||||
age := time.Since(post.CreatedAt)
|
||||
if age < 24*time.Hour {
|
||||
categories = append(categories, "today")
|
||||
} else if age < 7*24*time.Hour {
|
||||
categories = append(categories, "this_week")
|
||||
}
|
||||
return categories
|
||||
}
|
||||
1801
internal/handler/group_handler.go
Normal file
1801
internal/handler/group_handler.go
Normal file
File diff suppressed because it is too large
Load Diff
879
internal/handler/message_handler.go
Normal file
879
internal/handler/message_handler.go
Normal file
@@ -0,0 +1,879 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// MessageHandler 消息处理器
|
||||
type MessageHandler struct {
|
||||
chatService service.ChatService
|
||||
messageService *service.MessageService
|
||||
userService *service.UserService
|
||||
groupService service.GroupService
|
||||
}
|
||||
|
||||
// NewMessageHandler 创建消息处理器
|
||||
func NewMessageHandler(chatService service.ChatService, messageService *service.MessageService, userService *service.UserService, groupService service.GroupService) *MessageHandler {
|
||||
return &MessageHandler{
|
||||
chatService: chatService,
|
||||
messageService: messageService,
|
||||
userService: userService,
|
||||
groupService: groupService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetConversations 获取会话列表
|
||||
// GET /api/conversations
|
||||
func (h *MessageHandler) GetConversations(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
// 添加调试日志
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
convs, _, err := h.chatService.GetConversationList(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get conversations")
|
||||
return
|
||||
}
|
||||
|
||||
// 过滤掉系统会话(系统通知现在使用独立的表)
|
||||
filteredConvs := make([]*model.Conversation, 0)
|
||||
for _, conv := range convs {
|
||||
if conv.ID != model.SystemConversationID {
|
||||
filteredConvs = append(filteredConvs, conv)
|
||||
}
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := make([]*dto.ConversationResponse, len(filteredConvs))
|
||||
for i, conv := range filteredConvs {
|
||||
// 获取未读数
|
||||
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conv.ID, userID)
|
||||
|
||||
// 获取最后一条消息
|
||||
var lastMessage *model.Message
|
||||
messages, _, _ := h.chatService.GetMessages(c.Request.Context(), conv.ID, userID, 1, 1)
|
||||
if len(messages) > 0 {
|
||||
lastMessage = messages[0]
|
||||
}
|
||||
|
||||
// 群聊时返回member_count,私聊时返回participants
|
||||
var resp *dto.ConversationResponse
|
||||
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
|
||||
isPinned := myParticipant != nil && myParticipant.IsPinned
|
||||
if conv.Type == model.ConversationTypeGroup && conv.GroupID != nil && *conv.GroupID != "" {
|
||||
// 群聊:实时计算群成员数量
|
||||
memberCount, _ := h.groupService.GetMemberCount(*conv.GroupID)
|
||||
// 创建响应并设置member_count
|
||||
resp = dto.ConvertConversationToResponse(conv, nil, int(unreadCount), lastMessage, isPinned)
|
||||
resp.MemberCount = memberCount
|
||||
} else {
|
||||
// 私聊:获取参与者信息
|
||||
participants, _ := h.getConversationParticipants(c.Request.Context(), conv.ID, userID)
|
||||
resp = dto.ConvertConversationToResponse(conv, participants, int(unreadCount), lastMessage, isPinned)
|
||||
}
|
||||
result[i] = resp
|
||||
}
|
||||
|
||||
// 更新 total 为过滤后的数量
|
||||
response.Paginated(c, result, int64(len(filteredConvs)), page, pageSize)
|
||||
}
|
||||
|
||||
// CreateConversation 创建私聊会话
|
||||
// POST /api/conversations
|
||||
func (h *MessageHandler) CreateConversation(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var req dto.CreateConversationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证目标用户是否存在
|
||||
targetUser, err := h.userService.GetUserByID(c.Request.Context(), req.UserID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "target user not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 不能和自己创建会话
|
||||
if userID == req.UserID {
|
||||
response.BadRequest(c, "cannot create conversation with yourself")
|
||||
return
|
||||
}
|
||||
|
||||
conv, err := h.chatService.GetOrCreateConversation(c.Request.Context(), userID, req.UserID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to create conversation")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取参与者信息
|
||||
participants := []*model.User{targetUser}
|
||||
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
|
||||
isPinned := myParticipant != nil && myParticipant.IsPinned
|
||||
|
||||
response.Success(c, dto.ConvertConversationToResponse(conv, participants, 0, nil, isPinned))
|
||||
}
|
||||
|
||||
// GetConversationByID 获取会话详情
|
||||
// GET /api/conversations/:id
|
||||
func (h *MessageHandler) GetConversationByID(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr := c.Param("id")
|
||||
fmt.Printf("[DEBUG] GetConversationByID: conversationIDStr = %s\n", conversationIDStr)
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] GetConversationByID: failed to parse conversation ID: %v\n", err)
|
||||
response.BadRequest(c, "invalid conversation id")
|
||||
return
|
||||
}
|
||||
fmt.Printf("[DEBUG] GetConversationByID: conversationID = %s\n", conversationID)
|
||||
|
||||
conv, err := h.chatService.GetConversationByID(c.Request.Context(), conversationID, userID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取未读数
|
||||
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
|
||||
|
||||
// 获取参与者信息
|
||||
participants, _ := h.getConversationParticipants(c.Request.Context(), conversationID, userID)
|
||||
|
||||
// 获取当前用户的已读位置
|
||||
myLastReadSeq := int64(0)
|
||||
isPinned := false
|
||||
allParticipants, _ := h.messageService.GetConversationParticipants(conversationID)
|
||||
for _, p := range allParticipants {
|
||||
if p.UserID == userID {
|
||||
myLastReadSeq = p.LastReadSeq
|
||||
isPinned = p.IsPinned
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 获取对方用户的已读位置
|
||||
otherLastReadSeq := int64(0)
|
||||
response.Success(c, dto.ConvertConversationToDetailResponse(conv, participants, unreadCount, nil, myLastReadSeq, otherLastReadSeq, isPinned))
|
||||
}
|
||||
|
||||
// GetMessages 获取消息列表
|
||||
// GET /api/conversations/:id/messages
|
||||
func (h *MessageHandler) GetMessages(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr := c.Param("id")
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid conversation id")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否使用增量同步(after_seq参数)
|
||||
afterSeqStr := c.Query("after_seq")
|
||||
if afterSeqStr != "" {
|
||||
// 增量同步模式
|
||||
afterSeq, err := strconv.ParseInt(afterSeqStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid after_seq")
|
||||
return
|
||||
}
|
||||
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||
|
||||
messages, err := h.chatService.GetMessagesAfterSeq(c.Request.Context(), conversationID, userID, afterSeq, limit)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := dto.ConvertMessagesToResponse(messages)
|
||||
|
||||
response.Success(c, &dto.MessageSyncResponse{
|
||||
Messages: result,
|
||||
HasMore: len(messages) == limit,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否使用历史消息加载(before_seq参数)
|
||||
beforeSeqStr := c.Query("before_seq")
|
||||
if beforeSeqStr != "" {
|
||||
// 加载更早的消息(下拉加载更多)
|
||||
beforeSeq, err := strconv.ParseInt(beforeSeqStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid before_seq")
|
||||
return
|
||||
}
|
||||
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||
|
||||
messages, err := h.chatService.GetMessagesBeforeSeq(c.Request.Context(), conversationID, userID, beforeSeq, limit)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := dto.ConvertMessagesToResponse(messages)
|
||||
|
||||
response.Success(c, &dto.MessageSyncResponse{
|
||||
Messages: result,
|
||||
HasMore: len(messages) == limit,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 分页模式
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
messages, total, err := h.chatService.GetMessages(c.Request.Context(), conversationID, userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := dto.ConvertMessagesToResponse(messages)
|
||||
|
||||
response.Paginated(c, result, total, page, pageSize)
|
||||
}
|
||||
|
||||
// SendMessage 发送消息
|
||||
// POST /api/conversations/:id/messages
|
||||
func (h *MessageHandler) SendMessage(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr := c.Param("id")
|
||||
fmt.Printf("[DEBUG] SendMessage: conversationIDStr = %s\n", conversationIDStr)
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] SendMessage: failed to parse conversation ID: %v\n", err)
|
||||
response.BadRequest(c, "invalid conversation id")
|
||||
return
|
||||
}
|
||||
fmt.Printf("[DEBUG] SendMessage: conversationID = %s, userID = %s\n", conversationID, userID)
|
||||
|
||||
var req dto.SendMessageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 直接使用 segments
|
||||
msg, err := h.chatService.SendMessage(c.Request.Context(), userID, conversationID, req.Segments, req.ReplyToID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ConvertMessageToResponse(msg))
|
||||
}
|
||||
|
||||
// HandleSendMessage RESTful 风格的发送消息端点
|
||||
// POST /api/v1/conversations/send_message
|
||||
// 请求体格式: {"detail_type": "private", "conversation_id": "123445667", "segments": [{"type": "text", "data": {"text": "嗨~"}}]}
|
||||
func (h *MessageHandler) HandleSendMessage(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var params dto.SendMessageParams
|
||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证参数
|
||||
if params.ConversationID == "" {
|
||||
response.BadRequest(c, "conversation_id is required")
|
||||
return
|
||||
}
|
||||
if params.DetailType == "" {
|
||||
response.BadRequest(c, "detail_type is required")
|
||||
return
|
||||
}
|
||||
if params.Segments == nil || len(params.Segments) == 0 {
|
||||
response.BadRequest(c, "segments is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 发送消息
|
||||
msg, err := h.chatService.SendMessage(c.Request.Context(), userID, params.ConversationID, params.Segments, params.ReplyToID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 构建 WSEventResponse 格式响应
|
||||
wsResponse := dto.WSEventResponse{
|
||||
ID: msg.ID,
|
||||
Time: msg.CreatedAt.UnixMilli(),
|
||||
Type: "message",
|
||||
DetailType: params.DetailType,
|
||||
Seq: strconv.FormatInt(msg.Seq, 10),
|
||||
Segments: params.Segments,
|
||||
SenderID: userID,
|
||||
}
|
||||
|
||||
response.Success(c, wsResponse)
|
||||
}
|
||||
|
||||
// HandleDeleteMsg 撤回消息
|
||||
// POST /api/v1/messages/delete_msg
|
||||
// 请求体格式: {"message_id": "xxx"}
|
||||
func (h *MessageHandler) HandleDeleteMsg(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var params dto.DeleteMsgParams
|
||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证参数
|
||||
if params.MessageID == "" {
|
||||
response.BadRequest(c, "message_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 撤回消息
|
||||
err := h.chatService.RecallMessage(c.Request.Context(), params.MessageID, userID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "消息已撤回", nil)
|
||||
}
|
||||
|
||||
// HandleGetConversationList 获取会话列表
|
||||
// GET /api/v1/conversations/list
|
||||
func (h *MessageHandler) HandleGetConversationList(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
convs, _, err := h.chatService.GetConversationList(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get conversations")
|
||||
return
|
||||
}
|
||||
|
||||
// 过滤掉系统会话(系统通知现在使用独立的表)
|
||||
filteredConvs := make([]*model.Conversation, 0)
|
||||
for _, conv := range convs {
|
||||
if conv.ID != model.SystemConversationID {
|
||||
filteredConvs = append(filteredConvs, conv)
|
||||
}
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := make([]*dto.ConversationResponse, len(filteredConvs))
|
||||
for i, conv := range filteredConvs {
|
||||
// 获取未读数
|
||||
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conv.ID, userID)
|
||||
|
||||
// 获取最后一条消息
|
||||
var lastMessage *model.Message
|
||||
messages, _, _ := h.chatService.GetMessages(c.Request.Context(), conv.ID, userID, 1, 1)
|
||||
if len(messages) > 0 {
|
||||
lastMessage = messages[0]
|
||||
}
|
||||
|
||||
// 群聊时返回member_count,私聊时返回participants
|
||||
var resp *dto.ConversationResponse
|
||||
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
|
||||
isPinned := myParticipant != nil && myParticipant.IsPinned
|
||||
if conv.Type == model.ConversationTypeGroup && conv.GroupID != nil && *conv.GroupID != "" {
|
||||
// 群聊:实时计算群成员数量
|
||||
memberCount, _ := h.groupService.GetMemberCount(*conv.GroupID)
|
||||
// 创建响应并设置member_count
|
||||
resp = dto.ConvertConversationToResponse(conv, nil, int(unreadCount), lastMessage, isPinned)
|
||||
resp.MemberCount = memberCount
|
||||
} else {
|
||||
// 私聊:获取参与者信息
|
||||
participants, _ := h.getConversationParticipants(c.Request.Context(), conv.ID, userID)
|
||||
resp = dto.ConvertConversationToResponse(conv, participants, int(unreadCount), lastMessage, isPinned)
|
||||
}
|
||||
result[i] = resp
|
||||
}
|
||||
|
||||
response.Paginated(c, result, int64(len(filteredConvs)), page, pageSize)
|
||||
}
|
||||
|
||||
// HandleDeleteConversationForSelf 仅自己删除会话
|
||||
// DELETE /api/v1/conversations/:id/self
|
||||
func (h *MessageHandler) HandleDeleteConversationForSelf(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
conversationID := c.Param("id")
|
||||
if conversationID == "" {
|
||||
response.BadRequest(c, "conversation id is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.chatService.DeleteConversationForSelf(c.Request.Context(), conversationID, userID); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "conversation deleted for self", nil)
|
||||
}
|
||||
|
||||
// MarkAsRead 标记为已读
|
||||
// POST /api/conversations/:id/read
|
||||
func (h *MessageHandler) MarkAsRead(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr := c.Param("id")
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid conversation id")
|
||||
return
|
||||
}
|
||||
|
||||
var req dto.MarkReadRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "last_read_seq is required")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.chatService.MarkAsRead(c.Request.Context(), conversationID, userID, req.LastReadSeq)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "marked as read", nil)
|
||||
}
|
||||
|
||||
// GetUnreadCount 获取未读消息总数
|
||||
// GET /api/conversations/unread/count
|
||||
func (h *MessageHandler) GetUnreadCount(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
// 添加调试日志
|
||||
fmt.Printf("[DEBUG] GetUnreadCount: user_id from context = %q\n", userID)
|
||||
|
||||
if userID == "" {
|
||||
fmt.Printf("[DEBUG] GetUnreadCount: user_id is empty, returning 401\n")
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
count, err := h.chatService.GetAllUnreadCount(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get unread count")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, &dto.UnreadCountResponse{
|
||||
TotalUnreadCount: count,
|
||||
})
|
||||
}
|
||||
|
||||
// GetConversationUnreadCount 获取单个会话的未读数
|
||||
// GET /api/conversations/:id/unread/count
|
||||
func (h *MessageHandler) GetConversationUnreadCount(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr := c.Param("id")
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid conversation id")
|
||||
return
|
||||
}
|
||||
|
||||
count, err := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, &dto.ConversationUnreadCountResponse{
|
||||
ConversationID: conversationID,
|
||||
UnreadCount: count,
|
||||
})
|
||||
}
|
||||
|
||||
// RecallMessage 撤回消息
|
||||
// POST /api/messages/:id/recall
|
||||
func (h *MessageHandler) RecallMessage(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
messageIDStr := c.Param("id")
|
||||
// 直接使用 string 类型的 messageID
|
||||
err := h.chatService.RecallMessage(c.Request.Context(), messageIDStr, userID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "message recalled", nil)
|
||||
}
|
||||
|
||||
// DeleteMessage 删除消息
|
||||
// DELETE /api/messages/:id
|
||||
func (h *MessageHandler) DeleteMessage(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
messageIDStr := c.Param("id")
|
||||
// 直接使用 string 类型的 messageID
|
||||
err := h.chatService.DeleteMessage(c.Request.Context(), messageIDStr, userID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "message deleted", nil)
|
||||
}
|
||||
|
||||
// 辅助函数:验证内容类型
|
||||
func isValidContentType(contentType model.ContentType) bool {
|
||||
switch contentType {
|
||||
case model.ContentTypeText, model.ContentTypeImage, model.ContentTypeVideo, model.ContentTypeAudio, model.ContentTypeFile:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数:获取会话参与者信息
|
||||
func (h *MessageHandler) getConversationParticipants(ctx context.Context, conversationID string, currentUserID string) ([]*model.User, error) {
|
||||
// 从repository获取参与者列表
|
||||
participants, err := h.messageService.GetConversationParticipants(conversationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取参与者用户信息
|
||||
var users []*model.User
|
||||
for _, p := range participants {
|
||||
// 跳过当前用户
|
||||
if p.UserID == currentUserID {
|
||||
continue
|
||||
}
|
||||
user, err := h.userService.GetUserByID(ctx, p.UserID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// 获取当前用户在会话中的参与者信息
|
||||
func (h *MessageHandler) getMyConversationParticipant(conversationID string, userID string) (*model.ConversationParticipant, error) {
|
||||
participants, err := h.messageService.GetConversationParticipants(conversationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, p := range participants {
|
||||
if p.UserID == userID {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// ==================== RESTful Action 端点 ====================
|
||||
|
||||
// HandleCreateConversation 创建会话
|
||||
// POST /api/v1/conversations/create
|
||||
func (h *MessageHandler) HandleCreateConversation(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var params dto.CreateConversationParams
|
||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证目标用户是否存在
|
||||
targetUser, err := h.userService.GetUserByID(c.Request.Context(), params.UserID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "target user not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 不能和自己创建会话
|
||||
if userID == params.UserID {
|
||||
response.BadRequest(c, "cannot create conversation with yourself")
|
||||
return
|
||||
}
|
||||
|
||||
conv, err := h.chatService.GetOrCreateConversation(c.Request.Context(), userID, params.UserID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to create conversation")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取参与者信息
|
||||
participants := []*model.User{targetUser}
|
||||
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
|
||||
isPinned := myParticipant != nil && myParticipant.IsPinned
|
||||
|
||||
response.Success(c, dto.ConvertConversationToResponse(conv, participants, 0, nil, isPinned))
|
||||
}
|
||||
|
||||
// HandleGetConversation 获取会话详情
|
||||
// GET /api/v1/conversations/get?conversation_id=xxx
|
||||
func (h *MessageHandler) HandleGetConversation(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
conversationID := c.Query("conversation_id")
|
||||
if conversationID == "" {
|
||||
response.BadRequest(c, "conversation_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
conv, err := h.chatService.GetConversationByID(c.Request.Context(), conversationID, userID)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取未读数
|
||||
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
|
||||
|
||||
// 获取参与者信息
|
||||
participants, _ := h.getConversationParticipants(c.Request.Context(), conversationID, userID)
|
||||
|
||||
// 获取当前用户的已读位置
|
||||
myLastReadSeq := int64(0)
|
||||
isPinned := false
|
||||
allParticipants, _ := h.messageService.GetConversationParticipants(conversationID)
|
||||
for _, p := range allParticipants {
|
||||
if p.UserID == userID {
|
||||
myLastReadSeq = p.LastReadSeq
|
||||
isPinned = p.IsPinned
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 获取对方用户的已读位置
|
||||
otherLastReadSeq := int64(0)
|
||||
response.Success(c, dto.ConvertConversationToDetailResponse(conv, participants, unreadCount, nil, myLastReadSeq, otherLastReadSeq, isPinned))
|
||||
}
|
||||
|
||||
// HandleGetMessages 获取会话消息
|
||||
// GET /api/v1/conversations/get_messages?conversation_id=xxx
|
||||
func (h *MessageHandler) HandleGetMessages(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
conversationID := c.Query("conversation_id")
|
||||
if conversationID == "" {
|
||||
response.BadRequest(c, "conversation_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否使用增量同步(after_seq参数)
|
||||
afterSeqStr := c.Query("after_seq")
|
||||
if afterSeqStr != "" {
|
||||
// 增量同步模式
|
||||
afterSeq, err := strconv.ParseInt(afterSeqStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid after_seq")
|
||||
return
|
||||
}
|
||||
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||
|
||||
messages, err := h.chatService.GetMessagesAfterSeq(c.Request.Context(), conversationID, userID, afterSeq, limit)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := dto.ConvertMessagesToResponse(messages)
|
||||
|
||||
response.Success(c, &dto.MessageSyncResponse{
|
||||
Messages: result,
|
||||
HasMore: len(messages) == limit,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否使用历史消息加载(before_seq参数)
|
||||
beforeSeqStr := c.Query("before_seq")
|
||||
if beforeSeqStr != "" {
|
||||
// 加载更早的消息(下拉加载更多)
|
||||
beforeSeq, err := strconv.ParseInt(beforeSeqStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid before_seq")
|
||||
return
|
||||
}
|
||||
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||
|
||||
messages, err := h.chatService.GetMessagesBeforeSeq(c.Request.Context(), conversationID, userID, beforeSeq, limit)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := dto.ConvertMessagesToResponse(messages)
|
||||
|
||||
response.Success(c, &dto.MessageSyncResponse{
|
||||
Messages: result,
|
||||
HasMore: len(messages) == limit,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 分页模式
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
messages, total, err := h.chatService.GetMessages(c.Request.Context(), conversationID, userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := dto.ConvertMessagesToResponse(messages)
|
||||
|
||||
response.Paginated(c, result, total, page, pageSize)
|
||||
}
|
||||
|
||||
// HandleMarkRead 标记已读
|
||||
// POST /api/v1/conversations/mark_read
|
||||
func (h *MessageHandler) HandleMarkRead(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var params dto.MarkReadParams
|
||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if params.ConversationID == "" {
|
||||
response.BadRequest(c, "conversation_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.chatService.MarkAsRead(c.Request.Context(), params.ConversationID, userID, params.LastReadSeq)
|
||||
if err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "marked as read", nil)
|
||||
}
|
||||
|
||||
// HandleSetConversationPinned 设置会话置顶
|
||||
// POST /api/v1/conversations/set_pinned
|
||||
func (h *MessageHandler) HandleSetConversationPinned(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var params dto.SetConversationPinnedParams
|
||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if params.ConversationID == "" {
|
||||
response.BadRequest(c, "conversation_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.chatService.SetConversationPinned(c.Request.Context(), params.ConversationID, userID, params.IsPinned); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "conversation pinned status updated", gin.H{
|
||||
"conversation_id": params.ConversationID,
|
||||
"is_pinned": params.IsPinned,
|
||||
})
|
||||
}
|
||||
132
internal/handler/notification_handler.go
Normal file
132
internal/handler/notification_handler.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// NotificationHandler 通知处理器
|
||||
type NotificationHandler struct {
|
||||
notificationService *service.NotificationService
|
||||
}
|
||||
|
||||
// NewNotificationHandler 创建通知处理器
|
||||
func NewNotificationHandler(notificationService *service.NotificationService) *NotificationHandler {
|
||||
return &NotificationHandler{
|
||||
notificationService: notificationService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetNotifications 获取通知列表
|
||||
func (h *NotificationHandler) GetNotifications(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
unreadOnly := c.Query("unread_only") == "true"
|
||||
|
||||
notifications, total, err := h.notificationService.GetByUserID(c.Request.Context(), userID, page, pageSize, unreadOnly)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get notifications")
|
||||
return
|
||||
}
|
||||
|
||||
response.Paginated(c, notifications, total, page, pageSize)
|
||||
}
|
||||
|
||||
// MarkAsRead 标记为已读
|
||||
func (h *NotificationHandler) MarkAsRead(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
err := h.notificationService.MarkAsReadWithUserID(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to mark as read")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "marked as read", nil)
|
||||
}
|
||||
|
||||
// MarkAllAsRead 标记所有为已读
|
||||
func (h *NotificationHandler) MarkAllAsRead(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.notificationService.MarkAllAsRead(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to mark all as read")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "all marked as read", nil)
|
||||
}
|
||||
|
||||
// GetUnreadCount 获取未读数量
|
||||
func (h *NotificationHandler) GetUnreadCount(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
count, err := h.notificationService.GetUnreadCount(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get unread count")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"count": count})
|
||||
}
|
||||
|
||||
// DeleteNotification 删除通知
|
||||
func (h *NotificationHandler) DeleteNotification(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
err := h.notificationService.DeleteNotification(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to delete notification")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// ClearAllNotifications 清空所有通知
|
||||
func (h *NotificationHandler) ClearAllNotifications(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.notificationService.ClearAllNotifications(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to clear notifications")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
511
internal/handler/post_handler.go
Normal file
511
internal/handler/post_handler.go
Normal file
@@ -0,0 +1,511 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// PostHandler 帖子处理器
|
||||
type PostHandler struct {
|
||||
postService *service.PostService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
// NewPostHandler 创建帖子处理器
|
||||
func NewPostHandler(postService *service.PostService, userService *service.UserService) *PostHandler {
|
||||
return &PostHandler{
|
||||
postService: postService,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建帖子
|
||||
func (h *PostHandler) Create(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
type CreateRequest struct {
|
||||
Title string `json:"title" binding:"required"`
|
||||
Content string `json:"content" binding:"required"`
|
||||
Images []string `json:"images"`
|
||||
}
|
||||
|
||||
var req CreateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
post, err := h.postService.Create(c.Request.Context(), userID, req.Title, req.Content, req.Images)
|
||||
if err != nil {
|
||||
var moderationErr *service.PostModerationRejectedError
|
||||
if errors.As(err, &moderationErr) {
|
||||
response.BadRequest(c, moderationErr.UserMessage())
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to create post")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.ConvertPostToResponse(post, false, false))
|
||||
}
|
||||
|
||||
// GetByID 获取帖子(不增加浏览量)
|
||||
func (h *PostHandler) GetByID(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "post not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 非作者不可查看未发布内容
|
||||
currentUserID := c.GetString("user_id")
|
||||
if post.Status != model.PostStatusPublished && post.UserID != currentUserID {
|
||||
response.NotFound(c, "post not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 注意:不再自动增加浏览量,浏览量通过 RecordView 端点单独记录
|
||||
|
||||
// 获取当前用户ID用于判断点赞和收藏状态
|
||||
fmt.Printf("[DEBUG] GetByID - postID: %s, currentUserID: %s\n", id, currentUserID)
|
||||
|
||||
var isLiked, isFavorited bool
|
||||
if currentUserID != "" {
|
||||
isLiked = h.postService.IsLiked(c.Request.Context(), id, currentUserID)
|
||||
isFavorited = h.postService.IsFavorited(c.Request.Context(), id, currentUserID)
|
||||
fmt.Printf("[DEBUG] GetByID - isLiked: %v, isFavorited: %v\n", isLiked, isFavorited)
|
||||
} else {
|
||||
fmt.Printf("[DEBUG] GetByID - user not logged in, isLiked: false, isFavorited: false\n")
|
||||
}
|
||||
|
||||
// 如果有当前用户,检查与帖子作者的相互关注状态
|
||||
var authorWithFollowStatus *dto.UserResponse
|
||||
if currentUserID != "" && post.User != nil {
|
||||
_, isFollowing, isFollowingMe, err := h.userService.GetUserByIDWithMutualFollowStatus(c.Request.Context(), post.UserID, currentUserID)
|
||||
if err == nil {
|
||||
authorWithFollowStatus = dto.ConvertUserToResponseWithMutualFollow(post.User, isFollowing, isFollowingMe)
|
||||
} else {
|
||||
// 如果出错,使用默认的author
|
||||
authorWithFollowStatus = dto.ConvertUserToResponse(post.User)
|
||||
}
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
responseData := &dto.PostResponse{
|
||||
ID: post.ID,
|
||||
UserID: post.UserID,
|
||||
Title: post.Title,
|
||||
Content: post.Content,
|
||||
Images: dto.ConvertPostImagesToResponse(post.Images),
|
||||
LikesCount: post.LikesCount,
|
||||
CommentsCount: post.CommentsCount,
|
||||
FavoritesCount: post.FavoritesCount,
|
||||
SharesCount: post.SharesCount,
|
||||
ViewsCount: post.ViewsCount,
|
||||
IsPinned: post.IsPinned,
|
||||
IsLocked: post.IsLocked,
|
||||
IsVote: post.IsVote,
|
||||
CreatedAt: dto.FormatTime(post.CreatedAt),
|
||||
Author: authorWithFollowStatus,
|
||||
IsLiked: isLiked,
|
||||
IsFavorited: isFavorited,
|
||||
}
|
||||
|
||||
response.Success(c, responseData)
|
||||
}
|
||||
|
||||
// RecordView 记录帖子浏览(增加浏览量)
|
||||
func (h *PostHandler) RecordView(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
userID := c.GetString("user_id")
|
||||
|
||||
// 验证帖子存在
|
||||
_, err := h.postService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "post not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 增加浏览量
|
||||
if err := h.postService.IncrementViews(c.Request.Context(), id, userID); err != nil {
|
||||
fmt.Printf("[DEBUG] Failed to increment views for post %s: %v\n", id, err)
|
||||
response.InternalServerError(c, "failed to record view")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// List 获取帖子列表
|
||||
func (h *PostHandler) List(c *gin.Context) {
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
userID := c.Query("user_id")
|
||||
tab := c.Query("tab") // recommend, follow, hot, latest
|
||||
|
||||
// 获取当前用户ID
|
||||
currentUserID := c.GetString("user_id")
|
||||
|
||||
var posts []*model.Post
|
||||
var total int64
|
||||
var err error
|
||||
|
||||
// 根据 tab 参数选择不同的获取方式
|
||||
switch tab {
|
||||
case "follow":
|
||||
// 获取关注用户的帖子,需要登录
|
||||
if currentUserID == "" {
|
||||
response.Unauthorized(c, "请先登录")
|
||||
return
|
||||
}
|
||||
posts, total, err = h.postService.GetFollowingPosts(c.Request.Context(), currentUserID, page, pageSize)
|
||||
case "hot":
|
||||
// 获取热门帖子
|
||||
posts, total, err = h.postService.GetHotPosts(c.Request.Context(), page, pageSize)
|
||||
case "recommend":
|
||||
// 推荐帖子(从Gorse获取个性化推荐)
|
||||
posts, total, err = h.postService.GetRecommendedPosts(c.Request.Context(), currentUserID, page, pageSize)
|
||||
case "latest":
|
||||
// 最新帖子
|
||||
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
|
||||
default:
|
||||
// 默认获取最新帖子
|
||||
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get posts")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] List - tab: %s, currentUserID: %s, posts count: %d\n", tab, currentUserID, len(posts))
|
||||
|
||||
isLikedMap := make(map[string]bool)
|
||||
isFavoritedMap := make(map[string]bool)
|
||||
if currentUserID != "" {
|
||||
for _, post := range posts {
|
||||
isLiked := h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
|
||||
isFavorited := h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
|
||||
isLikedMap[post.ID] = isLiked
|
||||
isFavoritedMap[post.ID] = isFavorited
|
||||
fmt.Printf("[DEBUG] List - postID: %s, isLiked: %v, isFavorited: %v\n", post.ID, isLiked, isFavorited)
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("[DEBUG] List - user not logged in\n")
|
||||
}
|
||||
|
||||
// 转换为响应结构
|
||||
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
|
||||
|
||||
response.Paginated(c, postResponses, total, page, pageSize)
|
||||
}
|
||||
|
||||
// Update 更新帖子
|
||||
func (h *PostHandler) Update(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "post not found")
|
||||
return
|
||||
}
|
||||
|
||||
if post.UserID != userID {
|
||||
response.Forbidden(c, "cannot update others' post")
|
||||
return
|
||||
}
|
||||
|
||||
type UpdateRequest struct {
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
var req UpdateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.Title != "" {
|
||||
post.Title = req.Title
|
||||
}
|
||||
if req.Content != "" {
|
||||
post.Content = req.Content
|
||||
}
|
||||
|
||||
err = h.postService.Update(c.Request.Context(), post)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to update post")
|
||||
return
|
||||
}
|
||||
|
||||
currentUserID := c.GetString("user_id")
|
||||
var isLiked, isFavorited bool
|
||||
if currentUserID != "" {
|
||||
isLiked = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
|
||||
isFavorited = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
|
||||
}
|
||||
|
||||
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
|
||||
}
|
||||
|
||||
// Delete 删除帖子
|
||||
func (h *PostHandler) Delete(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.NotFound(c, "post not found")
|
||||
return
|
||||
}
|
||||
|
||||
if post.UserID != userID {
|
||||
response.Forbidden(c, "cannot delete others' post")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.postService.Delete(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to delete post")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "post deleted", nil)
|
||||
}
|
||||
|
||||
// Like 点赞帖子
|
||||
func (h *PostHandler) Like(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
fmt.Printf("[DEBUG] Like - postID: %s, userID: %s\n", id, userID)
|
||||
|
||||
err := h.postService.Like(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to like post")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取更新后的帖子状态
|
||||
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get post")
|
||||
return
|
||||
}
|
||||
|
||||
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
|
||||
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
|
||||
fmt.Printf("[DEBUG] Like - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
|
||||
|
||||
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
|
||||
}
|
||||
|
||||
// Unlike 取消点赞
|
||||
func (h *PostHandler) Unlike(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
fmt.Printf("[DEBUG] Unlike - postID: %s, userID: %s\n", id, userID)
|
||||
|
||||
err := h.postService.Unlike(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to unlike post")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取更新后的帖子状态
|
||||
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get post")
|
||||
return
|
||||
}
|
||||
|
||||
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
|
||||
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
|
||||
fmt.Printf("[DEBUG] Unlike - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
|
||||
|
||||
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
|
||||
}
|
||||
|
||||
// Favorite 收藏帖子
|
||||
func (h *PostHandler) Favorite(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
fmt.Printf("[DEBUG] Favorite - postID: %s, userID: %s\n", id, userID)
|
||||
|
||||
err := h.postService.Favorite(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to favorite post")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取更新后的帖子状态
|
||||
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get post")
|
||||
return
|
||||
}
|
||||
|
||||
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
|
||||
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
|
||||
fmt.Printf("[DEBUG] Favorite - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
|
||||
|
||||
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
|
||||
}
|
||||
|
||||
// Unfavorite 取消收藏
|
||||
func (h *PostHandler) Unfavorite(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
fmt.Printf("[DEBUG] Unfavorite - postID: %s, userID: %s\n", id, userID)
|
||||
|
||||
err := h.postService.Unfavorite(c.Request.Context(), id, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to unfavorite post")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取更新后的帖子状态
|
||||
post, err := h.postService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get post")
|
||||
return
|
||||
}
|
||||
|
||||
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
|
||||
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
|
||||
fmt.Printf("[DEBUG] Unfavorite - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
|
||||
|
||||
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
|
||||
}
|
||||
|
||||
// GetUserPosts 获取用户帖子列表
|
||||
func (h *PostHandler) GetUserPosts(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
posts, total, err := h.postService.GetUserPosts(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get user posts")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户ID用于判断点赞和收藏状态
|
||||
currentUserID := c.GetString("user_id")
|
||||
isLikedMap := make(map[string]bool)
|
||||
isFavoritedMap := make(map[string]bool)
|
||||
if currentUserID != "" {
|
||||
for _, post := range posts {
|
||||
isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
|
||||
isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
|
||||
}
|
||||
}
|
||||
|
||||
// 转换为响应结构
|
||||
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
|
||||
|
||||
response.Paginated(c, postResponses, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetFavorites 获取收藏列表
|
||||
func (h *PostHandler) GetFavorites(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
posts, total, err := h.postService.GetFavorites(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get favorites")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户ID用于判断点赞和收藏状态
|
||||
currentUserID := c.GetString("user_id")
|
||||
isLikedMap := make(map[string]bool)
|
||||
isFavoritedMap := make(map[string]bool)
|
||||
if currentUserID != "" {
|
||||
for _, post := range posts {
|
||||
isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
|
||||
isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
|
||||
}
|
||||
}
|
||||
|
||||
// 转换为响应结构
|
||||
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
|
||||
|
||||
response.Paginated(c, postResponses, total, page, pageSize)
|
||||
}
|
||||
|
||||
// Search 搜索帖子
|
||||
func (h *PostHandler) Search(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
posts, total, err := h.postService.Search(c.Request.Context(), keyword, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to search posts")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户ID用于判断点赞和收藏状态
|
||||
currentUserID := c.GetString("user_id")
|
||||
isLikedMap := make(map[string]bool)
|
||||
isFavoritedMap := make(map[string]bool)
|
||||
if currentUserID != "" {
|
||||
for _, post := range posts {
|
||||
isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
|
||||
isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
|
||||
}
|
||||
}
|
||||
|
||||
// 转换为响应结构
|
||||
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
|
||||
|
||||
response.Paginated(c, postResponses, total, page, pageSize)
|
||||
}
|
||||
157
internal/handler/push_handler.go
Normal file
157
internal/handler/push_handler.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// PushHandler 推送处理器
|
||||
type PushHandler struct {
|
||||
pushService service.PushService
|
||||
}
|
||||
|
||||
// NewPushHandler 创建推送处理器
|
||||
func NewPushHandler(pushService service.PushService) *PushHandler {
|
||||
return &PushHandler{
|
||||
pushService: pushService,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterDevice 注册设备
|
||||
// POST /api/v1/push/devices
|
||||
func (h *PushHandler) RegisterDevice(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var req dto.RegisterDeviceRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证设备类型
|
||||
deviceType := model.DeviceType(req.DeviceType)
|
||||
if !isValidDeviceType(deviceType) {
|
||||
response.BadRequest(c, "invalid device type")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.pushService.RegisterDevice(c.Request.Context(), userID, req.DeviceID, deviceType, req.PushToken)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to register device")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "device registered successfully", nil)
|
||||
}
|
||||
|
||||
// UnregisterDevice 注销设备
|
||||
// DELETE /api/v1/push/devices/:device_id
|
||||
func (h *PushHandler) UnregisterDevice(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
deviceID := c.Param("device_id")
|
||||
if deviceID == "" {
|
||||
response.BadRequest(c, "device_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.pushService.UnregisterDevice(c.Request.Context(), deviceID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to unregister device")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "device unregistered successfully", nil)
|
||||
}
|
||||
|
||||
// GetMyDevices 获取当前用户的设备列表
|
||||
// GET /api/v1/push/devices
|
||||
func (h *PushHandler) GetMyDevices(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
// 这里需要从DeviceTokenRepository获取设备列表
|
||||
// 由于PushService接口没有提供获取设备列表的方法,我们暂时返回空列表
|
||||
// TODO: 在PushService接口中添加GetUserDevices方法
|
||||
_ = userID // 避免未使用变量警告
|
||||
|
||||
response.Success(c, []*dto.DeviceTokenResponse{})
|
||||
}
|
||||
|
||||
// GetPushRecords 获取推送记录
|
||||
// GET /api/v1/push/records
|
||||
func (h *PushHandler) GetPushRecords(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
records, err := h.pushService.GetPendingPushes(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get push records")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, &dto.PushRecordListResponse{
|
||||
Records: dto.PushRecordsToResponse(records),
|
||||
Total: int64(len(records)),
|
||||
})
|
||||
}
|
||||
|
||||
// 辅助函数:验证设备类型
|
||||
func isValidDeviceType(deviceType model.DeviceType) bool {
|
||||
switch deviceType {
|
||||
case model.DeviceTypeIOS, model.DeviceTypeAndroid, model.DeviceTypeWeb:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateDeviceToken 更新设备推送Token
|
||||
// PUT /api/v1/push/devices/:device_id/token
|
||||
func (h *PushHandler) UpdateDeviceToken(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
deviceID := c.Param("device_id")
|
||||
if deviceID == "" {
|
||||
response.BadRequest(c, "device_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
PushToken string `json:"push_token" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err := h.pushService.UpdateDeviceToken(c.Request.Context(), deviceID, req.PushToken)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to update device token")
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "device token updated successfully", nil)
|
||||
}
|
||||
164
internal/handler/sticker_handler.go
Normal file
164
internal/handler/sticker_handler.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// StickerHandler 自定义表情处理器
|
||||
type StickerHandler struct {
|
||||
stickerService service.StickerService
|
||||
}
|
||||
|
||||
// NewStickerHandler 创建自定义表情处理器
|
||||
func NewStickerHandler(stickerService service.StickerService) *StickerHandler {
|
||||
return &StickerHandler{
|
||||
stickerService: stickerService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetStickersRequest 获取表情列表请求
|
||||
type GetStickersRequest struct {
|
||||
UserID string `form:"user_id" binding:"required"`
|
||||
}
|
||||
|
||||
// AddStickerRequest 添加表情请求
|
||||
type AddStickerRequest struct {
|
||||
URL string `json:"url" binding:"required"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
}
|
||||
|
||||
// DeleteStickerRequest 删除表情请求
|
||||
type DeleteStickerRequest struct {
|
||||
StickerID string `json:"sticker_id" binding:"required"`
|
||||
}
|
||||
|
||||
// ReorderStickersRequest 重新排序请求
|
||||
type ReorderStickersRequest struct {
|
||||
Orders map[string]int `json:"orders" binding:"required"`
|
||||
}
|
||||
|
||||
// CheckStickerRequest 检查表情是否存在请求
|
||||
type CheckStickerRequest struct {
|
||||
URL string `form:"url" binding:"required"`
|
||||
}
|
||||
|
||||
// GetStickers 获取用户的表情列表
|
||||
func (h *StickerHandler) GetStickers(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
stickers, err := h.stickerService.GetUserStickers(userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get stickers")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"stickers": stickers})
|
||||
}
|
||||
|
||||
// AddSticker 添加表情
|
||||
func (h *StickerHandler) AddSticker(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var req AddStickerRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sticker, err := h.stickerService.AddSticker(userID, req.URL, req.Width, req.Height)
|
||||
if err != nil {
|
||||
if err == service.ErrStickerAlreadyExists {
|
||||
response.Error(c, http.StatusConflict, "sticker already exists")
|
||||
return
|
||||
}
|
||||
if err == service.ErrInvalidStickerURL {
|
||||
response.BadRequest(c, "invalid sticker url, only http/https is allowed")
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"sticker": sticker})
|
||||
}
|
||||
|
||||
// DeleteSticker 删除表情
|
||||
func (h *StickerHandler) DeleteSticker(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var req DeleteStickerRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.stickerService.DeleteSticker(userID, req.StickerID); err != nil {
|
||||
response.InternalServerError(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "sticker deleted successfully", nil)
|
||||
}
|
||||
|
||||
// ReorderStickers 重新排序表情
|
||||
func (h *StickerHandler) ReorderStickers(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
var req ReorderStickersRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.stickerService.ReorderStickers(userID, req.Orders); err != nil {
|
||||
response.InternalServerError(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.SuccessWithMessage(c, "stickers reordered successfully", nil)
|
||||
}
|
||||
|
||||
// CheckStickerExists 检查表情是否存在
|
||||
func (h *StickerHandler) CheckStickerExists(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
url := c.Query("url")
|
||||
if url == "" {
|
||||
response.BadRequest(c, "url is required")
|
||||
return
|
||||
}
|
||||
|
||||
exists, err := h.stickerService.CheckExists(userID, url)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"exists": exists})
|
||||
}
|
||||
154
internal/handler/system_message_handler.go
Normal file
154
internal/handler/system_message_handler.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/repository"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// SystemMessageHandler 系统消息处理器
|
||||
type SystemMessageHandler struct {
|
||||
systemMsgService service.SystemMessageService
|
||||
notifyRepo *repository.SystemNotificationRepository
|
||||
}
|
||||
|
||||
// NewSystemMessageHandler 创建系统消息处理器
|
||||
func NewSystemMessageHandler(
|
||||
systemMsgService service.SystemMessageService,
|
||||
notifyRepo *repository.SystemNotificationRepository,
|
||||
) *SystemMessageHandler {
|
||||
return &SystemMessageHandler{
|
||||
systemMsgService: systemMsgService,
|
||||
notifyRepo: notifyRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GetSystemMessages 获取系统消息列表
|
||||
// GET /api/v1/messages/system
|
||||
func (h *SystemMessageHandler) GetSystemMessages(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
// 获取当前用户的系统通知(从独立表中获取)
|
||||
notifications, total, err := h.notifyRepo.GetByReceiverID(userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get system messages")
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := make([]*dto.SystemMessageResponse, 0)
|
||||
for _, n := range notifications {
|
||||
resp := dto.SystemNotificationToResponse(n)
|
||||
result = append(result, resp)
|
||||
}
|
||||
|
||||
response.Paginated(c, result, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetUnreadCount 获取系统消息未读数
|
||||
// GET /api/v1/messages/system/unread-count
|
||||
func (h *SystemMessageHandler) GetUnreadCount(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户的未读通知数
|
||||
unreadCount, err := h.notifyRepo.GetUnreadCount(userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get unread count")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, &dto.SystemUnreadCountResponse{
|
||||
UnreadCount: unreadCount,
|
||||
})
|
||||
}
|
||||
|
||||
// MarkAsRead 标记系统消息为已读
|
||||
// PUT /api/v1/messages/system/:id/read
|
||||
func (h *SystemMessageHandler) MarkAsRead(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
notificationIDStr := c.Param("id")
|
||||
notificationID, err := strconv.ParseInt(notificationIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid notification id")
|
||||
return
|
||||
}
|
||||
|
||||
// 标记为已读
|
||||
err = h.notifyRepo.MarkAsRead(notificationID, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to mark as read")
|
||||
return
|
||||
}
|
||||
cache.InvalidateUnreadSystem(cache.GetCache(), userID)
|
||||
|
||||
response.SuccessWithMessage(c, "marked as read", nil)
|
||||
}
|
||||
|
||||
// MarkAllAsRead 标记所有系统消息为已读
|
||||
// PUT /api/v1/messages/system/read-all
|
||||
func (h *SystemMessageHandler) MarkAllAsRead(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
// 标记当前用户所有通知为已读
|
||||
err := h.notifyRepo.MarkAllAsRead(userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to mark all as read")
|
||||
return
|
||||
}
|
||||
cache.InvalidateUnreadSystem(cache.GetCache(), userID)
|
||||
|
||||
response.SuccessWithMessage(c, "all messages marked as read", nil)
|
||||
}
|
||||
|
||||
// DeleteSystemMessage 删除系统消息
|
||||
// DELETE /api/v1/messages/system/:id
|
||||
func (h *SystemMessageHandler) DeleteSystemMessage(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
notificationIDStr := c.Param("id")
|
||||
notificationID, err := strconv.ParseInt(notificationIDStr, 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "invalid notification id")
|
||||
return
|
||||
}
|
||||
|
||||
// 删除通知
|
||||
err = h.notifyRepo.Delete(notificationID, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to delete notification")
|
||||
return
|
||||
}
|
||||
cache.InvalidateUnreadSystem(cache.GetCache(), userID)
|
||||
|
||||
response.SuccessWithMessage(c, "notification deleted", nil)
|
||||
}
|
||||
90
internal/handler/upload_handler.go
Normal file
90
internal/handler/upload_handler.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// UploadHandler 上传处理器
|
||||
type UploadHandler struct {
|
||||
uploadService *service.UploadService
|
||||
}
|
||||
|
||||
// NewUploadHandler 创建上传处理器
|
||||
func NewUploadHandler(uploadService *service.UploadService) *UploadHandler {
|
||||
return &UploadHandler{
|
||||
uploadService: uploadService,
|
||||
}
|
||||
}
|
||||
|
||||
// UploadImage 上传图片
|
||||
func (h *UploadHandler) UploadImage(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
file, err := c.FormFile("image")
|
||||
if err != nil {
|
||||
response.BadRequest(c, "image file is required")
|
||||
return
|
||||
}
|
||||
|
||||
url, err := h.uploadService.UploadImage(c.Request.Context(), file)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to upload image")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"url": url})
|
||||
}
|
||||
|
||||
// UploadAvatar 上传头像
|
||||
func (h *UploadHandler) UploadAvatar(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
file, err := c.FormFile("image")
|
||||
|
||||
if err != nil {
|
||||
response.BadRequest(c, "avatar file is required")
|
||||
return
|
||||
}
|
||||
|
||||
url, err := h.uploadService.UploadAvatar(c.Request.Context(), userID, file)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to upload avatar")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"url": url})
|
||||
}
|
||||
|
||||
// UploadCover 上传头图(个人主页封面)
|
||||
func (h *UploadHandler) UploadCover(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
file, err := c.FormFile("image")
|
||||
if err != nil {
|
||||
response.BadRequest(c, "image file is required")
|
||||
return
|
||||
}
|
||||
|
||||
url, err := h.uploadService.UploadCover(c.Request.Context(), userID, file)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to upload cover")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"url": url})
|
||||
}
|
||||
705
internal/handler/user_handler.go
Normal file
705
internal/handler/user_handler.go
Normal file
@@ -0,0 +1,705 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// UserHandler 用户处理器
|
||||
type UserHandler struct {
|
||||
userService *service.UserService
|
||||
jwtService *service.JWTService
|
||||
}
|
||||
|
||||
// NewUserHandler 创建用户处理器
|
||||
func NewUserHandler(userService *service.UserService) *UserHandler {
|
||||
return &UserHandler{
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 用户注册
|
||||
func (h *UserHandler) Register(c *gin.Context) {
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
Nickname string `json:"nickname" binding:"required"`
|
||||
Phone string `json:"phone"`
|
||||
VerificationCode string `json:"verification_code" binding:"required"`
|
||||
}
|
||||
|
||||
var req RegisterRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.Register(c.Request.Context(), req.Username, req.Email, req.Password, req.Nickname, req.Phone, req.VerificationCode)
|
||||
if err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to register")
|
||||
return
|
||||
}
|
||||
|
||||
// 生成Token
|
||||
accessToken, _ := h.jwtService.GenerateAccessToken(user.ID, user.Username)
|
||||
refreshToken, _ := h.jwtService.GenerateRefreshToken(user.ID, user.Username)
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"user": dto.ConvertUserToResponse(user),
|
||||
"token": accessToken,
|
||||
"refresh_token": refreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
// Login 用户登录
|
||||
func (h *UserHandler) Login(c *gin.Context) {
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username"`
|
||||
Account string `json:"account"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
var req LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
account := req.Account
|
||||
if account == "" {
|
||||
account = req.Username
|
||||
}
|
||||
if account == "" {
|
||||
response.BadRequest(c, "username or account is required")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.Login(c.Request.Context(), account, req.Password)
|
||||
if err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to login")
|
||||
return
|
||||
}
|
||||
|
||||
// 生成Token
|
||||
accessToken, _ := h.jwtService.GenerateAccessToken(user.ID, user.Username)
|
||||
refreshToken, _ := h.jwtService.GenerateRefreshToken(user.ID, user.Username)
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"user": dto.ConvertUserToResponse(user),
|
||||
"token": accessToken,
|
||||
"refresh_token": refreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
// SendRegisterCode 发送注册验证码
|
||||
func (h *UserHandler) SendRegisterCode(c *gin.Context) {
|
||||
type SendCodeRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
}
|
||||
|
||||
var req SendCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.SendRegisterCode(c.Request.Context(), req.Email); err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to send register verification code")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// SendPasswordResetCode 发送找回密码验证码
|
||||
func (h *UserHandler) SendPasswordResetCode(c *gin.Context) {
|
||||
type SendCodeRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
}
|
||||
|
||||
var req SendCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.SendPasswordResetCode(c.Request.Context(), req.Email); err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to send reset verification code")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// ResetPassword 找回密码并重置
|
||||
func (h *UserHandler) ResetPassword(c *gin.Context) {
|
||||
type ResetPasswordRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
VerificationCode string `json:"verification_code" binding:"required"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
var req ResetPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.ResetPasswordByEmail(c.Request.Context(), req.Email, req.VerificationCode, req.NewPassword); err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to reset password")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// GetCurrentUser 获取当前用户
|
||||
func (h *UserHandler) GetCurrentUser(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "user not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 实时计算帖子数量
|
||||
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
// 如果获取失败,使用数据库中的值
|
||||
postsCount = int64(user.PostsCount)
|
||||
}
|
||||
|
||||
response.Success(c, dto.ConvertUserToDetailResponseWithPostsCount(user, int(postsCount)))
|
||||
}
|
||||
|
||||
// GetUserByID 获取指定用户
|
||||
func (h *UserHandler) GetUserByID(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
currentUserID := c.GetString("user_id")
|
||||
|
||||
// 获取用户信息,包含双向关注状态
|
||||
user, isFollowing, isFollowingMe, err := h.userService.GetUserByIDWithMutualFollowStatus(c.Request.Context(), id, currentUserID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "user not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 实时计算帖子数量
|
||||
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
// 如果获取失败,使用数据库中的值
|
||||
postsCount = int64(user.PostsCount)
|
||||
}
|
||||
|
||||
// 转换为响应格式,包含双向关注状态和实时计算的帖子数量
|
||||
userResponse := dto.ConvertUserToResponseWithMutualFollowAndPostsCount(user, isFollowing, isFollowingMe, int(postsCount))
|
||||
|
||||
response.Success(c, userResponse)
|
||||
}
|
||||
|
||||
// UpdateUser 更新用户
|
||||
func (h *UserHandler) UpdateUser(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
type UpdateRequest struct {
|
||||
Nickname string `json:"nickname"`
|
||||
Bio string `json:"bio"`
|
||||
Website string `json:"website"`
|
||||
Location string `json:"location"`
|
||||
Avatar string `json:"avatar"`
|
||||
Phone *string `json:"phone"`
|
||||
Email *string `json:"email"`
|
||||
}
|
||||
|
||||
var req UpdateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "user not found")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Nickname != "" {
|
||||
user.Nickname = req.Nickname
|
||||
}
|
||||
if req.Bio != "" {
|
||||
user.Bio = req.Bio
|
||||
}
|
||||
if req.Website != "" {
|
||||
user.Website = req.Website
|
||||
}
|
||||
if req.Location != "" {
|
||||
user.Location = req.Location
|
||||
}
|
||||
if req.Avatar != "" {
|
||||
user.Avatar = req.Avatar
|
||||
}
|
||||
if req.Phone != nil {
|
||||
user.Phone = req.Phone
|
||||
}
|
||||
if req.Email != nil {
|
||||
if user.Email == nil || *user.Email != *req.Email {
|
||||
user.EmailVerified = false
|
||||
}
|
||||
user.Email = req.Email
|
||||
}
|
||||
|
||||
err = h.userService.UpdateUser(c.Request.Context(), user)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to update user")
|
||||
return
|
||||
}
|
||||
|
||||
// 实时计算帖子数量
|
||||
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
// 如果获取失败,使用数据库中的值
|
||||
postsCount = int64(user.PostsCount)
|
||||
}
|
||||
|
||||
response.Success(c, dto.ConvertUserToDetailResponseWithPostsCount(user, int(postsCount)))
|
||||
}
|
||||
|
||||
// SendEmailVerifyCode 发送当前用户邮箱验证码
|
||||
func (h *UserHandler) SendEmailVerifyCode(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
type SendCodeRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
}
|
||||
var req SendCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.SendCurrentUserEmailVerifyCode(c.Request.Context(), userID, req.Email); err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to send email verify code")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// VerifyEmail 验证当前用户邮箱
|
||||
func (h *UserHandler) VerifyEmail(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
type VerifyEmailRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
VerificationCode string `json:"verification_code" binding:"required"`
|
||||
}
|
||||
var req VerifyEmailRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.VerifyCurrentUserEmail(c.Request.Context(), userID, req.Email, req.VerificationCode); err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to verify email")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// RefreshToken 刷新Token
|
||||
func (h *UserHandler) RefreshToken(c *gin.Context) {
|
||||
type RefreshRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
}
|
||||
|
||||
var req RefreshRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 refresh token
|
||||
claims, err := h.jwtService.ParseToken(req.RefreshToken)
|
||||
if err != nil {
|
||||
response.Unauthorized(c, "invalid refresh token")
|
||||
return
|
||||
}
|
||||
|
||||
// 生成新 token
|
||||
accessToken, _ := h.jwtService.GenerateAccessToken(claims.UserID, claims.Username)
|
||||
refreshToken, _ := h.jwtService.GenerateRefreshToken(claims.UserID, claims.Username)
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"token": accessToken,
|
||||
"refresh_token": refreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
// SetJWTService 设置JWT服务
|
||||
func (h *UserHandler) SetJWTService(jwtSvc *service.JWTService) {
|
||||
h.jwtService = jwtSvc
|
||||
}
|
||||
|
||||
// FollowUser 关注用户
|
||||
func (h *UserHandler) FollowUser(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
currentUserID := c.GetString("user_id")
|
||||
|
||||
if userID == currentUserID {
|
||||
response.BadRequest(c, "cannot follow yourself")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.userService.FollowUser(c.Request.Context(), currentUserID, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to follow user")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// UnfollowUser 取消关注用户
|
||||
func (h *UserHandler) UnfollowUser(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
currentUserID := c.GetString("user_id")
|
||||
|
||||
err := h.userService.UnfollowUser(c.Request.Context(), currentUserID, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to unfollow user")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// BlockUser 拉黑用户
|
||||
func (h *UserHandler) BlockUser(c *gin.Context) {
|
||||
targetUserID := c.Param("id")
|
||||
currentUserID := c.GetString("user_id")
|
||||
|
||||
if targetUserID == currentUserID {
|
||||
response.BadRequest(c, "cannot block yourself")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.userService.BlockUser(c.Request.Context(), currentUserID, targetUserID)
|
||||
if err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to block user")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// UnblockUser 取消拉黑
|
||||
func (h *UserHandler) UnblockUser(c *gin.Context) {
|
||||
targetUserID := c.Param("id")
|
||||
currentUserID := c.GetString("user_id")
|
||||
|
||||
if targetUserID == currentUserID {
|
||||
response.BadRequest(c, "cannot unblock yourself")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.userService.UnblockUser(c.Request.Context(), currentUserID, targetUserID)
|
||||
if err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to unblock user")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// GetBlockedUsers 获取黑名单列表
|
||||
func (h *UserHandler) GetBlockedUsers(c *gin.Context) {
|
||||
currentUserID := c.GetString("user_id")
|
||||
if currentUserID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
users, total, err := h.userService.GetBlockedUsers(c.Request.Context(), currentUserID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get blocked users")
|
||||
return
|
||||
}
|
||||
|
||||
userIDs := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
userIDs[i] = u.ID
|
||||
}
|
||||
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||
userResponses := dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
|
||||
response.Paginated(c, userResponses, total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetBlockStatus 获取拉黑状态
|
||||
func (h *UserHandler) GetBlockStatus(c *gin.Context) {
|
||||
targetUserID := c.Param("id")
|
||||
currentUserID := c.GetString("user_id")
|
||||
if currentUserID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
if targetUserID == "" {
|
||||
response.BadRequest(c, "target user id is required")
|
||||
return
|
||||
}
|
||||
|
||||
isBlocked, err := h.userService.IsBlocked(c.Request.Context(), currentUserID, targetUserID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get block status")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"is_blocked": isBlocked})
|
||||
}
|
||||
|
||||
// GetFollowingList 获取关注列表
|
||||
func (h *UserHandler) GetFollowingList(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
currentUserID := c.GetString("user_id")
|
||||
page := c.DefaultQuery("page", "1")
|
||||
pageSize := c.DefaultQuery("page_size", "20")
|
||||
|
||||
users, err := h.userService.GetFollowingList(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get following list")
|
||||
return
|
||||
}
|
||||
|
||||
// 如果已登录,获取双向关注状态和实时计算的帖子数量
|
||||
var userResponses []*dto.UserResponse
|
||||
if currentUserID != "" && len(users) > 0 {
|
||||
userIDs := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
userIDs[i] = u.ID
|
||||
}
|
||||
statusMap, _ := h.userService.GetMutualFollowStatus(c.Request.Context(), currentUserID, userIDs)
|
||||
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, statusMap, postsCountMap)
|
||||
} else if len(users) > 0 {
|
||||
userIDs := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
userIDs[i] = u.ID
|
||||
}
|
||||
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
|
||||
} else {
|
||||
userResponses = dto.ConvertUsersToResponse(users)
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"list": userResponses,
|
||||
})
|
||||
}
|
||||
|
||||
// GetFollowersList 获取粉丝列表
|
||||
func (h *UserHandler) GetFollowersList(c *gin.Context) {
|
||||
userID := c.Param("id")
|
||||
currentUserID := c.GetString("user_id")
|
||||
page := c.DefaultQuery("page", "1")
|
||||
pageSize := c.DefaultQuery("page_size", "20")
|
||||
|
||||
fmt.Printf("[DEBUG] GetFollowersList: userID=%s, currentUserID=%s\n", userID, currentUserID)
|
||||
|
||||
users, err := h.userService.GetFollowersList(c.Request.Context(), userID, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to get followers list")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] GetFollowersList: found %d users\n", len(users))
|
||||
|
||||
// 如果已登录,获取双向关注状态和实时计算的帖子数量
|
||||
var userResponses []*dto.UserResponse
|
||||
if currentUserID != "" && len(users) > 0 {
|
||||
userIDs := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
userIDs[i] = u.ID
|
||||
}
|
||||
fmt.Printf("[DEBUG] GetFollowersList: checking mutual follow status for userIDs=%v\n", userIDs)
|
||||
statusMap, _ := h.userService.GetMutualFollowStatus(c.Request.Context(), currentUserID, userIDs)
|
||||
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, statusMap, postsCountMap)
|
||||
} else if len(users) > 0 {
|
||||
userIDs := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
userIDs[i] = u.ID
|
||||
}
|
||||
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
|
||||
} else {
|
||||
userResponses = dto.ConvertUsersToResponse(users)
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"list": userResponses,
|
||||
})
|
||||
}
|
||||
|
||||
// CheckUsername 检查用户名是否可用
|
||||
func (h *UserHandler) CheckUsername(c *gin.Context) {
|
||||
username := c.Query("username")
|
||||
if username == "" {
|
||||
response.BadRequest(c, "username is required")
|
||||
return
|
||||
}
|
||||
|
||||
available, err := h.userService.CheckUsernameAvailable(c.Request.Context(), username)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to check username")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"available": available})
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
currentUserID := c.GetString("user_id")
|
||||
|
||||
type ChangePasswordRequest struct {
|
||||
OldPassword string `json:"old_password" binding:"required"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
VerificationCode string `json:"verification_code" binding:"required"`
|
||||
}
|
||||
|
||||
var req ChangePasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err := h.userService.ChangePassword(c.Request.Context(), currentUserID, req.OldPassword, req.NewPassword, req.VerificationCode)
|
||||
if err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to change password")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// SendChangePasswordCode 发送修改密码验证码
|
||||
func (h *UserHandler) SendChangePasswordCode(c *gin.Context) {
|
||||
currentUserID := c.GetString("user_id")
|
||||
if currentUserID == "" {
|
||||
response.Unauthorized(c, "")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.userService.SendChangePasswordCode(c.Request.Context(), currentUserID)
|
||||
if err != nil {
|
||||
if se, ok := err.(*service.ServiceError); ok {
|
||||
response.Error(c, se.Code, se.Message)
|
||||
return
|
||||
}
|
||||
response.InternalServerError(c, "failed to send change password code")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// Search 搜索用户
|
||||
func (h *UserHandler) Search(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
users, total, err := h.userService.Search(c.Request.Context(), keyword, page, pageSize)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "failed to search users")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取实时计算的帖子数量
|
||||
var userResponses []*dto.UserResponse
|
||||
if len(users) > 0 {
|
||||
userIDs := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
userIDs[i] = u.ID
|
||||
}
|
||||
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
|
||||
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
|
||||
} else {
|
||||
userResponses = dto.ConvertUsersToResponse(users)
|
||||
}
|
||||
|
||||
response.Paginated(c, userResponses, total, page, pageSize)
|
||||
}
|
||||
216
internal/handler/vote_handler.go
Normal file
216
internal/handler/vote_handler.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// VoteHandler 投票处理器
|
||||
type VoteHandler struct {
|
||||
voteService *service.VoteService
|
||||
postService *service.PostService
|
||||
}
|
||||
|
||||
// NewVoteHandler 创建投票处理器
|
||||
func NewVoteHandler(voteService *service.VoteService, postService *service.PostService) *VoteHandler {
|
||||
return &VoteHandler{
|
||||
voteService: voteService,
|
||||
postService: postService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateVotePost 创建投票帖子
|
||||
// POST /api/v1/posts/vote
|
||||
func (h *VoteHandler) CreateVotePost(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "请先登录")
|
||||
return
|
||||
}
|
||||
|
||||
var req dto.CreateVotePostRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
post, err := h.voteService.CreateVotePost(c.Request.Context(), userID, &req)
|
||||
if err != nil {
|
||||
var moderationErr *service.PostModerationRejectedError
|
||||
if errors.As(err, &moderationErr) {
|
||||
response.BadRequest(c, moderationErr.UserMessage())
|
||||
return
|
||||
}
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, post)
|
||||
}
|
||||
|
||||
// GetVoteResult 获取投票结果
|
||||
// GET /api/v1/posts/:id/vote
|
||||
func (h *VoteHandler) GetVoteResult(c *gin.Context) {
|
||||
postID := c.Param("id")
|
||||
if postID == "" {
|
||||
response.BadRequest(c, "帖子ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证帖子存在
|
||||
_, err := h.postService.GetByID(c.Request.Context(), postID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "帖子不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户ID(可选登录)
|
||||
userID := c.GetString("user_id")
|
||||
|
||||
// 如果用户未登录,返回不带has_voted的结果
|
||||
if userID == "" {
|
||||
options, err := h.voteService.GetVoteOptions(postID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "获取投票选项失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 计算总票数
|
||||
totalVotes := 0
|
||||
for _, option := range options {
|
||||
totalVotes += option.VotesCount
|
||||
}
|
||||
|
||||
result := &dto.VoteResultDTO{
|
||||
Options: options,
|
||||
TotalVotes: totalVotes,
|
||||
HasVoted: false,
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
return
|
||||
}
|
||||
|
||||
// 用户已登录,获取完整的投票结果
|
||||
result, err := h.voteService.GetVoteResult(postID, userID)
|
||||
if err != nil {
|
||||
response.InternalServerError(c, "获取投票结果失败")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// Vote 投票
|
||||
// POST /api/v1/posts/:id/vote
|
||||
func (h *VoteHandler) Vote(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "请先登录")
|
||||
return
|
||||
}
|
||||
|
||||
postID := c.Param("id")
|
||||
if postID == "" {
|
||||
response.BadRequest(c, "帖子ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证帖子存在
|
||||
_, err := h.postService.GetByID(c.Request.Context(), postID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "帖子不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求体
|
||||
var req struct {
|
||||
OptionID string `json:"option_id" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.voteService.Vote(c.Request.Context(), postID, userID, req.OptionID); err != nil {
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// Unvote 取消投票
|
||||
// DELETE /api/v1/posts/:id/vote
|
||||
func (h *VoteHandler) Unvote(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "请先登录")
|
||||
return
|
||||
}
|
||||
|
||||
postID := c.Param("id")
|
||||
if postID == "" {
|
||||
response.BadRequest(c, "帖子ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证帖子存在
|
||||
_, err := h.postService.GetByID(c.Request.Context(), postID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "帖子不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.voteService.Unvote(c.Request.Context(), postID, userID); err != nil {
|
||||
response.Error(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// UpdateVoteOption 更新投票选项(仅作者)
|
||||
// PUT /api/v1/vote-options/:id
|
||||
func (h *VoteHandler) UpdateVoteOption(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
response.Unauthorized(c, "请先登录")
|
||||
return
|
||||
}
|
||||
|
||||
optionID := c.Param("id")
|
||||
if optionID == "" {
|
||||
response.BadRequest(c, "选项ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求体
|
||||
var req struct {
|
||||
Content string `json:"content" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取帖子ID(从查询参数或请求体中获取)
|
||||
postID := c.Query("post_id")
|
||||
if postID == "" {
|
||||
response.BadRequest(c, "帖子ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.voteService.UpdateVoteOption(c.Request.Context(), postID, optionID, userID, req.Content); err != nil {
|
||||
response.Error(c, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
866
internal/handler/websocket_handler.go
Normal file
866
internal/handler/websocket_handler.go
Normal file
@@ -0,0 +1,866 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/model"
|
||||
ws "carrot_bbs/internal/pkg/websocket"
|
||||
"carrot_bbs/internal/repository"
|
||||
"carrot_bbs/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // 允许所有来源,生产环境应该限制
|
||||
},
|
||||
}
|
||||
|
||||
// WebSocketHandler WebSocket处理器
|
||||
type WebSocketHandler struct {
|
||||
jwtService *service.JWTService
|
||||
chatService service.ChatService
|
||||
groupService service.GroupService
|
||||
groupRepo repository.GroupRepository
|
||||
userRepo *repository.UserRepository
|
||||
wsManager *ws.WebSocketManager
|
||||
}
|
||||
|
||||
// NewWebSocketHandler 创建WebSocket处理器
|
||||
func NewWebSocketHandler(
|
||||
jwtService *service.JWTService,
|
||||
chatService service.ChatService,
|
||||
groupService service.GroupService,
|
||||
groupRepo repository.GroupRepository,
|
||||
userRepo *repository.UserRepository,
|
||||
wsManager *ws.WebSocketManager,
|
||||
) *WebSocketHandler {
|
||||
return &WebSocketHandler{
|
||||
jwtService: jwtService,
|
||||
chatService: chatService,
|
||||
groupService: groupService,
|
||||
groupRepo: groupRepo,
|
||||
userRepo: userRepo,
|
||||
wsManager: wsManager,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleWebSocket 处理WebSocket连接
|
||||
func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) {
|
||||
// 调试:打印请求头信息
|
||||
log.Printf("[WebSocket] 收到请求: Method=%s, Path=%s", c.Request.Method, c.Request.URL.Path)
|
||||
log.Printf("[WebSocket] 请求头: Connection=%s, Upgrade=%s",
|
||||
c.GetHeader("Connection"),
|
||||
c.GetHeader("Upgrade"))
|
||||
log.Printf("[WebSocket] Sec-WebSocket-Key=%s, Sec-WebSocket-Version=%s",
|
||||
c.GetHeader("Sec-WebSocket-Key"),
|
||||
c.GetHeader("Sec-WebSocket-Version"))
|
||||
|
||||
// 从query参数获取token
|
||||
token := c.Query("token")
|
||||
if token == "" {
|
||||
// 尝试从header获取
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
token = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing token"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token
|
||||
claims, err := h.jwtService.ParseToken(token)
|
||||
if err != nil {
|
||||
log.Printf("Invalid token: %v", err)
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
userID := claims.UserID
|
||||
if userID == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token claims"})
|
||||
return
|
||||
}
|
||||
|
||||
// 升级HTTP连接为WebSocket连接
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to upgrade connection: %v", err)
|
||||
log.Printf("[WebSocket] 请求详情 - User-Agent: %s, Content-Type: %s",
|
||||
c.GetHeader("User-Agent"),
|
||||
c.GetHeader("Content-Type"))
|
||||
return
|
||||
}
|
||||
|
||||
// 如果用户已在线,先注销旧连接
|
||||
if h.wsManager.IsUserOnline(userID) {
|
||||
log.Printf("[DEBUG] 用户 %s 已有在线连接,复用该连接", userID)
|
||||
} else {
|
||||
log.Printf("[DEBUG] 用户 %s 当前不在线,创建新连接", userID)
|
||||
}
|
||||
|
||||
// 创建客户端
|
||||
client := &ws.Client{
|
||||
ID: userID,
|
||||
UserID: userID,
|
||||
Conn: conn,
|
||||
Send: make(chan []byte, 256),
|
||||
Manager: h.wsManager,
|
||||
}
|
||||
|
||||
// 注册客户端
|
||||
h.wsManager.Register(client)
|
||||
|
||||
// 启动读写协程
|
||||
go client.WritePump()
|
||||
go h.handleMessages(client)
|
||||
|
||||
log.Printf("[DEBUG] WebSocket连接建立: userID=%s, 当前在线=%v", userID, h.wsManager.IsUserOnline(userID))
|
||||
}
|
||||
|
||||
// handleMessages 处理客户端消息
|
||||
// 针对移动端优化:增加超时时间到 120 秒,配合客户端 55 秒心跳
|
||||
func (h *WebSocketHandler) handleMessages(client *ws.Client) {
|
||||
defer func() {
|
||||
h.wsManager.Unregister(client)
|
||||
client.Conn.Close()
|
||||
}()
|
||||
|
||||
client.Conn.SetReadLimit(512 * 1024) // 512KB
|
||||
client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒
|
||||
client.Conn.SetPongHandler(func(string) error {
|
||||
client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒
|
||||
return nil
|
||||
})
|
||||
|
||||
// 心跳定时器 - 服务端主动 ping 间隔保持 30 秒
|
||||
pingTicker := time.NewTicker(30 * time.Second)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pingTicker.C:
|
||||
// 发送心跳
|
||||
if err := client.SendPing(); err != nil {
|
||||
log.Printf("Failed to send ping: %v", err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
_, message, err := client.Conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Printf("WebSocket error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var wsMsg ws.WSMessage
|
||||
if err := json.Unmarshal(message, &wsMsg); err != nil {
|
||||
log.Printf("Failed to unmarshal message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
h.processMessage(client, &wsMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processMessage 处理消息
|
||||
func (h *WebSocketHandler) processMessage(client *ws.Client, msg *ws.WSMessage) {
|
||||
switch msg.Type {
|
||||
case ws.MessageTypePing:
|
||||
// 响应心跳
|
||||
if err := client.SendPong(); err != nil {
|
||||
log.Printf("Failed to send pong: %v", err)
|
||||
}
|
||||
|
||||
case ws.MessageTypePong:
|
||||
// 客户端响应心跳
|
||||
|
||||
case ws.MessageTypeMessage:
|
||||
// 处理聊天消息
|
||||
h.handleChatMessage(client, msg)
|
||||
|
||||
case ws.MessageTypeTyping:
|
||||
// 处理正在输入状态
|
||||
h.handleTyping(client, msg)
|
||||
|
||||
case ws.MessageTypeRead:
|
||||
// 处理已读回执
|
||||
h.handleReadReceipt(client, msg)
|
||||
|
||||
// 群组消息处理
|
||||
case ws.MessageTypeGroupMessage:
|
||||
// 处理群消息
|
||||
h.handleGroupMessage(client, msg)
|
||||
|
||||
case ws.MessageTypeGroupTyping:
|
||||
// 处理群输入状态
|
||||
h.handleGroupTyping(client, msg)
|
||||
|
||||
case ws.MessageTypeGroupRead:
|
||||
// 处理群消息已读
|
||||
h.handleGroupReadReceipt(client, msg)
|
||||
|
||||
case ws.MessageTypeGroupRecall:
|
||||
// 处理群消息撤回
|
||||
h.handleGroupRecall(client, msg)
|
||||
|
||||
default:
|
||||
log.Printf("Unknown message type: %s", msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// handleChatMessage 处理聊天消息
|
||||
func (h *WebSocketHandler) handleChatMessage(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
log.Printf("Invalid message data format")
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG handleChatMessage] 完整data: %+v", data)
|
||||
|
||||
conversationIDStr, _ := data["conversationId"].(string)
|
||||
|
||||
if conversationIDStr == "" {
|
||||
log.Printf("Missing conversationId")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析会话ID
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
log.Printf("Invalid conversation ID: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 segments
|
||||
var segments model.MessageSegments
|
||||
if data["segments"] != nil {
|
||||
segmentsBytes, err := json.Marshal(data["segments"])
|
||||
if err == nil {
|
||||
json.Unmarshal(segmentsBytes, &segments)
|
||||
}
|
||||
}
|
||||
|
||||
// 从 segments 中提取回复消息ID
|
||||
replyToID := dto.GetReplyMessageID(segments)
|
||||
var replyToIDPtr *string
|
||||
if replyToID != "" {
|
||||
replyToIDPtr = &replyToID
|
||||
}
|
||||
|
||||
// 发送消息 - 使用 segments
|
||||
message, err := h.chatService.SendMessage(context.Background(), client.UserID, conversationID, segments, replyToIDPtr)
|
||||
if err != nil {
|
||||
log.Printf("Failed to send message: %v", err)
|
||||
// 发送错误消息
|
||||
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
||||
"error": "Failed to send message",
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(errorMsg)
|
||||
client.Send <- msgBytes
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 发送确认消息(使用 meta 事件格式,包含完整的消息内容)
|
||||
metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{
|
||||
"detail_type": ws.MetaDetailTypeAck,
|
||||
"conversation_id": conversationID,
|
||||
"id": message.ID,
|
||||
"user_id": client.UserID,
|
||||
"sender_id": client.UserID,
|
||||
"seq": message.Seq,
|
||||
"segments": message.Segments,
|
||||
"created_at": message.CreatedAt.UnixMilli(),
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(metaAckMsg)
|
||||
log.Printf("[DEBUG handleChatMessage] 私聊 ack 消息: %s", string(msgBytes))
|
||||
log.Printf("[DEBUG handleChatMessage] message.Segments 类型: %T, 值: %+v", message.Segments, message.Segments)
|
||||
client.Send <- msgBytes
|
||||
}
|
||||
}
|
||||
|
||||
// handleTyping 处理正在输入状态
|
||||
func (h *WebSocketHandler) handleTyping(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr, _ := data["conversationId"].(string)
|
||||
if conversationIDStr == "" {
|
||||
return
|
||||
}
|
||||
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 直接使用 string 类型的 userID
|
||||
h.chatService.SendTyping(context.Background(), client.UserID, conversationID)
|
||||
}
|
||||
|
||||
// handleReadReceipt 处理已读回执
|
||||
func (h *WebSocketHandler) handleReadReceipt(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
conversationIDStr, _ := data["conversationId"].(string)
|
||||
if conversationIDStr == "" {
|
||||
return
|
||||
}
|
||||
|
||||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取lastReadSeq
|
||||
lastReadSeq, _ := data["lastReadSeq"].(float64)
|
||||
if lastReadSeq == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 直接使用 string 类型的 userID 和 conversationID
|
||||
if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil {
|
||||
log.Printf("Failed to mark as read: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 群组消息处理 ====================
|
||||
|
||||
// handleGroupMessage 处理群消息
|
||||
func (h *WebSocketHandler) handleGroupMessage(client *ws.Client, msg *ws.WSMessage) {
|
||||
// 打印接收到的消息类型和数据,用于调试
|
||||
log.Printf("[handleGroupMessage] Received message type: %s", msg.Type)
|
||||
log.Printf("[handleGroupMessage] Message data: %+v", msg.Data)
|
||||
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
log.Printf("Invalid group message data format: data is not map[string]interface{}")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析群组ID(支持 camelCase 和 snake_case)
|
||||
var groupIDFloat float64
|
||||
groupID := "" // 使用 groupID 作为最终变量名
|
||||
if val, ok := data["groupId"].(float64); ok {
|
||||
groupIDFloat = val
|
||||
groupID = strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
|
||||
} else if val, ok := data["group_id"].(string); ok {
|
||||
groupID = val
|
||||
}
|
||||
|
||||
if groupID == "" {
|
||||
log.Printf("Missing groupId in group message")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析会话ID(支持 camelCase 和 snake_case)
|
||||
var conversationID string
|
||||
if val, ok := data["conversationId"].(string); ok {
|
||||
conversationID = val
|
||||
} else if val, ok := data["conversation_id"].(string); ok {
|
||||
conversationID = val
|
||||
}
|
||||
if conversationID == "" {
|
||||
log.Printf("Missing conversationId in group message")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析 segments
|
||||
var segments model.MessageSegments
|
||||
if data["segments"] != nil {
|
||||
segmentsBytes, err := json.Marshal(data["segments"])
|
||||
if err == nil {
|
||||
json.Unmarshal(segmentsBytes, &segments)
|
||||
}
|
||||
}
|
||||
|
||||
// 解析@用户列表(支持 camelCase 和 snake_case)
|
||||
var mentionUsers []uint64
|
||||
var mentionUsersInterface []interface{}
|
||||
if val, ok := data["mentionUsers"].([]interface{}); ok {
|
||||
mentionUsersInterface = val
|
||||
} else if val, ok := data["mention_users"].([]interface{}); ok {
|
||||
mentionUsersInterface = val
|
||||
}
|
||||
if len(mentionUsersInterface) > 0 {
|
||||
for _, uid := range mentionUsersInterface {
|
||||
if uidFloat, ok := uid.(float64); ok {
|
||||
mentionUsers = append(mentionUsers, uint64(uidFloat))
|
||||
} else if uidStr, ok := uid.(string); ok {
|
||||
// 处理字符串格式的用户ID
|
||||
if uidInt, err := strconv.ParseUint(uidStr, 10, 64); err == nil {
|
||||
mentionUsers = append(mentionUsers, uidInt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 解析@所有人(支持 camelCase 和 snake_case)
|
||||
var mentionAll bool
|
||||
if val, ok := data["mentionAll"].(bool); ok {
|
||||
mentionAll = val
|
||||
} else if val, ok := data["mention_all"].(bool); ok {
|
||||
mentionAll = val
|
||||
}
|
||||
|
||||
// 检查用户是否可以发送群消息(验证成员身份和禁言状态)
|
||||
// client.UserID 已经是 string 格式的 UUID
|
||||
if err := h.groupService.CanSendGroupMessage(client.UserID, groupID); err != nil {
|
||||
log.Printf("User cannot send group message: %v", err)
|
||||
// 发送错误消息
|
||||
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
||||
"error": "Cannot send group message",
|
||||
"reason": err.Error(),
|
||||
"type": "group_message_error",
|
||||
"groupId": groupID,
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(errorMsg)
|
||||
client.Send <- msgBytes
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 检查@所有人权限(只有群主和管理员可以@所有人)
|
||||
if mentionAll {
|
||||
if !h.groupService.IsGroupAdmin(client.UserID, groupID) {
|
||||
log.Printf("User %s has no permission to mention all in group %s", client.UserID, groupID)
|
||||
mentionAll = false // 取消@所有人标记
|
||||
}
|
||||
}
|
||||
|
||||
// 创建消息
|
||||
message := &model.Message{
|
||||
ConversationID: conversationID,
|
||||
SenderID: client.UserID,
|
||||
Segments: segments,
|
||||
Status: model.MessageStatusNormal,
|
||||
MentionAll: mentionAll,
|
||||
}
|
||||
|
||||
// 序列化mentionUsers为JSON
|
||||
if len(mentionUsers) > 0 {
|
||||
mentionUsersJSON, _ := json.Marshal(mentionUsers)
|
||||
message.MentionUsers = string(mentionUsersJSON)
|
||||
}
|
||||
|
||||
// 保存消息到数据库(只存库,不发私聊 WebSocket 帧,群消息通过 BroadcastGroupMessage 单独广播)
|
||||
savedMessage, err := h.chatService.SaveMessage(context.Background(), client.UserID, conversationID, segments, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to save group message: %v", err)
|
||||
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
||||
"error": "Failed to save group message",
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(errorMsg)
|
||||
client.Send <- msgBytes
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 更新消息的mention信息
|
||||
if len(mentionUsers) > 0 || mentionAll {
|
||||
message.ID = savedMessage.ID
|
||||
message.Seq = savedMessage.Seq
|
||||
}
|
||||
|
||||
// 构造群消息响应
|
||||
groupMsg := &ws.GroupMessage{
|
||||
ID: savedMessage.ID,
|
||||
ConversationID: conversationID,
|
||||
GroupID: groupID,
|
||||
SenderID: client.UserID,
|
||||
Seq: savedMessage.Seq,
|
||||
Segments: segments,
|
||||
MentionUsers: mentionUsers,
|
||||
MentionAll: mentionAll,
|
||||
CreatedAt: savedMessage.CreatedAt.UnixMilli(),
|
||||
}
|
||||
|
||||
// 广播消息给群组所有成员(排除发送者)
|
||||
h.BroadcastGroupMessage(groupID, groupMsg, client.UserID)
|
||||
|
||||
// 发送确认消息给发送者(作为meta事件)
|
||||
// 使用 meta 事件格式发送 ack
|
||||
log.Printf("[DEBUG HandleGroupMessageSend] 准备发送 ack 消息, userID=%s, messageID=%s, seq=%d",
|
||||
client.UserID, savedMessage.ID, savedMessage.Seq)
|
||||
|
||||
metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{
|
||||
"detail_type": ws.MetaDetailTypeAck,
|
||||
"conversation_id": conversationID,
|
||||
"group_id": groupID,
|
||||
"id": savedMessage.ID,
|
||||
"user_id": client.UserID,
|
||||
"sender_id": client.UserID,
|
||||
"seq": savedMessage.Seq,
|
||||
"segments": segments,
|
||||
"created_at": savedMessage.CreatedAt.UnixMilli(),
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(metaAckMsg)
|
||||
log.Printf("[DEBUG HandleGroupMessageSend] 发送 ack 消息到 channel, userID=%s, msg=%s",
|
||||
client.UserID, string(msgBytes))
|
||||
client.Send <- msgBytes
|
||||
} else {
|
||||
log.Printf("[ERROR HandleGroupMessageSend] client.Send 为 nil, userID=%s", client.UserID)
|
||||
}
|
||||
|
||||
// 处理@提及通知
|
||||
if len(mentionUsers) > 0 || mentionAll {
|
||||
// 提取文本正文(不含 @ 部分)
|
||||
textContent := dto.ExtractTextContentFromModel(segments)
|
||||
// 在通知内容前拼接被@的真实昵称,通过群成员列表查找
|
||||
mentionContent := h.buildMentionContent(groupID, mentionUsers, mentionAll, textContent)
|
||||
h.handleGroupMention(groupID, savedMessage.ID, client.UserID, mentionContent, mentionUsers, mentionAll)
|
||||
}
|
||||
}
|
||||
|
||||
// handleGroupTyping 处理群输入状态
|
||||
func (h *WebSocketHandler) handleGroupTyping(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
groupIDFloat, _ := data["groupId"].(float64)
|
||||
if groupIDFloat == 0 {
|
||||
return
|
||||
}
|
||||
groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
|
||||
|
||||
isTyping, _ := data["isTyping"].(bool)
|
||||
|
||||
// 验证用户是否是群成员
|
||||
// client.UserID 已经是 string 格式的 UUID
|
||||
isMember, err := h.groupRepo.IsMember(groupID, client.UserID)
|
||||
if err != nil || !isMember {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := h.userRepo.GetByID(client.UserID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 构造输入状态消息
|
||||
typingMsg := &ws.GroupTypingMessage{
|
||||
GroupID: groupID,
|
||||
UserID: client.UserID,
|
||||
Username: user.Username,
|
||||
IsTyping: isTyping,
|
||||
}
|
||||
|
||||
// 广播给群组其他成员
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg)
|
||||
h.BroadcastGroupNoticeExclude(groupID, wsMsg, client.UserID)
|
||||
}
|
||||
|
||||
// handleGroupReadReceipt 处理群消息已读回执
|
||||
func (h *WebSocketHandler) handleGroupReadReceipt(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
conversationID, _ := data["conversationId"].(string)
|
||||
if conversationID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
lastReadSeq, _ := data["lastReadSeq"].(float64)
|
||||
if lastReadSeq == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 标记已读
|
||||
if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil {
|
||||
log.Printf("Failed to mark group message as read: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleGroupRecall 处理群消息撤回
|
||||
func (h *WebSocketHandler) handleGroupRecall(client *ws.Client, msg *ws.WSMessage) {
|
||||
data, ok := msg.Data.(map[string]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
messageID, _ := data["messageId"].(string)
|
||||
if messageID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
groupIDFloat, _ := data["groupId"].(float64)
|
||||
if groupIDFloat == 0 {
|
||||
return
|
||||
}
|
||||
groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
|
||||
|
||||
// 撤回消息
|
||||
if err := h.chatService.RecallMessage(context.Background(), messageID, client.UserID); err != nil {
|
||||
log.Printf("Failed to recall group message: %v", err)
|
||||
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
|
||||
"error": "Failed to recall message",
|
||||
})
|
||||
if client.Send != nil {
|
||||
msgBytes, _ := json.Marshal(errorMsg)
|
||||
client.Send <- msgBytes
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 广播撤回通知给群组所有成员
|
||||
recallNotice := ws.CreateWSMessage(ws.MessageTypeGroupRecall, map[string]interface{}{
|
||||
"messageId": messageID,
|
||||
"groupId": groupID,
|
||||
"userId": client.UserID,
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
})
|
||||
h.BroadcastGroupNotice(groupID, recallNotice)
|
||||
}
|
||||
|
||||
// handleGroupMention 处理群消息@提及通知
|
||||
func (h *WebSocketHandler) handleGroupMention(groupID string, messageID, senderID, content string, mentionUsers []uint64, mentionAll bool) {
|
||||
// 如果@所有人,获取所有群成员
|
||||
if mentionAll {
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for mention all: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, member := range members {
|
||||
// 不通知发送者自己
|
||||
memberIDStr := member.UserID
|
||||
if memberIDStr == senderID {
|
||||
continue
|
||||
}
|
||||
|
||||
// 发送@提及通知
|
||||
mentionMsg := &ws.GroupMentionMessage{
|
||||
GroupID: groupID,
|
||||
MessageID: messageID,
|
||||
FromUserID: senderID,
|
||||
Content: truncateContent(content, 50),
|
||||
MentionAll: true,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg)
|
||||
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 处理特定用户的@提及
|
||||
for _, userID := range mentionUsers {
|
||||
// userID 是 uint64,转换为 string
|
||||
userIDStr := strconv.FormatUint(userID, 10)
|
||||
if userIDStr == senderID {
|
||||
continue // 不通知发送者自己
|
||||
}
|
||||
|
||||
mentionMsg := &ws.GroupMentionMessage{
|
||||
GroupID: groupID,
|
||||
MessageID: messageID,
|
||||
FromUserID: senderID,
|
||||
Content: truncateContent(content, 50),
|
||||
MentionAll: false,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg)
|
||||
h.wsManager.SendToUser(userIDStr, wsMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// buildMentionContent 构建@提及通知的内容,通过群成员列表查找被@用户的真实昵称
|
||||
func (h *WebSocketHandler) buildMentionContent(groupID string, mentionUsers []uint64, mentionAll bool, textBody string) string {
|
||||
var prefix string
|
||||
if mentionAll {
|
||||
prefix = "@所有人 "
|
||||
} else if len(mentionUsers) > 0 {
|
||||
// 查询群成员列表,找到被@用户的昵称
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err == nil {
|
||||
memberNickMap := make(map[string]string, len(members))
|
||||
for _, m := range members {
|
||||
displayName := m.Nickname
|
||||
if displayName == "" {
|
||||
displayName = m.UserID
|
||||
}
|
||||
memberNickMap[m.UserID] = displayName
|
||||
}
|
||||
for _, uid := range mentionUsers {
|
||||
uidStr := strconv.FormatUint(uid, 10)
|
||||
if name, ok := memberNickMap[uidStr]; ok {
|
||||
prefix += "@" + name + " "
|
||||
} else {
|
||||
prefix += "@某人 "
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for range mentionUsers {
|
||||
prefix += "@某人 "
|
||||
}
|
||||
}
|
||||
}
|
||||
return prefix + textBody
|
||||
}
|
||||
|
||||
// BroadcastGroupMessage 向群组所有成员广播消息
|
||||
func (h *WebSocketHandler) BroadcastGroupMessage(groupID string, message *ws.GroupMessage, excludeUserID string) {
|
||||
// 获取群组所有成员
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for broadcast: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建WebSocket消息
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMessage, message)
|
||||
|
||||
// 遍历成员,如果在线则发送消息
|
||||
for _, member := range members {
|
||||
memberIDStr := member.UserID
|
||||
|
||||
// 排除发送者
|
||||
if memberIDStr == excludeUserID {
|
||||
continue
|
||||
}
|
||||
|
||||
// 发送消息
|
||||
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastGroupNotice 广播群组通知给所有成员
|
||||
func (h *WebSocketHandler) BroadcastGroupNotice(groupID string, notice *ws.WSMessage) {
|
||||
// 获取群组所有成员
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for notice broadcast: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 遍历成员,如果在线则发送通知
|
||||
for _, member := range members {
|
||||
memberIDStr := member.UserID
|
||||
h.wsManager.SendToUser(memberIDStr, notice)
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastGroupNoticeExclude 广播群组通知给所有成员(排除指定用户)
|
||||
func (h *WebSocketHandler) BroadcastGroupNoticeExclude(groupID string, notice *ws.WSMessage, excludeUserID string) {
|
||||
// 获取群组所有成员
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for notice broadcast: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 遍历成员,如果在线则发送通知
|
||||
for _, member := range members {
|
||||
memberIDStr := member.UserID
|
||||
if memberIDStr == excludeUserID {
|
||||
continue
|
||||
}
|
||||
h.wsManager.SendToUser(memberIDStr, notice)
|
||||
}
|
||||
}
|
||||
|
||||
// SendGroupMemberNotice 发送群成员变动通知
|
||||
func (h *WebSocketHandler) SendGroupMemberNotice(noticeType string, groupID string, data *ws.GroupNoticeData) {
|
||||
notice := &ws.GroupNoticeMessage{
|
||||
NoticeType: noticeType,
|
||||
GroupID: groupID,
|
||||
Data: data,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupNotice, notice)
|
||||
h.BroadcastGroupNotice(groupID, wsMsg)
|
||||
}
|
||||
|
||||
// truncateContent 截断内容
|
||||
func truncateContent(content string, maxLen int) string {
|
||||
if len(content) <= maxLen {
|
||||
return content
|
||||
}
|
||||
return content[:maxLen] + "..."
|
||||
}
|
||||
|
||||
// BroadcastGroupTyping 向群组所有成员广播输入状态
|
||||
func (h *WebSocketHandler) BroadcastGroupTyping(groupID string, typingMsg *ws.GroupTypingMessage, excludeUserID string) {
|
||||
// 获取群组所有成员
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for typing broadcast: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建WebSocket消息
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg)
|
||||
|
||||
// 遍历成员,如果在线则发送消息
|
||||
for _, member := range members {
|
||||
memberIDStr := member.UserID
|
||||
|
||||
// 排除指定用户
|
||||
if memberIDStr == excludeUserID {
|
||||
continue
|
||||
}
|
||||
|
||||
// 发送消息
|
||||
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastGroupRead 向群组所有成员广播已读状态
|
||||
func (h *WebSocketHandler) BroadcastGroupRead(groupID string, readMsg map[string]interface{}, excludeUserID string) {
|
||||
// 获取群组所有成员
|
||||
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get group members for read broadcast: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建WebSocket消息
|
||||
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupRead, readMsg)
|
||||
|
||||
// 遍历成员,如果在线则发送消息
|
||||
for _, member := range members {
|
||||
memberIDStr := member.UserID
|
||||
|
||||
// 排除指定用户
|
||||
if memberIDStr == excludeUserID {
|
||||
continue
|
||||
}
|
||||
|
||||
// 发送消息
|
||||
h.wsManager.SendToUser(memberIDStr, wsMsg)
|
||||
}
|
||||
}
|
||||
95
internal/middleware/auth.go
Normal file
95
internal/middleware/auth.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/pkg/response"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// Auth 认证中间件
|
||||
func Auth(jwtService *service.JWTService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
fmt.Printf("[DEBUG] Auth middleware: Authorization header = %q\n", authHeader)
|
||||
|
||||
if authHeader == "" {
|
||||
fmt.Printf("[DEBUG] Auth middleware: no Authorization header, returning 401\n")
|
||||
response.Unauthorized(c, "authorization header is required")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 提取Token
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
fmt.Printf("[DEBUG] Auth middleware: invalid Authorization header format\n")
|
||||
response.Unauthorized(c, "invalid authorization header format")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token := parts[1]
|
||||
fmt.Printf("[DEBUG] Auth middleware: token = %q\n", token[:min(20, len(token))]+"...")
|
||||
|
||||
// 验证Token
|
||||
claims, err := jwtService.ParseToken(token)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] Auth middleware: failed to parse token: %v\n", err)
|
||||
response.Unauthorized(c, "invalid token")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] Auth middleware: parsed claims, user_id = %q, username = %q\n", claims.UserID, claims.Username)
|
||||
|
||||
// 将用户信息存入上下文
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// OptionalAuth 可选认证中间件
|
||||
func OptionalAuth(jwtService *service.JWTService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 提取Token
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
token := parts[1]
|
||||
|
||||
// 验证Token
|
||||
claims, err := jwtService.ParseToken(token)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息存入上下文
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
46
internal/middleware/cors.go
Normal file
46
internal/middleware/cors.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS CORS中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取请求路径
|
||||
path := c.Request.URL.Path
|
||||
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
// 添加 WebSocket 升级所需的头
|
||||
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization, Connection, Upgrade, Sec-WebSocket-Key, Sec-WebSocket-Version, Sec-WebSocket-Protocol, Sec-WebSocket-Extensions")
|
||||
c.Header("Access-Control-Expose-Headers", "Content-Length, Connection, Upgrade")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
// 处理 WebSocket 升级请求的预检
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
log.Printf("[CORS] OPTIONS 预检请求: %s", path)
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
// 针对 WebSocket 路径的特殊处理
|
||||
if path == "/ws" {
|
||||
connection := c.GetHeader("Connection")
|
||||
upgrade := c.GetHeader("Upgrade")
|
||||
log.Printf("[CORS] WebSocket 请求: Connection=%s, Upgrade=%s", connection, upgrade)
|
||||
|
||||
// 检查是否是有效的 WebSocket 升级请求
|
||||
if strings.Contains(strings.ToLower(connection), "upgrade") && strings.ToLower(upgrade) == "websocket" {
|
||||
log.Printf("[CORS] 有效的 WebSocket 升级请求")
|
||||
} else {
|
||||
log.Printf("[CORS] 警告: 不是有效的 WebSocket 升级请求!")
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
49
internal/middleware/logger.go
Normal file
49
internal/middleware/logger.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Logger 日志中间件
|
||||
func Logger(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
|
||||
c.Next()
|
||||
|
||||
latency := time.Since(start)
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
logger.Info("request",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", path),
|
||||
zap.Int("status", statusCode),
|
||||
zap.Duration("latency", latency),
|
||||
zap.String("ip", c.ClientIP()),
|
||||
zap.String("user-agent", c.Request.UserAgent()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Recovery 恢复中间件
|
||||
func Recovery(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Error("panic recovered",
|
||||
zap.Any("error", err),
|
||||
)
|
||||
c.JSON(500, gin.H{
|
||||
"code": 500,
|
||||
"message": "internal server error",
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
102
internal/middleware/ratelimit.go
Normal file
102
internal/middleware/ratelimit.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RateLimiter 限流器
|
||||
type RateLimiter struct {
|
||||
requests map[string][]time.Time
|
||||
mu sync.Mutex
|
||||
limit int
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建限流器
|
||||
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
requests: make(map[string][]time.Time),
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
|
||||
// 定期清理过期的记录
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(window)
|
||||
rl.cleanup()
|
||||
}
|
||||
}()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// cleanup 清理过期的记录
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, times := range rl.requests {
|
||||
var valid []time.Time
|
||||
for _, t := range times {
|
||||
if now.Sub(t) < rl.window {
|
||||
valid = append(valid, t)
|
||||
}
|
||||
}
|
||||
if len(valid) == 0 {
|
||||
delete(rl.requests, key)
|
||||
} else {
|
||||
rl.requests[key] = valid
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isAllowed 检查是否允许请求
|
||||
func (rl *RateLimiter) isAllowed(key string) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
times := rl.requests[key]
|
||||
|
||||
// 过滤掉过期的
|
||||
var valid []time.Time
|
||||
for _, t := range times {
|
||||
if now.Sub(t) < rl.window {
|
||||
valid = append(valid, t)
|
||||
}
|
||||
}
|
||||
|
||||
if len(valid) >= rl.limit {
|
||||
rl.requests[key] = valid
|
||||
return false
|
||||
}
|
||||
|
||||
rl.requests[key] = append(valid, now)
|
||||
return true
|
||||
}
|
||||
|
||||
// RateLimit 限流中间件
|
||||
func RateLimit(requestsPerMinute int) gin.HandlerFunc {
|
||||
limiter := NewRateLimiter(requestsPerMinute, time.Minute)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
|
||||
if !limiter.isAllowed(ip) {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"code": 429,
|
||||
"message": "too many requests",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
118
internal/model/audit_log.go
Normal file
118
internal/model/audit_log.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AuditTargetType 审核对象类型
|
||||
type AuditTargetType string
|
||||
|
||||
const (
|
||||
AuditTargetTypePost AuditTargetType = "post" // 帖子
|
||||
AuditTargetTypeComment AuditTargetType = "comment" // 评论
|
||||
AuditTargetTypeMessage AuditTargetType = "message" // 私信
|
||||
AuditTargetTypeUser AuditTargetType = "user" // 用户资料
|
||||
AuditTargetTypeImage AuditTargetType = "image" // 图片
|
||||
)
|
||||
|
||||
// AuditResult 审核结果
|
||||
type AuditResult string
|
||||
|
||||
const (
|
||||
AuditResultPass AuditResult = "pass" // 通过
|
||||
AuditResultReview AuditResult = "review" // 需人工复审
|
||||
AuditResultBlock AuditResult = "block" // 违规拦截
|
||||
AuditResultUnknown AuditResult = "unknown" // 未知
|
||||
)
|
||||
|
||||
// AuditRiskLevel 风险等级
|
||||
type AuditRiskLevel string
|
||||
|
||||
const (
|
||||
AuditRiskLevelLow AuditRiskLevel = "low" // 低风险
|
||||
AuditRiskLevelMedium AuditRiskLevel = "medium" // 中风险
|
||||
AuditRiskLevelHigh AuditRiskLevel = "high" // 高风险
|
||||
)
|
||||
|
||||
// AuditSource 审核来源
|
||||
type AuditSource string
|
||||
|
||||
const (
|
||||
AuditSourceAuto AuditSource = "auto" // 自动审核
|
||||
AuditSourceManual AuditSource = "manual" // 人工审核
|
||||
AuditSourceCallback AuditSource = "callback" // 回调审核
|
||||
)
|
||||
|
||||
// AuditLog 审核日志实体
|
||||
type AuditLog struct {
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
TargetType AuditTargetType `json:"target_type" gorm:"type:varchar(50);index"`
|
||||
TargetID string `json:"target_id" gorm:"type:varchar(255);index"`
|
||||
Content string `json:"content" gorm:"type:text"` // 待审核内容
|
||||
ContentType string `json:"content_type" gorm:"type:varchar(50)"` // 内容类型: text, image
|
||||
ContentURL string `json:"content_url" gorm:"type:text"` // 图片/文件URL
|
||||
AuditType string `json:"audit_type" gorm:"type:varchar(50)"` // 审核类型: porn, violence, ad, political, fraud, gamble
|
||||
Result AuditResult `json:"result" gorm:"type:varchar(50);index"`
|
||||
RiskLevel AuditRiskLevel `json:"risk_level" gorm:"type:varchar(20)"`
|
||||
Labels string `json:"labels" gorm:"type:text"` // JSON数组,标签列表
|
||||
Suggestion string `json:"suggestion" gorm:"type:varchar(50)"` // pass, review, block
|
||||
Detail string `json:"detail" gorm:"type:text"` // 详细说明
|
||||
ThirdPartyID string `json:"third_party_id" gorm:"type:varchar(255)"` // 第三方审核服务返回的ID
|
||||
Source AuditSource `json:"source" gorm:"type:varchar(20);default:auto"`
|
||||
ReviewerID string `json:"reviewer_id" gorm:"type:varchar(255)"` // 审核人ID(人工审核时使用)
|
||||
ReviewerName string `json:"reviewer_name" gorm:"type:varchar(100)"` // 审核人名称
|
||||
ReviewTime *time.Time `json:"review_time" gorm:"index"` // 审核时间
|
||||
UserID string `json:"user_id" gorm:"type:varchar(255);index"` // 内容发布者ID
|
||||
UserIP string `json:"user_ip" gorm:"type:varchar(45)"` // 用户IP
|
||||
Status string `json:"status" gorm:"type:varchar(20);default:pending"` // pending, completed, failed
|
||||
RejectReason string `json:"reject_reason" gorm:"type:text"` // 拒绝原因(人工审核时使用)
|
||||
ExtraData string `json:"extra_data" gorm:"type:text"` // 额外数据,JSON格式
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成UUID
|
||||
func (al *AuditLog) BeforeCreate(tx *gorm.DB) error {
|
||||
if al.ID == "" {
|
||||
al.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (AuditLog) TableName() string {
|
||||
return "audit_logs"
|
||||
}
|
||||
|
||||
// AuditLogRequest 创建审核日志请求
|
||||
type AuditLogRequest struct {
|
||||
TargetType AuditTargetType `json:"target_type" validate:"required"`
|
||||
TargetID string `json:"target_id" validate:"required"`
|
||||
Content string `json:"content"`
|
||||
ContentType string `json:"content_type"`
|
||||
ContentURL string `json:"content_url"`
|
||||
AuditType string `json:"audit_type"`
|
||||
UserID string `json:"user_id"`
|
||||
UserIP string `json:"user_ip"`
|
||||
}
|
||||
|
||||
// AuditLogListItem 审核日志列表项
|
||||
type AuditLogListItem struct {
|
||||
ID string `json:"id"`
|
||||
TargetType AuditTargetType `json:"target_type"`
|
||||
TargetID string `json:"target_id"`
|
||||
Content string `json:"content"`
|
||||
ContentType string `json:"content_type"`
|
||||
Result AuditResult `json:"result"`
|
||||
RiskLevel AuditRiskLevel `json:"risk_level"`
|
||||
Suggestion string `json:"suggestion"`
|
||||
Source AuditSource `json:"source"`
|
||||
ReviewerID string `json:"reviewer_id"`
|
||||
ReviewTime *time.Time `json:"review_time"`
|
||||
UserID string `json:"user_id"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
80
internal/model/comment.go
Normal file
80
internal/model/comment.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CommentStatus 评论状态
|
||||
type CommentStatus string
|
||||
|
||||
const (
|
||||
CommentStatusDraft CommentStatus = "draft"
|
||||
CommentStatusPending CommentStatus = "pending"
|
||||
CommentStatusPublished CommentStatus = "published"
|
||||
CommentStatusRejected CommentStatus = "rejected"
|
||||
CommentStatusDeleted CommentStatus = "deleted"
|
||||
)
|
||||
|
||||
// Comment 评论实体
|
||||
type Comment struct {
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
PostID string `json:"post_id" gorm:"type:varchar(36);not null;index:idx_comments_post_parent_status_created,priority:1"`
|
||||
UserID string `json:"user_id" gorm:"type:varchar(36);index;not null"`
|
||||
ParentID *string `json:"parent_id" gorm:"type:varchar(36);index:idx_comments_post_parent_status_created,priority:2"` // 父评论 ID(支持嵌套)
|
||||
RootID *string `json:"root_id" gorm:"type:varchar(36);index:idx_comments_root_status_created,priority:1"` // 根评论 ID(用于高效查询)
|
||||
Content string `json:"content" gorm:"type:text;not null"`
|
||||
Images string `json:"images" gorm:"type:text"` // 图片URL列表,JSON数组格式
|
||||
|
||||
// 关联
|
||||
User *User `json:"-" gorm:"foreignKey:UserID"`
|
||||
Replies []*Comment `json:"-" gorm:"-"` // 子回复(手动加载,非 GORM 关联)
|
||||
|
||||
// 审核状态
|
||||
Status CommentStatus `json:"status" gorm:"type:varchar(20);default:published;index:idx_comments_post_parent_status_created,priority:3;index:idx_comments_root_status_created,priority:2"`
|
||||
|
||||
// 统计
|
||||
LikesCount int `json:"likes_count" gorm:"default:0"`
|
||||
RepliesCount int `json:"replies_count" gorm:"default:0"`
|
||||
|
||||
// 软删除
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
|
||||
// 时间戳
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_comments_post_parent_status_created,priority:4,sort:asc;index:idx_comments_root_status_created,priority:3,sort:asc"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成UUID
|
||||
func (c *Comment) BeforeCreate(tx *gorm.DB) error {
|
||||
if c.ID == "" {
|
||||
c.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (Comment) TableName() string {
|
||||
return "comments"
|
||||
}
|
||||
|
||||
// CommentLike 评论点赞
|
||||
type CommentLike struct {
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
CommentID string `json:"comment_id" gorm:"type:varchar(36);index;not null"`
|
||||
UserID string `json:"user_id" gorm:"type:varchar(36);index;not null"`
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成UUID
|
||||
func (cl *CommentLike) BeforeCreate(tx *gorm.DB) error {
|
||||
if cl.ID == "" {
|
||||
cl.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (CommentLike) TableName() string {
|
||||
return "comment_likes"
|
||||
}
|
||||
68
internal/model/conversation.go
Normal file
68
internal/model/conversation.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
)
|
||||
|
||||
// ConversationType 会话类型
|
||||
type ConversationType string
|
||||
|
||||
const (
|
||||
ConversationTypePrivate ConversationType = "private" // 私聊
|
||||
ConversationTypeGroup ConversationType = "group" // 群聊
|
||||
ConversationTypeSystem ConversationType = "system" // 系统通知会话
|
||||
)
|
||||
|
||||
// Conversation 会话实体
|
||||
// 使用雪花算法ID(作为string存储)和seq机制实现消息排序和增量同步
|
||||
type Conversation struct {
|
||||
ID string `gorm:"primaryKey;size:20" json:"id"` // 雪花算法ID(转为string避免精度丢失)
|
||||
Type ConversationType `gorm:"type:varchar(20);default:'private'" json:"type"` // 会话类型
|
||||
GroupID *string `gorm:"index" json:"group_id,omitempty"` // 关联的群组ID(群聊时使用,string类型避免JS精度丢失);使用指针支持NULL值
|
||||
LastSeq int64 `gorm:"default:0" json:"last_seq"` // 最后一条消息的seq
|
||||
LastMsgTime *time.Time `json:"last_msg_time,omitempty"` // 最后消息时间
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
|
||||
// 关联 - 使用 polymorphic 模式避免外键约束问题
|
||||
Participants []ConversationParticipant `gorm:"foreignKey:ConversationID" json:"participants,omitempty"`
|
||||
Group *Group `gorm:"foreignKey:GroupID;references:ID" json:"group,omitempty"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成雪花算法ID
|
||||
func (c *Conversation) BeforeCreate(tx *gorm.DB) error {
|
||||
if c.ID == "" {
|
||||
id, err := utils.GetSnowflake().GenerateID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.ID = strconv.FormatInt(id, 10)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (Conversation) TableName() string {
|
||||
return "conversations"
|
||||
}
|
||||
|
||||
// ConversationParticipant 会话参与者
|
||||
type ConversationParticipant struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
ConversationID string `gorm:"not null;size:20;uniqueIndex:idx_conversation_user,priority:1;index:idx_cp_conversation_user,priority:1" json:"conversation_id"` // 雪花算法ID(string类型)
|
||||
UserID string `gorm:"column:user_id;type:varchar(50);not null;uniqueIndex:idx_conversation_user,priority:2;index:idx_cp_conversation_user,priority:2;index:idx_cp_user_hidden_pinned_updated,priority:1" json:"user_id"` // UUID格式,与JWT中user_id保持一致
|
||||
LastReadSeq int64 `gorm:"default:0" json:"last_read_seq"` // 已读到的seq位置
|
||||
Muted bool `gorm:"default:false" json:"muted"` // 是否免打扰
|
||||
IsPinned bool `gorm:"default:false;index:idx_cp_user_hidden_pinned_updated,priority:3" json:"is_pinned"` // 是否置顶会话(用户维度)
|
||||
HiddenAt *time.Time `gorm:"index:idx_cp_user_hidden_pinned_updated,priority:2" json:"hidden_at,omitempty"` // 仅自己删除会话时使用,收到新消息后自动恢复
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime;index:idx_cp_user_hidden_pinned_updated,priority:4,sort:desc"`
|
||||
}
|
||||
|
||||
func (ConversationParticipant) TableName() string {
|
||||
return "conversation_participants"
|
||||
}
|
||||
94
internal/model/device_token.go
Normal file
94
internal/model/device_token.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
)
|
||||
|
||||
// DeviceType 设备类型
|
||||
type DeviceType string
|
||||
|
||||
const (
|
||||
DeviceTypeIOS DeviceType = "ios" // iOS设备
|
||||
DeviceTypeAndroid DeviceType = "android" // Android设备
|
||||
DeviceTypeWeb DeviceType = "web" // Web端
|
||||
)
|
||||
|
||||
// DeviceToken 设备Token实体
|
||||
// 用于管理用户的多设备推送Token
|
||||
type DeviceToken struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement:false" json:"id"` // 雪花算法ID
|
||||
UserID string `gorm:"column:user_id;type:varchar(50);index;not null" json:"user_id"` // 用户ID (UUID格式)
|
||||
DeviceID string `gorm:"type:varchar(100);not null" json:"device_id"` // 设备唯一标识
|
||||
DeviceType DeviceType `gorm:"type:varchar(20);not null" json:"device_type"` // 设备类型
|
||||
PushToken string `gorm:"type:varchar(256);not null" json:"push_token"` // 推送Token(FCM/APNs等)
|
||||
IsActive bool `gorm:"default:true" json:"is_active"` // 是否活跃
|
||||
DeviceName string `gorm:"type:varchar(100)" json:"device_name,omitempty"` // 设备名称(可选)
|
||||
|
||||
// 时间戳
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"` // 最后使用时间
|
||||
|
||||
// 软删除
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
|
||||
// 时间戳
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成雪花算法ID
|
||||
func (d *DeviceToken) BeforeCreate(tx *gorm.DB) error {
|
||||
if d.ID == 0 {
|
||||
id, err := utils.GetSnowflake().GenerateID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.ID = id
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (DeviceToken) TableName() string {
|
||||
return "device_tokens"
|
||||
}
|
||||
|
||||
// UpdateLastUsed 更新最后使用时间
|
||||
func (d *DeviceToken) UpdateLastUsed() {
|
||||
now := time.Now()
|
||||
d.LastUsedAt = &now
|
||||
}
|
||||
|
||||
// Deactivate 停用设备
|
||||
func (d *DeviceToken) Deactivate() {
|
||||
d.IsActive = false
|
||||
}
|
||||
|
||||
// Activate 激活设备
|
||||
func (d *DeviceToken) Activate() {
|
||||
d.IsActive = true
|
||||
now := time.Now()
|
||||
d.LastUsedAt = &now
|
||||
}
|
||||
|
||||
// IsIOS 判断是否为iOS设备
|
||||
func (d *DeviceToken) IsIOS() bool {
|
||||
return d.DeviceType == DeviceTypeIOS
|
||||
}
|
||||
|
||||
// IsAndroid 判断是否为Android设备
|
||||
func (d *DeviceToken) IsAndroid() bool {
|
||||
return d.DeviceType == DeviceTypeAndroid
|
||||
}
|
||||
|
||||
// IsWeb 判断是否为Web端
|
||||
func (d *DeviceToken) IsWeb() bool {
|
||||
return d.DeviceType == DeviceTypeWeb
|
||||
}
|
||||
|
||||
// SupportsMobilePush 判断是否支持手机推送
|
||||
func (d *DeviceToken) SupportsMobilePush() bool {
|
||||
return d.DeviceType == DeviceTypeIOS || d.DeviceType == DeviceTypeAndroid
|
||||
}
|
||||
28
internal/model/favorite.go
Normal file
28
internal/model/favorite.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Favorite 收藏
|
||||
type Favorite struct {
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
PostID string `json:"post_id" gorm:"type:varchar(36);not null;index;uniqueIndex:idx_favorite_post_user,priority:1"`
|
||||
UserID string `json:"user_id" gorm:"type:varchar(36);not null;index;uniqueIndex:idx_favorite_post_user,priority:2"`
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成UUID
|
||||
func (f *Favorite) BeforeCreate(tx *gorm.DB) error {
|
||||
if f.ID == "" {
|
||||
f.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (Favorite) TableName() string {
|
||||
return "favorites"
|
||||
}
|
||||
28
internal/model/follow.go
Normal file
28
internal/model/follow.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Follow 关注关系
|
||||
type Follow struct {
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
FollowerID string `json:"follower_id" gorm:"type:varchar(36);index;not null;uniqueIndex:idx_follower_following"` // 关注者
|
||||
FollowingID string `json:"following_id" gorm:"type:varchar(36);index;not null;uniqueIndex:idx_follower_following"` // 被关注者
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成UUID
|
||||
func (f *Follow) BeforeCreate(tx *gorm.DB) error {
|
||||
if f.ID == "" {
|
||||
f.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (Follow) TableName() string {
|
||||
return "follows"
|
||||
}
|
||||
57
internal/model/group.go
Normal file
57
internal/model/group.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// JoinType 群组加入类型
|
||||
type JoinType int
|
||||
|
||||
const (
|
||||
JoinTypeAnyone JoinType = 0 // 允许任何人加入
|
||||
JoinTypeApproval JoinType = 1 // 需要审批
|
||||
JoinTypeForbidden JoinType = 2 // 不允许加入
|
||||
)
|
||||
|
||||
// Group 群组模型
|
||||
type Group struct {
|
||||
ID string `gorm:"primaryKey;size:20" json:"id"`
|
||||
Name string `gorm:"size:50;not null" json:"name"`
|
||||
Avatar string `gorm:"size:512" json:"avatar"`
|
||||
Description string `gorm:"size:500" json:"description"`
|
||||
OwnerID string `gorm:"type:varchar(36);not null;index" json:"owner_id"`
|
||||
MemberCount int `gorm:"default:0" json:"member_count"`
|
||||
MaxMembers int `gorm:"default:500" json:"max_members"`
|
||||
JoinType JoinType `gorm:"default:0" json:"join_type"` // 0:允许任何人加入 1:需要审批 2:不允许加入
|
||||
MuteAll bool `gorm:"default:false" json:"mute_all"` // 全员禁言
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成雪花算法ID
|
||||
func (g *Group) BeforeCreate(tx *gorm.DB) error {
|
||||
if g.ID == "" {
|
||||
id, err := utils.GetSnowflake().GenerateID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
g.ID = strconv.FormatInt(id, 10)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetIDInt 获取数字类型的ID(用于比较)
|
||||
func (g *Group) GetIDInt() uint64 {
|
||||
id, _ := strconv.ParseUint(g.ID, 10, 64)
|
||||
return id
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (Group) TableName() string {
|
||||
return "groups"
|
||||
}
|
||||
38
internal/model/group_announcement.go
Normal file
38
internal/model/group_announcement.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GroupAnnouncement 群公告模型
|
||||
type GroupAnnouncement struct {
|
||||
ID string `gorm:"primaryKey;size:20" json:"id"`
|
||||
GroupID string `gorm:"not null;index" json:"group_id"`
|
||||
Content string `gorm:"type:text;not null" json:"content"`
|
||||
AuthorID string `gorm:"type:varchar(36);not null" json:"author_id"`
|
||||
IsPinned bool `gorm:"default:false" json:"is_pinned"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成雪花算法ID
|
||||
func (ga *GroupAnnouncement) BeforeCreate(tx *gorm.DB) error {
|
||||
if ga.ID == "" {
|
||||
id, err := utils.GetSnowflake().GenerateID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ga.ID = strconv.FormatInt(id, 10)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (GroupAnnouncement) TableName() string {
|
||||
return "group_announcements"
|
||||
}
|
||||
59
internal/model/group_join_request.go
Normal file
59
internal/model/group_join_request.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type GroupJoinRequestType string
|
||||
|
||||
const (
|
||||
GroupJoinRequestTypeInvite GroupJoinRequestType = "invite"
|
||||
GroupJoinRequestTypeJoinApply GroupJoinRequestType = "join_apply"
|
||||
)
|
||||
|
||||
type GroupJoinRequestStatus string
|
||||
|
||||
const (
|
||||
GroupJoinRequestStatusPending GroupJoinRequestStatus = "pending"
|
||||
GroupJoinRequestStatusAccepted GroupJoinRequestStatus = "accepted"
|
||||
GroupJoinRequestStatusRejected GroupJoinRequestStatus = "rejected"
|
||||
GroupJoinRequestStatusCancelled GroupJoinRequestStatus = "cancelled"
|
||||
GroupJoinRequestStatusExpired GroupJoinRequestStatus = "expired"
|
||||
)
|
||||
|
||||
// GroupJoinRequest 统一保存邀请入群和主动加群申请
|
||||
type GroupJoinRequest struct {
|
||||
ID string `gorm:"primaryKey;size:20" json:"id"`
|
||||
Flag string `gorm:"size:64;uniqueIndex;not null" json:"flag"`
|
||||
GroupID string `gorm:"not null;index;index:idx_gjr_group_target_type_status_created,priority:1" json:"group_id"`
|
||||
InitiatorID string `gorm:"type:varchar(36);not null;index" json:"initiator_id"`
|
||||
TargetUserID string `gorm:"type:varchar(36);not null;index;index:idx_gjr_group_target_type_status_created,priority:2" json:"target_user_id"`
|
||||
RequestType GroupJoinRequestType `gorm:"size:20;not null;index;index:idx_gjr_group_target_type_status_created,priority:3" json:"request_type"`
|
||||
Status GroupJoinRequestStatus `gorm:"size:20;not null;index;index:idx_gjr_group_target_type_status_created,priority:4" json:"status"`
|
||||
Reason string `gorm:"size:500" json:"reason"`
|
||||
ReviewerID string `gorm:"type:varchar(36);index" json:"reviewer_id"`
|
||||
ReviewedAt *time.Time `json:"reviewed_at"`
|
||||
ExpireAt *time.Time `json:"expire_at"`
|
||||
CreatedAt time.Time `gorm:"index:idx_gjr_group_target_type_status_created,priority:5,sort:desc" json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (r *GroupJoinRequest) BeforeCreate(tx *gorm.DB) error {
|
||||
if r.ID == "" {
|
||||
id, err := utils.GetSnowflake().GenerateID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.ID = strconv.FormatInt(id, 10)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (GroupJoinRequest) TableName() string {
|
||||
return "group_join_requests"
|
||||
}
|
||||
47
internal/model/group_member.go
Normal file
47
internal/model/group_member.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GroupMember 群成员模型
|
||||
type GroupMember struct {
|
||||
ID string `gorm:"primaryKey;size:20" json:"id"`
|
||||
GroupID string `gorm:"not null;uniqueIndex:idx_group_user" json:"group_id"`
|
||||
UserID string `gorm:"type:varchar(36);not null;uniqueIndex:idx_group_user" json:"user_id"`
|
||||
Role string `gorm:"size:20;default:'member'" json:"role"` // owner, admin, member
|
||||
Nickname string `gorm:"size:50" json:"nickname"` // 群内昵称
|
||||
Muted bool `gorm:"default:false" json:"muted"` // 是否被禁言
|
||||
JoinTime time.Time `json:"join_time"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成雪花算法ID
|
||||
func (gm *GroupMember) BeforeCreate(tx *gorm.DB) error {
|
||||
if gm.ID == "" {
|
||||
id, err := utils.GetSnowflake().GenerateID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
gm.ID = strconv.FormatInt(id, 10)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (GroupMember) TableName() string {
|
||||
return "group_members"
|
||||
}
|
||||
|
||||
// Role 常量
|
||||
const (
|
||||
GroupRoleOwner = "owner"
|
||||
GroupRoleAdmin = "admin"
|
||||
GroupRoleMember = "member"
|
||||
)
|
||||
159
internal/model/init.go
Normal file
159
internal/model/init.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"carrot_bbs/internal/config"
|
||||
)
|
||||
|
||||
// DB 全局数据库连接
|
||||
var DB *gorm.DB
|
||||
|
||||
// InitDB 初始化数据库连接
|
||||
func InitDB(cfg *config.DatabaseConfig) error {
|
||||
var err error
|
||||
var db *gorm.DB
|
||||
gormLogger := logger.New(
|
||||
log.New(os.Stdout, "\r\n", log.LstdFlags),
|
||||
logger.Config{
|
||||
SlowThreshold: time.Duration(cfg.SlowThresholdMs) * time.Millisecond,
|
||||
LogLevel: parseGormLogLevel(cfg.LogLevel),
|
||||
IgnoreRecordNotFoundError: cfg.IgnoreRecordNotFound,
|
||||
ParameterizedQueries: cfg.ParameterizedQueries,
|
||||
Colorful: false,
|
||||
},
|
||||
)
|
||||
|
||||
// 根据数据库类型选择驱动
|
||||
switch cfg.Type {
|
||||
case "sqlite":
|
||||
db, err = gorm.Open(sqlite.Open(cfg.SQLite.Path), &gorm.Config{
|
||||
Logger: gormLogger,
|
||||
})
|
||||
case "postgres", "postgresql":
|
||||
dsn := cfg.Postgres.DSN()
|
||||
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: gormLogger,
|
||||
})
|
||||
default:
|
||||
// 默认使用PostgreSQL
|
||||
dsn := cfg.Postgres.DSN()
|
||||
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: gormLogger,
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
DB = db
|
||||
|
||||
// 配置连接池(SQLite不支持连接池配置,跳过)
|
||||
if cfg.Type != "sqlite" {
|
||||
sqlDB, err := DB.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get database instance: %w", err)
|
||||
}
|
||||
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
}
|
||||
|
||||
// 自动迁移
|
||||
if err := autoMigrate(DB); err != nil {
|
||||
return fmt.Errorf("failed to auto migrate: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Database connected (%s) and migrated successfully", cfg.Type)
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseGormLogLevel(level string) logger.LogLevel {
|
||||
switch level {
|
||||
case "silent":
|
||||
return logger.Silent
|
||||
case "error":
|
||||
return logger.Error
|
||||
case "warn":
|
||||
return logger.Warn
|
||||
case "info":
|
||||
return logger.Info
|
||||
default:
|
||||
return logger.Warn
|
||||
}
|
||||
}
|
||||
|
||||
// autoMigrate 自动迁移数据库表
|
||||
func autoMigrate(db *gorm.DB) error {
|
||||
err := db.AutoMigrate(
|
||||
// 用户相关
|
||||
&User{},
|
||||
|
||||
// 帖子相关
|
||||
&Post{},
|
||||
&PostImage{},
|
||||
|
||||
// 评论相关
|
||||
&Comment{},
|
||||
&CommentLike{},
|
||||
|
||||
// 消息相关(使用雪花算法ID和seq机制)
|
||||
// 已读位置存储在 ConversationParticipant.LastReadSeq 中
|
||||
&Conversation{},
|
||||
&ConversationParticipant{},
|
||||
&Message{},
|
||||
|
||||
// 系统通知(独立表,每个用户只能看到自己的通知)
|
||||
&SystemNotification{},
|
||||
|
||||
// 通知
|
||||
&Notification{},
|
||||
|
||||
// 推送中心相关
|
||||
&PushRecord{}, // 推送记录
|
||||
&DeviceToken{}, // 设备Token
|
||||
|
||||
// 社交
|
||||
&Follow{},
|
||||
&UserBlock{},
|
||||
&PostLike{},
|
||||
&Favorite{},
|
||||
|
||||
// 投票
|
||||
&VoteOption{},
|
||||
&UserVote{},
|
||||
|
||||
// 敏感词和审核
|
||||
&SensitiveWord{},
|
||||
&AuditLog{},
|
||||
|
||||
// 群组相关
|
||||
&Group{},
|
||||
&GroupMember{},
|
||||
&GroupAnnouncement{},
|
||||
&GroupJoinRequest{},
|
||||
|
||||
// 自定义表情
|
||||
&UserSticker{},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseDB 关闭数据库连接
|
||||
func CloseDB() {
|
||||
if sqlDB, err := DB.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}
|
||||
28
internal/model/like.go
Normal file
28
internal/model/like.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// PostLike 帖子点赞
|
||||
type PostLike struct {
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
PostID string `json:"post_id" gorm:"type:varchar(36);not null;index;uniqueIndex:idx_post_like_user,priority:1"`
|
||||
UserID string `json:"user_id" gorm:"type:varchar(36);not null;index;uniqueIndex:idx_post_like_user,priority:2"`
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成UUID
|
||||
func (pl *PostLike) BeforeCreate(tx *gorm.DB) error {
|
||||
if pl.ID == "" {
|
||||
pl.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (PostLike) TableName() string {
|
||||
return "post_likes"
|
||||
}
|
||||
205
internal/model/message.go
Normal file
205
internal/model/message.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
)
|
||||
|
||||
// 系统消息相关常量
|
||||
const (
|
||||
// SystemSenderID 系统消息发送者ID (类似QQ 10000)
|
||||
SystemSenderID int64 = 10000
|
||||
|
||||
// SystemSenderIDStr 系统消息发送者ID字符串版本
|
||||
SystemSenderIDStr string = "10000"
|
||||
|
||||
// SystemConversationID 系统通知会话ID (string类型)
|
||||
SystemConversationID string = "9999999999"
|
||||
)
|
||||
|
||||
// ContentType 消息内容类型
|
||||
type ContentType string
|
||||
|
||||
const (
|
||||
ContentTypeText ContentType = "text"
|
||||
ContentTypeImage ContentType = "image"
|
||||
ContentTypeVideo ContentType = "video"
|
||||
ContentTypeAudio ContentType = "audio"
|
||||
ContentTypeFile ContentType = "file"
|
||||
)
|
||||
|
||||
// MessageStatus 消息状态
|
||||
type MessageStatus string
|
||||
|
||||
const (
|
||||
MessageStatusNormal MessageStatus = "normal" // 正常
|
||||
MessageStatusRecalled MessageStatus = "recalled" // 已撤回
|
||||
MessageStatusDeleted MessageStatus = "deleted" // 已删除
|
||||
)
|
||||
|
||||
// MessageCategory 消息类别
|
||||
type MessageCategory string
|
||||
|
||||
const (
|
||||
CategoryChat MessageCategory = "chat" // 普通聊天
|
||||
CategoryNotification MessageCategory = "notification" // 通知类消息
|
||||
CategoryAnnouncement MessageCategory = "announcement" // 系统公告
|
||||
CategoryMarketing MessageCategory = "marketing" // 营销消息
|
||||
)
|
||||
|
||||
// SystemMessageType 系统消息类型 (对应原NotificationType)
|
||||
type SystemMessageType string
|
||||
|
||||
const (
|
||||
// 互动通知
|
||||
SystemTypeLikePost SystemMessageType = "like_post" // 点赞帖子
|
||||
SystemTypeLikeComment SystemMessageType = "like_comment" // 点赞评论
|
||||
SystemTypeComment SystemMessageType = "comment" // 评论
|
||||
SystemTypeReply SystemMessageType = "reply" // 回复
|
||||
SystemTypeFollow SystemMessageType = "follow" // 关注
|
||||
SystemTypeMention SystemMessageType = "mention" // @提及
|
||||
|
||||
// 系统消息
|
||||
SystemTypeSystem SystemMessageType = "system" // 系统通知
|
||||
SystemTypeAnnounce SystemMessageType = "announce" // 系统公告
|
||||
)
|
||||
|
||||
// ExtraData 消息额外数据,用于存储系统消息的相关信息
|
||||
type ExtraData struct {
|
||||
// 操作者信息
|
||||
ActorID int64 `json:"actor_id,omitempty"` // 操作者ID (数字格式,兼容旧数据)
|
||||
ActorIDStr string `json:"actor_id_str,omitempty"` // 操作者ID (UUID字符串格式)
|
||||
ActorName string `json:"actor_name,omitempty"` // 操作者名称
|
||||
AvatarURL string `json:"avatar_url,omitempty"` // 操作者头像
|
||||
|
||||
// 目标信息
|
||||
TargetID int64 `json:"target_id,omitempty"` // 目标ID(帖子ID、评论ID等)
|
||||
TargetTitle string `json:"target_title,omitempty"` // 目标标题
|
||||
TargetType string `json:"target_type,omitempty"` // 目标类型(post/comment等)
|
||||
|
||||
// 其他信息
|
||||
ActionURL string `json:"action_url,omitempty"` // 跳转链接
|
||||
ActionTime string `json:"action_time,omitempty"` // 操作时间
|
||||
}
|
||||
|
||||
// Value 实现driver.Valuer接口,用于数据库存储
|
||||
func (e ExtraData) Value() (driver.Value, error) {
|
||||
return json.Marshal(e)
|
||||
}
|
||||
|
||||
// Scan 实现sql.Scanner接口,用于数据库读取
|
||||
func (e *ExtraData) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
bytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return errors.New("type assertion to []byte failed")
|
||||
}
|
||||
return json.Unmarshal(bytes, e)
|
||||
}
|
||||
|
||||
// MessageSegmentData 单个消息段的数据
|
||||
type MessageSegmentData map[string]interface{}
|
||||
|
||||
// MessageSegment 消息段
|
||||
type MessageSegment struct {
|
||||
Type string `json:"type"`
|
||||
Data MessageSegmentData `json:"data"`
|
||||
}
|
||||
|
||||
// MessageSegments 消息链类型
|
||||
type MessageSegments []MessageSegment
|
||||
|
||||
// Value 实现driver.Valuer接口,用于数据库存储
|
||||
func (s MessageSegments) Value() (driver.Value, error) {
|
||||
return json.Marshal(s)
|
||||
}
|
||||
|
||||
// Scan 实现sql.Scanner接口,用于数据库读取
|
||||
func (s *MessageSegments) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*s = nil
|
||||
return nil
|
||||
}
|
||||
bytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return errors.New("type assertion to []byte failed")
|
||||
}
|
||||
return json.Unmarshal(bytes, s)
|
||||
}
|
||||
|
||||
// Message 消息实体
|
||||
// 使用雪花算法ID(string类型)和seq机制实现消息排序和增量同步
|
||||
type Message struct {
|
||||
ID string `gorm:"primaryKey;size:20" json:"id"` // 雪花算法ID(string类型)
|
||||
ConversationID string `gorm:"not null;size:20;index:idx_msg_conversation_seq,priority:1" json:"conversation_id"` // 会话ID(string类型)
|
||||
SenderID string `gorm:"column:sender_id;type:varchar(50);index;not null" json:"sender_id"` // 发送者ID (UUID格式)
|
||||
Seq int64 `gorm:"not null;index:idx_msg_conversation_seq,priority:2" json:"seq"` // 会话内序号,用于排序和增量同步
|
||||
Segments MessageSegments `gorm:"type:json" json:"segments"` // 消息链(结构体数组)
|
||||
ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID(string类型)
|
||||
Status MessageStatus `gorm:"type:varchar(20);default:'normal'" json:"status"` // 消息状态
|
||||
|
||||
// 新增字段:消息分类和系统消息类型
|
||||
Category MessageCategory `gorm:"type:varchar(20);default:'chat'" json:"category"` // 消息分类
|
||||
SystemType SystemMessageType `gorm:"type:varchar(30)" json:"system_type,omitempty"` // 系统消息类型
|
||||
ExtraData *ExtraData `gorm:"type:json" json:"extra_data,omitempty"` // 额外数据(JSON格式)
|
||||
|
||||
// @相关字段
|
||||
MentionUsers string `gorm:"type:text" json:"mention_users"` // @的用户ID列表,JSON数组
|
||||
MentionAll bool `gorm:"default:false" json:"mention_all"` // 是否@所有人
|
||||
|
||||
// 软删除
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
|
||||
// 时间戳
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
}
|
||||
|
||||
// SenderIDStr 返回发送者ID字符串(保持兼容性)
|
||||
func (m *Message) SenderIDStr() string {
|
||||
return m.SenderID
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成雪花算法ID
|
||||
func (m *Message) BeforeCreate(tx *gorm.DB) error {
|
||||
if m.ID == "" {
|
||||
id, err := utils.GetSnowflake().GenerateID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ID = strconv.FormatInt(id, 10)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (Message) TableName() string {
|
||||
return "messages"
|
||||
}
|
||||
|
||||
// IsSystemMessage 判断是否为系统消息
|
||||
func (m *Message) IsSystemMessage() bool {
|
||||
return m.SenderID == SystemSenderIDStr || m.Category == CategoryNotification || m.Category == CategoryAnnouncement
|
||||
}
|
||||
|
||||
// IsInteractionNotification 判断是否为互动通知
|
||||
func (m *Message) IsInteractionNotification() bool {
|
||||
if m.Category != CategoryNotification {
|
||||
return false
|
||||
}
|
||||
switch m.SystemType {
|
||||
case SystemTypeLikePost, SystemTypeLikeComment, SystemTypeComment,
|
||||
SystemTypeReply, SystemTypeFollow, SystemTypeMention:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
19
internal/model/message_read.go
Normal file
19
internal/model/message_read.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// MessageRead 消息已读状态
|
||||
// 记录每个用户在每个会话中的已读位置
|
||||
type MessageRead struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
ConversationID int64 `gorm:"uniqueIndex:idx_conversation_user;not null" json:"conversation_id"`
|
||||
UserID uint `gorm:"uniqueIndex:idx_conversation_user;not null" json:"user_id"`
|
||||
LastReadSeq int64 `gorm:"not null" json:"last_read_seq"` // 已读到的seq位置
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
}
|
||||
|
||||
func (MessageRead) TableName() string {
|
||||
return "message_reads"
|
||||
}
|
||||
53
internal/model/notification.go
Normal file
53
internal/model/notification.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// NotificationType 通知类型
|
||||
type NotificationType string
|
||||
|
||||
const (
|
||||
NotificationTypeLikePost NotificationType = "like_post"
|
||||
NotificationTypeLikeComment NotificationType = "like_comment"
|
||||
NotificationTypeComment NotificationType = "comment"
|
||||
NotificationTypeReply NotificationType = "reply"
|
||||
NotificationTypeFollow NotificationType = "follow"
|
||||
NotificationTypeMention NotificationType = "mention"
|
||||
NotificationTypeSystem NotificationType = "system"
|
||||
)
|
||||
|
||||
// Notification 通知实体
|
||||
type Notification struct {
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
UserID string `json:"user_id" gorm:"type:varchar(36);not null;index:idx_notifications_user_read_created,priority:1"` // 接收者
|
||||
Type NotificationType `json:"type" gorm:"type:varchar(30);not null"`
|
||||
Title string `json:"title" gorm:"type:varchar(200);not null"`
|
||||
Content string `json:"content" gorm:"type:text"`
|
||||
Data string `json:"data" gorm:"type:jsonb"` // 相关数据(JSON)
|
||||
|
||||
// 已读状态
|
||||
IsRead bool `json:"is_read" gorm:"default:false;index:idx_notifications_user_read_created,priority:2"`
|
||||
ReadAt *time.Time `json:"read_at" gorm:"type:timestamp"`
|
||||
|
||||
// 软删除
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
|
||||
// 时间戳
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_notifications_user_read_created,priority:3,sort:desc"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成UUID
|
||||
func (n *Notification) BeforeCreate(tx *gorm.DB) error {
|
||||
if n.ID == "" {
|
||||
n.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (Notification) TableName() string {
|
||||
return "notifications"
|
||||
}
|
||||
100
internal/model/post.go
Normal file
100
internal/model/post.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// PostStatus 帖子状态
|
||||
type PostStatus string
|
||||
|
||||
const (
|
||||
PostStatusDraft PostStatus = "draft"
|
||||
PostStatusPending PostStatus = "pending" // 待审核
|
||||
PostStatusPublished PostStatus = "published"
|
||||
PostStatusRejected PostStatus = "rejected"
|
||||
PostStatusDeleted PostStatus = "deleted"
|
||||
)
|
||||
|
||||
// Post 帖子实体
|
||||
type Post struct {
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
UserID string `json:"user_id" gorm:"type:varchar(36);index;index:idx_posts_user_status_created,priority:1;not null"`
|
||||
CommunityID string `json:"community_id" gorm:"type:varchar(36);index"`
|
||||
Title string `json:"title" gorm:"type:varchar(200);not null"`
|
||||
Content string `json:"content" gorm:"type:text;not null"`
|
||||
|
||||
// 关联
|
||||
// User 需要参与缓存序列化;否则列表命中缓存后会丢失作者信息,前端退化为“匿名用户”
|
||||
User *User `json:"user,omitempty" gorm:"foreignKey:UserID"`
|
||||
Images []PostImage `json:"images" gorm:"foreignKey:PostID"`
|
||||
|
||||
// 审核状态
|
||||
Status PostStatus `json:"status" gorm:"type:varchar(20);default:published;index:idx_posts_status_created,priority:1;index:idx_posts_user_status_created,priority:2"`
|
||||
ReviewedAt *time.Time `json:"reviewed_at" gorm:"type:timestamp"`
|
||||
ReviewedBy string `json:"reviewed_by" gorm:"type:varchar(50)"`
|
||||
RejectReason string `json:"reject_reason" gorm:"type:varchar(500)"`
|
||||
|
||||
// 统计
|
||||
LikesCount int `json:"likes_count" gorm:"column:likes_count;default:0"`
|
||||
CommentsCount int `json:"comments_count" gorm:"column:comments_count;default:0"`
|
||||
FavoritesCount int `json:"favorites_count" gorm:"column:favorites_count;default:0"`
|
||||
SharesCount int `json:"shares_count" gorm:"column:shares_count;default:0"`
|
||||
ViewsCount int `json:"views_count" gorm:"column:views_count;default:0"`
|
||||
HotScore float64 `json:"hot_score" gorm:"column:hot_score;default:0;index:idx_posts_hot_score_created,priority:1"`
|
||||
|
||||
// 置顶/锁定
|
||||
IsPinned bool `json:"is_pinned" gorm:"default:false"`
|
||||
IsLocked bool `json:"is_locked" gorm:"default:false"`
|
||||
IsDeleted bool `json:"-" gorm:"default:false"`
|
||||
|
||||
// 投票
|
||||
IsVote bool `json:"is_vote" gorm:"column:is_vote;default:false"`
|
||||
|
||||
// 软删除
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
|
||||
// 时间戳
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_posts_status_created,priority:2,sort:desc;index:idx_posts_user_status_created,priority:3,sort:desc;index:idx_posts_hot_score_created,priority:2,sort:desc"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成UUID
|
||||
func (p *Post) BeforeCreate(tx *gorm.DB) error {
|
||||
if p.ID == "" {
|
||||
p.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (Post) TableName() string {
|
||||
return "posts"
|
||||
}
|
||||
|
||||
// PostImage 帖子图片
|
||||
type PostImage struct {
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
PostID string `json:"post_id" gorm:"type:varchar(36);index;not null"`
|
||||
URL string `json:"url" gorm:"type:text;not null"`
|
||||
ThumbnailURL string `json:"thumbnail_url" gorm:"type:text"`
|
||||
Width int `json:"width" gorm:"default:0"`
|
||||
Height int `json:"height" gorm:"default:0"`
|
||||
Size int64 `json:"size" gorm:"default:0"` // 文件大小(字节)
|
||||
MimeType string `json:"mime_type" gorm:"type:varchar(50)"`
|
||||
SortOrder int `json:"sort_order" gorm:"default:0"`
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成UUID
|
||||
func (pi *PostImage) BeforeCreate(tx *gorm.DB) error {
|
||||
if pi.ID == "" {
|
||||
pi.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (PostImage) TableName() string {
|
||||
return "post_images"
|
||||
}
|
||||
129
internal/model/push_record.go
Normal file
129
internal/model/push_record.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
)
|
||||
|
||||
// PushChannel 推送通道类型
|
||||
type PushChannel string
|
||||
|
||||
const (
|
||||
PushChannelWebSocket PushChannel = "websocket" // WebSocket推送
|
||||
PushChannelFCM PushChannel = "fcm" // Firebase Cloud Messaging
|
||||
PushChannelAPNs PushChannel = "apns" // Apple Push Notification service
|
||||
PushChannelHuawei PushChannel = "huawei" // 华为推送
|
||||
)
|
||||
|
||||
// PushStatus 推送状态
|
||||
type PushStatus string
|
||||
|
||||
const (
|
||||
PushStatusPending PushStatus = "pending" // 待推送
|
||||
PushStatusPushing PushStatus = "pushing" // 推送中
|
||||
PushStatusPushed PushStatus = "pushed" // 已推送(成功发送到推送服务)
|
||||
PushStatusDelivered PushStatus = "delivered" // 已送达(客户端确认)
|
||||
PushStatusFailed PushStatus = "failed" // 推送失败
|
||||
PushStatusExpired PushStatus = "expired" // 消息过期
|
||||
)
|
||||
|
||||
// PushRecord 推送记录实体
|
||||
// 用于跟踪消息的推送状态,支持多设备推送和重试机制
|
||||
type PushRecord struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement:false" json:"id"` // 雪花算法ID
|
||||
UserID string `gorm:"column:user_id;type:varchar(50);index;not null" json:"user_id"` // 目标用户ID (UUID格式)
|
||||
MessageID string `gorm:"index;not null;size:20" json:"message_id"` // 关联的消息ID (string类型)
|
||||
PushChannel PushChannel `gorm:"type:varchar(20);not null" json:"push_channel"` // 推送通道
|
||||
PushStatus PushStatus `gorm:"type:varchar(20);not null;default:'pending'" json:"push_status"` // 推送状态
|
||||
|
||||
// 设备信息
|
||||
DeviceToken string `gorm:"type:varchar(256)" json:"device_token,omitempty"` // 设备Token(用于手机推送)
|
||||
DeviceType string `gorm:"type:varchar(20)" json:"device_type,omitempty"` // 设备类型 (ios/android/web)
|
||||
|
||||
// 重试机制
|
||||
RetryCount int `gorm:"default:0" json:"retry_count"` // 重试次数
|
||||
MaxRetry int `gorm:"default:3" json:"max_retry"` // 最大重试次数
|
||||
|
||||
// 时间戳
|
||||
PushedAt *time.Time `json:"pushed_at,omitempty"` // 推送时间
|
||||
DeliveredAt *time.Time `json:"delivered_at,omitempty"` // 送达时间
|
||||
ExpiredAt *time.Time `gorm:"index" json:"expired_at,omitempty"` // 过期时间
|
||||
|
||||
// 错误信息
|
||||
ErrorMessage string `gorm:"type:varchar(500)" json:"error_message,omitempty"` // 错误信息
|
||||
|
||||
// 软删除
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
|
||||
// 时间戳
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成雪花算法ID
|
||||
func (r *PushRecord) BeforeCreate(tx *gorm.DB) error {
|
||||
if r.ID == 0 {
|
||||
id, err := utils.GetSnowflake().GenerateID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.ID = id
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (PushRecord) TableName() string {
|
||||
return "push_records"
|
||||
}
|
||||
|
||||
// CanRetry 判断是否可以重试
|
||||
func (r *PushRecord) CanRetry() bool {
|
||||
return r.RetryCount < r.MaxRetry && r.PushStatus != PushStatusDelivered && r.PushStatus != PushStatusExpired
|
||||
}
|
||||
|
||||
// IsExpired 判断是否已过期
|
||||
func (r *PushRecord) IsExpired() bool {
|
||||
if r.ExpiredAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(*r.ExpiredAt)
|
||||
}
|
||||
|
||||
// MarkPushing 标记为推送中
|
||||
func (r *PushRecord) MarkPushing() {
|
||||
r.PushStatus = PushStatusPushing
|
||||
}
|
||||
|
||||
// MarkPushed 标记为已推送
|
||||
func (r *PushRecord) MarkPushed() {
|
||||
now := time.Now()
|
||||
r.PushStatus = PushStatusPushed
|
||||
r.PushedAt = &now
|
||||
}
|
||||
|
||||
// MarkDelivered 标记为已送达
|
||||
func (r *PushRecord) MarkDelivered() {
|
||||
now := time.Now()
|
||||
r.PushStatus = PushStatusDelivered
|
||||
r.DeliveredAt = &now
|
||||
}
|
||||
|
||||
// MarkFailed 标记为推送失败
|
||||
func (r *PushRecord) MarkFailed(errMsg string) {
|
||||
r.PushStatus = PushStatusFailed
|
||||
r.ErrorMessage = errMsg
|
||||
r.RetryCount++
|
||||
}
|
||||
|
||||
// MarkExpired 标记为已过期
|
||||
func (r *PushRecord) MarkExpired() {
|
||||
r.PushStatus = PushStatusExpired
|
||||
}
|
||||
|
||||
// IncrementRetry 增加重试次数
|
||||
func (r *PushRecord) IncrementRetry() {
|
||||
r.RetryCount++
|
||||
}
|
||||
77
internal/model/sensitive_word.go
Normal file
77
internal/model/sensitive_word.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SensitiveWordLevel 敏感词级别
|
||||
type SensitiveWordLevel int
|
||||
|
||||
const (
|
||||
SensitiveWordLevelLow SensitiveWordLevel = 1 // 低危
|
||||
SensitiveWordLevelMedium SensitiveWordLevel = 2 // 中危
|
||||
SensitiveWordLevelHigh SensitiveWordLevel = 3 // 高危
|
||||
)
|
||||
|
||||
// SensitiveWordCategory 敏感词分类
|
||||
type SensitiveWordCategory string
|
||||
|
||||
const (
|
||||
SensitiveWordCategoryPolitical SensitiveWordCategory = "political" // 政治
|
||||
SensitiveWordCategoryPorn SensitiveWordCategory = "porn" // 色情
|
||||
SensitiveWordCategoryViolence SensitiveWordCategory = "violence" // 暴力
|
||||
SensitiveWordCategoryAd SensitiveWordCategory = "ad" // 广告
|
||||
SensitiveWordCategoryGamble SensitiveWordCategory = "gamble" // 赌博
|
||||
SensitiveWordCategoryFraud SensitiveWordCategory = "fraud" // 诈骗
|
||||
SensitiveWordCategoryOther SensitiveWordCategory = "other" // 其他
|
||||
)
|
||||
|
||||
// SensitiveWord 敏感词实体
|
||||
type SensitiveWord struct {
|
||||
ID string `gorm:"type:varchar(36);primaryKey"`
|
||||
Word string `gorm:"type:varchar(255);uniqueIndex;not null"`
|
||||
Category SensitiveWordCategory `gorm:"type:varchar(50);index"`
|
||||
Level SensitiveWordLevel `gorm:"type:int;default:1"`
|
||||
IsActive bool `gorm:"default:true"`
|
||||
CreatedBy string `gorm:"type:varchar(255)"`
|
||||
UpdatedBy string `gorm:"type:varchar(255)"`
|
||||
Remark string `gorm:"type:text"`
|
||||
CreatedAt time.Time `gorm:"autoCreateTime"`
|
||||
UpdatedAt time.Time `gorm:"autoUpdateTime"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成UUID
|
||||
func (sw *SensitiveWord) BeforeCreate(tx *gorm.DB) error {
|
||||
if sw.ID == "" {
|
||||
sw.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (SensitiveWord) TableName() string {
|
||||
return "sensitive_words"
|
||||
}
|
||||
|
||||
// SensitiveWordRequest 创建/更新敏感词请求
|
||||
type SensitiveWordRequest struct {
|
||||
Word string `json:"word" validate:"required,min=1,max=255"`
|
||||
Category SensitiveWordCategory `json:"category"`
|
||||
Level SensitiveWordLevel `json:"level"`
|
||||
Remark string `json:"remark"`
|
||||
CreatedBy string `json:"-"`
|
||||
}
|
||||
|
||||
// SensitiveWordListItem 敏感词列表项(用于列表展示)
|
||||
type SensitiveWordListItem struct {
|
||||
ID string `json:"id"`
|
||||
Word string `json:"word"`
|
||||
Category SensitiveWordCategory `json:"category"`
|
||||
Level SensitiveWordLevel `json:"level"`
|
||||
IsActive bool `json:"is_active"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Remark string `json:"remark"`
|
||||
}
|
||||
33
internal/model/sticker.go
Normal file
33
internal/model/sticker.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UserSticker 用户自定义表情
|
||||
type UserSticker struct {
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
UserID string `json:"user_id" gorm:"type:varchar(36);not null;index:idx_user_stickers"`
|
||||
URL string `json:"url" gorm:"type:text;not null"`
|
||||
Width int `json:"width" gorm:"default:0"`
|
||||
Height int `json:"height" gorm:"default:0"`
|
||||
SortOrder int `json:"sort_order" gorm:"default:0;index:idx_user_stickers_sort"`
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
}
|
||||
|
||||
// TableName 表名
|
||||
func (UserSticker) TableName() string {
|
||||
return "user_stickers"
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成UUID
|
||||
func (s *UserSticker) BeforeCreate(tx *gorm.DB) error {
|
||||
if s.ID == "" {
|
||||
s.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
127
internal/model/system_notification.go
Normal file
127
internal/model/system_notification.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
)
|
||||
|
||||
// SystemNotificationType 系统通知类型
|
||||
type SystemNotificationType string
|
||||
|
||||
const (
|
||||
// 互动通知
|
||||
SysNotifyLikePost SystemNotificationType = "like_post" // 点赞帖子
|
||||
SysNotifyLikeComment SystemNotificationType = "like_comment" // 点赞评论
|
||||
SysNotifyComment SystemNotificationType = "comment" // 评论
|
||||
SysNotifyReply SystemNotificationType = "reply" // 回复
|
||||
SysNotifyFollow SystemNotificationType = "follow" // 关注
|
||||
SysNotifyMention SystemNotificationType = "mention" // @提及
|
||||
SysNotifyFavoritePost SystemNotificationType = "favorite_post" // 收藏帖子
|
||||
SysNotifyLikeReply SystemNotificationType = "like_reply" // 点赞回复
|
||||
|
||||
// 系统消息
|
||||
SysNotifySystem SystemNotificationType = "system" // 系统通知
|
||||
SysNotifyAnnounce SystemNotificationType = "announce" // 系统公告
|
||||
SysNotifyGroupInvite SystemNotificationType = "group_invite" // 群邀请
|
||||
SysNotifyGroupJoinApply SystemNotificationType = "group_join_apply" // 加群申请待审批
|
||||
SysNotifyGroupJoinApproved SystemNotificationType = "group_join_approved" // 加群申请通过
|
||||
SysNotifyGroupJoinRejected SystemNotificationType = "group_join_rejected" // 加群申请拒绝
|
||||
)
|
||||
|
||||
// SystemNotificationExtra 额外数据
|
||||
type SystemNotificationExtra struct {
|
||||
// 操作者信息
|
||||
ActorID int64 `json:"actor_id,omitempty"`
|
||||
ActorIDStr string `json:"actor_id_str,omitempty"`
|
||||
ActorName string `json:"actor_name,omitempty"`
|
||||
AvatarURL string `json:"avatar_url,omitempty"`
|
||||
|
||||
// 目标信息
|
||||
TargetID string `json:"target_id,omitempty"` // 改为string类型以支持UUID
|
||||
TargetTitle string `json:"target_title,omitempty"`
|
||||
TargetType string `json:"target_type,omitempty"`
|
||||
|
||||
// 其他信息
|
||||
ActionURL string `json:"action_url,omitempty"`
|
||||
ActionTime string `json:"action_time,omitempty"`
|
||||
|
||||
// 群邀请/加群申请扩展字段
|
||||
GroupID string `json:"group_id,omitempty"`
|
||||
GroupName string `json:"group_name,omitempty"`
|
||||
GroupAvatar string `json:"group_avatar,omitempty"`
|
||||
GroupDescription string `json:"group_description,omitempty"`
|
||||
Flag string `json:"flag,omitempty"`
|
||||
RequestType string `json:"request_type,omitempty"`
|
||||
RequestStatus string `json:"request_status,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
TargetUserID string `json:"target_user_id,omitempty"`
|
||||
TargetUserName string `json:"target_user_name,omitempty"`
|
||||
TargetUserAvatar string `json:"target_user_avatar,omitempty"`
|
||||
}
|
||||
|
||||
// Value 实现driver.Valuer接口
|
||||
func (e SystemNotificationExtra) Value() (driver.Value, error) {
|
||||
return json.Marshal(e)
|
||||
}
|
||||
|
||||
// Scan 实现sql.Scanner接口
|
||||
func (e *SystemNotificationExtra) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
bytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return errors.New("type assertion to []byte failed")
|
||||
}
|
||||
return json.Unmarshal(bytes, e)
|
||||
}
|
||||
|
||||
// SystemNotification 系统通知(独立表,与消息完全分离)
|
||||
// 每个用户只能看到自己的系统通知
|
||||
type SystemNotification struct {
|
||||
ID int64 `gorm:"primaryKey;autoIncrement:false" json:"id"`
|
||||
ReceiverID string `gorm:"column:receiver_id;type:varchar(50);not null;index:idx_sys_notifications_receiver_read_created,priority:1" json:"receiver_id"` // 接收者ID (UUID)
|
||||
Type SystemNotificationType `gorm:"type:varchar(30);not null" json:"type"` // 通知类型
|
||||
Title string `gorm:"type:varchar(200)" json:"title,omitempty"` // 标题
|
||||
Content string `gorm:"type:text;not null" json:"content"` // 内容
|
||||
ExtraData *SystemNotificationExtra `gorm:"type:json" json:"extra_data,omitempty"` // 额外数据
|
||||
IsRead bool `gorm:"default:false;index:idx_sys_notifications_receiver_read_created,priority:2" json:"is_read"` // 是否已读
|
||||
ReadAt *time.Time `json:"read_at,omitempty"` // 阅读时间
|
||||
|
||||
// 软删除
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
|
||||
// 时间戳
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_sys_notifications_receiver_read_created,priority:3,sort:desc"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成雪花算法ID
|
||||
func (n *SystemNotification) BeforeCreate(tx *gorm.DB) error {
|
||||
if n.ID == 0 {
|
||||
id, err := utils.GetSnowflake().GenerateID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n.ID = id
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (SystemNotification) TableName() string {
|
||||
return "system_notifications"
|
||||
}
|
||||
|
||||
// MarkAsRead 标记为已读
|
||||
func (n *SystemNotification) MarkAsRead() {
|
||||
now := time.Now()
|
||||
n.IsRead = true
|
||||
n.ReadAt = &now
|
||||
}
|
||||
66
internal/model/user.go
Normal file
66
internal/model/user.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UserStatus 用户状态
|
||||
type UserStatus string
|
||||
|
||||
const (
|
||||
UserStatusActive UserStatus = "active"
|
||||
UserStatusBanned UserStatus = "banned"
|
||||
UserStatusInactive UserStatus = "inactive"
|
||||
)
|
||||
|
||||
// User 用户实体
|
||||
type User struct {
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
Username string `json:"username" gorm:"type:varchar(50);uniqueIndex;not null"`
|
||||
Nickname string `json:"nickname" gorm:"type:varchar(100);not null"`
|
||||
Email *string `json:"email" gorm:"type:varchar(255);uniqueIndex"`
|
||||
Phone *string `json:"phone" gorm:"type:varchar(20);uniqueIndex"`
|
||||
EmailVerified bool `json:"email_verified" gorm:"default:false"`
|
||||
PasswordHash string `json:"-" gorm:"type:varchar(255);not null"`
|
||||
Avatar string `json:"avatar" gorm:"type:text"`
|
||||
CoverURL string `json:"cover_url" gorm:"type:text"` // 头图URL
|
||||
Bio string `json:"bio" gorm:"type:text"`
|
||||
Website string `json:"website" gorm:"type:varchar(255)"`
|
||||
Location string `json:"location" gorm:"type:varchar(100)"`
|
||||
|
||||
// 实名认证信息(可选)
|
||||
RealName string `json:"real_name" gorm:"type:varchar(100)"` // 真实姓名
|
||||
IDCard string `json:"-" gorm:"type:varchar(18)"` // 身份证号(加密存储)
|
||||
IsVerified bool `json:"is_verified" gorm:"default:false"` // 是否实名认证
|
||||
VerifiedAt *time.Time `json:"verified_at" gorm:"type:timestamp"`
|
||||
|
||||
// 统计计数
|
||||
PostsCount int `json:"posts_count" gorm:"default:0"`
|
||||
FollowersCount int `json:"followers_count" gorm:"default:0"`
|
||||
FollowingCount int `json:"following_count" gorm:"default:0"`
|
||||
|
||||
// 状态
|
||||
Status UserStatus `json:"status" gorm:"type:varchar(20);default:active"`
|
||||
LastLoginAt *time.Time `json:"last_login_at" gorm:"type:timestamp"`
|
||||
LastLoginIP string `json:"last_login_ip" gorm:"type:varchar(45)"`
|
||||
|
||||
// 时间戳
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成UUID
|
||||
func (u *User) BeforeCreate(tx *gorm.DB) error {
|
||||
if u.ID == "" {
|
||||
u.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
27
internal/model/user_block.go
Normal file
27
internal/model/user_block.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// UserBlock 用户拉黑关系
|
||||
type UserBlock struct {
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
BlockerID string `json:"blocker_id" gorm:"type:varchar(36);index;not null;uniqueIndex:idx_blocker_blocked"` // 拉黑人
|
||||
BlockedID string `json:"blocked_id" gorm:"type:varchar(36);index;not null;uniqueIndex:idx_blocker_blocked"` // 被拉黑人
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
}
|
||||
|
||||
func (b *UserBlock) BeforeCreate(tx *gorm.DB) error {
|
||||
if b.ID == "" {
|
||||
b.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (UserBlock) TableName() string {
|
||||
return "user_blocks"
|
||||
}
|
||||
52
internal/model/vote.go
Normal file
52
internal/model/vote.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// VoteOption 投票选项
|
||||
type VoteOption struct {
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
PostID string `json:"post_id" gorm:"type:varchar(36);index:idx_vote_option_post_sort,priority:1;not null"`
|
||||
Content string `json:"content" gorm:"type:varchar(200);not null"`
|
||||
SortOrder int `json:"sort_order" gorm:"default:0;index:idx_vote_option_post_sort,priority:2"`
|
||||
VotesCount int `json:"votes_count" gorm:"default:0"`
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成UUID
|
||||
func (vo *VoteOption) BeforeCreate(tx *gorm.DB) error {
|
||||
if vo.ID == "" {
|
||||
vo.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (VoteOption) TableName() string {
|
||||
return "vote_options"
|
||||
}
|
||||
|
||||
// UserVote 用户投票记录
|
||||
type UserVote struct {
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
PostID string `json:"post_id" gorm:"type:varchar(36);index;uniqueIndex:idx_user_vote_post_user,priority:1;not null"`
|
||||
UserID string `json:"user_id" gorm:"type:varchar(36);index;uniqueIndex:idx_user_vote_post_user,priority:2;not null"`
|
||||
OptionID string `json:"option_id" gorm:"type:varchar(36);index;not null"`
|
||||
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
|
||||
}
|
||||
|
||||
// BeforeCreate 创建前生成UUID
|
||||
func (uv *UserVote) BeforeCreate(tx *gorm.DB) error {
|
||||
if uv.ID == "" {
|
||||
uv.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (UserVote) TableName() string {
|
||||
return "user_votes"
|
||||
}
|
||||
115
internal/pkg/avatar/avatar.go
Normal file
115
internal/pkg/avatar/avatar.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package avatar
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// 预定义一组好看的颜色
|
||||
var colors = []string{
|
||||
"#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4",
|
||||
"#FFEAA7", "#DDA0DD", "#98D8C8", "#F7DC6F",
|
||||
"#BB8FCE", "#85C1E9", "#F8B500", "#00CED1",
|
||||
"#E74C3C", "#3498DB", "#2ECC71", "#9B59B6",
|
||||
"#1ABC9C", "#F39C12", "#E67E22", "#16A085",
|
||||
}
|
||||
|
||||
// SVG模板
|
||||
const svgTemplate = `<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 100" width="%d" height="%d">
|
||||
<rect width="100" height="100" fill="%s"/>
|
||||
<text x="50" y="50" font-family="Arial, sans-serif" font-size="40" font-weight="bold" fill="#ffffff" text-anchor="middle" dominant-baseline="central">%s</text>
|
||||
</svg>`
|
||||
|
||||
// GenerateSVGAvatar 根据用户名生成SVG头像
|
||||
// username: 用户名
|
||||
// size: 头像尺寸(像素)
|
||||
func GenerateSVGAvatar(username string, size int) string {
|
||||
initials := getInitials(username)
|
||||
color := stringToColor(username)
|
||||
return fmt.Sprintf(svgTemplate, size, size, color, initials)
|
||||
}
|
||||
|
||||
// GenerateAvatarDataURI 生成Data URI格式的头像
|
||||
// 可以直接在HTML img标签或CSS background-image中使用
|
||||
func GenerateAvatarDataURI(username string, size int) string {
|
||||
svg := GenerateSVGAvatar(username, size)
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(svg))
|
||||
return fmt.Sprintf("data:image/svg+xml;base64,%s", encoded)
|
||||
}
|
||||
|
||||
// getInitials 获取用户名首字母
|
||||
// 中文取第一个字,英文取首字母(最多2个)
|
||||
func getInitials(username string) string {
|
||||
if username == "" {
|
||||
return "?"
|
||||
}
|
||||
|
||||
// 检查是否是中文字符
|
||||
firstRune, _ := utf8.DecodeRuneInString(username)
|
||||
if isChinese(firstRune) {
|
||||
// 中文直接返回第一个字符
|
||||
return string(firstRune)
|
||||
}
|
||||
|
||||
// 英文处理:取前两个单词的首字母
|
||||
// 例如: "John Doe" -> "JD", "john" -> "J"
|
||||
result := []rune{}
|
||||
for i, r := range username {
|
||||
if i == 0 {
|
||||
result = append(result, toUpper(r))
|
||||
} else if r == ' ' || r == '_' || r == '-' {
|
||||
// 找到下一个字符作为第二个首字母
|
||||
nextIdx := i + 1
|
||||
if nextIdx < len(username) {
|
||||
nextRune, _ := utf8.DecodeRuneInString(username[nextIdx:])
|
||||
if nextRune != utf8.RuneError && nextRune != ' ' {
|
||||
result = append(result, toUpper(nextRune))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return "?"
|
||||
}
|
||||
|
||||
// 最多返回2个字符
|
||||
if len(result) > 2 {
|
||||
result = result[:2]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// isChinese 判断是否是中文字符
|
||||
func isChinese(r rune) bool {
|
||||
return r >= 0x4E00 && r <= 0x9FFF
|
||||
}
|
||||
|
||||
// toUpper 将字母转换为大写
|
||||
func toUpper(r rune) rune {
|
||||
if r >= 'a' && r <= 'z' {
|
||||
return r - 32
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// stringToColor 根据字符串生成颜色
|
||||
// 使用简单的哈希算法确保同一用户名每次生成的颜色一致
|
||||
func stringToColor(s string) string {
|
||||
if s == "" {
|
||||
return colors[0]
|
||||
}
|
||||
|
||||
hash := 0
|
||||
for _, r := range s {
|
||||
hash = (hash*31 + int(r)) % len(colors)
|
||||
}
|
||||
if hash < 0 {
|
||||
hash = -hash
|
||||
}
|
||||
|
||||
return colors[hash%len(colors)]
|
||||
}
|
||||
118
internal/pkg/avatar/avatar_test.go
Normal file
118
internal/pkg/avatar/avatar_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package avatar
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetInitials(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
want string
|
||||
}{
|
||||
{"中文用户名", "张三", "张"},
|
||||
{"英文用户名", "John", "J"},
|
||||
{"英文全名", "John Doe", "JD"},
|
||||
{"带下划线", "john_doe", "JD"},
|
||||
{"带连字符", "john-doe", "JD"},
|
||||
{"空字符串", "", "?"},
|
||||
{"小写英文", "alice", "A"},
|
||||
{"中文复合", "李小龙", "李"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := getInitials(tt.username)
|
||||
if got != tt.want {
|
||||
t.Errorf("getInitials(%q) = %q, want %q", tt.username, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringToColor(t *testing.T) {
|
||||
// 测试同一用户名生成的颜色一致
|
||||
color1 := stringToColor("张三")
|
||||
color2 := stringToColor("张三")
|
||||
if color1 != color2 {
|
||||
t.Errorf("stringToColor should return consistent colors for the same input")
|
||||
}
|
||||
|
||||
// 测试不同用户名生成不同颜色(大概率)
|
||||
color3 := stringToColor("李四")
|
||||
if color1 == color3 {
|
||||
t.Logf("Warning: different usernames generated the same color (possible but unlikely)")
|
||||
}
|
||||
|
||||
// 测试空字符串
|
||||
color4 := stringToColor("")
|
||||
if color4 == "" {
|
||||
t.Errorf("stringToColor should return a color for empty string")
|
||||
}
|
||||
|
||||
// 验证颜色格式
|
||||
if !strings.HasPrefix(color4, "#") {
|
||||
t.Errorf("stringToColor should return hex color format starting with #")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSVGAvatar(t *testing.T) {
|
||||
svg := GenerateSVGAvatar("张三", 100)
|
||||
|
||||
// 验证SVG结构
|
||||
if !strings.Contains(svg, "<svg") {
|
||||
t.Errorf("SVG should contain <svg tag")
|
||||
}
|
||||
if !strings.Contains(svg, "</svg>") {
|
||||
t.Errorf("SVG should contain </svg> tag")
|
||||
}
|
||||
if !strings.Contains(svg, "width=\"100\"") {
|
||||
t.Errorf("SVG should have width=100")
|
||||
}
|
||||
if !strings.Contains(svg, "height=\"100\"") {
|
||||
t.Errorf("SVG should have height=100")
|
||||
}
|
||||
if !strings.Contains(svg, "张") {
|
||||
t.Errorf("SVG should contain the initial character")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAvatarDataURI(t *testing.T) {
|
||||
dataURI := GenerateAvatarDataURI("张三", 100)
|
||||
|
||||
// 验证Data URI格式
|
||||
if !strings.HasPrefix(dataURI, "data:image/svg+xml;base64,") {
|
||||
t.Errorf("Data URI should start with data:image/svg+xml;base64,")
|
||||
}
|
||||
|
||||
// 验证base64部分不为空
|
||||
parts := strings.Split(dataURI, ",")
|
||||
if len(parts) != 2 {
|
||||
t.Errorf("Data URI should have two parts separated by comma")
|
||||
}
|
||||
if parts[1] == "" {
|
||||
t.Errorf("Base64 part should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsChinese(t *testing.T) {
|
||||
tests := []struct {
|
||||
r rune
|
||||
want bool
|
||||
}{
|
||||
{'中', true},
|
||||
{'文', true},
|
||||
{'a', false},
|
||||
{'Z', false},
|
||||
{'0', false},
|
||||
{'_', false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := isChinese(tt.r)
|
||||
if got != tt.want {
|
||||
t.Errorf("isChinese(%q) = %v, want %v", tt.r, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
131
internal/pkg/email/client.go
Normal file
131
internal/pkg/email/client.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gomail "gopkg.in/gomail.v2"
|
||||
)
|
||||
|
||||
// Message 发信参数
|
||||
type Message struct {
|
||||
To []string
|
||||
Cc []string
|
||||
Bcc []string
|
||||
ReplyTo []string
|
||||
Subject string
|
||||
TextBody string
|
||||
HTMLBody string
|
||||
Attachments []string
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
IsEnabled() bool
|
||||
Config() Config
|
||||
Send(ctx context.Context, msg Message) error
|
||||
}
|
||||
|
||||
type clientImpl struct {
|
||||
cfg Config
|
||||
}
|
||||
|
||||
func NewClient(cfg Config) Client {
|
||||
return &clientImpl{cfg: cfg}
|
||||
}
|
||||
|
||||
func (c *clientImpl) IsEnabled() bool {
|
||||
return c.cfg.Enabled &&
|
||||
strings.TrimSpace(c.cfg.Host) != "" &&
|
||||
c.cfg.Port > 0 &&
|
||||
strings.TrimSpace(c.cfg.FromAddress) != ""
|
||||
}
|
||||
|
||||
func (c *clientImpl) Config() Config {
|
||||
return c.cfg
|
||||
}
|
||||
|
||||
func (c *clientImpl) Send(ctx context.Context, msg Message) error {
|
||||
if !c.IsEnabled() {
|
||||
return fmt.Errorf("email client is disabled or misconfigured")
|
||||
}
|
||||
if len(msg.To) == 0 {
|
||||
return fmt.Errorf("email recipient is empty")
|
||||
}
|
||||
if strings.TrimSpace(msg.Subject) == "" {
|
||||
return fmt.Errorf("email subject is empty")
|
||||
}
|
||||
if strings.TrimSpace(msg.TextBody) == "" && strings.TrimSpace(msg.HTMLBody) == "" {
|
||||
return fmt.Errorf("email body is empty")
|
||||
}
|
||||
|
||||
m := gomail.NewMessage()
|
||||
m.SetAddressHeader("From", c.cfg.FromAddress, c.cfg.FromName)
|
||||
m.SetHeader("To", msg.To...)
|
||||
if len(msg.Cc) > 0 {
|
||||
m.SetHeader("Cc", msg.Cc...)
|
||||
}
|
||||
if len(msg.Bcc) > 0 {
|
||||
m.SetHeader("Bcc", msg.Bcc...)
|
||||
}
|
||||
if len(msg.ReplyTo) > 0 {
|
||||
m.SetHeader("Reply-To", msg.ReplyTo...)
|
||||
}
|
||||
m.SetHeader("Subject", msg.Subject)
|
||||
|
||||
if strings.TrimSpace(msg.TextBody) != "" && strings.TrimSpace(msg.HTMLBody) != "" {
|
||||
m.SetBody("text/plain", msg.TextBody)
|
||||
m.AddAlternative("text/html", msg.HTMLBody)
|
||||
} else if strings.TrimSpace(msg.HTMLBody) != "" {
|
||||
m.SetBody("text/html", msg.HTMLBody)
|
||||
} else {
|
||||
m.SetBody("text/plain", msg.TextBody)
|
||||
}
|
||||
|
||||
for _, attachment := range msg.Attachments {
|
||||
if strings.TrimSpace(attachment) == "" {
|
||||
continue
|
||||
}
|
||||
m.Attach(attachment)
|
||||
}
|
||||
|
||||
timeout := c.cfg.TimeoutSeconds
|
||||
if timeout <= 0 {
|
||||
timeout = 15
|
||||
}
|
||||
dialer := gomail.NewDialer(c.cfg.Host, c.cfg.Port, c.cfg.Username, c.cfg.Password)
|
||||
if c.cfg.UseTLS {
|
||||
dialer.TLSConfig = &tls.Config{
|
||||
ServerName: c.cfg.Host,
|
||||
InsecureSkipVerify: c.cfg.InsecureSkipVerify,
|
||||
}
|
||||
// 465 端口通常要求直接 TLS(Implicit TLS)。
|
||||
if c.cfg.Port == 465 {
|
||||
dialer.SSL = true
|
||||
}
|
||||
}
|
||||
|
||||
sendCtx := ctx
|
||||
cancel := func() {}
|
||||
if timeout > 0 {
|
||||
sendCtx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- dialer.DialAndSend(m)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sendCtx.Done():
|
||||
return fmt.Errorf("send email canceled: %w", sendCtx.Err())
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return fmt.Errorf("send email failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
33
internal/pkg/email/config.go
Normal file
33
internal/pkg/email/config.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package email
|
||||
|
||||
import "carrot_bbs/internal/config"
|
||||
|
||||
// Config SMTP 邮件配置(由应用配置转换)
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
FromAddress string
|
||||
FromName string
|
||||
UseTLS bool
|
||||
InsecureSkipVerify bool
|
||||
TimeoutSeconds int
|
||||
}
|
||||
|
||||
// ConfigFromAppConfig 从应用配置转换
|
||||
func ConfigFromAppConfig(cfg *config.EmailConfig) Config {
|
||||
return Config{
|
||||
Enabled: cfg.Enabled,
|
||||
Host: cfg.Host,
|
||||
Port: cfg.Port,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
FromAddress: cfg.FromAddress,
|
||||
FromName: cfg.FromName,
|
||||
UseTLS: cfg.UseTLS,
|
||||
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||
TimeoutSeconds: cfg.Timeout,
|
||||
}
|
||||
}
|
||||
286
internal/pkg/gorse/client.go
Normal file
286
internal/pkg/gorse/client.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package gorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
gorseio "github.com/gorse-io/gorse-go"
|
||||
)
|
||||
|
||||
// FeedbackType 反馈类型
|
||||
type FeedbackType string
|
||||
|
||||
const (
|
||||
FeedbackTypeLike FeedbackType = "like" // 点赞
|
||||
FeedbackTypeStar FeedbackType = "star" // 收藏
|
||||
FeedbackTypeComment FeedbackType = "comment" // 评论
|
||||
FeedbackTypeRead FeedbackType = "read" // 浏览
|
||||
)
|
||||
|
||||
// Score 非个性化推荐返回的评分项
|
||||
type Score struct {
|
||||
Id string `json:"Id"`
|
||||
Score float64 `json:"Score"`
|
||||
}
|
||||
|
||||
// Client Gorse客户端接口
|
||||
type Client interface {
|
||||
// InsertFeedback 插入用户反馈
|
||||
InsertFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error
|
||||
// DeleteFeedback 删除用户反馈
|
||||
DeleteFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error
|
||||
// GetRecommend 获取个性化推荐列表
|
||||
GetRecommend(ctx context.Context, userID string, n int, offset int) ([]string, error)
|
||||
// GetNonPersonalized 获取非个性化推荐(通过名称)
|
||||
GetNonPersonalized(ctx context.Context, name string, n int, offset int, userID string) ([]string, error)
|
||||
// UpsertItem 插入或更新物品(无embedding)
|
||||
UpsertItem(ctx context.Context, itemID string, categories []string, comment string) error
|
||||
// UpsertItemWithEmbedding 插入或更新物品(带embedding)
|
||||
UpsertItemWithEmbedding(ctx context.Context, itemID string, categories []string, comment string, textToEmbed string) error
|
||||
// DeleteItem 删除物品
|
||||
DeleteItem(ctx context.Context, itemID string) error
|
||||
// UpsertUser 插入或更新用户
|
||||
UpsertUser(ctx context.Context, userID string, labels map[string]any) error
|
||||
// IsEnabled 检查是否启用
|
||||
IsEnabled() bool
|
||||
}
|
||||
|
||||
// client Gorse客户端实现
|
||||
type client struct {
|
||||
config Config
|
||||
gorse *gorseio.GorseClient
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewClient 创建新的Gorse客户端
|
||||
func NewClient(cfg Config) Client {
|
||||
if !cfg.Enabled {
|
||||
return &noopClient{}
|
||||
}
|
||||
|
||||
gorse := gorseio.NewGorseClient(cfg.Address, cfg.APIKey)
|
||||
return &client{
|
||||
config: cfg,
|
||||
gorse: gorse,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// IsEnabled 检查是否启用
|
||||
func (c *client) IsEnabled() bool {
|
||||
return c.config.Enabled
|
||||
}
|
||||
|
||||
// InsertFeedback 插入用户反馈
|
||||
func (c *client) InsertFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error {
|
||||
if !c.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := c.gorse.InsertFeedback(ctx, []gorseio.Feedback{
|
||||
{
|
||||
FeedbackType: string(feedbackType),
|
||||
UserId: userID,
|
||||
ItemId: itemID,
|
||||
Timestamp: time.Now().UTC().Truncate(time.Second),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteFeedback 删除用户反馈
|
||||
func (c *client) DeleteFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error {
|
||||
if !c.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := c.gorse.DeleteFeedback(ctx, string(feedbackType), userID, itemID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetRecommend 获取个性化推荐列表
|
||||
func (c *client) GetRecommend(ctx context.Context, userID string, n int, offset int) ([]string, error) {
|
||||
if !c.config.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result, err := c.gorse.GetRecommend(ctx, userID, "", n, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetNonPersonalized 获取非个性化推荐
|
||||
// name: 推荐器名称,如 "most_liked_weekly"
|
||||
// n: 返回数量
|
||||
// offset: 偏移量
|
||||
// userID: 可选,用于排除用户已读物品
|
||||
func (c *client) GetNonPersonalized(ctx context.Context, name string, n int, offset int, userID string) ([]string, error) {
|
||||
if !c.config.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 构建URL
|
||||
url := fmt.Sprintf("%s/api/non-personalized/%s?n=%d&offset=%d", c.config.Address, name, n, offset)
|
||||
if userID != "" {
|
||||
url += fmt.Sprintf("&user-id=%s", userID)
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// 设置API Key
|
||||
if c.config.APIKey != "" {
|
||||
req.Header.Set("X-API-Key", c.config.APIKey)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("gorse api error: status=%d, body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var scores []Score
|
||||
if err := json.Unmarshal(body, &scores); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
// 提取ID
|
||||
ids := make([]string, len(scores))
|
||||
for i, score := range scores {
|
||||
ids[i] = score.Id
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// UpsertItem 插入或更新物品
|
||||
func (c *client) UpsertItem(ctx context.Context, itemID string, categories []string, comment string) error {
|
||||
if !c.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := c.gorse.InsertItem(ctx, gorseio.Item{
|
||||
ItemId: itemID,
|
||||
IsHidden: false,
|
||||
Categories: categories,
|
||||
Comment: comment,
|
||||
Timestamp: time.Now().UTC().Truncate(time.Second),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// UpsertItemWithEmbedding 插入或更新物品(带embedding)
|
||||
func (c *client) UpsertItemWithEmbedding(ctx context.Context, itemID string, categories []string, comment string, textToEmbed string) error {
|
||||
if !c.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 生成embedding
|
||||
var embedding []float64
|
||||
if textToEmbed != "" {
|
||||
var err error
|
||||
embedding, err = GetEmbedding(textToEmbed)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Failed to get embedding for item %s: %v, using zero vector", itemID, err)
|
||||
embedding = make([]float64, 1024)
|
||||
}
|
||||
} else {
|
||||
embedding = make([]float64, 1024)
|
||||
}
|
||||
|
||||
_, err := c.gorse.InsertItem(ctx, gorseio.Item{
|
||||
ItemId: itemID,
|
||||
IsHidden: false,
|
||||
Categories: categories,
|
||||
Comment: comment,
|
||||
Timestamp: time.Now().UTC().Truncate(time.Second),
|
||||
Labels: map[string]any{
|
||||
"embedding": embedding,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteItem 删除物品
|
||||
func (c *client) DeleteItem(ctx context.Context, itemID string) error {
|
||||
if !c.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := c.gorse.DeleteItem(ctx, itemID)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpsertUser 插入或更新用户
|
||||
func (c *client) UpsertUser(ctx context.Context, userID string, labels map[string]any) error {
|
||||
if !c.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := c.gorse.InsertUser(ctx, gorseio.User{
|
||||
UserId: userID,
|
||||
Labels: labels,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// noopClient 空操作客户端(用于未启用推荐功能时)
|
||||
type noopClient struct{}
|
||||
|
||||
func (c *noopClient) IsEnabled() bool { return false }
|
||||
func (c *noopClient) InsertFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error {
|
||||
return nil
|
||||
}
|
||||
func (c *noopClient) DeleteFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error {
|
||||
return nil
|
||||
}
|
||||
func (c *noopClient) GetRecommend(ctx context.Context, userID string, n int, offset int) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *noopClient) GetNonPersonalized(ctx context.Context, name string, n int, offset int, userID string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *noopClient) UpsertItem(ctx context.Context, itemID string, categories []string, comment string) error {
|
||||
return nil
|
||||
}
|
||||
func (c *noopClient) UpsertItemWithEmbedding(ctx context.Context, itemID string, categories []string, comment string, textToEmbed string) error {
|
||||
return nil
|
||||
}
|
||||
func (c *noopClient) DeleteItem(ctx context.Context, itemID string) error { return nil }
|
||||
func (c *noopClient) UpsertUser(ctx context.Context, userID string, labels map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 确保实现了接口
|
||||
var _ Client = (*client)(nil)
|
||||
var _ Client = (*noopClient)(nil)
|
||||
|
||||
// log 用于内部日志
|
||||
func init() {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
}
|
||||
23
internal/pkg/gorse/config.go
Normal file
23
internal/pkg/gorse/config.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package gorse
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/config"
|
||||
)
|
||||
|
||||
// Config Gorse客户端配置(从config.GorseConfig转换)
|
||||
type Config struct {
|
||||
Address string
|
||||
APIKey string
|
||||
Enabled bool
|
||||
Dashboard string
|
||||
}
|
||||
|
||||
// ConfigFromAppConfig 从应用配置创建Gorse配置
|
||||
func ConfigFromAppConfig(cfg *config.GorseConfig) Config {
|
||||
return Config{
|
||||
Address: cfg.Address,
|
||||
APIKey: cfg.APIKey,
|
||||
Enabled: cfg.Enabled,
|
||||
Dashboard: cfg.Dashboard,
|
||||
}
|
||||
}
|
||||
106
internal/pkg/gorse/embedding.go
Normal file
106
internal/pkg/gorse/embedding.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package gorse
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EmbeddingConfig embedding服务配置
|
||||
type EmbeddingConfig struct {
|
||||
APIKey string
|
||||
URL string
|
||||
Model string
|
||||
}
|
||||
|
||||
var defaultEmbeddingConfig = EmbeddingConfig{
|
||||
APIKey: "sk-ZPN5NMPSqEaOGCPfD2LqndZ5Wwmw3DC4CQgzgKhM35fI3RpD",
|
||||
URL: "https://api.littlelan.cn/v1/embeddings",
|
||||
Model: "BAAI/bge-m3",
|
||||
}
|
||||
|
||||
// SetEmbeddingConfig 设置embedding配置
|
||||
func SetEmbeddingConfig(apiKey, url, model string) {
|
||||
if apiKey != "" {
|
||||
defaultEmbeddingConfig.APIKey = apiKey
|
||||
}
|
||||
if url != "" {
|
||||
defaultEmbeddingConfig.URL = url
|
||||
}
|
||||
if model != "" {
|
||||
defaultEmbeddingConfig.Model = model
|
||||
}
|
||||
}
|
||||
|
||||
// GetEmbedding 获取文本的embedding
|
||||
func GetEmbedding(text string) ([]float64, error) {
|
||||
type embeddingRequest struct {
|
||||
Input string `json:"input"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type embeddingResponse struct {
|
||||
Data []struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
reqBody := embeddingRequest{
|
||||
Input: text,
|
||||
Model: defaultEmbeddingConfig.Model,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", defaultEmbeddingConfig.URL, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+defaultEmbeddingConfig.APIKey)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("embedding API error: status=%d, body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result embeddingResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Data) == 0 {
|
||||
return nil, fmt.Errorf("no embedding returned")
|
||||
}
|
||||
|
||||
return result.Data[0].Embedding, nil
|
||||
}
|
||||
|
||||
// InitEmbeddingWithConfig 从应用配置初始化embedding
|
||||
func InitEmbeddingWithConfig(apiKey, url, model string) {
|
||||
if apiKey == "" {
|
||||
log.Println("[WARN] Gorse embedding API key not set, using default")
|
||||
}
|
||||
defaultEmbeddingConfig.APIKey = apiKey
|
||||
if url != "" {
|
||||
defaultEmbeddingConfig.URL = url
|
||||
}
|
||||
if model != "" {
|
||||
defaultEmbeddingConfig.Model = model
|
||||
}
|
||||
}
|
||||
105
internal/pkg/jwt/jwt.go
Normal file
105
internal/pkg/jwt/jwt.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
ErrExpiredToken = errors.New("token has expired")
|
||||
)
|
||||
|
||||
// Claims JWT 声明
|
||||
type Claims struct {
|
||||
UserID string `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// JWT JWT工具
|
||||
type JWT struct {
|
||||
secretKey string
|
||||
accessTokenExpire time.Duration
|
||||
refreshTokenExpire time.Duration
|
||||
}
|
||||
|
||||
// New 创建JWT实例
|
||||
func New(secret string, accessExpire, refreshExpire time.Duration) *JWT {
|
||||
return &JWT{
|
||||
secretKey: secret,
|
||||
accessTokenExpire: accessExpire,
|
||||
refreshTokenExpire: refreshExpire,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateAccessToken 生成访问令牌
|
||||
func (j *JWT) GenerateAccessToken(userID, username string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := Claims{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(j.accessTokenExpire)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Issuer: "carrot_bbs",
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(j.secretKey))
|
||||
}
|
||||
|
||||
// GenerateRefreshToken 生成刷新令牌
|
||||
func (j *JWT) GenerateRefreshToken(userID, username string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := Claims{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(j.refreshTokenExpire)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Issuer: "carrot_bbs",
|
||||
ID: "refresh",
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(j.secretKey))
|
||||
}
|
||||
|
||||
// ParseToken 解析令牌
|
||||
func (j *JWT) ParseToken(tokenString string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(j.secretKey), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
// ValidateToken 验证令牌
|
||||
func (j *JWT) ValidateToken(tokenString string) error {
|
||||
claims, err := j.ParseToken(tokenString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查是否是刷新令牌
|
||||
if claims.ID == "refresh" {
|
||||
return errors.New("cannot use refresh token as access token")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
438
internal/pkg/openai/client.go
Normal file
438
internal/pkg/openai/client.go
Normal file
@@ -0,0 +1,438 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
"image/jpeg"
|
||||
_ "image/png"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
xdraw "golang.org/x/image/draw"
|
||||
)
|
||||
|
||||
const moderationSystemPrompt = "你是中文社区的内容审核助手,负责对“帖子标题、正文、配图”做联合审核。目标是平衡社区安全与正常交流:必须拦截高风险违规内容,但不要误伤正常玩梗、二创、吐槽和轻度调侃。请只输出指定JSON。\n\n审核流程:\n1) 先判断是否命中硬性违规;\n2) 再判断语境(玩笑/自嘲/朋友间互动/作品讨论);\n3) 做文图交叉判断(文本+图片合并理解);\n4) 给出 approved 与简短 reason。\n\n硬性违规(命中任一项必须 approved=false):\nA. 宣传对立与煽动撕裂:\n- 明确煽动群体对立、地域对立、性别对立、民族宗教对立,鼓动仇恨、排斥、报复。\nB. 严重人身攻击与网暴引导:\n- 持续性侮辱贬损、羞辱人格、号召围攻/骚扰/挂人/线下冲突。\nC. 开盒/人肉/隐私暴露:\n- 故意公开、拼接、索取他人可识别隐私信息(姓名+联系方式、身份证号、住址、学校单位、车牌、定位轨迹等);\n- 图片/截图中出现可识别隐私信息并伴随曝光意图,也按违规处理。\nD. 其他高危违规:\n- 违法犯罪、暴力威胁、极端仇恨、色情低俗、诈骗引流、恶意广告等。\n\n放行规则(以下通常 approved=true):\n- 正常玩梗、表情包、谐音梗、二次创作、无恶意的吐槽;\n- 非定向、轻度口语化吐槽(无明确攻击对象、无网暴号召、无隐私暴露);\n- 对社会事件/作品的理性讨论、观点争论(即使语气尖锐,但未煽动对立或人身攻击)。\n\n边界判定:\n- 若只是“梗文化表达”且不指向现实伤害,优先通过;\n- 若存在明确伤害意图(煽动、围攻、曝光隐私),必须拒绝;\n- 对模糊内容不因个别粗口直接拒绝,需结合对象、意图、号召性和可执行性综合判断。\n\nreason 要求:\n- approved=false 时:中文10-30字,说明核心违规点;\n- approved=true 时:reason 为空字符串。\n\n输出格式(严格):\n仅输出一行JSON对象,不要Markdown,不要额外解释:\n{\"approved\": true/false, \"reason\": \"...\"}"
|
||||
|
||||
const (
|
||||
defaultMaxImagesPerModerationRequest = 1
|
||||
maxModerationResultRetries = 3
|
||||
maxChatCompletionRetries = 3
|
||||
initialRetryBackoff = 500 * time.Millisecond
|
||||
maxDownloadImageBytes = 10 * 1024 * 1024
|
||||
maxModerationImageSide = 1280
|
||||
compressedJPEGQuality = 72
|
||||
maxCompressedPayloadBytes = 1536 * 1024
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
IsEnabled() bool
|
||||
Config() Config
|
||||
ModeratePost(ctx context.Context, title, content string, images []string) (bool, string, error)
|
||||
ModerateComment(ctx context.Context, content string, images []string) (bool, string, error)
|
||||
}
|
||||
|
||||
type clientImpl struct {
|
||||
cfg Config
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewClient(cfg Config) Client {
|
||||
timeout := cfg.RequestTimeoutSeconds
|
||||
if timeout <= 0 {
|
||||
timeout = 30
|
||||
}
|
||||
return &clientImpl{
|
||||
cfg: cfg,
|
||||
httpClient: &http.Client{
|
||||
Timeout: time.Duration(timeout) * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientImpl) IsEnabled() bool {
|
||||
return c.cfg.Enabled && c.cfg.APIKey != "" && c.cfg.BaseURL != ""
|
||||
}
|
||||
|
||||
func (c *clientImpl) Config() Config {
|
||||
return c.cfg
|
||||
}
|
||||
|
||||
func (c *clientImpl) ModeratePost(ctx context.Context, title, content string, images []string) (bool, string, error) {
|
||||
if !c.IsEnabled() {
|
||||
return true, "", nil
|
||||
}
|
||||
return c.moderateContentInBatches(ctx, fmt.Sprintf("帖子标题:%s\n帖子内容:%s", title, content), images)
|
||||
}
|
||||
|
||||
func (c *clientImpl) ModerateComment(ctx context.Context, content string, images []string) (bool, string, error) {
|
||||
if !c.IsEnabled() {
|
||||
return true, "", nil
|
||||
}
|
||||
return c.moderateContentInBatches(ctx, fmt.Sprintf("评论内容:%s", content), images)
|
||||
}
|
||||
|
||||
func (c *clientImpl) moderateContentInBatches(ctx context.Context, contentPrompt string, images []string) (bool, string, error) {
|
||||
cleanImages := normalizeImageURLs(images)
|
||||
optimizedImages := c.optimizeImagesForModeration(ctx, cleanImages)
|
||||
maxImagesPerRequest := c.maxImagesPerModerationRequest()
|
||||
totalBatches := 1
|
||||
if len(optimizedImages) > 0 {
|
||||
totalBatches = (len(optimizedImages) + maxImagesPerRequest - 1) / maxImagesPerRequest
|
||||
}
|
||||
|
||||
// 图片超过单批上限时分批审核,任一批次拒绝即整体拒绝
|
||||
for i := 0; i < totalBatches; i++ {
|
||||
start := i * maxImagesPerRequest
|
||||
end := start + maxImagesPerRequest
|
||||
if end > len(optimizedImages) {
|
||||
end = len(optimizedImages)
|
||||
}
|
||||
|
||||
batchImages := []string{}
|
||||
if len(optimizedImages) > 0 {
|
||||
batchImages = optimizedImages[start:end]
|
||||
}
|
||||
|
||||
approved, reason, err := c.moderateSingleBatch(ctx, contentPrompt, batchImages, i+1, totalBatches)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
if !approved {
|
||||
if strings.TrimSpace(reason) != "" && totalBatches > 1 {
|
||||
reason = fmt.Sprintf("第%d/%d批图片未通过:%s", i+1, totalBatches, reason)
|
||||
}
|
||||
return false, reason, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
func (c *clientImpl) moderateSingleBatch(
|
||||
ctx context.Context,
|
||||
contentPrompt string,
|
||||
images []string,
|
||||
batchNo, totalBatches int,
|
||||
) (bool, string, error) {
|
||||
userPrompt := fmt.Sprintf(
|
||||
"%s\n图片批次:%d/%d(本次仅提供当前批次图片)",
|
||||
contentPrompt,
|
||||
batchNo,
|
||||
totalBatches,
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxModerationResultRetries; attempt++ {
|
||||
replyText, err := c.chatCompletion(ctx, c.cfg.ModerationModel, moderationSystemPrompt, userPrompt, images, 0.1, 220)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
} else {
|
||||
parsed := struct {
|
||||
Approved bool `json:"approved"`
|
||||
Reason string `json:"reason"`
|
||||
}{}
|
||||
if err := json.Unmarshal([]byte(extractJSONObject(replyText)), &parsed); err != nil {
|
||||
lastErr = fmt.Errorf("failed to parse moderation result: %w", err)
|
||||
} else {
|
||||
return parsed.Approved, parsed.Reason, nil
|
||||
}
|
||||
}
|
||||
|
||||
if attempt == maxModerationResultRetries-1 {
|
||||
break
|
||||
}
|
||||
if sleepErr := sleepWithBackoff(ctx, attempt); sleepErr != nil {
|
||||
return false, "", sleepErr
|
||||
}
|
||||
}
|
||||
|
||||
return false, "", fmt.Errorf(
|
||||
"moderation batch %d/%d failed after %d attempts: %w",
|
||||
batchNo,
|
||||
totalBatches,
|
||||
maxModerationResultRetries,
|
||||
lastErr,
|
||||
)
|
||||
}
|
||||
|
||||
type chatCompletionsRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type chatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content interface{} `json:"content"`
|
||||
}
|
||||
|
||||
type contentPart struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL *imageURLPart `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type imageURLPart struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type chatCompletionsResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
func (c *clientImpl) chatCompletion(
|
||||
ctx context.Context,
|
||||
model string,
|
||||
systemPrompt string,
|
||||
userPrompt string,
|
||||
images []string,
|
||||
temperature float64,
|
||||
maxTokens int,
|
||||
) (string, error) {
|
||||
if model == "" {
|
||||
return "", fmt.Errorf("model is empty")
|
||||
}
|
||||
|
||||
cleanImages := normalizeImageURLs(images)
|
||||
|
||||
userParts := []contentPart{
|
||||
{Type: "text", Text: userPrompt},
|
||||
}
|
||||
for _, image := range cleanImages {
|
||||
userParts = append(userParts, contentPart{
|
||||
Type: "image_url",
|
||||
ImageURL: &imageURLPart{URL: image},
|
||||
})
|
||||
}
|
||||
|
||||
reqBody := chatCompletionsRequest{
|
||||
Model: model,
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: userParts},
|
||||
},
|
||||
Temperature: temperature,
|
||||
MaxTokens: maxTokens,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(c.cfg.BaseURL, "/")
|
||||
endpoint := baseURL + "/v1/chat/completions"
|
||||
if strings.HasSuffix(baseURL, "/v1") {
|
||||
endpoint = baseURL + "/chat/completions"
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxChatCompletionRetries; attempt++ {
|
||||
body, statusCode, err := c.doChatCompletionRequest(ctx, endpoint, data)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
} else if statusCode >= 400 {
|
||||
lastErr = fmt.Errorf("openai error status=%d body=%s", statusCode, string(body))
|
||||
if !isRetryableStatusCode(statusCode) {
|
||||
return "", lastErr
|
||||
}
|
||||
} else {
|
||||
var parsed chatCompletionsResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return "", fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
if len(parsed.Choices) == 0 {
|
||||
return "", fmt.Errorf("empty response choices")
|
||||
}
|
||||
return parsed.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
if attempt == maxChatCompletionRetries-1 {
|
||||
break
|
||||
}
|
||||
if sleepErr := sleepWithBackoff(ctx, attempt); sleepErr != nil {
|
||||
return "", sleepErr
|
||||
}
|
||||
}
|
||||
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
func (c *clientImpl) doChatCompletionRequest(ctx context.Context, endpoint string, data []byte) ([]byte, int, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.cfg.APIKey)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("request openai: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
return body, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func isRetryableStatusCode(statusCode int) bool {
|
||||
if statusCode == http.StatusTooManyRequests {
|
||||
return true
|
||||
}
|
||||
return statusCode >= 500 && statusCode <= 599
|
||||
}
|
||||
|
||||
func sleepWithBackoff(ctx context.Context, attempt int) error {
|
||||
delay := initialRetryBackoff * time.Duration(1<<attempt)
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("request cancelled: %w", ctx.Err())
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeImageURLs(images []string) []string {
|
||||
clean := make([]string, 0, len(images))
|
||||
for _, image := range images {
|
||||
trimmed := strings.TrimSpace(image)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
clean = append(clean, trimmed)
|
||||
}
|
||||
return clean
|
||||
}
|
||||
|
||||
func extractJSONObject(raw string) string {
|
||||
text := strings.TrimSpace(raw)
|
||||
start := strings.Index(text, "{")
|
||||
end := strings.LastIndex(text, "}")
|
||||
if start >= 0 && end > start {
|
||||
return text[start : end+1]
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func (c *clientImpl) maxImagesPerModerationRequest() int {
|
||||
// 审核固定单图请求,降低单次payload体积,减少超时风险。
|
||||
if c.cfg.ModerationMaxImagesPerRequest <= 0 {
|
||||
return defaultMaxImagesPerModerationRequest
|
||||
}
|
||||
if c.cfg.ModerationMaxImagesPerRequest > 1 {
|
||||
return 1
|
||||
}
|
||||
return c.cfg.ModerationMaxImagesPerRequest
|
||||
}
|
||||
|
||||
func (c *clientImpl) optimizeImagesForModeration(ctx context.Context, images []string) []string {
|
||||
if len(images) == 0 {
|
||||
return images
|
||||
}
|
||||
|
||||
optimized := make([]string, 0, len(images))
|
||||
for _, imageURL := range images {
|
||||
optimized = append(optimized, c.tryCompressImageForModeration(ctx, imageURL))
|
||||
}
|
||||
return optimized
|
||||
}
|
||||
|
||||
func (c *clientImpl) tryCompressImageForModeration(ctx context.Context, imageURL string) string {
|
||||
parsed, err := url.Parse(imageURL)
|
||||
if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") {
|
||||
return imageURL
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, imageURL, nil)
|
||||
if err != nil {
|
||||
return imageURL
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return imageURL
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return imageURL
|
||||
}
|
||||
if !strings.HasPrefix(strings.ToLower(resp.Header.Get("Content-Type")), "image/") {
|
||||
return imageURL
|
||||
}
|
||||
|
||||
originBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxDownloadImageBytes))
|
||||
if err != nil || len(originBytes) == 0 {
|
||||
return imageURL
|
||||
}
|
||||
|
||||
srcImg, _, err := image.Decode(bytes.NewReader(originBytes))
|
||||
if err != nil {
|
||||
return imageURL
|
||||
}
|
||||
|
||||
dstImg := resizeIfNeeded(srcImg, maxModerationImageSide)
|
||||
var buf bytes.Buffer
|
||||
if err := jpeg.Encode(&buf, dstImg, &jpeg.Options{Quality: compressedJPEGQuality}); err != nil {
|
||||
return imageURL
|
||||
}
|
||||
|
||||
compressed := buf.Bytes()
|
||||
if len(compressed) == 0 || len(compressed) > maxCompressedPayloadBytes {
|
||||
return imageURL
|
||||
}
|
||||
// 压缩效果不明显时直接使用原图URL,避免增大请求体。
|
||||
if len(compressed) >= int(float64(len(originBytes))*0.95) {
|
||||
return imageURL
|
||||
}
|
||||
|
||||
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(compressed)
|
||||
}
|
||||
|
||||
func resizeIfNeeded(src image.Image, maxSide int) image.Image {
|
||||
bounds := src.Bounds()
|
||||
w := bounds.Dx()
|
||||
h := bounds.Dy()
|
||||
if w <= 0 || h <= 0 || max(w, h) <= maxSide {
|
||||
return src
|
||||
}
|
||||
|
||||
newW, newH := w, h
|
||||
if w >= h {
|
||||
newW = maxSide
|
||||
newH = int(float64(h) * (float64(maxSide) / float64(w)))
|
||||
} else {
|
||||
newH = maxSide
|
||||
newW = int(float64(w) * (float64(maxSide) / float64(h)))
|
||||
}
|
||||
if newW < 1 {
|
||||
newW = 1
|
||||
}
|
||||
if newH < 1 {
|
||||
newH = 1
|
||||
}
|
||||
|
||||
dst := image.NewRGBA(image.Rect(0, 0, newW, newH))
|
||||
xdraw.CatmullRom.Scale(dst, dst.Bounds(), src, bounds, xdraw.Over, nil)
|
||||
return dst
|
||||
}
|
||||
27
internal/pkg/openai/config.go
Normal file
27
internal/pkg/openai/config.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package openai
|
||||
|
||||
import "carrot_bbs/internal/config"
|
||||
|
||||
// Config OpenAI 兼容接口配置
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
BaseURL string
|
||||
APIKey string
|
||||
ModerationModel string
|
||||
ModerationMaxImagesPerRequest int
|
||||
RequestTimeoutSeconds int
|
||||
StrictModeration bool
|
||||
}
|
||||
|
||||
// ConfigFromAppConfig 从应用配置转换
|
||||
func ConfigFromAppConfig(cfg *config.OpenAIConfig) Config {
|
||||
return Config{
|
||||
Enabled: cfg.Enabled,
|
||||
BaseURL: cfg.BaseURL,
|
||||
APIKey: cfg.APIKey,
|
||||
ModerationModel: cfg.ModerationModel,
|
||||
ModerationMaxImagesPerRequest: cfg.ModerationMaxImagesPerRequest,
|
||||
RequestTimeoutSeconds: cfg.RequestTimeout,
|
||||
StrictModeration: cfg.StrictModeration,
|
||||
}
|
||||
}
|
||||
119
internal/pkg/redis/redis.go
Normal file
119
internal/pkg/redis/redis.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"carrot_bbs/internal/config"
|
||||
)
|
||||
|
||||
// Client Redis客户端
|
||||
type Client struct {
|
||||
rdb *redis.Client
|
||||
isMiniRedis bool
|
||||
mr *miniredis.Miniredis
|
||||
}
|
||||
|
||||
// New 创建Redis客户端
|
||||
func New(cfg *config.RedisConfig) (*Client, error) {
|
||||
switch cfg.Type {
|
||||
case "miniredis":
|
||||
// 启动内嵌Redis模拟
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start miniredis: %w", err)
|
||||
}
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: mr.Addr(),
|
||||
Password: "",
|
||||
DB: 0,
|
||||
})
|
||||
return &Client{
|
||||
rdb: rdb,
|
||||
isMiniRedis: true,
|
||||
mr: mr,
|
||||
}, nil
|
||||
case "redis":
|
||||
// 使用真实Redis
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Redis.Addr(),
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
PoolSize: cfg.PoolSize,
|
||||
})
|
||||
ctx := context.Background()
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to redis: %w", err)
|
||||
}
|
||||
return &Client{rdb: rdb, isMiniRedis: false}, nil
|
||||
default:
|
||||
// 默认使用miniredis
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start miniredis: %w", err)
|
||||
}
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: mr.Addr(),
|
||||
})
|
||||
return &Client{
|
||||
rdb: rdb,
|
||||
isMiniRedis: true,
|
||||
mr: mr,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Get 获取值
|
||||
func (c *Client) Get(ctx context.Context, key string) (string, error) {
|
||||
return c.rdb.Get(ctx, key).Result()
|
||||
}
|
||||
|
||||
// Set 设置值
|
||||
func (c *Client) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
|
||||
return c.rdb.Set(ctx, key, value, expiration).Err()
|
||||
}
|
||||
|
||||
// Del 删除键
|
||||
func (c *Client) Del(ctx context.Context, keys ...string) error {
|
||||
return c.rdb.Del(ctx, keys...).Err()
|
||||
}
|
||||
|
||||
// Exists 检查键是否存在
|
||||
func (c *Client) Exists(ctx context.Context, keys ...string) (int64, error) {
|
||||
return c.rdb.Exists(ctx, keys...).Result()
|
||||
}
|
||||
|
||||
// Incr 递增
|
||||
func (c *Client) Incr(ctx context.Context, key string) (int64, error) {
|
||||
return c.rdb.Incr(ctx, key).Result()
|
||||
}
|
||||
|
||||
// Expire 设置过期时间
|
||||
func (c *Client) Expire(ctx context.Context, key string, expiration time.Duration) (bool, error) {
|
||||
return c.rdb.Expire(ctx, key, expiration).Result()
|
||||
}
|
||||
|
||||
// GetClient 获取原生客户端
|
||||
func (c *Client) GetClient() *redis.Client {
|
||||
return c.rdb
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (c *Client) Close() error {
|
||||
if err := c.rdb.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if c.mr != nil {
|
||||
c.mr.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsMiniRedis 返回是否是miniredis
|
||||
func (c *Client) IsMiniRedis() bool {
|
||||
return c.isMiniRedis
|
||||
}
|
||||
117
internal/pkg/response/response.go
Normal file
117
internal/pkg/response/response.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Response 统一响应结构
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// ResponseSnakeCase 统一响应结构(snake_case)
|
||||
type ResponseSnakeCase struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Success 成功响应
|
||||
func Success(c *gin.Context, data interface{}) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// SuccessWithMessage 成功响应带消息
|
||||
func SuccessWithMessage(c *gin.Context, message string, data interface{}) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Code: 0,
|
||||
Message: message,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Error 错误响应
|
||||
func Error(c *gin.Context, code int, message string) {
|
||||
c.JSON(http.StatusBadRequest, Response{
|
||||
Code: code,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
|
||||
// ErrorWithStatusCode 带状态码的错误响应
|
||||
func ErrorWithStatusCode(c *gin.Context, statusCode int, code int, message string) {
|
||||
c.JSON(statusCode, Response{
|
||||
Code: code,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
|
||||
// BadRequest 参数错误
|
||||
func BadRequest(c *gin.Context, message string) {
|
||||
ErrorWithStatusCode(c, http.StatusBadRequest, 400, message)
|
||||
}
|
||||
|
||||
// Unauthorized 未授权
|
||||
func Unauthorized(c *gin.Context, message string) {
|
||||
if message == "" {
|
||||
message = "unauthorized"
|
||||
}
|
||||
ErrorWithStatusCode(c, http.StatusUnauthorized, 401, message)
|
||||
}
|
||||
|
||||
// Forbidden 禁止访问
|
||||
func Forbidden(c *gin.Context, message string) {
|
||||
if message == "" {
|
||||
message = "forbidden"
|
||||
}
|
||||
ErrorWithStatusCode(c, http.StatusForbidden, 403, message)
|
||||
}
|
||||
|
||||
// NotFound 资源不存在
|
||||
func NotFound(c *gin.Context, message string) {
|
||||
if message == "" {
|
||||
message = "resource not found"
|
||||
}
|
||||
ErrorWithStatusCode(c, http.StatusNotFound, 404, message)
|
||||
}
|
||||
|
||||
// InternalServerError 服务器内部错误
|
||||
func InternalServerError(c *gin.Context, message string) {
|
||||
if message == "" {
|
||||
message = "internal server error"
|
||||
}
|
||||
ErrorWithStatusCode(c, http.StatusInternalServerError, 500, message)
|
||||
}
|
||||
|
||||
// PaginatedResponse 分页响应
|
||||
type PaginatedResponse struct {
|
||||
List interface{} `json:"list"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
// Paginated 分页成功响应
|
||||
func Paginated(c *gin.Context, list interface{}, total int64, page, pageSize int) {
|
||||
totalPages := int(total) / pageSize
|
||||
if int(total)%pageSize > 0 {
|
||||
totalPages++
|
||||
}
|
||||
|
||||
Success(c, PaginatedResponse{
|
||||
List: list,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
})
|
||||
}
|
||||
119
internal/pkg/s3/s3.go
Normal file
119
internal/pkg/s3/s3.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
|
||||
"carrot_bbs/internal/config"
|
||||
)
|
||||
|
||||
// Client S3客户端
|
||||
type Client struct {
|
||||
client *minio.Client
|
||||
bucket string
|
||||
domain string
|
||||
}
|
||||
|
||||
// New 创建S3客户端
|
||||
func New(cfg *config.S3Config) (*Client, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
client, err := minio.New(cfg.Endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(cfg.AccessKey, cfg.SecretKey, ""),
|
||||
Secure: cfg.UseSSL,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create S3 client: %w", err)
|
||||
}
|
||||
|
||||
// 检查bucket是否存在
|
||||
exists, err := client.BucketExists(ctx, cfg.Bucket)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check bucket: %w", err)
|
||||
}
|
||||
|
||||
if !exists {
|
||||
if err := client.MakeBucket(ctx, cfg.Bucket, minio.MakeBucketOptions{
|
||||
Region: cfg.Region,
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("failed to create bucket: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有配置domain,则使用默认的endpoint
|
||||
domain := cfg.Domain
|
||||
if domain == "" {
|
||||
domain = cfg.Endpoint
|
||||
}
|
||||
|
||||
return &Client{
|
||||
client: client,
|
||||
bucket: cfg.Bucket,
|
||||
domain: domain,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Upload 上传文件
|
||||
func (c *Client) Upload(ctx context.Context, objectName string, filePath string, contentType string) (string, error) {
|
||||
_, err := c.client.FPutObject(ctx, c.bucket, objectName, filePath, minio.PutObjectOptions{
|
||||
ContentType: contentType,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to upload file: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/%s", c.bucket, objectName), nil
|
||||
}
|
||||
|
||||
// UploadData 上传数据
|
||||
func (c *Client) UploadData(ctx context.Context, objectName string, data []byte, contentType string) (string, error) {
|
||||
_, err := c.client.PutObject(ctx, c.bucket, objectName, bytes.NewReader(data), int64(len(data)), minio.PutObjectOptions{
|
||||
ContentType: contentType,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to upload data: %w", err)
|
||||
}
|
||||
|
||||
// 返回完整URL,包含bucket名称
|
||||
scheme := "https"
|
||||
if c.domain == c.bucket || c.domain == "" {
|
||||
scheme = "http"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s/%s/%s", scheme, c.domain, c.bucket, objectName), nil
|
||||
}
|
||||
|
||||
// GetURL 获取文件URL - 使用自定义域名
|
||||
func (c *Client) GetURL(ctx context.Context, objectName string) (string, error) {
|
||||
// 使用自定义域名构建URL,包含bucket名称
|
||||
scheme := "https"
|
||||
if c.domain == c.bucket || c.domain == "" {
|
||||
scheme = "http"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s/%s/%s", scheme, c.domain, c.bucket, objectName), nil
|
||||
}
|
||||
|
||||
// GetPresignedURL 获取预签名URL(用于私有桶)
|
||||
func (c *Client) GetPresignedURL(ctx context.Context, objectName string) (string, error) {
|
||||
url, err := c.client.PresignedGetObject(ctx, c.bucket, objectName, time.Hour*24, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get presigned URL: %w", err)
|
||||
}
|
||||
|
||||
return url.String(), nil
|
||||
}
|
||||
|
||||
// Delete 删除文件
|
||||
func (c *Client) Delete(ctx context.Context, objectName string) error {
|
||||
return c.client.RemoveObject(ctx, c.bucket, objectName, minio.RemoveObjectOptions{})
|
||||
}
|
||||
|
||||
// GetClient 获取原生客户端
|
||||
func (c *Client) GetClient() *minio.Client {
|
||||
return c.client
|
||||
}
|
||||
52
internal/pkg/utils/avatar.go
Normal file
52
internal/pkg/utils/avatar.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// AvatarServiceBaseURL 默认头像服务基础URL (使用 UI Avatars API)
|
||||
const AvatarServiceBaseURL = "https://ui-avatars.com/api"
|
||||
|
||||
// DefaultAvatarSize 默认头像尺寸
|
||||
const DefaultAvatarSize = 100
|
||||
|
||||
// AvatarInfo 头像信息
|
||||
type AvatarInfo struct {
|
||||
Username string
|
||||
Nickname string
|
||||
Avatar string
|
||||
}
|
||||
|
||||
// GetAvatarOrDefault 获取头像URL,如果为空则返回在线头像生成服务的URL
|
||||
// 优先使用已有的头像,否则使用昵称或用户名生成默认头像
|
||||
func GetAvatarOrDefault(username, nickname, avatar string) string {
|
||||
if avatar != "" {
|
||||
return avatar
|
||||
}
|
||||
// 使用用户名生成默认头像URL(优先使用昵称)
|
||||
displayName := nickname
|
||||
if displayName == "" {
|
||||
displayName = username
|
||||
}
|
||||
return GenerateDefaultAvatarURL(displayName)
|
||||
}
|
||||
|
||||
// GetAvatarOrDefaultFromInfo 从 AvatarInfo 获取头像URL
|
||||
func GetAvatarOrDefaultFromInfo(info AvatarInfo) string {
|
||||
return GetAvatarOrDefault(info.Username, info.Nickname, info.Avatar)
|
||||
}
|
||||
|
||||
// GenerateDefaultAvatarURL 生成默认头像URL
|
||||
// 使用 UI Avatars API 生成基于用户名首字母的头像
|
||||
func GenerateDefaultAvatarURL(name string) string {
|
||||
if name == "" {
|
||||
name = "?"
|
||||
}
|
||||
// 使用 UI Avatars API 生成头像
|
||||
params := url.Values{}
|
||||
params.Set("name", url.QueryEscape(name))
|
||||
params.Set("size", "100")
|
||||
params.Set("background", "0D8ABC") // 默认蓝色背景
|
||||
params.Set("color", "ffffff") // 白色文字
|
||||
return AvatarServiceBaseURL + "?" + params.Encode()
|
||||
}
|
||||
17
internal/pkg/utils/hash.go
Normal file
17
internal/pkg/utils/hash.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// HashPassword 密码哈希
|
||||
func HashPassword(password string) (string, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
// CheckPasswordHash 验证密码
|
||||
func CheckPasswordHash(password, hash string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
261
internal/pkg/utils/snowflake.go
Normal file
261
internal/pkg/utils/snowflake.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 雪花算法常量定义
|
||||
const (
|
||||
// 64位ID结构:1位符号位 + 41位时间戳 + 10位机器ID + 12位序列号
|
||||
|
||||
// 机器ID占用的位数
|
||||
nodeIDBits uint64 = 10
|
||||
// 序列号占用的位数
|
||||
sequenceBits uint64 = 12
|
||||
|
||||
// 机器ID的最大值 (0-1023)
|
||||
maxNodeID int64 = -1 ^ (-1 << nodeIDBits)
|
||||
// 序列号的最大值 (0-4095)
|
||||
maxSequence int64 = -1 ^ (-1 << sequenceBits)
|
||||
|
||||
// 机器ID左移位数
|
||||
nodeIDShift uint64 = sequenceBits
|
||||
// 时间戳左移位数
|
||||
timestampShift uint64 = sequenceBits + nodeIDBits
|
||||
|
||||
// 自定义纪元时间:2024-01-01 00:00:00 UTC
|
||||
// 使用自定义纪元可以延长ID有效期约24年(从2024年开始)
|
||||
customEpoch int64 = 1704067200000 // 2024-01-01 00:00:00 UTC 的毫秒时间戳
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
var (
|
||||
// ErrInvalidNodeID 机器ID无效
|
||||
ErrInvalidNodeID = errors.New("node ID must be between 0 and 1023")
|
||||
// ErrClockBackwards 时钟回拨
|
||||
ErrClockBackwards = errors.New("clock moved backwards, refusing to generate ID")
|
||||
)
|
||||
|
||||
// IDInfo 解析后的ID信息
|
||||
type IDInfo struct {
|
||||
Timestamp int64 // 生成ID时的时间戳(毫秒)
|
||||
NodeID int64 // 机器ID
|
||||
Sequence int64 // 序列号
|
||||
}
|
||||
|
||||
// Snowflake 雪花算法ID生成器
|
||||
type Snowflake struct {
|
||||
mu sync.Mutex // 互斥锁,保证线程安全
|
||||
nodeID int64 // 机器ID (0-1023)
|
||||
sequence int64 // 当前序列号 (0-4095)
|
||||
lastTimestamp int64 // 上次生成ID的时间戳
|
||||
}
|
||||
|
||||
// 全局雪花算法实例
|
||||
var (
|
||||
globalSnowflake *Snowflake
|
||||
globalSnowflakeOnce sync.Once
|
||||
globalSnowflakeErr error
|
||||
)
|
||||
|
||||
// InitSnowflake 初始化全局雪花算法实例
|
||||
// nodeID: 机器ID,范围0-1023,可以通过环境变量 NODE_ID 配置
|
||||
func InitSnowflake(nodeID int64) error {
|
||||
globalSnowflake, globalSnowflakeErr = NewSnowflake(nodeID)
|
||||
return globalSnowflakeErr
|
||||
}
|
||||
|
||||
// GetSnowflake 获取全局雪花算法实例
|
||||
// 如果未初始化,会自动使用默认配置初始化
|
||||
func GetSnowflake() *Snowflake {
|
||||
globalSnowflakeOnce.Do(func() {
|
||||
if globalSnowflake == nil {
|
||||
globalSnowflake, globalSnowflakeErr = NewSnowflake(-1)
|
||||
}
|
||||
})
|
||||
return globalSnowflake
|
||||
}
|
||||
|
||||
// NewSnowflake 创建雪花算法ID生成器实例
|
||||
// nodeID: 机器ID,范围0-1023,可以通过环境变量 NODE_ID 配置
|
||||
// 如果nodeID为-1,则尝试从环境变量 NODE_ID 读取
|
||||
func NewSnowflake(nodeID int64) (*Snowflake, error) {
|
||||
// 如果传入-1,尝试从环境变量读取
|
||||
if nodeID == -1 {
|
||||
nodeIDStr := os.Getenv("NODE_ID")
|
||||
if nodeIDStr != "" {
|
||||
// 解析环境变量
|
||||
parsedID, err := parseInt(nodeIDStr)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidNodeID
|
||||
}
|
||||
nodeID = parsedID
|
||||
} else {
|
||||
// 默认使用0
|
||||
nodeID = 0
|
||||
}
|
||||
}
|
||||
|
||||
// 验证机器ID范围
|
||||
if nodeID < 0 || nodeID > maxNodeID {
|
||||
return nil, ErrInvalidNodeID
|
||||
}
|
||||
|
||||
return &Snowflake{
|
||||
nodeID: nodeID,
|
||||
sequence: 0,
|
||||
lastTimestamp: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseInt 辅助函数:解析整数
|
||||
func parseInt(s string) (int64, error) {
|
||||
var result int64
|
||||
var negative bool
|
||||
|
||||
if len(s) == 0 {
|
||||
return 0, errors.New("empty string")
|
||||
}
|
||||
|
||||
i := 0
|
||||
if s[0] == '-' {
|
||||
negative = true
|
||||
i = 1
|
||||
}
|
||||
|
||||
for ; i < len(s); i++ {
|
||||
if s[i] < '0' || s[i] > '9' {
|
||||
return 0, errors.New("invalid character")
|
||||
}
|
||||
result = result*10 + int64(s[i]-'0')
|
||||
}
|
||||
|
||||
if negative {
|
||||
result = -result
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateID 生成唯一的雪花算法ID
|
||||
// 返回值:生成的ID,以及可能的错误(如时钟回拨)
|
||||
// 线程安全:使用互斥锁保证并发安全
|
||||
func (s *Snowflake) GenerateID() (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// 获取当前时间戳(毫秒)
|
||||
now := currentTimestamp()
|
||||
|
||||
// 处理时钟回拨
|
||||
if now < s.lastTimestamp {
|
||||
return 0, ErrClockBackwards
|
||||
}
|
||||
|
||||
// 同一毫秒内
|
||||
if now == s.lastTimestamp {
|
||||
// 序列号递增
|
||||
s.sequence = (s.sequence + 1) & maxSequence
|
||||
// 序列号溢出,等待下一毫秒
|
||||
if s.sequence == 0 {
|
||||
now = s.waitNextMillis(now)
|
||||
}
|
||||
} else {
|
||||
// 不同毫秒,序列号重置为0
|
||||
s.sequence = 0
|
||||
}
|
||||
|
||||
// 更新上次生成时间
|
||||
s.lastTimestamp = now
|
||||
|
||||
// 组装ID
|
||||
// ID结构:时间戳部分 | 机器ID部分 | 序列号部分
|
||||
id := ((now - customEpoch) << timestampShift) |
|
||||
(s.nodeID << nodeIDShift) |
|
||||
s.sequence
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// waitNextMillis 等待到下一毫秒
|
||||
// 参数:当前时间戳
|
||||
// 返回值:下一毫秒的时间戳
|
||||
func (s *Snowflake) waitNextMillis(timestamp int64) int64 {
|
||||
now := currentTimestamp()
|
||||
for now <= timestamp {
|
||||
now = currentTimestamp()
|
||||
}
|
||||
return now
|
||||
}
|
||||
|
||||
// ParseID 解析雪花算法ID,提取其中的信息
|
||||
// id: 要解析的雪花算法ID
|
||||
// 返回值:包含时间戳、机器ID、序列号的结构体
|
||||
func ParseID(id int64) *IDInfo {
|
||||
// 提取序列号(低12位)
|
||||
sequence := id & maxSequence
|
||||
|
||||
// 提取机器ID(中间10位)
|
||||
nodeID := (id >> nodeIDShift) & maxNodeID
|
||||
|
||||
// 提取时间戳(高41位)
|
||||
timestamp := (id >> timestampShift) + customEpoch
|
||||
|
||||
return &IDInfo{
|
||||
Timestamp: timestamp,
|
||||
NodeID: nodeID,
|
||||
Sequence: sequence,
|
||||
}
|
||||
}
|
||||
|
||||
// currentTimestamp 获取当前时间戳(毫秒)
|
||||
func currentTimestamp() int64 {
|
||||
return time.Now().UnixNano() / 1e6
|
||||
}
|
||||
|
||||
// GetNodeID 获取当前机器ID
|
||||
func (s *Snowflake) GetNodeID() int64 {
|
||||
return s.nodeID
|
||||
}
|
||||
|
||||
// GetCustomEpoch 获取自定义纪元时间
|
||||
func GetCustomEpoch() int64 {
|
||||
return customEpoch
|
||||
}
|
||||
|
||||
// IDToTime 将雪花算法ID转换为生成时间
|
||||
// 这是一个便捷方法,等价于 ParseID(id).Timestamp
|
||||
func IDToTime(id int64) time.Time {
|
||||
info := ParseID(id)
|
||||
return time.Unix(0, info.Timestamp*1e6) // 毫秒转纳秒
|
||||
}
|
||||
|
||||
// ValidateID 验证ID是否为有效的雪花算法ID
|
||||
// 检查时间戳是否在合理范围内
|
||||
func ValidateID(id int64) bool {
|
||||
if id <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
info := ParseID(id)
|
||||
|
||||
// 检查时间戳是否在合理范围内
|
||||
// 不能早于纪元时间,不能晚于当前时间太多(允许1分钟的时钟偏差)
|
||||
now := currentTimestamp()
|
||||
if info.Timestamp < customEpoch || info.Timestamp > now+60000 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查机器ID和序列号是否在有效范围内
|
||||
if info.NodeID < 0 || info.NodeID > maxNodeID {
|
||||
return false
|
||||
}
|
||||
if info.Sequence < 0 || info.Sequence > maxSequence {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
46
internal/pkg/utils/validator.go
Normal file
46
internal/pkg/utils/validator.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ValidateEmail 验证邮箱
|
||||
func ValidateEmail(email string) bool {
|
||||
pattern := `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`
|
||||
matched, _ := regexp.MatchString(pattern, email)
|
||||
return matched
|
||||
}
|
||||
|
||||
// ValidateUsername 验证用户名
|
||||
func ValidateUsername(username string) bool {
|
||||
if len(username) < 3 || len(username) > 30 {
|
||||
return false
|
||||
}
|
||||
pattern := `^[a-zA-Z0-9_]+$`
|
||||
matched, _ := regexp.MatchString(pattern, username)
|
||||
return matched
|
||||
}
|
||||
|
||||
// ValidatePassword 验证密码强度
|
||||
func ValidatePassword(password string) bool {
|
||||
if len(password) < 6 || len(password) > 50 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ValidatePhone 验证手机号
|
||||
func ValidatePhone(phone string) bool {
|
||||
pattern := `^1[3-9]\d{9}$`
|
||||
matched, _ := regexp.MatchString(pattern, phone)
|
||||
return matched
|
||||
}
|
||||
|
||||
// SanitizeHTML 清理HTML
|
||||
func SanitizeHTML(input string) string {
|
||||
// 简单实现,实际使用建议用专门库
|
||||
input = strings.ReplaceAll(input, "<", "<")
|
||||
input = strings.ReplaceAll(input, ">", ">")
|
||||
return input
|
||||
}
|
||||
440
internal/pkg/websocket/websocket.go
Normal file
440
internal/pkg/websocket/websocket.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// WebSocket消息类型常量
|
||||
const (
|
||||
MessageTypePing = "ping"
|
||||
MessageTypePong = "pong"
|
||||
MessageTypeMessage = "message"
|
||||
MessageTypeTyping = "typing"
|
||||
MessageTypeRead = "read"
|
||||
MessageTypeAck = "ack"
|
||||
MessageTypeError = "error"
|
||||
MessageTypeRecall = "recall" // 撤回消息
|
||||
MessageTypeSystem = "system" // 系统消息
|
||||
MessageTypeNotification = "notification" // 通知消息
|
||||
MessageTypeAnnouncement = "announcement" // 公告消息
|
||||
|
||||
// 群组相关消息类型
|
||||
MessageTypeGroupMessage = "group_message" // 群消息
|
||||
MessageTypeGroupTyping = "group_typing" // 群输入状态
|
||||
MessageTypeGroupNotice = "group_notice" // 群组通知(成员变动等)
|
||||
MessageTypeGroupMention = "group_mention" // @提及通知
|
||||
MessageTypeGroupRead = "group_read" // 群消息已读
|
||||
MessageTypeGroupRecall = "group_recall" // 群消息撤回
|
||||
|
||||
// Meta事件详细类型
|
||||
MetaDetailTypeHeartbeat = "heartbeat"
|
||||
MetaDetailTypeTyping = "typing"
|
||||
MetaDetailTypeAck = "ack" // 消息发送确认
|
||||
MetaDetailTypeRead = "read" // 已读回执
|
||||
)
|
||||
|
||||
// WSMessage WebSocket消息结构
|
||||
type WSMessage struct {
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
// ChatMessage 聊天消息结构
|
||||
type ChatMessage struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
SenderID string `json:"sender_id"`
|
||||
Seq int64 `json:"seq"`
|
||||
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
|
||||
ReplyToID *string `json:"reply_to_id,omitempty"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
|
||||
// SystemMessage 系统消息结构
|
||||
type SystemMessage struct {
|
||||
ID string `json:"id"` // 消息ID
|
||||
Type string `json:"type"` // 消息子类型(如:account_banned, post_approved等)
|
||||
Title string `json:"title"` // 消息标题
|
||||
Content string `json:"content"` // 消息内容
|
||||
Data map[string]interface{} `json:"data"` // 额外数据
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// NotificationMessage 通知消息结构
|
||||
type NotificationMessage struct {
|
||||
ID string `json:"id"` // 通知ID
|
||||
Type string `json:"type"` // 通知类型(like, comment, follow, mention等)
|
||||
Title string `json:"title"` // 通知标题
|
||||
Content string `json:"content"` // 通知内容
|
||||
TriggerUser *NotificationUser `json:"trigger_user"` // 触发用户
|
||||
ResourceType string `json:"resource_type"` // 资源类型(post, comment等)
|
||||
ResourceID string `json:"resource_id"` // 资源ID
|
||||
Extra map[string]interface{} `json:"extra"` // 额外数据
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// NotificationUser 通知中的用户信息
|
||||
type NotificationUser struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Avatar string `json:"avatar"`
|
||||
}
|
||||
|
||||
// AnnouncementMessage 公告消息结构
|
||||
type AnnouncementMessage struct {
|
||||
ID string `json:"id"` // 公告ID
|
||||
Title string `json:"title"` // 公告标题
|
||||
Content string `json:"content"` // 公告内容
|
||||
Priority int `json:"priority"` // 优先级(1-10)
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// GroupMessage 群消息结构
|
||||
type GroupMessage struct {
|
||||
ID string `json:"id"` // 消息ID
|
||||
ConversationID string `json:"conversation_id"` // 会话ID(群聊会话)
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
SenderID string `json:"sender_id"` // 发送者ID
|
||||
Seq int64 `json:"seq"` // 消息序号
|
||||
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
|
||||
ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID
|
||||
MentionUsers []uint64 `json:"mention_users,omitempty"` // @的用户ID列表
|
||||
MentionAll bool `json:"mention_all"` // 是否@所有人
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// GroupTypingMessage 群输入状态消息
|
||||
type GroupTypingMessage struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
UserID string `json:"user_id"` // 用户ID
|
||||
Username string `json:"username"` // 用户名
|
||||
IsTyping bool `json:"is_typing"` // 是否正在输入
|
||||
}
|
||||
|
||||
// GroupNoticeMessage 群组通知消息
|
||||
type GroupNoticeMessage struct {
|
||||
NoticeType string `json:"notice_type"` // 通知类型:member_join, member_leave, member_removed, role_changed, muted, unmuted, group_dissolved
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
Data interface{} `json:"data"` // 通知数据
|
||||
Timestamp int64 `json:"timestamp"` // 时间戳
|
||||
MessageID string `json:"message_id,omitempty"` // 消息ID(如果通知保存为消息)
|
||||
Seq int64 `json:"seq,omitempty"` // 消息序号(如果通知保存为消息)
|
||||
}
|
||||
|
||||
// GroupNoticeData 通知数据结构
|
||||
type GroupNoticeData struct {
|
||||
// 成员变动
|
||||
UserID string `json:"user_id,omitempty"` // 变动的用户ID
|
||||
Username string `json:"username,omitempty"` // 用户名
|
||||
OperatorID string `json:"operator_id,omitempty"` // 操作者ID
|
||||
OpName string `json:"op_name,omitempty"` // 操作者名称
|
||||
NewRole string `json:"new_role,omitempty"` // 新角色
|
||||
OldRole string `json:"old_role,omitempty"` // 旧角色
|
||||
MemberCount int `json:"member_count,omitempty"` // 当前成员数
|
||||
|
||||
// 群设置变更
|
||||
MuteAll bool `json:"mute_all,omitempty"` // 全员禁言状态
|
||||
}
|
||||
|
||||
// GroupMentionMessage @提及通知消息
|
||||
type GroupMentionMessage struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
MessageID string `json:"message_id"` // 消息ID
|
||||
FromUserID string `json:"from_user_id"` // 发送者ID
|
||||
FromName string `json:"from_name"` // 发送者名称
|
||||
Content string `json:"content"` // 消息内容预览
|
||||
MentionAll bool `json:"mention_all"` // 是否@所有人
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// AckMessage 消息发送确认结构
|
||||
type AckMessage struct {
|
||||
ConversationID string `json:"conversation_id"` // 会话ID
|
||||
GroupID string `json:"group_id,omitempty"` // 群组ID(群聊时)
|
||||
ID string `json:"id"` // 消息ID
|
||||
SenderID string `json:"sender_id"` // 发送者ID
|
||||
Seq int64 `json:"seq"` // 消息序号
|
||||
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// Client WebSocket客户端
|
||||
type Client struct {
|
||||
ID string
|
||||
UserID string
|
||||
Conn *websocket.Conn
|
||||
Send chan []byte
|
||||
Manager *WebSocketManager
|
||||
IsClosed bool
|
||||
Mu sync.Mutex
|
||||
closeOnce sync.Once // 确保 Send channel 只关闭一次
|
||||
}
|
||||
|
||||
// WebSocketManager WebSocket连接管理器
|
||||
type WebSocketManager struct {
|
||||
clients map[string]*Client // userID -> Client
|
||||
register chan *Client
|
||||
unregister chan *Client
|
||||
broadcast chan *BroadcastMessage
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// BroadcastMessage 广播消息
|
||||
type BroadcastMessage struct {
|
||||
Message *WSMessage
|
||||
ExcludeUser string // 排除的用户ID,为空表示不排除
|
||||
TargetUser string // 目标用户ID,为空表示广播给所有用户
|
||||
}
|
||||
|
||||
// NewWebSocketManager 创建WebSocket管理器
|
||||
func NewWebSocketManager() *WebSocketManager {
|
||||
return &WebSocketManager{
|
||||
clients: make(map[string]*Client),
|
||||
register: make(chan *Client, 100),
|
||||
unregister: make(chan *Client, 100),
|
||||
broadcast: make(chan *BroadcastMessage, 100),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动管理器
|
||||
func (m *WebSocketManager) Start() {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case client := <-m.register:
|
||||
m.mutex.Lock()
|
||||
m.clients[client.UserID] = client
|
||||
m.mutex.Unlock()
|
||||
log.Printf("WebSocket client connected: userID=%s, 当前在线用户数=%d", client.UserID, len(m.clients))
|
||||
|
||||
case client := <-m.unregister:
|
||||
m.mutex.Lock()
|
||||
if _, ok := m.clients[client.UserID]; ok {
|
||||
delete(m.clients, client.UserID)
|
||||
// 使用 closeOnce 确保 channel 只关闭一次,避免 panic
|
||||
client.closeOnce.Do(func() {
|
||||
close(client.Send)
|
||||
})
|
||||
log.Printf("WebSocket client disconnected: userID=%s", client.UserID)
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
case broadcast := <-m.broadcast:
|
||||
m.sendMessage(broadcast)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Register 注册客户端
|
||||
func (m *WebSocketManager) Register(client *Client) {
|
||||
m.register <- client
|
||||
}
|
||||
|
||||
// Unregister 注销客户端
|
||||
func (m *WebSocketManager) Unregister(client *Client) {
|
||||
m.unregister <- client
|
||||
}
|
||||
|
||||
// Broadcast 广播消息给所有用户
|
||||
func (m *WebSocketManager) Broadcast(msg *WSMessage) {
|
||||
m.broadcast <- &BroadcastMessage{
|
||||
Message: msg,
|
||||
TargetUser: "",
|
||||
}
|
||||
}
|
||||
|
||||
// SendToUser 发送消息给指定用户
|
||||
func (m *WebSocketManager) SendToUser(userID string, msg *WSMessage) {
|
||||
m.broadcast <- &BroadcastMessage{
|
||||
Message: msg,
|
||||
TargetUser: userID,
|
||||
}
|
||||
}
|
||||
|
||||
// SendToUsers 发送消息给指定用户列表
|
||||
func (m *WebSocketManager) SendToUsers(userIDs []string, msg *WSMessage) {
|
||||
for _, userID := range userIDs {
|
||||
m.SendToUser(userID, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// GetClient 获取客户端
|
||||
func (m *WebSocketManager) GetClient(userID string) (*Client, bool) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
client, ok := m.clients[userID]
|
||||
return client, ok
|
||||
}
|
||||
|
||||
// GetAllClients 获取所有客户端
|
||||
func (m *WebSocketManager) GetAllClients() map[string]*Client {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return m.clients
|
||||
}
|
||||
|
||||
// GetClientCount 获取在线用户数量
|
||||
func (m *WebSocketManager) GetClientCount() int {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return len(m.clients)
|
||||
}
|
||||
|
||||
// IsUserOnline 检查用户是否在线
|
||||
func (m *WebSocketManager) IsUserOnline(userID string) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
_, ok := m.clients[userID]
|
||||
log.Printf("[DEBUG IsUserOnline] 检查用户 %s, 结果=%v, 当前在线用户=%v", userID, ok, m.clients)
|
||||
return ok
|
||||
}
|
||||
|
||||
// sendMessage 发送消息
|
||||
func (m *WebSocketManager) sendMessage(broadcast *BroadcastMessage) {
|
||||
msgBytes, err := json.Marshal(broadcast.Message)
|
||||
if err != nil {
|
||||
log.Printf("Failed to marshal message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG WebSocket] sendMessage: 目标用户=%s, 当前在线用户数=%d, 消息类型=%s",
|
||||
broadcast.TargetUser, len(m.clients), broadcast.Message.Type)
|
||||
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
for userID, client := range m.clients {
|
||||
// 如果指定了目标用户,只发送给目标用户
|
||||
if broadcast.TargetUser != "" && userID != broadcast.TargetUser {
|
||||
continue
|
||||
}
|
||||
|
||||
// 如果指定了排除用户,跳过
|
||||
if broadcast.ExcludeUser != "" && userID == broadcast.ExcludeUser {
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case client.Send <- msgBytes:
|
||||
log.Printf("[DEBUG WebSocket] 成功发送消息到用户 %s, 消息类型=%s", userID, broadcast.Message.Type)
|
||||
default:
|
||||
log.Printf("Failed to send message to user %s: channel full", userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SendPing 发送心跳
|
||||
func (c *Client) SendPing() error {
|
||||
c.Mu.Lock()
|
||||
defer c.Mu.Unlock()
|
||||
if c.IsClosed {
|
||||
return nil
|
||||
}
|
||||
msg := WSMessage{
|
||||
Type: MessageTypePing,
|
||||
Data: nil,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
msgBytes, _ := json.Marshal(msg)
|
||||
return c.Conn.WriteMessage(websocket.TextMessage, msgBytes)
|
||||
}
|
||||
|
||||
// SendPong 发送Pong响应
|
||||
func (c *Client) SendPong() error {
|
||||
c.Mu.Lock()
|
||||
defer c.Mu.Unlock()
|
||||
if c.IsClosed {
|
||||
return nil
|
||||
}
|
||||
msg := WSMessage{
|
||||
Type: MessageTypePong,
|
||||
Data: nil,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
msgBytes, _ := json.Marshal(msg)
|
||||
return c.Conn.WriteMessage(websocket.TextMessage, msgBytes)
|
||||
}
|
||||
|
||||
// WritePump 写入泵,将消息从Manager发送到客户端
|
||||
func (c *Client) WritePump() {
|
||||
defer func() {
|
||||
c.Conn.Close()
|
||||
c.Mu.Lock()
|
||||
c.IsClosed = true
|
||||
c.Mu.Unlock()
|
||||
}()
|
||||
|
||||
for {
|
||||
message, ok := <-c.Send
|
||||
if !ok {
|
||||
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
c.Mu.Lock()
|
||||
if c.IsClosed {
|
||||
c.Mu.Unlock()
|
||||
return
|
||||
}
|
||||
err := c.Conn.WriteMessage(websocket.TextMessage, message)
|
||||
c.Mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Write error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReadPump 读取泵,从客户端读取消息
|
||||
func (c *Client) ReadPump(handler func(msg *WSMessage)) {
|
||||
defer func() {
|
||||
c.Manager.Unregister(c)
|
||||
c.Conn.Close()
|
||||
c.Mu.Lock()
|
||||
c.IsClosed = true
|
||||
c.Mu.Unlock()
|
||||
}()
|
||||
|
||||
c.Conn.SetReadLimit(512 * 1024) // 512KB
|
||||
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
c.Conn.SetPongHandler(func(string) error {
|
||||
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
_, message, err := c.Conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Printf("WebSocket error: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
var wsMsg WSMessage
|
||||
if err := json.Unmarshal(message, &wsMsg); err != nil {
|
||||
log.Printf("Failed to unmarshal message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
handler(&wsMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateWSMessage 创建WebSocket消息
|
||||
func CreateWSMessage(msgType string, data interface{}) *WSMessage {
|
||||
return &WSMessage{
|
||||
Type: msgType,
|
||||
Data: data,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
}
|
||||
296
internal/repository/comment_repo.go
Normal file
296
internal/repository/comment_repo.go
Normal file
@@ -0,0 +1,296 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CommentRepository 评论仓储
|
||||
type CommentRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewCommentRepository 创建评论仓储
|
||||
func NewCommentRepository(db *gorm.DB) *CommentRepository {
|
||||
return &CommentRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建评论
|
||||
func (r *CommentRepository) Create(comment *model.Comment) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 创建评论
|
||||
err := tx.Create(comment).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 增加帖子的评论数并同步热度分
|
||||
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
|
||||
Updates(map[string]interface{}{
|
||||
"comments_count": gorm.Expr("comments_count + 1"),
|
||||
"hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"),
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果是回复,增加父评论的回复数
|
||||
if comment.ParentID != nil && *comment.ParentID != "" {
|
||||
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
|
||||
UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取评论
|
||||
func (r *CommentRepository) GetByID(id string) (*model.Comment, error) {
|
||||
var comment model.Comment
|
||||
err := r.db.Preload("User").First(&comment, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &comment, nil
|
||||
}
|
||||
|
||||
// Update 更新评论
|
||||
func (r *CommentRepository) Update(comment *model.Comment) error {
|
||||
return r.db.Save(comment).Error
|
||||
}
|
||||
|
||||
// UpdateModerationStatus 更新评论审核状态
|
||||
func (r *CommentRepository) UpdateModerationStatus(commentID string, status model.CommentStatus) error {
|
||||
return r.db.Model(&model.Comment{}).
|
||||
Where("id = ?", commentID).
|
||||
Update("status", status).Error
|
||||
}
|
||||
|
||||
// Delete 删除评论(软删除,同时清理关联数据)
|
||||
func (r *CommentRepository) Delete(id string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 先查询评论获取post_id和parent_id
|
||||
var comment model.Comment
|
||||
if err := tx.First(&comment, "id = ?", id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除评论点赞记录
|
||||
if err := tx.Where("comment_id = ?", id).Delete(&model.CommentLike{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除评论(软删除)
|
||||
if err := tx.Delete(&model.Comment{}, "id = ?", id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 减少帖子的评论数并同步热度分
|
||||
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
|
||||
Updates(map[string]interface{}{
|
||||
"comments_count": gorm.Expr("comments_count - 1"),
|
||||
"hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"),
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果是回复,减少父评论的回复数
|
||||
if comment.ParentID != nil && *comment.ParentID != "" {
|
||||
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
|
||||
UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetByPostID 获取帖子评论
|
||||
func (r *CommentRepository) GetByPostID(postID string, page, pageSize int) ([]*model.Comment, int64, error) {
|
||||
var comments []*model.Comment
|
||||
var total int64
|
||||
|
||||
r.db.Model(&model.Comment{}).Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := r.db.Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).
|
||||
Preload("User").
|
||||
Offset(offset).Limit(pageSize).
|
||||
Order("created_at ASC").
|
||||
Find(&comments).Error
|
||||
|
||||
return comments, total, err
|
||||
}
|
||||
|
||||
// GetByPostIDWithReplies 获取帖子评论(包含回复,扁平化结构)
|
||||
// 所有层级的回复都扁平化展示在顶级评论的 replies 中
|
||||
func (r *CommentRepository) GetByPostIDWithReplies(postID string, page, pageSize, replyLimit int) ([]*model.Comment, int64, error) {
|
||||
var comments []*model.Comment
|
||||
var total int64
|
||||
|
||||
r.db.Model(&model.Comment{}).Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := r.db.Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).
|
||||
Preload("User").
|
||||
Offset(offset).Limit(pageSize).
|
||||
Order("created_at ASC").
|
||||
Find(&comments).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if len(comments) == 0 {
|
||||
return comments, total, nil
|
||||
}
|
||||
|
||||
rootIDs := make([]string, 0, len(comments))
|
||||
commentsByID := make(map[string]*model.Comment, len(comments))
|
||||
for _, comment := range comments {
|
||||
rootIDs = append(rootIDs, comment.ID)
|
||||
commentsByID[comment.ID] = comment
|
||||
}
|
||||
|
||||
// 批量加载所有回复,内存中按 root_id 分组并裁剪每个根评论的返回条数
|
||||
var allReplies []*model.Comment
|
||||
if err := r.db.Where("root_id IN ? AND status = ?", rootIDs, model.CommentStatusPublished).
|
||||
Preload("User").
|
||||
Order("created_at ASC").
|
||||
Find(&allReplies).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
repliesByRoot := make(map[string][]*model.Comment, len(rootIDs))
|
||||
for _, reply := range allReplies {
|
||||
if reply.RootID == nil {
|
||||
continue
|
||||
}
|
||||
rootID := *reply.RootID
|
||||
if replyLimit <= 0 || len(repliesByRoot[rootID]) < replyLimit {
|
||||
repliesByRoot[rootID] = append(repliesByRoot[rootID], reply)
|
||||
}
|
||||
}
|
||||
|
||||
type replyCountRow struct {
|
||||
RootID string
|
||||
Total int64
|
||||
}
|
||||
var replyCountRows []replyCountRow
|
||||
if err := r.db.Model(&model.Comment{}).
|
||||
Select("root_id, COUNT(*) AS total").
|
||||
Where("root_id IN ? AND status = ?", rootIDs, model.CommentStatusPublished).
|
||||
Group("root_id").
|
||||
Scan(&replyCountRows).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
replyCountMap := make(map[string]int64, len(replyCountRows))
|
||||
for _, row := range replyCountRows {
|
||||
replyCountMap[row.RootID] = row.Total
|
||||
}
|
||||
|
||||
for _, rootID := range rootIDs {
|
||||
comment := commentsByID[rootID]
|
||||
comment.Replies = repliesByRoot[rootID]
|
||||
comment.RepliesCount = int(replyCountMap[rootID])
|
||||
}
|
||||
|
||||
return comments, total, nil
|
||||
}
|
||||
|
||||
// loadFlatReplies 加载评论的所有回复(扁平化,所有层级都在同一层)
|
||||
func (r *CommentRepository) loadFlatReplies(rootComment *model.Comment, limit int) {
|
||||
var allReplies []*model.Comment
|
||||
|
||||
// 查询所有以该评论为根评论的回复(不包括顶级评论本身)
|
||||
r.db.Where("root_id = ? AND status = ?", rootComment.ID, model.CommentStatusPublished).
|
||||
Preload("User").
|
||||
Order("created_at ASC").
|
||||
Limit(limit).
|
||||
Find(&allReplies)
|
||||
|
||||
rootComment.Replies = allReplies
|
||||
}
|
||||
|
||||
// GetRepliesByRootID 根据根评论ID分页获取回复(扁平化)
|
||||
func (r *CommentRepository) GetRepliesByRootID(rootID string, page, pageSize int) ([]*model.Comment, int64, error) {
|
||||
var replies []*model.Comment
|
||||
var total int64
|
||||
|
||||
// 统计总数
|
||||
r.db.Model(&model.Comment{}).Where("root_id = ? AND status = ?", rootID, model.CommentStatusPublished).Count(&total)
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err := r.db.Where("root_id = ? AND status = ?", rootID, model.CommentStatusPublished).
|
||||
Preload("User").
|
||||
Order("created_at ASC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&replies).Error
|
||||
|
||||
return replies, total, err
|
||||
}
|
||||
|
||||
// GetReplies 获取回复
|
||||
func (r *CommentRepository) GetReplies(parentID string) ([]*model.Comment, error) {
|
||||
var comments []*model.Comment
|
||||
err := r.db.Where("parent_id = ? AND status = ?", parentID, model.CommentStatusPublished).
|
||||
Preload("User").
|
||||
Order("created_at ASC").
|
||||
Find(&comments).Error
|
||||
return comments, err
|
||||
}
|
||||
|
||||
// Like 点赞评论
|
||||
func (r *CommentRepository) Like(commentID, userID string) error {
|
||||
// 检查是否已经点赞
|
||||
var existing model.CommentLike
|
||||
err := r.db.Where("comment_id = ? AND user_id = ?", commentID, userID).First(&existing).Error
|
||||
if err == nil {
|
||||
// 已经点赞
|
||||
return nil
|
||||
}
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建点赞记录
|
||||
like := &model.CommentLike{
|
||||
CommentID: commentID,
|
||||
UserID: userID,
|
||||
}
|
||||
err = r.db.Create(like).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 增加评论点赞数
|
||||
return r.db.Model(&model.Comment{}).Where("id = ?", commentID).
|
||||
UpdateColumn("likes_count", gorm.Expr("likes_count + 1")).Error
|
||||
}
|
||||
|
||||
// Unlike 取消点赞评论
|
||||
func (r *CommentRepository) Unlike(commentID, userID string) error {
|
||||
result := r.db.Where("comment_id = ? AND user_id = ?", commentID, userID).Delete(&model.CommentLike{})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected > 0 {
|
||||
// 减少评论点赞数
|
||||
return r.db.Model(&model.Comment{}).Where("id = ?", commentID).
|
||||
UpdateColumn("likes_count", gorm.Expr("likes_count - 1")).Error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsLiked 检查是否已点赞
|
||||
func (r *CommentRepository) IsLiked(commentID, userID string) bool {
|
||||
var count int64
|
||||
r.db.Model(&model.CommentLike{}).Where("comment_id = ? AND user_id = ?", commentID, userID).Count(&count)
|
||||
return count > 0
|
||||
}
|
||||
166
internal/repository/device_token_repo.go
Normal file
166
internal/repository/device_token_repo.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// DeviceTokenRepository 设备Token仓储
|
||||
type DeviceTokenRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewDeviceTokenRepository 创建设备Token仓储
|
||||
func NewDeviceTokenRepository(db *gorm.DB) *DeviceTokenRepository {
|
||||
return &DeviceTokenRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建设备Token
|
||||
func (r *DeviceTokenRepository) Create(token *model.DeviceToken) error {
|
||||
return r.db.Create(token).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取设备Token
|
||||
func (r *DeviceTokenRepository) GetByID(id int64) (*model.DeviceToken, error) {
|
||||
var token model.DeviceToken
|
||||
err := r.db.First(&token, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// Update 更新设备Token
|
||||
func (r *DeviceTokenRepository) Update(token *model.DeviceToken) error {
|
||||
return r.db.Save(token).Error
|
||||
}
|
||||
|
||||
// Delete 删除设备Token(软删除)
|
||||
func (r *DeviceTokenRepository) Delete(id int64) error {
|
||||
return r.db.Delete(&model.DeviceToken{}, id).Error
|
||||
}
|
||||
|
||||
// GetByUserID 获取用户所有设备
|
||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (r *DeviceTokenRepository) GetByUserID(userID string) ([]*model.DeviceToken, error) {
|
||||
var tokens []*model.DeviceToken
|
||||
err := r.db.Where("user_id = ?", userID).
|
||||
Order("created_at DESC").
|
||||
Find(&tokens).Error
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
// GetActiveByUserID 获取用户活跃设备
|
||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (r *DeviceTokenRepository) GetActiveByUserID(userID string) ([]*model.DeviceToken, error) {
|
||||
var tokens []*model.DeviceToken
|
||||
err := r.db.Where("user_id = ? AND is_active = ?", userID, true).
|
||||
Order("last_used_at DESC").
|
||||
Find(&tokens).Error
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
// GetByDeviceID 根据设备ID获取设备Token
|
||||
func (r *DeviceTokenRepository) GetByDeviceID(deviceID string) (*model.DeviceToken, error) {
|
||||
var token model.DeviceToken
|
||||
err := r.db.Where("device_id = ?", deviceID).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// GetByPushToken 根据推送Token获取设备信息
|
||||
func (r *DeviceTokenRepository) GetByPushToken(pushToken string) (*model.DeviceToken, error) {
|
||||
var token model.DeviceToken
|
||||
err := r.db.Where("push_token = ?", pushToken).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// DeactivateAllExcept 登出其他设备(停用除指定设备外的所有设备)
|
||||
func (r *DeviceTokenRepository) DeactivateAllExcept(userID int64, deviceID string) error {
|
||||
return r.db.Model(&model.DeviceToken{}).
|
||||
Where("user_id = ? AND device_id != ?", userID, deviceID).
|
||||
Update("is_active", false).Error
|
||||
}
|
||||
|
||||
// Upsert 创建或更新设备Token
|
||||
// 如果设备ID已存在,则更新Token和激活状态;否则创建新记录
|
||||
func (r *DeviceTokenRepository) Upsert(token *model.DeviceToken) error {
|
||||
var existing model.DeviceToken
|
||||
err := r.db.Where("device_id = ?", token.DeviceID).First(&existing).Error
|
||||
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
// 创建新记录
|
||||
return r.db.Create(token).Error
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新现有记录
|
||||
return r.db.Model(&existing).Updates(map[string]interface{}{
|
||||
"push_token": token.PushToken,
|
||||
"is_active": true,
|
||||
"device_name": token.DeviceName,
|
||||
"last_used_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateLastUsed 更新最后使用时间
|
||||
func (r *DeviceTokenRepository) UpdateLastUsed(deviceID string) error {
|
||||
return r.db.Model(&model.DeviceToken{}).
|
||||
Where("device_id = ?", deviceID).
|
||||
Update("last_used_at", time.Now()).Error
|
||||
}
|
||||
|
||||
// Deactivate 停用设备
|
||||
func (r *DeviceTokenRepository) Deactivate(deviceID string) error {
|
||||
return r.db.Model(&model.DeviceToken{}).
|
||||
Where("device_id = ?", deviceID).
|
||||
Update("is_active", false).Error
|
||||
}
|
||||
|
||||
// Activate 激活设备
|
||||
func (r *DeviceTokenRepository) Activate(deviceID string) error {
|
||||
return r.db.Model(&model.DeviceToken{}).
|
||||
Where("device_id = ?", deviceID).
|
||||
Updates(map[string]interface{}{
|
||||
"is_active": true,
|
||||
"last_used_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// DeleteByUserID 删除用户所有设备Token
|
||||
func (r *DeviceTokenRepository) DeleteByUserID(userID int64) error {
|
||||
return r.db.Where("user_id = ?", userID).Delete(&model.DeviceToken{}).Error
|
||||
}
|
||||
|
||||
// GetDeviceCountByUserID 获取用户设备数量
|
||||
func (r *DeviceTokenRepository) GetDeviceCountByUserID(userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.DeviceToken{}).
|
||||
Where("user_id = ?", userID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetActiveDeviceCountByUserID 获取用户活跃设备数量
|
||||
func (r *DeviceTokenRepository) GetActiveDeviceCountByUserID(userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.DeviceToken{}).
|
||||
Where("user_id = ? AND is_active = ?", userID, true).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// DeleteInactiveDevices 删除长时间未使用的设备
|
||||
func (r *DeviceTokenRepository) DeleteInactiveDevices(before time.Time) error {
|
||||
return r.db.Where("is_active = ? AND last_used_at < ?", false, before).
|
||||
Delete(&model.DeviceToken{}).Error
|
||||
}
|
||||
50
internal/repository/group_join_request_repo.go
Normal file
50
internal/repository/group_join_request_repo.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type GroupJoinRequestRepository interface {
|
||||
Create(req *model.GroupJoinRequest) error
|
||||
GetByFlag(flag string) (*model.GroupJoinRequest, error)
|
||||
Update(req *model.GroupJoinRequest) error
|
||||
GetPendingByGroupAndTarget(groupID, targetUserID string, reqType model.GroupJoinRequestType) (*model.GroupJoinRequest, error)
|
||||
}
|
||||
|
||||
type groupJoinRequestRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewGroupJoinRequestRepository(db *gorm.DB) GroupJoinRequestRepository {
|
||||
return &groupJoinRequestRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *groupJoinRequestRepository) Create(req *model.GroupJoinRequest) error {
|
||||
return r.db.Create(req).Error
|
||||
}
|
||||
|
||||
func (r *groupJoinRequestRepository) GetByFlag(flag string) (*model.GroupJoinRequest, error) {
|
||||
var req model.GroupJoinRequest
|
||||
if err := r.db.Where("flag = ?", flag).First(&req).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
func (r *groupJoinRequestRepository) Update(req *model.GroupJoinRequest) error {
|
||||
return r.db.Save(req).Error
|
||||
}
|
||||
|
||||
func (r *groupJoinRequestRepository) GetPendingByGroupAndTarget(groupID, targetUserID string, reqType model.GroupJoinRequestType) (*model.GroupJoinRequest, error) {
|
||||
var req model.GroupJoinRequest
|
||||
err := r.db.Where("group_id = ? AND target_user_id = ? AND request_type = ? AND status = ?",
|
||||
groupID, targetUserID, reqType, model.GroupJoinRequestStatusPending).
|
||||
Order("created_at DESC").
|
||||
First(&req).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &req, nil
|
||||
}
|
||||
242
internal/repository/group_repo.go
Normal file
242
internal/repository/group_repo.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GroupRepository 群组仓库接口
|
||||
type GroupRepository interface {
|
||||
// 群组操作
|
||||
Create(group *model.Group) error
|
||||
GetByID(id string) (*model.Group, error)
|
||||
Update(group *model.Group) error
|
||||
Delete(id string) error
|
||||
GetByOwnerID(ownerID string, page, pageSize int) ([]model.Group, int64, error)
|
||||
|
||||
// 群成员操作
|
||||
AddMember(member *model.GroupMember) error
|
||||
GetMember(groupID string, userID string) (*model.GroupMember, error)
|
||||
GetMembers(groupID string, page, pageSize int) ([]model.GroupMember, int64, error)
|
||||
UpdateMember(member *model.GroupMember) error
|
||||
RemoveMember(groupID string, userID string) error
|
||||
GetMemberCount(groupID string) (int64, error)
|
||||
IsMember(groupID string, userID string) (bool, error)
|
||||
GetUserGroups(userID string, page, pageSize int) ([]model.Group, int64, error)
|
||||
|
||||
// 角色相关
|
||||
GetMemberRole(groupID string, userID string) (string, error)
|
||||
SetMemberRole(groupID string, userID string, role string) error
|
||||
GetAdmins(groupID string) ([]model.GroupMember, error)
|
||||
|
||||
// 群公告操作
|
||||
CreateAnnouncement(announcement *model.GroupAnnouncement) error
|
||||
GetAnnouncements(groupID string, page, pageSize int) ([]model.GroupAnnouncement, int64, error)
|
||||
GetAnnouncementByID(id string) (*model.GroupAnnouncement, error)
|
||||
DeleteAnnouncement(id string) error
|
||||
}
|
||||
|
||||
// groupRepository 群组仓库实现
|
||||
type groupRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewGroupRepository 创建群组仓库
|
||||
func NewGroupRepository(db *gorm.DB) GroupRepository {
|
||||
return &groupRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建群组
|
||||
func (r *groupRepository) Create(group *model.Group) error {
|
||||
return r.db.Create(group).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取群组
|
||||
func (r *groupRepository) GetByID(id string) (*model.Group, error) {
|
||||
var group model.Group
|
||||
err := r.db.First(&group, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
// Update 更新群组
|
||||
func (r *groupRepository) Update(group *model.Group) error {
|
||||
return r.db.Save(group).Error
|
||||
}
|
||||
|
||||
// Delete 删除群组
|
||||
func (r *groupRepository) Delete(id string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 删除群成员
|
||||
if err := tx.Where("group_id = ?", id).Delete(&model.GroupMember{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 删除群公告
|
||||
if err := tx.Where("group_id = ?", id).Delete(&model.GroupAnnouncement{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 删除群组
|
||||
if err := tx.Delete(&model.Group{}, "id = ?", id).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetByOwnerID 根据群主ID获取群组列表
|
||||
func (r *groupRepository) GetByOwnerID(ownerID string, page, pageSize int) ([]model.Group, int64, error) {
|
||||
var groups []model.Group
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Group{}).Where("owner_id = ?", ownerID)
|
||||
query.Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&groups).Error
|
||||
return groups, total, err
|
||||
}
|
||||
|
||||
// AddMember 添加群成员
|
||||
func (r *groupRepository) AddMember(member *model.GroupMember) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Create(member).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 更新群组成员数量
|
||||
return tx.Model(&model.Group{}).Where("id = ?", member.GroupID).
|
||||
Update("member_count", gorm.Expr("member_count + ?", 1)).Error
|
||||
})
|
||||
}
|
||||
|
||||
// GetMember 获取群成员
|
||||
func (r *groupRepository) GetMember(groupID string, userID string) (*model.GroupMember, error) {
|
||||
var member model.GroupMember
|
||||
err := r.db.First(&member, "group_id = ? AND user_id = ?", groupID, userID).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &member, nil
|
||||
}
|
||||
|
||||
// GetMembers 获取群成员列表
|
||||
func (r *groupRepository) GetMembers(groupID string, page, pageSize int) ([]model.GroupMember, int64, error) {
|
||||
var members []model.GroupMember
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.GroupMember{}).Where("group_id = ?", groupID)
|
||||
query.Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Offset(offset).Limit(pageSize).Order("created_at ASC").Find(&members).Error
|
||||
return members, total, err
|
||||
}
|
||||
|
||||
// UpdateMember 更新群成员
|
||||
func (r *groupRepository) UpdateMember(member *model.GroupMember) error {
|
||||
return r.db.Save(member).Error
|
||||
}
|
||||
|
||||
// RemoveMember 移除群成员
|
||||
func (r *groupRepository) RemoveMember(groupID string, userID string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 删除成员
|
||||
if err := tx.Where("group_id = ? AND user_id = ?", groupID, userID).Delete(&model.GroupMember{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 更新群组成员数量
|
||||
return tx.Model(&model.Group{}).Where("id = ?", groupID).
|
||||
Update("member_count", gorm.Expr("member_count - ?", 1)).Error
|
||||
})
|
||||
}
|
||||
|
||||
// GetMemberCount 获取群成员数量
|
||||
func (r *groupRepository) GetMemberCount(groupID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.GroupMember{}).Where("group_id = ?", groupID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// IsMember 检查是否是群成员
|
||||
func (r *groupRepository) IsMember(groupID string, userID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.GroupMember{}).Where("group_id = ? AND user_id = ?", groupID, userID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// GetUserGroups 获取用户加入的群组列表
|
||||
func (r *groupRepository) GetUserGroups(userID string, page, pageSize int) ([]model.Group, int64, error) {
|
||||
var groups []model.Group
|
||||
var total int64
|
||||
|
||||
// 通过群成员表查询用户加入的群组
|
||||
subQuery := r.db.Model(&model.GroupMember{}).
|
||||
Select("group_id").
|
||||
Where("user_id = ?", userID)
|
||||
|
||||
query := r.db.Model(&model.Group{}).Where("id IN (?)", subQuery)
|
||||
query.Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&groups).Error
|
||||
return groups, total, err
|
||||
}
|
||||
|
||||
// GetMemberRole 获取成员角色
|
||||
func (r *groupRepository) GetMemberRole(groupID string, userID string) (string, error) {
|
||||
member, err := r.GetMember(groupID, userID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return member.Role, nil
|
||||
}
|
||||
|
||||
// SetMemberRole 设置成员角色
|
||||
func (r *groupRepository) SetMemberRole(groupID string, userID string, role string) error {
|
||||
return r.db.Model(&model.GroupMember{}).
|
||||
Where("group_id = ? AND user_id = ?", groupID, userID).
|
||||
Update("role", role).Error
|
||||
}
|
||||
|
||||
// GetAdmins 获取群管理员列表
|
||||
func (r *groupRepository) GetAdmins(groupID string) ([]model.GroupMember, error) {
|
||||
var admins []model.GroupMember
|
||||
err := r.db.Where("group_id = ? AND role = ?", groupID, model.GroupRoleAdmin).Find(&admins).Error
|
||||
return admins, err
|
||||
}
|
||||
|
||||
// CreateAnnouncement 创建群公告
|
||||
func (r *groupRepository) CreateAnnouncement(announcement *model.GroupAnnouncement) error {
|
||||
return r.db.Create(announcement).Error
|
||||
}
|
||||
|
||||
// GetAnnouncements 获取群公告列表
|
||||
func (r *groupRepository) GetAnnouncements(groupID string, page, pageSize int) ([]model.GroupAnnouncement, int64, error) {
|
||||
var announcements []model.GroupAnnouncement
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.GroupAnnouncement{}).Where("group_id = ?", groupID)
|
||||
query.Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
// 置顶的排在前面,然后按时间倒序
|
||||
err := query.Offset(offset).Limit(pageSize).Order("is_pinned DESC, created_at DESC").Find(&announcements).Error
|
||||
return announcements, total, err
|
||||
}
|
||||
|
||||
// GetAnnouncementByID 根据ID获取群公告
|
||||
func (r *groupRepository) GetAnnouncementByID(id string) (*model.GroupAnnouncement, error) {
|
||||
var announcement model.GroupAnnouncement
|
||||
err := r.db.First(&announcement, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &announcement, nil
|
||||
}
|
||||
|
||||
// DeleteAnnouncement 删除群公告
|
||||
func (r *groupRepository) DeleteAnnouncement(id string) error {
|
||||
return r.db.Delete(&model.GroupAnnouncement{}, "id = ?", id).Error
|
||||
}
|
||||
543
internal/repository/message_repo.go
Normal file
543
internal/repository/message_repo.go
Normal file
@@ -0,0 +1,543 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// MessageRepository 消息仓储
|
||||
type MessageRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewMessageRepository 创建消息仓储
|
||||
func NewMessageRepository(db *gorm.DB) *MessageRepository {
|
||||
return &MessageRepository{db: db}
|
||||
}
|
||||
|
||||
// CreateMessage 创建消息
|
||||
func (r *MessageRepository) CreateMessage(msg *model.Message) error {
|
||||
return r.db.Create(msg).Error
|
||||
}
|
||||
|
||||
// GetConversation 获取会话
|
||||
func (r *MessageRepository) GetConversation(id string) (*model.Conversation, error) {
|
||||
var conv model.Conversation
|
||||
err := r.db.Preload("Group").First(&conv, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
// GetOrCreatePrivateConversation 获取或创建私聊会话
|
||||
// 使用参与者关系表来管理会话
|
||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (r *MessageRepository) GetOrCreatePrivateConversation(user1ID, user2ID string) (*model.Conversation, error) {
|
||||
var conv model.Conversation
|
||||
|
||||
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: user1ID=%s, user2ID=%s\n", user1ID, user2ID)
|
||||
|
||||
// 查找两个用户共同参与的私聊会话
|
||||
err := r.db.Table("conversations c").
|
||||
Joins("INNER JOIN conversation_participants cp1 ON c.id = cp1.conversation_id AND cp1.user_id = ?", user1ID).
|
||||
Joins("INNER JOIN conversation_participants cp2 ON c.id = cp2.conversation_id AND cp2.user_id = ?", user2ID).
|
||||
Where("c.type = ?", model.ConversationTypePrivate).
|
||||
First(&conv).Error
|
||||
|
||||
if err == nil {
|
||||
_ = r.db.Model(&model.ConversationParticipant{}).
|
||||
Where("conversation_id = ? AND user_id IN ?", conv.ID, []string{user1ID, user2ID}).
|
||||
Update("hidden_at", nil).Error
|
||||
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: found existing conversation, ID=%s\n", conv.ID)
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 没找到会话,创建新会话
|
||||
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: no existing conversation found, creating new one\n")
|
||||
conv = model.Conversation{
|
||||
Type: model.ConversationTypePrivate,
|
||||
}
|
||||
|
||||
// 使用事务创建会话和参与者
|
||||
err = r.db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Create(&conv).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建参与者记录 - UserID 存储为 string (UUID)
|
||||
participants := []model.ConversationParticipant{
|
||||
{ConversationID: conv.ID, UserID: user1ID},
|
||||
{ConversationID: conv.ID, UserID: user2ID},
|
||||
}
|
||||
if err := tx.Create(&participants).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: created new conversation, ID=%s\n", conv.ID)
|
||||
}
|
||||
|
||||
return &conv, err
|
||||
}
|
||||
|
||||
// GetConversations 获取用户会话列表
|
||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (r *MessageRepository) GetConversations(userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
|
||||
var convs []*model.Conversation
|
||||
var total int64
|
||||
|
||||
// 获取总数
|
||||
r.db.Model(&model.ConversationParticipant{}).
|
||||
Where("user_id = ? AND hidden_at IS NULL", userID).
|
||||
Count(&total)
|
||||
|
||||
if total == 0 {
|
||||
return convs, total, nil
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
// 查询会话列表并预加载关联数据:
|
||||
// 当前用户维度先按置顶排序,再按更新时间排序
|
||||
err := r.db.Model(&model.Conversation{}).
|
||||
Joins("INNER JOIN conversation_participants cp ON conversations.id = cp.conversation_id").
|
||||
Where("cp.user_id = ? AND cp.hidden_at IS NULL", userID).
|
||||
Preload("Group").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Order("cp.is_pinned DESC").
|
||||
Order("conversations.updated_at DESC").
|
||||
Find(&convs).Error
|
||||
|
||||
return convs, total, err
|
||||
}
|
||||
|
||||
// GetMessages 获取会话消息
|
||||
func (r *MessageRepository) GetMessages(conversationID string, page, pageSize int) ([]*model.Message, int64, error) {
|
||||
var messages []*model.Message
|
||||
var total int64
|
||||
|
||||
r.db.Model(&model.Message{}).Where("conversation_id = ?", conversationID).Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := r.db.Where("conversation_id = ?", conversationID).
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Order("seq DESC").
|
||||
Find(&messages).Error
|
||||
|
||||
return messages, total, err
|
||||
}
|
||||
|
||||
// GetMessagesAfterSeq 获取指定seq之后的消息(用于增量同步)
|
||||
func (r *MessageRepository) GetMessagesAfterSeq(conversationID string, afterSeq int64, limit int) ([]*model.Message, error) {
|
||||
var messages []*model.Message
|
||||
err := r.db.Where("conversation_id = ? AND seq > ?", conversationID, afterSeq).
|
||||
Order("seq ASC").
|
||||
Limit(limit).
|
||||
Find(&messages).Error
|
||||
return messages, err
|
||||
}
|
||||
|
||||
// GetMessagesBeforeSeq 获取指定seq之前的历史消息(用于下拉加载更多)
|
||||
func (r *MessageRepository) GetMessagesBeforeSeq(conversationID string, beforeSeq int64, limit int) ([]*model.Message, error) {
|
||||
var messages []*model.Message
|
||||
fmt.Printf("[DEBUG] GetMessagesBeforeSeq: conversationID=%s, beforeSeq=%d, limit=%d\n", conversationID, beforeSeq, limit)
|
||||
err := r.db.Where("conversation_id = ? AND seq < ?", conversationID, beforeSeq).
|
||||
Order("seq DESC"). // 降序获取最新消息在前
|
||||
Limit(limit).
|
||||
Find(&messages).Error
|
||||
fmt.Printf("[DEBUG] GetMessagesBeforeSeq: found %d messages, seq range: ", len(messages))
|
||||
for i, m := range messages {
|
||||
if i < 5 || i >= len(messages)-2 {
|
||||
fmt.Printf("%d ", m.Seq)
|
||||
} else if i == 5 {
|
||||
fmt.Printf("... ")
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
// 反转回正序
|
||||
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
|
||||
messages[i], messages[j] = messages[j], messages[i]
|
||||
}
|
||||
return messages, err
|
||||
}
|
||||
|
||||
// GetConversationParticipants 获取会话参与者
|
||||
func (r *MessageRepository) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) {
|
||||
var participants []*model.ConversationParticipant
|
||||
err := r.db.Where("conversation_id = ?", conversationID).Find(&participants).Error
|
||||
return participants, err
|
||||
}
|
||||
|
||||
// GetParticipant 获取用户在会话中的参与者信息
|
||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (r *MessageRepository) GetParticipant(conversationID string, userID string) (*model.ConversationParticipant, error) {
|
||||
var participant model.ConversationParticipant
|
||||
err := r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).First(&participant).Error
|
||||
if err != nil {
|
||||
// 如果找不到参与者,尝试添加(修复没有参与者记录的问题)
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
// 检查会话是否存在
|
||||
var conv model.Conversation
|
||||
if err := r.db.First(&conv, conversationID).Error; err == nil {
|
||||
// 会话存在,添加参与者
|
||||
participant = model.ConversationParticipant{
|
||||
ConversationID: conversationID,
|
||||
UserID: userID,
|
||||
}
|
||||
if err := r.db.Create(&participant).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &participant, nil
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &participant, nil
|
||||
}
|
||||
|
||||
// UpdateLastReadSeq 更新已读位置
|
||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (r *MessageRepository) UpdateLastReadSeq(conversationID string, userID string, lastReadSeq int64) error {
|
||||
result := r.db.Model(&model.ConversationParticipant{}).
|
||||
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
||||
Update("last_read_seq", lastReadSeq)
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// 如果没有更新任何记录,说明参与者记录不存在,需要插入
|
||||
if result.RowsAffected == 0 {
|
||||
// 尝试插入新记录(跨数据库 upsert)
|
||||
err := r.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "conversation_id"},
|
||||
{Name: "user_id"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||
"last_read_seq": lastReadSeq,
|
||||
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||
}),
|
||||
}).Create(&model.ConversationParticipant{
|
||||
ConversationID: conversationID,
|
||||
UserID: userID,
|
||||
LastReadSeq: lastReadSeq,
|
||||
}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePinned 更新会话置顶状态(用户维度)
|
||||
func (r *MessageRepository) UpdatePinned(conversationID string, userID string, isPinned bool) error {
|
||||
result := r.db.Model(&model.ConversationParticipant{}).
|
||||
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
||||
Update("is_pinned", isPinned)
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return r.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "conversation_id"},
|
||||
{Name: "user_id"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||
"is_pinned": isPinned,
|
||||
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||
}),
|
||||
}).Create(&model.ConversationParticipant{
|
||||
ConversationID: conversationID,
|
||||
UserID: userID,
|
||||
IsPinned: isPinned,
|
||||
}).Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUnreadCount 获取未读消息数
|
||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (r *MessageRepository) GetUnreadCount(conversationID string, userID string) (int64, error) {
|
||||
var participant model.ConversationParticipant
|
||||
err := r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).First(&participant).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var count int64
|
||||
err = r.db.Model(&model.Message{}).
|
||||
Where("conversation_id = ? AND sender_id != ? AND seq > ?", conversationID, userID, participant.LastReadSeq).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// UpdateConversationLastSeq 更新会话的最后消息seq和时间
|
||||
func (r *MessageRepository) UpdateConversationLastSeq(conversationID string, seq int64) error {
|
||||
return r.db.Model(&model.Conversation{}).
|
||||
Where("id = ?", conversationID).
|
||||
Updates(map[string]interface{}{
|
||||
"last_seq": seq,
|
||||
"last_msg_time": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// GetNextSeq 获取会话的下一个seq值
|
||||
func (r *MessageRepository) GetNextSeq(conversationID string) (int64, error) {
|
||||
var conv model.Conversation
|
||||
err := r.db.Select("last_seq").First(&conv, conversationID).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return conv.LastSeq + 1, nil
|
||||
}
|
||||
|
||||
// CreateMessageWithSeq 创建消息并更新seq(事务操作)
|
||||
func (r *MessageRepository) CreateMessageWithSeq(msg *model.Message) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 获取当前seq并+1
|
||||
var conv model.Conversation
|
||||
if err := tx.Select("last_seq").First(&conv, msg.ConversationID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msg.Seq = conv.LastSeq + 1
|
||||
|
||||
// 创建消息
|
||||
if err := tx.Create(msg).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新会话的last_seq
|
||||
if err := tx.Model(&model.Conversation{}).
|
||||
Where("id = ?", msg.ConversationID).
|
||||
Updates(map[string]interface{}{
|
||||
"last_seq": msg.Seq,
|
||||
"last_msg_time": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 新消息到达后,自动恢复被“仅自己删除”的会话
|
||||
if err := tx.Model(&model.ConversationParticipant{}).
|
||||
Where("conversation_id = ?", msg.ConversationID).
|
||||
Update("hidden_at", nil).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetAllUnreadCount 获取用户所有会话的未读消息总数
|
||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (r *MessageRepository) GetAllUnreadCount(userID string) (int64, error) {
|
||||
var totalUnread int64
|
||||
err := r.db.Table("conversation_participants AS cp").
|
||||
Joins("LEFT JOIN messages AS m ON m.conversation_id = cp.conversation_id AND m.sender_id <> ? AND m.seq > cp.last_read_seq AND m.deleted_at IS NULL", userID).
|
||||
Where("cp.user_id = ?", userID).
|
||||
Select("COALESCE(COUNT(m.id), 0)").
|
||||
Scan(&totalUnread).Error
|
||||
return totalUnread, err
|
||||
}
|
||||
|
||||
// GetMessageByID 根据ID获取消息
|
||||
func (r *MessageRepository) GetMessageByID(messageID string) (*model.Message, error) {
|
||||
var message model.Message
|
||||
err := r.db.First(&message, "id = ?", messageID).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &message, nil
|
||||
}
|
||||
|
||||
// CountMessagesBySenderInConversation 统计会话中某用户已发送消息数
|
||||
func (r *MessageRepository) CountMessagesBySenderInConversation(conversationID, senderID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Message{}).
|
||||
Where("conversation_id = ? AND sender_id = ?", conversationID, senderID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// UpdateMessageStatus 更新消息状态
|
||||
func (r *MessageRepository) UpdateMessageStatus(messageID int64, status model.MessageStatus) error {
|
||||
return r.db.Model(&model.Message{}).
|
||||
Where("id = ?", messageID).
|
||||
Update("status", status).Error
|
||||
}
|
||||
|
||||
// GetOrCreateSystemParticipant 获取或创建用户在系统会话中的参与者记录
|
||||
// 系统会话是虚拟会话,但需要参与者记录来跟踪已读状态
|
||||
func (r *MessageRepository) GetOrCreateSystemParticipant(userID string) (*model.ConversationParticipant, error) {
|
||||
var participant model.ConversationParticipant
|
||||
err := r.db.Where("conversation_id = ? AND user_id = ?",
|
||||
model.SystemConversationID, userID).First(&participant).Error
|
||||
|
||||
if err == nil {
|
||||
return &participant, nil
|
||||
}
|
||||
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 自动创建参与者记录
|
||||
participant = model.ConversationParticipant{
|
||||
ConversationID: model.SystemConversationID,
|
||||
UserID: userID,
|
||||
LastReadSeq: 0,
|
||||
}
|
||||
|
||||
if err := r.db.Create(&participant).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &participant, nil
|
||||
}
|
||||
|
||||
// GetSystemMessagesUnreadCount 获取系统消息未读数
|
||||
func (r *MessageRepository) GetSystemMessagesUnreadCount(userID string) (int64, error) {
|
||||
// 获取或创建参与者记录
|
||||
participant, err := r.GetOrCreateSystemParticipant(userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 计算未读数:查询 seq > last_read_seq 的消息
|
||||
var count int64
|
||||
err = r.db.Model(&model.Message{}).
|
||||
Where("conversation_id = ? AND seq > ?",
|
||||
model.SystemConversationID, participant.LastReadSeq).
|
||||
Count(&count).Error
|
||||
|
||||
return count, err
|
||||
}
|
||||
|
||||
// MarkAllSystemMessagesAsRead 标记所有系统消息已读
|
||||
func (r *MessageRepository) MarkAllSystemMessagesAsRead(userID string) error {
|
||||
// 获取系统会话的最新 seq
|
||||
var maxSeq int64
|
||||
err := r.db.Model(&model.Message{}).
|
||||
Where("conversation_id = ?", model.SystemConversationID).
|
||||
Select("COALESCE(MAX(seq), 0)").
|
||||
Scan(&maxSeq).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 使用跨数据库 upsert 方式更新或创建参与者记录
|
||||
return r.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "conversation_id"},
|
||||
{Name: "user_id"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||
"last_read_seq": maxSeq,
|
||||
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||
}),
|
||||
}).Create(&model.ConversationParticipant{
|
||||
ConversationID: model.SystemConversationID,
|
||||
UserID: userID,
|
||||
LastReadSeq: maxSeq,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// GetConversationByGroupID 通过群组ID获取会话
|
||||
func (r *MessageRepository) GetConversationByGroupID(groupID string) (*model.Conversation, error) {
|
||||
var conv model.Conversation
|
||||
err := r.db.Where("group_id = ?", groupID).First(&conv).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &conv, nil
|
||||
}
|
||||
|
||||
// RemoveParticipant 移除会话参与者
|
||||
// 当用户退出群聊时,需要同时移除其在对应会话中的参与者记录
|
||||
func (r *MessageRepository) RemoveParticipant(conversationID string, userID string) error {
|
||||
return r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
||||
Delete(&model.ConversationParticipant{}).Error
|
||||
}
|
||||
|
||||
// AddParticipant 添加会话参与者
|
||||
// 当用户加入群聊时,需要同时将其添加到对应会话的参与者记录
|
||||
func (r *MessageRepository) AddParticipant(conversationID string, userID string) error {
|
||||
// 先检查是否已经是参与者
|
||||
var count int64
|
||||
err := r.db.Model(&model.ConversationParticipant{}).
|
||||
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果已经是参与者,直接返回
|
||||
if count > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 添加参与者
|
||||
participant := model.ConversationParticipant{
|
||||
ConversationID: conversationID,
|
||||
UserID: userID,
|
||||
LastReadSeq: 0,
|
||||
}
|
||||
return r.db.Create(&participant).Error
|
||||
}
|
||||
|
||||
// DeleteConversationByGroupID 删除群组对应的会话及其参与者
|
||||
// 当解散群组时调用
|
||||
func (r *MessageRepository) DeleteConversationByGroupID(groupID string) error {
|
||||
// 获取群组对应的会话
|
||||
conv, err := r.GetConversationByGroupID(groupID)
|
||||
if err != nil {
|
||||
// 如果会话不存在,直接返回
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 删除会话参与者
|
||||
if err := tx.Where("conversation_id = ?", conv.ID).Delete(&model.ConversationParticipant{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 删除会话中的消息
|
||||
if err := tx.Where("conversation_id = ?", conv.ID).Delete(&model.Message{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 删除会话
|
||||
if err := tx.Delete(&model.Conversation{}, "id = ?", conv.ID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// HideConversationForUser 仅对当前用户隐藏会话(私聊删除)
|
||||
func (r *MessageRepository) HideConversationForUser(conversationID, userID string) error {
|
||||
now := time.Now()
|
||||
return r.db.Model(&model.ConversationParticipant{}).
|
||||
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
||||
Update("hidden_at", &now).Error
|
||||
}
|
||||
78
internal/repository/notification_repo.go
Normal file
78
internal/repository/notification_repo.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// NotificationRepository 通知仓储
|
||||
type NotificationRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewNotificationRepository 创建通知仓储
|
||||
func NewNotificationRepository(db *gorm.DB) *NotificationRepository {
|
||||
return &NotificationRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建通知
|
||||
func (r *NotificationRepository) Create(notification *model.Notification) error {
|
||||
return r.db.Create(notification).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取通知
|
||||
func (r *NotificationRepository) GetByID(id string) (*model.Notification, error) {
|
||||
var notification model.Notification
|
||||
err := r.db.First(¬ification, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ¬ification, nil
|
||||
}
|
||||
|
||||
// GetByUserID 获取用户通知
|
||||
func (r *NotificationRepository) GetByUserID(userID string, page, pageSize int, unreadOnly bool) ([]*model.Notification, int64, error) {
|
||||
var notifications []*model.Notification
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Notification{}).Where("user_id = ?", userID)
|
||||
|
||||
if unreadOnly {
|
||||
query = query.Where("is_read = ?", false)
|
||||
}
|
||||
|
||||
query.Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(¬ifications).Error
|
||||
|
||||
return notifications, total, err
|
||||
}
|
||||
|
||||
// MarkAsRead 标记为已读
|
||||
func (r *NotificationRepository) MarkAsRead(id string) error {
|
||||
return r.db.Model(&model.Notification{}).Where("id = ?", id).Update("is_read", true).Error
|
||||
}
|
||||
|
||||
// MarkAllAsRead 标记所有为已读
|
||||
func (r *NotificationRepository) MarkAllAsRead(userID string) error {
|
||||
return r.db.Model(&model.Notification{}).Where("user_id = ?", userID).Update("is_read", true).Error
|
||||
}
|
||||
|
||||
// Delete 删除通知
|
||||
func (r *NotificationRepository) Delete(id string) error {
|
||||
return r.db.Delete(&model.Notification{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// GetUnreadCount 获取未读数量
|
||||
func (r *NotificationRepository) GetUnreadCount(userID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Notification{}).Where("user_id = ? AND is_read = ?", userID, false).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// DeleteAllByUserID 删除用户所有通知
|
||||
func (r *NotificationRepository) DeleteAllByUserID(userID string) error {
|
||||
return r.db.Where("user_id = ?", userID).Delete(&model.Notification{}).Error
|
||||
}
|
||||
360
internal/repository/post_repo.go
Normal file
360
internal/repository/post_repo.go
Normal file
@@ -0,0 +1,360 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// PostRepository 帖子仓储
|
||||
type PostRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewPostRepository 创建帖子仓储
|
||||
func NewPostRepository(db *gorm.DB) *PostRepository {
|
||||
return &PostRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建帖子
|
||||
func (r *PostRepository) Create(post *model.Post, images []string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 创建帖子
|
||||
if err := tx.Create(post).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建图片记录
|
||||
for i, url := range images {
|
||||
image := &model.PostImage{
|
||||
PostID: post.ID,
|
||||
URL: url,
|
||||
SortOrder: i,
|
||||
}
|
||||
if err := tx.Create(image).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取帖子
|
||||
func (r *PostRepository) GetByID(id string) (*model.Post, error) {
|
||||
var post model.Post
|
||||
err := r.db.Preload("User").Preload("Images").First(&post, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &post, nil
|
||||
}
|
||||
|
||||
// Update 更新帖子
|
||||
func (r *PostRepository) Update(post *model.Post) error {
|
||||
return r.db.Save(post).Error
|
||||
}
|
||||
|
||||
// UpdateModerationStatus 更新帖子审核状态
|
||||
func (r *PostRepository) UpdateModerationStatus(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error {
|
||||
updates := map[string]interface{}{
|
||||
"status": status,
|
||||
"reviewed_at": gorm.Expr("CURRENT_TIMESTAMP"),
|
||||
"reviewed_by": reviewedBy,
|
||||
"reject_reason": rejectReason,
|
||||
}
|
||||
return r.db.Model(&model.Post{}).Where("id = ?", postID).Updates(updates).Error
|
||||
}
|
||||
|
||||
// Delete 删除帖子(软删除,同时清理关联数据)
|
||||
func (r *PostRepository) Delete(id string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 删除帖子图片
|
||||
if err := tx.Where("post_id = ?", id).Delete(&model.PostImage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除帖子点赞记录
|
||||
if err := tx.Where("post_id = ?", id).Delete(&model.PostLike{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除帖子收藏记录
|
||||
if err := tx.Where("post_id = ?", id).Delete(&model.Favorite{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除评论点赞记录(子查询获取该帖子所有评论ID)
|
||||
if err := tx.Where("comment_id IN (SELECT id FROM comments WHERE post_id = ?)", id).Delete(&model.CommentLike{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除帖子评论
|
||||
if err := tx.Where("post_id = ?", id).Delete(&model.Comment{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 最后删除帖子本身(软删除)
|
||||
return tx.Delete(&model.Post{}, "id = ?", id).Error
|
||||
})
|
||||
}
|
||||
|
||||
// List 分页获取帖子列表
|
||||
func (r *PostRepository) List(page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
||||
var posts []*model.Post
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished)
|
||||
|
||||
if userID != "" {
|
||||
query = query.Where("user_id = ?", userID)
|
||||
}
|
||||
|
||||
query.Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
|
||||
|
||||
return posts, total, err
|
||||
}
|
||||
|
||||
// GetUserPosts 获取用户帖子
|
||||
func (r *PostRepository) GetUserPosts(userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
var posts []*model.Post
|
||||
var total int64
|
||||
|
||||
r.db.Model(&model.Post{}).Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := r.db.Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
|
||||
|
||||
return posts, total, err
|
||||
}
|
||||
|
||||
// GetFavorites 获取用户收藏
|
||||
func (r *PostRepository) GetFavorites(userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
var posts []*model.Post
|
||||
var total int64
|
||||
|
||||
subQuery := r.db.Model(&model.Favorite{}).Where("user_id = ?", userID).Select("post_id")
|
||||
r.db.Model(&model.Post{}).Where("id IN (?) AND status = ?", subQuery, model.PostStatusPublished).Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := r.db.Where("id IN (?) AND status = ?", subQuery, model.PostStatusPublished).Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
|
||||
|
||||
return posts, total, err
|
||||
}
|
||||
|
||||
// Like 点赞帖子
|
||||
func (r *PostRepository) Like(postID, userID string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 检查是否已经点赞
|
||||
var existing model.PostLike
|
||||
err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&existing).Error
|
||||
if err == nil {
|
||||
// 已经点赞,直接返回
|
||||
return nil
|
||||
}
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建点赞记录
|
||||
if err := tx.Create(&model.PostLike{
|
||||
PostID: postID,
|
||||
UserID: userID,
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 增加帖子点赞数并同步热度分
|
||||
return tx.Model(&model.Post{}).Where("id = ?", postID).
|
||||
Updates(map[string]interface{}{
|
||||
"likes_count": gorm.Expr("likes_count + 1"),
|
||||
"hot_score": gorm.Expr("(likes_count + 1) * 2 + comments_count * 3 + views_count * 0.1"),
|
||||
}).Error
|
||||
})
|
||||
}
|
||||
|
||||
// Unlike 取消点赞
|
||||
func (r *PostRepository) Unlike(postID, userID string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.Where("post_id = ? AND user_id = ?", postID, userID).Delete(&model.PostLike{})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected > 0 {
|
||||
// 减少帖子点赞数并同步热度分
|
||||
return tx.Model(&model.Post{}).Where("id = ?", postID).
|
||||
Updates(map[string]interface{}{
|
||||
"likes_count": gorm.Expr("likes_count - 1"),
|
||||
"hot_score": gorm.Expr("(likes_count - 1) * 2 + comments_count * 3 + views_count * 0.1"),
|
||||
}).Error
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// IsLiked 检查是否点赞
|
||||
func (r *PostRepository) IsLiked(postID, userID string) bool {
|
||||
var count int64
|
||||
r.db.Model(&model.PostLike{}).Where("post_id = ? AND user_id = ?", postID, userID).Count(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// Favorite 收藏帖子
|
||||
func (r *PostRepository) Favorite(postID, userID string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 检查是否已经收藏
|
||||
var existing model.Favorite
|
||||
err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&existing).Error
|
||||
if err == nil {
|
||||
// 已经收藏,直接返回
|
||||
return nil
|
||||
}
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建收藏记录
|
||||
if err := tx.Create(&model.Favorite{
|
||||
PostID: postID,
|
||||
UserID: userID,
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 增加帖子收藏数
|
||||
return tx.Model(&model.Post{}).Where("id = ?", postID).
|
||||
UpdateColumn("favorites_count", gorm.Expr("favorites_count + 1")).Error
|
||||
})
|
||||
}
|
||||
|
||||
// Unfavorite 取消收藏
|
||||
func (r *PostRepository) Unfavorite(postID, userID string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.Where("post_id = ? AND user_id = ?", postID, userID).Delete(&model.Favorite{})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected > 0 {
|
||||
// 减少帖子收藏数
|
||||
return tx.Model(&model.Post{}).Where("id = ?", postID).
|
||||
UpdateColumn("favorites_count", gorm.Expr("favorites_count - 1")).Error
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// IsFavorited 检查是否收藏
|
||||
func (r *PostRepository) IsFavorited(postID, userID string) bool {
|
||||
var count int64
|
||||
r.db.Model(&model.Favorite{}).Where("post_id = ? AND user_id = ?", postID, userID).Count(&count)
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// IncrementViews 增加帖子观看量
|
||||
func (r *PostRepository) IncrementViews(postID string) error {
|
||||
return r.db.Model(&model.Post{}).Where("id = ?", postID).
|
||||
Updates(map[string]interface{}{
|
||||
"views_count": gorm.Expr("views_count + 1"),
|
||||
"hot_score": gorm.Expr("likes_count * 2 + comments_count * 3 + (views_count + 1) * 0.1"),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// Search 搜索帖子
|
||||
func (r *PostRepository) Search(keyword string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
var posts []*model.Post
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished)
|
||||
|
||||
// 搜索标题和内容
|
||||
if keyword != "" {
|
||||
if r.db.Dialector.Name() == "postgres" {
|
||||
// PostgreSQL 使用全文检索表达式,为 pg_trgm/GIN 索引升级预留路径
|
||||
query = query.Where(
|
||||
"to_tsvector('simple', COALESCE(title, '') || ' ' || COALESCE(content, '')) @@ plainto_tsquery('simple', ?)",
|
||||
keyword,
|
||||
)
|
||||
} else {
|
||||
searchPattern := "%" + keyword + "%"
|
||||
query = query.Where("title LIKE ? OR content LIKE ?", searchPattern, searchPattern)
|
||||
}
|
||||
}
|
||||
|
||||
query.Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
|
||||
|
||||
return posts, total, err
|
||||
}
|
||||
|
||||
// GetFollowingPosts 获取关注用户的帖子
|
||||
func (r *PostRepository) GetFollowingPosts(userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
var posts []*model.Post
|
||||
var total int64
|
||||
|
||||
// 子查询:获取当前用户关注的所有用户ID
|
||||
subQuery := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Select("following_id")
|
||||
|
||||
// 统计总数
|
||||
r.db.Model(&model.Post{}).Where("user_id IN (?) AND status = ?", subQuery, model.PostStatusPublished).Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := r.db.Where("user_id IN (?) AND status = ?", subQuery, model.PostStatusPublished).
|
||||
Preload("User").Preload("Images").
|
||||
Offset(offset).Limit(pageSize).
|
||||
Order("created_at DESC").
|
||||
Find(&posts).Error
|
||||
|
||||
return posts, total, err
|
||||
}
|
||||
|
||||
// GetHotPosts 获取热门帖子(按点赞数和评论数排序)
|
||||
func (r *PostRepository) GetHotPosts(page, pageSize int) ([]*model.Post, int64, error) {
|
||||
var posts []*model.Post
|
||||
var total int64
|
||||
|
||||
r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished).Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
// 热门排序使用预计算热度分,避免每次请求进行表达式排序计算
|
||||
err := r.db.Where("status = ?", model.PostStatusPublished).Preload("User").Preload("Images").
|
||||
Offset(offset).Limit(pageSize).
|
||||
Order("hot_score DESC, created_at DESC").
|
||||
Find(&posts).Error
|
||||
|
||||
return posts, total, err
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取帖子(保持传入顺序)
|
||||
func (r *PostRepository) GetByIDs(ids []string) ([]*model.Post, error) {
|
||||
if len(ids) == 0 {
|
||||
return []*model.Post{}, nil
|
||||
}
|
||||
|
||||
var posts []*model.Post
|
||||
err := r.db.Preload("User").Preload("Images").
|
||||
Where("id IN ? AND status = ?", ids, model.PostStatusPublished).
|
||||
Find(&posts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 按传入ID顺序排序
|
||||
postMap := make(map[string]*model.Post)
|
||||
for _, post := range posts {
|
||||
postMap[post.ID] = post
|
||||
}
|
||||
|
||||
ordered := make([]*model.Post, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if post, ok := postMap[id]; ok {
|
||||
ordered = append(ordered, post)
|
||||
}
|
||||
}
|
||||
|
||||
return ordered, nil
|
||||
}
|
||||
172
internal/repository/push_repo.go
Normal file
172
internal/repository/push_repo.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// PushRecordRepository 推送记录仓储
|
||||
type PushRecordRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewPushRecordRepository 创建推送记录仓储
|
||||
func NewPushRecordRepository(db *gorm.DB) *PushRecordRepository {
|
||||
return &PushRecordRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建推送记录
|
||||
func (r *PushRecordRepository) Create(record *model.PushRecord) error {
|
||||
return r.db.Create(record).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取推送记录
|
||||
func (r *PushRecordRepository) GetByID(id int64) (*model.PushRecord, error) {
|
||||
var record model.PushRecord
|
||||
err := r.db.First(&record, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// Update 更新推送记录
|
||||
func (r *PushRecordRepository) Update(record *model.PushRecord) error {
|
||||
return r.db.Save(record).Error
|
||||
}
|
||||
|
||||
// GetPendingPushes 获取待推送记录
|
||||
func (r *PushRecordRepository) GetPendingPushes(limit int) ([]*model.PushRecord, error) {
|
||||
var records []*model.PushRecord
|
||||
err := r.db.Where("push_status = ?", model.PushStatusPending).
|
||||
Where("expired_at IS NULL OR expired_at > ?", time.Now()).
|
||||
Order("created_at ASC").
|
||||
Limit(limit).
|
||||
Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
// GetByUserID 根据用户ID获取推送记录
|
||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (r *PushRecordRepository) GetByUserID(userID string, limit, offset int) ([]*model.PushRecord, error) {
|
||||
var records []*model.PushRecord
|
||||
err := r.db.Where("user_id = ?", userID).
|
||||
Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(limit).
|
||||
Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
// GetByMessageID 根据消息ID获取推送记录
|
||||
func (r *PushRecordRepository) GetByMessageID(messageID int64) ([]*model.PushRecord, error) {
|
||||
var records []*model.PushRecord
|
||||
err := r.db.Where("message_id = ?", messageID).
|
||||
Order("created_at DESC").
|
||||
Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
// GetFailedPushesForRetry 获取失败待重试的推送
|
||||
func (r *PushRecordRepository) GetFailedPushesForRetry(limit int) ([]*model.PushRecord, error) {
|
||||
var records []*model.PushRecord
|
||||
err := r.db.Where("push_status = ?", model.PushStatusFailed).
|
||||
Where("retry_count < max_retry").
|
||||
Where("expired_at IS NULL OR expired_at > ?", time.Now()).
|
||||
Order("created_at ASC").
|
||||
Limit(limit).
|
||||
Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
// BatchCreate 批量创建推送记录
|
||||
func (r *PushRecordRepository) BatchCreate(records []*model.PushRecord) error {
|
||||
if len(records) == 0 {
|
||||
return nil
|
||||
}
|
||||
return r.db.Create(&records).Error
|
||||
}
|
||||
|
||||
// BatchUpdateStatus 批量更新推送状态
|
||||
func (r *PushRecordRepository) BatchUpdateStatus(ids []int64, status model.PushStatus) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
updates := map[string]interface{}{
|
||||
"push_status": status,
|
||||
}
|
||||
if status == model.PushStatusPushed {
|
||||
updates["pushed_at"] = time.Now()
|
||||
}
|
||||
return r.db.Model(&model.PushRecord{}).
|
||||
Where("id IN ?", ids).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// UpdateStatus 更新单条记录状态
|
||||
func (r *PushRecordRepository) UpdateStatus(id int64, status model.PushStatus) error {
|
||||
updates := map[string]interface{}{
|
||||
"push_status": status,
|
||||
}
|
||||
if status == model.PushStatusPushed {
|
||||
updates["pushed_at"] = time.Now()
|
||||
}
|
||||
return r.db.Model(&model.PushRecord{}).
|
||||
Where("id = ?", id).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// MarkAsFailed 标记为失败
|
||||
func (r *PushRecordRepository) MarkAsFailed(id int64, errMsg string) error {
|
||||
return r.db.Model(&model.PushRecord{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"push_status": model.PushStatusFailed,
|
||||
"error_message": errMsg,
|
||||
"retry_count": gorm.Expr("retry_count + 1"),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// MarkAsDelivered 标记为已送达
|
||||
func (r *PushRecordRepository) MarkAsDelivered(id int64) error {
|
||||
return r.db.Model(&model.PushRecord{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"push_status": model.PushStatusDelivered,
|
||||
"delivered_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// DeleteExpiredRecords 删除过期的推送记录(软删除)
|
||||
func (r *PushRecordRepository) DeleteExpiredRecords() error {
|
||||
return r.db.Where("expired_at IS NOT NULL AND expired_at < ?", time.Now()).
|
||||
Delete(&model.PushRecord{}).Error
|
||||
}
|
||||
|
||||
// GetStatsByUserID 获取用户推送统计
|
||||
func (r *PushRecordRepository) GetStatsByUserID(userID int64) (map[model.PushStatus]int64, error) {
|
||||
type statusCount struct {
|
||||
Status model.PushStatus
|
||||
Count int64
|
||||
}
|
||||
var results []statusCount
|
||||
|
||||
err := r.db.Model(&model.PushRecord{}).
|
||||
Select("push_status as status, count(*) as count").
|
||||
Where("user_id = ?", userID).
|
||||
Group("push_status").
|
||||
Scan(&results).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stats := make(map[model.PushStatus]int64)
|
||||
for _, r := range results {
|
||||
stats[r.Status] = r.Count
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
112
internal/repository/sticker_repo.go
Normal file
112
internal/repository/sticker_repo.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// StickerRepository 自定义表情仓库接口
|
||||
type StickerRepository interface {
|
||||
// 获取用户的所有表情
|
||||
GetByUserID(userID string) ([]model.UserSticker, error)
|
||||
// 根据ID获取表情
|
||||
GetByID(id string) (*model.UserSticker, error)
|
||||
// 创建表情
|
||||
Create(sticker *model.UserSticker) error
|
||||
// 删除表情
|
||||
Delete(id string) error
|
||||
// 删除用户的所有表情
|
||||
DeleteByUserID(userID string) error
|
||||
// 检查表情是否存在
|
||||
Exists(userID string, url string) (bool, error)
|
||||
// 更新排序
|
||||
UpdateSortOrder(id string, sortOrder int) error
|
||||
// 批量更新排序
|
||||
BatchUpdateSortOrder(userID string, orders map[string]int) error
|
||||
// 获取用户表情数量
|
||||
CountByUserID(userID string) (int64, error)
|
||||
}
|
||||
|
||||
// stickerRepository 自定义表情仓库实现
|
||||
type stickerRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewStickerRepository 创建自定义表情仓库
|
||||
func NewStickerRepository(db *gorm.DB) StickerRepository {
|
||||
return &stickerRepository{db: db}
|
||||
}
|
||||
|
||||
// GetByUserID 获取用户的所有表情
|
||||
func (r *stickerRepository) GetByUserID(userID string) ([]model.UserSticker, error) {
|
||||
var stickers []model.UserSticker
|
||||
err := r.db.Where("user_id = ?", userID).
|
||||
Order("sort_order ASC, created_at DESC").
|
||||
Find(&stickers).Error
|
||||
return stickers, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取表情
|
||||
func (r *stickerRepository) GetByID(id string) (*model.UserSticker, error) {
|
||||
var sticker model.UserSticker
|
||||
err := r.db.Where("id = ?", id).First(&sticker).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sticker, nil
|
||||
}
|
||||
|
||||
// Create 创建表情
|
||||
func (r *stickerRepository) Create(sticker *model.UserSticker) error {
|
||||
return r.db.Create(sticker).Error
|
||||
}
|
||||
|
||||
// Delete 删除表情
|
||||
func (r *stickerRepository) Delete(id string) error {
|
||||
return r.db.Where("id = ?", id).Delete(&model.UserSticker{}).Error
|
||||
}
|
||||
|
||||
// DeleteByUserID 删除用户的所有表情
|
||||
func (r *stickerRepository) DeleteByUserID(userID string) error {
|
||||
return r.db.Where("user_id = ?", userID).Delete(&model.UserSticker{}).Error
|
||||
}
|
||||
|
||||
// Exists 检查表情是否存在
|
||||
func (r *stickerRepository) Exists(userID string, url string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.UserSticker{}).
|
||||
Where("user_id = ? AND url = ?", userID, url).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// UpdateSortOrder 更新排序
|
||||
func (r *stickerRepository) UpdateSortOrder(id string, sortOrder int) error {
|
||||
return r.db.Model(&model.UserSticker{}).
|
||||
Where("id = ?", id).
|
||||
Update("sort_order", sortOrder).Error
|
||||
}
|
||||
|
||||
// BatchUpdateSortOrder 批量更新排序
|
||||
func (r *stickerRepository) BatchUpdateSortOrder(userID string, orders map[string]int) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
for id, sortOrder := range orders {
|
||||
if err := tx.Model(&model.UserSticker{}).
|
||||
Where("id = ? AND user_id = ?", id, userID).
|
||||
Update("sort_order", sortOrder).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// CountByUserID 获取用户表情数量
|
||||
func (r *stickerRepository) CountByUserID(userID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.UserSticker{}).
|
||||
Where("user_id = ?", userID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
114
internal/repository/system_notification_repo.go
Normal file
114
internal/repository/system_notification_repo.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SystemNotificationRepository 系统通知仓储
|
||||
type SystemNotificationRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewSystemNotificationRepository 创建系统通知仓储
|
||||
func NewSystemNotificationRepository(db *gorm.DB) *SystemNotificationRepository {
|
||||
return &SystemNotificationRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建系统通知
|
||||
func (r *SystemNotificationRepository) Create(notification *model.SystemNotification) error {
|
||||
return r.db.Create(notification).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取通知
|
||||
func (r *SystemNotificationRepository) GetByID(id int64) (*model.SystemNotification, error) {
|
||||
var notification model.SystemNotification
|
||||
err := r.db.First(¬ification, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ¬ification, nil
|
||||
}
|
||||
|
||||
// GetByReceiverID 获取用户的通知列表
|
||||
func (r *SystemNotificationRepository) GetByReceiverID(receiverID string, page, pageSize int) ([]*model.SystemNotification, int64, error) {
|
||||
var notifications []*model.SystemNotification
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.SystemNotification{}).Where("receiver_id = ?", receiverID)
|
||||
query.Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Offset(offset).
|
||||
Limit(pageSize).
|
||||
Order("created_at DESC").
|
||||
Find(¬ifications).Error
|
||||
|
||||
return notifications, total, err
|
||||
}
|
||||
|
||||
// GetUnreadByReceiverID 获取用户的未读通知列表
|
||||
func (r *SystemNotificationRepository) GetUnreadByReceiverID(receiverID string, limit int) ([]*model.SystemNotification, error) {
|
||||
var notifications []*model.SystemNotification
|
||||
err := r.db.Where("receiver_id = ? AND is_read = ?", receiverID, false).
|
||||
Order("created_at DESC").
|
||||
Limit(limit).
|
||||
Find(¬ifications).Error
|
||||
return notifications, err
|
||||
}
|
||||
|
||||
// GetUnreadCount 获取用户未读通知数量
|
||||
func (r *SystemNotificationRepository) GetUnreadCount(receiverID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.SystemNotification{}).
|
||||
Where("receiver_id = ? AND is_read = ?", receiverID, false).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// MarkAsRead 标记单条通知为已读
|
||||
func (r *SystemNotificationRepository) MarkAsRead(id int64, receiverID string) error {
|
||||
now := model.SystemNotification{}.UpdatedAt
|
||||
return r.db.Model(&model.SystemNotification{}).
|
||||
Where("id = ? AND receiver_id = ?", id, receiverID).
|
||||
Updates(map[string]interface{}{
|
||||
"is_read": true,
|
||||
"read_at": now,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// MarkAllAsRead 标记用户所有通知为已读
|
||||
func (r *SystemNotificationRepository) MarkAllAsRead(receiverID string) error {
|
||||
now := model.SystemNotification{}.UpdatedAt
|
||||
return r.db.Model(&model.SystemNotification{}).
|
||||
Where("receiver_id = ? AND is_read = ?", receiverID, false).
|
||||
Updates(map[string]interface{}{
|
||||
"is_read": true,
|
||||
"read_at": now,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// Delete 删除通知(软删除)
|
||||
func (r *SystemNotificationRepository) Delete(id int64, receiverID string) error {
|
||||
return r.db.Where("id = ? AND receiver_id = ?", id, receiverID).
|
||||
Delete(&model.SystemNotification{}).Error
|
||||
}
|
||||
|
||||
// GetByType 获取用户指定类型的通知
|
||||
func (r *SystemNotificationRepository) GetByType(receiverID string, notifyType model.SystemNotificationType, page, pageSize int) ([]*model.SystemNotification, int64, error) {
|
||||
var notifications []*model.SystemNotification
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.SystemNotification{}).
|
||||
Where("receiver_id = ? AND type = ?", receiverID, notifyType)
|
||||
query.Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Offset(offset).
|
||||
Limit(pageSize).
|
||||
Order("created_at DESC").
|
||||
Find(¬ifications).Error
|
||||
|
||||
return notifications, total, err
|
||||
}
|
||||
404
internal/repository/user_repo.go
Normal file
404
internal/repository/user_repo.go
Normal file
@@ -0,0 +1,404 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// UserRepository 用户仓储
|
||||
type UserRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewUserRepository 创建用户仓储
|
||||
func NewUserRepository(db *gorm.DB) *UserRepository {
|
||||
return &UserRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建用户
|
||||
func (r *UserRepository) Create(user *model.User) error {
|
||||
return r.db.Create(user).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户
|
||||
func (r *UserRepository) GetByID(id string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.First(&user, "id = ?", id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetByUsername 根据用户名获取用户
|
||||
func (r *UserRepository) GetByUsername(username string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.First(&user, "username = ?", username).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetByEmail 根据邮箱获取用户
|
||||
func (r *UserRepository) GetByEmail(email string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.First(&user, "email = ?", email).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetByPhone 根据手机号获取用户
|
||||
func (r *UserRepository) GetByPhone(phone string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.First(&user, "phone = ?", phone).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// Update 更新用户
|
||||
func (r *UserRepository) Update(user *model.User) error {
|
||||
return r.db.Save(user).Error
|
||||
}
|
||||
|
||||
// Delete 删除用户
|
||||
func (r *UserRepository) Delete(id string) error {
|
||||
return r.db.Delete(&model.User{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// List 分页获取用户列表
|
||||
func (r *UserRepository) List(page, pageSize int) ([]*model.User, int64, error) {
|
||||
var users []*model.User
|
||||
var total int64
|
||||
|
||||
r.db.Model(&model.User{}).Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := r.db.Order("created_at DESC, id DESC").Offset(offset).Limit(pageSize).Find(&users).Error
|
||||
|
||||
return users, total, err
|
||||
}
|
||||
|
||||
// GetFollowers 获取用户粉丝
|
||||
func (r *UserRepository) GetFollowers(userID string, page, pageSize int) ([]*model.User, int64, error) {
|
||||
var users []*model.User
|
||||
var total int64
|
||||
|
||||
subQuery := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Select("follower_id")
|
||||
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := r.db.Where("id IN (?)", subQuery).
|
||||
Order("created_at DESC, id DESC").
|
||||
Offset(offset).Limit(pageSize).
|
||||
Find(&users).Error
|
||||
|
||||
return users, total, err
|
||||
}
|
||||
|
||||
// GetFollowing 获取用户关注
|
||||
func (r *UserRepository) GetFollowing(userID string, page, pageSize int) ([]*model.User, int64, error) {
|
||||
var users []*model.User
|
||||
var total int64
|
||||
|
||||
subQuery := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Select("following_id")
|
||||
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := r.db.Where("id IN (?)", subQuery).
|
||||
Order("created_at DESC, id DESC").
|
||||
Offset(offset).Limit(pageSize).
|
||||
Find(&users).Error
|
||||
|
||||
return users, total, err
|
||||
}
|
||||
|
||||
// CreateFollow 创建关注关系
|
||||
func (r *UserRepository) CreateFollow(follow *model.Follow) error {
|
||||
return r.db.Create(follow).Error
|
||||
}
|
||||
|
||||
// DeleteFollow 删除关注关系
|
||||
func (r *UserRepository) DeleteFollow(followerID, followingID string) error {
|
||||
return r.db.Where("follower_id = ? AND following_id = ?", followerID, followingID).Delete(&model.Follow{}).Error
|
||||
}
|
||||
|
||||
// IsFollowing 检查是否关注了某用户
|
||||
func (r *UserRepository) IsFollowing(followerID, followingID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Follow{}).Where("follower_id = ? AND following_id = ?", followerID, followingID).Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// IncrementFollowersCount 增加用户粉丝数
|
||||
func (r *UserRepository) IncrementFollowersCount(userID string) error {
|
||||
return r.db.Model(&model.User{}).Where("id = ?", userID).
|
||||
UpdateColumn("followers_count", gorm.Expr("followers_count + 1")).Error
|
||||
}
|
||||
|
||||
// DecrementFollowersCount 减少用户粉丝数
|
||||
func (r *UserRepository) DecrementFollowersCount(userID string) error {
|
||||
return r.db.Model(&model.User{}).Where("id = ? AND followers_count > 0", userID).
|
||||
UpdateColumn("followers_count", gorm.Expr("followers_count - 1")).Error
|
||||
}
|
||||
|
||||
// IncrementFollowingCount 增加用户关注数
|
||||
func (r *UserRepository) IncrementFollowingCount(userID string) error {
|
||||
return r.db.Model(&model.User{}).Where("id = ?", userID).
|
||||
UpdateColumn("following_count", gorm.Expr("following_count + 1")).Error
|
||||
}
|
||||
|
||||
// DecrementFollowingCount 减少用户关注数
|
||||
func (r *UserRepository) DecrementFollowingCount(userID string) error {
|
||||
return r.db.Model(&model.User{}).Where("id = ? AND following_count > 0", userID).
|
||||
UpdateColumn("following_count", gorm.Expr("following_count - 1")).Error
|
||||
}
|
||||
|
||||
// RefreshFollowersCount 刷新用户粉丝数(通过实际计数)
|
||||
func (r *UserRepository) RefreshFollowersCount(userID string) error {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Count(&count).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.db.Model(&model.User{}).Where("id = ?", userID).
|
||||
UpdateColumn("followers_count", count).Error
|
||||
}
|
||||
|
||||
// GetPostsCount 获取用户帖子数(实时计算)
|
||||
func (r *UserRepository) GetPostsCount(userID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Post{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetPostsCountBatch 批量获取用户帖子数(实时计算)
|
||||
// 返回 map[userID]postsCount
|
||||
func (r *UserRepository) GetPostsCountBatch(userIDs []string) (map[string]int64, error) {
|
||||
result := make(map[string]int64)
|
||||
if len(userIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 初始化所有用户ID的计数为0
|
||||
for _, userID := range userIDs {
|
||||
result[userID] = 0
|
||||
}
|
||||
|
||||
// 使用 GROUP BY 一次性查询所有用户的帖子数
|
||||
type CountResult struct {
|
||||
UserID string
|
||||
Count int64
|
||||
}
|
||||
var counts []CountResult
|
||||
err := r.db.Model(&model.Post{}).
|
||||
Select("user_id, count(*) as count").
|
||||
Where("user_id IN ?", userIDs).
|
||||
Group("user_id").
|
||||
Scan(&counts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 更新查询结果
|
||||
for _, c := range counts {
|
||||
result[c.UserID] = c.Count
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RefreshFollowingCount 刷新用户关注数(通过实际计数)
|
||||
func (r *UserRepository) RefreshFollowingCount(userID string) error {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Count(&count).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.db.Model(&model.User{}).Where("id = ?", userID).
|
||||
UpdateColumn("following_count", count).Error
|
||||
}
|
||||
|
||||
// IsBlocked 检查拉黑关系是否存在(blocker -> blocked)
|
||||
func (r *UserRepository) IsBlocked(blockerID, blockedID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.UserBlock{}).
|
||||
Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID).
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// IsBlockedEitherDirection 检查是否任一方向存在拉黑
|
||||
func (r *UserRepository) IsBlockedEitherDirection(userA, userB string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.UserBlock{}).
|
||||
Where("(blocker_id = ? AND blocked_id = ?) OR (blocker_id = ? AND blocked_id = ?)",
|
||||
userA, userB, userB, userA).
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// BlockUserAndCleanupRelations 拉黑用户并清理双向关注关系(事务)
|
||||
func (r *UserRepository) BlockUserAndCleanupRelations(blockerID, blockedID string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
block := &model.UserBlock{
|
||||
BlockerID: blockerID,
|
||||
BlockedID: blockedID,
|
||||
}
|
||||
if err := tx.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "blocker_id"}, {Name: "blocked_id"}},
|
||||
DoNothing: true,
|
||||
}).Create(block).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Where("follower_id = ? AND following_id = ?", blockerID, blockedID).
|
||||
Delete(&model.Follow{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Where("follower_id = ? AND following_id = ?", blockedID, blockerID).
|
||||
Delete(&model.Follow{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, uid := range []string{blockerID, blockedID} {
|
||||
var followersCount int64
|
||||
if err := tx.Model(&model.Follow{}).Where("following_id = ?", uid).Count(&followersCount).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(&model.User{}).Where("id = ?", uid).
|
||||
UpdateColumn("followers_count", followersCount).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var followingCount int64
|
||||
if err := tx.Model(&model.Follow{}).Where("follower_id = ?", uid).Count(&followingCount).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Model(&model.User{}).Where("id = ?", uid).
|
||||
UpdateColumn("following_count", followingCount).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// UnblockUser 取消拉黑
|
||||
func (r *UserRepository) UnblockUser(blockerID, blockedID string) error {
|
||||
return r.db.Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID).
|
||||
Delete(&model.UserBlock{}).Error
|
||||
}
|
||||
|
||||
// GetBlockedUsers 获取用户黑名单列表
|
||||
func (r *UserRepository) GetBlockedUsers(blockerID string, page, pageSize int) ([]*model.User, int64, error) {
|
||||
var users []*model.User
|
||||
var total int64
|
||||
|
||||
subQuery := r.db.Model(&model.UserBlock{}).Where("blocker_id = ?", blockerID).Select("blocked_id")
|
||||
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := r.db.Where("id IN (?)", subQuery).
|
||||
Order("created_at DESC, id DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&users).Error
|
||||
|
||||
return users, total, err
|
||||
}
|
||||
|
||||
// Search 搜索用户
|
||||
func (r *UserRepository) Search(keyword string, page, pageSize int) ([]*model.User, int64, error) {
|
||||
var users []*model.User
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.User{})
|
||||
|
||||
// 搜索用户名、昵称、简介
|
||||
if keyword != "" {
|
||||
if r.db.Dialector.Name() == "postgres" {
|
||||
query = query.Where(
|
||||
"to_tsvector('simple', COALESCE(username, '') || ' ' || COALESCE(nickname, '') || ' ' || COALESCE(bio, '')) @@ plainto_tsquery('simple', ?)",
|
||||
keyword,
|
||||
)
|
||||
} else {
|
||||
searchPattern := "%" + keyword + "%"
|
||||
query = query.Where("username LIKE ? OR nickname LIKE ? OR bio LIKE ?", searchPattern, searchPattern, searchPattern)
|
||||
}
|
||||
}
|
||||
|
||||
query.Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&users).Error
|
||||
|
||||
return users, total, err
|
||||
}
|
||||
|
||||
// GetMutualFollowStatus 批量获取双向关注状态
|
||||
// 返回 map[userID][isFollowing, isFollowingMe]
|
||||
func (r *UserRepository) GetMutualFollowStatus(currentUserID string, targetUserIDs []string) (map[string][2]bool, error) {
|
||||
result := make(map[string][2]bool)
|
||||
|
||||
if len(targetUserIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] GetMutualFollowStatus: currentUserID=%s, targetUserIDs=%v\n", currentUserID, targetUserIDs)
|
||||
|
||||
// 初始化所有目标用户为未关注状态
|
||||
for _, userID := range targetUserIDs {
|
||||
result[userID] = [2]bool{false, false}
|
||||
}
|
||||
|
||||
// 查询当前用户关注了哪些目标用户 (isFollowing)
|
||||
var followingIDs []string
|
||||
err := r.db.Model(&model.Follow{}).
|
||||
Where("follower_id = ? AND following_id IN ?", currentUserID, targetUserIDs).
|
||||
Pluck("following_id", &followingIDs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("[DEBUG] GetMutualFollowStatus: currentUser follows these targets: %v\n", followingIDs)
|
||||
for _, id := range followingIDs {
|
||||
status := result[id]
|
||||
status[0] = true
|
||||
result[id] = status
|
||||
}
|
||||
|
||||
// 查询哪些目标用户关注了当前用户 (isFollowingMe)
|
||||
var followerIDs []string
|
||||
err = r.db.Model(&model.Follow{}).
|
||||
Where("follower_id IN ? AND following_id = ?", targetUserIDs, currentUserID).
|
||||
Pluck("follower_id", &followerIDs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("[DEBUG] GetMutualFollowStatus: these targets follow currentUser: %v\n", followerIDs)
|
||||
for _, id := range followerIDs {
|
||||
status := result[id]
|
||||
status[1] = true
|
||||
result[id] = status
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] GetMutualFollowStatus: final result=%v\n", result)
|
||||
return result, nil
|
||||
}
|
||||
141
internal/repository/vote_repo.go
Normal file
141
internal/repository/vote_repo.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// VoteRepository 投票仓储
|
||||
type VoteRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewVoteRepository 创建投票仓储
|
||||
func NewVoteRepository(db *gorm.DB) *VoteRepository {
|
||||
return &VoteRepository{db: db}
|
||||
}
|
||||
|
||||
// CreateOptions 批量创建投票选项
|
||||
func (r *VoteRepository) CreateOptions(postID string, options []string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
for i, content := range options {
|
||||
option := &model.VoteOption{
|
||||
PostID: postID,
|
||||
Content: content,
|
||||
SortOrder: i,
|
||||
}
|
||||
if err := tx.Create(option).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetOptionsByPostID 获取帖子的所有投票选项
|
||||
func (r *VoteRepository) GetOptionsByPostID(postID string) ([]model.VoteOption, error) {
|
||||
var options []model.VoteOption
|
||||
err := r.db.Where("post_id = ?", postID).Order("sort_order ASC").Find(&options).Error
|
||||
return options, err
|
||||
}
|
||||
|
||||
// Vote 用户投票
|
||||
func (r *VoteRepository) Vote(postID, userID, optionID string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 检查用户是否已投票
|
||||
var existing model.UserVote
|
||||
err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&existing).Error
|
||||
if err == nil {
|
||||
// 已经投票,返回错误
|
||||
return errors.New("user already voted")
|
||||
}
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证选项是否属于该帖子
|
||||
var option model.VoteOption
|
||||
if err := tx.Where("id = ? AND post_id = ?", optionID, postID).First(&option).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return errors.New("invalid option")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建投票记录
|
||||
if err := tx.Create(&model.UserVote{
|
||||
PostID: postID,
|
||||
UserID: userID,
|
||||
OptionID: optionID,
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 原子增加选项投票数
|
||||
return tx.Model(&model.VoteOption{}).Where("id = ?", optionID).
|
||||
UpdateColumn("votes_count", gorm.Expr("votes_count + 1")).Error
|
||||
})
|
||||
}
|
||||
|
||||
// Unvote 取消投票
|
||||
func (r *VoteRepository) Unvote(postID, userID string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 获取用户的投票记录
|
||||
var userVote model.UserVote
|
||||
err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&userVote).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil // 没有投票记录,直接返回
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除投票记录
|
||||
result := tx.Where("post_id = ? AND user_id = ?", postID, userID).Delete(&model.UserVote{})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected > 0 {
|
||||
// 原子减少选项投票数
|
||||
return tx.Model(&model.VoteOption{}).Where("id = ?", userVote.OptionID).
|
||||
UpdateColumn("votes_count", gorm.Expr("votes_count - 1")).Error
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserVote 获取用户在指定帖子的投票
|
||||
func (r *VoteRepository) GetUserVote(postID, userID string) (*model.UserVote, error) {
|
||||
var userVote model.UserVote
|
||||
err := r.db.Where("post_id = ? AND user_id = ?", postID, userID).First(&userVote).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &userVote, nil
|
||||
}
|
||||
|
||||
// UpdateOption 更新选项内容
|
||||
func (r *VoteRepository) UpdateOption(optionID, content string) error {
|
||||
return r.db.Model(&model.VoteOption{}).Where("id = ?", optionID).
|
||||
Update("content", content).Error
|
||||
}
|
||||
|
||||
// DeleteOptionsByPostID 删除帖子的所有投票选项
|
||||
func (r *VoteRepository) DeleteOptionsByPostID(postID string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 删除关联的用户投票记录
|
||||
if err := tx.Where("post_id = ?", postID).Delete(&model.UserVote{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除投票选项
|
||||
return tx.Where("post_id = ?", postID).Delete(&model.VoteOption{}).Error
|
||||
})
|
||||
}
|
||||
334
internal/router/router.go
Normal file
334
internal/router/router.go
Normal file
@@ -0,0 +1,334 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"carrot_bbs/internal/handler"
|
||||
"carrot_bbs/internal/middleware"
|
||||
"carrot_bbs/internal/service"
|
||||
)
|
||||
|
||||
// Router 路由配置
|
||||
type Router struct {
|
||||
engine *gin.Engine
|
||||
userHandler *handler.UserHandler
|
||||
postHandler *handler.PostHandler
|
||||
commentHandler *handler.CommentHandler
|
||||
messageHandler *handler.MessageHandler
|
||||
notificationHandler *handler.NotificationHandler
|
||||
uploadHandler *handler.UploadHandler
|
||||
wsHandler *handler.WebSocketHandler
|
||||
pushHandler *handler.PushHandler
|
||||
systemMessageHandler *handler.SystemMessageHandler
|
||||
groupHandler *handler.GroupHandler
|
||||
stickerHandler *handler.StickerHandler
|
||||
gorseHandler *handler.GorseHandler
|
||||
voteHandler *handler.VoteHandler
|
||||
jwtService *service.JWTService
|
||||
}
|
||||
|
||||
// New 创建路由
|
||||
func New(
|
||||
userHandler *handler.UserHandler,
|
||||
postHandler *handler.PostHandler,
|
||||
commentHandler *handler.CommentHandler,
|
||||
messageHandler *handler.MessageHandler,
|
||||
notificationHandler *handler.NotificationHandler,
|
||||
uploadHandler *handler.UploadHandler,
|
||||
jwtService *service.JWTService,
|
||||
wsHandler *handler.WebSocketHandler,
|
||||
pushHandler *handler.PushHandler,
|
||||
systemMessageHandler *handler.SystemMessageHandler,
|
||||
groupHandler *handler.GroupHandler,
|
||||
stickerHandler *handler.StickerHandler,
|
||||
gorseHandler *handler.GorseHandler,
|
||||
voteHandler *handler.VoteHandler,
|
||||
) *Router {
|
||||
// 设置JWT服务
|
||||
userHandler.SetJWTService(jwtService)
|
||||
|
||||
r := &Router{
|
||||
engine: gin.Default(),
|
||||
userHandler: userHandler,
|
||||
postHandler: postHandler,
|
||||
commentHandler: commentHandler,
|
||||
messageHandler: messageHandler,
|
||||
notificationHandler: notificationHandler,
|
||||
uploadHandler: uploadHandler,
|
||||
wsHandler: wsHandler,
|
||||
pushHandler: pushHandler,
|
||||
systemMessageHandler: systemMessageHandler,
|
||||
groupHandler: groupHandler,
|
||||
stickerHandler: stickerHandler,
|
||||
gorseHandler: gorseHandler,
|
||||
voteHandler: voteHandler,
|
||||
jwtService: jwtService,
|
||||
}
|
||||
|
||||
r.setupRoutes()
|
||||
return r
|
||||
}
|
||||
|
||||
// setupRoutes 设置路由
|
||||
func (r *Router) setupRoutes() {
|
||||
// 中间件
|
||||
r.engine.Use(middleware.CORS())
|
||||
|
||||
// 健康检查
|
||||
r.engine.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// WebSocket 路由
|
||||
if r.wsHandler != nil {
|
||||
r.engine.GET("/ws", r.wsHandler.HandleWebSocket)
|
||||
}
|
||||
|
||||
// API v1
|
||||
v1 := r.engine.Group("/api/v1")
|
||||
{
|
||||
// 认证路由(公开)
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", r.userHandler.Register)
|
||||
auth.POST("/register/send-code", r.userHandler.SendRegisterCode)
|
||||
auth.POST("/login", r.userHandler.Login)
|
||||
auth.POST("/password/send-code", r.userHandler.SendPasswordResetCode)
|
||||
auth.POST("/password/reset", r.userHandler.ResetPassword)
|
||||
auth.POST("/refresh", r.userHandler.RefreshToken)
|
||||
}
|
||||
|
||||
// 需要认证的路由
|
||||
authMiddleware := middleware.Auth(r.jwtService)
|
||||
|
||||
// 用户路由
|
||||
users := v1.Group("/users")
|
||||
{
|
||||
// 当前用户
|
||||
users.GET("/me", authMiddleware, r.userHandler.GetCurrentUser)
|
||||
users.PUT("/me", authMiddleware, r.userHandler.UpdateUser)
|
||||
users.POST("/me/email/send-code", authMiddleware, r.userHandler.SendEmailVerifyCode)
|
||||
users.POST("/me/email/verify", authMiddleware, r.userHandler.VerifyEmail)
|
||||
users.POST("/me/avatar", authMiddleware, r.uploadHandler.UploadAvatar)
|
||||
users.POST("/me/cover", authMiddleware, r.uploadHandler.UploadCover)
|
||||
users.POST("/change-password/send-code", authMiddleware, r.userHandler.SendChangePasswordCode)
|
||||
users.POST("/change-password", authMiddleware, r.userHandler.ChangePassword)
|
||||
|
||||
// 搜索用户
|
||||
users.GET("/search", middleware.OptionalAuth(r.jwtService), r.userHandler.Search)
|
||||
|
||||
// 其他用户
|
||||
users.GET("/:id", middleware.OptionalAuth(r.jwtService), r.userHandler.GetUserByID)
|
||||
users.POST("/:id/follow", authMiddleware, r.userHandler.FollowUser)
|
||||
users.DELETE("/:id/follow", authMiddleware, r.userHandler.UnfollowUser)
|
||||
users.POST("/:id/block", authMiddleware, r.userHandler.BlockUser)
|
||||
users.DELETE("/:id/block", authMiddleware, r.userHandler.UnblockUser)
|
||||
users.GET("/:id/block-status", authMiddleware, r.userHandler.GetBlockStatus)
|
||||
users.GET("/blocks", authMiddleware, r.userHandler.GetBlockedUsers)
|
||||
users.GET("/:id/following", authMiddleware, r.userHandler.GetFollowingList)
|
||||
users.GET("/:id/followers", authMiddleware, r.userHandler.GetFollowersList)
|
||||
|
||||
// 用户帖子 - 使用 OptionalAuth 获取当前用户点赞/收藏状态
|
||||
users.GET("/:id/posts", middleware.OptionalAuth(r.jwtService), r.postHandler.GetUserPosts)
|
||||
users.GET("/:id/favorites", middleware.OptionalAuth(r.jwtService), r.postHandler.GetFavorites)
|
||||
}
|
||||
|
||||
// 认证路由(公开)
|
||||
authPublic := v1.Group("/auth")
|
||||
{
|
||||
authPublic.GET("/check-username", r.userHandler.CheckUsername)
|
||||
}
|
||||
|
||||
// 帖子路由
|
||||
posts := v1.Group("/posts")
|
||||
{
|
||||
// 使用 OptionalAuth 中间件来获取用户登录状态
|
||||
posts.GET("", middleware.OptionalAuth(r.jwtService), r.postHandler.List)
|
||||
posts.GET("/search", middleware.OptionalAuth(r.jwtService), r.postHandler.Search)
|
||||
posts.GET("/:id", middleware.OptionalAuth(r.jwtService), r.postHandler.GetByID)
|
||||
posts.POST("", authMiddleware, r.postHandler.Create)
|
||||
posts.PUT("/:id", authMiddleware, r.postHandler.Update)
|
||||
posts.DELETE("/:id", authMiddleware, r.postHandler.Delete)
|
||||
|
||||
// 浏览量记录(可选认证,允许游客浏览)
|
||||
posts.POST("/:id/view", middleware.OptionalAuth(r.jwtService), r.postHandler.RecordView)
|
||||
|
||||
// 点赞
|
||||
posts.POST("/:id/like", authMiddleware, r.postHandler.Like)
|
||||
posts.DELETE("/:id/like", authMiddleware, r.postHandler.Unlike)
|
||||
|
||||
// 收藏
|
||||
posts.POST("/:id/favorite", authMiddleware, r.postHandler.Favorite)
|
||||
posts.DELETE("/:id/favorite", authMiddleware, r.postHandler.Unfavorite)
|
||||
|
||||
// 投票相关路由
|
||||
posts.POST("/vote", authMiddleware, r.voteHandler.CreateVotePost) // 创建投票帖子
|
||||
posts.GET("/:id/vote", middleware.OptionalAuth(r.jwtService), r.voteHandler.GetVoteResult) // 获取投票结果
|
||||
posts.POST("/:id/vote", authMiddleware, r.voteHandler.Vote) // 投票
|
||||
posts.DELETE("/:id/vote", authMiddleware, r.voteHandler.Unvote) // 取消投票
|
||||
}
|
||||
|
||||
// 投票选项路由
|
||||
voteOptions := v1.Group("/vote-options")
|
||||
voteOptions.Use(authMiddleware)
|
||||
{
|
||||
voteOptions.PUT("/:id", r.voteHandler.UpdateVoteOption) // 更新选项
|
||||
}
|
||||
|
||||
// 评论路由
|
||||
comments := v1.Group("/comments")
|
||||
{
|
||||
comments.GET("/post/:id", middleware.OptionalAuth(r.jwtService), r.commentHandler.GetByPostID)
|
||||
comments.GET("/:id", middleware.OptionalAuth(r.jwtService), r.commentHandler.GetByID)
|
||||
comments.POST("", authMiddleware, r.commentHandler.Create)
|
||||
comments.PUT("/:id", authMiddleware, r.commentHandler.Update)
|
||||
comments.DELETE("/:id", authMiddleware, r.commentHandler.Delete)
|
||||
comments.GET("/:id/replies", middleware.OptionalAuth(r.jwtService), r.commentHandler.GetReplies)
|
||||
comments.GET("/:id/replies/flat", middleware.OptionalAuth(r.jwtService), r.commentHandler.GetRepliesByRootID) // 扁平化分页获取回复
|
||||
// 评论点赞
|
||||
comments.POST("/:id/like", authMiddleware, r.commentHandler.Like)
|
||||
comments.DELETE("/:id/like", authMiddleware, r.commentHandler.Unlike)
|
||||
}
|
||||
|
||||
// 会话路由(新版 RESTful action 风格)
|
||||
conversations := v1.Group("/conversations")
|
||||
conversations.Use(authMiddleware)
|
||||
{
|
||||
// 获取会话列表
|
||||
conversations.GET("/list", r.messageHandler.HandleGetConversationList)
|
||||
// 创建会话
|
||||
conversations.POST("/create", r.messageHandler.HandleCreateConversation)
|
||||
// 获取会话详情
|
||||
conversations.GET("/get", r.messageHandler.HandleGetConversation)
|
||||
// 获取会话消息
|
||||
conversations.GET("/get_messages", r.messageHandler.HandleGetMessages)
|
||||
// 发送消息
|
||||
conversations.POST("/send_message", r.messageHandler.HandleSendMessage)
|
||||
// 标记已读
|
||||
conversations.POST("/mark_read", r.messageHandler.HandleMarkRead)
|
||||
// 会话置顶
|
||||
conversations.POST("/set_pinned", r.messageHandler.HandleSetConversationPinned)
|
||||
// 获取未读消息总数
|
||||
conversations.GET("/unread/count", r.messageHandler.GetUnreadCount)
|
||||
// 仅自己删除会话
|
||||
conversations.DELETE("/:id/self", r.messageHandler.HandleDeleteConversationForSelf)
|
||||
}
|
||||
|
||||
// 消息操作路由
|
||||
messages := v1.Group("/messages")
|
||||
messages.Use(authMiddleware)
|
||||
{
|
||||
// 撤回/删除消息(统一接口)
|
||||
messages.POST("/delete_msg", r.messageHandler.HandleDeleteMsg)
|
||||
}
|
||||
|
||||
// 通知路由
|
||||
notifications := v1.Group("/notifications")
|
||||
{
|
||||
notifications.GET("", authMiddleware, r.notificationHandler.GetNotifications)
|
||||
notifications.POST("/:id/read", authMiddleware, r.notificationHandler.MarkAsRead)
|
||||
notifications.POST("/read-all", authMiddleware, r.notificationHandler.MarkAllAsRead)
|
||||
notifications.GET("/unread-count", authMiddleware, r.notificationHandler.GetUnreadCount)
|
||||
notifications.DELETE("/:id", authMiddleware, r.notificationHandler.DeleteNotification)
|
||||
notifications.DELETE("", authMiddleware, r.notificationHandler.ClearAllNotifications)
|
||||
}
|
||||
|
||||
// 上传路由
|
||||
uploads := v1.Group("/uploads")
|
||||
{
|
||||
uploads.POST("/images", authMiddleware, r.uploadHandler.UploadImage)
|
||||
}
|
||||
|
||||
// 推送相关路由
|
||||
if r.pushHandler != nil {
|
||||
pushGroup := v1.Group("/push")
|
||||
pushGroup.Use(authMiddleware)
|
||||
{
|
||||
pushGroup.POST("/devices", r.pushHandler.RegisterDevice)
|
||||
pushGroup.GET("/devices", r.pushHandler.GetMyDevices)
|
||||
pushGroup.DELETE("/devices/:device_id", r.pushHandler.UnregisterDevice)
|
||||
pushGroup.PUT("/devices/:device_id/token", r.pushHandler.UpdateDeviceToken)
|
||||
pushGroup.GET("/records", r.pushHandler.GetPushRecords)
|
||||
}
|
||||
}
|
||||
|
||||
// 系统消息路由
|
||||
if r.systemMessageHandler != nil {
|
||||
msgGroup := v1.Group("/messages")
|
||||
msgGroup.Use(authMiddleware)
|
||||
{
|
||||
msgGroup.GET("/system", r.systemMessageHandler.GetSystemMessages)
|
||||
msgGroup.GET("/system/unread-count", r.systemMessageHandler.GetUnreadCount)
|
||||
msgGroup.PUT("/system/:id/read", r.systemMessageHandler.MarkAsRead)
|
||||
msgGroup.PUT("/system/read-all", r.systemMessageHandler.MarkAllAsRead)
|
||||
}
|
||||
}
|
||||
|
||||
// 群组路由(新版 RESTful action 风格)
|
||||
if r.groupHandler != nil {
|
||||
groups := v1.Group("/groups")
|
||||
groups.Use(authMiddleware)
|
||||
{
|
||||
// 群组管理
|
||||
groups.POST("/create", r.groupHandler.HandleCreateGroup)
|
||||
groups.GET("/list", r.groupHandler.HandleGetUserGroups)
|
||||
groups.GET("/get", r.groupHandler.HandleGetGroupInfo)
|
||||
groups.GET("/get_my_info", r.groupHandler.HandleGetMyMemberInfo)
|
||||
groups.POST("/dissolve", r.groupHandler.HandleDissolveGroup)
|
||||
groups.POST("/transfer", r.groupHandler.HandleTransferOwner)
|
||||
|
||||
// 成员管理
|
||||
groups.POST("/invite_members", r.groupHandler.HandleInviteMembers)
|
||||
groups.POST("/join", r.groupHandler.HandleJoinGroup)
|
||||
groups.POST("/respond_invite", r.groupHandler.HandleRespondInvite)
|
||||
groups.POST("/set_group_leave", r.groupHandler.HandleSetGroupLeave)
|
||||
groups.GET("/get_members", r.groupHandler.HandleGetGroupMemberList)
|
||||
groups.POST("/set_group_kick", r.groupHandler.HandleSetGroupKick)
|
||||
groups.POST("/set_group_admin", r.groupHandler.HandleSetGroupAdmin)
|
||||
groups.POST("/set_nickname", r.groupHandler.HandleSetNickname)
|
||||
groups.POST("/set_group_ban", r.groupHandler.HandleSetGroupBan)
|
||||
|
||||
// 群设置
|
||||
groups.POST("/set_group_whole_ban", r.groupHandler.HandleSetGroupWholeBan)
|
||||
groups.POST("/set_join_type", r.groupHandler.HandleSetJoinType)
|
||||
groups.POST("/set_group_name", r.groupHandler.HandleSetGroupName)
|
||||
groups.POST("/set_group_avatar", r.groupHandler.HandleSetGroupAvatar)
|
||||
|
||||
// 群公告
|
||||
groups.POST("/create_announcement", r.groupHandler.HandleCreateAnnouncement)
|
||||
groups.GET("/get_announcements", r.groupHandler.HandleGetAnnouncements)
|
||||
groups.POST("/delete_announcement", r.groupHandler.HandleDeleteAnnouncement)
|
||||
|
||||
// 加群请求处理(预留)
|
||||
groups.POST("/set_group_add_request", r.groupHandler.HandleSetGroupAddRequest)
|
||||
}
|
||||
}
|
||||
|
||||
// 自定义表情路由
|
||||
if r.stickerHandler != nil {
|
||||
stickers := v1.Group("/stickers")
|
||||
stickers.Use(authMiddleware)
|
||||
{
|
||||
stickers.GET("", r.stickerHandler.GetStickers)
|
||||
stickers.POST("", r.stickerHandler.AddSticker)
|
||||
stickers.DELETE("", r.stickerHandler.DeleteSticker)
|
||||
stickers.POST("/reorder", r.stickerHandler.ReorderStickers)
|
||||
stickers.GET("/check", r.stickerHandler.CheckStickerExists)
|
||||
}
|
||||
}
|
||||
|
||||
// Gorse 管理路由
|
||||
if r.gorseHandler != nil {
|
||||
gorseGroup := v1.Group("/gorse")
|
||||
{
|
||||
gorseGroup.GET("/status", r.gorseHandler.GetStatus)
|
||||
gorseGroup.POST("/import", r.gorseHandler.ImportData)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Engine 获取引擎
|
||||
func (r *Router) Engine() *gin.Engine {
|
||||
return r.engine
|
||||
}
|
||||
759
internal/service/audit_service.go
Normal file
759
internal/service/audit_service.go
Normal file
@@ -0,0 +1,759 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ==================== 内容审核服务接口和实现 ====================
|
||||
|
||||
// AuditServiceProvider 内容审核服务提供商接口
|
||||
type AuditServiceProvider interface {
|
||||
// AuditText 审核文本
|
||||
AuditText(ctx context.Context, text string, scene string) (*AuditResult, error)
|
||||
// AuditImage 审核图片
|
||||
AuditImage(ctx context.Context, imageURL string) (*AuditResult, error)
|
||||
// GetName 获取提供商名称
|
||||
GetName() string
|
||||
}
|
||||
|
||||
// AuditResult 审核结果
|
||||
type AuditResult struct {
|
||||
Pass bool `json:"pass"` // 是否通过
|
||||
Risk string `json:"risk"` // 风险等级: low, medium, high
|
||||
Labels []string `json:"labels"` // 标签列表
|
||||
Suggest string `json:"suggest"` // 建议: pass, review, block
|
||||
Detail string `json:"detail"` // 详细说明
|
||||
Provider string `json:"provider"` // 服务提供商
|
||||
}
|
||||
|
||||
// AuditService 内容审核服务接口
|
||||
type AuditService interface {
|
||||
// AuditText 审核文本
|
||||
AuditText(ctx context.Context, text string, auditType string) (*AuditResult, error)
|
||||
// AuditImage 审核图片
|
||||
AuditImage(ctx context.Context, imageURL string) (*AuditResult, error)
|
||||
// GetAuditResult 获取审核结果
|
||||
GetAuditResult(ctx context.Context, auditID string) (*AuditResult, error)
|
||||
// SetProvider 设置审核服务提供商
|
||||
SetProvider(provider AuditServiceProvider)
|
||||
// GetProvider 获取当前审核服务提供商
|
||||
GetProvider() AuditServiceProvider
|
||||
}
|
||||
|
||||
// auditServiceImpl 内容审核服务实现
|
||||
type auditServiceImpl struct {
|
||||
db *gorm.DB
|
||||
provider AuditServiceProvider
|
||||
config *AuditConfig
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// AuditConfig 内容审核服务配置
|
||||
type AuditConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" yaml:"enabled"`
|
||||
// 审核服务提供商: local, aliyun, tencent, baidu
|
||||
Provider string `mapstructure:"provider" yaml:"provider"`
|
||||
// 阿里云配置
|
||||
AliyunAccessKey string `mapstructure:"aliyun_access_key" yaml:"aliyun_access_key"`
|
||||
AliyunSecretKey string `mapstructure:"aliyun_secret_key" yaml:"aliyun_secret_key"`
|
||||
AliyunRegion string `mapstructure:"aliyun_region" yaml:"aliyun_region"`
|
||||
// 腾讯云配置
|
||||
TencentSecretID string `mapstructure:"tencent_secret_id" yaml:"tencent_secret_id"`
|
||||
TencentSecretKey string `mapstructure:"tencent_secret_key" yaml:"tencent_secret_key"`
|
||||
// 百度云配置
|
||||
BaiduAPIKey string `mapstructure:"baidu_api_key" yaml:"baidu_api_key"`
|
||||
BaiduSecretKey string `mapstructure:"baidu_secret_key" yaml:"baidu_secret_key"`
|
||||
// 是否自动审核
|
||||
AutoAudit bool `mapstructure:"auto_audit" yaml:"auto_audit"`
|
||||
// 审核超时时间(秒)
|
||||
Timeout int `mapstructure:"timeout" yaml:"timeout"`
|
||||
}
|
||||
|
||||
// NewAuditService 创建内容审核服务
|
||||
func NewAuditService(db *gorm.DB, config *AuditConfig) AuditService {
|
||||
s := &auditServiceImpl{
|
||||
db: db,
|
||||
config: config,
|
||||
}
|
||||
|
||||
// 根据配置初始化提供商
|
||||
if config.Enabled {
|
||||
provider := s.initProvider(config.Provider)
|
||||
s.provider = provider
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// initProvider 根据配置初始化审核服务提供商
|
||||
func (s *auditServiceImpl) initProvider(providerType string) AuditServiceProvider {
|
||||
switch strings.ToLower(providerType) {
|
||||
case "aliyun":
|
||||
return NewAliyunAuditProvider(s.config.AliyunAccessKey, s.config.AliyunSecretKey, s.config.AliyunRegion)
|
||||
case "tencent":
|
||||
return NewTencentAuditProvider(s.config.TencentSecretID, s.config.TencentSecretKey)
|
||||
case "baidu":
|
||||
return NewBaiduAuditProvider(s.config.BaiduAPIKey, s.config.BaiduSecretKey)
|
||||
case "local":
|
||||
fallthrough
|
||||
default:
|
||||
// 默认使用本地审核服务
|
||||
return NewLocalAuditProvider()
|
||||
}
|
||||
}
|
||||
|
||||
// AuditText 审核文本
|
||||
func (s *auditServiceImpl) AuditText(ctx context.Context, text string, auditType string) (*AuditResult, error) {
|
||||
if !s.config.Enabled {
|
||||
// 如果审核服务未启用,直接返回通过
|
||||
return &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Detail: "Audit service disabled",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if text == "" {
|
||||
return &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Detail: "Empty text",
|
||||
}, nil
|
||||
}
|
||||
|
||||
var result *AuditResult
|
||||
var err error
|
||||
|
||||
// 使用提供商审核
|
||||
if s.provider != nil {
|
||||
result, err = s.provider.AuditText(ctx, text, auditType)
|
||||
} else {
|
||||
// 如果没有设置提供商,使用本地审核
|
||||
localProvider := NewLocalAuditProvider()
|
||||
result, err = localProvider.AuditText(ctx, text, auditType)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Audit text error: %v", err)
|
||||
return &AuditResult{
|
||||
Pass: false,
|
||||
Risk: "high",
|
||||
Suggest: "review",
|
||||
Detail: fmt.Sprintf("Audit error: %v", err),
|
||||
}, err
|
||||
}
|
||||
|
||||
// 记录审核日志
|
||||
go s.saveAuditLog(ctx, "text", "", text, auditType, result)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AuditImage 审核图片
|
||||
func (s *auditServiceImpl) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
|
||||
if !s.config.Enabled {
|
||||
return &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Detail: "Audit service disabled",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if imageURL == "" {
|
||||
return &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Detail: "Empty image URL",
|
||||
}, nil
|
||||
}
|
||||
|
||||
var result *AuditResult
|
||||
var err error
|
||||
|
||||
// 使用提供商审核
|
||||
if s.provider != nil {
|
||||
result, err = s.provider.AuditImage(ctx, imageURL)
|
||||
} else {
|
||||
// 如果没有设置提供商,使用本地审核
|
||||
localProvider := NewLocalAuditProvider()
|
||||
result, err = localProvider.AuditImage(ctx, imageURL)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Audit image error: %v", err)
|
||||
return &AuditResult{
|
||||
Pass: false,
|
||||
Risk: "high",
|
||||
Suggest: "review",
|
||||
Detail: fmt.Sprintf("Audit error: %v", err),
|
||||
}, err
|
||||
}
|
||||
|
||||
// 记录审核日志
|
||||
go s.saveAuditLog(ctx, "image", "", "", "image", result)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetAuditResult 获取审核结果
|
||||
func (s *auditServiceImpl) GetAuditResult(ctx context.Context, auditID string) (*AuditResult, error) {
|
||||
if s.db == nil || auditID == "" {
|
||||
return nil, fmt.Errorf("invalid audit ID")
|
||||
}
|
||||
|
||||
var auditLog model.AuditLog
|
||||
if err := s.db.Where("id = ?", auditID).First(&auditLog).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &AuditResult{
|
||||
Pass: auditLog.Result == model.AuditResultPass,
|
||||
Risk: string(auditLog.RiskLevel),
|
||||
Suggest: auditLog.Suggestion,
|
||||
Detail: auditLog.Detail,
|
||||
}
|
||||
|
||||
// 解析标签
|
||||
if auditLog.Labels != "" {
|
||||
json.Unmarshal([]byte(auditLog.Labels), &result.Labels)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SetProvider 设置审核服务提供商
|
||||
func (s *auditServiceImpl) SetProvider(provider AuditServiceProvider) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.provider = provider
|
||||
}
|
||||
|
||||
// GetProvider 获取当前审核服务提供商
|
||||
func (s *auditServiceImpl) GetProvider() AuditServiceProvider {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.provider
|
||||
}
|
||||
|
||||
// saveAuditLog 保存审核日志
|
||||
func (s *auditServiceImpl) saveAuditLog(ctx context.Context, contentType, content, imageURL, auditType string, result *AuditResult) {
|
||||
if s.db == nil {
|
||||
return
|
||||
}
|
||||
|
||||
auditLog := model.AuditLog{
|
||||
ContentType: contentType,
|
||||
Content: content,
|
||||
ContentURL: imageURL,
|
||||
AuditType: auditType,
|
||||
Labels: strings.Join(result.Labels, ","),
|
||||
Suggestion: result.Suggest,
|
||||
Detail: result.Detail,
|
||||
Source: model.AuditSourceAuto,
|
||||
Status: "completed",
|
||||
}
|
||||
|
||||
if result.Pass {
|
||||
auditLog.Result = model.AuditResultPass
|
||||
} else if result.Suggest == "review" {
|
||||
auditLog.Result = model.AuditResultReview
|
||||
} else {
|
||||
auditLog.Result = model.AuditResultBlock
|
||||
}
|
||||
|
||||
switch result.Risk {
|
||||
case "low":
|
||||
auditLog.RiskLevel = model.AuditRiskLevelLow
|
||||
case "medium":
|
||||
auditLog.RiskLevel = model.AuditRiskLevelMedium
|
||||
case "high":
|
||||
auditLog.RiskLevel = model.AuditRiskLevelHigh
|
||||
default:
|
||||
auditLog.RiskLevel = model.AuditRiskLevelLow
|
||||
}
|
||||
|
||||
if err := s.db.Create(&auditLog).Error; err != nil {
|
||||
log.Printf("Failed to save audit log: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 本地审核服务提供商 ====================
|
||||
|
||||
// localAuditProvider 本地审核服务提供商
|
||||
type localAuditProvider struct {
|
||||
// 可以注入敏感词服务进行本地审核
|
||||
sensitiveService SensitiveService
|
||||
}
|
||||
|
||||
// NewLocalAuditProvider 创建本地审核服务提供商
|
||||
func NewLocalAuditProvider() AuditServiceProvider {
|
||||
return &localAuditProvider{
|
||||
sensitiveService: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 获取提供商名称
|
||||
func (p *localAuditProvider) GetName() string {
|
||||
return "local"
|
||||
}
|
||||
|
||||
// AuditText 审核文本
|
||||
func (p *localAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
|
||||
// 本地审核逻辑
|
||||
// 1. 敏感词检查
|
||||
// 2. 规则匹配
|
||||
// 3. 简单的关键词检测
|
||||
|
||||
result := &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Labels: []string{},
|
||||
Provider: "local",
|
||||
}
|
||||
|
||||
// 如果有敏感词服务,使用它进行检测
|
||||
if p.sensitiveService != nil {
|
||||
hasSensitive, words := p.sensitiveService.Check(ctx, text)
|
||||
if hasSensitive {
|
||||
result.Pass = false
|
||||
result.Risk = "high"
|
||||
result.Suggest = "block"
|
||||
result.Detail = fmt.Sprintf("包含敏感词: %s", strings.Join(words, ","))
|
||||
result.Labels = append(result.Labels, "sensitive")
|
||||
}
|
||||
}
|
||||
|
||||
// 简单的关键词检测规则
|
||||
// 实际项目中应该从数据库加载
|
||||
suspiciousPatterns := []string{
|
||||
"诈骗",
|
||||
"钓鱼",
|
||||
"木马",
|
||||
"病毒",
|
||||
}
|
||||
|
||||
for _, pattern := range suspiciousPatterns {
|
||||
if strings.Contains(text, pattern) {
|
||||
result.Pass = false
|
||||
result.Risk = "high"
|
||||
result.Suggest = "block"
|
||||
result.Labels = append(result.Labels, "suspicious")
|
||||
if result.Detail == "" {
|
||||
result.Detail = fmt.Sprintf("包含可疑内容: %s", pattern)
|
||||
} else {
|
||||
result.Detail += fmt.Sprintf(", %s", pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AuditImage 审核图片
|
||||
func (p *localAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
|
||||
// 本地图片审核逻辑
|
||||
// 1. 图片URL合法性检查
|
||||
// 2. 图片格式检查
|
||||
// 3. 可以扩展接入本地图片识别服务
|
||||
|
||||
result := &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Labels: []string{},
|
||||
Provider: "local",
|
||||
}
|
||||
|
||||
// 检查URL是否为空
|
||||
if imageURL == "" {
|
||||
result.Detail = "Empty image URL"
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 检查是否为支持的图片URL格式
|
||||
validPrefixes := []string{"http://", "https://", "s3://", "oss://", "cos://"}
|
||||
isValid := false
|
||||
for _, prefix := range validPrefixes {
|
||||
if strings.HasPrefix(strings.ToLower(imageURL), prefix) {
|
||||
isValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isValid {
|
||||
result.Pass = false
|
||||
result.Risk = "medium"
|
||||
result.Suggest = "review"
|
||||
result.Detail = "Invalid image URL format"
|
||||
result.Labels = append(result.Labels, "invalid_url")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SetSensitiveService 设置敏感词服务
|
||||
func (p *localAuditProvider) SetSensitiveService(ss SensitiveService) {
|
||||
p.sensitiveService = ss
|
||||
}
|
||||
|
||||
// ==================== 阿里云审核服务提供商 ====================
|
||||
|
||||
// aliyunAuditProvider 阿里云审核服务提供商
|
||||
type aliyunAuditProvider struct {
|
||||
accessKey string
|
||||
secretKey string
|
||||
region string
|
||||
}
|
||||
|
||||
// NewAliyunAuditProvider 创建阿里云审核服务提供商
|
||||
func NewAliyunAuditProvider(accessKey, secretKey, region string) AuditServiceProvider {
|
||||
return &aliyunAuditProvider{
|
||||
accessKey: accessKey,
|
||||
secretKey: secretKey,
|
||||
region: region,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 获取提供商名称
|
||||
func (p *aliyunAuditProvider) GetName() string {
|
||||
return "aliyun"
|
||||
}
|
||||
|
||||
// AuditText 审核文本
|
||||
func (p *aliyunAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
|
||||
// 阿里云内容安全API调用
|
||||
// 实际项目中需要实现阿里云SDK调用
|
||||
// 这里预留接口
|
||||
|
||||
result := &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Labels: []string{},
|
||||
Provider: "aliyun",
|
||||
Detail: "Aliyun audit not implemented, using pass",
|
||||
}
|
||||
|
||||
// TODO: 实现阿里云内容安全API调用
|
||||
// 具体参考: https://help.aliyun.com/document_detail/28417.html
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AuditImage 审核图片
|
||||
func (p *aliyunAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
|
||||
result := &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Labels: []string{},
|
||||
Provider: "aliyun",
|
||||
Detail: "Aliyun image audit not implemented, using pass",
|
||||
}
|
||||
|
||||
// TODO: 实现阿里云图片审核API调用
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ==================== 腾讯云审核服务提供商 ====================
|
||||
|
||||
// tencentAuditProvider 腾讯云审核服务提供商
|
||||
type tencentAuditProvider struct {
|
||||
secretID string
|
||||
secretKey string
|
||||
}
|
||||
|
||||
// NewTencentAuditProvider 创建腾讯云审核服务提供商
|
||||
func NewTencentAuditProvider(secretID, secretKey string) AuditServiceProvider {
|
||||
return &tencentAuditProvider{
|
||||
secretID: secretID,
|
||||
secretKey: secretKey,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 获取提供商名称
|
||||
func (p *tencentAuditProvider) GetName() string {
|
||||
return "tencent"
|
||||
}
|
||||
|
||||
// AuditText 审核文本
|
||||
func (p *tencentAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
|
||||
result := &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Labels: []string{},
|
||||
Provider: "tencent",
|
||||
Detail: "Tencent audit not implemented, using pass",
|
||||
}
|
||||
|
||||
// TODO: 实现腾讯云内容审核API调用
|
||||
// 具体参考: https://cloud.tencent.com/document/product/1124/64508
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AuditImage 审核图片
|
||||
func (p *tencentAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
|
||||
result := &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Labels: []string{},
|
||||
Provider: "tencent",
|
||||
Detail: "Tencent image audit not implemented, using pass",
|
||||
}
|
||||
|
||||
// TODO: 实现腾讯云图片审核API调用
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ==================== 百度云审核服务提供商 ====================
|
||||
|
||||
// baiduAuditProvider 百度云审核服务提供商
|
||||
type baiduAuditProvider struct {
|
||||
apiKey string
|
||||
secretKey string
|
||||
}
|
||||
|
||||
// NewBaiduAuditProvider 创建百度云审核服务提供商
|
||||
func NewBaiduAuditProvider(apiKey, secretKey string) AuditServiceProvider {
|
||||
return &baiduAuditProvider{
|
||||
apiKey: apiKey,
|
||||
secretKey: secretKey,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 获取提供商名称
|
||||
func (p *baiduAuditProvider) GetName() string {
|
||||
return "baidu"
|
||||
}
|
||||
|
||||
// AuditText 审核文本
|
||||
func (p *baiduAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
|
||||
result := &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Labels: []string{},
|
||||
Provider: "baidu",
|
||||
Detail: "Baidu audit not implemented, using pass",
|
||||
}
|
||||
|
||||
// TODO: 实现百度云内容审核API调用
|
||||
// 具体参考: https://cloud.baidu.com/doc/ANTISPAM/s/Jjw0r1iF6
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AuditImage 审核图片
|
||||
func (p *baiduAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
|
||||
result := &AuditResult{
|
||||
Pass: true,
|
||||
Risk: "low",
|
||||
Suggest: "pass",
|
||||
Labels: []string{},
|
||||
Provider: "baidu",
|
||||
Detail: "Baidu image audit not implemented, using pass",
|
||||
}
|
||||
|
||||
// TODO: 实现百度云图片审核API调用
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ==================== 审核结果回调处理 ====================
|
||||
|
||||
// AuditCallback 审核回调处理
|
||||
type AuditCallback struct {
|
||||
service AuditService
|
||||
}
|
||||
|
||||
// NewAuditCallback 创建审核回调处理
|
||||
func NewAuditCallback(service AuditService) *AuditCallback {
|
||||
return &AuditCallback{
|
||||
service: service,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleTextCallback 处理文本审核回调
|
||||
func (c *AuditCallback) HandleTextCallback(ctx context.Context, auditID string, result *AuditResult) error {
|
||||
if c.service == nil || auditID == "" || result == nil {
|
||||
return fmt.Errorf("invalid parameters")
|
||||
}
|
||||
|
||||
log.Printf("Processing text audit callback: auditID=%s, result=%+v", auditID, result)
|
||||
|
||||
// 根据审核结果执行相应操作
|
||||
// 例如: 更新帖子状态、发送通知等
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleImageCallback 处理图片审核回调
|
||||
func (c *AuditCallback) HandleImageCallback(ctx context.Context, auditID string, result *AuditResult) error {
|
||||
if c.service == nil || auditID == "" || result == nil {
|
||||
return fmt.Errorf("invalid parameters")
|
||||
}
|
||||
|
||||
log.Printf("Processing image audit callback: auditID=%s, result=%+v", auditID, result)
|
||||
|
||||
// 根据审核结果执行相应操作
|
||||
// 例如: 更新图片状态、删除违规图片等
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== 辅助函数 ====================
|
||||
|
||||
// IsContentSafe 判断内容是否安全
|
||||
func IsContentSafe(result *AuditResult) bool {
|
||||
if result == nil {
|
||||
return true
|
||||
}
|
||||
return result.Pass && result.Suggest != "block"
|
||||
}
|
||||
|
||||
// NeedReview 判断内容是否需要人工复审
|
||||
func NeedReview(result *AuditResult) bool {
|
||||
if result == nil {
|
||||
return false
|
||||
}
|
||||
return result.Suggest == "review"
|
||||
}
|
||||
|
||||
// GetRiskLevel 获取风险等级
|
||||
func GetRiskLevel(result *AuditResult) string {
|
||||
if result == nil {
|
||||
return "low"
|
||||
}
|
||||
return result.Risk
|
||||
}
|
||||
|
||||
// FormatAuditResult 格式化审核结果为字符串
|
||||
func FormatAuditResult(result *AuditResult) string {
|
||||
if result == nil {
|
||||
return "{}"
|
||||
}
|
||||
data, _ := json.Marshal(result)
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// ParseAuditResult 从字符串解析审核结果
|
||||
func ParseAuditResult(data string) (*AuditResult, error) {
|
||||
if data == "" {
|
||||
return nil, fmt.Errorf("empty data")
|
||||
}
|
||||
var result AuditResult
|
||||
if err := json.Unmarshal([]byte(data), &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ==================== 审核日志查询 ====================
|
||||
|
||||
// GetAuditLogs 获取审核日志列表
|
||||
func GetAuditLogs(db *gorm.DB, targetType string, targetID string, result string, page, pageSize int) ([]model.AuditLog, int64, error) {
|
||||
query := db.Model(&model.AuditLog{})
|
||||
|
||||
if targetType != "" {
|
||||
query = query.Where("target_type = ?", targetType)
|
||||
}
|
||||
if targetID != "" {
|
||||
query = query.Where("target_id = ?", targetID)
|
||||
}
|
||||
if result != "" {
|
||||
query = query.Where("result = ?", result)
|
||||
}
|
||||
|
||||
var total int64
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var logs []model.AuditLog
|
||||
offset := (page - 1) * pageSize
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&logs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return logs, total, nil
|
||||
}
|
||||
|
||||
// ==================== 定时任务 ====================
|
||||
|
||||
// AuditScheduler 审核调度器
|
||||
type AuditScheduler struct {
|
||||
db *gorm.DB
|
||||
service AuditService
|
||||
interval time.Duration
|
||||
stopCh chan bool
|
||||
}
|
||||
|
||||
// NewAuditScheduler 创建审核调度器
|
||||
func NewAuditScheduler(db *gorm.DB, service AuditService, interval time.Duration) *AuditScheduler {
|
||||
return &AuditScheduler{
|
||||
db: db,
|
||||
service: service,
|
||||
interval: interval,
|
||||
stopCh: make(chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动调度器
|
||||
func (s *AuditScheduler) Start() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(s.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.processPendingAudits()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Stop 停止调度器
|
||||
func (s *AuditScheduler) Stop() {
|
||||
s.stopCh <- true
|
||||
}
|
||||
|
||||
// processPendingAudits 处理待审核内容
|
||||
func (s *AuditScheduler) processPendingAudits() {
|
||||
// 查询待审核的内容
|
||||
// 1. 查询审核状态为 pending 的记录
|
||||
// 2. 调用审核服务
|
||||
// 3. 更新审核状态
|
||||
|
||||
// 示例逻辑,实际需要根据业务需求实现
|
||||
log.Println("Processing pending audits...")
|
||||
}
|
||||
|
||||
// CleanupOldLogs 清理旧的审核日志
|
||||
func CleanupOldLogs(db *gorm.DB, days int) error {
|
||||
// 清理指定天数之前的审核日志
|
||||
cutoffTime := time.Now().AddDate(0, 0, -days)
|
||||
return db.Where("created_at < ? AND result = ?", cutoffTime, model.AuditResultPass).Delete(&model.AuditLog{}).Error
|
||||
}
|
||||
622
internal/service/chat_service.go
Normal file
622
internal/service/chat_service.go
Normal file
@@ -0,0 +1,622 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/websocket"
|
||||
"carrot_bbs/internal/repository"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 撤回消息的时间限制(2分钟)
|
||||
const RecallMessageTimeout = 2 * time.Minute
|
||||
|
||||
// ChatService 聊天服务接口
|
||||
type ChatService interface {
|
||||
// 会话管理
|
||||
GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error)
|
||||
GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error)
|
||||
GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error)
|
||||
DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error
|
||||
SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error
|
||||
|
||||
// 消息操作
|
||||
SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
|
||||
GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error)
|
||||
GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error)
|
||||
GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error)
|
||||
|
||||
// 已读管理
|
||||
MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error
|
||||
GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error)
|
||||
GetAllUnreadCount(ctx context.Context, userID string) (int64, error)
|
||||
|
||||
// 消息扩展功能
|
||||
RecallMessage(ctx context.Context, messageID string, userID string) error
|
||||
DeleteMessage(ctx context.Context, messageID string, userID string) error
|
||||
|
||||
// WebSocket相关
|
||||
SendTyping(ctx context.Context, senderID string, conversationID string)
|
||||
BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string)
|
||||
|
||||
// 系统消息推送
|
||||
IsUserOnline(userID string) bool
|
||||
PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error
|
||||
PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error
|
||||
PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error
|
||||
|
||||
// 仅保存消息到数据库,不发送 WebSocket 推送(供群聊等自行推送的场景使用)
|
||||
SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
|
||||
}
|
||||
|
||||
// chatServiceImpl 聊天服务实现
|
||||
type chatServiceImpl struct {
|
||||
db *gorm.DB
|
||||
repo *repository.MessageRepository
|
||||
userRepo *repository.UserRepository
|
||||
sensitive SensitiveService
|
||||
wsManager *websocket.WebSocketManager
|
||||
}
|
||||
|
||||
// NewChatService 创建聊天服务
|
||||
func NewChatService(
|
||||
db *gorm.DB,
|
||||
repo *repository.MessageRepository,
|
||||
userRepo *repository.UserRepository,
|
||||
sensitive SensitiveService,
|
||||
wsManager *websocket.WebSocketManager,
|
||||
) ChatService {
|
||||
return &chatServiceImpl{
|
||||
db: db,
|
||||
repo: repo,
|
||||
userRepo: userRepo,
|
||||
sensitive: sensitive,
|
||||
wsManager: wsManager,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrCreateConversation 获取或创建私聊会话
|
||||
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
|
||||
return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
|
||||
}
|
||||
|
||||
// GetConversationList 获取用户的会话列表
|
||||
func (s *chatServiceImpl) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
|
||||
return s.repo.GetConversations(userID, page, pageSize)
|
||||
}
|
||||
|
||||
// GetConversationByID 获取会话详情
|
||||
func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) {
|
||||
// 验证用户是否是会话参与者
|
||||
participant, err := s.repo.GetParticipant(conversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("conversation not found or no permission")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
// 获取会话信息
|
||||
conv, err := s.repo.GetConversation(conversationID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
||||
}
|
||||
|
||||
// 填充用户的已读位置信息
|
||||
_ = participant // 可以用于返回已读位置等信息
|
||||
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
// DeleteConversationForSelf 仅自己删除会话
|
||||
func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error {
|
||||
participant, err := s.repo.GetParticipant(conversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("conversation not found or no permission")
|
||||
}
|
||||
return fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
if participant.ConversationID == "" {
|
||||
return errors.New("conversation not found or no permission")
|
||||
}
|
||||
|
||||
if err := s.repo.HideConversationForUser(conversationID, userID); err != nil {
|
||||
return fmt.Errorf("failed to hide conversation: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetConversationPinned 设置会话置顶(用户维度)
|
||||
func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error {
|
||||
participant, err := s.repo.GetParticipant(conversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("conversation not found or no permission")
|
||||
}
|
||||
return fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
if participant.ConversationID == "" {
|
||||
return errors.New("conversation not found or no permission")
|
||||
}
|
||||
|
||||
if err := s.repo.UpdatePinned(conversationID, userID, isPinned); err != nil {
|
||||
return fmt.Errorf("failed to update pinned status: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendMessage 发送消息(使用 segments)
|
||||
func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
|
||||
// 首先验证会话是否存在
|
||||
conv, err := s.repo.GetConversation(conversationID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("会话不存在,请重新创建会话")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
||||
}
|
||||
|
||||
// 拉黑限制:仅拦截“被拉黑方 -> 拉黑人”方向
|
||||
if conv.Type == model.ConversationTypePrivate && s.userRepo != nil {
|
||||
participants, pErr := s.repo.GetConversationParticipants(conversationID)
|
||||
if pErr != nil {
|
||||
return nil, fmt.Errorf("failed to get participants: %w", pErr)
|
||||
}
|
||||
var sentCount *int64
|
||||
for _, p := range participants {
|
||||
if p.UserID == senderID {
|
||||
continue
|
||||
}
|
||||
blocked, bErr := s.userRepo.IsBlocked(p.UserID, senderID)
|
||||
if bErr != nil {
|
||||
return nil, fmt.Errorf("failed to check block status: %w", bErr)
|
||||
}
|
||||
if blocked {
|
||||
return nil, ErrUserBlocked
|
||||
}
|
||||
|
||||
// 陌生人限制:对方未回关前,只允许发送一条文本消息,且禁止发送图片
|
||||
isFollowedBack, fErr := s.userRepo.IsFollowing(p.UserID, senderID)
|
||||
if fErr != nil {
|
||||
return nil, fmt.Errorf("failed to check follow status: %w", fErr)
|
||||
}
|
||||
if !isFollowedBack {
|
||||
if containsImageSegment(segments) {
|
||||
return nil, errors.New("对方未关注你,暂不支持发送图片")
|
||||
}
|
||||
if sentCount == nil {
|
||||
c, cErr := s.repo.CountMessagesBySenderInConversation(conversationID, senderID)
|
||||
if cErr != nil {
|
||||
return nil, fmt.Errorf("failed to count sender messages: %w", cErr)
|
||||
}
|
||||
sentCount = &c
|
||||
}
|
||||
if *sentCount >= 1 {
|
||||
return nil, errors.New("对方未关注你前,仅允许发送一条消息")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 验证用户是否是会话参与者
|
||||
participant, err := s.repo.GetParticipant(conversationID, senderID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("您不是该会话的参与者")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
// 创建消息
|
||||
message := &model.Message{
|
||||
ConversationID: conversationID,
|
||||
SenderID: senderID, // 直接使用string类型的UUID
|
||||
Segments: segments,
|
||||
ReplyToID: replyToID,
|
||||
Status: model.MessageStatusNormal,
|
||||
}
|
||||
|
||||
// 使用事务创建消息并更新seq
|
||||
if err := s.repo.CreateMessageWithSeq(message); err != nil {
|
||||
return nil, fmt.Errorf("failed to save message: %w", err)
|
||||
}
|
||||
|
||||
// 发送消息给接收者
|
||||
log.Printf("[DEBUG SendMessage] 私聊消息 segments 类型: %T, 值: %+v", message.Segments, message.Segments)
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, websocket.ChatMessage{
|
||||
ID: message.ID,
|
||||
ConversationID: message.ConversationID,
|
||||
SenderID: senderID,
|
||||
Segments: message.Segments,
|
||||
Seq: message.Seq,
|
||||
CreatedAt: message.CreatedAt.UnixMilli(),
|
||||
})
|
||||
|
||||
// 获取会话中的其他参与者
|
||||
participants, err := s.repo.GetConversationParticipants(conversationID)
|
||||
if err == nil {
|
||||
for _, p := range participants {
|
||||
// 不发给自己
|
||||
if p.UserID == senderID {
|
||||
continue
|
||||
}
|
||||
// 如果接收者在线,发送实时消息
|
||||
if s.wsManager != nil {
|
||||
isOnline := s.wsManager.IsUserOnline(p.UserID)
|
||||
log.Printf("[DEBUG SendMessage] 接收者 UserID=%s, 在线状态=%v", p.UserID, isOnline)
|
||||
if isOnline {
|
||||
log.Printf("[DEBUG SendMessage] 发送WebSocket消息给 UserID=%s, 消息类型=%s", p.UserID, wsMsg.Type)
|
||||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Printf("[DEBUG SendMessage] 获取参与者失败: %v", err)
|
||||
}
|
||||
|
||||
_ = participant // 避免未使用变量警告
|
||||
|
||||
return message, nil
|
||||
}
|
||||
|
||||
func containsImageSegment(segments model.MessageSegments) bool {
|
||||
for _, seg := range segments {
|
||||
if seg.Type == string(model.ContentTypeImage) || seg.Type == "image" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetMessages 获取消息历史(分页)
|
||||
func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) {
|
||||
// 验证用户是否是会话参与者
|
||||
_, err := s.repo.GetParticipant(conversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, 0, errors.New("conversation not found or no permission")
|
||||
}
|
||||
return nil, 0, fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
return s.repo.GetMessages(conversationID, page, pageSize)
|
||||
}
|
||||
|
||||
// GetMessagesAfterSeq 获取指定seq之后的消息(用于增量同步)
|
||||
func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) {
|
||||
// 验证用户是否是会话参与者
|
||||
_, err := s.repo.GetParticipant(conversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("conversation not found or no permission")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
return s.repo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
|
||||
}
|
||||
|
||||
// GetMessagesBeforeSeq 获取指定seq之前的历史消息(用于下拉加载更多)
|
||||
func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) {
|
||||
// 验证用户是否是会话参与者
|
||||
_, err := s.repo.GetParticipant(conversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("conversation not found or no permission")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
|
||||
return s.repo.GetMessagesBeforeSeq(conversationID, beforeSeq, limit)
|
||||
}
|
||||
|
||||
// MarkAsRead 标记已读
|
||||
func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error {
|
||||
// 验证用户是否是会话参与者
|
||||
_, err := s.repo.GetParticipant(conversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("conversation not found or no permission")
|
||||
}
|
||||
return fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
// 更新参与者的已读位置
|
||||
err = s.repo.UpdateLastReadSeq(conversationID, userID, seq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update last read seq: %w", err)
|
||||
}
|
||||
|
||||
// 发送已读回执(作为 meta 事件)
|
||||
if s.wsManager != nil {
|
||||
wsMsg := websocket.CreateWSMessage("meta", map[string]interface{}{
|
||||
"detail_type": websocket.MetaDetailTypeRead,
|
||||
"conversation_id": conversationID,
|
||||
"seq": seq,
|
||||
"user_id": userID,
|
||||
})
|
||||
|
||||
// 获取会话中的所有参与者
|
||||
participants, err := s.repo.GetConversationParticipants(conversationID)
|
||||
if err == nil {
|
||||
// 推送给会话中的所有参与者(包括自己)
|
||||
for _, p := range participants {
|
||||
if s.wsManager.IsUserOnline(p.UserID) {
|
||||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUnreadCount 获取指定会话的未读消息数
|
||||
func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
|
||||
// 验证用户是否是会话参与者
|
||||
_, err := s.repo.GetParticipant(conversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return 0, errors.New("conversation not found or no permission")
|
||||
}
|
||||
return 0, fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
return s.repo.GetUnreadCount(conversationID, userID)
|
||||
}
|
||||
|
||||
// GetAllUnreadCount 获取所有会话的未读消息总数
|
||||
func (s *chatServiceImpl) GetAllUnreadCount(ctx context.Context, userID string) (int64, error) {
|
||||
return s.repo.GetAllUnreadCount(userID)
|
||||
}
|
||||
|
||||
// RecallMessage 撤回消息(2分钟内)
|
||||
func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, userID string) error {
|
||||
// 获取消息
|
||||
var message model.Message
|
||||
err := s.db.First(&message, "id = ?", messageID).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("message not found")
|
||||
}
|
||||
return fmt.Errorf("failed to get message: %w", err)
|
||||
}
|
||||
|
||||
// 验证是否是消息发送者
|
||||
if message.SenderIDStr() != userID {
|
||||
return errors.New("can only recall your own messages")
|
||||
}
|
||||
|
||||
// 验证消息是否已被撤回
|
||||
if message.Status == model.MessageStatusRecalled {
|
||||
return errors.New("message already recalled")
|
||||
}
|
||||
|
||||
// 验证是否在2分钟内
|
||||
if time.Since(message.CreatedAt) > RecallMessageTimeout {
|
||||
return errors.New("message recall timeout (2 minutes)")
|
||||
}
|
||||
|
||||
// 更新消息状态为已撤回
|
||||
err = s.db.Model(&message).Update("status", model.MessageStatusRecalled).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to recall message: %w", err)
|
||||
}
|
||||
|
||||
// 发送撤回通知
|
||||
if s.wsManager != nil {
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeRecall, map[string]interface{}{
|
||||
"messageId": messageID,
|
||||
"conversationId": message.ConversationID,
|
||||
"senderId": userID,
|
||||
})
|
||||
|
||||
// 通知会话中的所有参与者
|
||||
participants, err := s.repo.GetConversationParticipants(message.ConversationID)
|
||||
if err == nil {
|
||||
for _, p := range participants {
|
||||
if s.wsManager.IsUserOnline(p.UserID) {
|
||||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteMessage 删除消息(仅对自己可见)
|
||||
func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, userID string) error {
|
||||
// 获取消息
|
||||
var message model.Message
|
||||
err := s.db.First(&message, "id = ?", messageID).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("message not found")
|
||||
}
|
||||
return fmt.Errorf("failed to get message: %w", err)
|
||||
}
|
||||
|
||||
// 验证用户是否是会话参与者
|
||||
_, err = s.repo.GetParticipant(message.ConversationID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("no permission to delete this message")
|
||||
}
|
||||
return fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
// 对于删除消息,我们使用软删除,但需要确保只对当前用户隐藏
|
||||
// 这里简化处理:只有发送者可以删除自己的消息
|
||||
if message.SenderIDStr() != userID {
|
||||
return errors.New("can only delete your own messages")
|
||||
}
|
||||
|
||||
// 更新消息状态为已删除
|
||||
err = s.db.Model(&message).Update("status", model.MessageStatusDeleted).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendTyping 发送正在输入状态
|
||||
func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) {
|
||||
if s.wsManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 验证用户是否是会话参与者
|
||||
_, err := s.repo.GetParticipant(conversationID, senderID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取会话中的其他参与者
|
||||
participants, err := s.repo.GetConversationParticipants(conversationID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, p := range participants {
|
||||
if p.UserID == senderID {
|
||||
continue
|
||||
}
|
||||
// 发送正在输入状态
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeTyping, map[string]string{
|
||||
"conversationId": conversationID,
|
||||
"senderId": senderID,
|
||||
})
|
||||
|
||||
if s.wsManager.IsUserOnline(p.UserID) {
|
||||
s.wsManager.SendToUser(p.UserID, wsMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastMessage 广播消息给用户
|
||||
func (s *chatServiceImpl) BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string) {
|
||||
if s.wsManager != nil {
|
||||
s.wsManager.SendToUser(targetUser, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// IsUserOnline 检查用户是否在线
|
||||
func (s *chatServiceImpl) IsUserOnline(userID string) bool {
|
||||
if s.wsManager == nil {
|
||||
return false
|
||||
}
|
||||
return s.wsManager.IsUserOnline(userID)
|
||||
}
|
||||
|
||||
// PushSystemMessage 推送系统消息给指定用户
|
||||
func (s *chatServiceImpl) PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error {
|
||||
if s.wsManager == nil {
|
||||
return errors.New("websocket manager not available")
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return errors.New("user is offline")
|
||||
}
|
||||
|
||||
sysMsg := &websocket.SystemMessage{
|
||||
ID: "", // 由调用方生成
|
||||
Type: msgType,
|
||||
Title: title,
|
||||
Content: content,
|
||||
Data: data,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PushNotificationMessage 推送通知消息给指定用户
|
||||
func (s *chatServiceImpl) PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error {
|
||||
if s.wsManager == nil {
|
||||
return errors.New("websocket manager not available")
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return errors.New("user is offline")
|
||||
}
|
||||
|
||||
// 确保时间戳已设置
|
||||
if notification.CreatedAt == 0 {
|
||||
notification.CreatedAt = time.Now().UnixMilli()
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PushAnnouncementMessage 广播公告消息给所有在线用户
|
||||
func (s *chatServiceImpl) PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error {
|
||||
if s.wsManager == nil {
|
||||
return errors.New("websocket manager not available")
|
||||
}
|
||||
|
||||
// 确保时间戳已设置
|
||||
if announcement.CreatedAt == 0 {
|
||||
announcement.CreatedAt = time.Now().UnixMilli()
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
|
||||
s.wsManager.Broadcast(wsMsg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveMessage 仅保存消息到数据库,不发送 WebSocket 推送
|
||||
// 适用于群聊等由调用方自行负责推送的场景
|
||||
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
|
||||
// 验证会话是否存在
|
||||
_, err := s.repo.GetConversation(conversationID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("会话不存在,请重新创建会话")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
||||
}
|
||||
|
||||
// 验证用户是否是会话参与者
|
||||
_, err = s.repo.GetParticipant(conversationID, senderID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("您不是该会话的参与者")
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||||
}
|
||||
|
||||
message := &model.Message{
|
||||
ConversationID: conversationID,
|
||||
SenderID: senderID,
|
||||
Segments: segments,
|
||||
ReplyToID: replyToID,
|
||||
Status: model.MessageStatusNormal,
|
||||
}
|
||||
|
||||
if err := s.repo.CreateMessageWithSeq(message); err != nil {
|
||||
return nil, fmt.Errorf("failed to save message: %w", err)
|
||||
}
|
||||
|
||||
return message, nil
|
||||
}
|
||||
273
internal/service/comment_service.go
Normal file
273
internal/service/comment_service.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/gorse"
|
||||
"carrot_bbs/internal/repository"
|
||||
)
|
||||
|
||||
// CommentService 评论服务
|
||||
type CommentService struct {
|
||||
commentRepo *repository.CommentRepository
|
||||
postRepo *repository.PostRepository
|
||||
systemMessageService SystemMessageService
|
||||
gorseClient gorse.Client
|
||||
postAIService *PostAIService
|
||||
}
|
||||
|
||||
// NewCommentService 创建评论服务
|
||||
func NewCommentService(commentRepo *repository.CommentRepository, postRepo *repository.PostRepository, systemMessageService SystemMessageService, gorseClient gorse.Client, postAIService *PostAIService) *CommentService {
|
||||
return &CommentService{
|
||||
commentRepo: commentRepo,
|
||||
postRepo: postRepo,
|
||||
systemMessageService: systemMessageService,
|
||||
gorseClient: gorseClient,
|
||||
postAIService: postAIService,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建评论
|
||||
func (s *CommentService) Create(ctx context.Context, postID, userID, content string, parentID *string, images string, imageURLs []string) (*model.Comment, error) {
|
||||
if s.postAIService != nil {
|
||||
// 采用异步审核,前端先立即返回
|
||||
}
|
||||
|
||||
// 获取帖子信息用于发送通知
|
||||
post, err := s.postRepo.GetByID(postID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
comment := &model.Comment{
|
||||
PostID: postID,
|
||||
UserID: userID,
|
||||
Content: content,
|
||||
ParentID: parentID,
|
||||
Images: images,
|
||||
Status: model.CommentStatusPending,
|
||||
}
|
||||
|
||||
// 如果有父评论,设置根评论ID
|
||||
var parentUserID string
|
||||
if parentID != nil {
|
||||
parent, err := s.commentRepo.GetByID(*parentID)
|
||||
if err == nil && parent != nil {
|
||||
if parent.RootID != nil {
|
||||
comment.RootID = parent.RootID
|
||||
} else {
|
||||
comment.RootID = parentID
|
||||
}
|
||||
parentUserID = parent.UserID
|
||||
}
|
||||
}
|
||||
|
||||
err = s.commentRepo.Create(comment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 重新查询以获取关联的 User
|
||||
comment, err = s.commentRepo.GetByID(comment.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go s.reviewCommentAsync(comment.ID, userID, postID, content, imageURLs, parentID, parentUserID, post.UserID)
|
||||
|
||||
return comment, nil
|
||||
}
|
||||
|
||||
func (s *CommentService) reviewCommentAsync(
|
||||
commentID, userID, postID, content string,
|
||||
imageURLs []string,
|
||||
parentID *string,
|
||||
parentUserID string,
|
||||
postOwnerID string,
|
||||
) {
|
||||
// 未启用AI时,直接通过审核并发送后续通知
|
||||
if s.postAIService == nil || !s.postAIService.IsEnabled() {
|
||||
if err := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); err != nil {
|
||||
log.Printf("[WARN] Failed to publish comment without AI moderation: %v", err)
|
||||
return
|
||||
}
|
||||
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
||||
return
|
||||
}
|
||||
|
||||
err := s.postAIService.ModerateComment(context.Background(), content, imageURLs)
|
||||
if err != nil {
|
||||
var rejectedErr *CommentModerationRejectedError
|
||||
if errors.As(err, &rejectedErr) {
|
||||
if delErr := s.commentRepo.Delete(commentID); delErr != nil {
|
||||
log.Printf("[WARN] Failed to delete rejected comment %s: %v", commentID, delErr)
|
||||
}
|
||||
s.notifyCommentModerationRejected(userID, rejectedErr.Reason)
|
||||
return
|
||||
}
|
||||
|
||||
// 审核服务异常时降级放行,避免评论长期pending
|
||||
if updateErr := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); updateErr != nil {
|
||||
log.Printf("[WARN] Failed to publish comment %s after moderation error: %v", commentID, updateErr)
|
||||
return
|
||||
}
|
||||
log.Printf("[WARN] Comment moderation failed, fallback publish comment=%s err=%v", commentID, err)
|
||||
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
||||
return
|
||||
}
|
||||
|
||||
if updateErr := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); updateErr != nil {
|
||||
log.Printf("[WARN] Failed to publish comment %s: %v", commentID, updateErr)
|
||||
return
|
||||
}
|
||||
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
|
||||
}
|
||||
|
||||
func (s *CommentService) afterCommentPublished(userID, postID, commentID string, parentID *string, parentUserID, postOwnerID string) {
|
||||
// 发送系统消息通知
|
||||
if s.systemMessageService != nil {
|
||||
go func() {
|
||||
if parentID != nil && parentUserID != "" {
|
||||
// 回复评论,通知被回复的人
|
||||
if parentUserID != userID {
|
||||
notifyErr := s.systemMessageService.SendReplyNotification(context.Background(), parentUserID, userID, postID, *parentID, commentID)
|
||||
if notifyErr != nil {
|
||||
fmt.Printf("[DEBUG] Error sending reply notification: %v\n", notifyErr)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 评论帖子,通知帖子作者
|
||||
if postOwnerID != userID {
|
||||
notifyErr := s.systemMessageService.SendCommentNotification(context.Background(), postOwnerID, userID, postID, commentID)
|
||||
if notifyErr != nil {
|
||||
fmt.Printf("[DEBUG] Error sending comment notification: %v\n", notifyErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 推送评论行为到Gorse(异步)
|
||||
go func() {
|
||||
if s.gorseClient.IsEnabled() {
|
||||
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeComment, userID, postID); err != nil {
|
||||
log.Printf("[WARN] Failed to insert comment feedback to Gorse: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *CommentService) notifyCommentModerationRejected(userID, reason string) {
|
||||
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
content := "您发布的评论未通过AI审核,请修改后重试。"
|
||||
if strings.TrimSpace(reason) != "" {
|
||||
content = fmt.Sprintf("您发布的评论未通过AI审核,原因:%s。请修改后重试。", reason)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := s.systemMessageService.SendSystemAnnouncement(
|
||||
context.Background(),
|
||||
[]string{userID},
|
||||
"评论审核未通过",
|
||||
content,
|
||||
); err != nil {
|
||||
log.Printf("[WARN] Failed to send comment moderation reject notification: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取评论
|
||||
func (s *CommentService) GetByID(ctx context.Context, id string) (*model.Comment, error) {
|
||||
return s.commentRepo.GetByID(id)
|
||||
}
|
||||
|
||||
// GetByPostID 获取帖子评论
|
||||
func (s *CommentService) GetByPostID(ctx context.Context, postID string, page, pageSize int) ([]*model.Comment, int64, error) {
|
||||
// 使用带回复的查询,默认加载前3条回复
|
||||
return s.commentRepo.GetByPostIDWithReplies(postID, page, pageSize, 3)
|
||||
}
|
||||
|
||||
// GetRepliesByRootID 根据根评论ID分页获取回复
|
||||
func (s *CommentService) GetRepliesByRootID(ctx context.Context, rootID string, page, pageSize int) ([]*model.Comment, int64, error) {
|
||||
return s.commentRepo.GetRepliesByRootID(rootID, page, pageSize)
|
||||
}
|
||||
|
||||
// GetReplies 获取回复
|
||||
func (s *CommentService) GetReplies(ctx context.Context, parentID string) ([]*model.Comment, error) {
|
||||
return s.commentRepo.GetReplies(parentID)
|
||||
}
|
||||
|
||||
// Update 更新评论
|
||||
func (s *CommentService) Update(ctx context.Context, comment *model.Comment) error {
|
||||
return s.commentRepo.Update(comment)
|
||||
}
|
||||
|
||||
// Delete 删除评论
|
||||
func (s *CommentService) Delete(ctx context.Context, id string) error {
|
||||
return s.commentRepo.Delete(id)
|
||||
}
|
||||
|
||||
// Like 点赞评论
|
||||
func (s *CommentService) Like(ctx context.Context, commentID, userID string) error {
|
||||
// 获取评论信息用于发送通知
|
||||
comment, err := s.commentRepo.GetByID(commentID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.commentRepo.Like(commentID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 发送评论/回复点赞通知(只有不是给自己点赞时才发送)
|
||||
if s.systemMessageService != nil && comment.UserID != userID {
|
||||
go func() {
|
||||
var notifyErr error
|
||||
if comment.ParentID != nil {
|
||||
notifyErr = s.systemMessageService.SendLikeReplyNotification(
|
||||
context.Background(),
|
||||
comment.UserID,
|
||||
userID,
|
||||
comment.PostID,
|
||||
commentID,
|
||||
comment.Content,
|
||||
)
|
||||
} else {
|
||||
notifyErr = s.systemMessageService.SendLikeCommentNotification(
|
||||
context.Background(),
|
||||
comment.UserID,
|
||||
userID,
|
||||
comment.PostID,
|
||||
commentID,
|
||||
comment.Content,
|
||||
)
|
||||
}
|
||||
if notifyErr != nil {
|
||||
fmt.Printf("[DEBUG] Error sending like notification: %v\n", notifyErr)
|
||||
} else {
|
||||
fmt.Printf("[DEBUG] Like notification sent successfully\n")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlike 取消点赞评论
|
||||
func (s *CommentService) Unlike(ctx context.Context, commentID, userID string) error {
|
||||
return s.commentRepo.Unlike(commentID, userID)
|
||||
}
|
||||
|
||||
// IsLiked 检查是否已点赞
|
||||
func (s *CommentService) IsLiked(ctx context.Context, commentID, userID string) bool {
|
||||
return s.commentRepo.IsLiked(commentID, userID)
|
||||
}
|
||||
234
internal/service/email_code_service.go
Normal file
234
internal/service/email_code_service.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
verifyCodeTTL = 10 * time.Minute
|
||||
verifyCodeRateLimitTTL = 60 * time.Second
|
||||
)
|
||||
|
||||
const (
|
||||
CodePurposeRegister = "register"
|
||||
CodePurposePasswordReset = "password_reset"
|
||||
CodePurposeEmailVerify = "email_verify"
|
||||
CodePurposeChangePassword = "change_password"
|
||||
)
|
||||
|
||||
type verificationCodePayload struct {
|
||||
Code string `json:"code"`
|
||||
Purpose string `json:"purpose"`
|
||||
Email string `json:"email"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
}
|
||||
|
||||
type EmailCodeService interface {
|
||||
SendCode(ctx context.Context, purpose, email string) error
|
||||
VerifyCode(purpose, email, code string) error
|
||||
}
|
||||
|
||||
type emailCodeServiceImpl struct {
|
||||
emailService EmailService
|
||||
cache cache.Cache
|
||||
}
|
||||
|
||||
func NewEmailCodeService(emailService EmailService, cacheBackend cache.Cache) EmailCodeService {
|
||||
if cacheBackend == nil {
|
||||
cacheBackend = cache.GetCache()
|
||||
}
|
||||
return &emailCodeServiceImpl{
|
||||
emailService: emailService,
|
||||
cache: cacheBackend,
|
||||
}
|
||||
}
|
||||
|
||||
func verificationCodeCacheKey(purpose, email string) string {
|
||||
return fmt.Sprintf("auth:verify_code:%s:%s", purpose, strings.ToLower(strings.TrimSpace(email)))
|
||||
}
|
||||
|
||||
func verificationCodeRateLimitKey(purpose, email string) string {
|
||||
return fmt.Sprintf("auth:verify_code_rate_limit:%s:%s", purpose, strings.ToLower(strings.TrimSpace(email)))
|
||||
}
|
||||
|
||||
func generateNumericCode(length int) (string, error) {
|
||||
if length <= 0 {
|
||||
return "", fmt.Errorf("invalid code length")
|
||||
}
|
||||
max := big.NewInt(10)
|
||||
result := make([]byte, length)
|
||||
for i := 0; i < length; i++ {
|
||||
n, err := rand.Int(rand.Reader, max)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
result[i] = byte('0' + n.Int64())
|
||||
}
|
||||
return string(result), nil
|
||||
}
|
||||
|
||||
func (s *emailCodeServiceImpl) SendCode(ctx context.Context, purpose, email string) error {
|
||||
if strings.TrimSpace(email) == "" || !utils.ValidateEmail(email) {
|
||||
return ErrInvalidEmail
|
||||
}
|
||||
if s.emailService == nil || !s.emailService.IsEnabled() {
|
||||
return ErrEmailServiceUnavailable
|
||||
}
|
||||
if s.cache == nil {
|
||||
return ErrVerificationCodeUnavailable
|
||||
}
|
||||
|
||||
rateLimitKey := verificationCodeRateLimitKey(purpose, email)
|
||||
if s.cache.Exists(rateLimitKey) {
|
||||
return ErrVerificationCodeTooFrequent
|
||||
}
|
||||
|
||||
code, err := generateNumericCode(6)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate verification code failed: %w", err)
|
||||
}
|
||||
payload := verificationCodePayload{
|
||||
Code: code,
|
||||
Purpose: purpose,
|
||||
Email: strings.ToLower(strings.TrimSpace(email)),
|
||||
ExpiresAt: time.Now().Add(verifyCodeTTL).Unix(),
|
||||
}
|
||||
cacheKey := verificationCodeCacheKey(purpose, email)
|
||||
s.cache.Set(cacheKey, payload, verifyCodeTTL)
|
||||
s.cache.Set(rateLimitKey, "1", verifyCodeRateLimitTTL)
|
||||
|
||||
subject, sceneText := verificationEmailMeta(purpose)
|
||||
textBody := fmt.Sprintf("【%s】验证码:%s\n有效期:10分钟\n请勿将验证码泄露给他人。", sceneText, code)
|
||||
htmlBody := buildVerificationEmailHTML(sceneText, code)
|
||||
if err := s.emailService.Send(ctx, SendEmailRequest{
|
||||
To: []string{email},
|
||||
Subject: subject,
|
||||
TextBody: textBody,
|
||||
HTMLBody: htmlBody,
|
||||
}); err != nil {
|
||||
s.cache.Delete(cacheKey)
|
||||
return fmt.Errorf("send verification email failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCodeServiceImpl) VerifyCode(purpose, email, code string) error {
|
||||
if strings.TrimSpace(email) == "" || strings.TrimSpace(code) == "" {
|
||||
return ErrVerificationCodeInvalid
|
||||
}
|
||||
if s.cache == nil {
|
||||
return ErrVerificationCodeUnavailable
|
||||
}
|
||||
|
||||
cacheKey := verificationCodeCacheKey(purpose, email)
|
||||
raw, ok := s.cache.Get(cacheKey)
|
||||
if !ok {
|
||||
return ErrVerificationCodeExpired
|
||||
}
|
||||
|
||||
var payload verificationCodePayload
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
if err := json.Unmarshal([]byte(v), &payload); err != nil {
|
||||
return ErrVerificationCodeInvalid
|
||||
}
|
||||
case []byte:
|
||||
if err := json.Unmarshal(v, &payload); err != nil {
|
||||
return ErrVerificationCodeInvalid
|
||||
}
|
||||
case verificationCodePayload:
|
||||
payload = v
|
||||
default:
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return ErrVerificationCodeInvalid
|
||||
}
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
return ErrVerificationCodeInvalid
|
||||
}
|
||||
}
|
||||
|
||||
if payload.Purpose != purpose || payload.Email != strings.ToLower(strings.TrimSpace(email)) {
|
||||
return ErrVerificationCodeInvalid
|
||||
}
|
||||
if payload.ExpiresAt > 0 && time.Now().Unix() > payload.ExpiresAt {
|
||||
s.cache.Delete(cacheKey)
|
||||
return ErrVerificationCodeExpired
|
||||
}
|
||||
if payload.Code != strings.TrimSpace(code) {
|
||||
return ErrVerificationCodeInvalid
|
||||
}
|
||||
|
||||
s.cache.Delete(cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func verificationEmailMeta(purpose string) (subject string, sceneText string) {
|
||||
switch purpose {
|
||||
case CodePurposeRegister:
|
||||
return "Carrot BBS 注册验证码", "注册账号"
|
||||
case CodePurposePasswordReset:
|
||||
return "Carrot BBS 找回密码验证码", "找回密码"
|
||||
case CodePurposeEmailVerify:
|
||||
return "Carrot BBS 邮箱验证验证码", "验证邮箱"
|
||||
case CodePurposeChangePassword:
|
||||
return "Carrot BBS 修改密码验证码", "修改密码"
|
||||
default:
|
||||
return "Carrot BBS 验证码", "身份验证"
|
||||
}
|
||||
}
|
||||
|
||||
func buildVerificationEmailHTML(sceneText, code string) string {
|
||||
return fmt.Sprintf(`<!doctype html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Carrot BBS 验证码</title>
|
||||
</head>
|
||||
<body style="margin:0;padding:0;background:#f4f6fb;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,'PingFang SC','Microsoft YaHei',sans-serif;color:#1f2937;">
|
||||
<table role="presentation" width="100%%" cellspacing="0" cellpadding="0" style="background:#f4f6fb;padding:24px 12px;">
|
||||
<tr>
|
||||
<td align="center">
|
||||
<table role="presentation" width="100%%" cellspacing="0" cellpadding="0" style="max-width:560px;background:#ffffff;border-radius:14px;overflow:hidden;box-shadow:0 8px 30px rgba(15,23,42,0.08);">
|
||||
<tr>
|
||||
<td style="background:linear-gradient(135deg,#ff6b35,#ff8f66);padding:24px 28px;color:#ffffff;">
|
||||
<div style="font-size:22px;font-weight:700;line-height:1.2;">Carrot BBS</div>
|
||||
<div style="margin-top:6px;font-size:14px;opacity:0.95;">%s 验证</div>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="padding:28px;">
|
||||
<p style="margin:0 0 14px;font-size:15px;line-height:1.75;">你好,</p>
|
||||
<p style="margin:0 0 20px;font-size:15px;line-height:1.75;">你正在进行 <strong>%s</strong> 操作,请使用下方验证码完成验证:</p>
|
||||
<div style="margin:0 auto 18px;max-width:320px;border:1px dashed #ff8f66;background:#fff8f4;border-radius:12px;padding:14px 12px;text-align:center;">
|
||||
<div style="font-size:13px;color:#9a3412;letter-spacing:0.5px;">验证码(10分钟内有效)</div>
|
||||
<div style="margin-top:8px;font-size:34px;line-height:1;font-weight:800;letter-spacing:8px;color:#ea580c;">%s</div>
|
||||
</div>
|
||||
<p style="margin:0 0 8px;font-size:13px;color:#6b7280;line-height:1.7;">如果不是你本人操作,请忽略此邮件,并及时检查账号安全。</p>
|
||||
<p style="margin:0;font-size:13px;color:#6b7280;line-height:1.7;">请勿向任何人透露验证码,平台不会以任何理由索取验证码。</p>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="padding:14px 28px;background:#f8fafc;border-top:1px solid #e5e7eb;color:#94a3b8;font-size:12px;line-height:1.7;">
|
||||
此邮件由系统自动发送,请勿直接回复。<br/>
|
||||
© Carrot BBS
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>`, sceneText, sceneText, code)
|
||||
}
|
||||
82
internal/service/email_service.go
Normal file
82
internal/service/email_service.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
emailpkg "carrot_bbs/internal/pkg/email"
|
||||
)
|
||||
|
||||
// SendEmailRequest 发信请求
|
||||
type SendEmailRequest struct {
|
||||
To []string
|
||||
Cc []string
|
||||
Bcc []string
|
||||
ReplyTo []string
|
||||
Subject string
|
||||
TextBody string
|
||||
HTMLBody string
|
||||
Attachments []string
|
||||
}
|
||||
|
||||
type EmailService interface {
|
||||
IsEnabled() bool
|
||||
Send(ctx context.Context, req SendEmailRequest) error
|
||||
SendText(ctx context.Context, to []string, subject, body string) error
|
||||
SendHTML(ctx context.Context, to []string, subject, html string) error
|
||||
}
|
||||
|
||||
type emailServiceImpl struct {
|
||||
client emailpkg.Client
|
||||
}
|
||||
|
||||
func NewEmailService(client emailpkg.Client) EmailService {
|
||||
return &emailServiceImpl{client: client}
|
||||
}
|
||||
|
||||
func (s *emailServiceImpl) IsEnabled() bool {
|
||||
return s.client != nil && s.client.IsEnabled()
|
||||
}
|
||||
|
||||
func (s *emailServiceImpl) Send(ctx context.Context, req SendEmailRequest) error {
|
||||
if s.client == nil {
|
||||
return fmt.Errorf("email client is nil")
|
||||
}
|
||||
if !s.client.IsEnabled() {
|
||||
return fmt.Errorf("email service is disabled")
|
||||
}
|
||||
if len(req.To) == 0 {
|
||||
return fmt.Errorf("email recipient is empty")
|
||||
}
|
||||
if strings.TrimSpace(req.Subject) == "" {
|
||||
return fmt.Errorf("email subject is empty")
|
||||
}
|
||||
|
||||
return s.client.Send(ctx, emailpkg.Message{
|
||||
To: req.To,
|
||||
Cc: req.Cc,
|
||||
Bcc: req.Bcc,
|
||||
ReplyTo: req.ReplyTo,
|
||||
Subject: req.Subject,
|
||||
TextBody: req.TextBody,
|
||||
HTMLBody: req.HTMLBody,
|
||||
Attachments: req.Attachments,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *emailServiceImpl) SendText(ctx context.Context, to []string, subject, body string) error {
|
||||
return s.Send(ctx, SendEmailRequest{
|
||||
To: to,
|
||||
Subject: subject,
|
||||
TextBody: body,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *emailServiceImpl) SendHTML(ctx context.Context, to []string, subject, html string) error {
|
||||
return s.Send(ctx, SendEmailRequest{
|
||||
To: to,
|
||||
Subject: subject,
|
||||
HTMLBody: html,
|
||||
})
|
||||
}
|
||||
1491
internal/service/group_service.go
Normal file
1491
internal/service/group_service.go
Normal file
File diff suppressed because it is too large
Load Diff
38
internal/service/jwt_service.go
Normal file
38
internal/service/jwt_service.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/pkg/jwt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JWTService JWT服务
|
||||
type JWTService struct {
|
||||
jwt *jwt.JWT
|
||||
}
|
||||
|
||||
// NewJWTService 创建JWT服务
|
||||
func NewJWTService(secret string, accessExpire, refreshExpire int64) *JWTService {
|
||||
return &JWTService{
|
||||
jwt: jwt.New(secret, time.Duration(accessExpire)*time.Second, time.Duration(refreshExpire)*time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateAccessToken 生成访问令牌
|
||||
func (s *JWTService) GenerateAccessToken(userID, username string) (string, error) {
|
||||
return s.jwt.GenerateAccessToken(userID, username)
|
||||
}
|
||||
|
||||
// GenerateRefreshToken 生成刷新令牌
|
||||
func (s *JWTService) GenerateRefreshToken(userID, username string) (string, error) {
|
||||
return s.jwt.GenerateRefreshToken(userID, username)
|
||||
}
|
||||
|
||||
// ParseToken 解析令牌
|
||||
func (s *JWTService) ParseToken(tokenString string) (*jwt.Claims, error) {
|
||||
return s.jwt.ParseToken(tokenString)
|
||||
}
|
||||
|
||||
// ValidateToken 验证令牌
|
||||
func (s *JWTService) ValidateToken(tokenString string) error {
|
||||
return s.jwt.ValidateToken(tokenString)
|
||||
}
|
||||
215
internal/service/message_service.go
Normal file
215
internal/service/message_service.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/repository"
|
||||
)
|
||||
|
||||
// 缓存TTL常量
|
||||
const (
|
||||
ConversationListTTL = 60 * time.Second // 会话列表缓存60秒
|
||||
ConversationDetailTTL = 60 * time.Second // 会话详情缓存60秒
|
||||
UnreadCountTTL = 30 * time.Second // 未读数缓存30秒
|
||||
ConversationNullTTL = 5 * time.Second
|
||||
UnreadNullTTL = 5 * time.Second
|
||||
CacheJitterRatio = 0.1
|
||||
)
|
||||
|
||||
// MessageService 消息服务
|
||||
type MessageService struct {
|
||||
messageRepo *repository.MessageRepository
|
||||
cache cache.Cache
|
||||
}
|
||||
|
||||
// NewMessageService 创建消息服务
|
||||
func NewMessageService(messageRepo *repository.MessageRepository) *MessageService {
|
||||
return &MessageService{
|
||||
messageRepo: messageRepo,
|
||||
cache: cache.GetCache(),
|
||||
}
|
||||
}
|
||||
|
||||
// ConversationListResult 会话列表缓存结果
|
||||
type ConversationListResult struct {
|
||||
Conversations []*model.Conversation
|
||||
Total int64
|
||||
}
|
||||
|
||||
// SendMessage 发送消息(使用 segments)
|
||||
// senderID 和 receiverID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (s *MessageService) SendMessage(ctx context.Context, senderID, receiverID string, segments model.MessageSegments) (*model.Message, error) {
|
||||
// 获取或创建会话
|
||||
conv, err := s.messageRepo.GetOrCreatePrivateConversation(senderID, receiverID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg := &model.Message{
|
||||
ConversationID: conv.ID,
|
||||
SenderID: senderID,
|
||||
Segments: segments,
|
||||
Status: model.MessageStatusNormal,
|
||||
}
|
||||
|
||||
// 使用事务创建消息并更新seq
|
||||
err = s.messageRepo.CreateMessageWithSeq(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 失效会话列表缓存(发送者和接收者)
|
||||
cache.InvalidateConversationList(s.cache, senderID)
|
||||
cache.InvalidateConversationList(s.cache, receiverID)
|
||||
|
||||
// 失效未读数缓存
|
||||
cache.InvalidateUnreadConversation(s.cache, receiverID)
|
||||
cache.InvalidateUnreadDetail(s.cache, receiverID, conv.ID)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// GetConversations 获取会话列表(带缓存)
|
||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (s *MessageService) GetConversations(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
|
||||
cacheSettings := cache.GetSettings()
|
||||
conversationTTL := cacheSettings.ConversationTTL
|
||||
if conversationTTL <= 0 {
|
||||
conversationTTL = ConversationListTTL
|
||||
}
|
||||
nullTTL := cacheSettings.NullTTL
|
||||
if nullTTL <= 0 {
|
||||
nullTTL = ConversationNullTTL
|
||||
}
|
||||
jitter := cacheSettings.JitterRatio
|
||||
if jitter <= 0 {
|
||||
jitter = CacheJitterRatio
|
||||
}
|
||||
|
||||
// 生成缓存键
|
||||
cacheKey := cache.ConversationListKey(userID, page, pageSize)
|
||||
result, err := cache.GetOrLoadTyped[*ConversationListResult](
|
||||
s.cache,
|
||||
cacheKey,
|
||||
conversationTTL,
|
||||
jitter,
|
||||
nullTTL,
|
||||
func() (*ConversationListResult, error) {
|
||||
conversations, total, err := s.messageRepo.GetConversations(userID, page, pageSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ConversationListResult{
|
||||
Conversations: conversations,
|
||||
Total: total,
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if result == nil {
|
||||
return []*model.Conversation{}, 0, nil
|
||||
}
|
||||
return result.Conversations, result.Total, nil
|
||||
}
|
||||
|
||||
// GetMessages 获取消息列表
|
||||
func (s *MessageService) GetMessages(ctx context.Context, conversationID string, page, pageSize int) ([]*model.Message, int64, error) {
|
||||
return s.messageRepo.GetMessages(conversationID, page, pageSize)
|
||||
}
|
||||
|
||||
// GetMessagesAfterSeq 获取指定seq之后的消息(增量同步)
|
||||
func (s *MessageService) GetMessagesAfterSeq(ctx context.Context, conversationID string, afterSeq int64, limit int) ([]*model.Message, error) {
|
||||
return s.messageRepo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
|
||||
}
|
||||
|
||||
// MarkAsRead 标记为已读
|
||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (s *MessageService) MarkAsRead(ctx context.Context, conversationID string, userID string, lastReadSeq int64) error {
|
||||
err := s.messageRepo.UpdateLastReadSeq(conversationID, userID, lastReadSeq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效未读数缓存
|
||||
cache.InvalidateUnreadConversation(s.cache, userID)
|
||||
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
|
||||
|
||||
// 失效会话列表缓存
|
||||
cache.InvalidateConversationList(s.cache, userID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUnreadCount 获取未读消息数(带缓存)
|
||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (s *MessageService) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
|
||||
cacheSettings := cache.GetSettings()
|
||||
unreadTTL := cacheSettings.UnreadCountTTL
|
||||
if unreadTTL <= 0 {
|
||||
unreadTTL = UnreadCountTTL
|
||||
}
|
||||
nullTTL := cacheSettings.NullTTL
|
||||
if nullTTL <= 0 {
|
||||
nullTTL = UnreadNullTTL
|
||||
}
|
||||
jitter := cacheSettings.JitterRatio
|
||||
if jitter <= 0 {
|
||||
jitter = CacheJitterRatio
|
||||
}
|
||||
|
||||
// 生成缓存键
|
||||
cacheKey := cache.UnreadDetailKey(userID, conversationID)
|
||||
|
||||
return cache.GetOrLoadTyped[int64](
|
||||
s.cache,
|
||||
cacheKey,
|
||||
unreadTTL,
|
||||
jitter,
|
||||
nullTTL,
|
||||
func() (int64, error) {
|
||||
return s.messageRepo.GetUnreadCount(conversationID, userID)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// GetOrCreateConversation 获取或创建私聊会话
|
||||
// user1ID 和 user2ID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||
func (s *MessageService) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
|
||||
conv, err := s.messageRepo.GetOrCreatePrivateConversation(user1ID, user2ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 失效会话列表缓存
|
||||
cache.InvalidateConversationList(s.cache, user1ID)
|
||||
cache.InvalidateConversationList(s.cache, user2ID)
|
||||
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
// GetConversationParticipants 获取会话参与者列表
|
||||
func (s *MessageService) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) {
|
||||
return s.messageRepo.GetConversationParticipants(conversationID)
|
||||
}
|
||||
|
||||
// ParseConversationID 辅助函数:直接返回字符串ID(已经是string类型)
|
||||
func ParseConversationID(idStr string) (string, error) {
|
||||
return idStr, nil
|
||||
}
|
||||
|
||||
// InvalidateUserConversationCache 失效用户会话相关缓存(供外部调用)
|
||||
func (s *MessageService) InvalidateUserConversationCache(userID string) {
|
||||
cache.InvalidateConversationList(s.cache, userID)
|
||||
cache.InvalidateUnreadConversation(s.cache, userID)
|
||||
}
|
||||
|
||||
// InvalidateUserUnreadCache 失效用户未读数缓存(供外部调用)
|
||||
func (s *MessageService) InvalidateUserUnreadCache(userID, conversationID string) {
|
||||
cache.InvalidateUnreadConversation(s.cache, userID)
|
||||
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
|
||||
}
|
||||
169
internal/service/notification_service.go
Normal file
169
internal/service/notification_service.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/repository"
|
||||
)
|
||||
|
||||
// 缓存TTL常量
|
||||
const (
|
||||
NotificationUnreadCountTTL = 30 * time.Second // 通知未读数缓存30秒
|
||||
NotificationNullTTL = 5 * time.Second
|
||||
NotificationCacheJitter = 0.1
|
||||
)
|
||||
|
||||
// NotificationService 通知服务
|
||||
type NotificationService struct {
|
||||
notificationRepo *repository.NotificationRepository
|
||||
cache cache.Cache
|
||||
}
|
||||
|
||||
// NewNotificationService 创建通知服务
|
||||
func NewNotificationService(notificationRepo *repository.NotificationRepository) *NotificationService {
|
||||
return &NotificationService{
|
||||
notificationRepo: notificationRepo,
|
||||
cache: cache.GetCache(),
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建通知
|
||||
func (s *NotificationService) Create(ctx context.Context, userID string, notificationType model.NotificationType, title, content string) (*model.Notification, error) {
|
||||
notification := &model.Notification{
|
||||
UserID: userID,
|
||||
Type: notificationType,
|
||||
Title: title,
|
||||
Content: content,
|
||||
IsRead: false,
|
||||
}
|
||||
|
||||
err := s.notificationRepo.Create(notification)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 失效未读数缓存
|
||||
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||
|
||||
return notification, nil
|
||||
}
|
||||
|
||||
// GetByUserID 获取用户通知
|
||||
func (s *NotificationService) GetByUserID(ctx context.Context, userID string, page, pageSize int, unreadOnly bool) ([]*model.Notification, int64, error) {
|
||||
return s.notificationRepo.GetByUserID(userID, page, pageSize, unreadOnly)
|
||||
}
|
||||
|
||||
// MarkAsRead 标记为已读
|
||||
func (s *NotificationService) MarkAsRead(ctx context.Context, id string) error {
|
||||
err := s.notificationRepo.MarkAsRead(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 注意:这里无法获取userID,所以不在缓存中失效
|
||||
// 调用方应该使用MarkAsReadWithUserID方法
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkAsReadWithUserID 标记为已读(带用户ID,用于缓存失效)
|
||||
func (s *NotificationService) MarkAsReadWithUserID(ctx context.Context, id, userID string) error {
|
||||
err := s.notificationRepo.MarkAsRead(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效未读数缓存
|
||||
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkAllAsRead 标记所有为已读
|
||||
func (s *NotificationService) MarkAllAsRead(ctx context.Context, userID string) error {
|
||||
err := s.notificationRepo.MarkAllAsRead(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效未读数缓存
|
||||
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除通知
|
||||
func (s *NotificationService) Delete(ctx context.Context, id string) error {
|
||||
return s.notificationRepo.Delete(id)
|
||||
}
|
||||
|
||||
// GetUnreadCount 获取未读数量(带缓存)
|
||||
func (s *NotificationService) GetUnreadCount(ctx context.Context, userID string) (int64, error) {
|
||||
cacheSettings := cache.GetSettings()
|
||||
unreadTTL := cacheSettings.UnreadCountTTL
|
||||
if unreadTTL <= 0 {
|
||||
unreadTTL = NotificationUnreadCountTTL
|
||||
}
|
||||
nullTTL := cacheSettings.NullTTL
|
||||
if nullTTL <= 0 {
|
||||
nullTTL = NotificationNullTTL
|
||||
}
|
||||
jitter := cacheSettings.JitterRatio
|
||||
if jitter <= 0 {
|
||||
jitter = NotificationCacheJitter
|
||||
}
|
||||
|
||||
// 生成缓存键
|
||||
cacheKey := cache.UnreadSystemKey(userID)
|
||||
return cache.GetOrLoadTyped[int64](
|
||||
s.cache,
|
||||
cacheKey,
|
||||
unreadTTL,
|
||||
jitter,
|
||||
nullTTL,
|
||||
func() (int64, error) {
|
||||
return s.notificationRepo.GetUnreadCount(userID)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// DeleteNotification 删除通知(带用户验证)
|
||||
func (s *NotificationService) DeleteNotification(ctx context.Context, id, userID string) error {
|
||||
// 先检查通知是否属于该用户
|
||||
notification, err := s.notificationRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if notification.UserID != userID {
|
||||
return ErrUnauthorizedNotification
|
||||
}
|
||||
|
||||
err = s.notificationRepo.Delete(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效未读数缓存
|
||||
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearAllNotifications 清空所有通知
|
||||
func (s *NotificationService) ClearAllNotifications(ctx context.Context, userID string) error {
|
||||
err := s.notificationRepo.DeleteAllByUserID(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效未读数缓存
|
||||
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 错误定义
|
||||
var ErrUnauthorizedNotification = &ServiceError{Code: 403, Message: "unauthorized to delete this notification"}
|
||||
103
internal/service/post_ai_service.go
Normal file
103
internal/service/post_ai_service.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"carrot_bbs/internal/pkg/openai"
|
||||
)
|
||||
|
||||
// PostModerationRejectedError 帖子审核拒绝错误
|
||||
type PostModerationRejectedError struct {
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (e *PostModerationRejectedError) Error() string {
|
||||
if strings.TrimSpace(e.Reason) == "" {
|
||||
return "post rejected by moderation"
|
||||
}
|
||||
return "post rejected by moderation: " + e.Reason
|
||||
}
|
||||
|
||||
// UserMessage 返回给前端的用户可读文案
|
||||
func (e *PostModerationRejectedError) UserMessage() string {
|
||||
if strings.TrimSpace(e.Reason) == "" {
|
||||
return "内容未通过审核,请修改后重试"
|
||||
}
|
||||
return strings.TrimSpace(e.Reason)
|
||||
}
|
||||
|
||||
// CommentModerationRejectedError 评论审核拒绝错误
|
||||
type CommentModerationRejectedError struct {
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (e *CommentModerationRejectedError) Error() string {
|
||||
if strings.TrimSpace(e.Reason) == "" {
|
||||
return "comment rejected by moderation"
|
||||
}
|
||||
return "comment rejected by moderation: " + e.Reason
|
||||
}
|
||||
|
||||
// UserMessage 返回给前端的用户可读文案
|
||||
func (e *CommentModerationRejectedError) UserMessage() string {
|
||||
if strings.TrimSpace(e.Reason) == "" {
|
||||
return "评论未通过审核,请修改后重试"
|
||||
}
|
||||
return strings.TrimSpace(e.Reason)
|
||||
}
|
||||
|
||||
type PostAIService struct {
|
||||
openAIClient openai.Client
|
||||
}
|
||||
|
||||
func NewPostAIService(openAIClient openai.Client) *PostAIService {
|
||||
return &PostAIService{
|
||||
openAIClient: openAIClient,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PostAIService) IsEnabled() bool {
|
||||
return s != nil && s.openAIClient != nil && s.openAIClient.IsEnabled()
|
||||
}
|
||||
|
||||
// ModeratePost 审核帖子内容,返回 nil 表示通过
|
||||
func (s *PostAIService) ModeratePost(ctx context.Context, title, content string, images []string) error {
|
||||
if !s.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
approved, reason, err := s.openAIClient.ModeratePost(ctx, title, content, images)
|
||||
if err != nil {
|
||||
if s.openAIClient.Config().StrictModeration {
|
||||
return err
|
||||
}
|
||||
log.Printf("[WARN] AI moderation failed, fallback allow: %v", err)
|
||||
return nil
|
||||
}
|
||||
if !approved {
|
||||
return &PostModerationRejectedError{Reason: reason}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ModerateComment 审核评论内容,返回 nil 表示通过
|
||||
func (s *PostAIService) ModerateComment(ctx context.Context, content string, images []string) error {
|
||||
if !s.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
approved, reason, err := s.openAIClient.ModerateComment(ctx, content, images)
|
||||
if err != nil {
|
||||
if s.openAIClient.Config().StrictModeration {
|
||||
return err
|
||||
}
|
||||
log.Printf("[WARN] AI comment moderation failed, fallback allow: %v", err)
|
||||
return nil
|
||||
}
|
||||
if !approved {
|
||||
return &CommentModerationRejectedError{Reason: reason}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
593
internal/service/post_service.go
Normal file
593
internal/service/post_service.go
Normal file
@@ -0,0 +1,593 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/gorse"
|
||||
"carrot_bbs/internal/repository"
|
||||
)
|
||||
|
||||
// 缓存TTL常量
|
||||
const (
|
||||
PostListTTL = 30 * time.Second // 帖子列表缓存30秒
|
||||
PostListNullTTL = 5 * time.Second
|
||||
PostListJitterRatio = 0.15
|
||||
anonymousViewUserID = "_anon_view"
|
||||
)
|
||||
|
||||
// PostService 帖子服务
|
||||
type PostService struct {
|
||||
postRepo *repository.PostRepository
|
||||
systemMessageService SystemMessageService
|
||||
cache cache.Cache
|
||||
gorseClient gorse.Client
|
||||
postAIService *PostAIService
|
||||
}
|
||||
|
||||
// NewPostService 创建帖子服务
|
||||
func NewPostService(postRepo *repository.PostRepository, systemMessageService SystemMessageService, gorseClient gorse.Client, postAIService *PostAIService) *PostService {
|
||||
return &PostService{
|
||||
postRepo: postRepo,
|
||||
systemMessageService: systemMessageService,
|
||||
cache: cache.GetCache(),
|
||||
gorseClient: gorseClient,
|
||||
postAIService: postAIService,
|
||||
}
|
||||
}
|
||||
|
||||
// PostListResult 帖子列表缓存结果
|
||||
type PostListResult struct {
|
||||
Posts []*model.Post
|
||||
Total int64
|
||||
}
|
||||
|
||||
// Create 创建帖子
|
||||
func (s *PostService) Create(ctx context.Context, userID, title, content string, images []string) (*model.Post, error) {
|
||||
post := &model.Post{
|
||||
UserID: userID,
|
||||
Title: title,
|
||||
Content: content,
|
||||
Status: model.PostStatusPending,
|
||||
}
|
||||
|
||||
err := s.postRepo.Create(post, images)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 失效帖子列表缓存
|
||||
cache.InvalidatePostList(s.cache)
|
||||
|
||||
// 同步到Gorse推荐系统(异步)
|
||||
go s.reviewPostAsync(post.ID, userID, title, content, images)
|
||||
|
||||
// 重新查询以获取关联的 User 和 Images
|
||||
return s.postRepo.GetByID(post.ID)
|
||||
}
|
||||
|
||||
func (s *PostService) reviewPostAsync(postID, userID, title, content string, images []string) {
|
||||
// 未启用AI时,直接发布
|
||||
if s.postAIService == nil || !s.postAIService.IsEnabled() {
|
||||
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil {
|
||||
log.Printf("[WARN] Failed to publish post without AI moderation: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err := s.postAIService.ModeratePost(context.Background(), title, content, images)
|
||||
if err != nil {
|
||||
var rejectedErr *PostModerationRejectedError
|
||||
if errors.As(err, &rejectedErr) {
|
||||
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
|
||||
log.Printf("[WARN] Failed to reject post %s: %v", postID, updateErr)
|
||||
}
|
||||
s.notifyModerationRejected(userID, rejectedErr.Reason)
|
||||
return
|
||||
}
|
||||
|
||||
// 规则审核不可用时,降级为发布,避免长时间pending
|
||||
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
|
||||
log.Printf("[WARN] Failed to publish post %s after moderation error: %v", postID, updateErr)
|
||||
}
|
||||
log.Printf("[WARN] Post moderation failed, fallback publish post=%s err=%v", postID, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "ai"); err != nil {
|
||||
log.Printf("[WARN] Failed to publish post %s: %v", postID, err)
|
||||
return
|
||||
}
|
||||
|
||||
if s.gorseClient.IsEnabled() {
|
||||
post, getErr := s.postRepo.GetByID(postID)
|
||||
if getErr != nil {
|
||||
log.Printf("[WARN] Failed to load published post for gorse sync: %v", getErr)
|
||||
return
|
||||
}
|
||||
categories := s.buildPostCategories(post)
|
||||
comment := post.Title
|
||||
textToEmbed := post.Title + " " + post.Content
|
||||
if upsertErr := s.gorseClient.UpsertItemWithEmbedding(context.Background(), post.ID, categories, comment, textToEmbed); upsertErr != nil {
|
||||
log.Printf("[WARN] Failed to upsert item to Gorse: %v", upsertErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PostService) notifyModerationRejected(userID, reason string) {
|
||||
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
content := "您发布的帖子未通过AI审核,请修改后重试。"
|
||||
if strings.TrimSpace(reason) != "" {
|
||||
content = fmt.Sprintf("您发布的帖子未通过AI审核,原因:%s。请修改后重试。", reason)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := s.systemMessageService.SendSystemAnnouncement(
|
||||
context.Background(),
|
||||
[]string{userID},
|
||||
"帖子审核未通过",
|
||||
content,
|
||||
); err != nil {
|
||||
log.Printf("[WARN] Failed to send moderation reject notification: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取帖子
|
||||
func (s *PostService) GetByID(ctx context.Context, id string) (*model.Post, error) {
|
||||
return s.postRepo.GetByID(id)
|
||||
}
|
||||
|
||||
// Update 更新帖子
|
||||
func (s *PostService) Update(ctx context.Context, post *model.Post) error {
|
||||
err := s.postRepo.Update(post)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效帖子详情缓存和列表缓存
|
||||
cache.InvalidatePostDetail(s.cache, post.ID)
|
||||
cache.InvalidatePostList(s.cache)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除帖子
|
||||
func (s *PostService) Delete(ctx context.Context, id string) error {
|
||||
err := s.postRepo.Delete(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效帖子详情缓存和列表缓存
|
||||
cache.InvalidatePostDetail(s.cache, id)
|
||||
cache.InvalidatePostList(s.cache)
|
||||
|
||||
// 从Gorse中删除帖子(异步)
|
||||
go func() {
|
||||
if s.gorseClient.IsEnabled() {
|
||||
if err := s.gorseClient.DeleteItem(context.Background(), id); err != nil {
|
||||
log.Printf("[WARN] Failed to delete item from Gorse: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 获取帖子列表(带缓存)
|
||||
func (s *PostService) List(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
||||
cacheSettings := cache.GetSettings()
|
||||
postListTTL := cacheSettings.PostListTTL
|
||||
if postListTTL <= 0 {
|
||||
postListTTL = PostListTTL
|
||||
}
|
||||
nullTTL := cacheSettings.NullTTL
|
||||
if nullTTL <= 0 {
|
||||
nullTTL = PostListNullTTL
|
||||
}
|
||||
jitter := cacheSettings.JitterRatio
|
||||
if jitter <= 0 {
|
||||
jitter = PostListJitterRatio
|
||||
}
|
||||
|
||||
// 生成缓存键(包含 userID 维度,避免过滤查询与全量查询互相污染)
|
||||
cacheKey := cache.PostListKey("latest", userID, page, pageSize)
|
||||
|
||||
result, err := cache.GetOrLoadTyped[*PostListResult](
|
||||
s.cache,
|
||||
cacheKey,
|
||||
postListTTL,
|
||||
jitter,
|
||||
nullTTL,
|
||||
func() (*PostListResult, error) {
|
||||
posts, total, err := s.postRepo.List(page, pageSize, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PostListResult{Posts: posts, Total: total}, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if result == nil {
|
||||
return []*model.Post{}, 0, nil
|
||||
}
|
||||
|
||||
// 兼容历史脏缓存:旧缓存序列化会丢失 Post.User,导致前端显示“匿名用户”
|
||||
// 这里检测并回源重建一次缓存,避免在 TTL 内持续返回缺失作者的数据。
|
||||
missingAuthor := false
|
||||
for _, post := range result.Posts {
|
||||
if post != nil && post.UserID != "" && post.User == nil {
|
||||
missingAuthor = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if missingAuthor {
|
||||
posts, total, loadErr := s.postRepo.List(page, pageSize, userID)
|
||||
if loadErr != nil {
|
||||
return nil, 0, loadErr
|
||||
}
|
||||
result = &PostListResult{Posts: posts, Total: total}
|
||||
cache.SetWithJitter(s.cache, cacheKey, result, postListTTL, jitter)
|
||||
}
|
||||
|
||||
return result.Posts, result.Total, nil
|
||||
}
|
||||
|
||||
// GetLatestPosts 获取最新帖子(语义化别名)
|
||||
func (s *PostService) GetLatestPosts(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
||||
return s.List(ctx, page, pageSize, userID)
|
||||
}
|
||||
|
||||
// GetUserPosts 获取用户帖子
|
||||
func (s *PostService) GetUserPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
return s.postRepo.GetUserPosts(userID, page, pageSize)
|
||||
}
|
||||
|
||||
// Like 点赞
|
||||
func (s *PostService) Like(ctx context.Context, postID, userID string) error {
|
||||
// 获取帖子信息用于发送通知
|
||||
post, err := s.postRepo.GetByID(postID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.postRepo.Like(postID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效帖子详情缓存
|
||||
cache.InvalidatePostDetail(s.cache, postID)
|
||||
|
||||
// 发送点赞通知(不给自己发通知)
|
||||
if s.systemMessageService != nil && post.UserID != userID {
|
||||
go func() {
|
||||
notifyErr := s.systemMessageService.SendLikeNotification(context.Background(), post.UserID, userID, postID)
|
||||
if notifyErr != nil {
|
||||
fmt.Printf("[DEBUG] Error sending like notification: %v\n", notifyErr)
|
||||
} else {
|
||||
fmt.Printf("[DEBUG] Like notification sent successfully\n")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 推送点赞行为到Gorse(异步)
|
||||
go func() {
|
||||
if s.gorseClient.IsEnabled() {
|
||||
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeLike, userID, postID); err != nil {
|
||||
log.Printf("[WARN] Failed to insert like feedback to Gorse: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlike 取消点赞
|
||||
func (s *PostService) Unlike(ctx context.Context, postID, userID string) error {
|
||||
err := s.postRepo.Unlike(postID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效帖子详情缓存
|
||||
cache.InvalidatePostDetail(s.cache, postID)
|
||||
|
||||
// 删除Gorse中的点赞反馈(异步)
|
||||
go func() {
|
||||
if s.gorseClient.IsEnabled() {
|
||||
if err := s.gorseClient.DeleteFeedback(context.Background(), gorse.FeedbackTypeLike, userID, postID); err != nil {
|
||||
log.Printf("[WARN] Failed to delete like feedback from Gorse: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsLiked 检查是否点赞
|
||||
func (s *PostService) IsLiked(ctx context.Context, postID, userID string) bool {
|
||||
return s.postRepo.IsLiked(postID, userID)
|
||||
}
|
||||
|
||||
// Favorite 收藏
|
||||
func (s *PostService) Favorite(ctx context.Context, postID, userID string) error {
|
||||
// 获取帖子信息用于发送通知
|
||||
post, err := s.postRepo.GetByID(postID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.postRepo.Favorite(postID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效帖子详情缓存
|
||||
cache.InvalidatePostDetail(s.cache, postID)
|
||||
|
||||
// 发送收藏通知(不给自己发通知)
|
||||
if s.systemMessageService != nil && post.UserID != userID {
|
||||
go func() {
|
||||
notifyErr := s.systemMessageService.SendFavoriteNotification(context.Background(), post.UserID, userID, postID)
|
||||
if notifyErr != nil {
|
||||
fmt.Printf("[DEBUG] Error sending favorite notification: %v\n", notifyErr)
|
||||
} else {
|
||||
fmt.Printf("[DEBUG] Favorite notification sent successfully\n")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 推送收藏行为到Gorse(异步)
|
||||
go func() {
|
||||
if s.gorseClient.IsEnabled() {
|
||||
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeStar, userID, postID); err != nil {
|
||||
log.Printf("[WARN] Failed to insert favorite feedback to Gorse: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unfavorite 取消收藏
|
||||
func (s *PostService) Unfavorite(ctx context.Context, postID, userID string) error {
|
||||
err := s.postRepo.Unfavorite(postID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 失效帖子详情缓存
|
||||
cache.InvalidatePostDetail(s.cache, postID)
|
||||
|
||||
// 删除Gorse中的收藏反馈(异步)
|
||||
go func() {
|
||||
if s.gorseClient.IsEnabled() {
|
||||
if err := s.gorseClient.DeleteFeedback(context.Background(), gorse.FeedbackTypeStar, userID, postID); err != nil {
|
||||
log.Printf("[WARN] Failed to delete favorite feedback from Gorse: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsFavorited 检查是否收藏
|
||||
func (s *PostService) IsFavorited(ctx context.Context, postID, userID string) bool {
|
||||
return s.postRepo.IsFavorited(postID, userID)
|
||||
}
|
||||
|
||||
// IncrementViews 增加帖子观看量并同步到Gorse
|
||||
func (s *PostService) IncrementViews(ctx context.Context, postID, userID string) error {
|
||||
if err := s.postRepo.IncrementViews(postID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 同步浏览行为到Gorse(异步)
|
||||
go func() {
|
||||
if !s.gorseClient.IsEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
feedbackUserID := userID
|
||||
if feedbackUserID == "" {
|
||||
feedbackUserID = anonymousViewUserID
|
||||
}
|
||||
|
||||
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeRead, feedbackUserID, postID); err != nil {
|
||||
log.Printf("[WARN] Failed to insert read feedback to Gorse: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFavorites 获取收藏列表
|
||||
func (s *PostService) GetFavorites(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
return s.postRepo.GetFavorites(userID, page, pageSize)
|
||||
}
|
||||
|
||||
// Search 搜索帖子
|
||||
func (s *PostService) Search(ctx context.Context, keyword string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
return s.postRepo.Search(keyword, page, pageSize)
|
||||
}
|
||||
|
||||
// GetFollowingPosts 获取关注用户的帖子(带缓存)
|
||||
func (s *PostService) GetFollowingPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
cacheSettings := cache.GetSettings()
|
||||
postListTTL := cacheSettings.PostListTTL
|
||||
if postListTTL <= 0 {
|
||||
postListTTL = PostListTTL
|
||||
}
|
||||
nullTTL := cacheSettings.NullTTL
|
||||
if nullTTL <= 0 {
|
||||
nullTTL = PostListNullTTL
|
||||
}
|
||||
jitter := cacheSettings.JitterRatio
|
||||
if jitter <= 0 {
|
||||
jitter = PostListJitterRatio
|
||||
}
|
||||
|
||||
// 生成缓存键
|
||||
cacheKey := cache.PostListKey("follow", userID, page, pageSize)
|
||||
|
||||
result, err := cache.GetOrLoadTyped[*PostListResult](
|
||||
s.cache,
|
||||
cacheKey,
|
||||
postListTTL,
|
||||
jitter,
|
||||
nullTTL,
|
||||
func() (*PostListResult, error) {
|
||||
posts, total, err := s.postRepo.GetFollowingPosts(userID, page, pageSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PostListResult{Posts: posts, Total: total}, nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if result == nil {
|
||||
return []*model.Post{}, 0, nil
|
||||
}
|
||||
return result.Posts, result.Total, nil
|
||||
}
|
||||
|
||||
// GetHotPosts 获取热门帖子(使用Gorse非个性化推荐)
|
||||
func (s *PostService) GetHotPosts(ctx context.Context, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
// 如果Gorse启用,使用自定义的非个性化推荐器
|
||||
if s.gorseClient.IsEnabled() {
|
||||
offset := (page - 1) * pageSize
|
||||
// 使用 most_liked_weekly 推荐器获取周热门
|
||||
// 多取1条用于判断是否还有下一页
|
||||
itemIDs, err := s.gorseClient.GetNonPersonalized(ctx, "most_liked_weekly", pageSize+1, offset, "")
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Gorse GetNonPersonalized failed: %v, fallback to database", err)
|
||||
return s.getHotPostsFromDB(ctx, page, pageSize)
|
||||
}
|
||||
if len(itemIDs) > 0 {
|
||||
hasNext := len(itemIDs) > pageSize
|
||||
if hasNext {
|
||||
itemIDs = itemIDs[:pageSize]
|
||||
}
|
||||
posts, err := s.postRepo.GetByIDs(itemIDs)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
// 近似 total:当 hasNext 为 true 时,按分页窗口估算,避免因脏数据/缺失数据导致总页数被低估
|
||||
estimatedTotal := int64(offset + len(posts))
|
||||
if hasNext {
|
||||
estimatedTotal = int64(offset + pageSize + 1)
|
||||
}
|
||||
return posts, estimatedTotal, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 降级:从数据库获取
|
||||
return s.getHotPostsFromDB(ctx, page, pageSize)
|
||||
}
|
||||
|
||||
// getHotPostsFromDB 从数据库获取热门帖子(降级路径)
|
||||
func (s *PostService) getHotPostsFromDB(ctx context.Context, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
// 直接查询数据库,不再使用本地缓存(Gorse失败降级时使用)
|
||||
posts, total, err := s.postRepo.GetHotPosts(page, pageSize)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return posts, total, nil
|
||||
}
|
||||
|
||||
// GetRecommendedPosts 获取推荐帖子
|
||||
func (s *PostService) GetRecommendedPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
// 如果Gorse未启用或用户未登录,降级为热门帖子
|
||||
if !s.gorseClient.IsEnabled() || userID == "" {
|
||||
return s.GetHotPosts(ctx, page, pageSize)
|
||||
}
|
||||
|
||||
// 计算偏移量
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
// 从Gorse获取推荐列表
|
||||
// 多取1条用于判断是否还有下一页
|
||||
itemIDs, err := s.gorseClient.GetRecommend(ctx, userID, pageSize+1, offset)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Gorse recommendation failed: %v, fallback to hot posts", err)
|
||||
return s.GetHotPosts(ctx, page, pageSize)
|
||||
}
|
||||
|
||||
// 如果没有推荐结果,降级为热门帖子
|
||||
if len(itemIDs) == 0 {
|
||||
return s.GetHotPosts(ctx, page, pageSize)
|
||||
}
|
||||
|
||||
hasNext := len(itemIDs) > pageSize
|
||||
if hasNext {
|
||||
itemIDs = itemIDs[:pageSize]
|
||||
}
|
||||
|
||||
// 根据ID列表查询帖子详情
|
||||
posts, err := s.postRepo.GetByIDs(itemIDs)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 近似 total:当 hasNext 为 true 时,按分页窗口估算,避免因脏数据/缺失数据导致总页数被低估
|
||||
estimatedTotal := int64(offset + len(posts))
|
||||
if hasNext {
|
||||
estimatedTotal = int64(offset + pageSize + 1)
|
||||
}
|
||||
return posts, estimatedTotal, nil
|
||||
}
|
||||
|
||||
// buildPostCategories 构建帖子的类别标签
|
||||
func (s *PostService) buildPostCategories(post *model.Post) []string {
|
||||
var categories []string
|
||||
|
||||
// 热度标签
|
||||
if post.ViewsCount > 1000 {
|
||||
categories = append(categories, "hot_high")
|
||||
} else if post.ViewsCount > 100 {
|
||||
categories = append(categories, "hot_medium")
|
||||
}
|
||||
|
||||
// 点赞标签
|
||||
if post.LikesCount > 100 {
|
||||
categories = append(categories, "likes_100+")
|
||||
} else if post.LikesCount > 50 {
|
||||
categories = append(categories, "likes_50+")
|
||||
} else if post.LikesCount > 10 {
|
||||
categories = append(categories, "likes_10+")
|
||||
}
|
||||
|
||||
// 评论标签
|
||||
if post.CommentsCount > 50 {
|
||||
categories = append(categories, "comments_50+")
|
||||
} else if post.CommentsCount > 10 {
|
||||
categories = append(categories, "comments_10+")
|
||||
}
|
||||
|
||||
// 时间标签
|
||||
age := time.Since(post.CreatedAt)
|
||||
if age < 24*time.Hour {
|
||||
categories = append(categories, "today")
|
||||
} else if age < 7*24*time.Hour {
|
||||
categories = append(categories, "this_week")
|
||||
} else if age < 30*24*time.Hour {
|
||||
categories = append(categories, "this_month")
|
||||
}
|
||||
|
||||
return categories
|
||||
}
|
||||
575
internal/service/push_service.go
Normal file
575
internal/service/push_service.go
Normal file
@@ -0,0 +1,575 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/dto"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/websocket"
|
||||
"carrot_bbs/internal/repository"
|
||||
)
|
||||
|
||||
// 推送相关常量
|
||||
const (
|
||||
// DefaultPushTimeout 默认推送超时时间
|
||||
DefaultPushTimeout = 30 * time.Second
|
||||
// MaxRetryCount 最大重试次数
|
||||
MaxRetryCount = 3
|
||||
// DefaultExpiredTime 默认消息过期时间(24小时)
|
||||
DefaultExpiredTime = 24 * time.Hour
|
||||
// PushQueueSize 推送队列大小
|
||||
PushQueueSize = 1000
|
||||
)
|
||||
|
||||
// PushPriority 推送优先级
|
||||
type PushPriority int
|
||||
|
||||
const (
|
||||
PriorityLow PushPriority = 1 // 低优先级(营销消息等)
|
||||
PriorityNormal PushPriority = 5 // 普通优先级(系统通知)
|
||||
PriorityHigh PushPriority = 8 // 高优先级(聊天消息)
|
||||
PriorityCritical PushPriority = 10 // 最高优先级(重要系统通知)
|
||||
)
|
||||
|
||||
// PushService 推送服务接口
|
||||
type PushService interface {
|
||||
// 推送核心方法
|
||||
PushMessage(ctx context.Context, userID string, message *model.Message) error
|
||||
PushToUser(ctx context.Context, userID string, message *model.Message, priority int) error
|
||||
|
||||
// 系统消息推送
|
||||
PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error
|
||||
PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error
|
||||
PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error
|
||||
|
||||
// 系统通知推送(新接口,使用独立的 SystemNotification 模型)
|
||||
PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error
|
||||
|
||||
// 设备管理
|
||||
RegisterDevice(ctx context.Context, userID string, deviceID string, deviceType model.DeviceType, pushToken string) error
|
||||
UnregisterDevice(ctx context.Context, deviceID string) error
|
||||
UpdateDeviceToken(ctx context.Context, deviceID string, newPushToken string) error
|
||||
|
||||
// 推送记录管理
|
||||
CreatePushRecord(ctx context.Context, userID string, messageID string, channel model.PushChannel) (*model.PushRecord, error)
|
||||
GetPendingPushes(ctx context.Context, userID string) ([]*model.PushRecord, error)
|
||||
|
||||
// 后台任务
|
||||
StartPushWorker(ctx context.Context)
|
||||
StopPushWorker()
|
||||
}
|
||||
|
||||
// pushServiceImpl 推送服务实现
|
||||
type pushServiceImpl struct {
|
||||
pushRepo *repository.PushRecordRepository
|
||||
deviceRepo *repository.DeviceTokenRepository
|
||||
messageRepo *repository.MessageRepository
|
||||
wsManager *websocket.WebSocketManager
|
||||
|
||||
// 推送队列
|
||||
pushQueue chan *pushTask
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// pushTask 推送任务
|
||||
type pushTask struct {
|
||||
userID string
|
||||
message *model.Message
|
||||
priority int
|
||||
}
|
||||
|
||||
// NewPushService 创建推送服务
|
||||
func NewPushService(
|
||||
pushRepo *repository.PushRecordRepository,
|
||||
deviceRepo *repository.DeviceTokenRepository,
|
||||
messageRepo *repository.MessageRepository,
|
||||
wsManager *websocket.WebSocketManager,
|
||||
) PushService {
|
||||
return &pushServiceImpl{
|
||||
pushRepo: pushRepo,
|
||||
deviceRepo: deviceRepo,
|
||||
messageRepo: messageRepo,
|
||||
wsManager: wsManager,
|
||||
pushQueue: make(chan *pushTask, PushQueueSize),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// PushMessage 推送消息给用户
|
||||
func (s *pushServiceImpl) PushMessage(ctx context.Context, userID string, message *model.Message) error {
|
||||
return s.PushToUser(ctx, userID, message, int(PriorityNormal))
|
||||
}
|
||||
|
||||
// PushToUser 带优先级的推送
|
||||
func (s *pushServiceImpl) PushToUser(ctx context.Context, userID string, message *model.Message, priority int) error {
|
||||
// 首先尝试WebSocket推送(实时推送)
|
||||
if s.pushViaWebSocket(ctx, userID, message) {
|
||||
// WebSocket推送成功,记录推送状态
|
||||
record, err := s.CreatePushRecord(ctx, userID, message.ID, model.PushChannelWebSocket)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create push record: %w", err)
|
||||
}
|
||||
record.MarkPushed()
|
||||
if err := s.pushRepo.Update(record); err != nil {
|
||||
return fmt.Errorf("failed to update push record: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WebSocket推送失败,加入推送队列等待移动端推送
|
||||
select {
|
||||
case s.pushQueue <- &pushTask{
|
||||
userID: userID,
|
||||
message: message,
|
||||
priority: priority,
|
||||
}:
|
||||
return nil
|
||||
default:
|
||||
// 队列已满,直接创建待推送记录
|
||||
_, err := s.CreatePushRecord(ctx, userID, message.ID, model.PushChannelFCM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create pending push record: %w", err)
|
||||
}
|
||||
return errors.New("push queue is full, message queued for later delivery")
|
||||
}
|
||||
}
|
||||
|
||||
// pushViaWebSocket 通过WebSocket推送消息
|
||||
// 返回true表示推送成功,false表示用户不在线
|
||||
func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, message *model.Message) bool {
|
||||
if s.wsManager == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 判断是否为系统消息/通知消息
|
||||
if message.IsSystemMessage() || message.Category == model.CategoryNotification {
|
||||
// 使用 NotificationMessage 格式推送系统通知
|
||||
// 从 segments 中提取文本内容
|
||||
content := dto.ExtractTextContentFromModel(message.Segments)
|
||||
|
||||
notification := &websocket.NotificationMessage{
|
||||
ID: fmt.Sprintf("%s", message.ID),
|
||||
Type: string(message.SystemType),
|
||||
Content: content,
|
||||
Extra: make(map[string]interface{}),
|
||||
CreatedAt: message.CreatedAt.UnixMilli(),
|
||||
}
|
||||
|
||||
// 填充额外数据
|
||||
if message.ExtraData != nil {
|
||||
notification.Extra["actor_id"] = message.ExtraData.ActorID
|
||||
notification.Extra["actor_name"] = message.ExtraData.ActorName
|
||||
notification.Extra["avatar_url"] = message.ExtraData.AvatarURL
|
||||
notification.Extra["target_id"] = message.ExtraData.TargetID
|
||||
notification.Extra["target_type"] = message.ExtraData.TargetType
|
||||
notification.Extra["action_url"] = message.ExtraData.ActionURL
|
||||
notification.Extra["action_time"] = message.ExtraData.ActionTime
|
||||
|
||||
// 设置触发用户信息
|
||||
if message.ExtraData.ActorID > 0 {
|
||||
notification.TriggerUser = &websocket.NotificationUser{
|
||||
ID: fmt.Sprintf("%d", message.ExtraData.ActorID),
|
||||
Username: message.ExtraData.ActorName,
|
||||
Avatar: message.ExtraData.AvatarURL,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return true
|
||||
}
|
||||
|
||||
// 构建普通聊天消息的 WebSocket 消息 - 使用新的 WSEventResponse 格式
|
||||
// 获取会话类型 (private/group)
|
||||
detailType := "private"
|
||||
if message.ConversationID != "" {
|
||||
// 从会话中获取类型,需要查询数据库或从消息中判断
|
||||
// 这里暂时默认为 private,group 类型需要额外逻辑
|
||||
}
|
||||
|
||||
// 直接使用 message.Segments
|
||||
segments := message.Segments
|
||||
|
||||
event := &dto.WSEventResponse{
|
||||
ID: fmt.Sprintf("%s", message.ID),
|
||||
Time: message.CreatedAt.UnixMilli(),
|
||||
Type: "message",
|
||||
DetailType: detailType,
|
||||
Seq: fmt.Sprintf("%d", message.Seq),
|
||||
Segments: segments,
|
||||
SenderID: message.SenderID,
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, event)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return true
|
||||
}
|
||||
|
||||
// pushViaFCM 通过FCM推送(预留接口)
|
||||
func (s *pushServiceImpl) pushViaFCM(ctx context.Context, deviceToken *model.DeviceToken, message *model.Message) error {
|
||||
// TODO: 实现FCM推送
|
||||
// 1. 构建FCM消息
|
||||
// 2. 调用Firebase Admin SDK发送消息
|
||||
// 3. 处理发送结果
|
||||
return errors.New("FCM push not implemented")
|
||||
}
|
||||
|
||||
// pushViaAPNs 通过APNs推送(预留接口)
|
||||
func (s *pushServiceImpl) pushViaAPNs(ctx context.Context, deviceToken *model.DeviceToken, message *model.Message) error {
|
||||
// TODO: 实现APNs推送
|
||||
// 1. 构建APNs消息
|
||||
// 2. 调用APNs SDK发送消息
|
||||
// 3. 处理发送结果
|
||||
return errors.New("APNs push not implemented")
|
||||
}
|
||||
|
||||
// RegisterDevice 注册设备
|
||||
func (s *pushServiceImpl) RegisterDevice(ctx context.Context, userID string, deviceID string, deviceType model.DeviceType, pushToken string) error {
|
||||
deviceToken := &model.DeviceToken{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
DeviceType: deviceType,
|
||||
PushToken: pushToken,
|
||||
IsActive: true,
|
||||
}
|
||||
deviceToken.UpdateLastUsed()
|
||||
|
||||
return s.deviceRepo.Upsert(deviceToken)
|
||||
}
|
||||
|
||||
// UnregisterDevice 注销设备
|
||||
func (s *pushServiceImpl) UnregisterDevice(ctx context.Context, deviceID string) error {
|
||||
return s.deviceRepo.Deactivate(deviceID)
|
||||
}
|
||||
|
||||
// UpdateDeviceToken 更新设备Token
|
||||
func (s *pushServiceImpl) UpdateDeviceToken(ctx context.Context, deviceID string, newPushToken string) error {
|
||||
deviceToken, err := s.deviceRepo.GetByDeviceID(deviceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("device not found: %w", err)
|
||||
}
|
||||
|
||||
deviceToken.PushToken = newPushToken
|
||||
deviceToken.Activate()
|
||||
|
||||
return s.deviceRepo.Update(deviceToken)
|
||||
}
|
||||
|
||||
// CreatePushRecord 创建推送记录
|
||||
func (s *pushServiceImpl) CreatePushRecord(ctx context.Context, userID string, messageID string, channel model.PushChannel) (*model.PushRecord, error) {
|
||||
expiredAt := time.Now().Add(DefaultExpiredTime)
|
||||
record := &model.PushRecord{
|
||||
UserID: userID,
|
||||
MessageID: messageID,
|
||||
PushChannel: channel,
|
||||
PushStatus: model.PushStatusPending,
|
||||
MaxRetry: MaxRetryCount,
|
||||
ExpiredAt: &expiredAt,
|
||||
}
|
||||
|
||||
if err := s.pushRepo.Create(record); err != nil {
|
||||
return nil, fmt.Errorf("failed to create push record: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// GetPendingPushes 获取待推送记录
|
||||
func (s *pushServiceImpl) GetPendingPushes(ctx context.Context, userID string) ([]*model.PushRecord, error) {
|
||||
return s.pushRepo.GetByUserID(userID, 100, 0)
|
||||
}
|
||||
|
||||
// StartPushWorker 启动推送工作协程
|
||||
func (s *pushServiceImpl) StartPushWorker(ctx context.Context) {
|
||||
go s.processPushQueue()
|
||||
go s.retryFailedPushes()
|
||||
}
|
||||
|
||||
// StopPushWorker 停止推送工作协程
|
||||
func (s *pushServiceImpl) StopPushWorker() {
|
||||
close(s.stopChan)
|
||||
}
|
||||
|
||||
// processPushQueue 处理推送队列
|
||||
func (s *pushServiceImpl) processPushQueue() {
|
||||
for {
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
return
|
||||
case task := <-s.pushQueue:
|
||||
s.processPushTask(task)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processPushTask 处理单个推送任务
|
||||
func (s *pushServiceImpl) processPushTask(task *pushTask) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultPushTimeout)
|
||||
defer cancel()
|
||||
|
||||
// 获取用户活跃设备
|
||||
devices, err := s.deviceRepo.GetActiveByUserID(task.userID)
|
||||
if err != nil || len(devices) == 0 {
|
||||
// 没有可用设备,创建待推送记录
|
||||
s.CreatePushRecord(ctx, task.userID, task.message.ID, model.PushChannelFCM)
|
||||
return
|
||||
}
|
||||
|
||||
// 对每个设备创建推送记录并尝试推送
|
||||
for _, device := range devices {
|
||||
record, err := s.CreatePushRecord(ctx, task.userID, task.message.ID, s.getChannelForDevice(device))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var pushErr error
|
||||
switch {
|
||||
case device.IsIOS():
|
||||
pushErr = s.pushViaAPNs(ctx, device, task.message)
|
||||
case device.IsAndroid():
|
||||
pushErr = s.pushViaFCM(ctx, device, task.message)
|
||||
default:
|
||||
// Web设备只支持WebSocket
|
||||
continue
|
||||
}
|
||||
|
||||
if pushErr != nil {
|
||||
record.MarkFailed(pushErr.Error())
|
||||
} else {
|
||||
record.MarkPushed()
|
||||
}
|
||||
s.pushRepo.Update(record)
|
||||
}
|
||||
}
|
||||
|
||||
// getChannelForDevice 根据设备类型获取推送通道
|
||||
func (s *pushServiceImpl) getChannelForDevice(device *model.DeviceToken) model.PushChannel {
|
||||
switch device.DeviceType {
|
||||
case model.DeviceTypeIOS:
|
||||
return model.PushChannelAPNs
|
||||
case model.DeviceTypeAndroid:
|
||||
return model.PushChannelFCM
|
||||
default:
|
||||
return model.PushChannelWebSocket
|
||||
}
|
||||
}
|
||||
|
||||
// retryFailedPushes 重试失败的推送
|
||||
func (s *pushServiceImpl) retryFailedPushes() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.doRetry()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// doRetry 执行重试
|
||||
func (s *pushServiceImpl) doRetry() {
|
||||
ctx := context.Background()
|
||||
|
||||
// 获取失败待重试的推送
|
||||
records, err := s.pushRepo.GetFailedPushesForRetry(100)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
// 检查是否过期
|
||||
if record.IsExpired() {
|
||||
record.MarkExpired()
|
||||
s.pushRepo.Update(record)
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取消息
|
||||
message, err := s.messageRepo.GetMessageByID(record.MessageID)
|
||||
if err != nil {
|
||||
record.MarkFailed("message not found")
|
||||
s.pushRepo.Update(record)
|
||||
continue
|
||||
}
|
||||
|
||||
// 尝试WebSocket推送
|
||||
if s.pushViaWebSocket(ctx, record.UserID, message) {
|
||||
record.MarkDelivered()
|
||||
s.pushRepo.Update(record)
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取设备并尝试移动端推送
|
||||
if record.DeviceToken != "" {
|
||||
device, err := s.deviceRepo.GetByPushToken(record.DeviceToken)
|
||||
if err != nil {
|
||||
record.MarkFailed("device not found")
|
||||
s.pushRepo.Update(record)
|
||||
continue
|
||||
}
|
||||
|
||||
var pushErr error
|
||||
switch {
|
||||
case device.IsIOS():
|
||||
pushErr = s.pushViaAPNs(ctx, device, message)
|
||||
case device.IsAndroid():
|
||||
pushErr = s.pushViaFCM(ctx, device, message)
|
||||
}
|
||||
|
||||
if pushErr != nil {
|
||||
record.MarkFailed(pushErr.Error())
|
||||
} else {
|
||||
record.MarkPushed()
|
||||
}
|
||||
s.pushRepo.Update(record)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PushSystemMessage 推送系统消息
|
||||
func (s *pushServiceImpl) PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error {
|
||||
// 首先尝试WebSocket推送
|
||||
if s.pushSystemViaWebSocket(ctx, userID, msgType, title, content, data) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 用户不在线,创建待推送记录(移动端上线后可通过其他方式获取)
|
||||
// 系统消息通常不需要离线推送,客户端上线后会主动拉取
|
||||
return errors.New("user is offline, system message will be available on next sync")
|
||||
}
|
||||
|
||||
// pushSystemViaWebSocket 通过WebSocket推送系统消息
|
||||
func (s *pushServiceImpl) pushSystemViaWebSocket(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) bool {
|
||||
if s.wsManager == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
sysMsg := &websocket.SystemMessage{
|
||||
Type: msgType,
|
||||
Title: title,
|
||||
Content: content,
|
||||
Data: data,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return true
|
||||
}
|
||||
|
||||
// PushNotification 推送通知消息
|
||||
func (s *pushServiceImpl) PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error {
|
||||
// 首先尝试WebSocket推送
|
||||
if s.pushNotificationViaWebSocket(ctx, userID, notification) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 用户不在线,创建待推送记录
|
||||
// 通知消息可以等用户上线后拉取
|
||||
return errors.New("user is offline, notification will be available on next sync")
|
||||
}
|
||||
|
||||
// pushNotificationViaWebSocket 通过WebSocket推送通知消息
|
||||
func (s *pushServiceImpl) pushNotificationViaWebSocket(ctx context.Context, userID string, notification *websocket.NotificationMessage) bool {
|
||||
if s.wsManager == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
if notification.CreatedAt == 0 {
|
||||
notification.CreatedAt = time.Now().UnixMilli()
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return true
|
||||
}
|
||||
|
||||
// PushAnnouncement 广播公告消息
|
||||
func (s *pushServiceImpl) PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error {
|
||||
if s.wsManager == nil {
|
||||
return errors.New("websocket manager not available")
|
||||
}
|
||||
|
||||
if announcement.CreatedAt == 0 {
|
||||
announcement.CreatedAt = time.Now().UnixMilli()
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
|
||||
s.wsManager.Broadcast(wsMsg)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PushSystemNotification 推送系统通知(使用独立的 SystemNotification 模型)
|
||||
func (s *pushServiceImpl) PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error {
|
||||
// 首先尝试WebSocket推送
|
||||
if s.pushSystemNotificationViaWebSocket(ctx, userID, notification) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 用户不在线,系统通知已存储在数据库中,用户上线后会主动拉取
|
||||
return nil
|
||||
}
|
||||
|
||||
// pushSystemNotificationViaWebSocket 通过WebSocket推送系统通知
|
||||
func (s *pushServiceImpl) pushSystemNotificationViaWebSocket(ctx context.Context, userID string, notification *model.SystemNotification) bool {
|
||||
if s.wsManager == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !s.wsManager.IsUserOnline(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 构建 WebSocket 通知消息
|
||||
wsNotification := &websocket.NotificationMessage{
|
||||
ID: fmt.Sprintf("%d", notification.ID),
|
||||
Type: string(notification.Type),
|
||||
Title: notification.Title,
|
||||
Content: notification.Content,
|
||||
Extra: make(map[string]interface{}),
|
||||
CreatedAt: notification.CreatedAt.UnixMilli(),
|
||||
}
|
||||
|
||||
// 填充额外数据
|
||||
if notification.ExtraData != nil {
|
||||
wsNotification.Extra["actor_id_str"] = notification.ExtraData.ActorIDStr
|
||||
wsNotification.Extra["actor_name"] = notification.ExtraData.ActorName
|
||||
wsNotification.Extra["avatar_url"] = notification.ExtraData.AvatarURL
|
||||
wsNotification.Extra["target_id"] = notification.ExtraData.TargetID
|
||||
wsNotification.Extra["target_type"] = notification.ExtraData.TargetType
|
||||
wsNotification.Extra["action_url"] = notification.ExtraData.ActionURL
|
||||
wsNotification.Extra["action_time"] = notification.ExtraData.ActionTime
|
||||
|
||||
// 设置触发用户信息
|
||||
if notification.ExtraData.ActorIDStr != "" {
|
||||
wsNotification.TriggerUser = &websocket.NotificationUser{
|
||||
ID: notification.ExtraData.ActorIDStr,
|
||||
Username: notification.ExtraData.ActorName,
|
||||
Avatar: notification.ExtraData.AvatarURL,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, wsNotification)
|
||||
s.wsManager.SendToUser(userID, wsMsg)
|
||||
return true
|
||||
}
|
||||
559
internal/service/sensitive_service.go
Normal file
559
internal/service/sensitive_service.go
Normal file
@@ -0,0 +1,559 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"carrot_bbs/internal/model"
|
||||
redisclient "carrot_bbs/internal/pkg/redis"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ==================== DFA 敏感词过滤实现 ====================
|
||||
|
||||
// SensitiveNode 敏感词树节点
|
||||
type SensitiveNode struct {
|
||||
// 子节点映射
|
||||
Children map[rune]*SensitiveNode
|
||||
// 是否为敏感词结尾
|
||||
IsEnd bool
|
||||
// 敏感词信息(仅在 IsEnd 为 true 时有效)
|
||||
Word string
|
||||
Level model.SensitiveWordLevel
|
||||
Category model.SensitiveWordCategory
|
||||
}
|
||||
|
||||
// NewSensitiveNode 创建新的敏感词节点
|
||||
func NewSensitiveNode() *SensitiveNode {
|
||||
return &SensitiveNode{
|
||||
Children: make(map[rune]*SensitiveNode),
|
||||
IsEnd: false,
|
||||
}
|
||||
}
|
||||
|
||||
// SensitiveWordTree 敏感词树
|
||||
type SensitiveWordTree struct {
|
||||
root *SensitiveNode
|
||||
wordCount int
|
||||
mu sync.RWMutex
|
||||
lastReload time.Time
|
||||
}
|
||||
|
||||
// NewSensitiveWordTree 创建新的敏感词树
|
||||
func NewSensitiveWordTree() *SensitiveWordTree {
|
||||
return &SensitiveWordTree{
|
||||
root: NewSensitiveNode(),
|
||||
wordCount: 0,
|
||||
lastReload: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// AddWord 添加敏感词到树中
|
||||
func (t *SensitiveWordTree) AddWord(word string, level model.SensitiveWordLevel, category model.SensitiveWordCategory) {
|
||||
if word == "" {
|
||||
return
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
node := t.root
|
||||
// 转换为小写进行匹配(不区分大小写)
|
||||
lowerWord := strings.ToLower(word)
|
||||
runes := []rune(lowerWord)
|
||||
|
||||
for _, r := range runes {
|
||||
child, exists := node.Children[r]
|
||||
if !exists {
|
||||
child = NewSensitiveNode()
|
||||
node.Children[r] = child
|
||||
}
|
||||
node = child
|
||||
}
|
||||
|
||||
// 如果不是已存在的敏感词,则计数+1
|
||||
if !node.IsEnd {
|
||||
t.wordCount++
|
||||
}
|
||||
|
||||
node.IsEnd = true
|
||||
node.Word = word
|
||||
node.Level = level
|
||||
node.Category = category
|
||||
}
|
||||
|
||||
// RemoveWord 从树中移除敏感词
|
||||
func (t *SensitiveWordTree) RemoveWord(word string) {
|
||||
if word == "" {
|
||||
return
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
lowerWord := strings.ToLower(word)
|
||||
runes := []rune(lowerWord)
|
||||
|
||||
// 查找节点
|
||||
node := t.root
|
||||
for _, r := range runes {
|
||||
child, exists := node.Children[r]
|
||||
if !exists {
|
||||
return // 敏感词不存在
|
||||
}
|
||||
node = child
|
||||
}
|
||||
|
||||
if node.IsEnd {
|
||||
node.IsEnd = false
|
||||
node.Word = ""
|
||||
t.wordCount--
|
||||
}
|
||||
}
|
||||
|
||||
// Check 检查文本是否包含敏感词,返回是否包含及敏感词列表
|
||||
func (t *SensitiveWordTree) Check(text string) (bool, []string) {
|
||||
if text == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
var foundWords []string
|
||||
runes := []rune(strings.ToLower(text))
|
||||
length := len(runes)
|
||||
|
||||
// 用于标记已找到的敏感词位置,避免重复计算
|
||||
marked := make([]bool, length)
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
// 从当前位置开始搜索
|
||||
node := t.root
|
||||
matchEnd := -1
|
||||
matchWord := ""
|
||||
|
||||
for j := i; j < length; j++ {
|
||||
child, exists := node.Children[runes[j]]
|
||||
if !exists {
|
||||
break
|
||||
}
|
||||
node = child
|
||||
|
||||
if node.IsEnd {
|
||||
matchEnd = j
|
||||
matchWord = node.Word
|
||||
}
|
||||
}
|
||||
|
||||
// 标记找到的敏感词位置
|
||||
if matchEnd >= 0 && !marked[i] {
|
||||
for k := i; k <= matchEnd; k++ {
|
||||
marked[k] = true
|
||||
}
|
||||
foundWords = append(foundWords, matchWord)
|
||||
}
|
||||
}
|
||||
|
||||
return len(foundWords) > 0, foundWords
|
||||
}
|
||||
|
||||
// Replace 替换文本中的敏感词
|
||||
func (t *SensitiveWordTree) Replace(text string, repl string) string {
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
runes := []rune(text)
|
||||
length := len(runes)
|
||||
result := make([]rune, 0, length)
|
||||
|
||||
// 用于标记已替换的位置
|
||||
marked := make([]bool, length)
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
if marked[i] {
|
||||
continue
|
||||
}
|
||||
|
||||
// 从当前位置开始搜索
|
||||
node := t.root
|
||||
matchEnd := -1
|
||||
|
||||
for j := i; j < length; j++ {
|
||||
child, exists := node.Children[runes[j]]
|
||||
if !exists {
|
||||
break
|
||||
}
|
||||
node = child
|
||||
|
||||
if node.IsEnd {
|
||||
matchEnd = j
|
||||
}
|
||||
}
|
||||
|
||||
if matchEnd >= 0 {
|
||||
// 标记已替换的位置
|
||||
for k := i; k <= matchEnd; k++ {
|
||||
marked[k] = true
|
||||
}
|
||||
// 追加替换符
|
||||
replRunes := []rune(repl)
|
||||
result = append(result, replRunes...)
|
||||
// 跳过已匹配的字符
|
||||
i = matchEnd
|
||||
} else {
|
||||
// 追加原字符
|
||||
result = append(result, runes[i])
|
||||
}
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// WordCount 获取敏感词数量
|
||||
func (t *SensitiveWordTree) WordCount() int {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.wordCount
|
||||
}
|
||||
|
||||
// ==================== 敏感词服务实现 ====================
|
||||
|
||||
// SensitiveService 敏感词服务接口
|
||||
type SensitiveService interface {
|
||||
// Check 检查文本是否包含敏感词
|
||||
Check(ctx context.Context, text string) (bool, []string)
|
||||
// Replace 替换敏感词
|
||||
Replace(ctx context.Context, text string, repl string) string
|
||||
// AddWord 添加敏感词
|
||||
AddWord(ctx context.Context, word string, category string, level int) error
|
||||
// RemoveWord 移除敏感词
|
||||
RemoveWord(ctx context.Context, word string) error
|
||||
// Reload 重新加载敏感词库
|
||||
Reload(ctx context.Context) error
|
||||
// GetWordCount 获取敏感词数量
|
||||
GetWordCount(ctx context.Context) int
|
||||
}
|
||||
|
||||
// sensitiveServiceImpl 敏感词服务实现
|
||||
type sensitiveServiceImpl struct {
|
||||
tree *SensitiveWordTree
|
||||
db *gorm.DB
|
||||
redis *redisclient.Client
|
||||
config *SensitiveConfig
|
||||
mu sync.RWMutex
|
||||
replaceStr string
|
||||
}
|
||||
|
||||
// SensitiveConfig 敏感词服务配置
|
||||
type SensitiveConfig struct {
|
||||
Enabled bool `mapstructure:"enabled" yaml:"enabled"`
|
||||
ReplaceStr string `mapstructure:"replace_str" yaml:"replace_str"`
|
||||
// 最小匹配长度
|
||||
MinMatchLen int `mapstructure:"min_match_len" yaml:"min_match_len"`
|
||||
// 是否从数据库加载
|
||||
LoadFromDB bool `mapstructure:"load_from_db" yaml:"load_from_db"`
|
||||
// 是否从Redis加载
|
||||
LoadFromRedis bool `mapstructure:"load_from_redis" yaml:"load_from_redis"`
|
||||
// Redis键前缀
|
||||
RedisKeyPrefix string `mapstructure:"redis_key_prefix" yaml:"redis_key_prefix"`
|
||||
}
|
||||
|
||||
// NewSensitiveService 创建敏感词服务
|
||||
func NewSensitiveService(db *gorm.DB, redisClient *redisclient.Client, config *SensitiveConfig) SensitiveService {
|
||||
s := &sensitiveServiceImpl{
|
||||
tree: NewSensitiveWordTree(),
|
||||
db: db,
|
||||
redis: redisClient,
|
||||
config: config,
|
||||
replaceStr: config.ReplaceStr,
|
||||
}
|
||||
|
||||
// 如果未设置替换符,默认使用 ***
|
||||
if s.replaceStr == "" {
|
||||
s.replaceStr = "***"
|
||||
}
|
||||
|
||||
// 初始化加载敏感词
|
||||
if config.LoadFromDB {
|
||||
if err := s.loadFromDB(context.Background()); err != nil {
|
||||
log.Printf("Failed to load sensitive words from database: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if config.LoadFromRedis && redisClient != nil {
|
||||
if err := s.loadFromRedis(context.Background()); err != nil {
|
||||
log.Printf("Failed to load sensitive words from redis: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Check 检查文本是否包含敏感词
|
||||
func (s *sensitiveServiceImpl) Check(ctx context.Context, text string) (bool, []string) {
|
||||
if !s.config.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
if text == "" {
|
||||
return false, nil
|
||||
}
|
||||
return s.tree.Check(text)
|
||||
}
|
||||
|
||||
// Replace 替换敏感词
|
||||
func (s *sensitiveServiceImpl) Replace(ctx context.Context, text string, repl string) string {
|
||||
if !s.config.Enabled {
|
||||
return text
|
||||
}
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
|
||||
// 如果未指定替换符,使用默认替换符
|
||||
if repl == "" {
|
||||
repl = s.replaceStr
|
||||
}
|
||||
|
||||
return s.tree.Replace(text, repl)
|
||||
}
|
||||
|
||||
// AddWord 添加敏感词
|
||||
func (s *sensitiveServiceImpl) AddWord(ctx context.Context, word string, category string, level int) error {
|
||||
if word == "" {
|
||||
return fmt.Errorf("word cannot be empty")
|
||||
}
|
||||
|
||||
// 转换为敏感词级别
|
||||
wordLevel := model.SensitiveWordLevel(level)
|
||||
if wordLevel < 1 || wordLevel > 3 {
|
||||
wordLevel = model.SensitiveWordLevelLow
|
||||
}
|
||||
|
||||
// 转换为敏感词分类
|
||||
wordCategory := model.SensitiveWordCategory(category)
|
||||
if wordCategory == "" {
|
||||
wordCategory = model.SensitiveWordCategoryOther
|
||||
}
|
||||
|
||||
// 添加到树
|
||||
s.tree.AddWord(word, wordLevel, wordCategory)
|
||||
|
||||
// 持久化到数据库
|
||||
if s.db != nil {
|
||||
sensitiveWord := model.SensitiveWord{
|
||||
Word: word,
|
||||
Category: wordCategory,
|
||||
Level: wordLevel,
|
||||
IsActive: true,
|
||||
}
|
||||
|
||||
// 使用 upsert 逻辑
|
||||
var existing model.SensitiveWord
|
||||
result := s.db.Where("word = ?", word).First(&existing)
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
if err := s.db.Create(&sensitiveWord).Error; err != nil {
|
||||
log.Printf("Failed to save sensitive word to database: %v", err)
|
||||
}
|
||||
} else if result.Error == nil {
|
||||
// 更新已存在的记录
|
||||
existing.Category = wordCategory
|
||||
existing.Level = wordLevel
|
||||
existing.IsActive = true
|
||||
if err := s.db.Save(&existing).Error; err != nil {
|
||||
log.Printf("Failed to update sensitive word in database: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 同步到 Redis
|
||||
if s.redis != nil && s.config.RedisKeyPrefix != "" {
|
||||
key := fmt.Sprintf("%s:%s", s.config.RedisKeyPrefix, word)
|
||||
data := map[string]interface{}{
|
||||
"word": word,
|
||||
"category": category,
|
||||
"level": level,
|
||||
}
|
||||
jsonData, _ := json.Marshal(data)
|
||||
s.redis.Set(ctx, key, jsonData, 0)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveWord 移除敏感词
|
||||
func (s *sensitiveServiceImpl) RemoveWord(ctx context.Context, word string) error {
|
||||
if word == "" {
|
||||
return fmt.Errorf("word cannot be empty")
|
||||
}
|
||||
|
||||
// 从树中移除
|
||||
s.tree.RemoveWord(word)
|
||||
|
||||
// 从数据库中标记为不活跃
|
||||
if s.db != nil {
|
||||
result := s.db.Model(&model.SensitiveWord{}).Where("word = ?", word).Update("is_active", false)
|
||||
if result.Error != nil {
|
||||
log.Printf("Failed to deactivate sensitive word in database: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
// 从 Redis 中删除
|
||||
if s.redis != nil && s.config.RedisKeyPrefix != "" {
|
||||
key := fmt.Sprintf("%s:%s", s.config.RedisKeyPrefix, word)
|
||||
s.redis.Del(ctx, key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reload 重新加载敏感词库
|
||||
func (s *sensitiveServiceImpl) Reload(ctx context.Context) error {
|
||||
// 清空现有树
|
||||
s.tree = NewSensitiveWordTree()
|
||||
|
||||
// 从数据库加载
|
||||
if s.config.LoadFromDB {
|
||||
if err := s.loadFromDB(ctx); err != nil {
|
||||
return fmt.Errorf("failed to load from database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 从 Redis 加载
|
||||
if s.config.LoadFromRedis && s.redis != nil {
|
||||
if err := s.loadFromRedis(ctx); err != nil {
|
||||
return fmt.Errorf("failed to load from redis: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetWordCount 获取敏感词数量
|
||||
func (s *sensitiveServiceImpl) GetWordCount(ctx context.Context) int {
|
||||
return s.tree.WordCount()
|
||||
}
|
||||
|
||||
// loadFromDB 从数据库加载敏感词
|
||||
func (s *sensitiveServiceImpl) loadFromDB(ctx context.Context) error {
|
||||
if s.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var words []model.SensitiveWord
|
||||
if err := s.db.Where("is_active = ?", true).Find(&words).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, word := range words {
|
||||
s.tree.AddWord(word.Word, word.Level, word.Category)
|
||||
}
|
||||
|
||||
log.Printf("Loaded %d sensitive words from database", len(words))
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadFromRedis 从 Redis 加载敏感词
|
||||
func (s *sensitiveServiceImpl) loadFromRedis(ctx context.Context) error {
|
||||
if s.redis == nil || s.config.RedisKeyPrefix == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用 SCAN 命令代替 KEYS,避免阻塞
|
||||
pattern := fmt.Sprintf("%s:*", s.config.RedisKeyPrefix)
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := s.redis.GetClient().Scan(ctx, cursor, pattern, 100).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
data, err := s.redis.Get(ctx, key)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var wordData map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(data), &wordData); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
word, _ := wordData["word"].(string)
|
||||
category, _ := wordData["category"].(string)
|
||||
level, _ := wordData["level"].(float64)
|
||||
|
||||
if word != "" {
|
||||
s.tree.AddWord(word, model.SensitiveWordLevel(int(level)), model.SensitiveWordCategory(category))
|
||||
}
|
||||
}
|
||||
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== 辅助函数 ====================
|
||||
|
||||
// ContainsSensitiveWord 快速检查文本是否包含敏感词
|
||||
func ContainsSensitiveWord(text string, tree *SensitiveWordTree) bool {
|
||||
if tree == nil || text == "" {
|
||||
return false
|
||||
}
|
||||
hasSensitive, _ := tree.Check(text)
|
||||
return hasSensitive
|
||||
}
|
||||
|
||||
// FilterSensitiveWords 过滤敏感词并返回替换后的文本
|
||||
func FilterSensitiveWords(text string, tree *SensitiveWordTree, repl string) string {
|
||||
if tree == nil || text == "" {
|
||||
return text
|
||||
}
|
||||
if repl == "" {
|
||||
repl = "***"
|
||||
}
|
||||
return tree.Replace(text, repl)
|
||||
}
|
||||
|
||||
// ValidateTextLength 验证文本长度是否合法
|
||||
func ValidateTextLength(text string, minLen, maxLen int) bool {
|
||||
length := utf8.RuneCountInString(text)
|
||||
return length >= minLen && length <= maxLen
|
||||
}
|
||||
|
||||
// SanitizeText 清理文本,移除多余空白字符
|
||||
func SanitizeText(text string) string {
|
||||
// 替换多个连续空白字符为单个空格
|
||||
spaceReg := regexp.MustCompile(`\s+`)
|
||||
text = spaceReg.ReplaceAllString(text, " ")
|
||||
// 去除首尾空白
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
// ==================== 默认敏感词列表 ====================
|
||||
|
||||
// DefaultSensitiveWords 返回默认敏感词列表(示例)
|
||||
func DefaultSensitiveWords() map[string]struct{} {
|
||||
return map[string]struct{}{
|
||||
// 示例敏感词,实际需要从数据库或配置加载
|
||||
"测试敏感词1": {},
|
||||
"测试敏感词2": {},
|
||||
"测试敏感词3": {},
|
||||
}
|
||||
}
|
||||
139
internal/service/sticker_service.go
Normal file
139
internal/service/sticker_service.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/repository"
|
||||
"errors"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrStickerAlreadyExists = errors.New("sticker already exists")
|
||||
ErrInvalidStickerURL = errors.New("invalid sticker url")
|
||||
)
|
||||
|
||||
// StickerService 自定义表情服务接口
|
||||
type StickerService interface {
|
||||
// 获取用户的所有表情
|
||||
GetUserStickers(userID string) ([]model.UserSticker, error)
|
||||
// 添加表情
|
||||
AddSticker(userID string, url string, width, height int) (*model.UserSticker, error)
|
||||
// 删除表情
|
||||
DeleteSticker(userID string, stickerID string) error
|
||||
// 检查表情是否已存在
|
||||
CheckExists(userID string, url string) (bool, error)
|
||||
// 重新排序
|
||||
ReorderStickers(userID string, orders map[string]int) error
|
||||
// 获取用户表情数量
|
||||
GetStickerCount(userID string) (int64, error)
|
||||
}
|
||||
|
||||
// stickerService 自定义表情服务实现
|
||||
type stickerService struct {
|
||||
stickerRepo repository.StickerRepository
|
||||
}
|
||||
|
||||
// NewStickerService 创建自定义表情服务
|
||||
func NewStickerService(stickerRepo repository.StickerRepository) StickerService {
|
||||
return &stickerService{
|
||||
stickerRepo: stickerRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserStickers 获取用户的所有表情
|
||||
func (s *stickerService) GetUserStickers(userID string) ([]model.UserSticker, error) {
|
||||
stickers, err := s.stickerRepo.GetByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 兼容历史脏数据:过滤本地文件 URI,避免客户端加载 file:// 报错
|
||||
filtered := make([]model.UserSticker, 0, len(stickers))
|
||||
for _, sticker := range stickers {
|
||||
if isValidStickerURL(sticker.URL) {
|
||||
filtered = append(filtered, sticker)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// AddSticker 添加表情
|
||||
func (s *stickerService) AddSticker(userID string, url string, width, height int) (*model.UserSticker, error) {
|
||||
if !isValidStickerURL(url) {
|
||||
return nil, ErrInvalidStickerURL
|
||||
}
|
||||
|
||||
// 检查是否已存在
|
||||
exists, err := s.stickerRepo.Exists(userID, url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrStickerAlreadyExists
|
||||
}
|
||||
|
||||
// 获取当前数量用于设置排序
|
||||
count, err := s.stickerRepo.CountByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sticker := &model.UserSticker{
|
||||
UserID: userID,
|
||||
URL: url,
|
||||
Width: width,
|
||||
Height: height,
|
||||
SortOrder: int(count), // 新表情添加到末尾
|
||||
}
|
||||
|
||||
if err := s.stickerRepo.Create(sticker); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sticker, nil
|
||||
}
|
||||
|
||||
func isValidStickerURL(raw string) bool {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
scheme := strings.ToLower(parsed.Scheme)
|
||||
return scheme == "http" || scheme == "https"
|
||||
}
|
||||
|
||||
// DeleteSticker 删除表情
|
||||
func (s *stickerService) DeleteSticker(userID string, stickerID string) error {
|
||||
// 先检查表情是否属于该用户
|
||||
sticker, err := s.stickerRepo.GetByID(stickerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sticker.UserID != userID {
|
||||
return errors.New("sticker not found")
|
||||
}
|
||||
|
||||
return s.stickerRepo.Delete(stickerID)
|
||||
}
|
||||
|
||||
// CheckExists 检查表情是否已存在
|
||||
func (s *stickerService) CheckExists(userID string, url string) (bool, error) {
|
||||
return s.stickerRepo.Exists(userID, url)
|
||||
}
|
||||
|
||||
// ReorderStickers 重新排序
|
||||
func (s *stickerService) ReorderStickers(userID string, orders map[string]int) error {
|
||||
return s.stickerRepo.BatchUpdateSortOrder(userID, orders)
|
||||
}
|
||||
|
||||
// GetStickerCount 获取用户表情数量
|
||||
func (s *stickerService) GetStickerCount(userID string) (int64, error) {
|
||||
return s.stickerRepo.CountByUserID(userID)
|
||||
}
|
||||
462
internal/service/system_message_service.go
Normal file
462
internal/service/system_message_service.go
Normal file
@@ -0,0 +1,462 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
"carrot_bbs/internal/repository"
|
||||
)
|
||||
|
||||
// SystemMessageService 系统消息服务接口
|
||||
type SystemMessageService interface {
|
||||
// 发送互动通知
|
||||
SendLikeNotification(ctx context.Context, userID string, operatorID string, postID string) error
|
||||
SendCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string) error
|
||||
SendReplyNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, replyID string) error
|
||||
SendFollowNotification(ctx context.Context, userID string, operatorID string) error
|
||||
SendMentionNotification(ctx context.Context, userID string, operatorID string, postID string) error
|
||||
SendFavoriteNotification(ctx context.Context, userID string, operatorID string, postID string) error
|
||||
SendLikeCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, commentContent string) error
|
||||
SendLikeReplyNotification(ctx context.Context, userID string, operatorID string, postID string, replyID string, replyContent string) error
|
||||
|
||||
// 发送系统公告
|
||||
SendSystemAnnouncement(ctx context.Context, userIDs []string, title string, content string) error
|
||||
SendBroadcastAnnouncement(ctx context.Context, title string, content string) error
|
||||
}
|
||||
|
||||
type systemMessageServiceImpl struct {
|
||||
notifyRepo *repository.SystemNotificationRepository
|
||||
pushService PushService
|
||||
userRepo *repository.UserRepository
|
||||
postRepo *repository.PostRepository
|
||||
cache cache.Cache
|
||||
}
|
||||
|
||||
// NewSystemMessageService 创建系统消息服务
|
||||
func NewSystemMessageService(
|
||||
notifyRepo *repository.SystemNotificationRepository,
|
||||
pushService PushService,
|
||||
userRepo *repository.UserRepository,
|
||||
postRepo *repository.PostRepository,
|
||||
) SystemMessageService {
|
||||
return &systemMessageServiceImpl{
|
||||
notifyRepo: notifyRepo,
|
||||
pushService: pushService,
|
||||
userRepo: userRepo,
|
||||
postRepo: postRepo,
|
||||
cache: cache.GetCache(),
|
||||
}
|
||||
}
|
||||
|
||||
// SendLikeNotification 发送点赞通知
|
||||
func (s *systemMessageServiceImpl) SendLikeNotification(ctx context.Context, userID string, operatorID string, postID string) error {
|
||||
// 获取操作者信息
|
||||
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取帖子标题
|
||||
postTitle, err := s.getPostTitle(postID)
|
||||
if err != nil {
|
||||
postTitle = "您的帖子"
|
||||
}
|
||||
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
ActorIDStr: operatorID,
|
||||
ActorName: actorName,
|
||||
AvatarURL: avatarURL,
|
||||
TargetID: postID,
|
||||
TargetTitle: postTitle,
|
||||
TargetType: "post",
|
||||
ActionURL: fmt.Sprintf("/posts/%s", postID),
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("%s 赞了「%s」", actorName, postTitle)
|
||||
|
||||
// 创建通知
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyLikePost, content, extraData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create like notification: %w", err)
|
||||
}
|
||||
|
||||
// 推送通知
|
||||
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||
}
|
||||
|
||||
// SendCommentNotification 发送评论通知
|
||||
func (s *systemMessageServiceImpl) SendCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string) error {
|
||||
// 获取操作者信息
|
||||
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取帖子标题
|
||||
postTitle, err := s.getPostTitle(postID)
|
||||
if err != nil {
|
||||
postTitle = "您的帖子"
|
||||
}
|
||||
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
ActorIDStr: operatorID,
|
||||
ActorName: actorName,
|
||||
AvatarURL: avatarURL,
|
||||
TargetID: postID,
|
||||
TargetTitle: postTitle,
|
||||
TargetType: "comment",
|
||||
ActionURL: fmt.Sprintf("/posts/%s?comment=%s", postID, commentID),
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("%s 评论了「%s」", actorName, postTitle)
|
||||
|
||||
// 创建通知
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyComment, content, extraData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create comment notification: %w", err)
|
||||
}
|
||||
|
||||
// 推送通知
|
||||
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||
}
|
||||
|
||||
// SendReplyNotification 发送回复通知
|
||||
func (s *systemMessageServiceImpl) SendReplyNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, replyID string) error {
|
||||
// 获取操作者信息
|
||||
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取帖子标题
|
||||
postTitle, err := s.getPostTitle(postID)
|
||||
if err != nil {
|
||||
postTitle = "您的帖子"
|
||||
}
|
||||
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
ActorIDStr: operatorID,
|
||||
ActorName: actorName,
|
||||
AvatarURL: avatarURL,
|
||||
TargetID: replyID,
|
||||
TargetTitle: postTitle,
|
||||
TargetType: "reply",
|
||||
ActionURL: fmt.Sprintf("/posts/%s?comment=%s&reply=%s", postID, commentID, replyID),
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("%s 回复了您在「%s」的评论", actorName, postTitle)
|
||||
|
||||
// 创建通知
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyReply, content, extraData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create reply notification: %w", err)
|
||||
}
|
||||
|
||||
// 推送通知
|
||||
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||
}
|
||||
|
||||
// SendFollowNotification 发送关注通知
|
||||
func (s *systemMessageServiceImpl) SendFollowNotification(ctx context.Context, userID string, operatorID string) error {
|
||||
fmt.Printf("[DEBUG] SendFollowNotification: userID=%s, operatorID=%s\n", userID, operatorID)
|
||||
|
||||
// 获取操作者信息
|
||||
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] SendFollowNotification: getActorInfo error: %v\n", err)
|
||||
return err
|
||||
}
|
||||
fmt.Printf("[DEBUG] SendFollowNotification: actorName=%s, avatarURL=%s\n", actorName, avatarURL)
|
||||
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
ActorIDStr: operatorID,
|
||||
ActorName: actorName,
|
||||
AvatarURL: avatarURL,
|
||||
TargetID: "",
|
||||
TargetType: "user",
|
||||
ActionURL: fmt.Sprintf("/users/%s", operatorID),
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("%s 关注了你", actorName)
|
||||
|
||||
// 创建通知
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyFollow, content, extraData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create follow notification: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] SendFollowNotification: notification created, ID=%d, Content=%s\n", notification.ID, notification.Content)
|
||||
|
||||
// 推送通知
|
||||
pushErr := s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||
if pushErr != nil {
|
||||
fmt.Printf("[DEBUG] SendFollowNotification: PushSystemNotification error: %v\n", pushErr)
|
||||
} else {
|
||||
fmt.Printf("[DEBUG] SendFollowNotification: PushSystemNotification success\n")
|
||||
}
|
||||
return pushErr
|
||||
}
|
||||
|
||||
// SendFavoriteNotification 发送收藏通知
|
||||
func (s *systemMessageServiceImpl) SendFavoriteNotification(ctx context.Context, userID string, operatorID string, postID string) error {
|
||||
// 获取操作者信息
|
||||
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取帖子标题
|
||||
postTitle, err := s.getPostTitle(postID)
|
||||
if err != nil {
|
||||
postTitle = "您的帖子"
|
||||
}
|
||||
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
ActorIDStr: operatorID,
|
||||
ActorName: actorName,
|
||||
AvatarURL: avatarURL,
|
||||
TargetID: postID,
|
||||
TargetTitle: postTitle,
|
||||
TargetType: "post",
|
||||
ActionURL: fmt.Sprintf("/posts/%s", postID),
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("%s 收藏了「%s」", actorName, postTitle)
|
||||
|
||||
// 创建通知
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyFavoritePost, content, extraData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create favorite notification: %w", err)
|
||||
}
|
||||
|
||||
// 推送通知
|
||||
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||
}
|
||||
|
||||
// SendLikeCommentNotification 发送评论点赞通知
|
||||
func (s *systemMessageServiceImpl) SendLikeCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, commentContent string) error {
|
||||
// 获取操作者信息
|
||||
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 截取评论内容预览(最多50字)
|
||||
preview := commentContent
|
||||
runes := []rune(preview)
|
||||
if len(runes) > 50 {
|
||||
preview = string(runes[:50]) + "..."
|
||||
}
|
||||
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
ActorIDStr: operatorID,
|
||||
ActorName: actorName,
|
||||
AvatarURL: avatarURL,
|
||||
TargetID: postID,
|
||||
TargetTitle: preview,
|
||||
TargetType: "comment",
|
||||
ActionURL: fmt.Sprintf("/posts/%s?comment=%s", postID, commentID),
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("%s 赞了您的评论", actorName)
|
||||
|
||||
// 创建通知
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyLikeComment, content, extraData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create like comment notification: %w", err)
|
||||
}
|
||||
|
||||
// 推送通知
|
||||
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||
}
|
||||
|
||||
// SendLikeReplyNotification 发送回复点赞通知
|
||||
func (s *systemMessageServiceImpl) SendLikeReplyNotification(ctx context.Context, userID string, operatorID string, postID string, replyID string, replyContent string) error {
|
||||
// 获取操作者信息
|
||||
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 截取回复内容预览(最多50字)
|
||||
preview := replyContent
|
||||
runes := []rune(preview)
|
||||
if len(runes) > 50 {
|
||||
preview = string(runes[:50]) + "..."
|
||||
}
|
||||
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
ActorIDStr: operatorID,
|
||||
ActorName: actorName,
|
||||
AvatarURL: avatarURL,
|
||||
TargetID: postID,
|
||||
TargetTitle: preview,
|
||||
TargetType: "reply",
|
||||
ActionURL: fmt.Sprintf("/posts/%s?reply=%s", postID, replyID),
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("%s 赞了您的回复", actorName)
|
||||
|
||||
// 创建通知
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyLikeReply, content, extraData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create like reply notification: %w", err)
|
||||
}
|
||||
|
||||
// 推送通知
|
||||
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||
}
|
||||
|
||||
// SendMentionNotification 发送@提及通知
|
||||
func (s *systemMessageServiceImpl) SendMentionNotification(ctx context.Context, userID string, operatorID string, postID string) error {
|
||||
// 获取操作者信息
|
||||
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取帖子标题
|
||||
postTitle, err := s.getPostTitle(postID)
|
||||
if err != nil {
|
||||
postTitle = "您的帖子"
|
||||
}
|
||||
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
ActorIDStr: operatorID,
|
||||
ActorName: actorName,
|
||||
AvatarURL: avatarURL,
|
||||
TargetID: postID,
|
||||
TargetTitle: postTitle,
|
||||
TargetType: "post",
|
||||
ActionURL: fmt.Sprintf("/posts/%s", postID),
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
content := fmt.Sprintf("%s 在「%s」中提到了你", actorName, postTitle)
|
||||
|
||||
// 创建通知
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyMention, content, extraData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create mention notification: %w", err)
|
||||
}
|
||||
|
||||
// 推送通知
|
||||
return s.pushService.PushSystemNotification(ctx, userID, notification)
|
||||
}
|
||||
|
||||
// SendSystemAnnouncement 发送系统公告给指定用户
|
||||
func (s *systemMessageServiceImpl) SendSystemAnnouncement(ctx context.Context, userIDs []string, title string, content string) error {
|
||||
for _, userID := range userIDs {
|
||||
extraData := &model.SystemNotificationExtra{
|
||||
TargetType: "announcement",
|
||||
ActionTime: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
notification, err := s.createNotification(ctx, userID, model.SysNotifyAnnounce, fmt.Sprintf("【%s】%s", title, content), extraData)
|
||||
if err != nil {
|
||||
continue // 单个失败不影响其他用户
|
||||
}
|
||||
|
||||
// 推送通知(使用高优先级)
|
||||
if err := s.pushService.PushSystemNotification(ctx, userID, notification); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendBroadcastAnnouncement 发送广播公告给所有在线用户
|
||||
func (s *systemMessageServiceImpl) SendBroadcastAnnouncement(ctx context.Context, title string, content string) error {
|
||||
// TODO: 实现广播公告
|
||||
// 1. 获取所有在线用户
|
||||
// 2. 批量发送公告
|
||||
// 3. 对于离线用户,存储为待推送记录
|
||||
return fmt.Errorf("broadcast announcement not implemented")
|
||||
}
|
||||
|
||||
// createNotification 创建系统通知(存储到独立表)
|
||||
func (s *systemMessageServiceImpl) createNotification(ctx context.Context, userID string, notifyType model.SystemNotificationType, content string, extraData *model.SystemNotificationExtra) (*model.SystemNotification, error) {
|
||||
fmt.Printf("[DEBUG] createNotification: userID=%s, notifyType=%s\n", userID, notifyType)
|
||||
|
||||
// 生成雪花算法ID
|
||||
id, err := utils.GetSnowflake().GenerateID()
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] createNotification: failed to generate ID: %v\n", err)
|
||||
return nil, fmt.Errorf("failed to generate notification ID: %w", err)
|
||||
}
|
||||
|
||||
notification := &model.SystemNotification{
|
||||
ID: id,
|
||||
ReceiverID: userID,
|
||||
Type: notifyType,
|
||||
Content: content,
|
||||
ExtraData: extraData,
|
||||
IsRead: false,
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] createNotification: notification created with ID=%d, ReceiverID=%s\n", id, userID)
|
||||
|
||||
// 保存通知到数据库
|
||||
if err := s.notifyRepo.Create(notification); err != nil {
|
||||
fmt.Printf("[DEBUG] createNotification: failed to save notification: %v\n", err)
|
||||
return nil, fmt.Errorf("failed to save notification: %w", err)
|
||||
}
|
||||
|
||||
// 失效系统消息未读数缓存
|
||||
cache.InvalidateUnreadSystem(s.cache, userID)
|
||||
|
||||
fmt.Printf("[DEBUG] createNotification: notification saved successfully, ID=%d\n", notification.ID)
|
||||
return notification, nil
|
||||
}
|
||||
|
||||
// getActorInfo 获取操作者信息
|
||||
func (s *systemMessageServiceImpl) getActorInfo(ctx context.Context, operatorID string) (string, string, error) {
|
||||
// 从用户仓储获取用户信息
|
||||
if s.userRepo != nil {
|
||||
user, err := s.userRepo.GetByID(operatorID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] getActorInfo: failed to get user %s: %v\n", operatorID, err)
|
||||
return "用户", utils.GenerateDefaultAvatarURL("用户"), nil // 返回默认值,不阻断流程
|
||||
}
|
||||
avatar := utils.GetAvatarOrDefault(user.Username, user.Nickname, user.Avatar)
|
||||
return user.Nickname, avatar, nil
|
||||
}
|
||||
// 如果没有用户仓储,返回默认值
|
||||
return "用户", utils.GenerateDefaultAvatarURL("用户"), nil
|
||||
}
|
||||
|
||||
// getPostTitle 获取帖子标题
|
||||
func (s *systemMessageServiceImpl) getPostTitle(postID string) (string, error) {
|
||||
if s.postRepo == nil {
|
||||
if len(postID) >= 8 {
|
||||
return fmt.Sprintf("帖子#%s", postID[:8]), nil
|
||||
}
|
||||
return fmt.Sprintf("帖子#%s", postID), nil
|
||||
}
|
||||
post, err := s.postRepo.GetByID(postID)
|
||||
if err != nil {
|
||||
if len(postID) >= 8 {
|
||||
return fmt.Sprintf("帖子#%s", postID[:8]), nil
|
||||
}
|
||||
return fmt.Sprintf("帖子#%s", postID), nil
|
||||
}
|
||||
if post.Title != "" {
|
||||
return post.Title, nil
|
||||
}
|
||||
// 如果没有标题,返回内容前20个字符
|
||||
if len(post.Content) > 20 {
|
||||
return post.Content[:20] + "...", nil
|
||||
}
|
||||
return post.Content, nil
|
||||
}
|
||||
273
internal/service/upload_service.go
Normal file
273
internal/service/upload_service.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/jpeg"
|
||||
"image/png"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"carrot_bbs/internal/pkg/s3"
|
||||
_ "golang.org/x/image/bmp"
|
||||
_ "golang.org/x/image/tiff"
|
||||
)
|
||||
|
||||
// UploadService 上传服务
|
||||
type UploadService struct {
|
||||
s3Client *s3.Client
|
||||
userService *UserService
|
||||
}
|
||||
|
||||
// NewUploadService 创建上传服务
|
||||
func NewUploadService(s3Client *s3.Client, userService *UserService) *UploadService {
|
||||
return &UploadService{
|
||||
s3Client: s3Client,
|
||||
userService: userService,
|
||||
}
|
||||
}
|
||||
|
||||
// UploadImage 上传图片
|
||||
func (s *UploadService) UploadImage(ctx context.Context, file *multipart.FileHeader) (string, error) {
|
||||
processedData, contentType, ext, err := prepareImageForUpload(file)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 压缩后再计算哈希,确保同一压缩结果映射同一对象名
|
||||
hash := sha256.Sum256(processedData)
|
||||
hashStr := fmt.Sprintf("%x", hash)
|
||||
|
||||
objectName := fmt.Sprintf("images/%s%s", hashStr, ext)
|
||||
|
||||
url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to upload to S3: %w", err)
|
||||
}
|
||||
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// getExtFromContentType 根据Content-Type获取文件扩展名
|
||||
func getExtFromContentType(contentType string) string {
|
||||
baseType, _, err := mime.ParseMediaType(contentType)
|
||||
if err == nil && baseType != "" {
|
||||
contentType = baseType
|
||||
}
|
||||
|
||||
switch contentType {
|
||||
case "image/jpg", "image/jpeg":
|
||||
return ".jpg"
|
||||
case "image/png":
|
||||
return ".png"
|
||||
case "image/gif":
|
||||
return ".gif"
|
||||
case "image/webp":
|
||||
return ".webp"
|
||||
case "image/bmp", "image/x-ms-bmp":
|
||||
return ".bmp"
|
||||
case "image/tiff":
|
||||
return ".tiff"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// UploadAvatar 上传头像
|
||||
func (s *UploadService) UploadAvatar(ctx context.Context, userID string, file *multipart.FileHeader) (string, error) {
|
||||
processedData, contentType, ext, err := prepareImageForUpload(file)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 压缩后再计算哈希
|
||||
hash := sha256.Sum256(processedData)
|
||||
hashStr := fmt.Sprintf("%x", hash)
|
||||
|
||||
objectName := fmt.Sprintf("avatars/%s%s", hashStr, ext)
|
||||
|
||||
url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to upload to S3: %w", err)
|
||||
}
|
||||
|
||||
// 更新用户头像
|
||||
if s.userService != nil {
|
||||
user, err := s.userService.GetUserByID(ctx, userID)
|
||||
if err == nil && user != nil {
|
||||
user.Avatar = url
|
||||
err = s.userService.UpdateUser(ctx, user)
|
||||
if err != nil {
|
||||
// 更新失败不影响上传结果,只记录日志
|
||||
fmt.Printf("[UploadAvatar] failed to update user avatar: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// UploadCover 上传头图(个人主页封面)
|
||||
func (s *UploadService) UploadCover(ctx context.Context, userID string, file *multipart.FileHeader) (string, error) {
|
||||
processedData, contentType, ext, err := prepareImageForUpload(file)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 压缩后再计算哈希
|
||||
hash := sha256.Sum256(processedData)
|
||||
hashStr := fmt.Sprintf("%x", hash)
|
||||
|
||||
objectName := fmt.Sprintf("covers/%s%s", hashStr, ext)
|
||||
|
||||
url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to upload to S3: %w", err)
|
||||
}
|
||||
|
||||
// 更新用户头图
|
||||
if s.userService != nil {
|
||||
user, err := s.userService.GetUserByID(ctx, userID)
|
||||
if err == nil && user != nil {
|
||||
user.CoverURL = url
|
||||
err = s.userService.UpdateUser(ctx, user)
|
||||
if err != nil {
|
||||
// 更新失败不影响上传结果,只记录日志
|
||||
fmt.Printf("[UploadCover] failed to update user cover: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// GetURL 获取文件URL
|
||||
func (s *UploadService) GetURL(ctx context.Context, objectName string) (string, error) {
|
||||
return s.s3Client.GetURL(ctx, objectName)
|
||||
}
|
||||
|
||||
// Delete 删除文件
|
||||
func (s *UploadService) Delete(ctx context.Context, objectName string) error {
|
||||
return s.s3Client.Delete(ctx, objectName)
|
||||
}
|
||||
|
||||
func prepareImageForUpload(file *multipart.FileHeader) ([]byte, string, string, error) {
|
||||
f, err := file.Open()
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
originalData, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
// 优先从文件字节探测真实类型,避免前端压缩/转码后 header 与实际格式不一致
|
||||
detectedType := normalizeImageContentType(http.DetectContentType(originalData))
|
||||
headerType := normalizeImageContentType(file.Header.Get("Content-Type"))
|
||||
contentType := detectedType
|
||||
if contentType == "" || contentType == "application/octet-stream" {
|
||||
contentType = headerType
|
||||
}
|
||||
|
||||
compressedData, compressedType, err := compressImageData(originalData, contentType)
|
||||
if err != nil {
|
||||
// 压缩失败时回退到原图,保证上传可用性
|
||||
compressedData = originalData
|
||||
compressedType = contentType
|
||||
}
|
||||
|
||||
if compressedType == "" {
|
||||
compressedType = contentType
|
||||
}
|
||||
if compressedType == "" {
|
||||
compressedType = http.DetectContentType(compressedData)
|
||||
}
|
||||
|
||||
ext := getExtFromContentType(compressedType)
|
||||
if ext == "" {
|
||||
ext = strings.ToLower(filepath.Ext(file.Filename))
|
||||
}
|
||||
if ext == "" {
|
||||
// 最终兜底,避免对象名无扩展名导致 URL 语义不明确
|
||||
ext = ".jpg"
|
||||
}
|
||||
|
||||
return compressedData, compressedType, ext, nil
|
||||
}
|
||||
|
||||
func compressImageData(data []byte, contentType string) ([]byte, string, error) {
|
||||
contentType = normalizeImageContentType(contentType)
|
||||
|
||||
// GIF/WebP 等格式先保留原图,避免动画和透明通道丢失
|
||||
if contentType == "image/gif" || contentType == "image/webp" {
|
||||
return data, contentType, nil
|
||||
}
|
||||
|
||||
if contentType != "image/jpeg" &&
|
||||
contentType != "image/png" &&
|
||||
contentType != "image/bmp" &&
|
||||
contentType != "image/x-ms-bmp" &&
|
||||
contentType != "image/tiff" {
|
||||
return data, contentType, nil
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to decode image: %w", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
switch contentType {
|
||||
case "image/png":
|
||||
encoder := png.Encoder{CompressionLevel: png.BestCompression}
|
||||
if err := encoder.Encode(&buf, img); err != nil {
|
||||
return nil, "", fmt.Errorf("failed to encode png: %w", err)
|
||||
}
|
||||
return buf.Bytes(), "image/png", nil
|
||||
default:
|
||||
// BMP/TIFF 等无损大图统一压缩为 JPEG,控制体积
|
||||
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 82}); err != nil {
|
||||
return nil, "", fmt.Errorf("failed to encode jpeg: %w", err)
|
||||
}
|
||||
return buf.Bytes(), "image/jpeg", nil
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeImageContentType(contentType string) string {
|
||||
if contentType == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
baseType, _, err := mime.ParseMediaType(contentType)
|
||||
if err == nil && baseType != "" {
|
||||
contentType = baseType
|
||||
}
|
||||
|
||||
switch strings.ToLower(contentType) {
|
||||
case "image/jpg":
|
||||
return "image/jpeg"
|
||||
case "image/jpeg":
|
||||
return "image/jpeg"
|
||||
case "image/png":
|
||||
return "image/png"
|
||||
case "image/gif":
|
||||
return "image/gif"
|
||||
case "image/webp":
|
||||
return "image/webp"
|
||||
case "image/bmp", "image/x-ms-bmp":
|
||||
return "image/bmp"
|
||||
case "image/tiff":
|
||||
return "image/tiff"
|
||||
default:
|
||||
return contentType
|
||||
}
|
||||
}
|
||||
592
internal/service/user_service.go
Normal file
592
internal/service/user_service.go
Normal file
@@ -0,0 +1,592 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"carrot_bbs/internal/cache"
|
||||
"carrot_bbs/internal/model"
|
||||
"carrot_bbs/internal/pkg/utils"
|
||||
"carrot_bbs/internal/repository"
|
||||
)
|
||||
|
||||
// UserService 用户服务
|
||||
type UserService struct {
|
||||
userRepo *repository.UserRepository
|
||||
systemMessageService SystemMessageService
|
||||
emailCodeService EmailCodeService
|
||||
}
|
||||
|
||||
// NewUserService 创建用户服务
|
||||
func NewUserService(
|
||||
userRepo *repository.UserRepository,
|
||||
systemMessageService SystemMessageService,
|
||||
emailService EmailService,
|
||||
cacheBackend cache.Cache,
|
||||
) *UserService {
|
||||
return &UserService{
|
||||
userRepo: userRepo,
|
||||
systemMessageService: systemMessageService,
|
||||
emailCodeService: NewEmailCodeService(emailService, cacheBackend),
|
||||
}
|
||||
}
|
||||
|
||||
// SendRegisterCode 发送注册验证码
|
||||
func (s *UserService) SendRegisterCode(ctx context.Context, email string) error {
|
||||
user, err := s.userRepo.GetByEmail(email)
|
||||
if err == nil && user != nil {
|
||||
return ErrEmailExists
|
||||
}
|
||||
return s.emailCodeService.SendCode(ctx, CodePurposeRegister, email)
|
||||
}
|
||||
|
||||
// SendPasswordResetCode 发送找回密码验证码
|
||||
func (s *UserService) SendPasswordResetCode(ctx context.Context, email string) error {
|
||||
user, err := s.userRepo.GetByEmail(email)
|
||||
if err != nil || user == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
return s.emailCodeService.SendCode(ctx, CodePurposePasswordReset, email)
|
||||
}
|
||||
|
||||
// SendCurrentUserEmailVerifyCode 发送当前用户邮箱验证验证码
|
||||
func (s *UserService) SendCurrentUserEmailVerifyCode(ctx context.Context, userID, email string) error {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil || user == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
targetEmail := strings.TrimSpace(email)
|
||||
if targetEmail == "" && user.Email != nil {
|
||||
targetEmail = strings.TrimSpace(*user.Email)
|
||||
}
|
||||
if targetEmail == "" || !utils.ValidateEmail(targetEmail) {
|
||||
return ErrInvalidEmail
|
||||
}
|
||||
|
||||
if user.EmailVerified && user.Email != nil && strings.EqualFold(strings.TrimSpace(*user.Email), targetEmail) {
|
||||
return ErrEmailAlreadyVerified
|
||||
}
|
||||
|
||||
existingUser, queryErr := s.userRepo.GetByEmail(targetEmail)
|
||||
if queryErr == nil && existingUser != nil && existingUser.ID != userID {
|
||||
return ErrEmailExists
|
||||
}
|
||||
|
||||
return s.emailCodeService.SendCode(ctx, CodePurposeEmailVerify, targetEmail)
|
||||
}
|
||||
|
||||
// VerifyCurrentUserEmail 验证当前用户邮箱
|
||||
func (s *UserService) VerifyCurrentUserEmail(ctx context.Context, userID, email, verificationCode string) error {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil || user == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
targetEmail := strings.TrimSpace(email)
|
||||
if targetEmail == "" && user.Email != nil {
|
||||
targetEmail = strings.TrimSpace(*user.Email)
|
||||
}
|
||||
if targetEmail == "" || !utils.ValidateEmail(targetEmail) {
|
||||
return ErrInvalidEmail
|
||||
}
|
||||
|
||||
if err := s.emailCodeService.VerifyCode(CodePurposeEmailVerify, targetEmail, verificationCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
existingUser, queryErr := s.userRepo.GetByEmail(targetEmail)
|
||||
if queryErr == nil && existingUser != nil && existingUser.ID != userID {
|
||||
return ErrEmailExists
|
||||
}
|
||||
|
||||
user.Email = &targetEmail
|
||||
user.EmailVerified = true
|
||||
return s.userRepo.Update(user)
|
||||
}
|
||||
|
||||
// SendChangePasswordCode 发送修改密码验证码
|
||||
func (s *UserService) SendChangePasswordCode(ctx context.Context, userID string) error {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil || user == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
if user.Email == nil || strings.TrimSpace(*user.Email) == "" {
|
||||
return ErrEmailNotBound
|
||||
}
|
||||
return s.emailCodeService.SendCode(ctx, CodePurposeChangePassword, *user.Email)
|
||||
}
|
||||
|
||||
// Register 用户注册
|
||||
func (s *UserService) Register(ctx context.Context, username, email, password, nickname, phone, verificationCode string) (*model.User, error) {
|
||||
// 验证用户名
|
||||
if !utils.ValidateUsername(username) {
|
||||
return nil, ErrInvalidUsername
|
||||
}
|
||||
|
||||
// 注册必须提供邮箱并完成验证码校验
|
||||
if email == "" || !utils.ValidateEmail(email) {
|
||||
return nil, ErrInvalidEmail
|
||||
}
|
||||
if err := s.emailCodeService.VerifyCode(CodePurposeRegister, email, verificationCode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if !utils.ValidatePassword(password) {
|
||||
return nil, ErrWeakPassword
|
||||
}
|
||||
|
||||
// 验证手机号(如果提供)
|
||||
if phone != "" && !utils.ValidatePhone(phone) {
|
||||
return nil, ErrInvalidPhone
|
||||
}
|
||||
|
||||
// 检查用户名是否已存在
|
||||
existingUser, err := s.userRepo.GetByUsername(username)
|
||||
if err == nil && existingUser != nil {
|
||||
return nil, ErrUsernameExists
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在(如果提供)
|
||||
if email != "" {
|
||||
existingUser, err = s.userRepo.GetByEmail(email)
|
||||
if err == nil && existingUser != nil {
|
||||
return nil, ErrEmailExists
|
||||
}
|
||||
}
|
||||
|
||||
// 检查手机号是否已存在(如果提供)
|
||||
if phone != "" {
|
||||
existingUser, err = s.userRepo.GetByPhone(phone)
|
||||
if err == nil && existingUser != nil {
|
||||
return nil, ErrPhoneExists
|
||||
}
|
||||
}
|
||||
|
||||
// 密码哈希
|
||||
hashedPassword, err := utils.HashPassword(password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
user := &model.User{
|
||||
Username: username,
|
||||
Nickname: nickname,
|
||||
EmailVerified: true,
|
||||
PasswordHash: hashedPassword,
|
||||
Status: model.UserStatusActive,
|
||||
}
|
||||
|
||||
// 如果提供了邮箱,设置指针值
|
||||
if email != "" {
|
||||
user.Email = &email
|
||||
}
|
||||
|
||||
// 如果提供了手机号,设置指针值
|
||||
if phone != "" {
|
||||
user.Phone = &phone
|
||||
}
|
||||
|
||||
err = s.userRepo.Create(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Login 用户登录
|
||||
func (s *UserService) Login(ctx context.Context, account, password string) (*model.User, error) {
|
||||
account = strings.TrimSpace(account)
|
||||
var (
|
||||
user *model.User
|
||||
err error
|
||||
)
|
||||
if utils.ValidateEmail(account) {
|
||||
user, err = s.userRepo.GetByEmail(account)
|
||||
} else if utils.ValidatePhone(account) {
|
||||
user, err = s.userRepo.GetByPhone(account)
|
||||
} else {
|
||||
user, err = s.userRepo.GetByUsername(account)
|
||||
}
|
||||
if err != nil || user == nil {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
if !utils.CheckPasswordHash(password, user.PasswordHash) {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
if user.Status != model.UserStatusActive {
|
||||
return nil, ErrUserBanned
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUserByID 根据ID获取用户
|
||||
func (s *UserService) GetUserByID(ctx context.Context, id string) (*model.User, error) {
|
||||
return s.userRepo.GetByID(id)
|
||||
}
|
||||
|
||||
// GetUserPostCount 获取用户帖子数(实时计算)
|
||||
func (s *UserService) GetUserPostCount(ctx context.Context, userID string) (int64, error) {
|
||||
return s.userRepo.GetPostsCount(userID)
|
||||
}
|
||||
|
||||
// GetUserPostCountBatch 批量获取用户帖子数(实时计算)
|
||||
func (s *UserService) GetUserPostCountBatch(ctx context.Context, userIDs []string) (map[string]int64, error) {
|
||||
return s.userRepo.GetPostsCountBatch(userIDs)
|
||||
}
|
||||
|
||||
// GetUserByIDWithFollowingStatus 根据ID获取用户(包含当前用户是否关注的状态)
|
||||
func (s *UserService) GetUserByIDWithFollowingStatus(ctx context.Context, userID, currentUserID string) (*model.User, bool, error) {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
// 如果查询的是当前用户自己,不需要检查关注状态
|
||||
if userID == currentUserID {
|
||||
return user, false, nil
|
||||
}
|
||||
|
||||
isFollowing, err := s.userRepo.IsFollowing(currentUserID, userID)
|
||||
if err != nil {
|
||||
return user, false, err
|
||||
}
|
||||
|
||||
return user, isFollowing, nil
|
||||
}
|
||||
|
||||
// GetUserByIDWithMutualFollowStatus 根据ID获取用户(包含双向关注状态)
|
||||
func (s *UserService) GetUserByIDWithMutualFollowStatus(ctx context.Context, userID, currentUserID string) (*model.User, bool, bool, error) {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return nil, false, false, err
|
||||
}
|
||||
|
||||
// 如果查询的是当前用户自己,不需要检查关注状态
|
||||
if userID == currentUserID {
|
||||
return user, false, false, nil
|
||||
}
|
||||
|
||||
// 当前用户是否关注了该用户
|
||||
isFollowing, err := s.userRepo.IsFollowing(currentUserID, userID)
|
||||
if err != nil {
|
||||
return user, false, false, err
|
||||
}
|
||||
|
||||
// 该用户是否关注了当前用户
|
||||
isFollowingMe, err := s.userRepo.IsFollowing(userID, currentUserID)
|
||||
if err != nil {
|
||||
return user, isFollowing, false, err
|
||||
}
|
||||
|
||||
return user, isFollowing, isFollowingMe, nil
|
||||
}
|
||||
|
||||
// UpdateUser 更新用户
|
||||
func (s *UserService) UpdateUser(ctx context.Context, user *model.User) error {
|
||||
return s.userRepo.Update(user)
|
||||
}
|
||||
|
||||
// GetFollowers 获取粉丝
|
||||
func (s *UserService) GetFollowers(ctx context.Context, userID string, page, pageSize int) ([]*model.User, int64, error) {
|
||||
return s.userRepo.GetFollowers(userID, page, pageSize)
|
||||
}
|
||||
|
||||
// GetFollowing 获取关注
|
||||
func (s *UserService) GetFollowing(ctx context.Context, userID string, page, pageSize int) ([]*model.User, int64, error) {
|
||||
return s.userRepo.GetFollowing(userID, page, pageSize)
|
||||
}
|
||||
|
||||
// FollowUser 关注用户
|
||||
func (s *UserService) FollowUser(ctx context.Context, followerID, followeeID string) error {
|
||||
fmt.Printf("[DEBUG] FollowUser called: followerID=%s, followeeID=%s\n", followerID, followeeID)
|
||||
|
||||
blocked, err := s.userRepo.IsBlockedEitherDirection(followerID, followeeID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if blocked {
|
||||
return ErrUserBlocked
|
||||
}
|
||||
|
||||
// 检查是否已经关注
|
||||
isFollowing, err := s.userRepo.IsFollowing(followerID, followeeID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] Error checking existing follow: %v\n", err)
|
||||
return err
|
||||
}
|
||||
if isFollowing {
|
||||
fmt.Printf("[DEBUG] Already following, skip creation\n")
|
||||
return nil // 已经关注,直接返回成功
|
||||
}
|
||||
|
||||
// 创建关注关系
|
||||
follow := &model.Follow{
|
||||
FollowerID: followerID,
|
||||
FollowingID: followeeID,
|
||||
}
|
||||
|
||||
err = s.userRepo.CreateFollow(follow)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] CreateFollow error: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] Follow record created successfully\n")
|
||||
|
||||
// 刷新关注者的关注数(通过实际计数,更可靠)
|
||||
err = s.userRepo.RefreshFollowingCount(followerID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] Error refreshing following count: %v\n", err)
|
||||
// 不回滚,计数可以通过其他方式修复
|
||||
}
|
||||
|
||||
// 刷新被关注者的粉丝数(通过实际计数,更可靠)
|
||||
err = s.userRepo.RefreshFollowersCount(followeeID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] Error refreshing followers count: %v\n", err)
|
||||
// 不回滚,计数可以通过其他方式修复
|
||||
}
|
||||
|
||||
// 发送关注通知给被关注者
|
||||
if s.systemMessageService != nil {
|
||||
// 异步发送通知,不阻塞主流程
|
||||
go func() {
|
||||
notifyErr := s.systemMessageService.SendFollowNotification(context.Background(), followeeID, followerID)
|
||||
if notifyErr != nil {
|
||||
fmt.Printf("[DEBUG] Error sending follow notification: %v\n", notifyErr)
|
||||
} else {
|
||||
fmt.Printf("[DEBUG] Follow notification sent successfully to %s\n", followeeID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] FollowUser completed: followerID=%s, followeeID=%s\n", followerID, followeeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnfollowUser 取消关注用户
|
||||
func (s *UserService) UnfollowUser(ctx context.Context, followerID, followeeID string) error {
|
||||
fmt.Printf("[DEBUG] UnfollowUser called: followerID=%s, followeeID=%s\n", followerID, followeeID)
|
||||
|
||||
// 检查是否已经关注
|
||||
isFollowing, err := s.userRepo.IsFollowing(followerID, followeeID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] Error checking existing follow: %v\n", err)
|
||||
return err
|
||||
}
|
||||
if !isFollowing {
|
||||
fmt.Printf("[DEBUG] Not following, skip deletion\n")
|
||||
return nil // 没有关注,直接返回成功
|
||||
}
|
||||
|
||||
// 删除关注关系
|
||||
err = s.userRepo.DeleteFollow(followerID, followeeID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] DeleteFollow error: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] Follow record deleted successfully\n")
|
||||
|
||||
// 刷新关注者的关注数(通过实际计数,更可靠)
|
||||
err = s.userRepo.RefreshFollowingCount(followerID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] Error refreshing following count: %v\n", err)
|
||||
}
|
||||
|
||||
// 刷新被关注者的粉丝数(通过实际计数,更可靠)
|
||||
err = s.userRepo.RefreshFollowersCount(followeeID)
|
||||
if err != nil {
|
||||
fmt.Printf("[DEBUG] Error refreshing followers count: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("[DEBUG] UnfollowUser completed: followerID=%s, followeeID=%s\n", followerID, followeeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// BlockUser 拉黑用户,并自动清理双向关注/粉丝关系
|
||||
func (s *UserService) BlockUser(ctx context.Context, blockerID, blockedID string) error {
|
||||
if blockerID == blockedID {
|
||||
return ErrInvalidOperation
|
||||
}
|
||||
return s.userRepo.BlockUserAndCleanupRelations(blockerID, blockedID)
|
||||
}
|
||||
|
||||
// UnblockUser 取消拉黑
|
||||
func (s *UserService) UnblockUser(ctx context.Context, blockerID, blockedID string) error {
|
||||
if blockerID == blockedID {
|
||||
return ErrInvalidOperation
|
||||
}
|
||||
return s.userRepo.UnblockUser(blockerID, blockedID)
|
||||
}
|
||||
|
||||
// GetBlockedUsers 获取黑名单列表
|
||||
func (s *UserService) GetBlockedUsers(ctx context.Context, blockerID string, page, pageSize int) ([]*model.User, int64, error) {
|
||||
return s.userRepo.GetBlockedUsers(blockerID, page, pageSize)
|
||||
}
|
||||
|
||||
// IsBlocked 检查当前用户是否已拉黑目标用户
|
||||
func (s *UserService) IsBlocked(ctx context.Context, blockerID, blockedID string) (bool, error) {
|
||||
return s.userRepo.IsBlocked(blockerID, blockedID)
|
||||
}
|
||||
|
||||
// GetFollowingList 获取关注列表(字符串参数版本)
|
||||
func (s *UserService) GetFollowingList(ctx context.Context, userID, page, pageSize string) ([]*model.User, error) {
|
||||
// 转换字符串参数为整数
|
||||
pageInt := 1
|
||||
pageSizeInt := 20
|
||||
if page != "" {
|
||||
_, err := fmt.Sscanf(page, "%d", &pageInt)
|
||||
if err != nil {
|
||||
pageInt = 1
|
||||
}
|
||||
}
|
||||
if pageSize != "" {
|
||||
_, err := fmt.Sscanf(pageSize, "%d", &pageSizeInt)
|
||||
if err != nil {
|
||||
pageSizeInt = 20
|
||||
}
|
||||
}
|
||||
|
||||
users, _, err := s.userRepo.GetFollowing(userID, pageInt, pageSizeInt)
|
||||
return users, err
|
||||
}
|
||||
|
||||
// GetFollowersList 获取粉丝列表(字符串参数版本)
|
||||
func (s *UserService) GetFollowersList(ctx context.Context, userID, page, pageSize string) ([]*model.User, error) {
|
||||
// 转换字符串参数为整数
|
||||
pageInt := 1
|
||||
pageSizeInt := 20
|
||||
if page != "" {
|
||||
_, err := fmt.Sscanf(page, "%d", &pageInt)
|
||||
if err != nil {
|
||||
pageInt = 1
|
||||
}
|
||||
}
|
||||
if pageSize != "" {
|
||||
_, err := fmt.Sscanf(pageSize, "%d", &pageSizeInt)
|
||||
if err != nil {
|
||||
pageSizeInt = 20
|
||||
}
|
||||
}
|
||||
|
||||
users, _, err := s.userRepo.GetFollowers(userID, pageInt, pageSizeInt)
|
||||
return users, err
|
||||
}
|
||||
|
||||
// GetMutualFollowStatus 批量获取双向关注状态
|
||||
func (s *UserService) GetMutualFollowStatus(ctx context.Context, currentUserID string, targetUserIDs []string) (map[string][2]bool, error) {
|
||||
return s.userRepo.GetMutualFollowStatus(currentUserID, targetUserIDs)
|
||||
}
|
||||
|
||||
// CheckUsernameAvailable 检查用户名是否可用
|
||||
func (s *UserService) CheckUsernameAvailable(ctx context.Context, username string) (bool, error) {
|
||||
user, err := s.userRepo.GetByUsername(username)
|
||||
if err != nil {
|
||||
return true, nil // 用户不存在,可用
|
||||
}
|
||||
return user == nil, nil
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
func (s *UserService) ChangePassword(ctx context.Context, userID, oldPassword, newPassword, verificationCode string) error {
|
||||
// 获取用户
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
if user.Email == nil || strings.TrimSpace(*user.Email) == "" {
|
||||
return ErrEmailNotBound
|
||||
}
|
||||
if err := s.emailCodeService.VerifyCode(CodePurposeChangePassword, *user.Email, verificationCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证旧密码
|
||||
if !utils.CheckPasswordHash(oldPassword, user.PasswordHash) {
|
||||
return ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// 哈希新密码
|
||||
hashedPassword, err := utils.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
user.PasswordHash = hashedPassword
|
||||
return s.userRepo.Update(user)
|
||||
}
|
||||
|
||||
// ResetPasswordByEmail 通过邮箱重置密码
|
||||
func (s *UserService) ResetPasswordByEmail(ctx context.Context, email, verificationCode, newPassword string) error {
|
||||
email = strings.TrimSpace(email)
|
||||
if !utils.ValidateEmail(email) {
|
||||
return ErrInvalidEmail
|
||||
}
|
||||
if !utils.ValidatePassword(newPassword) {
|
||||
return ErrWeakPassword
|
||||
}
|
||||
if err := s.emailCodeService.VerifyCode(CodePurposePasswordReset, email, verificationCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByEmail(email)
|
||||
if err != nil || user == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
hashedPassword, err := utils.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user.PasswordHash = hashedPassword
|
||||
return s.userRepo.Update(user)
|
||||
}
|
||||
|
||||
// Search 搜索用户
|
||||
func (s *UserService) Search(ctx context.Context, keyword string, page, pageSize int) ([]*model.User, int64, error) {
|
||||
return s.userRepo.Search(keyword, page, pageSize)
|
||||
}
|
||||
|
||||
// 错误定义
|
||||
var (
|
||||
ErrInvalidUsername = &ServiceError{Code: 400, Message: "invalid username"}
|
||||
ErrInvalidEmail = &ServiceError{Code: 400, Message: "invalid email"}
|
||||
ErrInvalidPhone = &ServiceError{Code: 400, Message: "invalid phone number"}
|
||||
ErrWeakPassword = &ServiceError{Code: 400, Message: "password too weak"}
|
||||
ErrUsernameExists = &ServiceError{Code: 400, Message: "username already exists"}
|
||||
ErrEmailExists = &ServiceError{Code: 400, Message: "email already exists"}
|
||||
ErrPhoneExists = &ServiceError{Code: 400, Message: "phone number already exists"}
|
||||
ErrUserNotFound = &ServiceError{Code: 404, Message: "user not found"}
|
||||
ErrUserBanned = &ServiceError{Code: 403, Message: "user is banned"}
|
||||
ErrUserBlocked = &ServiceError{Code: 403, Message: "blocked relationship exists"}
|
||||
ErrInvalidOperation = &ServiceError{Code: 400, Message: "invalid operation"}
|
||||
ErrEmailServiceUnavailable = &ServiceError{Code: 503, Message: "email service unavailable"}
|
||||
ErrVerificationCodeTooFrequent = &ServiceError{Code: 429, Message: "verification code sent too frequently"}
|
||||
ErrVerificationCodeInvalid = &ServiceError{Code: 400, Message: "invalid verification code"}
|
||||
ErrVerificationCodeExpired = &ServiceError{Code: 400, Message: "verification code expired"}
|
||||
ErrVerificationCodeUnavailable = &ServiceError{Code: 500, Message: "verification code storage unavailable"}
|
||||
ErrEmailAlreadyVerified = &ServiceError{Code: 400, Message: "email already verified"}
|
||||
ErrEmailNotBound = &ServiceError{Code: 400, Message: "email not bound"}
|
||||
)
|
||||
|
||||
// ServiceError 服务错误
|
||||
type ServiceError struct {
|
||||
Code int
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *ServiceError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
var ErrInvalidCredentials = &ServiceError{Code: 401, Message: "invalid username or password"}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user