Initial backend repository commit.

Set up project files and add .gitignore to exclude local build/runtime artifacts.

Made-with: Cursor
This commit is contained in:
2026-03-09 21:28:58 +08:00
commit 4d8f2ec997
102 changed files with 25022 additions and 0 deletions

16
.gitignore vendored Normal file
View File

@@ -0,0 +1,16 @@
# Build artifacts
server
*.tar
# Runtime files
data/
logs/
*.log
# Local docker metadata
.last_docker_tag
.last_docker_tar
# Local environment files
.env
.env.*

28
Dockerfile Normal file
View File

@@ -0,0 +1,28 @@
# 使用Debian bookworm-slim作为生产镜像
FROM debian:bookworm-slim
# 安装必要的运行时依赖
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
tzdata \
libsqlite3-0 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
# 设置工作目录
WORKDIR /app
# 从宿主机复制编译好的二进制文件
COPY server ./carrot_bbs
# 复制配置文件
COPY configs ./configs
# 给可执行文件添加执行权限
RUN chmod +x ./carrot_bbs
# 暴露端口
EXPOSE 8080
# 默认命令
CMD ["./carrot_bbs"]

172
configs/config.yaml Normal file
View File

@@ -0,0 +1,172 @@
# 服务器配置
# 环境变量: APP_SERVER_HOST, APP_SERVER_PORT, APP_SERVER_MODE
server:
host: 0.0.0.0
port: 8080
mode: debug
# 数据库配置
# 环境变量:
# SQLite: APP_DATABASE_SQLITE_PATH
# Postgres: APP_DATABASE_POSTGRES_HOST, APP_DATABASE_POSTGRES_PORT, APP_DATABASE_POSTGRES_USER,
# APP_DATABASE_POSTGRES_PASSWORD, APP_DATABASE_POSTGRES_DBNAME
database:
type: sqlite # sqlite 或 postgres
sqlite:
path: ./data/carrot_bbs.db
postgres:
host: localhost
port: 5432
user: postgres
password: postgres
dbname: carrot_bbs
sslmode: disable
max_idle_conns: 10
max_open_conns: 100
log_level: warn
slow_threshold_ms: 200
ignore_record_not_found: true
parameterized_queries: true
# Redis配置
# 环境变量:
# 类型: APP_REDIS_TYPE (miniredis/redis)
# Redis: APP_REDIS_REDIS_HOST, APP_REDIS_REDIS_PORT, APP_REDIS_REDIS_PASSWORD, APP_REDIS_REDIS_DB
# Miniredis: APP_REDIS_MINIREDIS_HOST, APP_REDIS_MINIREDIS_PORT
redis:
type: miniredis # miniredis 或 redis
redis:
host: localhost
port: 6379
password: ""
db: 0
miniredis:
host: localhost
port: 6379
pool_size: 10
# 缓存配置
# 环境变量:
# APP_CACHE_ENABLED, APP_CACHE_KEY_PREFIX, APP_CACHE_DEFAULT_TTL, APP_CACHE_NULL_TTL
# APP_CACHE_JITTER_RATIO, APP_CACHE_DISABLE_FLUSHDB
# APP_CACHE_MODULES_POST_LIST_TTL, APP_CACHE_MODULES_CONVERSATION_TTL
# APP_CACHE_MODULES_UNREAD_COUNT_TTL, APP_CACHE_MODULES_GROUP_MEMBERS_TTL
cache:
enabled: true
key_prefix: ""
default_ttl: 30
null_ttl: 5
jitter_ratio: 0.1
disable_flushdb: true
modules:
post_list_ttl: 30
conversation_ttl: 60
unread_count_ttl: 30
group_members_ttl: 120
# S3对象存储配置
# 环境变量: APP_S3_ENDPOINT, APP_S3_ACCESS_KEY, APP_S3_SECRET_KEY, APP_S3_BUCKET, APP_S3_DOMAIN
s3:
endpoint: ""
access_key: ""
secret_key: ""
bucket: ""
use_ssl: true
region: us-east-1
domain: ""
# JWT配置
# 环境变量: APP_JWT_SECRET
jwt:
secret: your-jwt-secret-key-change-in-production
access_token_expire: 86400 # 24 hours in seconds
refresh_token_expire: 604800 # 7 days in seconds
log:
level: info
encoding: json
output_paths:
- stdout
- ./logs/app.log
rate_limit:
enabled: true
requests_per_minute: 60
upload:
max_file_size: 10485760 # 10MB
allowed_types:
- image/jpeg
- image/png
- image/gif
- image/webp
# 敏感词过滤配置
sensitive:
enabled: true
replace_str: "***"
min_match_len: 1
load_from_db: true
load_from_redis: false
redis_key_prefix: "sensitive_words"
# 内容审核服务配置
audit:
enabled: false # 暂时关闭第三方审核
# 审核服务提供商: local, aliyun, tencent, baidu
provider: "local"
auto_audit: true
timeout: 30
# 阿里云配置
aliyun_access_key: ""
aliyun_secret_key: ""
aliyun_region: "cn-shanghai"
# 腾讯云配置
tencent_secret_id: ""
tencent_secret_key: ""
# 百度云配置
baidu_api_key: ""
baidu_secret_key: ""
# Gorse推荐系统配置
# 环境变量: APP_GORSE_ADDRESS, APP_GORSE_API_KEY, APP_GORSE_DASHBOARD, APP_GORSE_IMPORT_PASSWORD
gorse:
enabled: false
address: "" # Gorse server地址
api_key: "" # API密钥
dashboard: "" # Gorse dashboard地址
import_password: "" # 导入数据密码
embedding_api_key: ""
embedding_url: "https://api.littlelan.cn/v1/embeddings"
embedding_model: "BAAI/bge-m3"
# OpenAI兼容接口配置用于帖子审核支持图文
# 环境变量:
# APP_OPENAI_ENABLED, APP_OPENAI_BASE_URL, APP_OPENAI_API_KEY
# APP_OPENAI_MODERATION_MODEL, APP_OPENAI_MODERATION_MAX_IMAGES_PER_REQUEST
# APP_OPENAI_REQUEST_TIMEOUT, APP_OPENAI_STRICT_MODERATION
openai:
enabled: true
base_url: "https://api.littlelan.cn/"
api_key: ""
moderation_model: "qwen3.5-122b"
moderation_max_images_per_request: 1
request_timeout: 30
strict_moderation: false
# SMTP发信配置gomail.v2
# 环境变量:
# APP_EMAIL_ENABLED, APP_EMAIL_HOST, APP_EMAIL_PORT
# APP_EMAIL_USERNAME, APP_EMAIL_PASSWORD
# APP_EMAIL_FROM_ADDRESS, APP_EMAIL_FROM_NAME
# APP_EMAIL_USE_TLS, APP_EMAIL_INSECURE_SKIP_VERIFY, APP_EMAIL_TIMEOUT
email:
enabled: false
host: ""
port: 587
username: ""
password: ""
from_address: ""
from_name: "Carrot BBS"
use_tls: true
insecure_skip_verify: false
timeout: 15

78
docker-compose.yml Normal file
View File

@@ -0,0 +1,78 @@
version: '3.8'
services:
# 开发环境使用 SQLite 和 miniredis不需要外部数据库服务
# 如需使用 PostgreSQL 和 Redis切换 config.yaml 中的配置并取消注释以下服务
# PostgreSQL (生产环境使用)
# postgres:
# image: postgres:15-alpine
# container_name: carrot_bbs_postgres
# environment:
# POSTGRES_USER: postgres
# POSTGRES_PASSWORD: postgres
# POSTGRES_DB: carrot_bbs
# ports:
# - "5432:5432"
# volumes:
# - postgres_data:/var/lib/postgresql/data
# healthcheck:
# test: ["CMD-SHELL", "pg_isready -U postgres"]
# interval: 10s
# timeout: 5s
# retries: 5
# Redis (生产环境使用)
# redis:
# image: redis:7-alpine
# container_name: carrot_bbs_redis
# ports:
# - "6379:6379"
# volumes:
# - redis_data:/data
# healthcheck:
# test: ["CMD", "redis-cli", "ping"]
# interval: 10s
# timeout: 5s
# retries: 5
# MinIO (对象存储,可选)
minio:
image: minio/minio:latest
container_name: carrot_bbs_minio
environment:
MINIO_ROOT_USER: minioadmin
MINIO_ROOT_PASSWORD: minioadmin
ports:
- "9000:9000"
- "9001:9001"
volumes:
- minio_data:/data
command: server /data --console-address ":9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
interval: 30s
timeout: 20s
retries: 3
app:
build:
context: .
dockerfile: Dockerfile
container_name: carrot_bbs_app
ports:
- "8080:8080"
depends_on:
minio:
condition: service_healthy
environment:
- CONFIG_PATH=/app/configs/config.yaml
volumes:
- ./configs:/app/configs
- ./logs:/app/logs
- ./data:/app/data
volumes:
# postgres_data:
# redis_data:
minio_data:

86
go.mod Normal file
View File

@@ -0,0 +1,86 @@
module carrot_bbs
go 1.25
require (
github.com/alicebob/miniredis/v2 v2.31.0
github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.0
github.com/google/uuid v1.5.0
github.com/gorilla/websocket v1.5.3
github.com/gorse-io/gorse-go v0.5.0-alpha.3
github.com/minio/minio-go/v7 v7.0.66
github.com/redis/go-redis/v9 v9.3.1
github.com/spf13/viper v1.18.2
go.uber.org/zap v1.26.0
golang.org/x/crypto v0.17.0
golang.org/x/image v0.24.0
gorm.io/driver/postgres v1.5.4
gorm.io/driver/sqlite v1.5.4
gorm.io/gorm v1.25.5
)
require (
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/felixge/httpsnoop v1.0.3 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-logr/logr v1.2.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.4.3 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.17.4 // indirect
github.com/klauspost/cpuid/v2 v2.2.6 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mattn/go-sqlite3 v1.14.17 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/minio/sha256-simd v1.0.1 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/rs/xid v1.5.0 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/yuin/gopher-lua v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4 // indirect
go.opentelemetry.io/otel v1.11.1 // indirect
go.opentelemetry.io/otel/metric v0.33.0 // indirect
go.opentelemetry.io/otel/trace v1.11.1 // indirect
go.uber.org/multierr v1.10.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/net v0.19.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.22.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

216
go.sum Normal file
View File

@@ -0,0 +1,216 @@
github.com/DmitriyVTitov/size v1.5.0/go.mod h1:le6rNI4CoLQV1b9gzp1+3d7hMAD/uu2QcJ+aYbNgiU0=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
github.com/alicebob/miniredis/v2 v2.31.0 h1:ObEFUNlJwoIiyjxdrYF0QIDE7qXcLc7D3WpSH4c22PU=
github.com/alicebob/miniredis/v2 v2.31.0/go.mod h1:UB/T2Uztp7MlFSDakaX1sTXUv5CASoprx0wulRT6HBg=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorse-io/gorse-go v0.5.0-alpha.3 h1:GR/OWzq016VyFyTTxgQWeayGahRVzB1cGFIW/AaShC4=
github.com/gorse-io/gorse-go v0.5.0-alpha.3/go.mod h1:ZxmVHzZPKm5pmEIlqaRDwK0rkfTRHlfziO033XZ+RW0=
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc=
github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
github.com/minio/minio-go/v7 v7.0.66/go.mod h1:DHAgmyQEGdW3Cif0UooKOyrT3Vxs82zNdV6tkKhRtbs=
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.3.1 h1:KqdY8U+3X6z+iACvumCNxnoluToB+9Me+TvyFa21Mds=
github.com/redis/go-redis/v9 v9.3.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/yuin/gopher-lua v1.1.0 h1:BojcDhfyDWgU2f2TOzYK/g5p2gxMrku8oupLDqlnSqE=
github.com/yuin/gopher-lua v1.1.0/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4 h1:aUEBEdCa6iamGzg6fuYxDA8ThxvOG240mAvWDU+XLio=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4/go.mod h1:l2MdsbKTocpPS5nQZscqTR9jd8u96VYZdcpF8Sye7mA=
go.opentelemetry.io/otel v1.11.1 h1:4WLLAmcfkmDk2ukNXJyq3/kiz/3UzCaYq6PskJsaou4=
go.opentelemetry.io/otel v1.11.1/go.mod h1:1nNhXBbWSD0nsL38H6btgnFN2k4i0sNLHNNMZMSbUGE=
go.opentelemetry.io/otel/metric v0.33.0 h1:xQAyl7uGEYvrLAiV/09iTJlp1pZnQ9Wl793qbVvED1E=
go.opentelemetry.io/otel/metric v0.33.0/go.mod h1:QlTYc+EnYNq/M2mNk1qDDMRLpqCOj2f/r5c7Fd5FYaI=
go.opentelemetry.io/otel/trace v1.11.1 h1:ofxdnzsNrGBYXbP7t7zpUK281+go5rF7dvdIZXF8gdQ=
go.opentelemetry.io/otel/trace v1.11.1/go.mod h1:f/Q9G7vzk5u91PhbmKbg1Qn0rzH1LJ4vbPHFGkTPtOk=
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo=
go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/image v0.24.0 h1:AN7zRgVsbvmTfNyqIbbOraYL8mSwcKncEj8ofjgzcMQ=
golang.org/x/image v0.24.0/go.mod h1:4b/ITuLfqYq1hqZcjofwctIhi7sZh2WaCjvsBNjjya8=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk=
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE=
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0=
gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

604
internal/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,604 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"log"
"math/rand"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9"
redisPkg "carrot_bbs/internal/pkg/redis"
)
// Cache 缓存接口
type Cache interface {
// Set 设置缓存值支持TTL
Set(key string, value interface{}, ttl time.Duration)
// Get 获取缓存值
Get(key string) (interface{}, bool)
// Delete 删除缓存
Delete(key string)
// DeleteByPrefix 根据前缀删除缓存
DeleteByPrefix(prefix string)
// Clear 清空所有缓存
Clear()
// Exists 检查键是否存在
Exists(key string) bool
// Increment 增加计数器的值
Increment(key string) int64
// IncrementBy 增加指定值
IncrementBy(key string, value int64) int64
}
// cacheItem 缓存项(用于内存缓存降级)
type cacheItem struct {
value interface{}
expiration int64 // 过期时间戳(纳秒)
}
const nullMarkerValue = "__carrot_cache_null__"
type cacheMetrics struct {
hit atomic.Int64
miss atomic.Int64
decodeError atomic.Int64
setError atomic.Int64
invalidate atomic.Int64
}
var metrics cacheMetrics
var loadLocks sync.Map
type MetricsSnapshot struct {
Hit int64
Miss int64
DecodeError int64
SetError int64
Invalidate int64
}
type Settings struct {
Enabled bool
KeyPrefix string
DefaultTTL time.Duration
NullTTL time.Duration
JitterRatio float64
PostListTTL time.Duration
ConversationTTL time.Duration
UnreadCountTTL time.Duration
GroupMembersTTL time.Duration
DisableFlushDB bool
}
var settings = Settings{
Enabled: true,
DefaultTTL: 30 * time.Second,
NullTTL: 5 * time.Second,
JitterRatio: 0.1,
PostListTTL: 30 * time.Second,
ConversationTTL: 60 * time.Second,
UnreadCountTTL: 30 * time.Second,
GroupMembersTTL: 120 * time.Second,
DisableFlushDB: true,
}
func Configure(s Settings) {
settings.Enabled = s.Enabled
if s.KeyPrefix != "" {
settings.KeyPrefix = s.KeyPrefix
}
if s.DefaultTTL > 0 {
settings.DefaultTTL = s.DefaultTTL
}
if s.NullTTL > 0 {
settings.NullTTL = s.NullTTL
}
if s.JitterRatio > 0 {
settings.JitterRatio = s.JitterRatio
}
if s.PostListTTL > 0 {
settings.PostListTTL = s.PostListTTL
}
if s.ConversationTTL > 0 {
settings.ConversationTTL = s.ConversationTTL
}
if s.UnreadCountTTL > 0 {
settings.UnreadCountTTL = s.UnreadCountTTL
}
if s.GroupMembersTTL > 0 {
settings.GroupMembersTTL = s.GroupMembersTTL
}
settings.DisableFlushDB = s.DisableFlushDB
}
func GetSettings() Settings {
return settings
}
func normalizeKey(key string) string {
if settings.KeyPrefix == "" {
return key
}
return settings.KeyPrefix + ":" + key
}
func normalizePrefix(prefix string) string {
if settings.KeyPrefix == "" {
return prefix
}
return settings.KeyPrefix + ":" + prefix
}
func GetMetricsSnapshot() MetricsSnapshot {
return MetricsSnapshot{
Hit: metrics.hit.Load(),
Miss: metrics.miss.Load(),
DecodeError: metrics.decodeError.Load(),
SetError: metrics.setError.Load(),
Invalidate: metrics.invalidate.Load(),
}
}
// isExpired 检查是否过期
func (item *cacheItem) isExpired() bool {
if item.expiration == 0 {
return false
}
return time.Now().UnixNano() > item.expiration
}
// MemoryCache 内存缓存实现(降级使用)
type MemoryCache struct {
items sync.Map
// cleanupInterval 清理过期缓存的间隔
cleanupInterval time.Duration
// stopCleanup 停止清理协程的通道
stopCleanup chan struct{}
}
// NewMemoryCache 创建内存缓存
func NewMemoryCache() *MemoryCache {
c := &MemoryCache{
cleanupInterval: 1 * time.Minute,
stopCleanup: make(chan struct{}),
}
// 启动后台清理协程
go c.cleanup()
return c
}
// Set 设置缓存值
func (c *MemoryCache) Set(key string, value interface{}, ttl time.Duration) {
key = normalizeKey(key)
var expiration int64
if ttl > 0 {
expiration = time.Now().Add(ttl).UnixNano()
}
c.items.Store(key, &cacheItem{
value: value,
expiration: expiration,
})
}
// Get 获取缓存值
func (c *MemoryCache) Get(key string) (interface{}, bool) {
key = normalizeKey(key)
val, ok := c.items.Load(key)
if !ok {
return nil, false
}
item := val.(*cacheItem)
if item.isExpired() {
c.items.Delete(key)
return nil, false
}
return item.value, true
}
// Delete 删除缓存
func (c *MemoryCache) Delete(key string) {
key = normalizeKey(key)
metrics.invalidate.Add(1)
c.items.Delete(key)
}
// DeleteByPrefix 根据前缀删除缓存
func (c *MemoryCache) DeleteByPrefix(prefix string) {
prefix = normalizePrefix(prefix)
c.items.Range(func(key, value interface{}) bool {
if keyStr, ok := key.(string); ok {
if strings.HasPrefix(keyStr, prefix) {
metrics.invalidate.Add(1)
c.items.Delete(key)
}
}
return true
})
}
// Clear 清空所有缓存
func (c *MemoryCache) Clear() {
c.items.Range(func(key, value interface{}) bool {
metrics.invalidate.Add(1)
c.items.Delete(key)
return true
})
}
// Exists 检查键是否存在
func (c *MemoryCache) Exists(key string) bool {
_, ok := c.Get(key)
return ok
}
// Increment 增加计数器的值
func (c *MemoryCache) Increment(key string) int64 {
return c.IncrementBy(key, 1)
}
// IncrementBy 增加指定值
func (c *MemoryCache) IncrementBy(key string, value int64) int64 {
key = normalizeKey(key)
for {
val, ok := c.items.Load(key)
if !ok {
// 键不存在,创建新值
c.items.Store(key, &cacheItem{
value: value,
expiration: 0,
})
return value
}
item := val.(*cacheItem)
if item.isExpired() {
// 已过期,创建新值
c.items.Store(key, &cacheItem{
value: value,
expiration: 0,
})
return value
}
// 尝试更新
currentValue, ok := item.value.(int64)
if !ok {
// 类型不匹配,覆盖为新值
c.items.Store(key, &cacheItem{
value: value,
expiration: item.expiration,
})
return value
}
newValue := currentValue + value
// 使用 CAS 操作确保并发安全
if c.items.CompareAndSwap(key, val, &cacheItem{
value: newValue,
expiration: item.expiration,
}) {
return newValue
}
// CAS 失败,重试
}
}
// cleanup 定期清理过期缓存
func (c *MemoryCache) cleanup() {
ticker := time.NewTicker(c.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.cleanExpired()
case <-c.stopCleanup:
return
}
}
}
// cleanExpired 清理过期缓存
func (c *MemoryCache) cleanExpired() {
count := 0
c.items.Range(func(key, value interface{}) bool {
item := value.(*cacheItem)
if item.isExpired() {
c.items.Delete(key)
count++
}
return true
})
if count > 0 {
log.Printf("[Cache] Cleaned %d expired items", count)
}
}
// Stop 停止缓存清理协程
func (c *MemoryCache) Stop() {
close(c.stopCleanup)
}
// RedisCache Redis缓存实现
type RedisCache struct {
client *redisPkg.Client
ctx context.Context
}
// NewRedisCache 创建Redis缓存
func NewRedisCache(client *redisPkg.Client) *RedisCache {
return &RedisCache{
client: client,
ctx: context.Background(),
}
}
// Set 设置缓存值
func (c *RedisCache) Set(key string, value interface{}, ttl time.Duration) {
key = normalizeKey(key)
// 将值序列化为JSON
data, err := json.Marshal(value)
if err != nil {
metrics.setError.Add(1)
log.Printf("[RedisCache] Failed to marshal value for key %s: %v", key, err)
return
}
if err := c.client.Set(c.ctx, key, data, ttl); err != nil {
metrics.setError.Add(1)
log.Printf("[RedisCache] Failed to set key %s: %v", key, err)
}
}
// Get 获取缓存值
func (c *RedisCache) Get(key string) (interface{}, bool) {
key = normalizeKey(key)
data, err := c.client.Get(c.ctx, key)
if err != nil {
if err == redis.Nil {
return nil, false
}
log.Printf("[RedisCache] Failed to get key %s: %v", key, err)
return nil, false
}
// 返回原始字符串,由调用侧决定如何解码为目标类型
return data, true
}
// Delete 删除缓存
func (c *RedisCache) Delete(key string) {
key = normalizeKey(key)
metrics.invalidate.Add(1)
if err := c.client.Del(c.ctx, key); err != nil {
log.Printf("[RedisCache] Failed to delete key %s: %v", key, err)
}
}
// DeleteByPrefix 根据前缀删除缓存
func (c *RedisCache) DeleteByPrefix(prefix string) {
prefix = normalizePrefix(prefix)
// 使用原生客户端执行SCAN命令
rdb := c.client.GetClient()
var cursor uint64
for {
keys, nextCursor, err := rdb.Scan(c.ctx, cursor, prefix+"*", 100).Result()
if err != nil {
log.Printf("[RedisCache] Failed to scan keys with prefix %s: %v", prefix, err)
return
}
if len(keys) > 0 {
metrics.invalidate.Add(int64(len(keys)))
if err := c.client.Del(c.ctx, keys...); err != nil {
log.Printf("[RedisCache] Failed to delete keys with prefix %s: %v", prefix, err)
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
}
// Clear 清空所有缓存
func (c *RedisCache) Clear() {
if settings.DisableFlushDB {
log.Printf("[RedisCache] Skip FlushDB because cache.disable_flushdb=true")
return
}
metrics.invalidate.Add(1)
rdb := c.client.GetClient()
if err := rdb.FlushDB(c.ctx).Err(); err != nil {
log.Printf("[RedisCache] Failed to clear cache: %v", err)
}
}
// Exists 检查键是否存在
func (c *RedisCache) Exists(key string) bool {
key = normalizeKey(key)
n, err := c.client.Exists(c.ctx, key)
if err != nil {
log.Printf("[RedisCache] Failed to check existence of key %s: %v", key, err)
return false
}
return n > 0
}
// Increment 增加计数器的值
func (c *RedisCache) Increment(key string) int64 {
return c.IncrementBy(key, 1)
}
// IncrementBy 增加指定值
func (c *RedisCache) IncrementBy(key string, value int64) int64 {
key = normalizeKey(key)
rdb := c.client.GetClient()
result, err := rdb.IncrBy(c.ctx, key, value).Result()
if err != nil {
log.Printf("[RedisCache] Failed to increment key %s: %v", key, err)
return 0
}
return result
}
// 全局缓存实例
var globalCache Cache
var once sync.Once
// InitCache 初始化全局缓存实例使用Redis
func InitCache(redisClient *redisPkg.Client) {
once.Do(func() {
if redisClient != nil {
globalCache = NewRedisCache(redisClient)
log.Println("[Cache] Initialized Redis cache")
} else {
globalCache = NewMemoryCache()
log.Println("[Cache] Initialized Memory cache (Redis not available)")
}
})
}
// GetCache 获取全局缓存实例
func GetCache() Cache {
if globalCache == nil {
// 如果未初始化,返回内存缓存作为降级
log.Println("[Cache] Warning: Cache not initialized, using Memory cache")
return NewMemoryCache()
}
return globalCache
}
// GetRedisClient 从缓存中获取Redis客户端仅在Redis模式下有效
func GetRedisClient() (*redisPkg.Client, error) {
if redisCache, ok := globalCache.(*RedisCache); ok {
return redisCache.client, nil
}
return nil, fmt.Errorf("cache is not using Redis backend")
}
func SetWithJitter(c Cache, key string, value interface{}, ttl time.Duration, jitterRatio float64) {
if !settings.Enabled {
return
}
c.Set(key, value, ApplyTTLJitter(ttl, jitterRatio))
}
func SetNull(c Cache, key string, ttl time.Duration) {
if !settings.Enabled {
return
}
c.Set(key, nullMarkerValue, ttl)
}
func ApplyTTLJitter(ttl time.Duration, jitterRatio float64) time.Duration {
if ttl <= 0 || jitterRatio <= 0 {
return ttl
}
if jitterRatio > 1 {
jitterRatio = 1
}
maxJitter := int64(float64(ttl) * jitterRatio)
if maxJitter <= 0 {
return ttl
}
delta := rand.Int63n(maxJitter + 1)
return ttl + time.Duration(delta)
}
func GetTyped[T any](c Cache, key string) (T, bool) {
var zero T
if !settings.Enabled {
return zero, false
}
raw, ok := c.Get(key)
if !ok {
metrics.miss.Add(1)
return zero, false
}
if str, ok := raw.(string); ok && str == nullMarkerValue {
metrics.hit.Add(1)
return zero, false
}
if typed, ok := raw.(T); ok {
metrics.hit.Add(1)
return typed, true
}
var out T
switch v := raw.(type) {
case string:
if err := json.Unmarshal([]byte(v), &out); err != nil {
metrics.decodeError.Add(1)
return zero, false
}
metrics.hit.Add(1)
return out, true
case []byte:
if err := json.Unmarshal(v, &out); err != nil {
metrics.decodeError.Add(1)
return zero, false
}
metrics.hit.Add(1)
return out, true
default:
data, err := json.Marshal(v)
if err != nil {
metrics.decodeError.Add(1)
return zero, false
}
if err := json.Unmarshal(data, &out); err != nil {
metrics.decodeError.Add(1)
return zero, false
}
metrics.hit.Add(1)
return out, true
}
}
func GetOrLoadTyped[T any](
c Cache,
key string,
ttl time.Duration,
jitterRatio float64,
nullTTL time.Duration,
loader func() (T, error),
) (T, error) {
if cached, ok := GetTyped[T](c, key); ok {
return cached, nil
}
lockValue, _ := loadLocks.LoadOrStore(key, &sync.Mutex{})
lock := lockValue.(*sync.Mutex)
lock.Lock()
defer lock.Unlock()
if cached, ok := GetTyped[T](c, key); ok {
return cached, nil
}
loaded, err := loader()
if err != nil {
var zero T
return zero, err
}
encoded, marshalErr := json.Marshal(loaded)
if marshalErr == nil && string(encoded) == "null" && nullTTL > 0 {
SetNull(c, key, nullTTL)
return loaded, nil
}
SetWithJitter(c, key, loaded, ttl, jitterRatio)
return loaded, nil
}

147
internal/cache/keys.go vendored Normal file
View File

@@ -0,0 +1,147 @@
package cache
import (
"fmt"
)
// 缓存键前缀常量
const (
// 帖子相关
PrefixPostList = "posts:list"
PrefixPost = "posts:detail"
// 会话相关
PrefixConversationList = "conversations:list"
PrefixConversationDetail = "conversations:detail"
// 群组相关
PrefixGroupMembers = "groups:members"
PrefixGroupInfo = "groups:info"
// 未读数相关
PrefixUnreadSystem = "unread:system"
PrefixUnreadConversation = "unread:conversation"
PrefixUnreadDetail = "unread:detail"
// 用户相关
PrefixUserInfo = "users:info"
PrefixUserMe = "users:me"
)
// PostListKey 生成帖子列表缓存键
// postType: 帖子类型 (recommend, hot, follow, latest)
// page: 页码
// pageSize: 每页数量
// userID: 用户维度(仅在个性化列表如 follow 场景使用)
func PostListKey(postType string, userID string, page, pageSize int) string {
if userID == "" {
return fmt.Sprintf("%s:%s:%d:%d", PrefixPostList, postType, page, pageSize)
}
return fmt.Sprintf("%s:%s:%s:%d:%d", PrefixPostList, postType, userID, page, pageSize)
}
// PostDetailKey 生成帖子详情缓存键
func PostDetailKey(postID string) string {
return fmt.Sprintf("%s:%s", PrefixPost, postID)
}
// ConversationListKey 生成会话列表缓存键
func ConversationListKey(userID string, page, pageSize int) string {
return fmt.Sprintf("%s:%s:%d:%d", PrefixConversationList, userID, page, pageSize)
}
// ConversationDetailKey 生成会话详情缓存键
func ConversationDetailKey(conversationID, userID string) string {
return fmt.Sprintf("%s:%s:%s", PrefixConversationDetail, conversationID, userID)
}
// GroupMembersKey 生成群组成员缓存键
func GroupMembersKey(groupID string, page, pageSize int) string {
return fmt.Sprintf("%s:%s:page:%d:size:%d", PrefixGroupMembers, groupID, page, pageSize)
}
// GroupMembersAllKey 生成群组全量成员ID列表缓存键
func GroupMembersAllKey(groupID string) string {
return fmt.Sprintf("%s:all:%s", PrefixGroupMembers, groupID)
}
// GroupInfoKey 生成群组信息缓存键
func GroupInfoKey(groupID string) string {
return fmt.Sprintf("%s:%s", PrefixGroupInfo, groupID)
}
// UnreadSystemKey 生成系统消息未读数缓存键
func UnreadSystemKey(userID string) string {
return fmt.Sprintf("%s:%s", PrefixUnreadSystem, userID)
}
// UnreadConversationKey 生成会话未读总数缓存键
func UnreadConversationKey(userID string) string {
return fmt.Sprintf("%s:%s", PrefixUnreadConversation, userID)
}
// UnreadDetailKey 生成单个会话未读数缓存键
func UnreadDetailKey(userID, conversationID string) string {
return fmt.Sprintf("%s:%s:%s", PrefixUnreadDetail, userID, conversationID)
}
// UserInfoKey 生成用户信息缓存键
func UserInfoKey(userID string) string {
return fmt.Sprintf("%s:%s", PrefixUserInfo, userID)
}
// UserMeKey 生成当前用户信息缓存键
func UserMeKey(userID string) string {
return fmt.Sprintf("%s:%s", PrefixUserMe, userID)
}
// InvalidatePostList 失效帖子列表缓存
func InvalidatePostList(cache Cache) {
cache.DeleteByPrefix(PrefixPostList)
}
// InvalidatePostDetail 失效帖子详情缓存
func InvalidatePostDetail(cache Cache, postID string) {
cache.Delete(PostDetailKey(postID))
}
// InvalidateConversationList 失效会话列表缓存
func InvalidateConversationList(cache Cache, userID string) {
cache.DeleteByPrefix(PrefixConversationList + ":" + userID + ":")
}
// InvalidateConversationDetail 失效会话详情缓存
func InvalidateConversationDetail(cache Cache, conversationID, userID string) {
cache.Delete(ConversationDetailKey(conversationID, userID))
}
// InvalidateGroupMembers 失效群组成员缓存
func InvalidateGroupMembers(cache Cache, groupID string) {
cache.DeleteByPrefix(PrefixGroupMembers + ":" + groupID)
}
// InvalidateGroupInfo 失效群组信息缓存
func InvalidateGroupInfo(cache Cache, groupID string) {
cache.Delete(GroupInfoKey(groupID))
}
// InvalidateUnreadSystem 失效系统消息未读数缓存
func InvalidateUnreadSystem(cache Cache, userID string) {
cache.Delete(UnreadSystemKey(userID))
}
// InvalidateUnreadConversation 失效会话未读数缓存
func InvalidateUnreadConversation(cache Cache, userID string) {
cache.Delete(UnreadConversationKey(userID))
}
// InvalidateUnreadDetail 失效单个会话未读数缓存
func InvalidateUnreadDetail(cache Cache, userID, conversationID string) {
cache.Delete(UnreadDetailKey(userID, conversationID))
}
// InvalidateUserInfo 失效用户信息缓存
func InvalidateUserInfo(cache Cache, userID string) {
cache.Delete(UserInfoKey(userID))
cache.Delete(UserMeKey(userID))
}

393
internal/config/config.go Normal file
View File

@@ -0,0 +1,393 @@
package config
import (
"context"
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"github.com/redis/go-redis/v9"
"github.com/spf13/viper"
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
Cache CacheConfig `mapstructure:"cache"`
S3 S3Config `mapstructure:"s3"`
JWT JWTConfig `mapstructure:"jwt"`
Log LogConfig `mapstructure:"log"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Upload UploadConfig `mapstructure:"upload"`
Gorse GorseConfig `mapstructure:"gorse"`
OpenAI OpenAIConfig `mapstructure:"openai"`
Email EmailConfig `mapstructure:"email"`
}
type ServerConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"`
}
type DatabaseConfig struct {
Type string `mapstructure:"type"`
SQLite SQLiteConfig `mapstructure:"sqlite"`
Postgres PostgresConfig `mapstructure:"postgres"`
MaxIdleConns int `mapstructure:"max_idle_conns"`
MaxOpenConns int `mapstructure:"max_open_conns"`
LogLevel string `mapstructure:"log_level"`
SlowThresholdMs int `mapstructure:"slow_threshold_ms"`
IgnoreRecordNotFound bool `mapstructure:"ignore_record_not_found"`
ParameterizedQueries bool `mapstructure:"parameterized_queries"`
}
type SQLiteConfig struct {
Path string `mapstructure:"path"`
}
type PostgresConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
DBName string `mapstructure:"dbname"`
SSLMode string `mapstructure:"sslmode"`
}
func (d PostgresConfig) DSN() string {
return fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode,
)
}
type RedisConfig struct {
Type string `mapstructure:"type"`
Redis RedisServerConfig `mapstructure:"redis"`
Miniredis MiniredisConfig `mapstructure:"miniredis"`
PoolSize int `mapstructure:"pool_size"`
}
type CacheConfig struct {
Enabled bool `mapstructure:"enabled"`
KeyPrefix string `mapstructure:"key_prefix"`
DefaultTTL int `mapstructure:"default_ttl"`
NullTTL int `mapstructure:"null_ttl"`
JitterRatio float64 `mapstructure:"jitter_ratio"`
DisableFlushDB bool `mapstructure:"disable_flushdb"`
Modules CacheModuleTTL `mapstructure:"modules"`
}
type CacheModuleTTL struct {
PostList int `mapstructure:"post_list_ttl"`
Conversation int `mapstructure:"conversation_ttl"`
UnreadCount int `mapstructure:"unread_count_ttl"`
GroupMembers int `mapstructure:"group_members_ttl"`
}
type RedisServerConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
}
func (r RedisServerConfig) Addr() string {
return fmt.Sprintf("%s:%d", r.Host, r.Port)
}
type MiniredisConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
}
type S3Config struct {
Endpoint string `mapstructure:"endpoint"`
AccessKey string `mapstructure:"access_key"`
SecretKey string `mapstructure:"secret_key"`
Bucket string `mapstructure:"bucket"`
UseSSL bool `mapstructure:"use_ssl"`
Region string `mapstructure:"region"`
Domain string `mapstructure:"domain"` // 自定义域名,如 s3.carrot.skin
}
type JWTConfig struct {
Secret string `mapstructure:"secret"`
AccessTokenExpire time.Duration `mapstructure:"access_token_expire"`
RefreshTokenExpire time.Duration `mapstructure:"refresh_token_expire"`
}
type LogConfig struct {
Level string `mapstructure:"level"`
Encoding string `mapstructure:"encoding"`
OutputPaths []string `mapstructure:"output_paths"`
}
type RateLimitConfig struct {
Enabled bool `mapstructure:"enabled"`
RequestsPerMinute int `mapstructure:"requests_per_minute"`
}
type UploadConfig struct {
MaxFileSize int64 `mapstructure:"max_file_size"`
AllowedTypes []string `mapstructure:"allowed_types"`
}
type GorseConfig struct {
Address string `mapstructure:"address"`
APIKey string `mapstructure:"api_key"`
Enabled bool `mapstructure:"enabled"`
Dashboard string `mapstructure:"dashboard"`
ImportPassword string `mapstructure:"import_password"`
EmbeddingAPIKey string `mapstructure:"embedding_api_key"`
EmbeddingURL string `mapstructure:"embedding_url"`
EmbeddingModel string `mapstructure:"embedding_model"`
}
type OpenAIConfig struct {
Enabled bool `mapstructure:"enabled"`
BaseURL string `mapstructure:"base_url"`
APIKey string `mapstructure:"api_key"`
ModerationModel string `mapstructure:"moderation_model"`
ModerationMaxImagesPerRequest int `mapstructure:"moderation_max_images_per_request"`
RequestTimeout int `mapstructure:"request_timeout"`
StrictModeration bool `mapstructure:"strict_moderation"`
}
type EmailConfig struct {
Enabled bool `mapstructure:"enabled"`
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
FromAddress string `mapstructure:"from_address"`
FromName string `mapstructure:"from_name"`
UseTLS bool `mapstructure:"use_tls"`
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"`
Timeout int `mapstructure:"timeout"`
}
func Load(configPath string) (*Config, error) {
viper.SetConfigFile(configPath)
viper.SetConfigType("yaml")
// 启用环境变量支持
viper.SetEnvPrefix("APP")
viper.AutomaticEnv()
// 允许环境变量使用下划线或连字符
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_", "-", "_"))
// Set default values
viper.SetDefault("server.port", 8080)
viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.host", "0.0.0.0")
viper.SetDefault("database.type", "sqlite")
viper.SetDefault("database.sqlite.path", "./data/carrot_bbs.db")
viper.SetDefault("database.max_idle_conns", 10)
viper.SetDefault("database.max_open_conns", 100)
viper.SetDefault("database.log_level", "warn")
viper.SetDefault("database.slow_threshold_ms", 200)
viper.SetDefault("database.ignore_record_not_found", true)
viper.SetDefault("database.parameterized_queries", true)
viper.SetDefault("redis.type", "miniredis")
viper.SetDefault("redis.redis.host", "localhost")
viper.SetDefault("redis.redis.port", 6379)
viper.SetDefault("redis.redis.password", "")
viper.SetDefault("redis.redis.db", 0)
viper.SetDefault("redis.miniredis.host", "localhost")
viper.SetDefault("redis.miniredis.port", 6379)
viper.SetDefault("redis.pool_size", 10)
viper.SetDefault("cache.enabled", true)
viper.SetDefault("cache.key_prefix", "")
viper.SetDefault("cache.default_ttl", 30)
viper.SetDefault("cache.null_ttl", 5)
viper.SetDefault("cache.jitter_ratio", 0.1)
viper.SetDefault("cache.disable_flushdb", true)
viper.SetDefault("cache.modules.post_list_ttl", 30)
viper.SetDefault("cache.modules.conversation_ttl", 60)
viper.SetDefault("cache.modules.unread_count_ttl", 30)
viper.SetDefault("cache.modules.group_members_ttl", 120)
viper.SetDefault("jwt.secret", "your-jwt-secret-key-change-in-production")
viper.SetDefault("jwt.access_token_expire", 86400)
viper.SetDefault("jwt.refresh_token_expire", 604800)
viper.SetDefault("log.level", "info")
viper.SetDefault("log.encoding", "json")
viper.SetDefault("log.output_paths", []string{"stdout", "./logs/app.log"})
viper.SetDefault("rate_limit.enabled", true)
viper.SetDefault("rate_limit.requests_per_minute", 60)
viper.SetDefault("upload.max_file_size", 10485760)
viper.SetDefault("upload.allowed_types", []string{"image/jpeg", "image/png", "image/gif", "image/webp"})
viper.SetDefault("s3.endpoint", "")
viper.SetDefault("s3.access_key", "")
viper.SetDefault("s3.secret_key", "")
viper.SetDefault("s3.bucket", "")
viper.SetDefault("s3.use_ssl", true)
viper.SetDefault("s3.region", "us-east-1")
viper.SetDefault("s3.domain", "")
viper.SetDefault("sensitive.enabled", true)
viper.SetDefault("sensitive.replace_str", "***")
viper.SetDefault("audit.enabled", false)
viper.SetDefault("audit.provider", "local")
viper.SetDefault("gorse.enabled", false)
viper.SetDefault("gorse.address", "http://localhost:8087")
viper.SetDefault("gorse.api_key", "")
viper.SetDefault("gorse.dashboard", "http://localhost:8088")
viper.SetDefault("gorse.import_password", "")
viper.SetDefault("gorse.embedding_api_key", "")
viper.SetDefault("gorse.embedding_url", "https://api.littlelan.cn/v1/embeddings")
viper.SetDefault("gorse.embedding_model", "BAAI/bge-m3")
viper.SetDefault("openai.enabled", true)
viper.SetDefault("openai.base_url", "https://api.littlelan.cn/")
viper.SetDefault("openai.api_key", "")
viper.SetDefault("openai.moderation_model", "qwen3.5-122b")
viper.SetDefault("openai.moderation_max_images_per_request", 1)
viper.SetDefault("openai.request_timeout", 30)
viper.SetDefault("openai.strict_moderation", false)
viper.SetDefault("email.enabled", false)
viper.SetDefault("email.host", "")
viper.SetDefault("email.port", 587)
viper.SetDefault("email.username", "")
viper.SetDefault("email.password", "")
viper.SetDefault("email.from_address", "")
viper.SetDefault("email.from_name", "Carrot BBS")
viper.SetDefault("email.use_tls", true)
viper.SetDefault("email.insecure_skip_verify", false)
viper.SetDefault("email.timeout", 15)
if err := viper.ReadInConfig(); err != nil {
return nil, fmt.Errorf("failed to read config: %w", err)
}
var cfg Config
if err := viper.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
// Convert seconds to duration
cfg.JWT.AccessTokenExpire = time.Duration(viper.GetInt("jwt.access_token_expire")) * time.Second
cfg.JWT.RefreshTokenExpire = time.Duration(viper.GetInt("jwt.refresh_token_expire")) * time.Second
// 环境变量覆盖(显式处理敏感配置)
cfg.JWT.Secret = getEnvOrDefault("APP_JWT_SECRET", cfg.JWT.Secret)
cfg.Database.SQLite.Path = getEnvOrDefault("APP_DATABASE_SQLITE_PATH", cfg.Database.SQLite.Path)
cfg.Database.Postgres.Host = getEnvOrDefault("APP_DATABASE_POSTGRES_HOST", cfg.Database.Postgres.Host)
cfg.Database.Postgres.Port, _ = strconv.Atoi(getEnvOrDefault("APP_DATABASE_POSTGRES_PORT", fmt.Sprintf("%d", cfg.Database.Postgres.Port)))
cfg.Database.Postgres.User = getEnvOrDefault("APP_DATABASE_POSTGRES_USER", cfg.Database.Postgres.User)
cfg.Database.Postgres.Password = getEnvOrDefault("APP_DATABASE_POSTGRES_PASSWORD", cfg.Database.Postgres.Password)
cfg.Database.Postgres.DBName = getEnvOrDefault("APP_DATABASE_POSTGRES_DBNAME", cfg.Database.Postgres.DBName)
cfg.Database.LogLevel = getEnvOrDefault("APP_DATABASE_LOG_LEVEL", cfg.Database.LogLevel)
cfg.Database.SlowThresholdMs, _ = strconv.Atoi(getEnvOrDefault("APP_DATABASE_SLOW_THRESHOLD_MS", fmt.Sprintf("%d", cfg.Database.SlowThresholdMs)))
cfg.Database.IgnoreRecordNotFound, _ = strconv.ParseBool(getEnvOrDefault("APP_DATABASE_IGNORE_RECORD_NOT_FOUND", fmt.Sprintf("%t", cfg.Database.IgnoreRecordNotFound)))
cfg.Database.ParameterizedQueries, _ = strconv.ParseBool(getEnvOrDefault("APP_DATABASE_PARAMETERIZED_QUERIES", fmt.Sprintf("%t", cfg.Database.ParameterizedQueries)))
cfg.Redis.Redis.Host = getEnvOrDefault("APP_REDIS_REDIS_HOST", cfg.Redis.Redis.Host)
cfg.Redis.Redis.Port, _ = strconv.Atoi(getEnvOrDefault("APP_REDIS_REDIS_PORT", fmt.Sprintf("%d", cfg.Redis.Redis.Port)))
cfg.Redis.Redis.Password = getEnvOrDefault("APP_REDIS_REDIS_PASSWORD", cfg.Redis.Redis.Password)
cfg.Redis.Redis.DB, _ = strconv.Atoi(getEnvOrDefault("APP_REDIS_REDIS_DB", fmt.Sprintf("%d", cfg.Redis.Redis.DB)))
cfg.Redis.Miniredis.Host = getEnvOrDefault("APP_REDIS_MINIREDIS_HOST", cfg.Redis.Miniredis.Host)
cfg.Redis.Miniredis.Port, _ = strconv.Atoi(getEnvOrDefault("APP_REDIS_MINIREDIS_PORT", fmt.Sprintf("%d", cfg.Redis.Miniredis.Port)))
cfg.Redis.Type = getEnvOrDefault("APP_REDIS_TYPE", cfg.Redis.Type)
cfg.Cache.KeyPrefix = getEnvOrDefault("APP_CACHE_KEY_PREFIX", cfg.Cache.KeyPrefix)
cfg.Cache.Enabled, _ = strconv.ParseBool(getEnvOrDefault("APP_CACHE_ENABLED", fmt.Sprintf("%t", cfg.Cache.Enabled)))
cfg.Cache.DisableFlushDB, _ = strconv.ParseBool(getEnvOrDefault("APP_CACHE_DISABLE_FLUSHDB", fmt.Sprintf("%t", cfg.Cache.DisableFlushDB)))
cfg.Cache.DefaultTTL, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_DEFAULT_TTL", fmt.Sprintf("%d", cfg.Cache.DefaultTTL)))
cfg.Cache.NullTTL, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_NULL_TTL", fmt.Sprintf("%d", cfg.Cache.NullTTL)))
cfg.Cache.JitterRatio, _ = strconv.ParseFloat(getEnvOrDefault("APP_CACHE_JITTER_RATIO", fmt.Sprintf("%.2f", cfg.Cache.JitterRatio)), 64)
cfg.Cache.Modules.PostList, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_MODULES_POST_LIST_TTL", fmt.Sprintf("%d", cfg.Cache.Modules.PostList)))
cfg.Cache.Modules.Conversation, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_MODULES_CONVERSATION_TTL", fmt.Sprintf("%d", cfg.Cache.Modules.Conversation)))
cfg.Cache.Modules.UnreadCount, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_MODULES_UNREAD_COUNT_TTL", fmt.Sprintf("%d", cfg.Cache.Modules.UnreadCount)))
cfg.Cache.Modules.GroupMembers, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_MODULES_GROUP_MEMBERS_TTL", fmt.Sprintf("%d", cfg.Cache.Modules.GroupMembers)))
cfg.S3.Endpoint = getEnvOrDefault("APP_S3_ENDPOINT", cfg.S3.Endpoint)
cfg.S3.AccessKey = getEnvOrDefault("APP_S3_ACCESS_KEY", cfg.S3.AccessKey)
cfg.S3.SecretKey = getEnvOrDefault("APP_S3_SECRET_KEY", cfg.S3.SecretKey)
cfg.S3.Bucket = getEnvOrDefault("APP_S3_BUCKET", cfg.S3.Bucket)
cfg.S3.Domain = getEnvOrDefault("APP_S3_DOMAIN", cfg.S3.Domain)
cfg.Server.Host = getEnvOrDefault("APP_SERVER_HOST", cfg.Server.Host)
cfg.Server.Port, _ = strconv.Atoi(getEnvOrDefault("APP_SERVER_PORT", fmt.Sprintf("%d", cfg.Server.Port)))
cfg.Server.Mode = getEnvOrDefault("APP_SERVER_MODE", cfg.Server.Mode)
cfg.Gorse.Address = getEnvOrDefault("APP_GORSE_ADDRESS", cfg.Gorse.Address)
cfg.Gorse.APIKey = getEnvOrDefault("APP_GORSE_API_KEY", cfg.Gorse.APIKey)
cfg.Gorse.Dashboard = getEnvOrDefault("APP_GORSE_DASHBOARD", cfg.Gorse.Dashboard)
cfg.Gorse.ImportPassword = getEnvOrDefault("APP_GORSE_IMPORT_PASSWORD", cfg.Gorse.ImportPassword)
cfg.Gorse.EmbeddingAPIKey = getEnvOrDefault("APP_GORSE_EMBEDDING_API_KEY", cfg.Gorse.EmbeddingAPIKey)
cfg.Gorse.EmbeddingURL = getEnvOrDefault("APP_GORSE_EMBEDDING_URL", cfg.Gorse.EmbeddingURL)
cfg.Gorse.EmbeddingModel = getEnvOrDefault("APP_GORSE_EMBEDDING_MODEL", cfg.Gorse.EmbeddingModel)
cfg.OpenAI.BaseURL = getEnvOrDefault("APP_OPENAI_BASE_URL", cfg.OpenAI.BaseURL)
cfg.OpenAI.APIKey = getEnvOrDefault("APP_OPENAI_API_KEY", cfg.OpenAI.APIKey)
cfg.OpenAI.ModerationModel = getEnvOrDefault("APP_OPENAI_MODERATION_MODEL", cfg.OpenAI.ModerationModel)
cfg.OpenAI.ModerationMaxImagesPerRequest, _ = strconv.Atoi(getEnvOrDefault("APP_OPENAI_MODERATION_MAX_IMAGES_PER_REQUEST", fmt.Sprintf("%d", cfg.OpenAI.ModerationMaxImagesPerRequest)))
cfg.OpenAI.RequestTimeout, _ = strconv.Atoi(getEnvOrDefault("APP_OPENAI_REQUEST_TIMEOUT", fmt.Sprintf("%d", cfg.OpenAI.RequestTimeout)))
cfg.OpenAI.Enabled, _ = strconv.ParseBool(getEnvOrDefault("APP_OPENAI_ENABLED", fmt.Sprintf("%t", cfg.OpenAI.Enabled)))
cfg.OpenAI.StrictModeration, _ = strconv.ParseBool(getEnvOrDefault("APP_OPENAI_STRICT_MODERATION", fmt.Sprintf("%t", cfg.OpenAI.StrictModeration)))
cfg.Email.Enabled, _ = strconv.ParseBool(getEnvOrDefault("APP_EMAIL_ENABLED", fmt.Sprintf("%t", cfg.Email.Enabled)))
cfg.Email.Host = getEnvOrDefault("APP_EMAIL_HOST", cfg.Email.Host)
cfg.Email.Port, _ = strconv.Atoi(getEnvOrDefault("APP_EMAIL_PORT", fmt.Sprintf("%d", cfg.Email.Port)))
cfg.Email.Username = getEnvOrDefault("APP_EMAIL_USERNAME", cfg.Email.Username)
cfg.Email.Password = getEnvOrDefault("APP_EMAIL_PASSWORD", cfg.Email.Password)
cfg.Email.FromAddress = getEnvOrDefault("APP_EMAIL_FROM_ADDRESS", cfg.Email.FromAddress)
cfg.Email.FromName = getEnvOrDefault("APP_EMAIL_FROM_NAME", cfg.Email.FromName)
cfg.Email.UseTLS, _ = strconv.ParseBool(getEnvOrDefault("APP_EMAIL_USE_TLS", fmt.Sprintf("%t", cfg.Email.UseTLS)))
cfg.Email.InsecureSkipVerify, _ = strconv.ParseBool(getEnvOrDefault("APP_EMAIL_INSECURE_SKIP_VERIFY", fmt.Sprintf("%t", cfg.Email.InsecureSkipVerify)))
cfg.Email.Timeout, _ = strconv.Atoi(getEnvOrDefault("APP_EMAIL_TIMEOUT", fmt.Sprintf("%d", cfg.Email.Timeout)))
return &cfg, nil
}
// getEnvOrDefault 获取环境变量值,如果未设置则返回默认值
func getEnvOrDefault(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
// NewRedis 创建Redis客户端真实Redis
func NewRedis(cfg *RedisConfig) (*redis.Client, error) {
client := redis.NewClient(&redis.Options{
Addr: cfg.Redis.Addr(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
PoolSize: cfg.PoolSize,
})
ctx := context.Background()
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("failed to connect to redis: %w", err)
}
return client, nil
}
// NewS3 创建S3客户端
func NewS3(cfg *S3Config) (*minio.Client, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
client, err := minio.New(cfg.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(cfg.AccessKey, cfg.SecretKey, ""),
Secure: cfg.UseSSL,
})
if err != nil {
return nil, fmt.Errorf("failed to create S3 client: %w", err)
}
exists, err := client.BucketExists(ctx, cfg.Bucket)
if err != nil {
return nil, fmt.Errorf("failed to check bucket: %w", err)
}
if !exists {
if err := client.MakeBucket(ctx, cfg.Bucket, minio.MakeBucketOptions{
Region: cfg.Region,
}); err != nil {
return nil, fmt.Errorf("failed to create bucket: %w", err)
}
}
return client, nil
}

885
internal/dto/converter.go Normal file
View File

@@ -0,0 +1,885 @@
package dto
import (
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/utils"
"context"
"encoding/json"
"strconv"
)
// ==================== User 转换 ====================
// getAvatarOrDefault 获取头像URL如果为空则返回在线头像生成服务的URL
func getAvatarOrDefault(user *model.User) string {
return utils.GetAvatarOrDefault(user.Username, user.Nickname, user.Avatar)
}
// ConvertUserToResponse 将User转换为UserResponse
func ConvertUserToResponse(user *model.User) *UserResponse {
if user == nil {
return nil
}
return &UserResponse{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Phone: user.Phone,
EmailVerified: user.EmailVerified,
Avatar: getAvatarOrDefault(user),
CoverURL: user.CoverURL,
Bio: user.Bio,
Website: user.Website,
Location: user.Location,
PostsCount: user.PostsCount,
FollowersCount: user.FollowersCount,
FollowingCount: user.FollowingCount,
CreatedAt: FormatTime(user.CreatedAt),
}
}
// ConvertUserToResponseWithFollowing 将User转换为UserResponse包含关注状态
func ConvertUserToResponseWithFollowing(user *model.User, isFollowing bool) *UserResponse {
if user == nil {
return nil
}
return &UserResponse{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Phone: user.Phone,
EmailVerified: user.EmailVerified,
Avatar: getAvatarOrDefault(user),
CoverURL: user.CoverURL,
Bio: user.Bio,
Website: user.Website,
Location: user.Location,
PostsCount: user.PostsCount,
FollowersCount: user.FollowersCount,
FollowingCount: user.FollowingCount,
IsFollowing: isFollowing,
IsFollowingMe: false, // 默认false需要单独计算
CreatedAt: FormatTime(user.CreatedAt),
}
}
// ConvertUserToResponseWithPostsCount 将User转换为UserResponse使用实时计算的帖子数量
func ConvertUserToResponseWithPostsCount(user *model.User, postsCount int) *UserResponse {
if user == nil {
return nil
}
return &UserResponse{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Phone: user.Phone,
EmailVerified: user.EmailVerified,
Avatar: getAvatarOrDefault(user),
CoverURL: user.CoverURL,
Bio: user.Bio,
Website: user.Website,
Location: user.Location,
PostsCount: postsCount,
FollowersCount: user.FollowersCount,
FollowingCount: user.FollowingCount,
CreatedAt: FormatTime(user.CreatedAt),
}
}
// ConvertUserToResponseWithMutualFollow 将User转换为UserResponse包含双向关注状态
func ConvertUserToResponseWithMutualFollow(user *model.User, isFollowing, isFollowingMe bool) *UserResponse {
if user == nil {
return nil
}
return &UserResponse{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Phone: user.Phone,
EmailVerified: user.EmailVerified,
Avatar: getAvatarOrDefault(user),
CoverURL: user.CoverURL,
Bio: user.Bio,
Website: user.Website,
Location: user.Location,
PostsCount: user.PostsCount,
FollowersCount: user.FollowersCount,
FollowingCount: user.FollowingCount,
IsFollowing: isFollowing,
IsFollowingMe: isFollowingMe,
CreatedAt: FormatTime(user.CreatedAt),
}
}
// ConvertUserToResponseWithMutualFollowAndPostsCount 将User转换为UserResponse包含双向关注状态和实时计算的帖子数量
func ConvertUserToResponseWithMutualFollowAndPostsCount(user *model.User, isFollowing, isFollowingMe bool, postsCount int) *UserResponse {
if user == nil {
return nil
}
return &UserResponse{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Phone: user.Phone,
EmailVerified: user.EmailVerified,
Avatar: getAvatarOrDefault(user),
CoverURL: user.CoverURL,
Bio: user.Bio,
Website: user.Website,
Location: user.Location,
PostsCount: postsCount,
FollowersCount: user.FollowersCount,
FollowingCount: user.FollowingCount,
IsFollowing: isFollowing,
IsFollowingMe: isFollowingMe,
CreatedAt: FormatTime(user.CreatedAt),
}
}
// ConvertUserToDetailResponse 将User转换为UserDetailResponse
func ConvertUserToDetailResponse(user *model.User) *UserDetailResponse {
if user == nil {
return nil
}
return &UserDetailResponse{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
EmailVerified: user.EmailVerified,
Avatar: getAvatarOrDefault(user),
CoverURL: user.CoverURL,
Bio: user.Bio,
Website: user.Website,
Location: user.Location,
PostsCount: user.PostsCount,
FollowersCount: user.FollowersCount,
FollowingCount: user.FollowingCount,
IsVerified: user.IsVerified,
CreatedAt: FormatTime(user.CreatedAt),
}
}
// ConvertUserToDetailResponseWithPostsCount 将User转换为UserDetailResponse使用实时计算的帖子数量
func ConvertUserToDetailResponseWithPostsCount(user *model.User, postsCount int) *UserDetailResponse {
if user == nil {
return nil
}
return &UserDetailResponse{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
EmailVerified: user.EmailVerified,
Phone: user.Phone, // 仅当前用户自己可见
Avatar: getAvatarOrDefault(user),
CoverURL: user.CoverURL,
Bio: user.Bio,
Website: user.Website,
Location: user.Location,
PostsCount: postsCount,
FollowersCount: user.FollowersCount,
FollowingCount: user.FollowingCount,
IsVerified: user.IsVerified,
CreatedAt: FormatTime(user.CreatedAt),
}
}
// ConvertUsersToResponse 将User列表转换为响应列表
func ConvertUsersToResponse(users []*model.User) []*UserResponse {
result := make([]*UserResponse, 0, len(users))
for _, user := range users {
result = append(result, ConvertUserToResponse(user))
}
return result
}
// ConvertUsersToResponseWithMutualFollow 将User列表转换为响应列表包含双向关注状态
// followingStatusMap: key是用户IDvalue是[isFollowing, isFollowingMe]
func ConvertUsersToResponseWithMutualFollow(users []*model.User, followingStatusMap map[string][2]bool) []*UserResponse {
result := make([]*UserResponse, 0, len(users))
for _, user := range users {
status, ok := followingStatusMap[user.ID]
if ok {
result = append(result, ConvertUserToResponseWithMutualFollow(user, status[0], status[1]))
} else {
result = append(result, ConvertUserToResponse(user))
}
}
return result
}
// ConvertUsersToResponseWithMutualFollowAndPostsCount 将User列表转换为响应列表包含双向关注状态和实时计算的帖子数量
// followingStatusMap: key是用户IDvalue是[isFollowing, isFollowingMe]
// postsCountMap: key是用户IDvalue是帖子数量
func ConvertUsersToResponseWithMutualFollowAndPostsCount(users []*model.User, followingStatusMap map[string][2]bool, postsCountMap map[string]int64) []*UserResponse {
result := make([]*UserResponse, 0, len(users))
for _, user := range users {
status, hasStatus := followingStatusMap[user.ID]
postsCount, hasPostsCount := postsCountMap[user.ID]
// 如果没有帖子数量,使用数据库中的值
if !hasPostsCount {
postsCount = int64(user.PostsCount)
}
if hasStatus {
result = append(result, ConvertUserToResponseWithMutualFollowAndPostsCount(user, status[0], status[1], int(postsCount)))
} else {
result = append(result, ConvertUserToResponseWithPostsCount(user, int(postsCount)))
}
}
return result
}
// ==================== Post 转换 ====================
// ConvertPostImageToResponse 将PostImage转换为PostImageResponse
func ConvertPostImageToResponse(img *model.PostImage) PostImageResponse {
if img == nil {
return PostImageResponse{}
}
return PostImageResponse{
ID: img.ID,
URL: img.URL,
ThumbnailURL: img.ThumbnailURL,
Width: img.Width,
Height: img.Height,
}
}
// ConvertPostImagesToResponse 将PostImage列表转换为响应列表
func ConvertPostImagesToResponse(images []model.PostImage) []PostImageResponse {
result := make([]PostImageResponse, 0, len(images))
for i := range images {
result = append(result, ConvertPostImageToResponse(&images[i]))
}
return result
}
// ConvertPostToResponse 将Post转换为PostResponse列表用
func ConvertPostToResponse(post *model.Post, isLiked, isFavorited bool) *PostResponse {
if post == nil {
return nil
}
images := make([]PostImageResponse, 0)
for _, img := range post.Images {
images = append(images, ConvertPostImageToResponse(&img))
}
var author *UserResponse
if post.User != nil {
author = ConvertUserToResponse(post.User)
}
return &PostResponse{
ID: post.ID,
UserID: post.UserID,
Title: post.Title,
Content: post.Content,
Images: images,
LikesCount: post.LikesCount,
CommentsCount: post.CommentsCount,
FavoritesCount: post.FavoritesCount,
SharesCount: post.SharesCount,
ViewsCount: post.ViewsCount,
IsPinned: post.IsPinned,
IsLocked: post.IsLocked,
IsVote: post.IsVote,
CreatedAt: FormatTime(post.CreatedAt),
Author: author,
IsLiked: isLiked,
IsFavorited: isFavorited,
}
}
// ConvertPostToDetailResponse 将Post转换为PostDetailResponse
func ConvertPostToDetailResponse(post *model.Post, isLiked, isFavorited bool) *PostDetailResponse {
if post == nil {
return nil
}
images := make([]PostImageResponse, 0)
for _, img := range post.Images {
images = append(images, ConvertPostImageToResponse(&img))
}
var author *UserResponse
if post.User != nil {
author = ConvertUserToResponse(post.User)
}
return &PostDetailResponse{
ID: post.ID,
UserID: post.UserID,
Title: post.Title,
Content: post.Content,
Images: images,
Status: string(post.Status),
LikesCount: post.LikesCount,
CommentsCount: post.CommentsCount,
FavoritesCount: post.FavoritesCount,
SharesCount: post.SharesCount,
ViewsCount: post.ViewsCount,
IsPinned: post.IsPinned,
IsLocked: post.IsLocked,
IsVote: post.IsVote,
CreatedAt: FormatTime(post.CreatedAt),
UpdatedAt: FormatTime(post.UpdatedAt),
Author: author,
IsLiked: isLiked,
IsFavorited: isFavorited,
}
}
// ConvertPostsToResponse 将Post列表转换为响应列表每个帖子独立检查点赞/收藏状态)
func ConvertPostsToResponse(posts []*model.Post, isLikedMap, isFavoritedMap map[string]bool) []*PostResponse {
result := make([]*PostResponse, 0, len(posts))
for _, post := range posts {
isLiked := false
isFavorited := false
if isLikedMap != nil {
isLiked = isLikedMap[post.ID]
}
if isFavoritedMap != nil {
isFavorited = isFavoritedMap[post.ID]
}
result = append(result, ConvertPostToResponse(post, isLiked, isFavorited))
}
return result
}
// ==================== Comment 转换 ====================
// ConvertCommentToResponse 将Comment转换为CommentResponse
func ConvertCommentToResponse(comment *model.Comment, isLiked bool) *CommentResponse {
if comment == nil {
return nil
}
var author *UserResponse
if comment.User != nil {
author = ConvertUserToResponse(comment.User)
}
// 转换子回复(扁平化结构)
var replies []*CommentResponse
if len(comment.Replies) > 0 {
replies = make([]*CommentResponse, 0, len(comment.Replies))
for _, reply := range comment.Replies {
replies = append(replies, ConvertCommentToResponse(reply, false))
}
}
// TargetID 就是 ParentID前端根据这个 ID 找到被回复用户的昵称
var targetID *string
if comment.ParentID != nil && *comment.ParentID != "" {
targetID = comment.ParentID
}
// 解析图片JSON
var images []CommentImageResponse
if comment.Images != "" {
var urlList []string
if err := json.Unmarshal([]byte(comment.Images), &urlList); err == nil {
images = make([]CommentImageResponse, 0, len(urlList))
for _, url := range urlList {
images = append(images, CommentImageResponse{URL: url})
}
}
}
return &CommentResponse{
ID: comment.ID,
PostID: comment.PostID,
UserID: comment.UserID,
ParentID: comment.ParentID,
RootID: comment.RootID,
Content: comment.Content,
Images: images,
LikesCount: comment.LikesCount,
RepliesCount: comment.RepliesCount,
CreatedAt: FormatTime(comment.CreatedAt),
Author: author,
IsLiked: isLiked,
TargetID: targetID,
Replies: replies,
}
}
// ConvertCommentsToResponse 将Comment列表转换为响应列表
func ConvertCommentsToResponse(comments []*model.Comment, isLiked bool) []*CommentResponse {
result := make([]*CommentResponse, 0, len(comments))
for _, comment := range comments {
result = append(result, ConvertCommentToResponse(comment, isLiked))
}
return result
}
// IsLikedChecker 点赞状态检查器接口
type IsLikedChecker interface {
IsLiked(ctx context.Context, commentID, userID string) bool
}
// ConvertCommentToResponseWithUser 将Comment转换为CommentResponse根据用户ID检查点赞状态
func ConvertCommentToResponseWithUser(comment *model.Comment, userID string, checker IsLikedChecker) *CommentResponse {
if comment == nil {
return nil
}
// 检查当前用户是否点赞了该评论
isLiked := false
if userID != "" && checker != nil {
isLiked = checker.IsLiked(context.Background(), comment.ID, userID)
}
var author *UserResponse
if comment.User != nil {
author = ConvertUserToResponse(comment.User)
}
// 转换子回复(扁平化结构),递归检查点赞状态
var replies []*CommentResponse
if len(comment.Replies) > 0 {
replies = make([]*CommentResponse, 0, len(comment.Replies))
for _, reply := range comment.Replies {
replies = append(replies, ConvertCommentToResponseWithUser(reply, userID, checker))
}
}
// TargetID 就是 ParentID前端根据这个 ID 找到被回复用户的昵称
var targetID *string
if comment.ParentID != nil && *comment.ParentID != "" {
targetID = comment.ParentID
}
// 解析图片JSON
var images []CommentImageResponse
if comment.Images != "" {
var urlList []string
if err := json.Unmarshal([]byte(comment.Images), &urlList); err == nil {
images = make([]CommentImageResponse, 0, len(urlList))
for _, url := range urlList {
images = append(images, CommentImageResponse{URL: url})
}
}
}
return &CommentResponse{
ID: comment.ID,
PostID: comment.PostID,
UserID: comment.UserID,
ParentID: comment.ParentID,
RootID: comment.RootID,
Content: comment.Content,
Images: images,
LikesCount: comment.LikesCount,
RepliesCount: comment.RepliesCount,
CreatedAt: FormatTime(comment.CreatedAt),
Author: author,
IsLiked: isLiked,
TargetID: targetID,
Replies: replies,
}
}
// ConvertCommentsToResponseWithUser 将Comment列表转换为响应列表根据用户ID检查点赞状态
func ConvertCommentsToResponseWithUser(comments []*model.Comment, userID string, checker IsLikedChecker) []*CommentResponse {
result := make([]*CommentResponse, 0, len(comments))
for _, comment := range comments {
result = append(result, ConvertCommentToResponseWithUser(comment, userID, checker))
}
return result
}
// ==================== Notification 转换 ====================
// ConvertNotificationToResponse 将Notification转换为NotificationResponse
func ConvertNotificationToResponse(notification *model.Notification) *NotificationResponse {
if notification == nil {
return nil
}
return &NotificationResponse{
ID: notification.ID,
UserID: notification.UserID,
Type: string(notification.Type),
Title: notification.Title,
Content: notification.Content,
Data: notification.Data,
IsRead: notification.IsRead,
CreatedAt: FormatTime(notification.CreatedAt),
}
}
// ConvertNotificationsToResponse 将Notification列表转换为响应列表
func ConvertNotificationsToResponse(notifications []*model.Notification) []*NotificationResponse {
result := make([]*NotificationResponse, 0, len(notifications))
for _, n := range notifications {
result = append(result, ConvertNotificationToResponse(n))
}
return result
}
// ==================== Message 转换 ====================
// ConvertMessageToResponse 将Message转换为MessageResponse
func ConvertMessageToResponse(message *model.Message) *MessageResponse {
if message == nil {
return nil
}
// 直接使用 segments不需要解析
segments := make(model.MessageSegments, len(message.Segments))
for i, seg := range message.Segments {
segments[i] = model.MessageSegment{
Type: seg.Type,
Data: seg.Data,
}
}
return &MessageResponse{
ID: message.ID,
ConversationID: message.ConversationID,
SenderID: message.SenderID,
Seq: message.Seq,
Segments: segments,
ReplyToID: message.ReplyToID,
Status: string(message.Status),
Category: string(message.Category),
CreatedAt: FormatTime(message.CreatedAt),
}
}
// ConvertConversationToResponse 将Conversation转换为ConversationResponse
// participants: 会话参与者列表(用户信息)
// unreadCount: 当前用户的未读消息数
// lastMessage: 最后一条消息
func ConvertConversationToResponse(conv *model.Conversation, participants []*model.User, unreadCount int, lastMessage *model.Message, isPinned bool) *ConversationResponse {
if conv == nil {
return nil
}
var participantResponses []*UserResponse
for _, p := range participants {
participantResponses = append(participantResponses, ConvertUserToResponse(p))
}
// 转换群组信息
var groupResponse *GroupResponse
if conv.Group != nil {
groupResponse = GroupToResponse(conv.Group)
}
return &ConversationResponse{
ID: conv.ID,
Type: string(conv.Type),
IsPinned: isPinned,
Group: groupResponse,
LastSeq: conv.LastSeq,
LastMessage: ConvertMessageToResponse(lastMessage),
LastMessageAt: FormatTimePointer(conv.LastMsgTime),
UnreadCount: unreadCount,
Participants: participantResponses,
CreatedAt: FormatTime(conv.CreatedAt),
UpdatedAt: FormatTime(conv.UpdatedAt),
}
}
// ConvertConversationToDetailResponse 将Conversation转换为ConversationDetailResponse
func ConvertConversationToDetailResponse(conv *model.Conversation, participants []*model.User, unreadCount int64, lastMessage *model.Message, myLastReadSeq int64, otherLastReadSeq int64, isPinned bool) *ConversationDetailResponse {
if conv == nil {
return nil
}
var participantResponses []*UserResponse
for _, p := range participants {
participantResponses = append(participantResponses, ConvertUserToResponse(p))
}
return &ConversationDetailResponse{
ID: conv.ID,
Type: string(conv.Type),
IsPinned: isPinned,
LastSeq: conv.LastSeq,
LastMessage: ConvertMessageToResponse(lastMessage),
LastMessageAt: FormatTimePointer(conv.LastMsgTime),
UnreadCount: unreadCount,
Participants: participantResponses,
MyLastReadSeq: myLastReadSeq,
OtherLastReadSeq: otherLastReadSeq,
CreatedAt: FormatTime(conv.CreatedAt),
UpdatedAt: FormatTime(conv.UpdatedAt),
}
}
// ConvertMessagesToResponse 将Message列表转换为响应列表
func ConvertMessagesToResponse(messages []*model.Message) []*MessageResponse {
result := make([]*MessageResponse, 0, len(messages))
for _, msg := range messages {
result = append(result, ConvertMessageToResponse(msg))
}
return result
}
// ConvertConversationsToResponse 将Conversation列表转换为响应列表
func ConvertConversationsToResponse(convs []*model.Conversation) []*ConversationResponse {
result := make([]*ConversationResponse, 0, len(convs))
for _, conv := range convs {
result = append(result, ConvertConversationToResponse(conv, nil, 0, nil, false))
}
return result
}
// ==================== PushRecord 转换 ====================
// PushRecordToResponse 将PushRecord转换为PushRecordResponse
func PushRecordToResponse(record *model.PushRecord) *PushRecordResponse {
if record == nil {
return nil
}
resp := &PushRecordResponse{
ID: record.ID,
MessageID: record.MessageID,
PushChannel: string(record.PushChannel),
PushStatus: string(record.PushStatus),
RetryCount: record.RetryCount,
CreatedAt: record.CreatedAt,
}
if record.PushedAt != nil {
resp.PushedAt = *record.PushedAt
}
if record.DeliveredAt != nil {
resp.DeliveredAt = *record.DeliveredAt
}
return resp
}
// PushRecordsToResponse 将PushRecord列表转换为响应列表
func PushRecordsToResponse(records []*model.PushRecord) []*PushRecordResponse {
result := make([]*PushRecordResponse, 0, len(records))
for _, record := range records {
result = append(result, PushRecordToResponse(record))
}
return result
}
// ==================== DeviceToken 转换 ====================
// DeviceTokenToResponse 将DeviceToken转换为DeviceTokenResponse
func DeviceTokenToResponse(token *model.DeviceToken) *DeviceTokenResponse {
if token == nil {
return nil
}
resp := &DeviceTokenResponse{
ID: token.ID,
DeviceID: token.DeviceID,
DeviceType: string(token.DeviceType),
IsActive: token.IsActive,
DeviceName: token.DeviceName,
CreatedAt: token.CreatedAt,
}
if token.LastUsedAt != nil {
resp.LastUsedAt = *token.LastUsedAt
}
return resp
}
// DeviceTokensToResponse 将DeviceToken列表转换为响应列表
func DeviceTokensToResponse(tokens []*model.DeviceToken) []*DeviceTokenResponse {
result := make([]*DeviceTokenResponse, 0, len(tokens))
for _, token := range tokens {
result = append(result, DeviceTokenToResponse(token))
}
return result
}
// ==================== SystemMessage 转换 ====================
// SystemMessageToResponse 将Message转换为SystemMessageResponse
func SystemMessageToResponse(msg *model.Message) *SystemMessageResponse {
if msg == nil {
return nil
}
// 从 segments 中提取文本内容
content := ExtractTextContentFromModel(msg.Segments)
resp := &SystemMessageResponse{
ID: msg.ID,
SenderID: msg.SenderID,
ReceiverID: "", // 系统消息的接收者需要从上下文获取
Content: content,
Category: string(msg.Category),
SystemType: string(msg.SystemType),
CreatedAt: msg.CreatedAt,
}
if msg.ExtraData != nil {
resp.ExtraData = map[string]interface{}{
"actor_id": msg.ExtraData.ActorID,
"actor_name": msg.ExtraData.ActorName,
"avatar_url": msg.ExtraData.AvatarURL,
"target_id": msg.ExtraData.TargetID,
"target_type": msg.ExtraData.TargetType,
"action_url": msg.ExtraData.ActionURL,
"action_time": msg.ExtraData.ActionTime,
}
}
return resp
}
// SystemMessagesToResponse 将Message列表转换为SystemMessageResponse列表
func SystemMessagesToResponse(messages []*model.Message) []*SystemMessageResponse {
result := make([]*SystemMessageResponse, 0, len(messages))
for _, msg := range messages {
result = append(result, SystemMessageToResponse(msg))
}
return result
}
// SystemNotificationToResponse 将SystemNotification转换为SystemMessageResponse
func SystemNotificationToResponse(n *model.SystemNotification) *SystemMessageResponse {
if n == nil {
return nil
}
resp := &SystemMessageResponse{
ID: strconv.FormatInt(n.ID, 10),
SenderID: model.SystemSenderIDStr, // 系统发送者
ReceiverID: n.ReceiverID,
Content: n.Content,
Category: "notification",
SystemType: string(n.Type),
IsRead: n.IsRead,
CreatedAt: n.CreatedAt,
}
if n.ExtraData != nil {
resp.ExtraData = map[string]interface{}{
"actor_id": n.ExtraData.ActorID,
"actor_id_str": n.ExtraData.ActorIDStr,
"actor_name": n.ExtraData.ActorName,
"avatar_url": n.ExtraData.AvatarURL,
"target_id": n.ExtraData.TargetID,
"target_title": n.ExtraData.TargetTitle,
"target_type": n.ExtraData.TargetType,
"action_url": n.ExtraData.ActionURL,
"action_time": n.ExtraData.ActionTime,
"group_id": n.ExtraData.GroupID,
"group_name": n.ExtraData.GroupName,
"group_avatar": n.ExtraData.GroupAvatar,
"group_description": n.ExtraData.GroupDescription,
"flag": n.ExtraData.Flag,
"request_type": n.ExtraData.RequestType,
"request_status": n.ExtraData.RequestStatus,
"reason": n.ExtraData.Reason,
"target_user_id": n.ExtraData.TargetUserID,
"target_user_name": n.ExtraData.TargetUserName,
"target_user_avatar": n.ExtraData.TargetUserAvatar,
}
}
return resp
}
// SystemNotificationsToResponse 将SystemNotification列表转换为SystemMessageResponse列表
func SystemNotificationsToResponse(notifications []*model.SystemNotification) []*SystemMessageResponse {
result := make([]*SystemMessageResponse, 0, len(notifications))
for _, n := range notifications {
result = append(result, SystemNotificationToResponse(n))
}
return result
}
// ==================== Group 转换 ====================
// GroupToResponse 将Group转换为GroupResponse
func GroupToResponse(group *model.Group) *GroupResponse {
if group == nil {
return nil
}
return &GroupResponse{
ID: group.ID,
Name: group.Name,
Avatar: group.Avatar,
Description: group.Description,
OwnerID: group.OwnerID,
MemberCount: group.MemberCount,
MaxMembers: group.MaxMembers,
JoinType: int(group.JoinType),
MuteAll: group.MuteAll,
CreatedAt: FormatTime(group.CreatedAt),
}
}
// GroupsToResponse 将Group列表转换为GroupResponse列表
func GroupsToResponse(groups []model.Group) []*GroupResponse {
result := make([]*GroupResponse, 0, len(groups))
for i := range groups {
result = append(result, GroupToResponse(&groups[i]))
}
return result
}
// GroupMemberToResponse 将GroupMember转换为GroupMemberResponse
func GroupMemberToResponse(member *model.GroupMember) *GroupMemberResponse {
if member == nil {
return nil
}
return &GroupMemberResponse{
ID: member.ID,
GroupID: member.GroupID,
UserID: member.UserID,
Role: member.Role,
Nickname: member.Nickname,
Muted: member.Muted,
JoinTime: FormatTime(member.JoinTime),
}
}
// GroupMemberToResponseWithUser 将GroupMember转换为GroupMemberResponse包含用户信息
func GroupMemberToResponseWithUser(member *model.GroupMember, user *model.User) *GroupMemberResponse {
if member == nil {
return nil
}
resp := GroupMemberToResponse(member)
if user != nil {
resp.User = ConvertUserToResponse(user)
}
return resp
}
// GroupMembersToResponse 将GroupMember列表转换为GroupMemberResponse列表
func GroupMembersToResponse(members []model.GroupMember) []*GroupMemberResponse {
result := make([]*GroupMemberResponse, 0, len(members))
for i := range members {
result = append(result, GroupMemberToResponse(&members[i]))
}
return result
}
// GroupAnnouncementToResponse 将GroupAnnouncement转换为GroupAnnouncementResponse
func GroupAnnouncementToResponse(announcement *model.GroupAnnouncement) *GroupAnnouncementResponse {
if announcement == nil {
return nil
}
return &GroupAnnouncementResponse{
ID: announcement.ID,
GroupID: announcement.GroupID,
Content: announcement.Content,
AuthorID: announcement.AuthorID,
IsPinned: announcement.IsPinned,
CreatedAt: FormatTime(announcement.CreatedAt),
}
}
// GroupAnnouncementsToResponse 将GroupAnnouncement列表转换为GroupAnnouncementResponse列表
func GroupAnnouncementsToResponse(announcements []model.GroupAnnouncement) []*GroupAnnouncementResponse {
result := make([]*GroupAnnouncementResponse, 0, len(announcements))
for i := range announcements {
result = append(result, GroupAnnouncementToResponse(&announcements[i]))
}
return result
}

819
internal/dto/dto.go Normal file
View File

@@ -0,0 +1,819 @@
package dto
import (
"carrot_bbs/internal/model"
"time"
)
// ==================== User DTOs ====================
// UserResponse 用户信息响应
type UserResponse struct {
ID string `json:"id"`
Username string `json:"username"`
Nickname string `json:"nickname"`
Email *string `json:"email,omitempty"`
Phone *string `json:"phone,omitempty"`
EmailVerified bool `json:"email_verified"`
Avatar string `json:"avatar"`
CoverURL string `json:"cover_url"` // 头图URL
Bio string `json:"bio"`
Website string `json:"website"`
Location string `json:"location"`
PostsCount int `json:"posts_count"`
FollowersCount int `json:"followers_count"`
FollowingCount int `json:"following_count"`
IsFollowing bool `json:"is_following"` // 当前用户是否关注了该用户
IsFollowingMe bool `json:"is_following_me"` // 该用户是否关注了当前用户
CreatedAt string `json:"created_at"`
}
// UserDetailResponse 用户详情响应
type UserDetailResponse struct {
ID string `json:"id"`
Username string `json:"username"`
Nickname string `json:"nickname"`
Email *string `json:"email"`
EmailVerified bool `json:"email_verified"`
Phone *string `json:"phone,omitempty"` // 仅当前用户自己可见
Avatar string `json:"avatar"`
CoverURL string `json:"cover_url"` // 头图URL
Bio string `json:"bio"`
Website string `json:"website"`
Location string `json:"location"`
PostsCount int `json:"posts_count"`
FollowersCount int `json:"followers_count"`
FollowingCount int `json:"following_count"`
IsVerified bool `json:"is_verified"`
IsFollowing bool `json:"is_following"` // 当前用户是否关注了该用户
IsFollowingMe bool `json:"is_following_me"` // 该用户是否关注了当前用户
CreatedAt string `json:"created_at"`
}
// ==================== Post DTOs ====================
// PostImageResponse 帖子图片响应
type PostImageResponse struct {
ID string `json:"id"`
URL string `json:"url"`
ThumbnailURL string `json:"thumbnail_url"`
Width int `json:"width"`
Height int `json:"height"`
}
// PostResponse 帖子响应(列表用)
type PostResponse struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Title string `json:"title"`
Content string `json:"content"`
Images []PostImageResponse `json:"images"`
LikesCount int `json:"likes_count"`
CommentsCount int `json:"comments_count"`
FavoritesCount int `json:"favorites_count"`
SharesCount int `json:"shares_count"`
ViewsCount int `json:"views_count"`
IsPinned bool `json:"is_pinned"`
IsLocked bool `json:"is_locked"`
IsVote bool `json:"is_vote"`
CreatedAt string `json:"created_at"`
Author *UserResponse `json:"author"`
IsLiked bool `json:"is_liked"`
IsFavorited bool `json:"is_favorited"`
}
// PostDetailResponse 帖子详情响应
type PostDetailResponse struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Title string `json:"title"`
Content string `json:"content"`
Images []PostImageResponse `json:"images"`
Status string `json:"status"`
LikesCount int `json:"likes_count"`
CommentsCount int `json:"comments_count"`
FavoritesCount int `json:"favorites_count"`
SharesCount int `json:"shares_count"`
ViewsCount int `json:"views_count"`
IsPinned bool `json:"is_pinned"`
IsLocked bool `json:"is_locked"`
IsVote bool `json:"is_vote"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
Author *UserResponse `json:"author"`
IsLiked bool `json:"is_liked"`
IsFavorited bool `json:"is_favorited"`
}
// ==================== Comment DTOs ====================
// CommentImageResponse 评论图片响应
type CommentImageResponse struct {
URL string `json:"url"`
}
// CommentResponse 评论响应扁平化结构类似B站/抖音)
// 第一层级正常展示,第二三四五层级在第一层级的评论区扁平展示
type CommentResponse struct {
ID string `json:"id"`
PostID string `json:"post_id"`
UserID string `json:"user_id"`
ParentID *string `json:"parent_id"`
RootID *string `json:"root_id"`
Content string `json:"content"`
Images []CommentImageResponse `json:"images"`
LikesCount int `json:"likes_count"`
RepliesCount int `json:"replies_count"`
CreatedAt string `json:"created_at"`
Author *UserResponse `json:"author"`
IsLiked bool `json:"is_liked"`
TargetID *string `json:"target_id,omitempty"` // 被回复的评论ID前端根据此ID找到被回复用户的昵称
Replies []*CommentResponse `json:"replies,omitempty"` // 子回复列表(扁平化,所有层级都在这里)
}
// ==================== Notification DTOs ====================
// NotificationResponse 通知响应
type NotificationResponse struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Type string `json:"type"`
Title string `json:"title"`
Content string `json:"content"`
Data string `json:"data"`
IsRead bool `json:"is_read"`
CreatedAt string `json:"created_at"`
}
// ==================== Message Segment DTOs ====================
// SegmentType Segment类型
type SegmentType string
const (
SegmentTypeText SegmentType = "text"
SegmentTypeImage SegmentType = "image"
SegmentTypeVoice SegmentType = "voice"
SegmentTypeVideo SegmentType = "video"
SegmentTypeFile SegmentType = "file"
SegmentTypeAt SegmentType = "at"
SegmentTypeReply SegmentType = "reply"
SegmentTypeFace SegmentType = "face"
SegmentTypeLink SegmentType = "link"
)
// TextSegmentData 文本数据
type TextSegmentData struct {
Text string `json:"text"`
}
// ImageSegmentData 图片数据
type ImageSegmentData struct {
URL string `json:"url"`
Width int `json:"width,omitempty"`
Height int `json:"height,omitempty"`
ThumbnailURL string `json:"thumbnail_url,omitempty"`
FileSize int64 `json:"file_size,omitempty"`
}
// VoiceSegmentData 语音数据
type VoiceSegmentData struct {
URL string `json:"url"`
Duration int `json:"duration,omitempty"` // 秒
FileSize int64 `json:"file_size,omitempty"`
}
// VideoSegmentData 视频数据
type VideoSegmentData struct {
URL string `json:"url"`
Width int `json:"width,omitempty"`
Height int `json:"height,omitempty"`
Duration int `json:"duration,omitempty"` // 秒
ThumbnailURL string `json:"thumbnail_url,omitempty"`
FileSize int64 `json:"file_size,omitempty"`
}
// FileSegmentData 文件数据
type FileSegmentData struct {
URL string `json:"url"`
Name string `json:"name"`
Size int64 `json:"size,omitempty"`
MimeType string `json:"mime_type,omitempty"`
}
// AtSegmentData @数据
type AtSegmentData struct {
UserID string `json:"user_id"` // "all" 表示@所有人
Nickname string `json:"nickname,omitempty"`
}
// ReplySegmentData 回复数据
type ReplySegmentData struct {
ID string `json:"id"` // 被回复消息的ID
}
// FaceSegmentData 表情数据
type FaceSegmentData struct {
ID int `json:"id"`
Name string `json:"name,omitempty"`
URL string `json:"url,omitempty"`
}
// LinkSegmentData 链接数据
type LinkSegmentData struct {
URL string `json:"url"`
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
Image string `json:"image,omitempty"`
}
// ==================== Message DTOs ====================
// MessageResponse 消息响应
type MessageResponse struct {
ID string `json:"id"`
ConversationID string `json:"conversation_id"`
SenderID string `json:"sender_id"`
Seq int64 `json:"seq"`
Segments model.MessageSegments `json:"segments"` // 消息链(必须)
ReplyToID *string `json:"reply_to_id,omitempty"` // 被回复消息的ID用于关联查找
Status string `json:"status"`
Category string `json:"category,omitempty"` // 消息类别chat, notification, announcement
CreatedAt string `json:"created_at"`
Sender *UserResponse `json:"sender"`
}
// ConversationResponse 会话响应
type ConversationResponse struct {
ID string `json:"id"`
Type string `json:"type"`
IsPinned bool `json:"is_pinned"`
Group *GroupResponse `json:"group,omitempty"`
LastSeq int64 `json:"last_seq"`
LastMessage *MessageResponse `json:"last_message"`
LastMessageAt string `json:"last_message_at"`
UnreadCount int `json:"unread_count"`
Participants []*UserResponse `json:"participants,omitempty"` // 私聊时使用
MemberCount int `json:"member_count,omitempty"` // 群聊时使用
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// ConversationParticipantResponse 会话参与者响应
type ConversationParticipantResponse struct {
UserID string `json:"user_id"`
LastReadSeq int64 `json:"last_read_seq"`
Muted bool `json:"muted"`
IsPinned bool `json:"is_pinned"`
}
// ==================== Auth DTOs ====================
// LoginResponse 登录响应
type LoginResponse struct {
User *UserResponse `json:"user"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
// RegisterResponse 注册响应
type RegisterResponse struct {
User *UserResponse `json:"user"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
// RefreshTokenResponse 刷新Token响应
type RefreshTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
// ==================== Common DTOs ====================
// SuccessResponse 通用成功响应
type SuccessResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
}
// AvailableResponse 可用性检查响应
type AvailableResponse struct {
Available bool `json:"available"`
}
// CountResponse 数量响应
type CountResponse struct {
Count int `json:"count"`
}
// URLResponse URL响应
type URLResponse struct {
URL string `json:"url"`
}
// ==================== Chat Request DTOs ====================
// CreateConversationRequest 创建会话请求
type CreateConversationRequest struct {
UserID string `json:"user_id" binding:"required"` // 目标用户ID (UUID格式)
}
// SendMessageRequest 发送消息请求
type SendMessageRequest struct {
Segments model.MessageSegments `json:"segments" binding:"required"` // 消息链(必须)
ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID (string类型)
}
// MarkReadRequest 标记已读请求
type MarkReadRequest struct {
LastReadSeq int64 `json:"last_read_seq" binding:"required"` // 已读到的seq位置
}
// SetConversationPinnedRequest 设置会话置顶请求
type SetConversationPinnedRequest struct {
ConversationID string `json:"conversation_id" binding:"required"`
IsPinned bool `json:"is_pinned"`
}
// ==================== Chat Response DTOs ====================
// ConversationListResponse 会话列表响应
type ConversationListResponse struct {
Conversations []*ConversationResponse `json:"list"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// ConversationDetailResponse 会话详情响应
type ConversationDetailResponse struct {
ID string `json:"id"`
Type string `json:"type"`
IsPinned bool `json:"is_pinned"`
LastSeq int64 `json:"last_seq"`
LastMessage *MessageResponse `json:"last_message"`
LastMessageAt string `json:"last_message_at"`
UnreadCount int64 `json:"unread_count"`
Participants []*UserResponse `json:"participants"`
MyLastReadSeq int64 `json:"my_last_read_seq"` // 当前用户的已读位置
OtherLastReadSeq int64 `json:"other_last_read_seq"` // 对方用户的已读位置
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// UnreadCountResponse 未读数响应
type UnreadCountResponse struct {
TotalUnreadCount int64 `json:"total_unread_count"` // 所有会话的未读总数
}
// ConversationUnreadCountResponse 单个会话未读数响应
type ConversationUnreadCountResponse struct {
ConversationID string `json:"conversation_id"`
UnreadCount int64 `json:"unread_count"`
}
// MessageListResponse 消息列表响应
type MessageListResponse struct {
Messages []*MessageResponse `json:"messages"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// MessageSyncResponse 消息同步响应(增量同步)
type MessageSyncResponse struct {
Messages []*MessageResponse `json:"messages"`
HasMore bool `json:"has_more"`
}
// ==================== 设备Token DTOs ====================
// RegisterDeviceRequest 注册设备请求
type RegisterDeviceRequest struct {
DeviceID string `json:"device_id" binding:"required"`
DeviceType string `json:"device_type" binding:"required,oneof=ios android web"`
PushToken string `json:"push_token"`
DeviceName string `json:"device_name"`
}
// DeviceTokenResponse 设备Token响应
type DeviceTokenResponse struct {
ID int64 `json:"id"`
DeviceID string `json:"device_id"`
DeviceType string `json:"device_type"`
IsActive bool `json:"is_active"`
DeviceName string `json:"device_name"`
LastUsedAt time.Time `json:"last_used_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// ==================== 推送记录 DTOs ====================
// PushRecordResponse 推送记录响应
type PushRecordResponse struct {
ID int64 `json:"id"`
MessageID string `json:"message_id"`
PushChannel string `json:"push_channel"`
PushStatus string `json:"push_status"`
RetryCount int `json:"retry_count"`
PushedAt time.Time `json:"pushed_at,omitempty"`
DeliveredAt time.Time `json:"delivered_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// PushRecordListResponse 推送记录列表响应
type PushRecordListResponse struct {
Records []*PushRecordResponse `json:"records"`
Total int64 `json:"total"`
}
// ==================== 系统消息 DTOs ====================
// SystemMessageResponse 系统消息响应
type SystemMessageResponse struct {
ID string `json:"id"`
SenderID string `json:"sender_id"`
ReceiverID string `json:"receiver_id"`
Content string `json:"content"`
Category string `json:"category"`
SystemType string `json:"system_type"`
ExtraData map[string]interface{} `json:"extra_data,omitempty"`
IsRead bool `json:"is_read"`
CreatedAt time.Time `json:"created_at"`
}
// SystemMessageListResponse 系统消息列表响应
type SystemMessageListResponse struct {
Messages []*SystemMessageResponse `json:"messages"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// SystemUnreadCountResponse 系统消息未读数响应
type SystemUnreadCountResponse struct {
UnreadCount int64 `json:"unread_count"`
}
// ==================== 时间格式化 ====================
// FormatTime 格式化时间
func FormatTime(t time.Time) string {
if t.IsZero() {
return ""
}
return t.Format("2006-01-02T15:04:05Z07:00")
}
// FormatTimePointer 格式化时间指针
func FormatTimePointer(t *time.Time) string {
if t == nil {
return ""
}
return FormatTime(*t)
}
// ==================== Group DTOs ====================
// CreateGroupRequest 创建群组请求
type CreateGroupRequest struct {
Name string `json:"name" binding:"required,max=50"`
Description string `json:"description" binding:"max=500"`
MemberIDs []string `json:"member_ids"`
}
// UpdateGroupRequest 更新群组请求
type UpdateGroupRequest struct {
Name string `json:"name" binding:"omitempty,max=50"`
Description string `json:"description" binding:"omitempty,max=500"`
Avatar string `json:"avatar" binding:"omitempty,url"`
}
// InviteMembersRequest 邀请成员请求
type InviteMembersRequest struct {
MemberIDs []string `json:"member_ids" binding:"required,min=1"`
}
// TransferOwnerRequest 转让群主请求
type TransferOwnerRequest struct {
NewOwnerID string `json:"new_owner_id" binding:"required"`
}
// SetRoleRequest 设置角色请求
type SetRoleRequest struct {
Role string `json:"role" binding:"required,oneof=admin member"`
}
// SetNicknameRequest 设置昵称请求
type SetNicknameRequest struct {
Nickname string `json:"nickname" binding:"max=50"`
}
// MuteMemberRequest 禁言成员请求
type MuteMemberRequest struct {
Muted bool `json:"muted"`
}
// SetMuteAllRequest 设置全员禁言请求
type SetMuteAllRequest struct {
MuteAll bool `json:"mute_all"`
}
// SetJoinTypeRequest 设置加群方式请求
type SetJoinTypeRequest struct {
JoinType int `json:"join_type" binding:"min=0,max=2"`
}
// CreateAnnouncementRequest 创建群公告请求
type CreateAnnouncementRequest struct {
Content string `json:"content" binding:"required,max=2000"`
}
// GroupResponse 群组响应
type GroupResponse struct {
ID string `json:"id"`
Name string `json:"name"`
Avatar string `json:"avatar"`
Description string `json:"description"`
OwnerID string `json:"owner_id"`
MemberCount int `json:"member_count"`
MaxMembers int `json:"max_members"`
JoinType int `json:"join_type"`
MuteAll bool `json:"mute_all"`
CreatedAt string `json:"created_at"`
}
// GroupMemberResponse 群成员响应
type GroupMemberResponse struct {
ID string `json:"id"`
GroupID string `json:"group_id"`
UserID string `json:"user_id"`
Role string `json:"role"`
Nickname string `json:"nickname"`
Muted bool `json:"muted"`
JoinTime string `json:"join_time"`
User *UserResponse `json:"user,omitempty"`
}
// GroupAnnouncementResponse 群公告响应
type GroupAnnouncementResponse struct {
ID string `json:"id"`
GroupID string `json:"group_id"`
Content string `json:"content"`
AuthorID string `json:"author_id"`
IsPinned bool `json:"is_pinned"`
CreatedAt string `json:"created_at"`
}
// GroupListResponse 群组列表响应
type GroupListResponse struct {
List []*GroupResponse `json:"list"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// GroupMemberListResponse 群成员列表响应
type GroupMemberListResponse struct {
List []*GroupMemberResponse `json:"list"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// GroupAnnouncementListResponse 群公告列表响应
type GroupAnnouncementListResponse struct {
List []*GroupAnnouncementResponse `json:"list"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
// ==================== WebSocket Event DTOs ====================
// WSEventResponse WebSocket事件响应结构体
// 用于后端推送消息给前端的标准格式
type WSEventResponse struct {
ID string `json:"id"` // 事件唯一ID (UUID)
Time int64 `json:"time"` // 时间戳 (毫秒)
Type string `json:"type"` // 事件类型 (message, notification, system等)
DetailType string `json:"detail_type"` // 详细类型 (private, group, like, comment等)
Seq string `json:"seq"` // 消息序列号
Segments model.MessageSegments `json:"segments"` // 消息段数组
SenderID string `json:"sender_id"` // 发送者用户ID
}
// ==================== WebSocket Request DTOs ====================
// SendMessageParams 发送消息参数(用于 REST API
type SendMessageParams struct {
DetailType string `json:"detail_type"` // 消息类型: private, group
ConversationID string `json:"conversation_id"` // 会话ID
Segments model.MessageSegments `json:"segments"` // 消息内容(消息段数组)
ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID
}
// DeleteMsgParams 撤回消息参数
type DeleteMsgParams struct {
MessageID string `json:"message_id"` // 消息ID
}
// ==================== Group Action Params ====================
// SetGroupKickParams 群组踢人参数
type SetGroupKickParams struct {
GroupID string `json:"group_id"` // 群组ID
UserID string `json:"user_id"` // 被踢用户ID
RejectAddRequest bool `json:"reject_add_request"` // 是否拒绝再次加群
}
// SetGroupBanParams 群组单人禁言参数
type SetGroupBanParams struct {
GroupID string `json:"group_id"` // 群组ID
UserID string `json:"user_id"` // 被禁言用户ID
Duration int64 `json:"duration"` // 禁言时长0表示解除禁言
}
// SetGroupWholeBanParams 群组全员禁言参数
type SetGroupWholeBanParams struct {
GroupID string `json:"group_id"` // 群组ID
Enable bool `json:"enable"` // 是否开启全员禁言
}
// SetGroupAdminParams 群组设置管理员参数
type SetGroupAdminParams struct {
GroupID string `json:"group_id"` // 群组ID
UserID string `json:"user_id"` // 被设置的用户ID
Enable bool `json:"enable"` // 是否设置为管理员
}
// SetGroupNameParams 设置群名参数
type SetGroupNameParams struct {
GroupID string `json:"group_id"` // 群组ID
GroupName string `json:"group_name"` // 新群名
}
// SetGroupAvatarParams 设置群头像参数
type SetGroupAvatarParams struct {
GroupID string `json:"group_id"` // 群组ID
Avatar string `json:"avatar"` // 头像URL
}
// SetGroupLeaveParams 退出群组参数
type SetGroupLeaveParams struct {
GroupID string `json:"group_id"` // 群组ID
}
// SetGroupAddRequestParams 处理加群请求参数
type SetGroupAddRequestParams struct {
Flag string `json:"flag"` // 加群请求的flag标识
Approve bool `json:"approve"` // 是否同意
Reason string `json:"reason"` // 拒绝理由当approve为false时
}
// GetConversationListParams 获取会话列表参数
type GetConversationListParams struct {
Page int `json:"page"` // 页码
PageSize int `json:"page_size"` // 每页数量
}
// GetGroupInfoParams 获取群信息参数
type GetGroupInfoParams struct {
GroupID string `json:"group_id"` // 群组ID
}
// GetGroupMemberListParams 获取群成员列表参数
type GetGroupMemberListParams struct {
GroupID string `json:"group_id"` // 群组ID
Page int `json:"page"` // 页码
PageSize int `json:"page_size"` // 每页数量
}
// ==================== Conversation Action Params ====================
// CreateConversationParams 创建会话参数
type CreateConversationParams struct {
UserID string `json:"user_id"` // 目标用户ID私聊
}
// MarkReadParams 标记已读参数
type MarkReadParams struct {
ConversationID string `json:"conversation_id"` // 会话ID
LastReadSeq int64 `json:"last_read_seq"` // 最后已读消息序号
}
// SetConversationPinnedParams 设置会话置顶参数
type SetConversationPinnedParams struct {
ConversationID string `json:"conversation_id"` // 会话ID
IsPinned bool `json:"is_pinned"` // 是否置顶
}
// ==================== Group Action Params (Additional) ====================
// CreateGroupParams 创建群组参数
type CreateGroupParams struct {
Name string `json:"name"` // 群名
Description string `json:"description,omitempty"` // 群描述
MemberIDs []string `json:"member_ids,omitempty"` // 初始成员ID列表
}
// GetUserGroupsParams 获取用户群组列表参数
type GetUserGroupsParams struct {
Page int `json:"page"` // 页码
PageSize int `json:"page_size"` // 每页数量
}
// TransferOwnerParams 转让群主参数
type TransferOwnerParams struct {
GroupID string `json:"group_id"` // 群组ID
NewOwnerID string `json:"new_owner_id"` // 新群主ID
}
// InviteMembersParams 邀请成员参数
type InviteMembersParams struct {
GroupID string `json:"group_id"` // 群组ID
MemberIDs []string `json:"member_ids"` // 被邀请的用户ID列表
}
// JoinGroupParams 加入群组参数
type JoinGroupParams struct {
GroupID string `json:"group_id"` // 群组ID
}
// SetNicknameParams 设置群内昵称参数
type SetNicknameParams struct {
GroupID string `json:"group_id"` // 群组ID
Nickname string `json:"nickname"` // 群内昵称
}
// SetJoinTypeParams 设置加群方式参数
type SetJoinTypeParams struct {
GroupID string `json:"group_id"` // 群组ID
JoinType int `json:"join_type"` // 加群方式0-允许任何人加入1-需要审批2-不允许加入
}
// CreateAnnouncementParams 创建群公告参数
type CreateAnnouncementParams struct {
GroupID string `json:"group_id"` // 群组ID
Content string `json:"content"` // 公告内容
}
// GetAnnouncementsParams 获取群公告列表参数
type GetAnnouncementsParams struct {
GroupID string `json:"group_id"` // 群组ID
Page int `json:"page"` // 页码
PageSize int `json:"page_size"` // 每页数量
}
// DeleteAnnouncementParams 删除群公告参数
type DeleteAnnouncementParams struct {
GroupID string `json:"group_id"` // 群组ID
AnnouncementID string `json:"announcement_id"` // 公告ID
}
// DissolveGroupParams 解散群组参数
type DissolveGroupParams struct {
GroupID string `json:"group_id"` // 群组ID
}
// GetMyMemberInfoParams 获取我在群组中的成员信息参数
type GetMyMemberInfoParams struct {
GroupID string `json:"group_id"` // 群组ID
}
// ==================== Vote DTOs ====================
// CreateVotePostRequest 创建投票帖子请求
type CreateVotePostRequest struct {
Title string `json:"title" binding:"required,max=200"`
Content string `json:"content" binding:"max=2000"`
CommunityID string `json:"community_id"`
Images []string `json:"images"`
VoteOptions []string `json:"vote_options" binding:"required,min=2,max=10"` // 投票选项至少2个最多10个
}
// VoteOptionDTO 投票选项DTO
type VoteOptionDTO struct {
ID string `json:"id"`
Content string `json:"content"`
VotesCount int `json:"votes_count"`
}
// VoteResultDTO 投票结果DTO
type VoteResultDTO struct {
Options []VoteOptionDTO `json:"options"`
TotalVotes int `json:"total_votes"`
HasVoted bool `json:"has_voted"`
VotedOptionID string `json:"voted_option_id,omitempty"`
}
// ==================== WebSocket Response DTOs ====================
// WSResponse WebSocket响应结构体
type WSResponse struct {
Success bool `json:"success"` // 是否成功
Action string `json:"action"` // 响应原action
Data interface{} `json:"data,omitempty"` // 响应数据
Error string `json:"error,omitempty"` // 错误信息
}

362
internal/dto/segment.go Normal file
View File

@@ -0,0 +1,362 @@
package dto
import (
"encoding/json"
"fmt"
"carrot_bbs/internal/model"
)
// ParseSegmentData 解析Segment数据到目标结构体
func ParseSegmentData(segment model.MessageSegment, target interface{}) error {
dataBytes, err := json.Marshal(segment.Data)
if err != nil {
return err
}
return json.Unmarshal(dataBytes, target)
}
// NewTextSegment 创建文本Segment
func NewTextSegment(content string) model.MessageSegment {
return model.MessageSegment{
Type: string(SegmentTypeText),
Data: map[string]interface{}{"text": content},
}
}
// NewImageSegment 创建图片Segment
func NewImageSegment(url string, width, height int, thumbnailURL string) model.MessageSegment {
return model.MessageSegment{
Type: string(SegmentTypeImage),
Data: map[string]interface{}{
"url": url,
"width": width,
"height": height,
"thumbnail_url": thumbnailURL,
},
}
}
// NewImageSegmentWithSize 创建带文件大小的图片Segment
func NewImageSegmentWithSize(url string, width, height int, thumbnailURL string, fileSize int64) model.MessageSegment {
return model.MessageSegment{
Type: string(SegmentTypeImage),
Data: map[string]interface{}{
"url": url,
"width": width,
"height": height,
"thumbnail_url": thumbnailURL,
"file_size": fileSize,
},
}
}
// NewVoiceSegment 创建语音Segment
func NewVoiceSegment(url string, duration int, fileSize int64) model.MessageSegment {
return model.MessageSegment{
Type: string(SegmentTypeVoice),
Data: map[string]interface{}{
"url": url,
"duration": duration,
"file_size": fileSize,
},
}
}
// NewVideoSegment 创建视频Segment
func NewVideoSegment(url string, width, height, duration int, thumbnailURL string, fileSize int64) model.MessageSegment {
return model.MessageSegment{
Type: string(SegmentTypeVideo),
Data: map[string]interface{}{
"url": url,
"width": width,
"height": height,
"duration": duration,
"thumbnail_url": thumbnailURL,
"file_size": fileSize,
},
}
}
// NewFileSegment 创建文件Segment
func NewFileSegment(url, name string, size int64, mimeType string) model.MessageSegment {
return model.MessageSegment{
Type: string(SegmentTypeFile),
Data: map[string]interface{}{
"url": url,
"name": name,
"size": size,
"mime_type": mimeType,
},
}
}
// NewAtSegment 创建@Segment只存储 user_id昵称由前端根据群成员列表实时解析
func NewAtSegment(userID string) model.MessageSegment {
return model.MessageSegment{
Type: string(SegmentTypeAt),
Data: map[string]interface{}{
"user_id": userID,
},
}
}
// NewAtAllSegment 创建@所有人Segment
func NewAtAllSegment() model.MessageSegment {
return model.MessageSegment{
Type: string(SegmentTypeAt),
Data: map[string]interface{}{
"user_id": "all",
},
}
}
// NewReplySegment 创建回复Segment
func NewReplySegment(messageID string) model.MessageSegment {
return model.MessageSegment{
Type: string(SegmentTypeReply),
Data: map[string]interface{}{"id": messageID},
}
}
// NewFaceSegment 创建表情Segment
func NewFaceSegment(id int, name, url string) model.MessageSegment {
return model.MessageSegment{
Type: string(SegmentTypeFace),
Data: map[string]interface{}{
"id": id,
"name": name,
"url": url,
},
}
}
// NewLinkSegment 创建链接Segment
func NewLinkSegment(url, title, description, image string) model.MessageSegment {
return model.MessageSegment{
Type: string(SegmentTypeLink),
Data: map[string]interface{}{
"url": url,
"title": title,
"description": description,
"image": image,
},
}
}
// ExtractTextContentFromJSON 从JSON格式的segments中提取纯文本内容
// 用于从数据库读取的 []byte 格式的 segments
// 已废弃:现在数据库直接存储 model.MessageSegments 类型
func ExtractTextContentFromJSON(segmentsJSON []byte) string {
if len(segmentsJSON) == 0 {
return ""
}
var segments model.MessageSegments
if err := json.Unmarshal(segmentsJSON, &segments); err != nil {
return ""
}
return ExtractTextContentFromModel(segments)
}
// ExtractTextContentFromModel 从 model.MessageSegments 中提取纯文本内容
func ExtractTextContentFromModel(segments model.MessageSegments) string {
var result string
for _, segment := range segments {
switch segment.Type {
case "text":
if text, ok := segment.Data["text"].(string); ok {
result += text
}
case "at":
userID, _ := segment.Data["user_id"].(string)
if userID == "all" {
result += "@所有人 "
} else if userID != "" {
// 昵称由前端实时解析,后端文本提取仅用于推送通知兜底
result += "@某人 "
}
case "image":
result += "[图片]"
case "voice":
result += "[语音]"
case "video":
result += "[视频]"
case "file":
if name, ok := segment.Data["name"].(string); ok && name != "" {
result += fmt.Sprintf("[文件:%s]", name)
} else {
result += "[文件]"
}
case "face":
if name, ok := segment.Data["name"].(string); ok && name != "" {
result += fmt.Sprintf("[%s]", name)
} else {
result += "[表情]"
}
case "link":
if title, ok := segment.Data["title"].(string); ok && title != "" {
result += fmt.Sprintf("[链接:%s]", title)
} else {
result += "[链接]"
}
}
}
return result
}
// ExtractTextContent 从消息链中提取纯文本内容
// 用于搜索、通知展示等场景
func ExtractTextContent(segments model.MessageSegments) string {
return ExtractTextContentFromModel(segments)
}
// ExtractMentionedUsers 从消息链中提取被@的用户ID列表
// 不包括 "all"@所有人)
func ExtractMentionedUsers(segments model.MessageSegments) []string {
var userIDs []string
seen := make(map[string]bool)
for _, segment := range segments {
if segment.Type == string(SegmentTypeAt) {
userID, _ := segment.Data["user_id"].(string)
if userID != "all" && userID != "" && !seen[userID] {
userIDs = append(userIDs, userID)
seen[userID] = true
}
}
}
return userIDs
}
// IsAtAll 检查消息是否@了所有人
func IsAtAll(segments model.MessageSegments) bool {
for _, segment := range segments {
if segment.Type == string(SegmentTypeAt) {
if userID, ok := segment.Data["user_id"].(string); ok && userID == "all" {
return true
}
}
}
return false
}
// GetReplyMessageID 从消息链中获取被回复的消息ID
// 如果没有回复segment返回空字符串
func GetReplyMessageID(segments model.MessageSegments) string {
for _, segment := range segments {
if segment.Type == string(SegmentTypeReply) {
if id, ok := segment.Data["id"].(string); ok {
return id
}
}
}
return ""
}
// BuildSegmentsFromContent 从旧版content构建segments
// 用于兼容旧版本消息
func BuildSegmentsFromContent(contentType, content string, mediaURL *string) model.MessageSegments {
var segments model.MessageSegments
switch contentType {
case "text":
segments = append(segments, NewTextSegment(content))
case "image":
if mediaURL != nil {
segments = append(segments, NewImageSegment(*mediaURL, 0, 0, ""))
}
case "voice":
if mediaURL != nil {
segments = append(segments, NewVoiceSegment(*mediaURL, 0, 0))
}
case "video":
if mediaURL != nil {
segments = append(segments, NewVideoSegment(*mediaURL, 0, 0, 0, "", 0))
}
case "file":
if mediaURL != nil {
segments = append(segments, NewFileSegment(*mediaURL, content, 0, ""))
}
default:
// 默认当作文本处理
if content != "" {
segments = append(segments, NewTextSegment(content))
}
}
return segments
}
// HasSegmentType 检查消息链中是否包含指定类型的segment
func HasSegmentType(segments model.MessageSegments, segmentType SegmentType) bool {
for _, segment := range segments {
if segment.Type == string(segmentType) {
return true
}
}
return false
}
// GetSegmentsByType 获取消息链中所有指定类型的segment
func GetSegmentsByType(segments model.MessageSegments, segmentType SegmentType) []model.MessageSegment {
var result []model.MessageSegment
for _, segment := range segments {
if segment.Type == string(segmentType) {
result = append(result, segment)
}
}
return result
}
// GetFirstImageURL 获取消息链中第一张图片的URL
// 如果没有图片,返回空字符串
func GetFirstImageURL(segments model.MessageSegments) string {
for _, segment := range segments {
if segment.Type == string(SegmentTypeImage) {
if url, ok := segment.Data["url"].(string); ok {
return url
}
}
}
return ""
}
// GetFirstMediaURL 获取消息链中第一个媒体文件的URL图片/视频/语音/文件)
// 用于兼容旧版本API
func GetFirstMediaURL(segments model.MessageSegments) string {
for _, segment := range segments {
switch segment.Type {
case string(SegmentTypeImage), string(SegmentTypeVideo), string(SegmentTypeVoice), string(SegmentTypeFile):
if url, ok := segment.Data["url"].(string); ok {
return url
}
}
}
return ""
}
// DetermineContentType 从消息链推断消息类型(用于兼容旧版本)
func DetermineContentType(segments model.MessageSegments) string {
if len(segments) == 0 {
return "text"
}
// 优先检查媒体类型
for _, segment := range segments {
switch segment.Type {
case string(SegmentTypeImage):
return "image"
case string(SegmentTypeVideo):
return "video"
case string(SegmentTypeVoice):
return "voice"
case string(SegmentTypeFile):
return "file"
}
}
// 默认返回text
return "text"
}

View File

@@ -0,0 +1,253 @@
package handler
import (
"encoding/json"
"errors"
"strconv"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// CommentHandler 评论处理器
type CommentHandler struct {
commentService *service.CommentService
}
// NewCommentHandler 创建评论处理器
func NewCommentHandler(commentService *service.CommentService) *CommentHandler {
return &CommentHandler{
commentService: commentService,
}
}
// Create 创建评论
func (h *CommentHandler) Create(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
type CreateRequest struct {
PostID string `json:"post_id" binding:"required"`
Content string `json:"content"` // 内容可选,允许只发图片
ParentID *string `json:"parent_id"`
Images []string `json:"images"` // 图片URL列表
}
var req CreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
// 验证:评论必须有内容或图片
if req.Content == "" && len(req.Images) == 0 {
response.BadRequest(c, "评论内容或图片不能同时为空")
return
}
// 将图片列表转换为JSON字符串
var imagesJSON string
if len(req.Images) > 0 {
imagesBytes, _ := json.Marshal(req.Images)
imagesJSON = string(imagesBytes)
}
comment, err := h.commentService.Create(c.Request.Context(), req.PostID, userID, req.Content, req.ParentID, imagesJSON, req.Images)
if err != nil {
var moderationErr *service.CommentModerationRejectedError
if errors.As(err, &moderationErr) {
response.BadRequest(c, moderationErr.UserMessage())
return
}
response.InternalServerError(c, "failed to create comment")
return
}
response.Success(c, dto.ConvertCommentToResponse(comment, false))
}
// GetByID 获取单条评论详情
func (h *CommentHandler) GetByID(c *gin.Context) {
userID := c.GetString("user_id")
id := c.Param("id")
comment, err := h.commentService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "comment not found")
return
}
resp := dto.ConvertCommentToResponse(comment, h.commentService.IsLiked(c.Request.Context(), id, userID))
response.Success(c, resp)
}
// GetByPostID 获取帖子评论
func (h *CommentHandler) GetByPostID(c *gin.Context) {
userID := c.GetString("user_id")
postID := c.Param("id")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
comments, total, err := h.commentService.GetByPostID(c.Request.Context(), postID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get comments")
return
}
// 转换为响应结构,检查每个评论的点赞状态
commentResponses := dto.ConvertCommentsToResponseWithUser(comments, userID, h.commentService)
response.Paginated(c, commentResponses, total, page, pageSize)
}
// GetReplies 获取回复
func (h *CommentHandler) GetReplies(c *gin.Context) {
userID := c.GetString("user_id")
parentID := c.Param("id")
comments, err := h.commentService.GetReplies(c.Request.Context(), parentID)
if err != nil {
response.InternalServerError(c, "failed to get replies")
return
}
// 转换为响应结构,检查每个回复的点赞状态
commentResponses := dto.ConvertCommentsToResponseWithUser(comments, userID, h.commentService)
response.Success(c, commentResponses)
}
// GetRepliesByRootID 根据根评论ID分页获取回复扁平化
func (h *CommentHandler) GetRepliesByRootID(c *gin.Context) {
userID := c.GetString("user_id")
rootID := c.Param("id")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10"))
replies, total, err := h.commentService.GetRepliesByRootID(c.Request.Context(), rootID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get replies")
return
}
// 转换为响应结构,检查每个回复的点赞状态
replyResponses := dto.ConvertCommentsToResponseWithUser(replies, userID, h.commentService)
response.Paginated(c, replyResponses, total, page, pageSize)
}
// Update 更新评论
func (h *CommentHandler) Update(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
comment, err := h.commentService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "comment not found")
return
}
if comment.UserID != userID {
response.Forbidden(c, "cannot update others' comment")
return
}
type UpdateRequest struct {
Content string `json:"content" binding:"required"`
}
var req UpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
comment.Content = req.Content
err = h.commentService.Update(c.Request.Context(), comment)
if err != nil {
response.InternalServerError(c, "failed to update comment")
return
}
response.Success(c, dto.ConvertCommentToResponse(comment, false))
}
// Delete 删除评论
func (h *CommentHandler) Delete(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
comment, err := h.commentService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "comment not found")
return
}
if comment.UserID != userID {
response.Forbidden(c, "cannot delete others' comment")
return
}
err = h.commentService.Delete(c.Request.Context(), id)
if err != nil {
response.InternalServerError(c, "failed to delete comment")
return
}
response.SuccessWithMessage(c, "comment deleted", nil)
}
// Like 点赞评论
func (h *CommentHandler) Like(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
err := h.commentService.Like(c.Request.Context(), id, userID)
if err != nil {
response.InternalServerError(c, "failed to like comment")
return
}
response.SuccessWithMessage(c, "liked", nil)
}
// Unlike 取消点赞评论
func (h *CommentHandler) Unlike(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
err := h.commentService.Unlike(c.Request.Context(), id, userID)
if err != nil {
response.InternalServerError(c, "failed to unlike comment")
return
}
response.SuccessWithMessage(c, "unliked", nil)
}

View File

@@ -0,0 +1,234 @@
package handler
import (
"context"
"log"
"strings"
"time"
"carrot_bbs/internal/config"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/gorse"
"carrot_bbs/internal/pkg/response"
gorseio "github.com/gorse-io/gorse-go"
"github.com/gin-gonic/gin"
)
// GorseHandler Gorse推荐处理器
type GorseHandler struct {
importPassword string
gorseConfig config.GorseConfig
}
// NewGorseHandler 创建Gorse处理器
func NewGorseHandler(cfg config.GorseConfig) *GorseHandler {
return &GorseHandler{
importPassword: cfg.ImportPassword,
gorseConfig: cfg,
}
}
// ImportRequest 导入请求
type ImportRequest struct {
Password string `json:"password"`
}
// ImportData 导入数据到Gorse
// POST /api/v1/gorse/import
func (h *GorseHandler) ImportData(c *gin.Context) {
// 验证密码
if h.importPassword == "" {
response.BadRequest(c, "Gorse import is disabled")
return
}
var req ImportRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "invalid request body")
return
}
if req.Password != h.importPassword {
response.Unauthorized(c, "invalid password")
return
}
ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Minute)
defer cancel()
stats, err := h.importAllData(ctx)
if err != nil {
log.Printf("[ERROR] gorse import failed: %v", err)
response.InternalServerError(c, "gorse import failed: "+err.Error())
return
}
response.Success(c, gin.H{
"message": "import completed",
"status": "done",
"stats": stats,
})
}
// GetStatus 获取Gorse状态
// GET /api/v1/gorse/status
func (h *GorseHandler) GetStatus(c *gin.Context) {
// 返回Gorse连接状态和配置信息
hasPassword := h.importPassword != ""
response.Success(c, gin.H{
"enabled": h.gorseConfig.Enabled,
"has_password": hasPassword,
"address": h.gorseConfig.Address,
"api_key": strings.Repeat("*", 8), // 不返回实际APIKey
})
}
func (h *GorseHandler) importAllData(ctx context.Context) (gin.H, error) {
gorseClient := gorseio.NewGorseClient(h.gorseConfig.Address, h.gorseConfig.APIKey)
gorse.InitEmbeddingWithConfig(h.gorseConfig.EmbeddingAPIKey, h.gorseConfig.EmbeddingURL, h.gorseConfig.EmbeddingModel)
stats := gin.H{
"items": 0,
"users": 0,
"likes": 0,
"favorites": 0,
"comments": 0,
}
// 导入帖子
var posts []model.Post
if err := model.DB.Find(&posts).Error; err != nil {
return nil, err
}
for _, post := range posts {
embedding, err := gorse.GetEmbedding(strings.TrimSpace(post.Title + " " + post.Content))
if err != nil {
log.Printf("[WARN] get embedding failed for post %s: %v", post.ID, err)
embedding = make([]float64, 1024)
}
_, err = gorseClient.InsertItem(ctx, gorseio.Item{
ItemId: post.ID,
IsHidden: post.DeletedAt.Valid,
Categories: buildPostCategories(&post),
Comment: post.Title,
Timestamp: post.CreatedAt.UTC().Truncate(time.Second),
Labels: map[string]any{
"embedding": embedding,
},
})
if err != nil {
log.Printf("[WARN] import item failed (%s): %v", post.ID, err)
continue
}
stats["items"] = stats["items"].(int) + 1
}
// 导入用户
var users []model.User
if err := model.DB.Find(&users).Error; err != nil {
return nil, err
}
for _, user := range users {
_, err := gorseClient.InsertUser(ctx, gorseio.User{
UserId: user.ID,
Labels: map[string]any{
"posts_count": user.PostsCount,
"followers_count": user.FollowersCount,
"following_count": user.FollowingCount,
},
Comment: user.Nickname,
})
if err != nil {
log.Printf("[WARN] import user failed (%s): %v", user.ID, err)
continue
}
stats["users"] = stats["users"].(int) + 1
}
// 导入点赞
var likes []model.PostLike
if err := model.DB.Find(&likes).Error; err != nil {
return nil, err
}
for _, like := range likes {
_, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{
FeedbackType: string(gorse.FeedbackTypeLike),
UserId: like.UserID,
ItemId: like.PostID,
Timestamp: like.CreatedAt.UTC().Truncate(time.Second),
}})
if err != nil {
log.Printf("[WARN] import like failed (%s/%s): %v", like.UserID, like.PostID, err)
continue
}
stats["likes"] = stats["likes"].(int) + 1
}
// 导入收藏
var favorites []model.Favorite
if err := model.DB.Find(&favorites).Error; err != nil {
return nil, err
}
for _, fav := range favorites {
_, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{
FeedbackType: string(gorse.FeedbackTypeStar),
UserId: fav.UserID,
ItemId: fav.PostID,
Timestamp: fav.CreatedAt.UTC().Truncate(time.Second),
}})
if err != nil {
log.Printf("[WARN] import favorite failed (%s/%s): %v", fav.UserID, fav.PostID, err)
continue
}
stats["favorites"] = stats["favorites"].(int) + 1
}
// 导入评论(按用户-帖子去重)
var comments []model.Comment
if err := model.DB.Where("status = ?", model.CommentStatusPublished).Find(&comments).Error; err != nil {
return nil, err
}
seen := make(map[string]struct{})
for _, cm := range comments {
key := cm.UserID + ":" + cm.PostID
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
_, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{
FeedbackType: string(gorse.FeedbackTypeComment),
UserId: cm.UserID,
ItemId: cm.PostID,
Timestamp: cm.CreatedAt.UTC().Truncate(time.Second),
}})
if err != nil {
log.Printf("[WARN] import comment failed (%s/%s): %v", cm.UserID, cm.PostID, err)
continue
}
stats["comments"] = stats["comments"].(int) + 1
}
return stats, nil
}
func buildPostCategories(post *model.Post) []string {
var categories []string
if post.ViewsCount > 1000 {
categories = append(categories, "hot_high")
} else if post.ViewsCount > 100 {
categories = append(categories, "hot_medium")
}
if post.LikesCount > 100 {
categories = append(categories, "likes_100+")
} else if post.LikesCount > 10 {
categories = append(categories, "likes_10+")
}
age := time.Since(post.CreatedAt)
if age < 24*time.Hour {
categories = append(categories, "today")
} else if age < 7*24*time.Hour {
categories = append(categories, "this_week")
}
return categories
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,879 @@
package handler
import (
"context"
"fmt"
"strconv"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// MessageHandler 消息处理器
type MessageHandler struct {
chatService service.ChatService
messageService *service.MessageService
userService *service.UserService
groupService service.GroupService
}
// NewMessageHandler 创建消息处理器
func NewMessageHandler(chatService service.ChatService, messageService *service.MessageService, userService *service.UserService, groupService service.GroupService) *MessageHandler {
return &MessageHandler{
chatService: chatService,
messageService: messageService,
userService: userService,
groupService: groupService,
}
}
// GetConversations 获取会话列表
// GET /api/conversations
func (h *MessageHandler) GetConversations(c *gin.Context) {
userID := c.GetString("user_id")
// 添加调试日志
if userID == "" {
response.Unauthorized(c, "")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
convs, _, err := h.chatService.GetConversationList(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get conversations")
return
}
// 过滤掉系统会话(系统通知现在使用独立的表)
filteredConvs := make([]*model.Conversation, 0)
for _, conv := range convs {
if conv.ID != model.SystemConversationID {
filteredConvs = append(filteredConvs, conv)
}
}
// 转换为响应格式
result := make([]*dto.ConversationResponse, len(filteredConvs))
for i, conv := range filteredConvs {
// 获取未读数
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conv.ID, userID)
// 获取最后一条消息
var lastMessage *model.Message
messages, _, _ := h.chatService.GetMessages(c.Request.Context(), conv.ID, userID, 1, 1)
if len(messages) > 0 {
lastMessage = messages[0]
}
// 群聊时返回member_count私聊时返回participants
var resp *dto.ConversationResponse
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
isPinned := myParticipant != nil && myParticipant.IsPinned
if conv.Type == model.ConversationTypeGroup && conv.GroupID != nil && *conv.GroupID != "" {
// 群聊:实时计算群成员数量
memberCount, _ := h.groupService.GetMemberCount(*conv.GroupID)
// 创建响应并设置member_count
resp = dto.ConvertConversationToResponse(conv, nil, int(unreadCount), lastMessage, isPinned)
resp.MemberCount = memberCount
} else {
// 私聊:获取参与者信息
participants, _ := h.getConversationParticipants(c.Request.Context(), conv.ID, userID)
resp = dto.ConvertConversationToResponse(conv, participants, int(unreadCount), lastMessage, isPinned)
}
result[i] = resp
}
// 更新 total 为过滤后的数量
response.Paginated(c, result, int64(len(filteredConvs)), page, pageSize)
}
// CreateConversation 创建私聊会话
// POST /api/conversations
func (h *MessageHandler) CreateConversation(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var req dto.CreateConversationRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
// 验证目标用户是否存在
targetUser, err := h.userService.GetUserByID(c.Request.Context(), req.UserID)
if err != nil {
response.BadRequest(c, "target user not found")
return
}
// 不能和自己创建会话
if userID == req.UserID {
response.BadRequest(c, "cannot create conversation with yourself")
return
}
conv, err := h.chatService.GetOrCreateConversation(c.Request.Context(), userID, req.UserID)
if err != nil {
response.InternalServerError(c, "failed to create conversation")
return
}
// 获取参与者信息
participants := []*model.User{targetUser}
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
isPinned := myParticipant != nil && myParticipant.IsPinned
response.Success(c, dto.ConvertConversationToResponse(conv, participants, 0, nil, isPinned))
}
// GetConversationByID 获取会话详情
// GET /api/conversations/:id
func (h *MessageHandler) GetConversationByID(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
conversationIDStr := c.Param("id")
fmt.Printf("[DEBUG] GetConversationByID: conversationIDStr = %s\n", conversationIDStr)
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
fmt.Printf("[DEBUG] GetConversationByID: failed to parse conversation ID: %v\n", err)
response.BadRequest(c, "invalid conversation id")
return
}
fmt.Printf("[DEBUG] GetConversationByID: conversationID = %s\n", conversationID)
conv, err := h.chatService.GetConversationByID(c.Request.Context(), conversationID, userID)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 获取未读数
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
// 获取参与者信息
participants, _ := h.getConversationParticipants(c.Request.Context(), conversationID, userID)
// 获取当前用户的已读位置
myLastReadSeq := int64(0)
isPinned := false
allParticipants, _ := h.messageService.GetConversationParticipants(conversationID)
for _, p := range allParticipants {
if p.UserID == userID {
myLastReadSeq = p.LastReadSeq
isPinned = p.IsPinned
break
}
}
// 获取对方用户的已读位置
otherLastReadSeq := int64(0)
response.Success(c, dto.ConvertConversationToDetailResponse(conv, participants, unreadCount, nil, myLastReadSeq, otherLastReadSeq, isPinned))
}
// GetMessages 获取消息列表
// GET /api/conversations/:id/messages
func (h *MessageHandler) GetMessages(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
conversationIDStr := c.Param("id")
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
response.BadRequest(c, "invalid conversation id")
return
}
// 检查是否使用增量同步after_seq参数
afterSeqStr := c.Query("after_seq")
if afterSeqStr != "" {
// 增量同步模式
afterSeq, err := strconv.ParseInt(afterSeqStr, 10, 64)
if err != nil {
response.BadRequest(c, "invalid after_seq")
return
}
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
messages, err := h.chatService.GetMessagesAfterSeq(c.Request.Context(), conversationID, userID, afterSeq, limit)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 转换为响应格式
result := dto.ConvertMessagesToResponse(messages)
response.Success(c, &dto.MessageSyncResponse{
Messages: result,
HasMore: len(messages) == limit,
})
return
}
// 检查是否使用历史消息加载before_seq参数
beforeSeqStr := c.Query("before_seq")
if beforeSeqStr != "" {
// 加载更早的消息(下拉加载更多)
beforeSeq, err := strconv.ParseInt(beforeSeqStr, 10, 64)
if err != nil {
response.BadRequest(c, "invalid before_seq")
return
}
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
messages, err := h.chatService.GetMessagesBeforeSeq(c.Request.Context(), conversationID, userID, beforeSeq, limit)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 转换为响应格式
result := dto.ConvertMessagesToResponse(messages)
response.Success(c, &dto.MessageSyncResponse{
Messages: result,
HasMore: len(messages) == limit,
})
return
}
// 分页模式
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
messages, total, err := h.chatService.GetMessages(c.Request.Context(), conversationID, userID, page, pageSize)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 转换为响应格式
result := dto.ConvertMessagesToResponse(messages)
response.Paginated(c, result, total, page, pageSize)
}
// SendMessage 发送消息
// POST /api/conversations/:id/messages
func (h *MessageHandler) SendMessage(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
conversationIDStr := c.Param("id")
fmt.Printf("[DEBUG] SendMessage: conversationIDStr = %s\n", conversationIDStr)
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
fmt.Printf("[DEBUG] SendMessage: failed to parse conversation ID: %v\n", err)
response.BadRequest(c, "invalid conversation id")
return
}
fmt.Printf("[DEBUG] SendMessage: conversationID = %s, userID = %s\n", conversationID, userID)
var req dto.SendMessageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
// 直接使用 segments
msg, err := h.chatService.SendMessage(c.Request.Context(), userID, conversationID, req.Segments, req.ReplyToID)
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.Success(c, dto.ConvertMessageToResponse(msg))
}
// HandleSendMessage RESTful 风格的发送消息端点
// POST /api/v1/conversations/send_message
// 请求体格式: {"detail_type": "private", "conversation_id": "123445667", "segments": [{"type": "text", "data": {"text": "嗨~"}}]}
func (h *MessageHandler) HandleSendMessage(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var params dto.SendMessageParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
// 验证参数
if params.ConversationID == "" {
response.BadRequest(c, "conversation_id is required")
return
}
if params.DetailType == "" {
response.BadRequest(c, "detail_type is required")
return
}
if params.Segments == nil || len(params.Segments) == 0 {
response.BadRequest(c, "segments is required")
return
}
// 发送消息
msg, err := h.chatService.SendMessage(c.Request.Context(), userID, params.ConversationID, params.Segments, params.ReplyToID)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 构建 WSEventResponse 格式响应
wsResponse := dto.WSEventResponse{
ID: msg.ID,
Time: msg.CreatedAt.UnixMilli(),
Type: "message",
DetailType: params.DetailType,
Seq: strconv.FormatInt(msg.Seq, 10),
Segments: params.Segments,
SenderID: userID,
}
response.Success(c, wsResponse)
}
// HandleDeleteMsg 撤回消息
// POST /api/v1/messages/delete_msg
// 请求体格式: {"message_id": "xxx"}
func (h *MessageHandler) HandleDeleteMsg(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var params dto.DeleteMsgParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
// 验证参数
if params.MessageID == "" {
response.BadRequest(c, "message_id is required")
return
}
// 撤回消息
err := h.chatService.RecallMessage(c.Request.Context(), params.MessageID, userID)
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.SuccessWithMessage(c, "消息已撤回", nil)
}
// HandleGetConversationList 获取会话列表
// GET /api/v1/conversations/list
func (h *MessageHandler) HandleGetConversationList(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
convs, _, err := h.chatService.GetConversationList(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get conversations")
return
}
// 过滤掉系统会话(系统通知现在使用独立的表)
filteredConvs := make([]*model.Conversation, 0)
for _, conv := range convs {
if conv.ID != model.SystemConversationID {
filteredConvs = append(filteredConvs, conv)
}
}
// 转换为响应格式
result := make([]*dto.ConversationResponse, len(filteredConvs))
for i, conv := range filteredConvs {
// 获取未读数
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conv.ID, userID)
// 获取最后一条消息
var lastMessage *model.Message
messages, _, _ := h.chatService.GetMessages(c.Request.Context(), conv.ID, userID, 1, 1)
if len(messages) > 0 {
lastMessage = messages[0]
}
// 群聊时返回member_count私聊时返回participants
var resp *dto.ConversationResponse
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
isPinned := myParticipant != nil && myParticipant.IsPinned
if conv.Type == model.ConversationTypeGroup && conv.GroupID != nil && *conv.GroupID != "" {
// 群聊:实时计算群成员数量
memberCount, _ := h.groupService.GetMemberCount(*conv.GroupID)
// 创建响应并设置member_count
resp = dto.ConvertConversationToResponse(conv, nil, int(unreadCount), lastMessage, isPinned)
resp.MemberCount = memberCount
} else {
// 私聊:获取参与者信息
participants, _ := h.getConversationParticipants(c.Request.Context(), conv.ID, userID)
resp = dto.ConvertConversationToResponse(conv, participants, int(unreadCount), lastMessage, isPinned)
}
result[i] = resp
}
response.Paginated(c, result, int64(len(filteredConvs)), page, pageSize)
}
// HandleDeleteConversationForSelf 仅自己删除会话
// DELETE /api/v1/conversations/:id/self
func (h *MessageHandler) HandleDeleteConversationForSelf(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
conversationID := c.Param("id")
if conversationID == "" {
response.BadRequest(c, "conversation id is required")
return
}
if err := h.chatService.DeleteConversationForSelf(c.Request.Context(), conversationID, userID); err != nil {
response.BadRequest(c, err.Error())
return
}
response.SuccessWithMessage(c, "conversation deleted for self", nil)
}
// MarkAsRead 标记为已读
// POST /api/conversations/:id/read
func (h *MessageHandler) MarkAsRead(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
conversationIDStr := c.Param("id")
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
response.BadRequest(c, "invalid conversation id")
return
}
var req dto.MarkReadRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "last_read_seq is required")
return
}
err = h.chatService.MarkAsRead(c.Request.Context(), conversationID, userID, req.LastReadSeq)
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.SuccessWithMessage(c, "marked as read", nil)
}
// GetUnreadCount 获取未读消息总数
// GET /api/conversations/unread/count
func (h *MessageHandler) GetUnreadCount(c *gin.Context) {
userID := c.GetString("user_id")
// 添加调试日志
fmt.Printf("[DEBUG] GetUnreadCount: user_id from context = %q\n", userID)
if userID == "" {
fmt.Printf("[DEBUG] GetUnreadCount: user_id is empty, returning 401\n")
response.Unauthorized(c, "")
return
}
count, err := h.chatService.GetAllUnreadCount(c.Request.Context(), userID)
if err != nil {
response.InternalServerError(c, "failed to get unread count")
return
}
response.Success(c, &dto.UnreadCountResponse{
TotalUnreadCount: count,
})
}
// GetConversationUnreadCount 获取单个会话的未读数
// GET /api/conversations/:id/unread/count
func (h *MessageHandler) GetConversationUnreadCount(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
conversationIDStr := c.Param("id")
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
response.BadRequest(c, "invalid conversation id")
return
}
count, err := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.Success(c, &dto.ConversationUnreadCountResponse{
ConversationID: conversationID,
UnreadCount: count,
})
}
// RecallMessage 撤回消息
// POST /api/messages/:id/recall
func (h *MessageHandler) RecallMessage(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
messageIDStr := c.Param("id")
// 直接使用 string 类型的 messageID
err := h.chatService.RecallMessage(c.Request.Context(), messageIDStr, userID)
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.SuccessWithMessage(c, "message recalled", nil)
}
// DeleteMessage 删除消息
// DELETE /api/messages/:id
func (h *MessageHandler) DeleteMessage(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
messageIDStr := c.Param("id")
// 直接使用 string 类型的 messageID
err := h.chatService.DeleteMessage(c.Request.Context(), messageIDStr, userID)
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.SuccessWithMessage(c, "message deleted", nil)
}
// 辅助函数:验证内容类型
func isValidContentType(contentType model.ContentType) bool {
switch contentType {
case model.ContentTypeText, model.ContentTypeImage, model.ContentTypeVideo, model.ContentTypeAudio, model.ContentTypeFile:
return true
default:
return false
}
}
// 辅助函数:获取会话参与者信息
func (h *MessageHandler) getConversationParticipants(ctx context.Context, conversationID string, currentUserID string) ([]*model.User, error) {
// 从repository获取参与者列表
participants, err := h.messageService.GetConversationParticipants(conversationID)
if err != nil {
return nil, err
}
// 获取参与者用户信息
var users []*model.User
for _, p := range participants {
// 跳过当前用户
if p.UserID == currentUserID {
continue
}
user, err := h.userService.GetUserByID(ctx, p.UserID)
if err != nil {
continue
}
users = append(users, user)
}
return users, nil
}
// 获取当前用户在会话中的参与者信息
func (h *MessageHandler) getMyConversationParticipant(conversationID string, userID string) (*model.ConversationParticipant, error) {
participants, err := h.messageService.GetConversationParticipants(conversationID)
if err != nil {
return nil, err
}
for _, p := range participants {
if p.UserID == userID {
return p, nil
}
}
return nil, nil
}
// ==================== RESTful Action 端点 ====================
// HandleCreateConversation 创建会话
// POST /api/v1/conversations/create
func (h *MessageHandler) HandleCreateConversation(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var params dto.CreateConversationParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
// 验证目标用户是否存在
targetUser, err := h.userService.GetUserByID(c.Request.Context(), params.UserID)
if err != nil {
response.BadRequest(c, "target user not found")
return
}
// 不能和自己创建会话
if userID == params.UserID {
response.BadRequest(c, "cannot create conversation with yourself")
return
}
conv, err := h.chatService.GetOrCreateConversation(c.Request.Context(), userID, params.UserID)
if err != nil {
response.InternalServerError(c, "failed to create conversation")
return
}
// 获取参与者信息
participants := []*model.User{targetUser}
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
isPinned := myParticipant != nil && myParticipant.IsPinned
response.Success(c, dto.ConvertConversationToResponse(conv, participants, 0, nil, isPinned))
}
// HandleGetConversation 获取会话详情
// GET /api/v1/conversations/get?conversation_id=xxx
func (h *MessageHandler) HandleGetConversation(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
conversationID := c.Query("conversation_id")
if conversationID == "" {
response.BadRequest(c, "conversation_id is required")
return
}
conv, err := h.chatService.GetConversationByID(c.Request.Context(), conversationID, userID)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 获取未读数
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
// 获取参与者信息
participants, _ := h.getConversationParticipants(c.Request.Context(), conversationID, userID)
// 获取当前用户的已读位置
myLastReadSeq := int64(0)
isPinned := false
allParticipants, _ := h.messageService.GetConversationParticipants(conversationID)
for _, p := range allParticipants {
if p.UserID == userID {
myLastReadSeq = p.LastReadSeq
isPinned = p.IsPinned
break
}
}
// 获取对方用户的已读位置
otherLastReadSeq := int64(0)
response.Success(c, dto.ConvertConversationToDetailResponse(conv, participants, unreadCount, nil, myLastReadSeq, otherLastReadSeq, isPinned))
}
// HandleGetMessages 获取会话消息
// GET /api/v1/conversations/get_messages?conversation_id=xxx
func (h *MessageHandler) HandleGetMessages(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
conversationID := c.Query("conversation_id")
if conversationID == "" {
response.BadRequest(c, "conversation_id is required")
return
}
// 检查是否使用增量同步after_seq参数
afterSeqStr := c.Query("after_seq")
if afterSeqStr != "" {
// 增量同步模式
afterSeq, err := strconv.ParseInt(afterSeqStr, 10, 64)
if err != nil {
response.BadRequest(c, "invalid after_seq")
return
}
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
messages, err := h.chatService.GetMessagesAfterSeq(c.Request.Context(), conversationID, userID, afterSeq, limit)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 转换为响应格式
result := dto.ConvertMessagesToResponse(messages)
response.Success(c, &dto.MessageSyncResponse{
Messages: result,
HasMore: len(messages) == limit,
})
return
}
// 检查是否使用历史消息加载before_seq参数
beforeSeqStr := c.Query("before_seq")
if beforeSeqStr != "" {
// 加载更早的消息(下拉加载更多)
beforeSeq, err := strconv.ParseInt(beforeSeqStr, 10, 64)
if err != nil {
response.BadRequest(c, "invalid before_seq")
return
}
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
messages, err := h.chatService.GetMessagesBeforeSeq(c.Request.Context(), conversationID, userID, beforeSeq, limit)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 转换为响应格式
result := dto.ConvertMessagesToResponse(messages)
response.Success(c, &dto.MessageSyncResponse{
Messages: result,
HasMore: len(messages) == limit,
})
return
}
// 分页模式
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
messages, total, err := h.chatService.GetMessages(c.Request.Context(), conversationID, userID, page, pageSize)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 转换为响应格式
result := dto.ConvertMessagesToResponse(messages)
response.Paginated(c, result, total, page, pageSize)
}
// HandleMarkRead 标记已读
// POST /api/v1/conversations/mark_read
func (h *MessageHandler) HandleMarkRead(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var params dto.MarkReadParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.ConversationID == "" {
response.BadRequest(c, "conversation_id is required")
return
}
err := h.chatService.MarkAsRead(c.Request.Context(), params.ConversationID, userID, params.LastReadSeq)
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.SuccessWithMessage(c, "marked as read", nil)
}
// HandleSetConversationPinned 设置会话置顶
// POST /api/v1/conversations/set_pinned
func (h *MessageHandler) HandleSetConversationPinned(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var params dto.SetConversationPinnedParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.ConversationID == "" {
response.BadRequest(c, "conversation_id is required")
return
}
if err := h.chatService.SetConversationPinned(c.Request.Context(), params.ConversationID, userID, params.IsPinned); err != nil {
response.BadRequest(c, err.Error())
return
}
response.SuccessWithMessage(c, "conversation pinned status updated", gin.H{
"conversation_id": params.ConversationID,
"is_pinned": params.IsPinned,
})
}

View File

@@ -0,0 +1,132 @@
package handler
import (
"strconv"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// NotificationHandler 通知处理器
type NotificationHandler struct {
notificationService *service.NotificationService
}
// NewNotificationHandler 创建通知处理器
func NewNotificationHandler(notificationService *service.NotificationService) *NotificationHandler {
return &NotificationHandler{
notificationService: notificationService,
}
}
// GetNotifications 获取通知列表
func (h *NotificationHandler) GetNotifications(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
unreadOnly := c.Query("unread_only") == "true"
notifications, total, err := h.notificationService.GetByUserID(c.Request.Context(), userID, page, pageSize, unreadOnly)
if err != nil {
response.InternalServerError(c, "failed to get notifications")
return
}
response.Paginated(c, notifications, total, page, pageSize)
}
// MarkAsRead 标记为已读
func (h *NotificationHandler) MarkAsRead(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
err := h.notificationService.MarkAsReadWithUserID(c.Request.Context(), id, userID)
if err != nil {
response.InternalServerError(c, "failed to mark as read")
return
}
response.SuccessWithMessage(c, "marked as read", nil)
}
// MarkAllAsRead 标记所有为已读
func (h *NotificationHandler) MarkAllAsRead(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
err := h.notificationService.MarkAllAsRead(c.Request.Context(), userID)
if err != nil {
response.InternalServerError(c, "failed to mark all as read")
return
}
response.SuccessWithMessage(c, "all marked as read", nil)
}
// GetUnreadCount 获取未读数量
func (h *NotificationHandler) GetUnreadCount(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
count, err := h.notificationService.GetUnreadCount(c.Request.Context(), userID)
if err != nil {
response.InternalServerError(c, "failed to get unread count")
return
}
response.Success(c, gin.H{"count": count})
}
// DeleteNotification 删除通知
func (h *NotificationHandler) DeleteNotification(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
err := h.notificationService.DeleteNotification(c.Request.Context(), id, userID)
if err != nil {
response.InternalServerError(c, "failed to delete notification")
return
}
response.Success(c, gin.H{"success": true})
}
// ClearAllNotifications 清空所有通知
func (h *NotificationHandler) ClearAllNotifications(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
err := h.notificationService.ClearAllNotifications(c.Request.Context(), userID)
if err != nil {
response.InternalServerError(c, "failed to clear notifications")
return
}
response.Success(c, gin.H{"success": true})
}

View File

@@ -0,0 +1,511 @@
package handler
import (
"errors"
"fmt"
"strconv"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// PostHandler 帖子处理器
type PostHandler struct {
postService *service.PostService
userService *service.UserService
}
// NewPostHandler 创建帖子处理器
func NewPostHandler(postService *service.PostService, userService *service.UserService) *PostHandler {
return &PostHandler{
postService: postService,
userService: userService,
}
}
// Create 创建帖子
func (h *PostHandler) Create(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
type CreateRequest struct {
Title string `json:"title" binding:"required"`
Content string `json:"content" binding:"required"`
Images []string `json:"images"`
}
var req CreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
post, err := h.postService.Create(c.Request.Context(), userID, req.Title, req.Content, req.Images)
if err != nil {
var moderationErr *service.PostModerationRejectedError
if errors.As(err, &moderationErr) {
response.BadRequest(c, moderationErr.UserMessage())
return
}
response.InternalServerError(c, "failed to create post")
return
}
response.Success(c, dto.ConvertPostToResponse(post, false, false))
}
// GetByID 获取帖子(不增加浏览量)
func (h *PostHandler) GetByID(c *gin.Context) {
id := c.Param("id")
post, err := h.postService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "post not found")
return
}
// 非作者不可查看未发布内容
currentUserID := c.GetString("user_id")
if post.Status != model.PostStatusPublished && post.UserID != currentUserID {
response.NotFound(c, "post not found")
return
}
// 注意:不再自动增加浏览量,浏览量通过 RecordView 端点单独记录
// 获取当前用户ID用于判断点赞和收藏状态
fmt.Printf("[DEBUG] GetByID - postID: %s, currentUserID: %s\n", id, currentUserID)
var isLiked, isFavorited bool
if currentUserID != "" {
isLiked = h.postService.IsLiked(c.Request.Context(), id, currentUserID)
isFavorited = h.postService.IsFavorited(c.Request.Context(), id, currentUserID)
fmt.Printf("[DEBUG] GetByID - isLiked: %v, isFavorited: %v\n", isLiked, isFavorited)
} else {
fmt.Printf("[DEBUG] GetByID - user not logged in, isLiked: false, isFavorited: false\n")
}
// 如果有当前用户,检查与帖子作者的相互关注状态
var authorWithFollowStatus *dto.UserResponse
if currentUserID != "" && post.User != nil {
_, isFollowing, isFollowingMe, err := h.userService.GetUserByIDWithMutualFollowStatus(c.Request.Context(), post.UserID, currentUserID)
if err == nil {
authorWithFollowStatus = dto.ConvertUserToResponseWithMutualFollow(post.User, isFollowing, isFollowingMe)
} else {
// 如果出错使用默认的author
authorWithFollowStatus = dto.ConvertUserToResponse(post.User)
}
}
// 构建响应
responseData := &dto.PostResponse{
ID: post.ID,
UserID: post.UserID,
Title: post.Title,
Content: post.Content,
Images: dto.ConvertPostImagesToResponse(post.Images),
LikesCount: post.LikesCount,
CommentsCount: post.CommentsCount,
FavoritesCount: post.FavoritesCount,
SharesCount: post.SharesCount,
ViewsCount: post.ViewsCount,
IsPinned: post.IsPinned,
IsLocked: post.IsLocked,
IsVote: post.IsVote,
CreatedAt: dto.FormatTime(post.CreatedAt),
Author: authorWithFollowStatus,
IsLiked: isLiked,
IsFavorited: isFavorited,
}
response.Success(c, responseData)
}
// RecordView 记录帖子浏览(增加浏览量)
func (h *PostHandler) RecordView(c *gin.Context) {
id := c.Param("id")
userID := c.GetString("user_id")
// 验证帖子存在
_, err := h.postService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "post not found")
return
}
// 增加浏览量
if err := h.postService.IncrementViews(c.Request.Context(), id, userID); err != nil {
fmt.Printf("[DEBUG] Failed to increment views for post %s: %v\n", id, err)
response.InternalServerError(c, "failed to record view")
return
}
response.Success(c, gin.H{"success": true})
}
// List 获取帖子列表
func (h *PostHandler) List(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
userID := c.Query("user_id")
tab := c.Query("tab") // recommend, follow, hot, latest
// 获取当前用户ID
currentUserID := c.GetString("user_id")
var posts []*model.Post
var total int64
var err error
// 根据 tab 参数选择不同的获取方式
switch tab {
case "follow":
// 获取关注用户的帖子,需要登录
if currentUserID == "" {
response.Unauthorized(c, "请先登录")
return
}
posts, total, err = h.postService.GetFollowingPosts(c.Request.Context(), currentUserID, page, pageSize)
case "hot":
// 获取热门帖子
posts, total, err = h.postService.GetHotPosts(c.Request.Context(), page, pageSize)
case "recommend":
// 推荐帖子从Gorse获取个性化推荐
posts, total, err = h.postService.GetRecommendedPosts(c.Request.Context(), currentUserID, page, pageSize)
case "latest":
// 最新帖子
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
default:
// 默认获取最新帖子
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
}
if err != nil {
response.InternalServerError(c, "failed to get posts")
return
}
fmt.Printf("[DEBUG] List - tab: %s, currentUserID: %s, posts count: %d\n", tab, currentUserID, len(posts))
isLikedMap := make(map[string]bool)
isFavoritedMap := make(map[string]bool)
if currentUserID != "" {
for _, post := range posts {
isLiked := h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
isFavorited := h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
isLikedMap[post.ID] = isLiked
isFavoritedMap[post.ID] = isFavorited
fmt.Printf("[DEBUG] List - postID: %s, isLiked: %v, isFavorited: %v\n", post.ID, isLiked, isFavorited)
}
} else {
fmt.Printf("[DEBUG] List - user not logged in\n")
}
// 转换为响应结构
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
response.Paginated(c, postResponses, total, page, pageSize)
}
// Update 更新帖子
func (h *PostHandler) Update(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
post, err := h.postService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "post not found")
return
}
if post.UserID != userID {
response.Forbidden(c, "cannot update others' post")
return
}
type UpdateRequest struct {
Title string `json:"title"`
Content string `json:"content"`
}
var req UpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if req.Title != "" {
post.Title = req.Title
}
if req.Content != "" {
post.Content = req.Content
}
err = h.postService.Update(c.Request.Context(), post)
if err != nil {
response.InternalServerError(c, "failed to update post")
return
}
currentUserID := c.GetString("user_id")
var isLiked, isFavorited bool
if currentUserID != "" {
isLiked = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
isFavorited = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
}
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
}
// Delete 删除帖子
func (h *PostHandler) Delete(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
post, err := h.postService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "post not found")
return
}
if post.UserID != userID {
response.Forbidden(c, "cannot delete others' post")
return
}
err = h.postService.Delete(c.Request.Context(), id)
if err != nil {
response.InternalServerError(c, "failed to delete post")
return
}
response.SuccessWithMessage(c, "post deleted", nil)
}
// Like 点赞帖子
func (h *PostHandler) Like(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
fmt.Printf("[DEBUG] Like - postID: %s, userID: %s\n", id, userID)
err := h.postService.Like(c.Request.Context(), id, userID)
if err != nil {
response.InternalServerError(c, "failed to like post")
return
}
// 获取更新后的帖子状态
post, err := h.postService.GetByID(c.Request.Context(), id)
if err != nil {
response.InternalServerError(c, "failed to get post")
return
}
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
fmt.Printf("[DEBUG] Like - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
}
// Unlike 取消点赞
func (h *PostHandler) Unlike(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
fmt.Printf("[DEBUG] Unlike - postID: %s, userID: %s\n", id, userID)
err := h.postService.Unlike(c.Request.Context(), id, userID)
if err != nil {
response.InternalServerError(c, "failed to unlike post")
return
}
// 获取更新后的帖子状态
post, err := h.postService.GetByID(c.Request.Context(), id)
if err != nil {
response.InternalServerError(c, "failed to get post")
return
}
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
fmt.Printf("[DEBUG] Unlike - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
}
// Favorite 收藏帖子
func (h *PostHandler) Favorite(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
fmt.Printf("[DEBUG] Favorite - postID: %s, userID: %s\n", id, userID)
err := h.postService.Favorite(c.Request.Context(), id, userID)
if err != nil {
response.InternalServerError(c, "failed to favorite post")
return
}
// 获取更新后的帖子状态
post, err := h.postService.GetByID(c.Request.Context(), id)
if err != nil {
response.InternalServerError(c, "failed to get post")
return
}
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
fmt.Printf("[DEBUG] Favorite - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
}
// Unfavorite 取消收藏
func (h *PostHandler) Unfavorite(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
fmt.Printf("[DEBUG] Unfavorite - postID: %s, userID: %s\n", id, userID)
err := h.postService.Unfavorite(c.Request.Context(), id, userID)
if err != nil {
response.InternalServerError(c, "failed to unfavorite post")
return
}
// 获取更新后的帖子状态
post, err := h.postService.GetByID(c.Request.Context(), id)
if err != nil {
response.InternalServerError(c, "failed to get post")
return
}
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
fmt.Printf("[DEBUG] Unfavorite - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
}
// GetUserPosts 获取用户帖子列表
func (h *PostHandler) GetUserPosts(c *gin.Context) {
userID := c.Param("id")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
posts, total, err := h.postService.GetUserPosts(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get user posts")
return
}
// 获取当前用户ID用于判断点赞和收藏状态
currentUserID := c.GetString("user_id")
isLikedMap := make(map[string]bool)
isFavoritedMap := make(map[string]bool)
if currentUserID != "" {
for _, post := range posts {
isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
}
}
// 转换为响应结构
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
response.Paginated(c, postResponses, total, page, pageSize)
}
// GetFavorites 获取收藏列表
func (h *PostHandler) GetFavorites(c *gin.Context) {
userID := c.Param("id")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
posts, total, err := h.postService.GetFavorites(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get favorites")
return
}
// 获取当前用户ID用于判断点赞和收藏状态
currentUserID := c.GetString("user_id")
isLikedMap := make(map[string]bool)
isFavoritedMap := make(map[string]bool)
if currentUserID != "" {
for _, post := range posts {
isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
}
}
// 转换为响应结构
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
response.Paginated(c, postResponses, total, page, pageSize)
}
// Search 搜索帖子
func (h *PostHandler) Search(c *gin.Context) {
keyword := c.Query("keyword")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
posts, total, err := h.postService.Search(c.Request.Context(), keyword, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to search posts")
return
}
// 获取当前用户ID用于判断点赞和收藏状态
currentUserID := c.GetString("user_id")
isLikedMap := make(map[string]bool)
isFavoritedMap := make(map[string]bool)
if currentUserID != "" {
for _, post := range posts {
isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
}
}
// 转换为响应结构
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
response.Paginated(c, postResponses, total, page, pageSize)
}

View File

@@ -0,0 +1,157 @@
package handler
import (
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
"github.com/gin-gonic/gin"
)
// PushHandler 推送处理器
type PushHandler struct {
pushService service.PushService
}
// NewPushHandler 创建推送处理器
func NewPushHandler(pushService service.PushService) *PushHandler {
return &PushHandler{
pushService: pushService,
}
}
// RegisterDevice 注册设备
// POST /api/v1/push/devices
func (h *PushHandler) RegisterDevice(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var req dto.RegisterDeviceRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
// 验证设备类型
deviceType := model.DeviceType(req.DeviceType)
if !isValidDeviceType(deviceType) {
response.BadRequest(c, "invalid device type")
return
}
err := h.pushService.RegisterDevice(c.Request.Context(), userID, req.DeviceID, deviceType, req.PushToken)
if err != nil {
response.InternalServerError(c, "failed to register device")
return
}
response.SuccessWithMessage(c, "device registered successfully", nil)
}
// UnregisterDevice 注销设备
// DELETE /api/v1/push/devices/:device_id
func (h *PushHandler) UnregisterDevice(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
deviceID := c.Param("device_id")
if deviceID == "" {
response.BadRequest(c, "device_id is required")
return
}
err := h.pushService.UnregisterDevice(c.Request.Context(), deviceID)
if err != nil {
response.InternalServerError(c, "failed to unregister device")
return
}
response.SuccessWithMessage(c, "device unregistered successfully", nil)
}
// GetMyDevices 获取当前用户的设备列表
// GET /api/v1/push/devices
func (h *PushHandler) GetMyDevices(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
// 这里需要从DeviceTokenRepository获取设备列表
// 由于PushService接口没有提供获取设备列表的方法我们暂时返回空列表
// TODO: 在PushService接口中添加GetUserDevices方法
_ = userID // 避免未使用变量警告
response.Success(c, []*dto.DeviceTokenResponse{})
}
// GetPushRecords 获取推送记录
// GET /api/v1/push/records
func (h *PushHandler) GetPushRecords(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
records, err := h.pushService.GetPendingPushes(c.Request.Context(), userID)
if err != nil {
response.InternalServerError(c, "failed to get push records")
return
}
response.Success(c, &dto.PushRecordListResponse{
Records: dto.PushRecordsToResponse(records),
Total: int64(len(records)),
})
}
// 辅助函数:验证设备类型
func isValidDeviceType(deviceType model.DeviceType) bool {
switch deviceType {
case model.DeviceTypeIOS, model.DeviceTypeAndroid, model.DeviceTypeWeb:
return true
default:
return false
}
}
// UpdateDeviceToken 更新设备推送Token
// PUT /api/v1/push/devices/:device_id/token
func (h *PushHandler) UpdateDeviceToken(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
deviceID := c.Param("device_id")
if deviceID == "" {
response.BadRequest(c, "device_id is required")
return
}
var req struct {
PushToken string `json:"push_token" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
err := h.pushService.UpdateDeviceToken(c.Request.Context(), deviceID, req.PushToken)
if err != nil {
response.InternalServerError(c, "failed to update device token")
return
}
response.SuccessWithMessage(c, "device token updated successfully", nil)
}

View File

@@ -0,0 +1,164 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// StickerHandler 自定义表情处理器
type StickerHandler struct {
stickerService service.StickerService
}
// NewStickerHandler 创建自定义表情处理器
func NewStickerHandler(stickerService service.StickerService) *StickerHandler {
return &StickerHandler{
stickerService: stickerService,
}
}
// GetStickersRequest 获取表情列表请求
type GetStickersRequest struct {
UserID string `form:"user_id" binding:"required"`
}
// AddStickerRequest 添加表情请求
type AddStickerRequest struct {
URL string `json:"url" binding:"required"`
Width int `json:"width"`
Height int `json:"height"`
}
// DeleteStickerRequest 删除表情请求
type DeleteStickerRequest struct {
StickerID string `json:"sticker_id" binding:"required"`
}
// ReorderStickersRequest 重新排序请求
type ReorderStickersRequest struct {
Orders map[string]int `json:"orders" binding:"required"`
}
// CheckStickerRequest 检查表情是否存在请求
type CheckStickerRequest struct {
URL string `form:"url" binding:"required"`
}
// GetStickers 获取用户的表情列表
func (h *StickerHandler) GetStickers(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
stickers, err := h.stickerService.GetUserStickers(userID)
if err != nil {
response.InternalServerError(c, "failed to get stickers")
return
}
response.Success(c, gin.H{"stickers": stickers})
}
// AddSticker 添加表情
func (h *StickerHandler) AddSticker(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var req AddStickerRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
sticker, err := h.stickerService.AddSticker(userID, req.URL, req.Width, req.Height)
if err != nil {
if err == service.ErrStickerAlreadyExists {
response.Error(c, http.StatusConflict, "sticker already exists")
return
}
if err == service.ErrInvalidStickerURL {
response.BadRequest(c, "invalid sticker url, only http/https is allowed")
return
}
response.InternalServerError(c, err.Error())
return
}
response.Success(c, gin.H{"sticker": sticker})
}
// DeleteSticker 删除表情
func (h *StickerHandler) DeleteSticker(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var req DeleteStickerRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.stickerService.DeleteSticker(userID, req.StickerID); err != nil {
response.InternalServerError(c, err.Error())
return
}
response.SuccessWithMessage(c, "sticker deleted successfully", nil)
}
// ReorderStickers 重新排序表情
func (h *StickerHandler) ReorderStickers(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var req ReorderStickersRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.stickerService.ReorderStickers(userID, req.Orders); err != nil {
response.InternalServerError(c, err.Error())
return
}
response.SuccessWithMessage(c, "stickers reordered successfully", nil)
}
// CheckStickerExists 检查表情是否存在
func (h *StickerHandler) CheckStickerExists(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
url := c.Query("url")
if url == "" {
response.BadRequest(c, "url is required")
return
}
exists, err := h.stickerService.CheckExists(userID, url)
if err != nil {
response.InternalServerError(c, err.Error())
return
}
response.Success(c, gin.H{"exists": exists})
}

View File

@@ -0,0 +1,154 @@
package handler
import (
"strconv"
"carrot_bbs/internal/cache"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/repository"
"carrot_bbs/internal/service"
)
// SystemMessageHandler 系统消息处理器
type SystemMessageHandler struct {
systemMsgService service.SystemMessageService
notifyRepo *repository.SystemNotificationRepository
}
// NewSystemMessageHandler 创建系统消息处理器
func NewSystemMessageHandler(
systemMsgService service.SystemMessageService,
notifyRepo *repository.SystemNotificationRepository,
) *SystemMessageHandler {
return &SystemMessageHandler{
systemMsgService: systemMsgService,
notifyRepo: notifyRepo,
}
}
// GetSystemMessages 获取系统消息列表
// GET /api/v1/messages/system
func (h *SystemMessageHandler) GetSystemMessages(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
// 获取当前用户的系统通知(从独立表中获取)
notifications, total, err := h.notifyRepo.GetByReceiverID(userID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get system messages")
return
}
// 转换为响应格式
result := make([]*dto.SystemMessageResponse, 0)
for _, n := range notifications {
resp := dto.SystemNotificationToResponse(n)
result = append(result, resp)
}
response.Paginated(c, result, total, page, pageSize)
}
// GetUnreadCount 获取系统消息未读数
// GET /api/v1/messages/system/unread-count
func (h *SystemMessageHandler) GetUnreadCount(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
// 获取当前用户的未读通知数
unreadCount, err := h.notifyRepo.GetUnreadCount(userID)
if err != nil {
response.InternalServerError(c, "failed to get unread count")
return
}
response.Success(c, &dto.SystemUnreadCountResponse{
UnreadCount: unreadCount,
})
}
// MarkAsRead 标记系统消息为已读
// PUT /api/v1/messages/system/:id/read
func (h *SystemMessageHandler) MarkAsRead(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
notificationIDStr := c.Param("id")
notificationID, err := strconv.ParseInt(notificationIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "invalid notification id")
return
}
// 标记为已读
err = h.notifyRepo.MarkAsRead(notificationID, userID)
if err != nil {
response.InternalServerError(c, "failed to mark as read")
return
}
cache.InvalidateUnreadSystem(cache.GetCache(), userID)
response.SuccessWithMessage(c, "marked as read", nil)
}
// MarkAllAsRead 标记所有系统消息为已读
// PUT /api/v1/messages/system/read-all
func (h *SystemMessageHandler) MarkAllAsRead(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
// 标记当前用户所有通知为已读
err := h.notifyRepo.MarkAllAsRead(userID)
if err != nil {
response.InternalServerError(c, "failed to mark all as read")
return
}
cache.InvalidateUnreadSystem(cache.GetCache(), userID)
response.SuccessWithMessage(c, "all messages marked as read", nil)
}
// DeleteSystemMessage 删除系统消息
// DELETE /api/v1/messages/system/:id
func (h *SystemMessageHandler) DeleteSystemMessage(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
notificationIDStr := c.Param("id")
notificationID, err := strconv.ParseInt(notificationIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "invalid notification id")
return
}
// 删除通知
err = h.notifyRepo.Delete(notificationID, userID)
if err != nil {
response.InternalServerError(c, "failed to delete notification")
return
}
cache.InvalidateUnreadSystem(cache.GetCache(), userID)
response.SuccessWithMessage(c, "notification deleted", nil)
}

View File

@@ -0,0 +1,90 @@
package handler
import (
"github.com/gin-gonic/gin"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// UploadHandler 上传处理器
type UploadHandler struct {
uploadService *service.UploadService
}
// NewUploadHandler 创建上传处理器
func NewUploadHandler(uploadService *service.UploadService) *UploadHandler {
return &UploadHandler{
uploadService: uploadService,
}
}
// UploadImage 上传图片
func (h *UploadHandler) UploadImage(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
file, err := c.FormFile("image")
if err != nil {
response.BadRequest(c, "image file is required")
return
}
url, err := h.uploadService.UploadImage(c.Request.Context(), file)
if err != nil {
response.InternalServerError(c, "failed to upload image")
return
}
response.Success(c, gin.H{"url": url})
}
// UploadAvatar 上传头像
func (h *UploadHandler) UploadAvatar(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
file, err := c.FormFile("image")
if err != nil {
response.BadRequest(c, "avatar file is required")
return
}
url, err := h.uploadService.UploadAvatar(c.Request.Context(), userID, file)
if err != nil {
response.InternalServerError(c, "failed to upload avatar")
return
}
response.Success(c, gin.H{"url": url})
}
// UploadCover 上传头图(个人主页封面)
func (h *UploadHandler) UploadCover(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
file, err := c.FormFile("image")
if err != nil {
response.BadRequest(c, "image file is required")
return
}
url, err := h.uploadService.UploadCover(c.Request.Context(), userID, file)
if err != nil {
response.InternalServerError(c, "failed to upload cover")
return
}
response.Success(c, gin.H{"url": url})
}

View File

@@ -0,0 +1,705 @@
package handler
import (
"fmt"
"strconv"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// UserHandler 用户处理器
type UserHandler struct {
userService *service.UserService
jwtService *service.JWTService
}
// NewUserHandler 创建用户处理器
func NewUserHandler(userService *service.UserService) *UserHandler {
return &UserHandler{
userService: userService,
}
}
// Register 用户注册
func (h *UserHandler) Register(c *gin.Context) {
type RegisterRequest struct {
Username string `json:"username" binding:"required"`
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Nickname string `json:"nickname" binding:"required"`
Phone string `json:"phone"`
VerificationCode string `json:"verification_code" binding:"required"`
}
var req RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
user, err := h.userService.Register(c.Request.Context(), req.Username, req.Email, req.Password, req.Nickname, req.Phone, req.VerificationCode)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to register")
return
}
// 生成Token
accessToken, _ := h.jwtService.GenerateAccessToken(user.ID, user.Username)
refreshToken, _ := h.jwtService.GenerateRefreshToken(user.ID, user.Username)
response.Success(c, gin.H{
"user": dto.ConvertUserToResponse(user),
"token": accessToken,
"refresh_token": refreshToken,
})
}
// Login 用户登录
func (h *UserHandler) Login(c *gin.Context) {
type LoginRequest struct {
Username string `json:"username"`
Account string `json:"account"`
Password string `json:"password" binding:"required"`
}
var req LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
account := req.Account
if account == "" {
account = req.Username
}
if account == "" {
response.BadRequest(c, "username or account is required")
return
}
user, err := h.userService.Login(c.Request.Context(), account, req.Password)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to login")
return
}
// 生成Token
accessToken, _ := h.jwtService.GenerateAccessToken(user.ID, user.Username)
refreshToken, _ := h.jwtService.GenerateRefreshToken(user.ID, user.Username)
response.Success(c, gin.H{
"user": dto.ConvertUserToResponse(user),
"token": accessToken,
"refresh_token": refreshToken,
})
}
// SendRegisterCode 发送注册验证码
func (h *UserHandler) SendRegisterCode(c *gin.Context) {
type SendCodeRequest struct {
Email string `json:"email" binding:"required,email"`
}
var req SendCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.userService.SendRegisterCode(c.Request.Context(), req.Email); err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to send register verification code")
return
}
response.Success(c, gin.H{"success": true})
}
// SendPasswordResetCode 发送找回密码验证码
func (h *UserHandler) SendPasswordResetCode(c *gin.Context) {
type SendCodeRequest struct {
Email string `json:"email" binding:"required,email"`
}
var req SendCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.userService.SendPasswordResetCode(c.Request.Context(), req.Email); err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to send reset verification code")
return
}
response.Success(c, gin.H{"success": true})
}
// ResetPassword 找回密码并重置
func (h *UserHandler) ResetPassword(c *gin.Context) {
type ResetPasswordRequest struct {
Email string `json:"email" binding:"required,email"`
VerificationCode string `json:"verification_code" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
var req ResetPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.userService.ResetPasswordByEmail(c.Request.Context(), req.Email, req.VerificationCode, req.NewPassword); err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to reset password")
return
}
response.Success(c, gin.H{"success": true})
}
// GetCurrentUser 获取当前用户
func (h *UserHandler) GetCurrentUser(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
response.NotFound(c, "user not found")
return
}
// 实时计算帖子数量
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), userID)
if err != nil {
// 如果获取失败,使用数据库中的值
postsCount = int64(user.PostsCount)
}
response.Success(c, dto.ConvertUserToDetailResponseWithPostsCount(user, int(postsCount)))
}
// GetUserByID 获取指定用户
func (h *UserHandler) GetUserByID(c *gin.Context) {
id := c.Param("id")
currentUserID := c.GetString("user_id")
// 获取用户信息,包含双向关注状态
user, isFollowing, isFollowingMe, err := h.userService.GetUserByIDWithMutualFollowStatus(c.Request.Context(), id, currentUserID)
if err != nil {
response.NotFound(c, "user not found")
return
}
// 实时计算帖子数量
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), id)
if err != nil {
// 如果获取失败,使用数据库中的值
postsCount = int64(user.PostsCount)
}
// 转换为响应格式,包含双向关注状态和实时计算的帖子数量
userResponse := dto.ConvertUserToResponseWithMutualFollowAndPostsCount(user, isFollowing, isFollowingMe, int(postsCount))
response.Success(c, userResponse)
}
// UpdateUser 更新用户
func (h *UserHandler) UpdateUser(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
type UpdateRequest struct {
Nickname string `json:"nickname"`
Bio string `json:"bio"`
Website string `json:"website"`
Location string `json:"location"`
Avatar string `json:"avatar"`
Phone *string `json:"phone"`
Email *string `json:"email"`
}
var req UpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
response.NotFound(c, "user not found")
return
}
if req.Nickname != "" {
user.Nickname = req.Nickname
}
if req.Bio != "" {
user.Bio = req.Bio
}
if req.Website != "" {
user.Website = req.Website
}
if req.Location != "" {
user.Location = req.Location
}
if req.Avatar != "" {
user.Avatar = req.Avatar
}
if req.Phone != nil {
user.Phone = req.Phone
}
if req.Email != nil {
if user.Email == nil || *user.Email != *req.Email {
user.EmailVerified = false
}
user.Email = req.Email
}
err = h.userService.UpdateUser(c.Request.Context(), user)
if err != nil {
response.InternalServerError(c, "failed to update user")
return
}
// 实时计算帖子数量
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), userID)
if err != nil {
// 如果获取失败,使用数据库中的值
postsCount = int64(user.PostsCount)
}
response.Success(c, dto.ConvertUserToDetailResponseWithPostsCount(user, int(postsCount)))
}
// SendEmailVerifyCode 发送当前用户邮箱验证码
func (h *UserHandler) SendEmailVerifyCode(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
type SendCodeRequest struct {
Email string `json:"email" binding:"required,email"`
}
var req SendCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.userService.SendCurrentUserEmailVerifyCode(c.Request.Context(), userID, req.Email); err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to send email verify code")
return
}
response.Success(c, gin.H{"success": true})
}
// VerifyEmail 验证当前用户邮箱
func (h *UserHandler) VerifyEmail(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
type VerifyEmailRequest struct {
Email string `json:"email" binding:"required,email"`
VerificationCode string `json:"verification_code" binding:"required"`
}
var req VerifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.userService.VerifyCurrentUserEmail(c.Request.Context(), userID, req.Email, req.VerificationCode); err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to verify email")
return
}
response.Success(c, gin.H{"success": true})
}
// RefreshToken 刷新Token
func (h *UserHandler) RefreshToken(c *gin.Context) {
type RefreshRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
}
var req RefreshRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
// 解析 refresh token
claims, err := h.jwtService.ParseToken(req.RefreshToken)
if err != nil {
response.Unauthorized(c, "invalid refresh token")
return
}
// 生成新 token
accessToken, _ := h.jwtService.GenerateAccessToken(claims.UserID, claims.Username)
refreshToken, _ := h.jwtService.GenerateRefreshToken(claims.UserID, claims.Username)
response.Success(c, gin.H{
"token": accessToken,
"refresh_token": refreshToken,
})
}
// SetJWTService 设置JWT服务
func (h *UserHandler) SetJWTService(jwtSvc *service.JWTService) {
h.jwtService = jwtSvc
}
// FollowUser 关注用户
func (h *UserHandler) FollowUser(c *gin.Context) {
userID := c.Param("id")
currentUserID := c.GetString("user_id")
if userID == currentUserID {
response.BadRequest(c, "cannot follow yourself")
return
}
err := h.userService.FollowUser(c.Request.Context(), currentUserID, userID)
if err != nil {
response.InternalServerError(c, "failed to follow user")
return
}
response.Success(c, gin.H{"success": true})
}
// UnfollowUser 取消关注用户
func (h *UserHandler) UnfollowUser(c *gin.Context) {
userID := c.Param("id")
currentUserID := c.GetString("user_id")
err := h.userService.UnfollowUser(c.Request.Context(), currentUserID, userID)
if err != nil {
response.InternalServerError(c, "failed to unfollow user")
return
}
response.Success(c, gin.H{"success": true})
}
// BlockUser 拉黑用户
func (h *UserHandler) BlockUser(c *gin.Context) {
targetUserID := c.Param("id")
currentUserID := c.GetString("user_id")
if targetUserID == currentUserID {
response.BadRequest(c, "cannot block yourself")
return
}
err := h.userService.BlockUser(c.Request.Context(), currentUserID, targetUserID)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to block user")
return
}
response.Success(c, gin.H{"success": true})
}
// UnblockUser 取消拉黑
func (h *UserHandler) UnblockUser(c *gin.Context) {
targetUserID := c.Param("id")
currentUserID := c.GetString("user_id")
if targetUserID == currentUserID {
response.BadRequest(c, "cannot unblock yourself")
return
}
err := h.userService.UnblockUser(c.Request.Context(), currentUserID, targetUserID)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to unblock user")
return
}
response.Success(c, gin.H{"success": true})
}
// GetBlockedUsers 获取黑名单列表
func (h *UserHandler) GetBlockedUsers(c *gin.Context) {
currentUserID := c.GetString("user_id")
if currentUserID == "" {
response.Unauthorized(c, "")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
users, total, err := h.userService.GetBlockedUsers(c.Request.Context(), currentUserID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get blocked users")
return
}
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses := dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
response.Paginated(c, userResponses, total, page, pageSize)
}
// GetBlockStatus 获取拉黑状态
func (h *UserHandler) GetBlockStatus(c *gin.Context) {
targetUserID := c.Param("id")
currentUserID := c.GetString("user_id")
if currentUserID == "" {
response.Unauthorized(c, "")
return
}
if targetUserID == "" {
response.BadRequest(c, "target user id is required")
return
}
isBlocked, err := h.userService.IsBlocked(c.Request.Context(), currentUserID, targetUserID)
if err != nil {
response.InternalServerError(c, "failed to get block status")
return
}
response.Success(c, gin.H{"is_blocked": isBlocked})
}
// GetFollowingList 获取关注列表
func (h *UserHandler) GetFollowingList(c *gin.Context) {
userID := c.Param("id")
currentUserID := c.GetString("user_id")
page := c.DefaultQuery("page", "1")
pageSize := c.DefaultQuery("page_size", "20")
users, err := h.userService.GetFollowingList(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get following list")
return
}
// 如果已登录,获取双向关注状态和实时计算的帖子数量
var userResponses []*dto.UserResponse
if currentUserID != "" && len(users) > 0 {
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
statusMap, _ := h.userService.GetMutualFollowStatus(c.Request.Context(), currentUserID, userIDs)
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, statusMap, postsCountMap)
} else if len(users) > 0 {
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
} else {
userResponses = dto.ConvertUsersToResponse(users)
}
response.Success(c, gin.H{
"list": userResponses,
})
}
// GetFollowersList 获取粉丝列表
func (h *UserHandler) GetFollowersList(c *gin.Context) {
userID := c.Param("id")
currentUserID := c.GetString("user_id")
page := c.DefaultQuery("page", "1")
pageSize := c.DefaultQuery("page_size", "20")
fmt.Printf("[DEBUG] GetFollowersList: userID=%s, currentUserID=%s\n", userID, currentUserID)
users, err := h.userService.GetFollowersList(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get followers list")
return
}
fmt.Printf("[DEBUG] GetFollowersList: found %d users\n", len(users))
// 如果已登录,获取双向关注状态和实时计算的帖子数量
var userResponses []*dto.UserResponse
if currentUserID != "" && len(users) > 0 {
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
fmt.Printf("[DEBUG] GetFollowersList: checking mutual follow status for userIDs=%v\n", userIDs)
statusMap, _ := h.userService.GetMutualFollowStatus(c.Request.Context(), currentUserID, userIDs)
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, statusMap, postsCountMap)
} else if len(users) > 0 {
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
} else {
userResponses = dto.ConvertUsersToResponse(users)
}
response.Success(c, gin.H{
"list": userResponses,
})
}
// CheckUsername 检查用户名是否可用
func (h *UserHandler) CheckUsername(c *gin.Context) {
username := c.Query("username")
if username == "" {
response.BadRequest(c, "username is required")
return
}
available, err := h.userService.CheckUsernameAvailable(c.Request.Context(), username)
if err != nil {
response.InternalServerError(c, "failed to check username")
return
}
response.Success(c, gin.H{"available": available})
}
// ChangePassword 修改密码
func (h *UserHandler) ChangePassword(c *gin.Context) {
currentUserID := c.GetString("user_id")
type ChangePasswordRequest struct {
OldPassword string `json:"old_password" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
VerificationCode string `json:"verification_code" binding:"required"`
}
var req ChangePasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
err := h.userService.ChangePassword(c.Request.Context(), currentUserID, req.OldPassword, req.NewPassword, req.VerificationCode)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to change password")
return
}
response.Success(c, gin.H{"success": true})
}
// SendChangePasswordCode 发送修改密码验证码
func (h *UserHandler) SendChangePasswordCode(c *gin.Context) {
currentUserID := c.GetString("user_id")
if currentUserID == "" {
response.Unauthorized(c, "")
return
}
err := h.userService.SendChangePasswordCode(c.Request.Context(), currentUserID)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to send change password code")
return
}
response.Success(c, gin.H{"success": true})
}
// Search 搜索用户
func (h *UserHandler) Search(c *gin.Context) {
keyword := c.Query("keyword")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
users, total, err := h.userService.Search(c.Request.Context(), keyword, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to search users")
return
}
// 获取实时计算的帖子数量
var userResponses []*dto.UserResponse
if len(users) > 0 {
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
} else {
userResponses = dto.ConvertUsersToResponse(users)
}
response.Paginated(c, userResponses, total, page, pageSize)
}

View File

@@ -0,0 +1,216 @@
package handler
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// VoteHandler 投票处理器
type VoteHandler struct {
voteService *service.VoteService
postService *service.PostService
}
// NewVoteHandler 创建投票处理器
func NewVoteHandler(voteService *service.VoteService, postService *service.PostService) *VoteHandler {
return &VoteHandler{
voteService: voteService,
postService: postService,
}
}
// CreateVotePost 创建投票帖子
// POST /api/v1/posts/vote
func (h *VoteHandler) CreateVotePost(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "请先登录")
return
}
var req dto.CreateVotePostRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
post, err := h.voteService.CreateVotePost(c.Request.Context(), userID, &req)
if err != nil {
var moderationErr *service.PostModerationRejectedError
if errors.As(err, &moderationErr) {
response.BadRequest(c, moderationErr.UserMessage())
return
}
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, post)
}
// GetVoteResult 获取投票结果
// GET /api/v1/posts/:id/vote
func (h *VoteHandler) GetVoteResult(c *gin.Context) {
postID := c.Param("id")
if postID == "" {
response.BadRequest(c, "帖子ID不能为空")
return
}
// 验证帖子存在
_, err := h.postService.GetByID(c.Request.Context(), postID)
if err != nil {
response.NotFound(c, "帖子不存在")
return
}
// 获取当前用户ID可选登录
userID := c.GetString("user_id")
// 如果用户未登录返回不带has_voted的结果
if userID == "" {
options, err := h.voteService.GetVoteOptions(postID)
if err != nil {
response.InternalServerError(c, "获取投票选项失败")
return
}
// 计算总票数
totalVotes := 0
for _, option := range options {
totalVotes += option.VotesCount
}
result := &dto.VoteResultDTO{
Options: options,
TotalVotes: totalVotes,
HasVoted: false,
}
response.Success(c, result)
return
}
// 用户已登录,获取完整的投票结果
result, err := h.voteService.GetVoteResult(postID, userID)
if err != nil {
response.InternalServerError(c, "获取投票结果失败")
return
}
response.Success(c, result)
}
// Vote 投票
// POST /api/v1/posts/:id/vote
func (h *VoteHandler) Vote(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "请先登录")
return
}
postID := c.Param("id")
if postID == "" {
response.BadRequest(c, "帖子ID不能为空")
return
}
// 验证帖子存在
_, err := h.postService.GetByID(c.Request.Context(), postID)
if err != nil {
response.NotFound(c, "帖子不存在")
return
}
// 解析请求体
var req struct {
OptionID string `json:"option_id" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.voteService.Vote(c.Request.Context(), postID, userID, req.OptionID); err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, gin.H{"success": true})
}
// Unvote 取消投票
// DELETE /api/v1/posts/:id/vote
func (h *VoteHandler) Unvote(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "请先登录")
return
}
postID := c.Param("id")
if postID == "" {
response.BadRequest(c, "帖子ID不能为空")
return
}
// 验证帖子存在
_, err := h.postService.GetByID(c.Request.Context(), postID)
if err != nil {
response.NotFound(c, "帖子不存在")
return
}
if err := h.voteService.Unvote(c.Request.Context(), postID, userID); err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, gin.H{"success": true})
}
// UpdateVoteOption 更新投票选项(仅作者)
// PUT /api/v1/vote-options/:id
func (h *VoteHandler) UpdateVoteOption(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "请先登录")
return
}
optionID := c.Param("id")
if optionID == "" {
response.BadRequest(c, "选项ID不能为空")
return
}
// 解析请求体
var req struct {
Content string `json:"content" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
// 获取帖子ID从查询参数或请求体中获取
postID := c.Query("post_id")
if postID == "" {
response.BadRequest(c, "帖子ID不能为空")
return
}
if err := h.voteService.UpdateVoteOption(c.Request.Context(), postID, optionID, userID, req.Content); err != nil {
response.Error(c, http.StatusForbidden, err.Error())
return
}
response.Success(c, gin.H{"success": true})
}

View File

@@ -0,0 +1,866 @@
package handler
import (
"context"
"encoding/json"
"log"
"net/http"
"strconv"
"strings"
"time"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
ws "carrot_bbs/internal/pkg/websocket"
"carrot_bbs/internal/repository"
"carrot_bbs/internal/service"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // 允许所有来源,生产环境应该限制
},
}
// WebSocketHandler WebSocket处理器
type WebSocketHandler struct {
jwtService *service.JWTService
chatService service.ChatService
groupService service.GroupService
groupRepo repository.GroupRepository
userRepo *repository.UserRepository
wsManager *ws.WebSocketManager
}
// NewWebSocketHandler 创建WebSocket处理器
func NewWebSocketHandler(
jwtService *service.JWTService,
chatService service.ChatService,
groupService service.GroupService,
groupRepo repository.GroupRepository,
userRepo *repository.UserRepository,
wsManager *ws.WebSocketManager,
) *WebSocketHandler {
return &WebSocketHandler{
jwtService: jwtService,
chatService: chatService,
groupService: groupService,
groupRepo: groupRepo,
userRepo: userRepo,
wsManager: wsManager,
}
}
// HandleWebSocket 处理WebSocket连接
func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) {
// 调试:打印请求头信息
log.Printf("[WebSocket] 收到请求: Method=%s, Path=%s", c.Request.Method, c.Request.URL.Path)
log.Printf("[WebSocket] 请求头: Connection=%s, Upgrade=%s",
c.GetHeader("Connection"),
c.GetHeader("Upgrade"))
log.Printf("[WebSocket] Sec-WebSocket-Key=%s, Sec-WebSocket-Version=%s",
c.GetHeader("Sec-WebSocket-Key"),
c.GetHeader("Sec-WebSocket-Version"))
// 从query参数获取token
token := c.Query("token")
if token == "" {
// 尝试从header获取
authHeader := c.GetHeader("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
token = strings.TrimPrefix(authHeader, "Bearer ")
}
}
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing token"})
return
}
// 验证token
claims, err := h.jwtService.ParseToken(token)
if err != nil {
log.Printf("Invalid token: %v", err)
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
userID := claims.UserID
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token claims"})
return
}
// 升级HTTP连接为WebSocket连接
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("Failed to upgrade connection: %v", err)
log.Printf("[WebSocket] 请求详情 - User-Agent: %s, Content-Type: %s",
c.GetHeader("User-Agent"),
c.GetHeader("Content-Type"))
return
}
// 如果用户已在线,先注销旧连接
if h.wsManager.IsUserOnline(userID) {
log.Printf("[DEBUG] 用户 %s 已有在线连接,复用该连接", userID)
} else {
log.Printf("[DEBUG] 用户 %s 当前不在线,创建新连接", userID)
}
// 创建客户端
client := &ws.Client{
ID: userID,
UserID: userID,
Conn: conn,
Send: make(chan []byte, 256),
Manager: h.wsManager,
}
// 注册客户端
h.wsManager.Register(client)
// 启动读写协程
go client.WritePump()
go h.handleMessages(client)
log.Printf("[DEBUG] WebSocket连接建立: userID=%s, 当前在线=%v", userID, h.wsManager.IsUserOnline(userID))
}
// handleMessages 处理客户端消息
// 针对移动端优化:增加超时时间到 120 秒,配合客户端 55 秒心跳
func (h *WebSocketHandler) handleMessages(client *ws.Client) {
defer func() {
h.wsManager.Unregister(client)
client.Conn.Close()
}()
client.Conn.SetReadLimit(512 * 1024) // 512KB
client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒
client.Conn.SetPongHandler(func(string) error {
client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒
return nil
})
// 心跳定时器 - 服务端主动 ping 间隔保持 30 秒
pingTicker := time.NewTicker(30 * time.Second)
defer pingTicker.Stop()
for {
select {
case <-pingTicker.C:
// 发送心跳
if err := client.SendPing(); err != nil {
log.Printf("Failed to send ping: %v", err)
return
}
default:
_, message, err := client.Conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket error: %v", err)
}
return
}
var wsMsg ws.WSMessage
if err := json.Unmarshal(message, &wsMsg); err != nil {
log.Printf("Failed to unmarshal message: %v", err)
continue
}
h.processMessage(client, &wsMsg)
}
}
}
// processMessage 处理消息
func (h *WebSocketHandler) processMessage(client *ws.Client, msg *ws.WSMessage) {
switch msg.Type {
case ws.MessageTypePing:
// 响应心跳
if err := client.SendPong(); err != nil {
log.Printf("Failed to send pong: %v", err)
}
case ws.MessageTypePong:
// 客户端响应心跳
case ws.MessageTypeMessage:
// 处理聊天消息
h.handleChatMessage(client, msg)
case ws.MessageTypeTyping:
// 处理正在输入状态
h.handleTyping(client, msg)
case ws.MessageTypeRead:
// 处理已读回执
h.handleReadReceipt(client, msg)
// 群组消息处理
case ws.MessageTypeGroupMessage:
// 处理群消息
h.handleGroupMessage(client, msg)
case ws.MessageTypeGroupTyping:
// 处理群输入状态
h.handleGroupTyping(client, msg)
case ws.MessageTypeGroupRead:
// 处理群消息已读
h.handleGroupReadReceipt(client, msg)
case ws.MessageTypeGroupRecall:
// 处理群消息撤回
h.handleGroupRecall(client, msg)
default:
log.Printf("Unknown message type: %s", msg.Type)
}
}
// handleChatMessage 处理聊天消息
func (h *WebSocketHandler) handleChatMessage(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
log.Printf("Invalid message data format")
return
}
log.Printf("[DEBUG handleChatMessage] 完整data: %+v", data)
conversationIDStr, _ := data["conversationId"].(string)
if conversationIDStr == "" {
log.Printf("Missing conversationId")
return
}
// 解析会话ID
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
log.Printf("Invalid conversation ID: %v", err)
return
}
// 解析 segments
var segments model.MessageSegments
if data["segments"] != nil {
segmentsBytes, err := json.Marshal(data["segments"])
if err == nil {
json.Unmarshal(segmentsBytes, &segments)
}
}
// 从 segments 中提取回复消息ID
replyToID := dto.GetReplyMessageID(segments)
var replyToIDPtr *string
if replyToID != "" {
replyToIDPtr = &replyToID
}
// 发送消息 - 使用 segments
message, err := h.chatService.SendMessage(context.Background(), client.UserID, conversationID, segments, replyToIDPtr)
if err != nil {
log.Printf("Failed to send message: %v", err)
// 发送错误消息
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
"error": "Failed to send message",
})
if client.Send != nil {
msgBytes, _ := json.Marshal(errorMsg)
client.Send <- msgBytes
}
return
}
// 发送确认消息(使用 meta 事件格式,包含完整的消息内容)
metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{
"detail_type": ws.MetaDetailTypeAck,
"conversation_id": conversationID,
"id": message.ID,
"user_id": client.UserID,
"sender_id": client.UserID,
"seq": message.Seq,
"segments": message.Segments,
"created_at": message.CreatedAt.UnixMilli(),
})
if client.Send != nil {
msgBytes, _ := json.Marshal(metaAckMsg)
log.Printf("[DEBUG handleChatMessage] 私聊 ack 消息: %s", string(msgBytes))
log.Printf("[DEBUG handleChatMessage] message.Segments 类型: %T, 值: %+v", message.Segments, message.Segments)
client.Send <- msgBytes
}
}
// handleTyping 处理正在输入状态
func (h *WebSocketHandler) handleTyping(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
return
}
conversationIDStr, _ := data["conversationId"].(string)
if conversationIDStr == "" {
return
}
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
return
}
// 直接使用 string 类型的 userID
h.chatService.SendTyping(context.Background(), client.UserID, conversationID)
}
// handleReadReceipt 处理已读回执
func (h *WebSocketHandler) handleReadReceipt(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
return
}
conversationIDStr, _ := data["conversationId"].(string)
if conversationIDStr == "" {
return
}
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
return
}
// 获取lastReadSeq
lastReadSeq, _ := data["lastReadSeq"].(float64)
if lastReadSeq == 0 {
return
}
// 直接使用 string 类型的 userID 和 conversationID
if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil {
log.Printf("Failed to mark as read: %v", err)
}
}
// ==================== 群组消息处理 ====================
// handleGroupMessage 处理群消息
func (h *WebSocketHandler) handleGroupMessage(client *ws.Client, msg *ws.WSMessage) {
// 打印接收到的消息类型和数据,用于调试
log.Printf("[handleGroupMessage] Received message type: %s", msg.Type)
log.Printf("[handleGroupMessage] Message data: %+v", msg.Data)
data, ok := msg.Data.(map[string]interface{})
if !ok {
log.Printf("Invalid group message data format: data is not map[string]interface{}")
return
}
// 解析群组ID支持 camelCase 和 snake_case
var groupIDFloat float64
groupID := "" // 使用 groupID 作为最终变量名
if val, ok := data["groupId"].(float64); ok {
groupIDFloat = val
groupID = strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
} else if val, ok := data["group_id"].(string); ok {
groupID = val
}
if groupID == "" {
log.Printf("Missing groupId in group message")
return
}
// 解析会话ID支持 camelCase 和 snake_case
var conversationID string
if val, ok := data["conversationId"].(string); ok {
conversationID = val
} else if val, ok := data["conversation_id"].(string); ok {
conversationID = val
}
if conversationID == "" {
log.Printf("Missing conversationId in group message")
return
}
// 解析 segments
var segments model.MessageSegments
if data["segments"] != nil {
segmentsBytes, err := json.Marshal(data["segments"])
if err == nil {
json.Unmarshal(segmentsBytes, &segments)
}
}
// 解析@用户列表(支持 camelCase 和 snake_case
var mentionUsers []uint64
var mentionUsersInterface []interface{}
if val, ok := data["mentionUsers"].([]interface{}); ok {
mentionUsersInterface = val
} else if val, ok := data["mention_users"].([]interface{}); ok {
mentionUsersInterface = val
}
if len(mentionUsersInterface) > 0 {
for _, uid := range mentionUsersInterface {
if uidFloat, ok := uid.(float64); ok {
mentionUsers = append(mentionUsers, uint64(uidFloat))
} else if uidStr, ok := uid.(string); ok {
// 处理字符串格式的用户ID
if uidInt, err := strconv.ParseUint(uidStr, 10, 64); err == nil {
mentionUsers = append(mentionUsers, uidInt)
}
}
}
}
// 解析@所有人(支持 camelCase 和 snake_case
var mentionAll bool
if val, ok := data["mentionAll"].(bool); ok {
mentionAll = val
} else if val, ok := data["mention_all"].(bool); ok {
mentionAll = val
}
// 检查用户是否可以发送群消息(验证成员身份和禁言状态)
// client.UserID 已经是 string 格式的 UUID
if err := h.groupService.CanSendGroupMessage(client.UserID, groupID); err != nil {
log.Printf("User cannot send group message: %v", err)
// 发送错误消息
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
"error": "Cannot send group message",
"reason": err.Error(),
"type": "group_message_error",
"groupId": groupID,
})
if client.Send != nil {
msgBytes, _ := json.Marshal(errorMsg)
client.Send <- msgBytes
}
return
}
// 检查@所有人权限(只有群主和管理员可以@所有人)
if mentionAll {
if !h.groupService.IsGroupAdmin(client.UserID, groupID) {
log.Printf("User %s has no permission to mention all in group %s", client.UserID, groupID)
mentionAll = false // 取消@所有人标记
}
}
// 创建消息
message := &model.Message{
ConversationID: conversationID,
SenderID: client.UserID,
Segments: segments,
Status: model.MessageStatusNormal,
MentionAll: mentionAll,
}
// 序列化mentionUsers为JSON
if len(mentionUsers) > 0 {
mentionUsersJSON, _ := json.Marshal(mentionUsers)
message.MentionUsers = string(mentionUsersJSON)
}
// 保存消息到数据库(只存库,不发私聊 WebSocket 帧,群消息通过 BroadcastGroupMessage 单独广播)
savedMessage, err := h.chatService.SaveMessage(context.Background(), client.UserID, conversationID, segments, nil)
if err != nil {
log.Printf("Failed to save group message: %v", err)
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
"error": "Failed to save group message",
})
if client.Send != nil {
msgBytes, _ := json.Marshal(errorMsg)
client.Send <- msgBytes
}
return
}
// 更新消息的mention信息
if len(mentionUsers) > 0 || mentionAll {
message.ID = savedMessage.ID
message.Seq = savedMessage.Seq
}
// 构造群消息响应
groupMsg := &ws.GroupMessage{
ID: savedMessage.ID,
ConversationID: conversationID,
GroupID: groupID,
SenderID: client.UserID,
Seq: savedMessage.Seq,
Segments: segments,
MentionUsers: mentionUsers,
MentionAll: mentionAll,
CreatedAt: savedMessage.CreatedAt.UnixMilli(),
}
// 广播消息给群组所有成员(排除发送者)
h.BroadcastGroupMessage(groupID, groupMsg, client.UserID)
// 发送确认消息给发送者作为meta事件
// 使用 meta 事件格式发送 ack
log.Printf("[DEBUG HandleGroupMessageSend] 准备发送 ack 消息, userID=%s, messageID=%s, seq=%d",
client.UserID, savedMessage.ID, savedMessage.Seq)
metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{
"detail_type": ws.MetaDetailTypeAck,
"conversation_id": conversationID,
"group_id": groupID,
"id": savedMessage.ID,
"user_id": client.UserID,
"sender_id": client.UserID,
"seq": savedMessage.Seq,
"segments": segments,
"created_at": savedMessage.CreatedAt.UnixMilli(),
})
if client.Send != nil {
msgBytes, _ := json.Marshal(metaAckMsg)
log.Printf("[DEBUG HandleGroupMessageSend] 发送 ack 消息到 channel, userID=%s, msg=%s",
client.UserID, string(msgBytes))
client.Send <- msgBytes
} else {
log.Printf("[ERROR HandleGroupMessageSend] client.Send 为 nil, userID=%s", client.UserID)
}
// 处理@提及通知
if len(mentionUsers) > 0 || mentionAll {
// 提取文本正文(不含 @ 部分)
textContent := dto.ExtractTextContentFromModel(segments)
// 在通知内容前拼接被@的真实昵称,通过群成员列表查找
mentionContent := h.buildMentionContent(groupID, mentionUsers, mentionAll, textContent)
h.handleGroupMention(groupID, savedMessage.ID, client.UserID, mentionContent, mentionUsers, mentionAll)
}
}
// handleGroupTyping 处理群输入状态
func (h *WebSocketHandler) handleGroupTyping(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
return
}
groupIDFloat, _ := data["groupId"].(float64)
if groupIDFloat == 0 {
return
}
groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
isTyping, _ := data["isTyping"].(bool)
// 验证用户是否是群成员
// client.UserID 已经是 string 格式的 UUID
isMember, err := h.groupRepo.IsMember(groupID, client.UserID)
if err != nil || !isMember {
return
}
// 获取用户信息
user, err := h.userRepo.GetByID(client.UserID)
if err != nil {
return
}
// 构造输入状态消息
typingMsg := &ws.GroupTypingMessage{
GroupID: groupID,
UserID: client.UserID,
Username: user.Username,
IsTyping: isTyping,
}
// 广播给群组其他成员
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg)
h.BroadcastGroupNoticeExclude(groupID, wsMsg, client.UserID)
}
// handleGroupReadReceipt 处理群消息已读回执
func (h *WebSocketHandler) handleGroupReadReceipt(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
return
}
conversationID, _ := data["conversationId"].(string)
if conversationID == "" {
return
}
lastReadSeq, _ := data["lastReadSeq"].(float64)
if lastReadSeq == 0 {
return
}
// 标记已读
if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil {
log.Printf("Failed to mark group message as read: %v", err)
}
}
// handleGroupRecall 处理群消息撤回
func (h *WebSocketHandler) handleGroupRecall(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
return
}
messageID, _ := data["messageId"].(string)
if messageID == "" {
return
}
groupIDFloat, _ := data["groupId"].(float64)
if groupIDFloat == 0 {
return
}
groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
// 撤回消息
if err := h.chatService.RecallMessage(context.Background(), messageID, client.UserID); err != nil {
log.Printf("Failed to recall group message: %v", err)
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
"error": "Failed to recall message",
})
if client.Send != nil {
msgBytes, _ := json.Marshal(errorMsg)
client.Send <- msgBytes
}
return
}
// 广播撤回通知给群组所有成员
recallNotice := ws.CreateWSMessage(ws.MessageTypeGroupRecall, map[string]interface{}{
"messageId": messageID,
"groupId": groupID,
"userId": client.UserID,
"timestamp": time.Now().UnixMilli(),
})
h.BroadcastGroupNotice(groupID, recallNotice)
}
// handleGroupMention 处理群消息@提及通知
func (h *WebSocketHandler) handleGroupMention(groupID string, messageID, senderID, content string, mentionUsers []uint64, mentionAll bool) {
// 如果@所有人,获取所有群成员
if mentionAll {
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for mention all: %v", err)
return
}
for _, member := range members {
// 不通知发送者自己
memberIDStr := member.UserID
if memberIDStr == senderID {
continue
}
// 发送@提及通知
mentionMsg := &ws.GroupMentionMessage{
GroupID: groupID,
MessageID: messageID,
FromUserID: senderID,
Content: truncateContent(content, 50),
MentionAll: true,
CreatedAt: time.Now().UnixMilli(),
}
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg)
h.wsManager.SendToUser(memberIDStr, wsMsg)
}
return
}
// 处理特定用户的@提及
for _, userID := range mentionUsers {
// userID 是 uint64转换为 string
userIDStr := strconv.FormatUint(userID, 10)
if userIDStr == senderID {
continue // 不通知发送者自己
}
mentionMsg := &ws.GroupMentionMessage{
GroupID: groupID,
MessageID: messageID,
FromUserID: senderID,
Content: truncateContent(content, 50),
MentionAll: false,
CreatedAt: time.Now().UnixMilli(),
}
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg)
h.wsManager.SendToUser(userIDStr, wsMsg)
}
}
// buildMentionContent 构建@提及通知的内容,通过群成员列表查找被@用户的真实昵称
func (h *WebSocketHandler) buildMentionContent(groupID string, mentionUsers []uint64, mentionAll bool, textBody string) string {
var prefix string
if mentionAll {
prefix = "@所有人 "
} else if len(mentionUsers) > 0 {
// 查询群成员列表,找到被@用户的昵称
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err == nil {
memberNickMap := make(map[string]string, len(members))
for _, m := range members {
displayName := m.Nickname
if displayName == "" {
displayName = m.UserID
}
memberNickMap[m.UserID] = displayName
}
for _, uid := range mentionUsers {
uidStr := strconv.FormatUint(uid, 10)
if name, ok := memberNickMap[uidStr]; ok {
prefix += "@" + name + " "
} else {
prefix += "@某人 "
}
}
} else {
for range mentionUsers {
prefix += "@某人 "
}
}
}
return prefix + textBody
}
// BroadcastGroupMessage 向群组所有成员广播消息
func (h *WebSocketHandler) BroadcastGroupMessage(groupID string, message *ws.GroupMessage, excludeUserID string) {
// 获取群组所有成员
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for broadcast: %v", err)
return
}
// 创建WebSocket消息
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMessage, message)
// 遍历成员,如果在线则发送消息
for _, member := range members {
memberIDStr := member.UserID
// 排除发送者
if memberIDStr == excludeUserID {
continue
}
// 发送消息
h.wsManager.SendToUser(memberIDStr, wsMsg)
}
}
// BroadcastGroupNotice 广播群组通知给所有成员
func (h *WebSocketHandler) BroadcastGroupNotice(groupID string, notice *ws.WSMessage) {
// 获取群组所有成员
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for notice broadcast: %v", err)
return
}
// 遍历成员,如果在线则发送通知
for _, member := range members {
memberIDStr := member.UserID
h.wsManager.SendToUser(memberIDStr, notice)
}
}
// BroadcastGroupNoticeExclude 广播群组通知给所有成员(排除指定用户)
func (h *WebSocketHandler) BroadcastGroupNoticeExclude(groupID string, notice *ws.WSMessage, excludeUserID string) {
// 获取群组所有成员
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for notice broadcast: %v", err)
return
}
// 遍历成员,如果在线则发送通知
for _, member := range members {
memberIDStr := member.UserID
if memberIDStr == excludeUserID {
continue
}
h.wsManager.SendToUser(memberIDStr, notice)
}
}
// SendGroupMemberNotice 发送群成员变动通知
func (h *WebSocketHandler) SendGroupMemberNotice(noticeType string, groupID string, data *ws.GroupNoticeData) {
notice := &ws.GroupNoticeMessage{
NoticeType: noticeType,
GroupID: groupID,
Data: data,
Timestamp: time.Now().UnixMilli(),
}
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupNotice, notice)
h.BroadcastGroupNotice(groupID, wsMsg)
}
// truncateContent 截断内容
func truncateContent(content string, maxLen int) string {
if len(content) <= maxLen {
return content
}
return content[:maxLen] + "..."
}
// BroadcastGroupTyping 向群组所有成员广播输入状态
func (h *WebSocketHandler) BroadcastGroupTyping(groupID string, typingMsg *ws.GroupTypingMessage, excludeUserID string) {
// 获取群组所有成员
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for typing broadcast: %v", err)
return
}
// 创建WebSocket消息
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg)
// 遍历成员,如果在线则发送消息
for _, member := range members {
memberIDStr := member.UserID
// 排除指定用户
if memberIDStr == excludeUserID {
continue
}
// 发送消息
h.wsManager.SendToUser(memberIDStr, wsMsg)
}
}
// BroadcastGroupRead 向群组所有成员广播已读状态
func (h *WebSocketHandler) BroadcastGroupRead(groupID string, readMsg map[string]interface{}, excludeUserID string) {
// 获取群组所有成员
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for read broadcast: %v", err)
return
}
// 创建WebSocket消息
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupRead, readMsg)
// 遍历成员,如果在线则发送消息
for _, member := range members {
memberIDStr := member.UserID
// 排除指定用户
if memberIDStr == excludeUserID {
continue
}
// 发送消息
h.wsManager.SendToUser(memberIDStr, wsMsg)
}
}

View File

@@ -0,0 +1,95 @@
package middleware
import (
"fmt"
"strings"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// Auth 认证中间件
func Auth(jwtService *service.JWTService) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
fmt.Printf("[DEBUG] Auth middleware: Authorization header = %q\n", authHeader)
if authHeader == "" {
fmt.Printf("[DEBUG] Auth middleware: no Authorization header, returning 401\n")
response.Unauthorized(c, "authorization header is required")
c.Abort()
return
}
// 提取Token
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
fmt.Printf("[DEBUG] Auth middleware: invalid Authorization header format\n")
response.Unauthorized(c, "invalid authorization header format")
c.Abort()
return
}
token := parts[1]
fmt.Printf("[DEBUG] Auth middleware: token = %q\n", token[:min(20, len(token))]+"...")
// 验证Token
claims, err := jwtService.ParseToken(token)
if err != nil {
fmt.Printf("[DEBUG] Auth middleware: failed to parse token: %v\n", err)
response.Unauthorized(c, "invalid token")
c.Abort()
return
}
fmt.Printf("[DEBUG] Auth middleware: parsed claims, user_id = %q, username = %q\n", claims.UserID, claims.Username)
// 将用户信息存入上下文
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Next()
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
// OptionalAuth 可选认证中间件
func OptionalAuth(jwtService *service.JWTService) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.Next()
return
}
// 提取Token
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
c.Next()
return
}
token := parts[1]
// 验证Token
claims, err := jwtService.ParseToken(token)
if err != nil {
c.Next()
return
}
// 将用户信息存入上下文
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Next()
}
}

View File

@@ -0,0 +1,46 @@
package middleware
import (
"log"
"strings"
"github.com/gin-gonic/gin"
)
// CORS CORS中间件
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取请求路径
path := c.Request.URL.Path
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
// 添加 WebSocket 升级所需的头
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization, Connection, Upgrade, Sec-WebSocket-Key, Sec-WebSocket-Version, Sec-WebSocket-Protocol, Sec-WebSocket-Extensions")
c.Header("Access-Control-Expose-Headers", "Content-Length, Connection, Upgrade")
c.Header("Access-Control-Allow-Credentials", "true")
// 处理 WebSocket 升级请求的预检
if c.Request.Method == "OPTIONS" {
log.Printf("[CORS] OPTIONS 预检请求: %s", path)
c.AbortWithStatus(204)
return
}
// 针对 WebSocket 路径的特殊处理
if path == "/ws" {
connection := c.GetHeader("Connection")
upgrade := c.GetHeader("Upgrade")
log.Printf("[CORS] WebSocket 请求: Connection=%s, Upgrade=%s", connection, upgrade)
// 检查是否是有效的 WebSocket 升级请求
if strings.Contains(strings.ToLower(connection), "upgrade") && strings.ToLower(upgrade) == "websocket" {
log.Printf("[CORS] 有效的 WebSocket 升级请求")
} else {
log.Printf("[CORS] 警告: 不是有效的 WebSocket 升级请求!")
}
}
c.Next()
}
}

View File

@@ -0,0 +1,49 @@
package middleware
import (
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// Logger 日志中间件
func Logger(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
c.Next()
latency := time.Since(start)
statusCode := c.Writer.Status()
logger.Info("request",
zap.String("method", c.Request.Method),
zap.String("path", path),
zap.Int("status", statusCode),
zap.Duration("latency", latency),
zap.String("ip", c.ClientIP()),
zap.String("user-agent", c.Request.UserAgent()),
)
}
}
// Recovery 恢复中间件
func Recovery(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
logger.Error("panic recovered",
zap.Any("error", err),
)
c.JSON(500, gin.H{
"code": 500,
"message": "internal server error",
})
c.Abort()
}
}()
c.Next()
}
}

View File

@@ -0,0 +1,102 @@
package middleware
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// RateLimiter 限流器
type RateLimiter struct {
requests map[string][]time.Time
mu sync.Mutex
limit int
window time.Duration
}
// NewRateLimiter 创建限流器
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
rl := &RateLimiter{
requests: make(map[string][]time.Time),
limit: limit,
window: window,
}
// 定期清理过期的记录
go func() {
for {
time.Sleep(window)
rl.cleanup()
}
}()
return rl
}
// cleanup 清理过期的记录
func (rl *RateLimiter) cleanup() {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
for key, times := range rl.requests {
var valid []time.Time
for _, t := range times {
if now.Sub(t) < rl.window {
valid = append(valid, t)
}
}
if len(valid) == 0 {
delete(rl.requests, key)
} else {
rl.requests[key] = valid
}
}
}
// isAllowed 检查是否允许请求
func (rl *RateLimiter) isAllowed(key string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
times := rl.requests[key]
// 过滤掉过期的
var valid []time.Time
for _, t := range times {
if now.Sub(t) < rl.window {
valid = append(valid, t)
}
}
if len(valid) >= rl.limit {
rl.requests[key] = valid
return false
}
rl.requests[key] = append(valid, now)
return true
}
// RateLimit 限流中间件
func RateLimit(requestsPerMinute int) gin.HandlerFunc {
limiter := NewRateLimiter(requestsPerMinute, time.Minute)
return func(c *gin.Context) {
ip := c.ClientIP()
if !limiter.isAllowed(ip) {
c.JSON(http.StatusTooManyRequests, gin.H{
"code": 429,
"message": "too many requests",
})
c.Abort()
return
}
c.Next()
}
}

118
internal/model/audit_log.go Normal file
View File

@@ -0,0 +1,118 @@
package model
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// AuditTargetType 审核对象类型
type AuditTargetType string
const (
AuditTargetTypePost AuditTargetType = "post" // 帖子
AuditTargetTypeComment AuditTargetType = "comment" // 评论
AuditTargetTypeMessage AuditTargetType = "message" // 私信
AuditTargetTypeUser AuditTargetType = "user" // 用户资料
AuditTargetTypeImage AuditTargetType = "image" // 图片
)
// AuditResult 审核结果
type AuditResult string
const (
AuditResultPass AuditResult = "pass" // 通过
AuditResultReview AuditResult = "review" // 需人工复审
AuditResultBlock AuditResult = "block" // 违规拦截
AuditResultUnknown AuditResult = "unknown" // 未知
)
// AuditRiskLevel 风险等级
type AuditRiskLevel string
const (
AuditRiskLevelLow AuditRiskLevel = "low" // 低风险
AuditRiskLevelMedium AuditRiskLevel = "medium" // 中风险
AuditRiskLevelHigh AuditRiskLevel = "high" // 高风险
)
// AuditSource 审核来源
type AuditSource string
const (
AuditSourceAuto AuditSource = "auto" // 自动审核
AuditSourceManual AuditSource = "manual" // 人工审核
AuditSourceCallback AuditSource = "callback" // 回调审核
)
// AuditLog 审核日志实体
type AuditLog struct {
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
TargetType AuditTargetType `json:"target_type" gorm:"type:varchar(50);index"`
TargetID string `json:"target_id" gorm:"type:varchar(255);index"`
Content string `json:"content" gorm:"type:text"` // 待审核内容
ContentType string `json:"content_type" gorm:"type:varchar(50)"` // 内容类型: text, image
ContentURL string `json:"content_url" gorm:"type:text"` // 图片/文件URL
AuditType string `json:"audit_type" gorm:"type:varchar(50)"` // 审核类型: porn, violence, ad, political, fraud, gamble
Result AuditResult `json:"result" gorm:"type:varchar(50);index"`
RiskLevel AuditRiskLevel `json:"risk_level" gorm:"type:varchar(20)"`
Labels string `json:"labels" gorm:"type:text"` // JSON数组标签列表
Suggestion string `json:"suggestion" gorm:"type:varchar(50)"` // pass, review, block
Detail string `json:"detail" gorm:"type:text"` // 详细说明
ThirdPartyID string `json:"third_party_id" gorm:"type:varchar(255)"` // 第三方审核服务返回的ID
Source AuditSource `json:"source" gorm:"type:varchar(20);default:auto"`
ReviewerID string `json:"reviewer_id" gorm:"type:varchar(255)"` // 审核人ID人工审核时使用
ReviewerName string `json:"reviewer_name" gorm:"type:varchar(100)"` // 审核人名称
ReviewTime *time.Time `json:"review_time" gorm:"index"` // 审核时间
UserID string `json:"user_id" gorm:"type:varchar(255);index"` // 内容发布者ID
UserIP string `json:"user_ip" gorm:"type:varchar(45)"` // 用户IP
Status string `json:"status" gorm:"type:varchar(20);default:pending"` // pending, completed, failed
RejectReason string `json:"reject_reason" gorm:"type:text"` // 拒绝原因(人工审核时使用)
ExtraData string `json:"extra_data" gorm:"type:text"` // 额外数据JSON格式
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
}
// BeforeCreate 创建前生成UUID
func (al *AuditLog) BeforeCreate(tx *gorm.DB) error {
if al.ID == "" {
al.ID = uuid.New().String()
}
return nil
}
func (AuditLog) TableName() string {
return "audit_logs"
}
// AuditLogRequest 创建审核日志请求
type AuditLogRequest struct {
TargetType AuditTargetType `json:"target_type" validate:"required"`
TargetID string `json:"target_id" validate:"required"`
Content string `json:"content"`
ContentType string `json:"content_type"`
ContentURL string `json:"content_url"`
AuditType string `json:"audit_type"`
UserID string `json:"user_id"`
UserIP string `json:"user_ip"`
}
// AuditLogListItem 审核日志列表项
type AuditLogListItem struct {
ID string `json:"id"`
TargetType AuditTargetType `json:"target_type"`
TargetID string `json:"target_id"`
Content string `json:"content"`
ContentType string `json:"content_type"`
Result AuditResult `json:"result"`
RiskLevel AuditRiskLevel `json:"risk_level"`
Suggestion string `json:"suggestion"`
Source AuditSource `json:"source"`
ReviewerID string `json:"reviewer_id"`
ReviewTime *time.Time `json:"review_time"`
UserID string `json:"user_id"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
}

80
internal/model/comment.go Normal file
View File

@@ -0,0 +1,80 @@
package model
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// CommentStatus 评论状态
type CommentStatus string
const (
CommentStatusDraft CommentStatus = "draft"
CommentStatusPending CommentStatus = "pending"
CommentStatusPublished CommentStatus = "published"
CommentStatusRejected CommentStatus = "rejected"
CommentStatusDeleted CommentStatus = "deleted"
)
// Comment 评论实体
type Comment struct {
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
PostID string `json:"post_id" gorm:"type:varchar(36);not null;index:idx_comments_post_parent_status_created,priority:1"`
UserID string `json:"user_id" gorm:"type:varchar(36);index;not null"`
ParentID *string `json:"parent_id" gorm:"type:varchar(36);index:idx_comments_post_parent_status_created,priority:2"` // 父评论 ID支持嵌套
RootID *string `json:"root_id" gorm:"type:varchar(36);index:idx_comments_root_status_created,priority:1"` // 根评论 ID用于高效查询
Content string `json:"content" gorm:"type:text;not null"`
Images string `json:"images" gorm:"type:text"` // 图片URL列表JSON数组格式
// 关联
User *User `json:"-" gorm:"foreignKey:UserID"`
Replies []*Comment `json:"-" gorm:"-"` // 子回复(手动加载,非 GORM 关联)
// 审核状态
Status CommentStatus `json:"status" gorm:"type:varchar(20);default:published;index:idx_comments_post_parent_status_created,priority:3;index:idx_comments_root_status_created,priority:2"`
// 统计
LikesCount int `json:"likes_count" gorm:"default:0"`
RepliesCount int `json:"replies_count" gorm:"default:0"`
// 软删除
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
// 时间戳
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_comments_post_parent_status_created,priority:4,sort:asc;index:idx_comments_root_status_created,priority:3,sort:asc"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
// BeforeCreate 创建前生成UUID
func (c *Comment) BeforeCreate(tx *gorm.DB) error {
if c.ID == "" {
c.ID = uuid.New().String()
}
return nil
}
func (Comment) TableName() string {
return "comments"
}
// CommentLike 评论点赞
type CommentLike struct {
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
CommentID string `json:"comment_id" gorm:"type:varchar(36);index;not null"`
UserID string `json:"user_id" gorm:"type:varchar(36);index;not null"`
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
}
// BeforeCreate 创建前生成UUID
func (cl *CommentLike) BeforeCreate(tx *gorm.DB) error {
if cl.ID == "" {
cl.ID = uuid.New().String()
}
return nil
}
func (CommentLike) TableName() string {
return "comment_likes"
}

View File

@@ -0,0 +1,68 @@
package model
import (
"strconv"
"time"
"gorm.io/gorm"
"carrot_bbs/internal/pkg/utils"
)
// ConversationType 会话类型
type ConversationType string
const (
ConversationTypePrivate ConversationType = "private" // 私聊
ConversationTypeGroup ConversationType = "group" // 群聊
ConversationTypeSystem ConversationType = "system" // 系统通知会话
)
// Conversation 会话实体
// 使用雪花算法ID作为string存储和seq机制实现消息排序和增量同步
type Conversation struct {
ID string `gorm:"primaryKey;size:20" json:"id"` // 雪花算法ID转为string避免精度丢失
Type ConversationType `gorm:"type:varchar(20);default:'private'" json:"type"` // 会话类型
GroupID *string `gorm:"index" json:"group_id,omitempty"` // 关联的群组ID群聊时使用string类型避免JS精度丢失使用指针支持NULL值
LastSeq int64 `gorm:"default:0" json:"last_seq"` // 最后一条消息的seq
LastMsgTime *time.Time `json:"last_msg_time,omitempty"` // 最后消息时间
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
// 关联 - 使用 polymorphic 模式避免外键约束问题
Participants []ConversationParticipant `gorm:"foreignKey:ConversationID" json:"participants,omitempty"`
Group *Group `gorm:"foreignKey:GroupID;references:ID" json:"group,omitempty"`
}
// BeforeCreate 创建前生成雪花算法ID
func (c *Conversation) BeforeCreate(tx *gorm.DB) error {
if c.ID == "" {
id, err := utils.GetSnowflake().GenerateID()
if err != nil {
return err
}
c.ID = strconv.FormatInt(id, 10)
}
return nil
}
func (Conversation) TableName() string {
return "conversations"
}
// ConversationParticipant 会话参与者
type ConversationParticipant struct {
ID uint `gorm:"primaryKey" json:"id"`
ConversationID string `gorm:"not null;size:20;uniqueIndex:idx_conversation_user,priority:1;index:idx_cp_conversation_user,priority:1" json:"conversation_id"` // 雪花算法IDstring类型
UserID string `gorm:"column:user_id;type:varchar(50);not null;uniqueIndex:idx_conversation_user,priority:2;index:idx_cp_conversation_user,priority:2;index:idx_cp_user_hidden_pinned_updated,priority:1" json:"user_id"` // UUID格式与JWT中user_id保持一致
LastReadSeq int64 `gorm:"default:0" json:"last_read_seq"` // 已读到的seq位置
Muted bool `gorm:"default:false" json:"muted"` // 是否免打扰
IsPinned bool `gorm:"default:false;index:idx_cp_user_hidden_pinned_updated,priority:3" json:"is_pinned"` // 是否置顶会话(用户维度)
HiddenAt *time.Time `gorm:"index:idx_cp_user_hidden_pinned_updated,priority:2" json:"hidden_at,omitempty"` // 仅自己删除会话时使用,收到新消息后自动恢复
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime;index:idx_cp_user_hidden_pinned_updated,priority:4,sort:desc"`
}
func (ConversationParticipant) TableName() string {
return "conversation_participants"
}

View File

@@ -0,0 +1,94 @@
package model
import (
"time"
"gorm.io/gorm"
"carrot_bbs/internal/pkg/utils"
)
// DeviceType 设备类型
type DeviceType string
const (
DeviceTypeIOS DeviceType = "ios" // iOS设备
DeviceTypeAndroid DeviceType = "android" // Android设备
DeviceTypeWeb DeviceType = "web" // Web端
)
// DeviceToken 设备Token实体
// 用于管理用户的多设备推送Token
type DeviceToken struct {
ID int64 `gorm:"primaryKey;autoIncrement:false" json:"id"` // 雪花算法ID
UserID string `gorm:"column:user_id;type:varchar(50);index;not null" json:"user_id"` // 用户ID (UUID格式)
DeviceID string `gorm:"type:varchar(100);not null" json:"device_id"` // 设备唯一标识
DeviceType DeviceType `gorm:"type:varchar(20);not null" json:"device_type"` // 设备类型
PushToken string `gorm:"type:varchar(256);not null" json:"push_token"` // 推送TokenFCM/APNs等
IsActive bool `gorm:"default:true" json:"is_active"` // 是否活跃
DeviceName string `gorm:"type:varchar(100)" json:"device_name,omitempty"` // 设备名称(可选)
// 时间戳
LastUsedAt *time.Time `json:"last_used_at,omitempty"` // 最后使用时间
// 软删除
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
// 时间戳
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
// BeforeCreate 创建前生成雪花算法ID
func (d *DeviceToken) BeforeCreate(tx *gorm.DB) error {
if d.ID == 0 {
id, err := utils.GetSnowflake().GenerateID()
if err != nil {
return err
}
d.ID = id
}
return nil
}
func (DeviceToken) TableName() string {
return "device_tokens"
}
// UpdateLastUsed 更新最后使用时间
func (d *DeviceToken) UpdateLastUsed() {
now := time.Now()
d.LastUsedAt = &now
}
// Deactivate 停用设备
func (d *DeviceToken) Deactivate() {
d.IsActive = false
}
// Activate 激活设备
func (d *DeviceToken) Activate() {
d.IsActive = true
now := time.Now()
d.LastUsedAt = &now
}
// IsIOS 判断是否为iOS设备
func (d *DeviceToken) IsIOS() bool {
return d.DeviceType == DeviceTypeIOS
}
// IsAndroid 判断是否为Android设备
func (d *DeviceToken) IsAndroid() bool {
return d.DeviceType == DeviceTypeAndroid
}
// IsWeb 判断是否为Web端
func (d *DeviceToken) IsWeb() bool {
return d.DeviceType == DeviceTypeWeb
}
// SupportsMobilePush 判断是否支持手机推送
func (d *DeviceToken) SupportsMobilePush() bool {
return d.DeviceType == DeviceTypeIOS || d.DeviceType == DeviceTypeAndroid
}

View File

@@ -0,0 +1,28 @@
package model
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// Favorite 收藏
type Favorite struct {
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
PostID string `json:"post_id" gorm:"type:varchar(36);not null;index;uniqueIndex:idx_favorite_post_user,priority:1"`
UserID string `json:"user_id" gorm:"type:varchar(36);not null;index;uniqueIndex:idx_favorite_post_user,priority:2"`
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
}
// BeforeCreate 创建前生成UUID
func (f *Favorite) BeforeCreate(tx *gorm.DB) error {
if f.ID == "" {
f.ID = uuid.New().String()
}
return nil
}
func (Favorite) TableName() string {
return "favorites"
}

28
internal/model/follow.go Normal file
View File

@@ -0,0 +1,28 @@
package model
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// Follow 关注关系
type Follow struct {
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
FollowerID string `json:"follower_id" gorm:"type:varchar(36);index;not null;uniqueIndex:idx_follower_following"` // 关注者
FollowingID string `json:"following_id" gorm:"type:varchar(36);index;not null;uniqueIndex:idx_follower_following"` // 被关注者
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
}
// BeforeCreate 创建前生成UUID
func (f *Follow) BeforeCreate(tx *gorm.DB) error {
if f.ID == "" {
f.ID = uuid.New().String()
}
return nil
}
func (Follow) TableName() string {
return "follows"
}

57
internal/model/group.go Normal file
View File

@@ -0,0 +1,57 @@
package model
import (
"strconv"
"time"
"carrot_bbs/internal/pkg/utils"
"gorm.io/gorm"
)
// JoinType 群组加入类型
type JoinType int
const (
JoinTypeAnyone JoinType = 0 // 允许任何人加入
JoinTypeApproval JoinType = 1 // 需要审批
JoinTypeForbidden JoinType = 2 // 不允许加入
)
// Group 群组模型
type Group struct {
ID string `gorm:"primaryKey;size:20" json:"id"`
Name string `gorm:"size:50;not null" json:"name"`
Avatar string `gorm:"size:512" json:"avatar"`
Description string `gorm:"size:500" json:"description"`
OwnerID string `gorm:"type:varchar(36);not null;index" json:"owner_id"`
MemberCount int `gorm:"default:0" json:"member_count"`
MaxMembers int `gorm:"default:500" json:"max_members"`
JoinType JoinType `gorm:"default:0" json:"join_type"` // 0:允许任何人加入 1:需要审批 2:不允许加入
MuteAll bool `gorm:"default:false" json:"mute_all"` // 全员禁言
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// BeforeCreate 创建前生成雪花算法ID
func (g *Group) BeforeCreate(tx *gorm.DB) error {
if g.ID == "" {
id, err := utils.GetSnowflake().GenerateID()
if err != nil {
return err
}
g.ID = strconv.FormatInt(id, 10)
}
return nil
}
// GetIDInt 获取数字类型的ID用于比较
func (g *Group) GetIDInt() uint64 {
id, _ := strconv.ParseUint(g.ID, 10, 64)
return id
}
// TableName 指定表名
func (Group) TableName() string {
return "groups"
}

View File

@@ -0,0 +1,38 @@
package model
import (
"strconv"
"time"
"carrot_bbs/internal/pkg/utils"
"gorm.io/gorm"
)
// GroupAnnouncement 群公告模型
type GroupAnnouncement struct {
ID string `gorm:"primaryKey;size:20" json:"id"`
GroupID string `gorm:"not null;index" json:"group_id"`
Content string `gorm:"type:text;not null" json:"content"`
AuthorID string `gorm:"type:varchar(36);not null" json:"author_id"`
IsPinned bool `gorm:"default:false" json:"is_pinned"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// BeforeCreate 创建前生成雪花算法ID
func (ga *GroupAnnouncement) BeforeCreate(tx *gorm.DB) error {
if ga.ID == "" {
id, err := utils.GetSnowflake().GenerateID()
if err != nil {
return err
}
ga.ID = strconv.FormatInt(id, 10)
}
return nil
}
// TableName 指定表名
func (GroupAnnouncement) TableName() string {
return "group_announcements"
}

View File

@@ -0,0 +1,59 @@
package model
import (
"strconv"
"time"
"carrot_bbs/internal/pkg/utils"
"gorm.io/gorm"
)
type GroupJoinRequestType string
const (
GroupJoinRequestTypeInvite GroupJoinRequestType = "invite"
GroupJoinRequestTypeJoinApply GroupJoinRequestType = "join_apply"
)
type GroupJoinRequestStatus string
const (
GroupJoinRequestStatusPending GroupJoinRequestStatus = "pending"
GroupJoinRequestStatusAccepted GroupJoinRequestStatus = "accepted"
GroupJoinRequestStatusRejected GroupJoinRequestStatus = "rejected"
GroupJoinRequestStatusCancelled GroupJoinRequestStatus = "cancelled"
GroupJoinRequestStatusExpired GroupJoinRequestStatus = "expired"
)
// GroupJoinRequest 统一保存邀请入群和主动加群申请
type GroupJoinRequest struct {
ID string `gorm:"primaryKey;size:20" json:"id"`
Flag string `gorm:"size:64;uniqueIndex;not null" json:"flag"`
GroupID string `gorm:"not null;index;index:idx_gjr_group_target_type_status_created,priority:1" json:"group_id"`
InitiatorID string `gorm:"type:varchar(36);not null;index" json:"initiator_id"`
TargetUserID string `gorm:"type:varchar(36);not null;index;index:idx_gjr_group_target_type_status_created,priority:2" json:"target_user_id"`
RequestType GroupJoinRequestType `gorm:"size:20;not null;index;index:idx_gjr_group_target_type_status_created,priority:3" json:"request_type"`
Status GroupJoinRequestStatus `gorm:"size:20;not null;index;index:idx_gjr_group_target_type_status_created,priority:4" json:"status"`
Reason string `gorm:"size:500" json:"reason"`
ReviewerID string `gorm:"type:varchar(36);index" json:"reviewer_id"`
ReviewedAt *time.Time `json:"reviewed_at"`
ExpireAt *time.Time `json:"expire_at"`
CreatedAt time.Time `gorm:"index:idx_gjr_group_target_type_status_created,priority:5,sort:desc" json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (r *GroupJoinRequest) BeforeCreate(tx *gorm.DB) error {
if r.ID == "" {
id, err := utils.GetSnowflake().GenerateID()
if err != nil {
return err
}
r.ID = strconv.FormatInt(id, 10)
}
return nil
}
func (GroupJoinRequest) TableName() string {
return "group_join_requests"
}

View File

@@ -0,0 +1,47 @@
package model
import (
"strconv"
"time"
"carrot_bbs/internal/pkg/utils"
"gorm.io/gorm"
)
// GroupMember 群成员模型
type GroupMember struct {
ID string `gorm:"primaryKey;size:20" json:"id"`
GroupID string `gorm:"not null;uniqueIndex:idx_group_user" json:"group_id"`
UserID string `gorm:"type:varchar(36);not null;uniqueIndex:idx_group_user" json:"user_id"`
Role string `gorm:"size:20;default:'member'" json:"role"` // owner, admin, member
Nickname string `gorm:"size:50" json:"nickname"` // 群内昵称
Muted bool `gorm:"default:false" json:"muted"` // 是否被禁言
JoinTime time.Time `json:"join_time"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// BeforeCreate 创建前生成雪花算法ID
func (gm *GroupMember) BeforeCreate(tx *gorm.DB) error {
if gm.ID == "" {
id, err := utils.GetSnowflake().GenerateID()
if err != nil {
return err
}
gm.ID = strconv.FormatInt(id, 10)
}
return nil
}
// TableName 指定表名
func (GroupMember) TableName() string {
return "group_members"
}
// Role 常量
const (
GroupRoleOwner = "owner"
GroupRoleAdmin = "admin"
GroupRoleMember = "member"
)

159
internal/model/init.go Normal file
View File

@@ -0,0 +1,159 @@
package model
import (
"fmt"
"log"
"os"
"time"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"carrot_bbs/internal/config"
)
// DB 全局数据库连接
var DB *gorm.DB
// InitDB 初始化数据库连接
func InitDB(cfg *config.DatabaseConfig) error {
var err error
var db *gorm.DB
gormLogger := logger.New(
log.New(os.Stdout, "\r\n", log.LstdFlags),
logger.Config{
SlowThreshold: time.Duration(cfg.SlowThresholdMs) * time.Millisecond,
LogLevel: parseGormLogLevel(cfg.LogLevel),
IgnoreRecordNotFoundError: cfg.IgnoreRecordNotFound,
ParameterizedQueries: cfg.ParameterizedQueries,
Colorful: false,
},
)
// 根据数据库类型选择驱动
switch cfg.Type {
case "sqlite":
db, err = gorm.Open(sqlite.Open(cfg.SQLite.Path), &gorm.Config{
Logger: gormLogger,
})
case "postgres", "postgresql":
dsn := cfg.Postgres.DSN()
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: gormLogger,
})
default:
// 默认使用PostgreSQL
dsn := cfg.Postgres.DSN()
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: gormLogger,
})
}
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
DB = db
// 配置连接池SQLite不支持连接池配置跳过
if cfg.Type != "sqlite" {
sqlDB, err := DB.DB()
if err != nil {
return fmt.Errorf("failed to get database instance: %w", err)
}
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
}
// 自动迁移
if err := autoMigrate(DB); err != nil {
return fmt.Errorf("failed to auto migrate: %w", err)
}
log.Printf("Database connected (%s) and migrated successfully", cfg.Type)
return nil
}
func parseGormLogLevel(level string) logger.LogLevel {
switch level {
case "silent":
return logger.Silent
case "error":
return logger.Error
case "warn":
return logger.Warn
case "info":
return logger.Info
default:
return logger.Warn
}
}
// autoMigrate 自动迁移数据库表
func autoMigrate(db *gorm.DB) error {
err := db.AutoMigrate(
// 用户相关
&User{},
// 帖子相关
&Post{},
&PostImage{},
// 评论相关
&Comment{},
&CommentLike{},
// 消息相关使用雪花算法ID和seq机制
// 已读位置存储在 ConversationParticipant.LastReadSeq 中
&Conversation{},
&ConversationParticipant{},
&Message{},
// 系统通知(独立表,每个用户只能看到自己的通知)
&SystemNotification{},
// 通知
&Notification{},
// 推送中心相关
&PushRecord{}, // 推送记录
&DeviceToken{}, // 设备Token
// 社交
&Follow{},
&UserBlock{},
&PostLike{},
&Favorite{},
// 投票
&VoteOption{},
&UserVote{},
// 敏感词和审核
&SensitiveWord{},
&AuditLog{},
// 群组相关
&Group{},
&GroupMember{},
&GroupAnnouncement{},
&GroupJoinRequest{},
// 自定义表情
&UserSticker{},
)
if err != nil {
return err
}
return nil
}
// CloseDB 关闭数据库连接
func CloseDB() {
if sqlDB, err := DB.DB(); err == nil {
sqlDB.Close()
}
}

28
internal/model/like.go Normal file
View File

@@ -0,0 +1,28 @@
package model
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// PostLike 帖子点赞
type PostLike struct {
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
PostID string `json:"post_id" gorm:"type:varchar(36);not null;index;uniqueIndex:idx_post_like_user,priority:1"`
UserID string `json:"user_id" gorm:"type:varchar(36);not null;index;uniqueIndex:idx_post_like_user,priority:2"`
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
}
// BeforeCreate 创建前生成UUID
func (pl *PostLike) BeforeCreate(tx *gorm.DB) error {
if pl.ID == "" {
pl.ID = uuid.New().String()
}
return nil
}
func (PostLike) TableName() string {
return "post_likes"
}

205
internal/model/message.go Normal file
View File

@@ -0,0 +1,205 @@
package model
import (
"database/sql/driver"
"encoding/json"
"errors"
"strconv"
"time"
"gorm.io/gorm"
"carrot_bbs/internal/pkg/utils"
)
// 系统消息相关常量
const (
// SystemSenderID 系统消息发送者ID (类似QQ 10000)
SystemSenderID int64 = 10000
// SystemSenderIDStr 系统消息发送者ID字符串版本
SystemSenderIDStr string = "10000"
// SystemConversationID 系统通知会话ID (string类型)
SystemConversationID string = "9999999999"
)
// ContentType 消息内容类型
type ContentType string
const (
ContentTypeText ContentType = "text"
ContentTypeImage ContentType = "image"
ContentTypeVideo ContentType = "video"
ContentTypeAudio ContentType = "audio"
ContentTypeFile ContentType = "file"
)
// MessageStatus 消息状态
type MessageStatus string
const (
MessageStatusNormal MessageStatus = "normal" // 正常
MessageStatusRecalled MessageStatus = "recalled" // 已撤回
MessageStatusDeleted MessageStatus = "deleted" // 已删除
)
// MessageCategory 消息类别
type MessageCategory string
const (
CategoryChat MessageCategory = "chat" // 普通聊天
CategoryNotification MessageCategory = "notification" // 通知类消息
CategoryAnnouncement MessageCategory = "announcement" // 系统公告
CategoryMarketing MessageCategory = "marketing" // 营销消息
)
// SystemMessageType 系统消息类型 (对应原NotificationType)
type SystemMessageType string
const (
// 互动通知
SystemTypeLikePost SystemMessageType = "like_post" // 点赞帖子
SystemTypeLikeComment SystemMessageType = "like_comment" // 点赞评论
SystemTypeComment SystemMessageType = "comment" // 评论
SystemTypeReply SystemMessageType = "reply" // 回复
SystemTypeFollow SystemMessageType = "follow" // 关注
SystemTypeMention SystemMessageType = "mention" // @提及
// 系统消息
SystemTypeSystem SystemMessageType = "system" // 系统通知
SystemTypeAnnounce SystemMessageType = "announce" // 系统公告
)
// ExtraData 消息额外数据,用于存储系统消息的相关信息
type ExtraData struct {
// 操作者信息
ActorID int64 `json:"actor_id,omitempty"` // 操作者ID (数字格式,兼容旧数据)
ActorIDStr string `json:"actor_id_str,omitempty"` // 操作者ID (UUID字符串格式)
ActorName string `json:"actor_name,omitempty"` // 操作者名称
AvatarURL string `json:"avatar_url,omitempty"` // 操作者头像
// 目标信息
TargetID int64 `json:"target_id,omitempty"` // 目标ID帖子ID、评论ID等
TargetTitle string `json:"target_title,omitempty"` // 目标标题
TargetType string `json:"target_type,omitempty"` // 目标类型post/comment等
// 其他信息
ActionURL string `json:"action_url,omitempty"` // 跳转链接
ActionTime string `json:"action_time,omitempty"` // 操作时间
}
// Value 实现driver.Valuer接口用于数据库存储
func (e ExtraData) Value() (driver.Value, error) {
return json.Marshal(e)
}
// Scan 实现sql.Scanner接口用于数据库读取
func (e *ExtraData) Scan(value interface{}) error {
if value == nil {
return nil
}
bytes, ok := value.([]byte)
if !ok {
return errors.New("type assertion to []byte failed")
}
return json.Unmarshal(bytes, e)
}
// MessageSegmentData 单个消息段的数据
type MessageSegmentData map[string]interface{}
// MessageSegment 消息段
type MessageSegment struct {
Type string `json:"type"`
Data MessageSegmentData `json:"data"`
}
// MessageSegments 消息链类型
type MessageSegments []MessageSegment
// Value 实现driver.Valuer接口用于数据库存储
func (s MessageSegments) Value() (driver.Value, error) {
return json.Marshal(s)
}
// Scan 实现sql.Scanner接口用于数据库读取
func (s *MessageSegments) Scan(value interface{}) error {
if value == nil {
*s = nil
return nil
}
bytes, ok := value.([]byte)
if !ok {
return errors.New("type assertion to []byte failed")
}
return json.Unmarshal(bytes, s)
}
// Message 消息实体
// 使用雪花算法IDstring类型和seq机制实现消息排序和增量同步
type Message struct {
ID string `gorm:"primaryKey;size:20" json:"id"` // 雪花算法IDstring类型
ConversationID string `gorm:"not null;size:20;index:idx_msg_conversation_seq,priority:1" json:"conversation_id"` // 会话IDstring类型
SenderID string `gorm:"column:sender_id;type:varchar(50);index;not null" json:"sender_id"` // 发送者ID (UUID格式)
Seq int64 `gorm:"not null;index:idx_msg_conversation_seq,priority:2" json:"seq"` // 会话内序号,用于排序和增量同步
Segments MessageSegments `gorm:"type:json" json:"segments"` // 消息链(结构体数组)
ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息IDstring类型
Status MessageStatus `gorm:"type:varchar(20);default:'normal'" json:"status"` // 消息状态
// 新增字段:消息分类和系统消息类型
Category MessageCategory `gorm:"type:varchar(20);default:'chat'" json:"category"` // 消息分类
SystemType SystemMessageType `gorm:"type:varchar(30)" json:"system_type,omitempty"` // 系统消息类型
ExtraData *ExtraData `gorm:"type:json" json:"extra_data,omitempty"` // 额外数据JSON格式
// @相关字段
MentionUsers string `gorm:"type:text" json:"mention_users"` // @的用户ID列表JSON数组
MentionAll bool `gorm:"default:false" json:"mention_all"` // 是否@所有人
// 软删除
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
// 时间戳
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
// SenderIDStr 返回发送者ID字符串保持兼容性
func (m *Message) SenderIDStr() string {
return m.SenderID
}
// BeforeCreate 创建前生成雪花算法ID
func (m *Message) BeforeCreate(tx *gorm.DB) error {
if m.ID == "" {
id, err := utils.GetSnowflake().GenerateID()
if err != nil {
return err
}
m.ID = strconv.FormatInt(id, 10)
}
return nil
}
func (Message) TableName() string {
return "messages"
}
// IsSystemMessage 判断是否为系统消息
func (m *Message) IsSystemMessage() bool {
return m.SenderID == SystemSenderIDStr || m.Category == CategoryNotification || m.Category == CategoryAnnouncement
}
// IsInteractionNotification 判断是否为互动通知
func (m *Message) IsInteractionNotification() bool {
if m.Category != CategoryNotification {
return false
}
switch m.SystemType {
case SystemTypeLikePost, SystemTypeLikeComment, SystemTypeComment,
SystemTypeReply, SystemTypeFollow, SystemTypeMention:
return true
default:
return false
}
}

View File

@@ -0,0 +1,19 @@
package model
import (
"time"
)
// MessageRead 消息已读状态
// 记录每个用户在每个会话中的已读位置
type MessageRead struct {
ID uint `gorm:"primaryKey" json:"id"`
ConversationID int64 `gorm:"uniqueIndex:idx_conversation_user;not null" json:"conversation_id"`
UserID uint `gorm:"uniqueIndex:idx_conversation_user;not null" json:"user_id"`
LastReadSeq int64 `gorm:"not null" json:"last_read_seq"` // 已读到的seq位置
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
func (MessageRead) TableName() string {
return "message_reads"
}

View File

@@ -0,0 +1,53 @@
package model
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// NotificationType 通知类型
type NotificationType string
const (
NotificationTypeLikePost NotificationType = "like_post"
NotificationTypeLikeComment NotificationType = "like_comment"
NotificationTypeComment NotificationType = "comment"
NotificationTypeReply NotificationType = "reply"
NotificationTypeFollow NotificationType = "follow"
NotificationTypeMention NotificationType = "mention"
NotificationTypeSystem NotificationType = "system"
)
// Notification 通知实体
type Notification struct {
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
UserID string `json:"user_id" gorm:"type:varchar(36);not null;index:idx_notifications_user_read_created,priority:1"` // 接收者
Type NotificationType `json:"type" gorm:"type:varchar(30);not null"`
Title string `json:"title" gorm:"type:varchar(200);not null"`
Content string `json:"content" gorm:"type:text"`
Data string `json:"data" gorm:"type:jsonb"` // 相关数据JSON
// 已读状态
IsRead bool `json:"is_read" gorm:"default:false;index:idx_notifications_user_read_created,priority:2"`
ReadAt *time.Time `json:"read_at" gorm:"type:timestamp"`
// 软删除
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
// 时间戳
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_notifications_user_read_created,priority:3,sort:desc"`
}
// BeforeCreate 创建前生成UUID
func (n *Notification) BeforeCreate(tx *gorm.DB) error {
if n.ID == "" {
n.ID = uuid.New().String()
}
return nil
}
func (Notification) TableName() string {
return "notifications"
}

100
internal/model/post.go Normal file
View File

@@ -0,0 +1,100 @@
package model
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// PostStatus 帖子状态
type PostStatus string
const (
PostStatusDraft PostStatus = "draft"
PostStatusPending PostStatus = "pending" // 待审核
PostStatusPublished PostStatus = "published"
PostStatusRejected PostStatus = "rejected"
PostStatusDeleted PostStatus = "deleted"
)
// Post 帖子实体
type Post struct {
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
UserID string `json:"user_id" gorm:"type:varchar(36);index;index:idx_posts_user_status_created,priority:1;not null"`
CommunityID string `json:"community_id" gorm:"type:varchar(36);index"`
Title string `json:"title" gorm:"type:varchar(200);not null"`
Content string `json:"content" gorm:"type:text;not null"`
// 关联
// User 需要参与缓存序列化;否则列表命中缓存后会丢失作者信息,前端退化为“匿名用户”
User *User `json:"user,omitempty" gorm:"foreignKey:UserID"`
Images []PostImage `json:"images" gorm:"foreignKey:PostID"`
// 审核状态
Status PostStatus `json:"status" gorm:"type:varchar(20);default:published;index:idx_posts_status_created,priority:1;index:idx_posts_user_status_created,priority:2"`
ReviewedAt *time.Time `json:"reviewed_at" gorm:"type:timestamp"`
ReviewedBy string `json:"reviewed_by" gorm:"type:varchar(50)"`
RejectReason string `json:"reject_reason" gorm:"type:varchar(500)"`
// 统计
LikesCount int `json:"likes_count" gorm:"column:likes_count;default:0"`
CommentsCount int `json:"comments_count" gorm:"column:comments_count;default:0"`
FavoritesCount int `json:"favorites_count" gorm:"column:favorites_count;default:0"`
SharesCount int `json:"shares_count" gorm:"column:shares_count;default:0"`
ViewsCount int `json:"views_count" gorm:"column:views_count;default:0"`
HotScore float64 `json:"hot_score" gorm:"column:hot_score;default:0;index:idx_posts_hot_score_created,priority:1"`
// 置顶/锁定
IsPinned bool `json:"is_pinned" gorm:"default:false"`
IsLocked bool `json:"is_locked" gorm:"default:false"`
IsDeleted bool `json:"-" gorm:"default:false"`
// 投票
IsVote bool `json:"is_vote" gorm:"column:is_vote;default:false"`
// 软删除
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
// 时间戳
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_posts_status_created,priority:2,sort:desc;index:idx_posts_user_status_created,priority:3,sort:desc;index:idx_posts_hot_score_created,priority:2,sort:desc"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
// BeforeCreate 创建前生成UUID
func (p *Post) BeforeCreate(tx *gorm.DB) error {
if p.ID == "" {
p.ID = uuid.New().String()
}
return nil
}
func (Post) TableName() string {
return "posts"
}
// PostImage 帖子图片
type PostImage struct {
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
PostID string `json:"post_id" gorm:"type:varchar(36);index;not null"`
URL string `json:"url" gorm:"type:text;not null"`
ThumbnailURL string `json:"thumbnail_url" gorm:"type:text"`
Width int `json:"width" gorm:"default:0"`
Height int `json:"height" gorm:"default:0"`
Size int64 `json:"size" gorm:"default:0"` // 文件大小(字节)
MimeType string `json:"mime_type" gorm:"type:varchar(50)"`
SortOrder int `json:"sort_order" gorm:"default:0"`
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
}
// BeforeCreate 创建前生成UUID
func (pi *PostImage) BeforeCreate(tx *gorm.DB) error {
if pi.ID == "" {
pi.ID = uuid.New().String()
}
return nil
}
func (PostImage) TableName() string {
return "post_images"
}

View File

@@ -0,0 +1,129 @@
package model
import (
"time"
"gorm.io/gorm"
"carrot_bbs/internal/pkg/utils"
)
// PushChannel 推送通道类型
type PushChannel string
const (
PushChannelWebSocket PushChannel = "websocket" // WebSocket推送
PushChannelFCM PushChannel = "fcm" // Firebase Cloud Messaging
PushChannelAPNs PushChannel = "apns" // Apple Push Notification service
PushChannelHuawei PushChannel = "huawei" // 华为推送
)
// PushStatus 推送状态
type PushStatus string
const (
PushStatusPending PushStatus = "pending" // 待推送
PushStatusPushing PushStatus = "pushing" // 推送中
PushStatusPushed PushStatus = "pushed" // 已推送(成功发送到推送服务)
PushStatusDelivered PushStatus = "delivered" // 已送达(客户端确认)
PushStatusFailed PushStatus = "failed" // 推送失败
PushStatusExpired PushStatus = "expired" // 消息过期
)
// PushRecord 推送记录实体
// 用于跟踪消息的推送状态,支持多设备推送和重试机制
type PushRecord struct {
ID int64 `gorm:"primaryKey;autoIncrement:false" json:"id"` // 雪花算法ID
UserID string `gorm:"column:user_id;type:varchar(50);index;not null" json:"user_id"` // 目标用户ID (UUID格式)
MessageID string `gorm:"index;not null;size:20" json:"message_id"` // 关联的消息ID (string类型)
PushChannel PushChannel `gorm:"type:varchar(20);not null" json:"push_channel"` // 推送通道
PushStatus PushStatus `gorm:"type:varchar(20);not null;default:'pending'" json:"push_status"` // 推送状态
// 设备信息
DeviceToken string `gorm:"type:varchar(256)" json:"device_token,omitempty"` // 设备Token用于手机推送
DeviceType string `gorm:"type:varchar(20)" json:"device_type,omitempty"` // 设备类型 (ios/android/web)
// 重试机制
RetryCount int `gorm:"default:0" json:"retry_count"` // 重试次数
MaxRetry int `gorm:"default:3" json:"max_retry"` // 最大重试次数
// 时间戳
PushedAt *time.Time `json:"pushed_at,omitempty"` // 推送时间
DeliveredAt *time.Time `json:"delivered_at,omitempty"` // 送达时间
ExpiredAt *time.Time `gorm:"index" json:"expired_at,omitempty"` // 过期时间
// 错误信息
ErrorMessage string `gorm:"type:varchar(500)" json:"error_message,omitempty"` // 错误信息
// 软删除
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
// 时间戳
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
// BeforeCreate 创建前生成雪花算法ID
func (r *PushRecord) BeforeCreate(tx *gorm.DB) error {
if r.ID == 0 {
id, err := utils.GetSnowflake().GenerateID()
if err != nil {
return err
}
r.ID = id
}
return nil
}
func (PushRecord) TableName() string {
return "push_records"
}
// CanRetry 判断是否可以重试
func (r *PushRecord) CanRetry() bool {
return r.RetryCount < r.MaxRetry && r.PushStatus != PushStatusDelivered && r.PushStatus != PushStatusExpired
}
// IsExpired 判断是否已过期
func (r *PushRecord) IsExpired() bool {
if r.ExpiredAt == nil {
return false
}
return time.Now().After(*r.ExpiredAt)
}
// MarkPushing 标记为推送中
func (r *PushRecord) MarkPushing() {
r.PushStatus = PushStatusPushing
}
// MarkPushed 标记为已推送
func (r *PushRecord) MarkPushed() {
now := time.Now()
r.PushStatus = PushStatusPushed
r.PushedAt = &now
}
// MarkDelivered 标记为已送达
func (r *PushRecord) MarkDelivered() {
now := time.Now()
r.PushStatus = PushStatusDelivered
r.DeliveredAt = &now
}
// MarkFailed 标记为推送失败
func (r *PushRecord) MarkFailed(errMsg string) {
r.PushStatus = PushStatusFailed
r.ErrorMessage = errMsg
r.RetryCount++
}
// MarkExpired 标记为已过期
func (r *PushRecord) MarkExpired() {
r.PushStatus = PushStatusExpired
}
// IncrementRetry 增加重试次数
func (r *PushRecord) IncrementRetry() {
r.RetryCount++
}

View File

@@ -0,0 +1,77 @@
package model
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// SensitiveWordLevel 敏感词级别
type SensitiveWordLevel int
const (
SensitiveWordLevelLow SensitiveWordLevel = 1 // 低危
SensitiveWordLevelMedium SensitiveWordLevel = 2 // 中危
SensitiveWordLevelHigh SensitiveWordLevel = 3 // 高危
)
// SensitiveWordCategory 敏感词分类
type SensitiveWordCategory string
const (
SensitiveWordCategoryPolitical SensitiveWordCategory = "political" // 政治
SensitiveWordCategoryPorn SensitiveWordCategory = "porn" // 色情
SensitiveWordCategoryViolence SensitiveWordCategory = "violence" // 暴力
SensitiveWordCategoryAd SensitiveWordCategory = "ad" // 广告
SensitiveWordCategoryGamble SensitiveWordCategory = "gamble" // 赌博
SensitiveWordCategoryFraud SensitiveWordCategory = "fraud" // 诈骗
SensitiveWordCategoryOther SensitiveWordCategory = "other" // 其他
)
// SensitiveWord 敏感词实体
type SensitiveWord struct {
ID string `gorm:"type:varchar(36);primaryKey"`
Word string `gorm:"type:varchar(255);uniqueIndex;not null"`
Category SensitiveWordCategory `gorm:"type:varchar(50);index"`
Level SensitiveWordLevel `gorm:"type:int;default:1"`
IsActive bool `gorm:"default:true"`
CreatedBy string `gorm:"type:varchar(255)"`
UpdatedBy string `gorm:"type:varchar(255)"`
Remark string `gorm:"type:text"`
CreatedAt time.Time `gorm:"autoCreateTime"`
UpdatedAt time.Time `gorm:"autoUpdateTime"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
// BeforeCreate 创建前生成UUID
func (sw *SensitiveWord) BeforeCreate(tx *gorm.DB) error {
if sw.ID == "" {
sw.ID = uuid.New().String()
}
return nil
}
func (SensitiveWord) TableName() string {
return "sensitive_words"
}
// SensitiveWordRequest 创建/更新敏感词请求
type SensitiveWordRequest struct {
Word string `json:"word" validate:"required,min=1,max=255"`
Category SensitiveWordCategory `json:"category"`
Level SensitiveWordLevel `json:"level"`
Remark string `json:"remark"`
CreatedBy string `json:"-"`
}
// SensitiveWordListItem 敏感词列表项(用于列表展示)
type SensitiveWordListItem struct {
ID string `json:"id"`
Word string `json:"word"`
Category SensitiveWordCategory `json:"category"`
Level SensitiveWordLevel `json:"level"`
IsActive bool `json:"is_active"`
CreatedAt time.Time `json:"created_at"`
Remark string `json:"remark"`
}

33
internal/model/sticker.go Normal file
View File

@@ -0,0 +1,33 @@
package model
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// UserSticker 用户自定义表情
type UserSticker struct {
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
UserID string `json:"user_id" gorm:"type:varchar(36);not null;index:idx_user_stickers"`
URL string `json:"url" gorm:"type:text;not null"`
Width int `json:"width" gorm:"default:0"`
Height int `json:"height" gorm:"default:0"`
SortOrder int `json:"sort_order" gorm:"default:0;index:idx_user_stickers_sort"`
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
// TableName 表名
func (UserSticker) TableName() string {
return "user_stickers"
}
// BeforeCreate 创建前生成UUID
func (s *UserSticker) BeforeCreate(tx *gorm.DB) error {
if s.ID == "" {
s.ID = uuid.New().String()
}
return nil
}

View File

@@ -0,0 +1,127 @@
package model
import (
"database/sql/driver"
"encoding/json"
"errors"
"time"
"gorm.io/gorm"
"carrot_bbs/internal/pkg/utils"
)
// SystemNotificationType 系统通知类型
type SystemNotificationType string
const (
// 互动通知
SysNotifyLikePost SystemNotificationType = "like_post" // 点赞帖子
SysNotifyLikeComment SystemNotificationType = "like_comment" // 点赞评论
SysNotifyComment SystemNotificationType = "comment" // 评论
SysNotifyReply SystemNotificationType = "reply" // 回复
SysNotifyFollow SystemNotificationType = "follow" // 关注
SysNotifyMention SystemNotificationType = "mention" // @提及
SysNotifyFavoritePost SystemNotificationType = "favorite_post" // 收藏帖子
SysNotifyLikeReply SystemNotificationType = "like_reply" // 点赞回复
// 系统消息
SysNotifySystem SystemNotificationType = "system" // 系统通知
SysNotifyAnnounce SystemNotificationType = "announce" // 系统公告
SysNotifyGroupInvite SystemNotificationType = "group_invite" // 群邀请
SysNotifyGroupJoinApply SystemNotificationType = "group_join_apply" // 加群申请待审批
SysNotifyGroupJoinApproved SystemNotificationType = "group_join_approved" // 加群申请通过
SysNotifyGroupJoinRejected SystemNotificationType = "group_join_rejected" // 加群申请拒绝
)
// SystemNotificationExtra 额外数据
type SystemNotificationExtra struct {
// 操作者信息
ActorID int64 `json:"actor_id,omitempty"`
ActorIDStr string `json:"actor_id_str,omitempty"`
ActorName string `json:"actor_name,omitempty"`
AvatarURL string `json:"avatar_url,omitempty"`
// 目标信息
TargetID string `json:"target_id,omitempty"` // 改为string类型以支持UUID
TargetTitle string `json:"target_title,omitempty"`
TargetType string `json:"target_type,omitempty"`
// 其他信息
ActionURL string `json:"action_url,omitempty"`
ActionTime string `json:"action_time,omitempty"`
// 群邀请/加群申请扩展字段
GroupID string `json:"group_id,omitempty"`
GroupName string `json:"group_name,omitempty"`
GroupAvatar string `json:"group_avatar,omitempty"`
GroupDescription string `json:"group_description,omitempty"`
Flag string `json:"flag,omitempty"`
RequestType string `json:"request_type,omitempty"`
RequestStatus string `json:"request_status,omitempty"`
Reason string `json:"reason,omitempty"`
TargetUserID string `json:"target_user_id,omitempty"`
TargetUserName string `json:"target_user_name,omitempty"`
TargetUserAvatar string `json:"target_user_avatar,omitempty"`
}
// Value 实现driver.Valuer接口
func (e SystemNotificationExtra) Value() (driver.Value, error) {
return json.Marshal(e)
}
// Scan 实现sql.Scanner接口
func (e *SystemNotificationExtra) Scan(value interface{}) error {
if value == nil {
return nil
}
bytes, ok := value.([]byte)
if !ok {
return errors.New("type assertion to []byte failed")
}
return json.Unmarshal(bytes, e)
}
// SystemNotification 系统通知(独立表,与消息完全分离)
// 每个用户只能看到自己的系统通知
type SystemNotification struct {
ID int64 `gorm:"primaryKey;autoIncrement:false" json:"id"`
ReceiverID string `gorm:"column:receiver_id;type:varchar(50);not null;index:idx_sys_notifications_receiver_read_created,priority:1" json:"receiver_id"` // 接收者ID (UUID)
Type SystemNotificationType `gorm:"type:varchar(30);not null" json:"type"` // 通知类型
Title string `gorm:"type:varchar(200)" json:"title,omitempty"` // 标题
Content string `gorm:"type:text;not null" json:"content"` // 内容
ExtraData *SystemNotificationExtra `gorm:"type:json" json:"extra_data,omitempty"` // 额外数据
IsRead bool `gorm:"default:false;index:idx_sys_notifications_receiver_read_created,priority:2" json:"is_read"` // 是否已读
ReadAt *time.Time `json:"read_at,omitempty"` // 阅读时间
// 软删除
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
// 时间戳
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_sys_notifications_receiver_read_created,priority:3,sort:desc"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
// BeforeCreate 创建前生成雪花算法ID
func (n *SystemNotification) BeforeCreate(tx *gorm.DB) error {
if n.ID == 0 {
id, err := utils.GetSnowflake().GenerateID()
if err != nil {
return err
}
n.ID = id
}
return nil
}
// TableName 指定表名
func (SystemNotification) TableName() string {
return "system_notifications"
}
// MarkAsRead 标记为已读
func (n *SystemNotification) MarkAsRead() {
now := time.Now()
n.IsRead = true
n.ReadAt = &now
}

66
internal/model/user.go Normal file
View File

@@ -0,0 +1,66 @@
package model
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// UserStatus 用户状态
type UserStatus string
const (
UserStatusActive UserStatus = "active"
UserStatusBanned UserStatus = "banned"
UserStatusInactive UserStatus = "inactive"
)
// User 用户实体
type User struct {
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
Username string `json:"username" gorm:"type:varchar(50);uniqueIndex;not null"`
Nickname string `json:"nickname" gorm:"type:varchar(100);not null"`
Email *string `json:"email" gorm:"type:varchar(255);uniqueIndex"`
Phone *string `json:"phone" gorm:"type:varchar(20);uniqueIndex"`
EmailVerified bool `json:"email_verified" gorm:"default:false"`
PasswordHash string `json:"-" gorm:"type:varchar(255);not null"`
Avatar string `json:"avatar" gorm:"type:text"`
CoverURL string `json:"cover_url" gorm:"type:text"` // 头图URL
Bio string `json:"bio" gorm:"type:text"`
Website string `json:"website" gorm:"type:varchar(255)"`
Location string `json:"location" gorm:"type:varchar(100)"`
// 实名认证信息(可选)
RealName string `json:"real_name" gorm:"type:varchar(100)"` // 真实姓名
IDCard string `json:"-" gorm:"type:varchar(18)"` // 身份证号(加密存储)
IsVerified bool `json:"is_verified" gorm:"default:false"` // 是否实名认证
VerifiedAt *time.Time `json:"verified_at" gorm:"type:timestamp"`
// 统计计数
PostsCount int `json:"posts_count" gorm:"default:0"`
FollowersCount int `json:"followers_count" gorm:"default:0"`
FollowingCount int `json:"following_count" gorm:"default:0"`
// 状态
Status UserStatus `json:"status" gorm:"type:varchar(20);default:active"`
LastLoginAt *time.Time `json:"last_login_at" gorm:"type:timestamp"`
LastLoginIP string `json:"last_login_ip" gorm:"type:varchar(45)"`
// 时间戳
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
}
// BeforeCreate 创建前生成UUID
func (u *User) BeforeCreate(tx *gorm.DB) error {
if u.ID == "" {
u.ID = uuid.New().String()
}
return nil
}
func (User) TableName() string {
return "users"
}

View File

@@ -0,0 +1,27 @@
package model
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// UserBlock 用户拉黑关系
type UserBlock struct {
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
BlockerID string `json:"blocker_id" gorm:"type:varchar(36);index;not null;uniqueIndex:idx_blocker_blocked"` // 拉黑人
BlockedID string `json:"blocked_id" gorm:"type:varchar(36);index;not null;uniqueIndex:idx_blocker_blocked"` // 被拉黑人
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
}
func (b *UserBlock) BeforeCreate(tx *gorm.DB) error {
if b.ID == "" {
b.ID = uuid.New().String()
}
return nil
}
func (UserBlock) TableName() string {
return "user_blocks"
}

52
internal/model/vote.go Normal file
View File

@@ -0,0 +1,52 @@
package model
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// VoteOption 投票选项
type VoteOption struct {
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
PostID string `json:"post_id" gorm:"type:varchar(36);index:idx_vote_option_post_sort,priority:1;not null"`
Content string `json:"content" gorm:"type:varchar(200);not null"`
SortOrder int `json:"sort_order" gorm:"default:0;index:idx_vote_option_post_sort,priority:2"`
VotesCount int `json:"votes_count" gorm:"default:0"`
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"`
}
// BeforeCreate 创建前生成UUID
func (vo *VoteOption) BeforeCreate(tx *gorm.DB) error {
if vo.ID == "" {
vo.ID = uuid.New().String()
}
return nil
}
func (VoteOption) TableName() string {
return "vote_options"
}
// UserVote 用户投票记录
type UserVote struct {
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
PostID string `json:"post_id" gorm:"type:varchar(36);index;uniqueIndex:idx_user_vote_post_user,priority:1;not null"`
UserID string `json:"user_id" gorm:"type:varchar(36);index;uniqueIndex:idx_user_vote_post_user,priority:2;not null"`
OptionID string `json:"option_id" gorm:"type:varchar(36);index;not null"`
CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"`
}
// BeforeCreate 创建前生成UUID
func (uv *UserVote) BeforeCreate(tx *gorm.DB) error {
if uv.ID == "" {
uv.ID = uuid.New().String()
}
return nil
}
func (UserVote) TableName() string {
return "user_votes"
}

View File

@@ -0,0 +1,115 @@
package avatar
import (
"encoding/base64"
"fmt"
"unicode/utf8"
)
// 预定义一组好看的颜色
var colors = []string{
"#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4",
"#FFEAA7", "#DDA0DD", "#98D8C8", "#F7DC6F",
"#BB8FCE", "#85C1E9", "#F8B500", "#00CED1",
"#E74C3C", "#3498DB", "#2ECC71", "#9B59B6",
"#1ABC9C", "#F39C12", "#E67E22", "#16A085",
}
// SVG模板
const svgTemplate = `<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 100" width="%d" height="%d">
<rect width="100" height="100" fill="%s"/>
<text x="50" y="50" font-family="Arial, sans-serif" font-size="40" font-weight="bold" fill="#ffffff" text-anchor="middle" dominant-baseline="central">%s</text>
</svg>`
// GenerateSVGAvatar 根据用户名生成SVG头像
// username: 用户名
// size: 头像尺寸(像素)
func GenerateSVGAvatar(username string, size int) string {
initials := getInitials(username)
color := stringToColor(username)
return fmt.Sprintf(svgTemplate, size, size, color, initials)
}
// GenerateAvatarDataURI 生成Data URI格式的头像
// 可以直接在HTML img标签或CSS background-image中使用
func GenerateAvatarDataURI(username string, size int) string {
svg := GenerateSVGAvatar(username, size)
encoded := base64.StdEncoding.EncodeToString([]byte(svg))
return fmt.Sprintf("data:image/svg+xml;base64,%s", encoded)
}
// getInitials 获取用户名首字母
// 中文取第一个字英文取首字母最多2个
func getInitials(username string) string {
if username == "" {
return "?"
}
// 检查是否是中文字符
firstRune, _ := utf8.DecodeRuneInString(username)
if isChinese(firstRune) {
// 中文直接返回第一个字符
return string(firstRune)
}
// 英文处理:取前两个单词的首字母
// 例如: "John Doe" -> "JD", "john" -> "J"
result := []rune{}
for i, r := range username {
if i == 0 {
result = append(result, toUpper(r))
} else if r == ' ' || r == '_' || r == '-' {
// 找到下一个字符作为第二个首字母
nextIdx := i + 1
if nextIdx < len(username) {
nextRune, _ := utf8.DecodeRuneInString(username[nextIdx:])
if nextRune != utf8.RuneError && nextRune != ' ' {
result = append(result, toUpper(nextRune))
break
}
}
}
}
if len(result) == 0 {
return "?"
}
// 最多返回2个字符
if len(result) > 2 {
result = result[:2]
}
return string(result)
}
// isChinese 判断是否是中文字符
func isChinese(r rune) bool {
return r >= 0x4E00 && r <= 0x9FFF
}
// toUpper 将字母转换为大写
func toUpper(r rune) rune {
if r >= 'a' && r <= 'z' {
return r - 32
}
return r
}
// stringToColor 根据字符串生成颜色
// 使用简单的哈希算法确保同一用户名每次生成的颜色一致
func stringToColor(s string) string {
if s == "" {
return colors[0]
}
hash := 0
for _, r := range s {
hash = (hash*31 + int(r)) % len(colors)
}
if hash < 0 {
hash = -hash
}
return colors[hash%len(colors)]
}

View File

@@ -0,0 +1,118 @@
package avatar
import (
"strings"
"testing"
)
func TestGetInitials(t *testing.T) {
tests := []struct {
name string
username string
want string
}{
{"中文用户名", "张三", "张"},
{"英文用户名", "John", "J"},
{"英文全名", "John Doe", "JD"},
{"带下划线", "john_doe", "JD"},
{"带连字符", "john-doe", "JD"},
{"空字符串", "", "?"},
{"小写英文", "alice", "A"},
{"中文复合", "李小龙", "李"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := getInitials(tt.username)
if got != tt.want {
t.Errorf("getInitials(%q) = %q, want %q", tt.username, got, tt.want)
}
})
}
}
func TestStringToColor(t *testing.T) {
// 测试同一用户名生成的颜色一致
color1 := stringToColor("张三")
color2 := stringToColor("张三")
if color1 != color2 {
t.Errorf("stringToColor should return consistent colors for the same input")
}
// 测试不同用户名生成不同颜色(大概率)
color3 := stringToColor("李四")
if color1 == color3 {
t.Logf("Warning: different usernames generated the same color (possible but unlikely)")
}
// 测试空字符串
color4 := stringToColor("")
if color4 == "" {
t.Errorf("stringToColor should return a color for empty string")
}
// 验证颜色格式
if !strings.HasPrefix(color4, "#") {
t.Errorf("stringToColor should return hex color format starting with #")
}
}
func TestGenerateSVGAvatar(t *testing.T) {
svg := GenerateSVGAvatar("张三", 100)
// 验证SVG结构
if !strings.Contains(svg, "<svg") {
t.Errorf("SVG should contain <svg tag")
}
if !strings.Contains(svg, "</svg>") {
t.Errorf("SVG should contain </svg> tag")
}
if !strings.Contains(svg, "width=\"100\"") {
t.Errorf("SVG should have width=100")
}
if !strings.Contains(svg, "height=\"100\"") {
t.Errorf("SVG should have height=100")
}
if !strings.Contains(svg, "张") {
t.Errorf("SVG should contain the initial character")
}
}
func TestGenerateAvatarDataURI(t *testing.T) {
dataURI := GenerateAvatarDataURI("张三", 100)
// 验证Data URI格式
if !strings.HasPrefix(dataURI, "data:image/svg+xml;base64,") {
t.Errorf("Data URI should start with data:image/svg+xml;base64,")
}
// 验证base64部分不为空
parts := strings.Split(dataURI, ",")
if len(parts) != 2 {
t.Errorf("Data URI should have two parts separated by comma")
}
if parts[1] == "" {
t.Errorf("Base64 part should not be empty")
}
}
func TestIsChinese(t *testing.T) {
tests := []struct {
r rune
want bool
}{
{'中', true},
{'文', true},
{'a', false},
{'Z', false},
{'0', false},
{'_', false},
}
for _, tt := range tests {
got := isChinese(tt.r)
if got != tt.want {
t.Errorf("isChinese(%q) = %v, want %v", tt.r, got, tt.want)
}
}
}

View File

@@ -0,0 +1,131 @@
package email
import (
"context"
"crypto/tls"
"fmt"
"strings"
"time"
gomail "gopkg.in/gomail.v2"
)
// Message 发信参数
type Message struct {
To []string
Cc []string
Bcc []string
ReplyTo []string
Subject string
TextBody string
HTMLBody string
Attachments []string
}
type Client interface {
IsEnabled() bool
Config() Config
Send(ctx context.Context, msg Message) error
}
type clientImpl struct {
cfg Config
}
func NewClient(cfg Config) Client {
return &clientImpl{cfg: cfg}
}
func (c *clientImpl) IsEnabled() bool {
return c.cfg.Enabled &&
strings.TrimSpace(c.cfg.Host) != "" &&
c.cfg.Port > 0 &&
strings.TrimSpace(c.cfg.FromAddress) != ""
}
func (c *clientImpl) Config() Config {
return c.cfg
}
func (c *clientImpl) Send(ctx context.Context, msg Message) error {
if !c.IsEnabled() {
return fmt.Errorf("email client is disabled or misconfigured")
}
if len(msg.To) == 0 {
return fmt.Errorf("email recipient is empty")
}
if strings.TrimSpace(msg.Subject) == "" {
return fmt.Errorf("email subject is empty")
}
if strings.TrimSpace(msg.TextBody) == "" && strings.TrimSpace(msg.HTMLBody) == "" {
return fmt.Errorf("email body is empty")
}
m := gomail.NewMessage()
m.SetAddressHeader("From", c.cfg.FromAddress, c.cfg.FromName)
m.SetHeader("To", msg.To...)
if len(msg.Cc) > 0 {
m.SetHeader("Cc", msg.Cc...)
}
if len(msg.Bcc) > 0 {
m.SetHeader("Bcc", msg.Bcc...)
}
if len(msg.ReplyTo) > 0 {
m.SetHeader("Reply-To", msg.ReplyTo...)
}
m.SetHeader("Subject", msg.Subject)
if strings.TrimSpace(msg.TextBody) != "" && strings.TrimSpace(msg.HTMLBody) != "" {
m.SetBody("text/plain", msg.TextBody)
m.AddAlternative("text/html", msg.HTMLBody)
} else if strings.TrimSpace(msg.HTMLBody) != "" {
m.SetBody("text/html", msg.HTMLBody)
} else {
m.SetBody("text/plain", msg.TextBody)
}
for _, attachment := range msg.Attachments {
if strings.TrimSpace(attachment) == "" {
continue
}
m.Attach(attachment)
}
timeout := c.cfg.TimeoutSeconds
if timeout <= 0 {
timeout = 15
}
dialer := gomail.NewDialer(c.cfg.Host, c.cfg.Port, c.cfg.Username, c.cfg.Password)
if c.cfg.UseTLS {
dialer.TLSConfig = &tls.Config{
ServerName: c.cfg.Host,
InsecureSkipVerify: c.cfg.InsecureSkipVerify,
}
// 465 端口通常要求直接 TLSImplicit TLS
if c.cfg.Port == 465 {
dialer.SSL = true
}
}
sendCtx := ctx
cancel := func() {}
if timeout > 0 {
sendCtx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
}
defer cancel()
done := make(chan error, 1)
go func() {
done <- dialer.DialAndSend(m)
}()
select {
case <-sendCtx.Done():
return fmt.Errorf("send email canceled: %w", sendCtx.Err())
case err := <-done:
if err != nil {
return fmt.Errorf("send email failed: %w", err)
}
return nil
}
}

View File

@@ -0,0 +1,33 @@
package email
import "carrot_bbs/internal/config"
// Config SMTP 邮件配置(由应用配置转换)
type Config struct {
Enabled bool
Host string
Port int
Username string
Password string
FromAddress string
FromName string
UseTLS bool
InsecureSkipVerify bool
TimeoutSeconds int
}
// ConfigFromAppConfig 从应用配置转换
func ConfigFromAppConfig(cfg *config.EmailConfig) Config {
return Config{
Enabled: cfg.Enabled,
Host: cfg.Host,
Port: cfg.Port,
Username: cfg.Username,
Password: cfg.Password,
FromAddress: cfg.FromAddress,
FromName: cfg.FromName,
UseTLS: cfg.UseTLS,
InsecureSkipVerify: cfg.InsecureSkipVerify,
TimeoutSeconds: cfg.Timeout,
}
}

View File

@@ -0,0 +1,286 @@
package gorse
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
gorseio "github.com/gorse-io/gorse-go"
)
// FeedbackType 反馈类型
type FeedbackType string
const (
FeedbackTypeLike FeedbackType = "like" // 点赞
FeedbackTypeStar FeedbackType = "star" // 收藏
FeedbackTypeComment FeedbackType = "comment" // 评论
FeedbackTypeRead FeedbackType = "read" // 浏览
)
// Score 非个性化推荐返回的评分项
type Score struct {
Id string `json:"Id"`
Score float64 `json:"Score"`
}
// Client Gorse客户端接口
type Client interface {
// InsertFeedback 插入用户反馈
InsertFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error
// DeleteFeedback 删除用户反馈
DeleteFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error
// GetRecommend 获取个性化推荐列表
GetRecommend(ctx context.Context, userID string, n int, offset int) ([]string, error)
// GetNonPersonalized 获取非个性化推荐(通过名称)
GetNonPersonalized(ctx context.Context, name string, n int, offset int, userID string) ([]string, error)
// UpsertItem 插入或更新物品无embedding
UpsertItem(ctx context.Context, itemID string, categories []string, comment string) error
// UpsertItemWithEmbedding 插入或更新物品带embedding
UpsertItemWithEmbedding(ctx context.Context, itemID string, categories []string, comment string, textToEmbed string) error
// DeleteItem 删除物品
DeleteItem(ctx context.Context, itemID string) error
// UpsertUser 插入或更新用户
UpsertUser(ctx context.Context, userID string, labels map[string]any) error
// IsEnabled 检查是否启用
IsEnabled() bool
}
// client Gorse客户端实现
type client struct {
config Config
gorse *gorseio.GorseClient
httpClient *http.Client
}
// NewClient 创建新的Gorse客户端
func NewClient(cfg Config) Client {
if !cfg.Enabled {
return &noopClient{}
}
gorse := gorseio.NewGorseClient(cfg.Address, cfg.APIKey)
return &client{
config: cfg,
gorse: gorse,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// IsEnabled 检查是否启用
func (c *client) IsEnabled() bool {
return c.config.Enabled
}
// InsertFeedback 插入用户反馈
func (c *client) InsertFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error {
if !c.config.Enabled {
return nil
}
_, err := c.gorse.InsertFeedback(ctx, []gorseio.Feedback{
{
FeedbackType: string(feedbackType),
UserId: userID,
ItemId: itemID,
Timestamp: time.Now().UTC().Truncate(time.Second),
},
})
return err
}
// DeleteFeedback 删除用户反馈
func (c *client) DeleteFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error {
if !c.config.Enabled {
return nil
}
_, err := c.gorse.DeleteFeedback(ctx, string(feedbackType), userID, itemID)
return err
}
// GetRecommend 获取个性化推荐列表
func (c *client) GetRecommend(ctx context.Context, userID string, n int, offset int) ([]string, error) {
if !c.config.Enabled {
return nil, nil
}
result, err := c.gorse.GetRecommend(ctx, userID, "", n, offset)
if err != nil {
return nil, err
}
return result, nil
}
// GetNonPersonalized 获取非个性化推荐
// name: 推荐器名称,如 "most_liked_weekly"
// n: 返回数量
// offset: 偏移量
// userID: 可选,用于排除用户已读物品
func (c *client) GetNonPersonalized(ctx context.Context, name string, n int, offset int, userID string) ([]string, error) {
if !c.config.Enabled {
return nil, nil
}
// 构建URL
url := fmt.Sprintf("%s/api/non-personalized/%s?n=%d&offset=%d", c.config.Address, name, n, offset)
if userID != "" {
url += fmt.Sprintf("&user-id=%s", userID)
}
// 创建请求
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// 设置API Key
if c.config.APIKey != "" {
req.Header.Set("X-API-Key", c.config.APIKey)
}
// 发送请求
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
// 读取响应
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("gorse api error: status=%d, body=%s", resp.StatusCode, string(body))
}
// 解析响应
var scores []Score
if err := json.Unmarshal(body, &scores); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
// 提取ID
ids := make([]string, len(scores))
for i, score := range scores {
ids[i] = score.Id
}
return ids, nil
}
// UpsertItem 插入或更新物品
func (c *client) UpsertItem(ctx context.Context, itemID string, categories []string, comment string) error {
if !c.config.Enabled {
return nil
}
_, err := c.gorse.InsertItem(ctx, gorseio.Item{
ItemId: itemID,
IsHidden: false,
Categories: categories,
Comment: comment,
Timestamp: time.Now().UTC().Truncate(time.Second),
})
return err
}
// UpsertItemWithEmbedding 插入或更新物品带embedding
func (c *client) UpsertItemWithEmbedding(ctx context.Context, itemID string, categories []string, comment string, textToEmbed string) error {
if !c.config.Enabled {
return nil
}
// 生成embedding
var embedding []float64
if textToEmbed != "" {
var err error
embedding, err = GetEmbedding(textToEmbed)
if err != nil {
log.Printf("[WARN] Failed to get embedding for item %s: %v, using zero vector", itemID, err)
embedding = make([]float64, 1024)
}
} else {
embedding = make([]float64, 1024)
}
_, err := c.gorse.InsertItem(ctx, gorseio.Item{
ItemId: itemID,
IsHidden: false,
Categories: categories,
Comment: comment,
Timestamp: time.Now().UTC().Truncate(time.Second),
Labels: map[string]any{
"embedding": embedding,
},
})
return err
}
// DeleteItem 删除物品
func (c *client) DeleteItem(ctx context.Context, itemID string) error {
if !c.config.Enabled {
return nil
}
_, err := c.gorse.DeleteItem(ctx, itemID)
return err
}
// UpsertUser 插入或更新用户
func (c *client) UpsertUser(ctx context.Context, userID string, labels map[string]any) error {
if !c.config.Enabled {
return nil
}
_, err := c.gorse.InsertUser(ctx, gorseio.User{
UserId: userID,
Labels: labels,
})
return err
}
// noopClient 空操作客户端(用于未启用推荐功能时)
type noopClient struct{}
func (c *noopClient) IsEnabled() bool { return false }
func (c *noopClient) InsertFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error {
return nil
}
func (c *noopClient) DeleteFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error {
return nil
}
func (c *noopClient) GetRecommend(ctx context.Context, userID string, n int, offset int) ([]string, error) {
return nil, nil
}
func (c *noopClient) GetNonPersonalized(ctx context.Context, name string, n int, offset int, userID string) ([]string, error) {
return nil, nil
}
func (c *noopClient) UpsertItem(ctx context.Context, itemID string, categories []string, comment string) error {
return nil
}
func (c *noopClient) UpsertItemWithEmbedding(ctx context.Context, itemID string, categories []string, comment string, textToEmbed string) error {
return nil
}
func (c *noopClient) DeleteItem(ctx context.Context, itemID string) error { return nil }
func (c *noopClient) UpsertUser(ctx context.Context, userID string, labels map[string]any) error {
return nil
}
// 确保实现了接口
var _ Client = (*client)(nil)
var _ Client = (*noopClient)(nil)
// log 用于内部日志
func init() {
log.SetFlags(log.LstdFlags | log.Lshortfile)
}

View File

@@ -0,0 +1,23 @@
package gorse
import (
"carrot_bbs/internal/config"
)
// Config Gorse客户端配置从config.GorseConfig转换
type Config struct {
Address string
APIKey string
Enabled bool
Dashboard string
}
// ConfigFromAppConfig 从应用配置创建Gorse配置
func ConfigFromAppConfig(cfg *config.GorseConfig) Config {
return Config{
Address: cfg.Address,
APIKey: cfg.APIKey,
Enabled: cfg.Enabled,
Dashboard: cfg.Dashboard,
}
}

View File

@@ -0,0 +1,106 @@
package gorse
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
)
// EmbeddingConfig embedding服务配置
type EmbeddingConfig struct {
APIKey string
URL string
Model string
}
var defaultEmbeddingConfig = EmbeddingConfig{
APIKey: "sk-ZPN5NMPSqEaOGCPfD2LqndZ5Wwmw3DC4CQgzgKhM35fI3RpD",
URL: "https://api.littlelan.cn/v1/embeddings",
Model: "BAAI/bge-m3",
}
// SetEmbeddingConfig 设置embedding配置
func SetEmbeddingConfig(apiKey, url, model string) {
if apiKey != "" {
defaultEmbeddingConfig.APIKey = apiKey
}
if url != "" {
defaultEmbeddingConfig.URL = url
}
if model != "" {
defaultEmbeddingConfig.Model = model
}
}
// GetEmbedding 获取文本的embedding
func GetEmbedding(text string) ([]float64, error) {
type embeddingRequest struct {
Input string `json:"input"`
Model string `json:"model"`
}
type embeddingResponse struct {
Data []struct {
Embedding []float64 `json:"embedding"`
} `json:"data"`
}
reqBody := embeddingRequest{
Input: text,
Model: defaultEmbeddingConfig.Model,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", defaultEmbeddingConfig.URL, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+defaultEmbeddingConfig.APIKey)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("embedding API error: status=%d, body=%s", resp.StatusCode, string(body))
}
var result embeddingResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
if len(result.Data) == 0 {
return nil, fmt.Errorf("no embedding returned")
}
return result.Data[0].Embedding, nil
}
// InitEmbeddingWithConfig 从应用配置初始化embedding
func InitEmbeddingWithConfig(apiKey, url, model string) {
if apiKey == "" {
log.Println("[WARN] Gorse embedding API key not set, using default")
}
defaultEmbeddingConfig.APIKey = apiKey
if url != "" {
defaultEmbeddingConfig.URL = url
}
if model != "" {
defaultEmbeddingConfig.Model = model
}
}

105
internal/pkg/jwt/jwt.go Normal file
View File

@@ -0,0 +1,105 @@
package jwt
import (
"errors"
"time"
"github.com/golang-jwt/jwt/v5"
)
var (
ErrInvalidToken = errors.New("invalid token")
ErrExpiredToken = errors.New("token has expired")
)
// Claims JWT 声明
type Claims struct {
UserID string `json:"user_id"`
Username string `json:"username"`
jwt.RegisteredClaims
}
// JWT JWT工具
type JWT struct {
secretKey string
accessTokenExpire time.Duration
refreshTokenExpire time.Duration
}
// New 创建JWT实例
func New(secret string, accessExpire, refreshExpire time.Duration) *JWT {
return &JWT{
secretKey: secret,
accessTokenExpire: accessExpire,
refreshTokenExpire: refreshExpire,
}
}
// GenerateAccessToken 生成访问令牌
func (j *JWT) GenerateAccessToken(userID, username string) (string, error) {
now := time.Now()
claims := Claims{
UserID: userID,
Username: username,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(j.accessTokenExpire)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
Issuer: "carrot_bbs",
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(j.secretKey))
}
// GenerateRefreshToken 生成刷新令牌
func (j *JWT) GenerateRefreshToken(userID, username string) (string, error) {
now := time.Now()
claims := Claims{
UserID: userID,
Username: username,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(j.refreshTokenExpire)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
Issuer: "carrot_bbs",
ID: "refresh",
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(j.secretKey))
}
// ParseToken 解析令牌
func (j *JWT) ParseToken(tokenString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(j.secretKey), nil
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
return claims, nil
}
return nil, ErrInvalidToken
}
// ValidateToken 验证令牌
func (j *JWT) ValidateToken(tokenString string) error {
claims, err := j.ParseToken(tokenString)
if err != nil {
return err
}
// 检查是否是刷新令牌
if claims.ID == "refresh" {
return errors.New("cannot use refresh token as access token")
}
return nil
}

View File

@@ -0,0 +1,438 @@
package openai
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"image"
_ "image/gif"
"image/jpeg"
_ "image/png"
"io"
"net/http"
"net/url"
"strings"
"time"
xdraw "golang.org/x/image/draw"
)
const moderationSystemPrompt = "你是中文社区的内容审核助手负责对“帖子标题、正文、配图”做联合审核。目标是平衡社区安全与正常交流必须拦截高风险违规内容但不要误伤正常玩梗、二创、吐槽和轻度调侃。请只输出指定JSON。\n\n审核流程\n1) 先判断是否命中硬性违规;\n2) 再判断语境(玩笑/自嘲/朋友间互动/作品讨论);\n3) 做文图交叉判断(文本+图片合并理解);\n4) 给出 approved 与简短 reason。\n\n硬性违规命中任一项必须 approved=false\nA. 宣传对立与煽动撕裂:\n- 明确煽动群体对立、地域对立、性别对立、民族宗教对立,鼓动仇恨、排斥、报复。\nB. 严重人身攻击与网暴引导:\n- 持续性侮辱贬损、羞辱人格、号召围攻/骚扰/挂人/线下冲突。\nC. 开盒/人肉/隐私暴露:\n- 故意公开、拼接、索取他人可识别隐私信息(姓名+联系方式、身份证号、住址、学校单位、车牌、定位轨迹等);\n- 图片/截图中出现可识别隐私信息并伴随曝光意图,也按违规处理。\nD. 其他高危违规:\n- 违法犯罪、暴力威胁、极端仇恨、色情低俗、诈骗引流、恶意广告等。\n\n放行规则以下通常 approved=true\n- 正常玩梗、表情包、谐音梗、二次创作、无恶意的吐槽;\n- 非定向、轻度口语化吐槽(无明确攻击对象、无网暴号召、无隐私暴露);\n- 对社会事件/作品的理性讨论、观点争论(即使语气尖锐,但未煽动对立或人身攻击)。\n\n边界判定\n- 若只是“梗文化表达”且不指向现实伤害,优先通过;\n- 若存在明确伤害意图(煽动、围攻、曝光隐私),必须拒绝;\n- 对模糊内容不因个别粗口直接拒绝,需结合对象、意图、号召性和可执行性综合判断。\n\nreason 要求:\n- approved=false 时中文10-30字说明核心违规点\n- approved=true 时reason 为空字符串。\n\n输出格式严格\n仅输出一行JSON对象不要Markdown不要额外解释\n{\"approved\": true/false, \"reason\": \"...\"}"
const (
defaultMaxImagesPerModerationRequest = 1
maxModerationResultRetries = 3
maxChatCompletionRetries = 3
initialRetryBackoff = 500 * time.Millisecond
maxDownloadImageBytes = 10 * 1024 * 1024
maxModerationImageSide = 1280
compressedJPEGQuality = 72
maxCompressedPayloadBytes = 1536 * 1024
)
type Client interface {
IsEnabled() bool
Config() Config
ModeratePost(ctx context.Context, title, content string, images []string) (bool, string, error)
ModerateComment(ctx context.Context, content string, images []string) (bool, string, error)
}
type clientImpl struct {
cfg Config
httpClient *http.Client
}
func NewClient(cfg Config) Client {
timeout := cfg.RequestTimeoutSeconds
if timeout <= 0 {
timeout = 30
}
return &clientImpl{
cfg: cfg,
httpClient: &http.Client{
Timeout: time.Duration(timeout) * time.Second,
},
}
}
func (c *clientImpl) IsEnabled() bool {
return c.cfg.Enabled && c.cfg.APIKey != "" && c.cfg.BaseURL != ""
}
func (c *clientImpl) Config() Config {
return c.cfg
}
func (c *clientImpl) ModeratePost(ctx context.Context, title, content string, images []string) (bool, string, error) {
if !c.IsEnabled() {
return true, "", nil
}
return c.moderateContentInBatches(ctx, fmt.Sprintf("帖子标题:%s\n帖子内容%s", title, content), images)
}
func (c *clientImpl) ModerateComment(ctx context.Context, content string, images []string) (bool, string, error) {
if !c.IsEnabled() {
return true, "", nil
}
return c.moderateContentInBatches(ctx, fmt.Sprintf("评论内容:%s", content), images)
}
func (c *clientImpl) moderateContentInBatches(ctx context.Context, contentPrompt string, images []string) (bool, string, error) {
cleanImages := normalizeImageURLs(images)
optimizedImages := c.optimizeImagesForModeration(ctx, cleanImages)
maxImagesPerRequest := c.maxImagesPerModerationRequest()
totalBatches := 1
if len(optimizedImages) > 0 {
totalBatches = (len(optimizedImages) + maxImagesPerRequest - 1) / maxImagesPerRequest
}
// 图片超过单批上限时分批审核,任一批次拒绝即整体拒绝
for i := 0; i < totalBatches; i++ {
start := i * maxImagesPerRequest
end := start + maxImagesPerRequest
if end > len(optimizedImages) {
end = len(optimizedImages)
}
batchImages := []string{}
if len(optimizedImages) > 0 {
batchImages = optimizedImages[start:end]
}
approved, reason, err := c.moderateSingleBatch(ctx, contentPrompt, batchImages, i+1, totalBatches)
if err != nil {
return false, "", err
}
if !approved {
if strings.TrimSpace(reason) != "" && totalBatches > 1 {
reason = fmt.Sprintf("第%d/%d批图片未通过%s", i+1, totalBatches, reason)
}
return false, reason, nil
}
}
return true, "", nil
}
func (c *clientImpl) moderateSingleBatch(
ctx context.Context,
contentPrompt string,
images []string,
batchNo, totalBatches int,
) (bool, string, error) {
userPrompt := fmt.Sprintf(
"%s\n图片批次%d/%d本次仅提供当前批次图片",
contentPrompt,
batchNo,
totalBatches,
)
var lastErr error
for attempt := 0; attempt < maxModerationResultRetries; attempt++ {
replyText, err := c.chatCompletion(ctx, c.cfg.ModerationModel, moderationSystemPrompt, userPrompt, images, 0.1, 220)
if err != nil {
lastErr = err
} else {
parsed := struct {
Approved bool `json:"approved"`
Reason string `json:"reason"`
}{}
if err := json.Unmarshal([]byte(extractJSONObject(replyText)), &parsed); err != nil {
lastErr = fmt.Errorf("failed to parse moderation result: %w", err)
} else {
return parsed.Approved, parsed.Reason, nil
}
}
if attempt == maxModerationResultRetries-1 {
break
}
if sleepErr := sleepWithBackoff(ctx, attempt); sleepErr != nil {
return false, "", sleepErr
}
}
return false, "", fmt.Errorf(
"moderation batch %d/%d failed after %d attempts: %w",
batchNo,
totalBatches,
maxModerationResultRetries,
lastErr,
)
}
type chatCompletionsRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
}
type chatMessage struct {
Role string `json:"role"`
Content interface{} `json:"content"`
}
type contentPart struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageURL *imageURLPart `json:"image_url,omitempty"`
}
type imageURLPart struct {
URL string `json:"url"`
}
type chatCompletionsResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
func (c *clientImpl) chatCompletion(
ctx context.Context,
model string,
systemPrompt string,
userPrompt string,
images []string,
temperature float64,
maxTokens int,
) (string, error) {
if model == "" {
return "", fmt.Errorf("model is empty")
}
cleanImages := normalizeImageURLs(images)
userParts := []contentPart{
{Type: "text", Text: userPrompt},
}
for _, image := range cleanImages {
userParts = append(userParts, contentPart{
Type: "image_url",
ImageURL: &imageURLPart{URL: image},
})
}
reqBody := chatCompletionsRequest{
Model: model,
Messages: []chatMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userParts},
},
Temperature: temperature,
MaxTokens: maxTokens,
}
data, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshal request: %w", err)
}
baseURL := strings.TrimRight(c.cfg.BaseURL, "/")
endpoint := baseURL + "/v1/chat/completions"
if strings.HasSuffix(baseURL, "/v1") {
endpoint = baseURL + "/chat/completions"
}
var lastErr error
for attempt := 0; attempt < maxChatCompletionRetries; attempt++ {
body, statusCode, err := c.doChatCompletionRequest(ctx, endpoint, data)
if err != nil {
lastErr = err
} else if statusCode >= 400 {
lastErr = fmt.Errorf("openai error status=%d body=%s", statusCode, string(body))
if !isRetryableStatusCode(statusCode) {
return "", lastErr
}
} else {
var parsed chatCompletionsResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return "", fmt.Errorf("decode response: %w", err)
}
if len(parsed.Choices) == 0 {
return "", fmt.Errorf("empty response choices")
}
return parsed.Choices[0].Message.Content, nil
}
if attempt == maxChatCompletionRetries-1 {
break
}
if sleepErr := sleepWithBackoff(ctx, attempt); sleepErr != nil {
return "", sleepErr
}
}
return "", lastErr
}
func (c *clientImpl) doChatCompletionRequest(ctx context.Context, endpoint string, data []byte) ([]byte, int, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data))
if err != nil {
return nil, 0, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.cfg.APIKey)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("request openai: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, 0, fmt.Errorf("read response: %w", err)
}
return body, resp.StatusCode, nil
}
func isRetryableStatusCode(statusCode int) bool {
if statusCode == http.StatusTooManyRequests {
return true
}
return statusCode >= 500 && statusCode <= 599
}
func sleepWithBackoff(ctx context.Context, attempt int) error {
delay := initialRetryBackoff * time.Duration(1<<attempt)
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return fmt.Errorf("request cancelled: %w", ctx.Err())
case <-timer.C:
return nil
}
}
func normalizeImageURLs(images []string) []string {
clean := make([]string, 0, len(images))
for _, image := range images {
trimmed := strings.TrimSpace(image)
if trimmed == "" {
continue
}
clean = append(clean, trimmed)
}
return clean
}
func extractJSONObject(raw string) string {
text := strings.TrimSpace(raw)
start := strings.Index(text, "{")
end := strings.LastIndex(text, "}")
if start >= 0 && end > start {
return text[start : end+1]
}
return text
}
func (c *clientImpl) maxImagesPerModerationRequest() int {
// 审核固定单图请求降低单次payload体积减少超时风险。
if c.cfg.ModerationMaxImagesPerRequest <= 0 {
return defaultMaxImagesPerModerationRequest
}
if c.cfg.ModerationMaxImagesPerRequest > 1 {
return 1
}
return c.cfg.ModerationMaxImagesPerRequest
}
func (c *clientImpl) optimizeImagesForModeration(ctx context.Context, images []string) []string {
if len(images) == 0 {
return images
}
optimized := make([]string, 0, len(images))
for _, imageURL := range images {
optimized = append(optimized, c.tryCompressImageForModeration(ctx, imageURL))
}
return optimized
}
func (c *clientImpl) tryCompressImageForModeration(ctx context.Context, imageURL string) string {
parsed, err := url.Parse(imageURL)
if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") {
return imageURL
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, imageURL, nil)
if err != nil {
return imageURL
}
resp, err := c.httpClient.Do(req)
if err != nil {
return imageURL
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return imageURL
}
if !strings.HasPrefix(strings.ToLower(resp.Header.Get("Content-Type")), "image/") {
return imageURL
}
originBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxDownloadImageBytes))
if err != nil || len(originBytes) == 0 {
return imageURL
}
srcImg, _, err := image.Decode(bytes.NewReader(originBytes))
if err != nil {
return imageURL
}
dstImg := resizeIfNeeded(srcImg, maxModerationImageSide)
var buf bytes.Buffer
if err := jpeg.Encode(&buf, dstImg, &jpeg.Options{Quality: compressedJPEGQuality}); err != nil {
return imageURL
}
compressed := buf.Bytes()
if len(compressed) == 0 || len(compressed) > maxCompressedPayloadBytes {
return imageURL
}
// 压缩效果不明显时直接使用原图URL避免增大请求体。
if len(compressed) >= int(float64(len(originBytes))*0.95) {
return imageURL
}
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(compressed)
}
func resizeIfNeeded(src image.Image, maxSide int) image.Image {
bounds := src.Bounds()
w := bounds.Dx()
h := bounds.Dy()
if w <= 0 || h <= 0 || max(w, h) <= maxSide {
return src
}
newW, newH := w, h
if w >= h {
newW = maxSide
newH = int(float64(h) * (float64(maxSide) / float64(w)))
} else {
newH = maxSide
newW = int(float64(w) * (float64(maxSide) / float64(h)))
}
if newW < 1 {
newW = 1
}
if newH < 1 {
newH = 1
}
dst := image.NewRGBA(image.Rect(0, 0, newW, newH))
xdraw.CatmullRom.Scale(dst, dst.Bounds(), src, bounds, xdraw.Over, nil)
return dst
}

View File

@@ -0,0 +1,27 @@
package openai
import "carrot_bbs/internal/config"
// Config OpenAI 兼容接口配置
type Config struct {
Enabled bool
BaseURL string
APIKey string
ModerationModel string
ModerationMaxImagesPerRequest int
RequestTimeoutSeconds int
StrictModeration bool
}
// ConfigFromAppConfig 从应用配置转换
func ConfigFromAppConfig(cfg *config.OpenAIConfig) Config {
return Config{
Enabled: cfg.Enabled,
BaseURL: cfg.BaseURL,
APIKey: cfg.APIKey,
ModerationModel: cfg.ModerationModel,
ModerationMaxImagesPerRequest: cfg.ModerationMaxImagesPerRequest,
RequestTimeoutSeconds: cfg.RequestTimeout,
StrictModeration: cfg.StrictModeration,
}
}

119
internal/pkg/redis/redis.go Normal file
View File

@@ -0,0 +1,119 @@
package redis
import (
"context"
"fmt"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
"carrot_bbs/internal/config"
)
// Client Redis客户端
type Client struct {
rdb *redis.Client
isMiniRedis bool
mr *miniredis.Miniredis
}
// New 创建Redis客户端
func New(cfg *config.RedisConfig) (*Client, error) {
switch cfg.Type {
case "miniredis":
// 启动内嵌Redis模拟
mr, err := miniredis.Run()
if err != nil {
return nil, fmt.Errorf("failed to start miniredis: %w", err)
}
rdb := redis.NewClient(&redis.Options{
Addr: mr.Addr(),
Password: "",
DB: 0,
})
return &Client{
rdb: rdb,
isMiniRedis: true,
mr: mr,
}, nil
case "redis":
// 使用真实Redis
rdb := redis.NewClient(&redis.Options{
Addr: cfg.Redis.Addr(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
PoolSize: cfg.PoolSize,
})
ctx := context.Background()
if err := rdb.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("failed to connect to redis: %w", err)
}
return &Client{rdb: rdb, isMiniRedis: false}, nil
default:
// 默认使用miniredis
mr, err := miniredis.Run()
if err != nil {
return nil, fmt.Errorf("failed to start miniredis: %w", err)
}
rdb := redis.NewClient(&redis.Options{
Addr: mr.Addr(),
})
return &Client{
rdb: rdb,
isMiniRedis: true,
mr: mr,
}, nil
}
}
// Get 获取值
func (c *Client) Get(ctx context.Context, key string) (string, error) {
return c.rdb.Get(ctx, key).Result()
}
// Set 设置值
func (c *Client) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
return c.rdb.Set(ctx, key, value, expiration).Err()
}
// Del 删除键
func (c *Client) Del(ctx context.Context, keys ...string) error {
return c.rdb.Del(ctx, keys...).Err()
}
// Exists 检查键是否存在
func (c *Client) Exists(ctx context.Context, keys ...string) (int64, error) {
return c.rdb.Exists(ctx, keys...).Result()
}
// Incr 递增
func (c *Client) Incr(ctx context.Context, key string) (int64, error) {
return c.rdb.Incr(ctx, key).Result()
}
// Expire 设置过期时间
func (c *Client) Expire(ctx context.Context, key string, expiration time.Duration) (bool, error) {
return c.rdb.Expire(ctx, key, expiration).Result()
}
// GetClient 获取原生客户端
func (c *Client) GetClient() *redis.Client {
return c.rdb
}
// Close 关闭连接
func (c *Client) Close() error {
if err := c.rdb.Close(); err != nil {
return err
}
if c.mr != nil {
c.mr.Close()
}
return nil
}
// IsMiniRedis 返回是否是miniredis
func (c *Client) IsMiniRedis() bool {
return c.isMiniRedis
}

View File

@@ -0,0 +1,117 @@
package response
import (
"net/http"
"github.com/gin-gonic/gin"
)
// Response 统一响应结构
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// ResponseSnakeCase 统一响应结构(snake_case)
type ResponseSnakeCase struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// Success 成功响应
func Success(c *gin.Context, data interface{}) {
c.JSON(http.StatusOK, Response{
Code: 0,
Message: "success",
Data: data,
})
}
// SuccessWithMessage 成功响应带消息
func SuccessWithMessage(c *gin.Context, message string, data interface{}) {
c.JSON(http.StatusOK, Response{
Code: 0,
Message: message,
Data: data,
})
}
// Error 错误响应
func Error(c *gin.Context, code int, message string) {
c.JSON(http.StatusBadRequest, Response{
Code: code,
Message: message,
})
}
// ErrorWithStatusCode 带状态码的错误响应
func ErrorWithStatusCode(c *gin.Context, statusCode int, code int, message string) {
c.JSON(statusCode, Response{
Code: code,
Message: message,
})
}
// BadRequest 参数错误
func BadRequest(c *gin.Context, message string) {
ErrorWithStatusCode(c, http.StatusBadRequest, 400, message)
}
// Unauthorized 未授权
func Unauthorized(c *gin.Context, message string) {
if message == "" {
message = "unauthorized"
}
ErrorWithStatusCode(c, http.StatusUnauthorized, 401, message)
}
// Forbidden 禁止访问
func Forbidden(c *gin.Context, message string) {
if message == "" {
message = "forbidden"
}
ErrorWithStatusCode(c, http.StatusForbidden, 403, message)
}
// NotFound 资源不存在
func NotFound(c *gin.Context, message string) {
if message == "" {
message = "resource not found"
}
ErrorWithStatusCode(c, http.StatusNotFound, 404, message)
}
// InternalServerError 服务器内部错误
func InternalServerError(c *gin.Context, message string) {
if message == "" {
message = "internal server error"
}
ErrorWithStatusCode(c, http.StatusInternalServerError, 500, message)
}
// PaginatedResponse 分页响应
type PaginatedResponse struct {
List interface{} `json:"list"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
}
// Paginated 分页成功响应
func Paginated(c *gin.Context, list interface{}, total int64, page, pageSize int) {
totalPages := int(total) / pageSize
if int(total)%pageSize > 0 {
totalPages++
}
Success(c, PaginatedResponse{
List: list,
Total: total,
Page: page,
PageSize: pageSize,
TotalPages: totalPages,
})
}

119
internal/pkg/s3/s3.go Normal file
View File

@@ -0,0 +1,119 @@
package s3
import (
"bytes"
"context"
"fmt"
"time"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"carrot_bbs/internal/config"
)
// Client S3客户端
type Client struct {
client *minio.Client
bucket string
domain string
}
// New 创建S3客户端
func New(cfg *config.S3Config) (*Client, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
client, err := minio.New(cfg.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(cfg.AccessKey, cfg.SecretKey, ""),
Secure: cfg.UseSSL,
})
if err != nil {
return nil, fmt.Errorf("failed to create S3 client: %w", err)
}
// 检查bucket是否存在
exists, err := client.BucketExists(ctx, cfg.Bucket)
if err != nil {
return nil, fmt.Errorf("failed to check bucket: %w", err)
}
if !exists {
if err := client.MakeBucket(ctx, cfg.Bucket, minio.MakeBucketOptions{
Region: cfg.Region,
}); err != nil {
return nil, fmt.Errorf("failed to create bucket: %w", err)
}
}
// 如果没有配置domain则使用默认的endpoint
domain := cfg.Domain
if domain == "" {
domain = cfg.Endpoint
}
return &Client{
client: client,
bucket: cfg.Bucket,
domain: domain,
}, nil
}
// Upload 上传文件
func (c *Client) Upload(ctx context.Context, objectName string, filePath string, contentType string) (string, error) {
_, err := c.client.FPutObject(ctx, c.bucket, objectName, filePath, minio.PutObjectOptions{
ContentType: contentType,
})
if err != nil {
return "", fmt.Errorf("failed to upload file: %w", err)
}
return fmt.Sprintf("%s/%s", c.bucket, objectName), nil
}
// UploadData 上传数据
func (c *Client) UploadData(ctx context.Context, objectName string, data []byte, contentType string) (string, error) {
_, err := c.client.PutObject(ctx, c.bucket, objectName, bytes.NewReader(data), int64(len(data)), minio.PutObjectOptions{
ContentType: contentType,
})
if err != nil {
return "", fmt.Errorf("failed to upload data: %w", err)
}
// 返回完整URL包含bucket名称
scheme := "https"
if c.domain == c.bucket || c.domain == "" {
scheme = "http"
}
return fmt.Sprintf("%s://%s/%s/%s", scheme, c.domain, c.bucket, objectName), nil
}
// GetURL 获取文件URL - 使用自定义域名
func (c *Client) GetURL(ctx context.Context, objectName string) (string, error) {
// 使用自定义域名构建URL包含bucket名称
scheme := "https"
if c.domain == c.bucket || c.domain == "" {
scheme = "http"
}
return fmt.Sprintf("%s://%s/%s/%s", scheme, c.domain, c.bucket, objectName), nil
}
// GetPresignedURL 获取预签名URL用于私有桶
func (c *Client) GetPresignedURL(ctx context.Context, objectName string) (string, error) {
url, err := c.client.PresignedGetObject(ctx, c.bucket, objectName, time.Hour*24, nil)
if err != nil {
return "", fmt.Errorf("failed to get presigned URL: %w", err)
}
return url.String(), nil
}
// Delete 删除文件
func (c *Client) Delete(ctx context.Context, objectName string) error {
return c.client.RemoveObject(ctx, c.bucket, objectName, minio.RemoveObjectOptions{})
}
// GetClient 获取原生客户端
func (c *Client) GetClient() *minio.Client {
return c.client
}

View File

@@ -0,0 +1,52 @@
package utils
import (
"net/url"
)
// AvatarServiceBaseURL 默认头像服务基础URL (使用 UI Avatars API)
const AvatarServiceBaseURL = "https://ui-avatars.com/api"
// DefaultAvatarSize 默认头像尺寸
const DefaultAvatarSize = 100
// AvatarInfo 头像信息
type AvatarInfo struct {
Username string
Nickname string
Avatar string
}
// GetAvatarOrDefault 获取头像URL如果为空则返回在线头像生成服务的URL
// 优先使用已有的头像,否则使用昵称或用户名生成默认头像
func GetAvatarOrDefault(username, nickname, avatar string) string {
if avatar != "" {
return avatar
}
// 使用用户名生成默认头像URL优先使用昵称
displayName := nickname
if displayName == "" {
displayName = username
}
return GenerateDefaultAvatarURL(displayName)
}
// GetAvatarOrDefaultFromInfo 从 AvatarInfo 获取头像URL
func GetAvatarOrDefaultFromInfo(info AvatarInfo) string {
return GetAvatarOrDefault(info.Username, info.Nickname, info.Avatar)
}
// GenerateDefaultAvatarURL 生成默认头像URL
// 使用 UI Avatars API 生成基于用户名首字母的头像
func GenerateDefaultAvatarURL(name string) string {
if name == "" {
name = "?"
}
// 使用 UI Avatars API 生成头像
params := url.Values{}
params.Set("name", url.QueryEscape(name))
params.Set("size", "100")
params.Set("background", "0D8ABC") // 默认蓝色背景
params.Set("color", "ffffff") // 白色文字
return AvatarServiceBaseURL + "?" + params.Encode()
}

View File

@@ -0,0 +1,17 @@
package utils
import (
"golang.org/x/crypto/bcrypt"
)
// HashPassword 密码哈希
func HashPassword(password string) (string, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return string(bytes), err
}
// CheckPasswordHash 验证密码
func CheckPasswordHash(password, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
}

View File

@@ -0,0 +1,261 @@
package utils
import (
"errors"
"os"
"sync"
"time"
)
// 雪花算法常量定义
const (
// 64位ID结构1位符号位 + 41位时间戳 + 10位机器ID + 12位序列号
// 机器ID占用的位数
nodeIDBits uint64 = 10
// 序列号占用的位数
sequenceBits uint64 = 12
// 机器ID的最大值 (0-1023)
maxNodeID int64 = -1 ^ (-1 << nodeIDBits)
// 序列号的最大值 (0-4095)
maxSequence int64 = -1 ^ (-1 << sequenceBits)
// 机器ID左移位数
nodeIDShift uint64 = sequenceBits
// 时间戳左移位数
timestampShift uint64 = sequenceBits + nodeIDBits
// 自定义纪元时间2024-01-01 00:00:00 UTC
// 使用自定义纪元可以延长ID有效期约24年从2024年开始
customEpoch int64 = 1704067200000 // 2024-01-01 00:00:00 UTC 的毫秒时间戳
)
// 错误定义
var (
// ErrInvalidNodeID 机器ID无效
ErrInvalidNodeID = errors.New("node ID must be between 0 and 1023")
// ErrClockBackwards 时钟回拨
ErrClockBackwards = errors.New("clock moved backwards, refusing to generate ID")
)
// IDInfo 解析后的ID信息
type IDInfo struct {
Timestamp int64 // 生成ID时的时间戳毫秒
NodeID int64 // 机器ID
Sequence int64 // 序列号
}
// Snowflake 雪花算法ID生成器
type Snowflake struct {
mu sync.Mutex // 互斥锁,保证线程安全
nodeID int64 // 机器ID (0-1023)
sequence int64 // 当前序列号 (0-4095)
lastTimestamp int64 // 上次生成ID的时间戳
}
// 全局雪花算法实例
var (
globalSnowflake *Snowflake
globalSnowflakeOnce sync.Once
globalSnowflakeErr error
)
// InitSnowflake 初始化全局雪花算法实例
// nodeID: 机器ID范围0-1023可以通过环境变量 NODE_ID 配置
func InitSnowflake(nodeID int64) error {
globalSnowflake, globalSnowflakeErr = NewSnowflake(nodeID)
return globalSnowflakeErr
}
// GetSnowflake 获取全局雪花算法实例
// 如果未初始化,会自动使用默认配置初始化
func GetSnowflake() *Snowflake {
globalSnowflakeOnce.Do(func() {
if globalSnowflake == nil {
globalSnowflake, globalSnowflakeErr = NewSnowflake(-1)
}
})
return globalSnowflake
}
// NewSnowflake 创建雪花算法ID生成器实例
// nodeID: 机器ID范围0-1023可以通过环境变量 NODE_ID 配置
// 如果nodeID为-1则尝试从环境变量 NODE_ID 读取
func NewSnowflake(nodeID int64) (*Snowflake, error) {
// 如果传入-1尝试从环境变量读取
if nodeID == -1 {
nodeIDStr := os.Getenv("NODE_ID")
if nodeIDStr != "" {
// 解析环境变量
parsedID, err := parseInt(nodeIDStr)
if err != nil {
return nil, ErrInvalidNodeID
}
nodeID = parsedID
} else {
// 默认使用0
nodeID = 0
}
}
// 验证机器ID范围
if nodeID < 0 || nodeID > maxNodeID {
return nil, ErrInvalidNodeID
}
return &Snowflake{
nodeID: nodeID,
sequence: 0,
lastTimestamp: 0,
}, nil
}
// parseInt 辅助函数:解析整数
func parseInt(s string) (int64, error) {
var result int64
var negative bool
if len(s) == 0 {
return 0, errors.New("empty string")
}
i := 0
if s[0] == '-' {
negative = true
i = 1
}
for ; i < len(s); i++ {
if s[i] < '0' || s[i] > '9' {
return 0, errors.New("invalid character")
}
result = result*10 + int64(s[i]-'0')
}
if negative {
result = -result
}
return result, nil
}
// GenerateID 生成唯一的雪花算法ID
// 返回值生成的ID以及可能的错误如时钟回拨
// 线程安全:使用互斥锁保证并发安全
func (s *Snowflake) GenerateID() (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
// 获取当前时间戳(毫秒)
now := currentTimestamp()
// 处理时钟回拨
if now < s.lastTimestamp {
return 0, ErrClockBackwards
}
// 同一毫秒内
if now == s.lastTimestamp {
// 序列号递增
s.sequence = (s.sequence + 1) & maxSequence
// 序列号溢出,等待下一毫秒
if s.sequence == 0 {
now = s.waitNextMillis(now)
}
} else {
// 不同毫秒序列号重置为0
s.sequence = 0
}
// 更新上次生成时间
s.lastTimestamp = now
// 组装ID
// ID结构时间戳部分 | 机器ID部分 | 序列号部分
id := ((now - customEpoch) << timestampShift) |
(s.nodeID << nodeIDShift) |
s.sequence
return id, nil
}
// waitNextMillis 等待到下一毫秒
// 参数:当前时间戳
// 返回值:下一毫秒的时间戳
func (s *Snowflake) waitNextMillis(timestamp int64) int64 {
now := currentTimestamp()
for now <= timestamp {
now = currentTimestamp()
}
return now
}
// ParseID 解析雪花算法ID提取其中的信息
// id: 要解析的雪花算法ID
// 返回值包含时间戳、机器ID、序列号的结构体
func ParseID(id int64) *IDInfo {
// 提取序列号低12位
sequence := id & maxSequence
// 提取机器ID中间10位
nodeID := (id >> nodeIDShift) & maxNodeID
// 提取时间戳高41位
timestamp := (id >> timestampShift) + customEpoch
return &IDInfo{
Timestamp: timestamp,
NodeID: nodeID,
Sequence: sequence,
}
}
// currentTimestamp 获取当前时间戳(毫秒)
func currentTimestamp() int64 {
return time.Now().UnixNano() / 1e6
}
// GetNodeID 获取当前机器ID
func (s *Snowflake) GetNodeID() int64 {
return s.nodeID
}
// GetCustomEpoch 获取自定义纪元时间
func GetCustomEpoch() int64 {
return customEpoch
}
// IDToTime 将雪花算法ID转换为生成时间
// 这是一个便捷方法,等价于 ParseID(id).Timestamp
func IDToTime(id int64) time.Time {
info := ParseID(id)
return time.Unix(0, info.Timestamp*1e6) // 毫秒转纳秒
}
// ValidateID 验证ID是否为有效的雪花算法ID
// 检查时间戳是否在合理范围内
func ValidateID(id int64) bool {
if id <= 0 {
return false
}
info := ParseID(id)
// 检查时间戳是否在合理范围内
// 不能早于纪元时间不能晚于当前时间太多允许1分钟的时钟偏差
now := currentTimestamp()
if info.Timestamp < customEpoch || info.Timestamp > now+60000 {
return false
}
// 检查机器ID和序列号是否在有效范围内
if info.NodeID < 0 || info.NodeID > maxNodeID {
return false
}
if info.Sequence < 0 || info.Sequence > maxSequence {
return false
}
return true
}

View File

@@ -0,0 +1,46 @@
package utils
import (
"regexp"
"strings"
)
// ValidateEmail 验证邮箱
func ValidateEmail(email string) bool {
pattern := `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`
matched, _ := regexp.MatchString(pattern, email)
return matched
}
// ValidateUsername 验证用户名
func ValidateUsername(username string) bool {
if len(username) < 3 || len(username) > 30 {
return false
}
pattern := `^[a-zA-Z0-9_]+$`
matched, _ := regexp.MatchString(pattern, username)
return matched
}
// ValidatePassword 验证密码强度
func ValidatePassword(password string) bool {
if len(password) < 6 || len(password) > 50 {
return false
}
return true
}
// ValidatePhone 验证手机号
func ValidatePhone(phone string) bool {
pattern := `^1[3-9]\d{9}$`
matched, _ := regexp.MatchString(pattern, phone)
return matched
}
// SanitizeHTML 清理HTML
func SanitizeHTML(input string) string {
// 简单实现,实际使用建议用专门库
input = strings.ReplaceAll(input, "<", "<")
input = strings.ReplaceAll(input, ">", ">")
return input
}

View File

@@ -0,0 +1,440 @@
package websocket
import (
"carrot_bbs/internal/model"
"encoding/json"
"log"
"sync"
"time"
"github.com/gorilla/websocket"
)
// WebSocket消息类型常量
const (
MessageTypePing = "ping"
MessageTypePong = "pong"
MessageTypeMessage = "message"
MessageTypeTyping = "typing"
MessageTypeRead = "read"
MessageTypeAck = "ack"
MessageTypeError = "error"
MessageTypeRecall = "recall" // 撤回消息
MessageTypeSystem = "system" // 系统消息
MessageTypeNotification = "notification" // 通知消息
MessageTypeAnnouncement = "announcement" // 公告消息
// 群组相关消息类型
MessageTypeGroupMessage = "group_message" // 群消息
MessageTypeGroupTyping = "group_typing" // 群输入状态
MessageTypeGroupNotice = "group_notice" // 群组通知(成员变动等)
MessageTypeGroupMention = "group_mention" // @提及通知
MessageTypeGroupRead = "group_read" // 群消息已读
MessageTypeGroupRecall = "group_recall" // 群消息撤回
// Meta事件详细类型
MetaDetailTypeHeartbeat = "heartbeat"
MetaDetailTypeTyping = "typing"
MetaDetailTypeAck = "ack" // 消息发送确认
MetaDetailTypeRead = "read" // 已读回执
)
// WSMessage WebSocket消息结构
type WSMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
Timestamp int64 `json:"timestamp"`
}
// ChatMessage 聊天消息结构
type ChatMessage struct {
ID string `json:"id"`
ConversationID string `json:"conversation_id"`
SenderID string `json:"sender_id"`
Seq int64 `json:"seq"`
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
ReplyToID *string `json:"reply_to_id,omitempty"`
CreatedAt int64 `json:"created_at"`
}
// SystemMessage 系统消息结构
type SystemMessage struct {
ID string `json:"id"` // 消息ID
Type string `json:"type"` // 消息子类型account_banned, post_approved等
Title string `json:"title"` // 消息标题
Content string `json:"content"` // 消息内容
Data map[string]interface{} `json:"data"` // 额外数据
CreatedAt int64 `json:"created_at"` // 创建时间戳
}
// NotificationMessage 通知消息结构
type NotificationMessage struct {
ID string `json:"id"` // 通知ID
Type string `json:"type"` // 通知类型like, comment, follow, mention等
Title string `json:"title"` // 通知标题
Content string `json:"content"` // 通知内容
TriggerUser *NotificationUser `json:"trigger_user"` // 触发用户
ResourceType string `json:"resource_type"` // 资源类型post, comment等
ResourceID string `json:"resource_id"` // 资源ID
Extra map[string]interface{} `json:"extra"` // 额外数据
CreatedAt int64 `json:"created_at"` // 创建时间戳
}
// NotificationUser 通知中的用户信息
type NotificationUser struct {
ID string `json:"id"`
Username string `json:"username"`
Avatar string `json:"avatar"`
}
// AnnouncementMessage 公告消息结构
type AnnouncementMessage struct {
ID string `json:"id"` // 公告ID
Title string `json:"title"` // 公告标题
Content string `json:"content"` // 公告内容
Priority int `json:"priority"` // 优先级1-10
CreatedAt int64 `json:"created_at"` // 创建时间戳
}
// GroupMessage 群消息结构
type GroupMessage struct {
ID string `json:"id"` // 消息ID
ConversationID string `json:"conversation_id"` // 会话ID群聊会话
GroupID string `json:"group_id"` // 群组ID
SenderID string `json:"sender_id"` // 发送者ID
Seq int64 `json:"seq"` // 消息序号
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID
MentionUsers []uint64 `json:"mention_users,omitempty"` // @的用户ID列表
MentionAll bool `json:"mention_all"` // 是否@所有人
CreatedAt int64 `json:"created_at"` // 创建时间戳
}
// GroupTypingMessage 群输入状态消息
type GroupTypingMessage struct {
GroupID string `json:"group_id"` // 群组ID
UserID string `json:"user_id"` // 用户ID
Username string `json:"username"` // 用户名
IsTyping bool `json:"is_typing"` // 是否正在输入
}
// GroupNoticeMessage 群组通知消息
type GroupNoticeMessage struct {
NoticeType string `json:"notice_type"` // 通知类型member_join, member_leave, member_removed, role_changed, muted, unmuted, group_dissolved
GroupID string `json:"group_id"` // 群组ID
Data interface{} `json:"data"` // 通知数据
Timestamp int64 `json:"timestamp"` // 时间戳
MessageID string `json:"message_id,omitempty"` // 消息ID如果通知保存为消息
Seq int64 `json:"seq,omitempty"` // 消息序号(如果通知保存为消息)
}
// GroupNoticeData 通知数据结构
type GroupNoticeData struct {
// 成员变动
UserID string `json:"user_id,omitempty"` // 变动的用户ID
Username string `json:"username,omitempty"` // 用户名
OperatorID string `json:"operator_id,omitempty"` // 操作者ID
OpName string `json:"op_name,omitempty"` // 操作者名称
NewRole string `json:"new_role,omitempty"` // 新角色
OldRole string `json:"old_role,omitempty"` // 旧角色
MemberCount int `json:"member_count,omitempty"` // 当前成员数
// 群设置变更
MuteAll bool `json:"mute_all,omitempty"` // 全员禁言状态
}
// GroupMentionMessage @提及通知消息
type GroupMentionMessage struct {
GroupID string `json:"group_id"` // 群组ID
MessageID string `json:"message_id"` // 消息ID
FromUserID string `json:"from_user_id"` // 发送者ID
FromName string `json:"from_name"` // 发送者名称
Content string `json:"content"` // 消息内容预览
MentionAll bool `json:"mention_all"` // 是否@所有人
CreatedAt int64 `json:"created_at"` // 创建时间戳
}
// AckMessage 消息发送确认结构
type AckMessage struct {
ConversationID string `json:"conversation_id"` // 会话ID
GroupID string `json:"group_id,omitempty"` // 群组ID群聊时
ID string `json:"id"` // 消息ID
SenderID string `json:"sender_id"` // 发送者ID
Seq int64 `json:"seq"` // 消息序号
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
CreatedAt int64 `json:"created_at"` // 创建时间戳
}
// Client WebSocket客户端
type Client struct {
ID string
UserID string
Conn *websocket.Conn
Send chan []byte
Manager *WebSocketManager
IsClosed bool
Mu sync.Mutex
closeOnce sync.Once // 确保 Send channel 只关闭一次
}
// WebSocketManager WebSocket连接管理器
type WebSocketManager struct {
clients map[string]*Client // userID -> Client
register chan *Client
unregister chan *Client
broadcast chan *BroadcastMessage
mutex sync.RWMutex
}
// BroadcastMessage 广播消息
type BroadcastMessage struct {
Message *WSMessage
ExcludeUser string // 排除的用户ID为空表示不排除
TargetUser string // 目标用户ID为空表示广播给所有用户
}
// NewWebSocketManager 创建WebSocket管理器
func NewWebSocketManager() *WebSocketManager {
return &WebSocketManager{
clients: make(map[string]*Client),
register: make(chan *Client, 100),
unregister: make(chan *Client, 100),
broadcast: make(chan *BroadcastMessage, 100),
}
}
// Start 启动管理器
func (m *WebSocketManager) Start() {
go func() {
for {
select {
case client := <-m.register:
m.mutex.Lock()
m.clients[client.UserID] = client
m.mutex.Unlock()
log.Printf("WebSocket client connected: userID=%s, 当前在线用户数=%d", client.UserID, len(m.clients))
case client := <-m.unregister:
m.mutex.Lock()
if _, ok := m.clients[client.UserID]; ok {
delete(m.clients, client.UserID)
// 使用 closeOnce 确保 channel 只关闭一次,避免 panic
client.closeOnce.Do(func() {
close(client.Send)
})
log.Printf("WebSocket client disconnected: userID=%s", client.UserID)
}
m.mutex.Unlock()
case broadcast := <-m.broadcast:
m.sendMessage(broadcast)
}
}
}()
}
// Register 注册客户端
func (m *WebSocketManager) Register(client *Client) {
m.register <- client
}
// Unregister 注销客户端
func (m *WebSocketManager) Unregister(client *Client) {
m.unregister <- client
}
// Broadcast 广播消息给所有用户
func (m *WebSocketManager) Broadcast(msg *WSMessage) {
m.broadcast <- &BroadcastMessage{
Message: msg,
TargetUser: "",
}
}
// SendToUser 发送消息给指定用户
func (m *WebSocketManager) SendToUser(userID string, msg *WSMessage) {
m.broadcast <- &BroadcastMessage{
Message: msg,
TargetUser: userID,
}
}
// SendToUsers 发送消息给指定用户列表
func (m *WebSocketManager) SendToUsers(userIDs []string, msg *WSMessage) {
for _, userID := range userIDs {
m.SendToUser(userID, msg)
}
}
// GetClient 获取客户端
func (m *WebSocketManager) GetClient(userID string) (*Client, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
client, ok := m.clients[userID]
return client, ok
}
// GetAllClients 获取所有客户端
func (m *WebSocketManager) GetAllClients() map[string]*Client {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.clients
}
// GetClientCount 获取在线用户数量
func (m *WebSocketManager) GetClientCount() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
return len(m.clients)
}
// IsUserOnline 检查用户是否在线
func (m *WebSocketManager) IsUserOnline(userID string) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
_, ok := m.clients[userID]
log.Printf("[DEBUG IsUserOnline] 检查用户 %s, 结果=%v, 当前在线用户=%v", userID, ok, m.clients)
return ok
}
// sendMessage 发送消息
func (m *WebSocketManager) sendMessage(broadcast *BroadcastMessage) {
msgBytes, err := json.Marshal(broadcast.Message)
if err != nil {
log.Printf("Failed to marshal message: %v", err)
return
}
log.Printf("[DEBUG WebSocket] sendMessage: 目标用户=%s, 当前在线用户数=%d, 消息类型=%s",
broadcast.TargetUser, len(m.clients), broadcast.Message.Type)
m.mutex.RLock()
defer m.mutex.RUnlock()
for userID, client := range m.clients {
// 如果指定了目标用户,只发送给目标用户
if broadcast.TargetUser != "" && userID != broadcast.TargetUser {
continue
}
// 如果指定了排除用户,跳过
if broadcast.ExcludeUser != "" && userID == broadcast.ExcludeUser {
continue
}
select {
case client.Send <- msgBytes:
log.Printf("[DEBUG WebSocket] 成功发送消息到用户 %s, 消息类型=%s", userID, broadcast.Message.Type)
default:
log.Printf("Failed to send message to user %s: channel full", userID)
}
}
}
// SendPing 发送心跳
func (c *Client) SendPing() error {
c.Mu.Lock()
defer c.Mu.Unlock()
if c.IsClosed {
return nil
}
msg := WSMessage{
Type: MessageTypePing,
Data: nil,
Timestamp: time.Now().UnixMilli(),
}
msgBytes, _ := json.Marshal(msg)
return c.Conn.WriteMessage(websocket.TextMessage, msgBytes)
}
// SendPong 发送Pong响应
func (c *Client) SendPong() error {
c.Mu.Lock()
defer c.Mu.Unlock()
if c.IsClosed {
return nil
}
msg := WSMessage{
Type: MessageTypePong,
Data: nil,
Timestamp: time.Now().UnixMilli(),
}
msgBytes, _ := json.Marshal(msg)
return c.Conn.WriteMessage(websocket.TextMessage, msgBytes)
}
// WritePump 写入泵将消息从Manager发送到客户端
func (c *Client) WritePump() {
defer func() {
c.Conn.Close()
c.Mu.Lock()
c.IsClosed = true
c.Mu.Unlock()
}()
for {
message, ok := <-c.Send
if !ok {
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
c.Mu.Lock()
if c.IsClosed {
c.Mu.Unlock()
return
}
err := c.Conn.WriteMessage(websocket.TextMessage, message)
c.Mu.Unlock()
if err != nil {
log.Printf("Write error: %v", err)
return
}
}
}
// ReadPump 读取泵,从客户端读取消息
func (c *Client) ReadPump(handler func(msg *WSMessage)) {
defer func() {
c.Manager.Unregister(c)
c.Conn.Close()
c.Mu.Lock()
c.IsClosed = true
c.Mu.Unlock()
}()
c.Conn.SetReadLimit(512 * 1024) // 512KB
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
c.Conn.SetPongHandler(func(string) error {
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
for {
_, message, err := c.Conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket error: %v", err)
}
break
}
var wsMsg WSMessage
if err := json.Unmarshal(message, &wsMsg); err != nil {
log.Printf("Failed to unmarshal message: %v", err)
continue
}
handler(&wsMsg)
}
}
// CreateWSMessage 创建WebSocket消息
func CreateWSMessage(msgType string, data interface{}) *WSMessage {
return &WSMessage{
Type: msgType,
Data: data,
Timestamp: time.Now().UnixMilli(),
}
}

View File

@@ -0,0 +1,296 @@
package repository
import (
"carrot_bbs/internal/model"
"gorm.io/gorm"
)
// CommentRepository 评论仓储
type CommentRepository struct {
db *gorm.DB
}
// NewCommentRepository 创建评论仓储
func NewCommentRepository(db *gorm.DB) *CommentRepository {
return &CommentRepository{db: db}
}
// Create 创建评论
func (r *CommentRepository) Create(comment *model.Comment) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 创建评论
err := tx.Create(comment).Error
if err != nil {
return err
}
// 增加帖子的评论数并同步热度分
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
Updates(map[string]interface{}{
"comments_count": gorm.Expr("comments_count + 1"),
"hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"),
}).Error; err != nil {
return err
}
// 如果是回复,增加父评论的回复数
if comment.ParentID != nil && *comment.ParentID != "" {
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil {
return err
}
}
return nil
})
}
// GetByID 根据ID获取评论
func (r *CommentRepository) GetByID(id string) (*model.Comment, error) {
var comment model.Comment
err := r.db.Preload("User").First(&comment, "id = ?", id).Error
if err != nil {
return nil, err
}
return &comment, nil
}
// Update 更新评论
func (r *CommentRepository) Update(comment *model.Comment) error {
return r.db.Save(comment).Error
}
// UpdateModerationStatus 更新评论审核状态
func (r *CommentRepository) UpdateModerationStatus(commentID string, status model.CommentStatus) error {
return r.db.Model(&model.Comment{}).
Where("id = ?", commentID).
Update("status", status).Error
}
// Delete 删除评论(软删除,同时清理关联数据)
func (r *CommentRepository) Delete(id string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 先查询评论获取post_id和parent_id
var comment model.Comment
if err := tx.First(&comment, "id = ?", id).Error; err != nil {
return err
}
// 删除评论点赞记录
if err := tx.Where("comment_id = ?", id).Delete(&model.CommentLike{}).Error; err != nil {
return err
}
// 删除评论(软删除)
if err := tx.Delete(&model.Comment{}, "id = ?", id).Error; err != nil {
return err
}
// 减少帖子的评论数并同步热度分
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
Updates(map[string]interface{}{
"comments_count": gorm.Expr("comments_count - 1"),
"hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"),
}).Error; err != nil {
return err
}
// 如果是回复,减少父评论的回复数
if comment.ParentID != nil && *comment.ParentID != "" {
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil {
return err
}
}
return nil
})
}
// GetByPostID 获取帖子评论
func (r *CommentRepository) GetByPostID(postID string, page, pageSize int) ([]*model.Comment, int64, error) {
var comments []*model.Comment
var total int64
r.db.Model(&model.Comment{}).Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).
Preload("User").
Offset(offset).Limit(pageSize).
Order("created_at ASC").
Find(&comments).Error
return comments, total, err
}
// GetByPostIDWithReplies 获取帖子评论(包含回复,扁平化结构)
// 所有层级的回复都扁平化展示在顶级评论的 replies 中
func (r *CommentRepository) GetByPostIDWithReplies(postID string, page, pageSize, replyLimit int) ([]*model.Comment, int64, error) {
var comments []*model.Comment
var total int64
r.db.Model(&model.Comment{}).Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).
Preload("User").
Offset(offset).Limit(pageSize).
Order("created_at ASC").
Find(&comments).Error
if err != nil {
return nil, 0, err
}
if len(comments) == 0 {
return comments, total, nil
}
rootIDs := make([]string, 0, len(comments))
commentsByID := make(map[string]*model.Comment, len(comments))
for _, comment := range comments {
rootIDs = append(rootIDs, comment.ID)
commentsByID[comment.ID] = comment
}
// 批量加载所有回复,内存中按 root_id 分组并裁剪每个根评论的返回条数
var allReplies []*model.Comment
if err := r.db.Where("root_id IN ? AND status = ?", rootIDs, model.CommentStatusPublished).
Preload("User").
Order("created_at ASC").
Find(&allReplies).Error; err != nil {
return nil, 0, err
}
repliesByRoot := make(map[string][]*model.Comment, len(rootIDs))
for _, reply := range allReplies {
if reply.RootID == nil {
continue
}
rootID := *reply.RootID
if replyLimit <= 0 || len(repliesByRoot[rootID]) < replyLimit {
repliesByRoot[rootID] = append(repliesByRoot[rootID], reply)
}
}
type replyCountRow struct {
RootID string
Total int64
}
var replyCountRows []replyCountRow
if err := r.db.Model(&model.Comment{}).
Select("root_id, COUNT(*) AS total").
Where("root_id IN ? AND status = ?", rootIDs, model.CommentStatusPublished).
Group("root_id").
Scan(&replyCountRows).Error; err != nil {
return nil, 0, err
}
replyCountMap := make(map[string]int64, len(replyCountRows))
for _, row := range replyCountRows {
replyCountMap[row.RootID] = row.Total
}
for _, rootID := range rootIDs {
comment := commentsByID[rootID]
comment.Replies = repliesByRoot[rootID]
comment.RepliesCount = int(replyCountMap[rootID])
}
return comments, total, nil
}
// loadFlatReplies 加载评论的所有回复(扁平化,所有层级都在同一层)
func (r *CommentRepository) loadFlatReplies(rootComment *model.Comment, limit int) {
var allReplies []*model.Comment
// 查询所有以该评论为根评论的回复(不包括顶级评论本身)
r.db.Where("root_id = ? AND status = ?", rootComment.ID, model.CommentStatusPublished).
Preload("User").
Order("created_at ASC").
Limit(limit).
Find(&allReplies)
rootComment.Replies = allReplies
}
// GetRepliesByRootID 根据根评论ID分页获取回复扁平化
func (r *CommentRepository) GetRepliesByRootID(rootID string, page, pageSize int) ([]*model.Comment, int64, error) {
var replies []*model.Comment
var total int64
// 统计总数
r.db.Model(&model.Comment{}).Where("root_id = ? AND status = ?", rootID, model.CommentStatusPublished).Count(&total)
// 分页查询
offset := (page - 1) * pageSize
err := r.db.Where("root_id = ? AND status = ?", rootID, model.CommentStatusPublished).
Preload("User").
Order("created_at ASC").
Offset(offset).
Limit(pageSize).
Find(&replies).Error
return replies, total, err
}
// GetReplies 获取回复
func (r *CommentRepository) GetReplies(parentID string) ([]*model.Comment, error) {
var comments []*model.Comment
err := r.db.Where("parent_id = ? AND status = ?", parentID, model.CommentStatusPublished).
Preload("User").
Order("created_at ASC").
Find(&comments).Error
return comments, err
}
// Like 点赞评论
func (r *CommentRepository) Like(commentID, userID string) error {
// 检查是否已经点赞
var existing model.CommentLike
err := r.db.Where("comment_id = ? AND user_id = ?", commentID, userID).First(&existing).Error
if err == nil {
// 已经点赞
return nil
}
if err != gorm.ErrRecordNotFound {
return err
}
// 创建点赞记录
like := &model.CommentLike{
CommentID: commentID,
UserID: userID,
}
err = r.db.Create(like).Error
if err != nil {
return err
}
// 增加评论点赞数
return r.db.Model(&model.Comment{}).Where("id = ?", commentID).
UpdateColumn("likes_count", gorm.Expr("likes_count + 1")).Error
}
// Unlike 取消点赞评论
func (r *CommentRepository) Unlike(commentID, userID string) error {
result := r.db.Where("comment_id = ? AND user_id = ?", commentID, userID).Delete(&model.CommentLike{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected > 0 {
// 减少评论点赞数
return r.db.Model(&model.Comment{}).Where("id = ?", commentID).
UpdateColumn("likes_count", gorm.Expr("likes_count - 1")).Error
}
return nil
}
// IsLiked 检查是否已点赞
func (r *CommentRepository) IsLiked(commentID, userID string) bool {
var count int64
r.db.Model(&model.CommentLike{}).Where("comment_id = ? AND user_id = ?", commentID, userID).Count(&count)
return count > 0
}

View File

@@ -0,0 +1,166 @@
package repository
import (
"time"
"carrot_bbs/internal/model"
"gorm.io/gorm"
)
// DeviceTokenRepository 设备Token仓储
type DeviceTokenRepository struct {
db *gorm.DB
}
// NewDeviceTokenRepository 创建设备Token仓储
func NewDeviceTokenRepository(db *gorm.DB) *DeviceTokenRepository {
return &DeviceTokenRepository{db: db}
}
// Create 创建设备Token
func (r *DeviceTokenRepository) Create(token *model.DeviceToken) error {
return r.db.Create(token).Error
}
// GetByID 根据ID获取设备Token
func (r *DeviceTokenRepository) GetByID(id int64) (*model.DeviceToken, error) {
var token model.DeviceToken
err := r.db.First(&token, "id = ?", id).Error
if err != nil {
return nil, err
}
return &token, nil
}
// Update 更新设备Token
func (r *DeviceTokenRepository) Update(token *model.DeviceToken) error {
return r.db.Save(token).Error
}
// Delete 删除设备Token软删除
func (r *DeviceTokenRepository) Delete(id int64) error {
return r.db.Delete(&model.DeviceToken{}, id).Error
}
// GetByUserID 获取用户所有设备
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *DeviceTokenRepository) GetByUserID(userID string) ([]*model.DeviceToken, error) {
var tokens []*model.DeviceToken
err := r.db.Where("user_id = ?", userID).
Order("created_at DESC").
Find(&tokens).Error
return tokens, err
}
// GetActiveByUserID 获取用户活跃设备
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *DeviceTokenRepository) GetActiveByUserID(userID string) ([]*model.DeviceToken, error) {
var tokens []*model.DeviceToken
err := r.db.Where("user_id = ? AND is_active = ?", userID, true).
Order("last_used_at DESC").
Find(&tokens).Error
return tokens, err
}
// GetByDeviceID 根据设备ID获取设备Token
func (r *DeviceTokenRepository) GetByDeviceID(deviceID string) (*model.DeviceToken, error) {
var token model.DeviceToken
err := r.db.Where("device_id = ?", deviceID).First(&token).Error
if err != nil {
return nil, err
}
return &token, nil
}
// GetByPushToken 根据推送Token获取设备信息
func (r *DeviceTokenRepository) GetByPushToken(pushToken string) (*model.DeviceToken, error) {
var token model.DeviceToken
err := r.db.Where("push_token = ?", pushToken).First(&token).Error
if err != nil {
return nil, err
}
return &token, nil
}
// DeactivateAllExcept 登出其他设备(停用除指定设备外的所有设备)
func (r *DeviceTokenRepository) DeactivateAllExcept(userID int64, deviceID string) error {
return r.db.Model(&model.DeviceToken{}).
Where("user_id = ? AND device_id != ?", userID, deviceID).
Update("is_active", false).Error
}
// Upsert 创建或更新设备Token
// 如果设备ID已存在则更新Token和激活状态否则创建新记录
func (r *DeviceTokenRepository) Upsert(token *model.DeviceToken) error {
var existing model.DeviceToken
err := r.db.Where("device_id = ?", token.DeviceID).First(&existing).Error
if err == gorm.ErrRecordNotFound {
// 创建新记录
return r.db.Create(token).Error
} else if err != nil {
return err
}
// 更新现有记录
return r.db.Model(&existing).Updates(map[string]interface{}{
"push_token": token.PushToken,
"is_active": true,
"device_name": token.DeviceName,
"last_used_at": time.Now(),
}).Error
}
// UpdateLastUsed 更新最后使用时间
func (r *DeviceTokenRepository) UpdateLastUsed(deviceID string) error {
return r.db.Model(&model.DeviceToken{}).
Where("device_id = ?", deviceID).
Update("last_used_at", time.Now()).Error
}
// Deactivate 停用设备
func (r *DeviceTokenRepository) Deactivate(deviceID string) error {
return r.db.Model(&model.DeviceToken{}).
Where("device_id = ?", deviceID).
Update("is_active", false).Error
}
// Activate 激活设备
func (r *DeviceTokenRepository) Activate(deviceID string) error {
return r.db.Model(&model.DeviceToken{}).
Where("device_id = ?", deviceID).
Updates(map[string]interface{}{
"is_active": true,
"last_used_at": time.Now(),
}).Error
}
// DeleteByUserID 删除用户所有设备Token
func (r *DeviceTokenRepository) DeleteByUserID(userID int64) error {
return r.db.Where("user_id = ?", userID).Delete(&model.DeviceToken{}).Error
}
// GetDeviceCountByUserID 获取用户设备数量
func (r *DeviceTokenRepository) GetDeviceCountByUserID(userID int64) (int64, error) {
var count int64
err := r.db.Model(&model.DeviceToken{}).
Where("user_id = ?", userID).
Count(&count).Error
return count, err
}
// GetActiveDeviceCountByUserID 获取用户活跃设备数量
func (r *DeviceTokenRepository) GetActiveDeviceCountByUserID(userID int64) (int64, error) {
var count int64
err := r.db.Model(&model.DeviceToken{}).
Where("user_id = ? AND is_active = ?", userID, true).
Count(&count).Error
return count, err
}
// DeleteInactiveDevices 删除长时间未使用的设备
func (r *DeviceTokenRepository) DeleteInactiveDevices(before time.Time) error {
return r.db.Where("is_active = ? AND last_used_at < ?", false, before).
Delete(&model.DeviceToken{}).Error
}

View File

@@ -0,0 +1,50 @@
package repository
import (
"carrot_bbs/internal/model"
"gorm.io/gorm"
)
type GroupJoinRequestRepository interface {
Create(req *model.GroupJoinRequest) error
GetByFlag(flag string) (*model.GroupJoinRequest, error)
Update(req *model.GroupJoinRequest) error
GetPendingByGroupAndTarget(groupID, targetUserID string, reqType model.GroupJoinRequestType) (*model.GroupJoinRequest, error)
}
type groupJoinRequestRepository struct {
db *gorm.DB
}
func NewGroupJoinRequestRepository(db *gorm.DB) GroupJoinRequestRepository {
return &groupJoinRequestRepository{db: db}
}
func (r *groupJoinRequestRepository) Create(req *model.GroupJoinRequest) error {
return r.db.Create(req).Error
}
func (r *groupJoinRequestRepository) GetByFlag(flag string) (*model.GroupJoinRequest, error) {
var req model.GroupJoinRequest
if err := r.db.Where("flag = ?", flag).First(&req).Error; err != nil {
return nil, err
}
return &req, nil
}
func (r *groupJoinRequestRepository) Update(req *model.GroupJoinRequest) error {
return r.db.Save(req).Error
}
func (r *groupJoinRequestRepository) GetPendingByGroupAndTarget(groupID, targetUserID string, reqType model.GroupJoinRequestType) (*model.GroupJoinRequest, error) {
var req model.GroupJoinRequest
err := r.db.Where("group_id = ? AND target_user_id = ? AND request_type = ? AND status = ?",
groupID, targetUserID, reqType, model.GroupJoinRequestStatusPending).
Order("created_at DESC").
First(&req).Error
if err != nil {
return nil, err
}
return &req, nil
}

View File

@@ -0,0 +1,242 @@
package repository
import (
"carrot_bbs/internal/model"
"gorm.io/gorm"
)
// GroupRepository 群组仓库接口
type GroupRepository interface {
// 群组操作
Create(group *model.Group) error
GetByID(id string) (*model.Group, error)
Update(group *model.Group) error
Delete(id string) error
GetByOwnerID(ownerID string, page, pageSize int) ([]model.Group, int64, error)
// 群成员操作
AddMember(member *model.GroupMember) error
GetMember(groupID string, userID string) (*model.GroupMember, error)
GetMembers(groupID string, page, pageSize int) ([]model.GroupMember, int64, error)
UpdateMember(member *model.GroupMember) error
RemoveMember(groupID string, userID string) error
GetMemberCount(groupID string) (int64, error)
IsMember(groupID string, userID string) (bool, error)
GetUserGroups(userID string, page, pageSize int) ([]model.Group, int64, error)
// 角色相关
GetMemberRole(groupID string, userID string) (string, error)
SetMemberRole(groupID string, userID string, role string) error
GetAdmins(groupID string) ([]model.GroupMember, error)
// 群公告操作
CreateAnnouncement(announcement *model.GroupAnnouncement) error
GetAnnouncements(groupID string, page, pageSize int) ([]model.GroupAnnouncement, int64, error)
GetAnnouncementByID(id string) (*model.GroupAnnouncement, error)
DeleteAnnouncement(id string) error
}
// groupRepository 群组仓库实现
type groupRepository struct {
db *gorm.DB
}
// NewGroupRepository 创建群组仓库
func NewGroupRepository(db *gorm.DB) GroupRepository {
return &groupRepository{db: db}
}
// Create 创建群组
func (r *groupRepository) Create(group *model.Group) error {
return r.db.Create(group).Error
}
// GetByID 根据ID获取群组
func (r *groupRepository) GetByID(id string) (*model.Group, error) {
var group model.Group
err := r.db.First(&group, "id = ?", id).Error
if err != nil {
return nil, err
}
return &group, nil
}
// Update 更新群组
func (r *groupRepository) Update(group *model.Group) error {
return r.db.Save(group).Error
}
// Delete 删除群组
func (r *groupRepository) Delete(id string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 删除群成员
if err := tx.Where("group_id = ?", id).Delete(&model.GroupMember{}).Error; err != nil {
return err
}
// 删除群公告
if err := tx.Where("group_id = ?", id).Delete(&model.GroupAnnouncement{}).Error; err != nil {
return err
}
// 删除群组
if err := tx.Delete(&model.Group{}, "id = ?", id).Error; err != nil {
return err
}
return nil
})
}
// GetByOwnerID 根据群主ID获取群组列表
func (r *groupRepository) GetByOwnerID(ownerID string, page, pageSize int) ([]model.Group, int64, error) {
var groups []model.Group
var total int64
query := r.db.Model(&model.Group{}).Where("owner_id = ?", ownerID)
query.Count(&total)
offset := (page - 1) * pageSize
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&groups).Error
return groups, total, err
}
// AddMember 添加群成员
func (r *groupRepository) AddMember(member *model.GroupMember) error {
return r.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Create(member).Error; err != nil {
return err
}
// 更新群组成员数量
return tx.Model(&model.Group{}).Where("id = ?", member.GroupID).
Update("member_count", gorm.Expr("member_count + ?", 1)).Error
})
}
// GetMember 获取群成员
func (r *groupRepository) GetMember(groupID string, userID string) (*model.GroupMember, error) {
var member model.GroupMember
err := r.db.First(&member, "group_id = ? AND user_id = ?", groupID, userID).Error
if err != nil {
return nil, err
}
return &member, nil
}
// GetMembers 获取群成员列表
func (r *groupRepository) GetMembers(groupID string, page, pageSize int) ([]model.GroupMember, int64, error) {
var members []model.GroupMember
var total int64
query := r.db.Model(&model.GroupMember{}).Where("group_id = ?", groupID)
query.Count(&total)
offset := (page - 1) * pageSize
err := query.Offset(offset).Limit(pageSize).Order("created_at ASC").Find(&members).Error
return members, total, err
}
// UpdateMember 更新群成员
func (r *groupRepository) UpdateMember(member *model.GroupMember) error {
return r.db.Save(member).Error
}
// RemoveMember 移除群成员
func (r *groupRepository) RemoveMember(groupID string, userID string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 删除成员
if err := tx.Where("group_id = ? AND user_id = ?", groupID, userID).Delete(&model.GroupMember{}).Error; err != nil {
return err
}
// 更新群组成员数量
return tx.Model(&model.Group{}).Where("id = ?", groupID).
Update("member_count", gorm.Expr("member_count - ?", 1)).Error
})
}
// GetMemberCount 获取群成员数量
func (r *groupRepository) GetMemberCount(groupID string) (int64, error) {
var count int64
err := r.db.Model(&model.GroupMember{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err
}
// IsMember 检查是否是群成员
func (r *groupRepository) IsMember(groupID string, userID string) (bool, error) {
var count int64
err := r.db.Model(&model.GroupMember{}).Where("group_id = ? AND user_id = ?", groupID, userID).Count(&count).Error
return count > 0, err
}
// GetUserGroups 获取用户加入的群组列表
func (r *groupRepository) GetUserGroups(userID string, page, pageSize int) ([]model.Group, int64, error) {
var groups []model.Group
var total int64
// 通过群成员表查询用户加入的群组
subQuery := r.db.Model(&model.GroupMember{}).
Select("group_id").
Where("user_id = ?", userID)
query := r.db.Model(&model.Group{}).Where("id IN (?)", subQuery)
query.Count(&total)
offset := (page - 1) * pageSize
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&groups).Error
return groups, total, err
}
// GetMemberRole 获取成员角色
func (r *groupRepository) GetMemberRole(groupID string, userID string) (string, error) {
member, err := r.GetMember(groupID, userID)
if err != nil {
return "", err
}
return member.Role, nil
}
// SetMemberRole 设置成员角色
func (r *groupRepository) SetMemberRole(groupID string, userID string, role string) error {
return r.db.Model(&model.GroupMember{}).
Where("group_id = ? AND user_id = ?", groupID, userID).
Update("role", role).Error
}
// GetAdmins 获取群管理员列表
func (r *groupRepository) GetAdmins(groupID string) ([]model.GroupMember, error) {
var admins []model.GroupMember
err := r.db.Where("group_id = ? AND role = ?", groupID, model.GroupRoleAdmin).Find(&admins).Error
return admins, err
}
// CreateAnnouncement 创建群公告
func (r *groupRepository) CreateAnnouncement(announcement *model.GroupAnnouncement) error {
return r.db.Create(announcement).Error
}
// GetAnnouncements 获取群公告列表
func (r *groupRepository) GetAnnouncements(groupID string, page, pageSize int) ([]model.GroupAnnouncement, int64, error) {
var announcements []model.GroupAnnouncement
var total int64
query := r.db.Model(&model.GroupAnnouncement{}).Where("group_id = ?", groupID)
query.Count(&total)
offset := (page - 1) * pageSize
// 置顶的排在前面,然后按时间倒序
err := query.Offset(offset).Limit(pageSize).Order("is_pinned DESC, created_at DESC").Find(&announcements).Error
return announcements, total, err
}
// GetAnnouncementByID 根据ID获取群公告
func (r *groupRepository) GetAnnouncementByID(id string) (*model.GroupAnnouncement, error) {
var announcement model.GroupAnnouncement
err := r.db.First(&announcement, "id = ?", id).Error
if err != nil {
return nil, err
}
return &announcement, nil
}
// DeleteAnnouncement 删除群公告
func (r *groupRepository) DeleteAnnouncement(id string) error {
return r.db.Delete(&model.GroupAnnouncement{}, "id = ?", id).Error
}

View File

@@ -0,0 +1,543 @@
package repository
import (
"carrot_bbs/internal/model"
"fmt"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// MessageRepository 消息仓储
type MessageRepository struct {
db *gorm.DB
}
// NewMessageRepository 创建消息仓储
func NewMessageRepository(db *gorm.DB) *MessageRepository {
return &MessageRepository{db: db}
}
// CreateMessage 创建消息
func (r *MessageRepository) CreateMessage(msg *model.Message) error {
return r.db.Create(msg).Error
}
// GetConversation 获取会话
func (r *MessageRepository) GetConversation(id string) (*model.Conversation, error) {
var conv model.Conversation
err := r.db.Preload("Group").First(&conv, "id = ?", id).Error
if err != nil {
return nil, err
}
return &conv, nil
}
// GetOrCreatePrivateConversation 获取或创建私聊会话
// 使用参与者关系表来管理会话
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) GetOrCreatePrivateConversation(user1ID, user2ID string) (*model.Conversation, error) {
var conv model.Conversation
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: user1ID=%s, user2ID=%s\n", user1ID, user2ID)
// 查找两个用户共同参与的私聊会话
err := r.db.Table("conversations c").
Joins("INNER JOIN conversation_participants cp1 ON c.id = cp1.conversation_id AND cp1.user_id = ?", user1ID).
Joins("INNER JOIN conversation_participants cp2 ON c.id = cp2.conversation_id AND cp2.user_id = ?", user2ID).
Where("c.type = ?", model.ConversationTypePrivate).
First(&conv).Error
if err == nil {
_ = r.db.Model(&model.ConversationParticipant{}).
Where("conversation_id = ? AND user_id IN ?", conv.ID, []string{user1ID, user2ID}).
Update("hidden_at", nil).Error
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: found existing conversation, ID=%s\n", conv.ID)
return &conv, nil
}
if err != gorm.ErrRecordNotFound {
return nil, err
}
// 没找到会话,创建新会话
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: no existing conversation found, creating new one\n")
conv = model.Conversation{
Type: model.ConversationTypePrivate,
}
// 使用事务创建会话和参与者
err = r.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Create(&conv).Error; err != nil {
return err
}
// 创建参与者记录 - UserID 存储为 string (UUID)
participants := []model.ConversationParticipant{
{ConversationID: conv.ID, UserID: user1ID},
{ConversationID: conv.ID, UserID: user2ID},
}
if err := tx.Create(&participants).Error; err != nil {
return err
}
return nil
})
if err == nil {
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: created new conversation, ID=%s\n", conv.ID)
}
return &conv, err
}
// GetConversations 获取用户会话列表
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) GetConversations(userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
var convs []*model.Conversation
var total int64
// 获取总数
r.db.Model(&model.ConversationParticipant{}).
Where("user_id = ? AND hidden_at IS NULL", userID).
Count(&total)
if total == 0 {
return convs, total, nil
}
offset := (page - 1) * pageSize
// 查询会话列表并预加载关联数据:
// 当前用户维度先按置顶排序,再按更新时间排序
err := r.db.Model(&model.Conversation{}).
Joins("INNER JOIN conversation_participants cp ON conversations.id = cp.conversation_id").
Where("cp.user_id = ? AND cp.hidden_at IS NULL", userID).
Preload("Group").
Offset(offset).
Limit(pageSize).
Order("cp.is_pinned DESC").
Order("conversations.updated_at DESC").
Find(&convs).Error
return convs, total, err
}
// GetMessages 获取会话消息
func (r *MessageRepository) GetMessages(conversationID string, page, pageSize int) ([]*model.Message, int64, error) {
var messages []*model.Message
var total int64
r.db.Model(&model.Message{}).Where("conversation_id = ?", conversationID).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("conversation_id = ?", conversationID).
Offset(offset).
Limit(pageSize).
Order("seq DESC").
Find(&messages).Error
return messages, total, err
}
// GetMessagesAfterSeq 获取指定seq之后的消息用于增量同步
func (r *MessageRepository) GetMessagesAfterSeq(conversationID string, afterSeq int64, limit int) ([]*model.Message, error) {
var messages []*model.Message
err := r.db.Where("conversation_id = ? AND seq > ?", conversationID, afterSeq).
Order("seq ASC").
Limit(limit).
Find(&messages).Error
return messages, err
}
// GetMessagesBeforeSeq 获取指定seq之前的历史消息用于下拉加载更多
func (r *MessageRepository) GetMessagesBeforeSeq(conversationID string, beforeSeq int64, limit int) ([]*model.Message, error) {
var messages []*model.Message
fmt.Printf("[DEBUG] GetMessagesBeforeSeq: conversationID=%s, beforeSeq=%d, limit=%d\n", conversationID, beforeSeq, limit)
err := r.db.Where("conversation_id = ? AND seq < ?", conversationID, beforeSeq).
Order("seq DESC"). // 降序获取最新消息在前
Limit(limit).
Find(&messages).Error
fmt.Printf("[DEBUG] GetMessagesBeforeSeq: found %d messages, seq range: ", len(messages))
for i, m := range messages {
if i < 5 || i >= len(messages)-2 {
fmt.Printf("%d ", m.Seq)
} else if i == 5 {
fmt.Printf("... ")
}
}
fmt.Println()
// 反转回正序
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
messages[i], messages[j] = messages[j], messages[i]
}
return messages, err
}
// GetConversationParticipants 获取会话参与者
func (r *MessageRepository) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) {
var participants []*model.ConversationParticipant
err := r.db.Where("conversation_id = ?", conversationID).Find(&participants).Error
return participants, err
}
// GetParticipant 获取用户在会话中的参与者信息
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) GetParticipant(conversationID string, userID string) (*model.ConversationParticipant, error) {
var participant model.ConversationParticipant
err := r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).First(&participant).Error
if err != nil {
// 如果找不到参与者,尝试添加(修复没有参与者记录的问题)
if err == gorm.ErrRecordNotFound {
// 检查会话是否存在
var conv model.Conversation
if err := r.db.First(&conv, conversationID).Error; err == nil {
// 会话存在,添加参与者
participant = model.ConversationParticipant{
ConversationID: conversationID,
UserID: userID,
}
if err := r.db.Create(&participant).Error; err != nil {
return nil, err
}
return &participant, nil
}
}
return nil, err
}
return &participant, nil
}
// UpdateLastReadSeq 更新已读位置
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) UpdateLastReadSeq(conversationID string, userID string, lastReadSeq int64) error {
result := r.db.Model(&model.ConversationParticipant{}).
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
Update("last_read_seq", lastReadSeq)
if result.Error != nil {
return result.Error
}
// 如果没有更新任何记录,说明参与者记录不存在,需要插入
if result.RowsAffected == 0 {
// 尝试插入新记录(跨数据库 upsert
err := r.db.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "conversation_id"},
{Name: "user_id"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"last_read_seq": lastReadSeq,
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
}),
}).Create(&model.ConversationParticipant{
ConversationID: conversationID,
UserID: userID,
LastReadSeq: lastReadSeq,
}).Error
if err != nil {
return err
}
}
return nil
}
// UpdatePinned 更新会话置顶状态(用户维度)
func (r *MessageRepository) UpdatePinned(conversationID string, userID string, isPinned bool) error {
result := r.db.Model(&model.ConversationParticipant{}).
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
Update("is_pinned", isPinned)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return r.db.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "conversation_id"},
{Name: "user_id"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"is_pinned": isPinned,
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
}),
}).Create(&model.ConversationParticipant{
ConversationID: conversationID,
UserID: userID,
IsPinned: isPinned,
}).Error
}
return nil
}
// GetUnreadCount 获取未读消息数
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) GetUnreadCount(conversationID string, userID string) (int64, error) {
var participant model.ConversationParticipant
err := r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).First(&participant).Error
if err != nil {
return 0, err
}
var count int64
err = r.db.Model(&model.Message{}).
Where("conversation_id = ? AND sender_id != ? AND seq > ?", conversationID, userID, participant.LastReadSeq).
Count(&count).Error
return count, err
}
// UpdateConversationLastSeq 更新会话的最后消息seq和时间
func (r *MessageRepository) UpdateConversationLastSeq(conversationID string, seq int64) error {
return r.db.Model(&model.Conversation{}).
Where("id = ?", conversationID).
Updates(map[string]interface{}{
"last_seq": seq,
"last_msg_time": gorm.Expr("CURRENT_TIMESTAMP"),
}).Error
}
// GetNextSeq 获取会话的下一个seq值
func (r *MessageRepository) GetNextSeq(conversationID string) (int64, error) {
var conv model.Conversation
err := r.db.Select("last_seq").First(&conv, conversationID).Error
if err != nil {
return 0, err
}
return conv.LastSeq + 1, nil
}
// CreateMessageWithSeq 创建消息并更新seq事务操作
func (r *MessageRepository) CreateMessageWithSeq(msg *model.Message) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 获取当前seq并+1
var conv model.Conversation
if err := tx.Select("last_seq").First(&conv, msg.ConversationID).Error; err != nil {
return err
}
msg.Seq = conv.LastSeq + 1
// 创建消息
if err := tx.Create(msg).Error; err != nil {
return err
}
// 更新会话的last_seq
if err := tx.Model(&model.Conversation{}).
Where("id = ?", msg.ConversationID).
Updates(map[string]interface{}{
"last_seq": msg.Seq,
"last_msg_time": gorm.Expr("CURRENT_TIMESTAMP"),
}).Error; err != nil {
return err
}
// 新消息到达后,自动恢复被“仅自己删除”的会话
if err := tx.Model(&model.ConversationParticipant{}).
Where("conversation_id = ?", msg.ConversationID).
Update("hidden_at", nil).Error; err != nil {
return err
}
return nil
})
}
// GetAllUnreadCount 获取用户所有会话的未读消息总数
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) GetAllUnreadCount(userID string) (int64, error) {
var totalUnread int64
err := r.db.Table("conversation_participants AS cp").
Joins("LEFT JOIN messages AS m ON m.conversation_id = cp.conversation_id AND m.sender_id <> ? AND m.seq > cp.last_read_seq AND m.deleted_at IS NULL", userID).
Where("cp.user_id = ?", userID).
Select("COALESCE(COUNT(m.id), 0)").
Scan(&totalUnread).Error
return totalUnread, err
}
// GetMessageByID 根据ID获取消息
func (r *MessageRepository) GetMessageByID(messageID string) (*model.Message, error) {
var message model.Message
err := r.db.First(&message, "id = ?", messageID).Error
if err != nil {
return nil, err
}
return &message, nil
}
// CountMessagesBySenderInConversation 统计会话中某用户已发送消息数
func (r *MessageRepository) CountMessagesBySenderInConversation(conversationID, senderID string) (int64, error) {
var count int64
err := r.db.Model(&model.Message{}).
Where("conversation_id = ? AND sender_id = ?", conversationID, senderID).
Count(&count).Error
return count, err
}
// UpdateMessageStatus 更新消息状态
func (r *MessageRepository) UpdateMessageStatus(messageID int64, status model.MessageStatus) error {
return r.db.Model(&model.Message{}).
Where("id = ?", messageID).
Update("status", status).Error
}
// GetOrCreateSystemParticipant 获取或创建用户在系统会话中的参与者记录
// 系统会话是虚拟会话,但需要参与者记录来跟踪已读状态
func (r *MessageRepository) GetOrCreateSystemParticipant(userID string) (*model.ConversationParticipant, error) {
var participant model.ConversationParticipant
err := r.db.Where("conversation_id = ? AND user_id = ?",
model.SystemConversationID, userID).First(&participant).Error
if err == nil {
return &participant, nil
}
if err != gorm.ErrRecordNotFound {
return nil, err
}
// 自动创建参与者记录
participant = model.ConversationParticipant{
ConversationID: model.SystemConversationID,
UserID: userID,
LastReadSeq: 0,
}
if err := r.db.Create(&participant).Error; err != nil {
return nil, err
}
return &participant, nil
}
// GetSystemMessagesUnreadCount 获取系统消息未读数
func (r *MessageRepository) GetSystemMessagesUnreadCount(userID string) (int64, error) {
// 获取或创建参与者记录
participant, err := r.GetOrCreateSystemParticipant(userID)
if err != nil {
return 0, err
}
// 计算未读数:查询 seq > last_read_seq 的消息
var count int64
err = r.db.Model(&model.Message{}).
Where("conversation_id = ? AND seq > ?",
model.SystemConversationID, participant.LastReadSeq).
Count(&count).Error
return count, err
}
// MarkAllSystemMessagesAsRead 标记所有系统消息已读
func (r *MessageRepository) MarkAllSystemMessagesAsRead(userID string) error {
// 获取系统会话的最新 seq
var maxSeq int64
err := r.db.Model(&model.Message{}).
Where("conversation_id = ?", model.SystemConversationID).
Select("COALESCE(MAX(seq), 0)").
Scan(&maxSeq).Error
if err != nil {
return err
}
// 使用跨数据库 upsert 方式更新或创建参与者记录
return r.db.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "conversation_id"},
{Name: "user_id"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"last_read_seq": maxSeq,
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
}),
}).Create(&model.ConversationParticipant{
ConversationID: model.SystemConversationID,
UserID: userID,
LastReadSeq: maxSeq,
}).Error
}
// GetConversationByGroupID 通过群组ID获取会话
func (r *MessageRepository) GetConversationByGroupID(groupID string) (*model.Conversation, error) {
var conv model.Conversation
err := r.db.Where("group_id = ?", groupID).First(&conv).Error
if err != nil {
return nil, err
}
return &conv, nil
}
// RemoveParticipant 移除会话参与者
// 当用户退出群聊时,需要同时移除其在对应会话中的参与者记录
func (r *MessageRepository) RemoveParticipant(conversationID string, userID string) error {
return r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).
Delete(&model.ConversationParticipant{}).Error
}
// AddParticipant 添加会话参与者
// 当用户加入群聊时,需要同时将其添加到对应会话的参与者记录
func (r *MessageRepository) AddParticipant(conversationID string, userID string) error {
// 先检查是否已经是参与者
var count int64
err := r.db.Model(&model.ConversationParticipant{}).
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
Count(&count).Error
if err != nil {
return err
}
// 如果已经是参与者,直接返回
if count > 0 {
return nil
}
// 添加参与者
participant := model.ConversationParticipant{
ConversationID: conversationID,
UserID: userID,
LastReadSeq: 0,
}
return r.db.Create(&participant).Error
}
// DeleteConversationByGroupID 删除群组对应的会话及其参与者
// 当解散群组时调用
func (r *MessageRepository) DeleteConversationByGroupID(groupID string) error {
// 获取群组对应的会话
conv, err := r.GetConversationByGroupID(groupID)
if err != nil {
// 如果会话不存在,直接返回
if err == gorm.ErrRecordNotFound {
return nil
}
return err
}
return r.db.Transaction(func(tx *gorm.DB) error {
// 删除会话参与者
if err := tx.Where("conversation_id = ?", conv.ID).Delete(&model.ConversationParticipant{}).Error; err != nil {
return err
}
// 删除会话中的消息
if err := tx.Where("conversation_id = ?", conv.ID).Delete(&model.Message{}).Error; err != nil {
return err
}
// 删除会话
if err := tx.Delete(&model.Conversation{}, "id = ?", conv.ID).Error; err != nil {
return err
}
return nil
})
}
// HideConversationForUser 仅对当前用户隐藏会话(私聊删除)
func (r *MessageRepository) HideConversationForUser(conversationID, userID string) error {
now := time.Now()
return r.db.Model(&model.ConversationParticipant{}).
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
Update("hidden_at", &now).Error
}

View File

@@ -0,0 +1,78 @@
package repository
import (
"carrot_bbs/internal/model"
"gorm.io/gorm"
)
// NotificationRepository 通知仓储
type NotificationRepository struct {
db *gorm.DB
}
// NewNotificationRepository 创建通知仓储
func NewNotificationRepository(db *gorm.DB) *NotificationRepository {
return &NotificationRepository{db: db}
}
// Create 创建通知
func (r *NotificationRepository) Create(notification *model.Notification) error {
return r.db.Create(notification).Error
}
// GetByID 根据ID获取通知
func (r *NotificationRepository) GetByID(id string) (*model.Notification, error) {
var notification model.Notification
err := r.db.First(&notification, "id = ?", id).Error
if err != nil {
return nil, err
}
return &notification, nil
}
// GetByUserID 获取用户通知
func (r *NotificationRepository) GetByUserID(userID string, page, pageSize int, unreadOnly bool) ([]*model.Notification, int64, error) {
var notifications []*model.Notification
var total int64
query := r.db.Model(&model.Notification{}).Where("user_id = ?", userID)
if unreadOnly {
query = query.Where("is_read = ?", false)
}
query.Count(&total)
offset := (page - 1) * pageSize
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&notifications).Error
return notifications, total, err
}
// MarkAsRead 标记为已读
func (r *NotificationRepository) MarkAsRead(id string) error {
return r.db.Model(&model.Notification{}).Where("id = ?", id).Update("is_read", true).Error
}
// MarkAllAsRead 标记所有为已读
func (r *NotificationRepository) MarkAllAsRead(userID string) error {
return r.db.Model(&model.Notification{}).Where("user_id = ?", userID).Update("is_read", true).Error
}
// Delete 删除通知
func (r *NotificationRepository) Delete(id string) error {
return r.db.Delete(&model.Notification{}, "id = ?", id).Error
}
// GetUnreadCount 获取未读数量
func (r *NotificationRepository) GetUnreadCount(userID string) (int64, error) {
var count int64
err := r.db.Model(&model.Notification{}).Where("user_id = ? AND is_read = ?", userID, false).Count(&count).Error
return count, err
}
// DeleteAllByUserID 删除用户所有通知
func (r *NotificationRepository) DeleteAllByUserID(userID string) error {
return r.db.Where("user_id = ?", userID).Delete(&model.Notification{}).Error
}

View File

@@ -0,0 +1,360 @@
package repository
import (
"carrot_bbs/internal/model"
"gorm.io/gorm"
)
// PostRepository 帖子仓储
type PostRepository struct {
db *gorm.DB
}
// NewPostRepository 创建帖子仓储
func NewPostRepository(db *gorm.DB) *PostRepository {
return &PostRepository{db: db}
}
// Create 创建帖子
func (r *PostRepository) Create(post *model.Post, images []string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 创建帖子
if err := tx.Create(post).Error; err != nil {
return err
}
// 创建图片记录
for i, url := range images {
image := &model.PostImage{
PostID: post.ID,
URL: url,
SortOrder: i,
}
if err := tx.Create(image).Error; err != nil {
return err
}
}
return nil
})
}
// GetByID 根据ID获取帖子
func (r *PostRepository) GetByID(id string) (*model.Post, error) {
var post model.Post
err := r.db.Preload("User").Preload("Images").First(&post, "id = ?", id).Error
if err != nil {
return nil, err
}
return &post, nil
}
// Update 更新帖子
func (r *PostRepository) Update(post *model.Post) error {
return r.db.Save(post).Error
}
// UpdateModerationStatus 更新帖子审核状态
func (r *PostRepository) UpdateModerationStatus(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error {
updates := map[string]interface{}{
"status": status,
"reviewed_at": gorm.Expr("CURRENT_TIMESTAMP"),
"reviewed_by": reviewedBy,
"reject_reason": rejectReason,
}
return r.db.Model(&model.Post{}).Where("id = ?", postID).Updates(updates).Error
}
// Delete 删除帖子(软删除,同时清理关联数据)
func (r *PostRepository) Delete(id string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 删除帖子图片
if err := tx.Where("post_id = ?", id).Delete(&model.PostImage{}).Error; err != nil {
return err
}
// 删除帖子点赞记录
if err := tx.Where("post_id = ?", id).Delete(&model.PostLike{}).Error; err != nil {
return err
}
// 删除帖子收藏记录
if err := tx.Where("post_id = ?", id).Delete(&model.Favorite{}).Error; err != nil {
return err
}
// 删除评论点赞记录子查询获取该帖子所有评论ID
if err := tx.Where("comment_id IN (SELECT id FROM comments WHERE post_id = ?)", id).Delete(&model.CommentLike{}).Error; err != nil {
return err
}
// 删除帖子评论
if err := tx.Where("post_id = ?", id).Delete(&model.Comment{}).Error; err != nil {
return err
}
// 最后删除帖子本身(软删除)
return tx.Delete(&model.Post{}, "id = ?", id).Error
})
}
// List 分页获取帖子列表
func (r *PostRepository) List(page, pageSize int, userID string) ([]*model.Post, int64, error) {
var posts []*model.Post
var total int64
query := r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished)
if userID != "" {
query = query.Where("user_id = ?", userID)
}
query.Count(&total)
offset := (page - 1) * pageSize
err := query.Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
return posts, total, err
}
// GetUserPosts 获取用户帖子
func (r *PostRepository) GetUserPosts(userID string, page, pageSize int) ([]*model.Post, int64, error) {
var posts []*model.Post
var total int64
r.db.Model(&model.Post{}).Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
return posts, total, err
}
// GetFavorites 获取用户收藏
func (r *PostRepository) GetFavorites(userID string, page, pageSize int) ([]*model.Post, int64, error) {
var posts []*model.Post
var total int64
subQuery := r.db.Model(&model.Favorite{}).Where("user_id = ?", userID).Select("post_id")
r.db.Model(&model.Post{}).Where("id IN (?) AND status = ?", subQuery, model.PostStatusPublished).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("id IN (?) AND status = ?", subQuery, model.PostStatusPublished).Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
return posts, total, err
}
// Like 点赞帖子
func (r *PostRepository) Like(postID, userID string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 检查是否已经点赞
var existing model.PostLike
err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&existing).Error
if err == nil {
// 已经点赞,直接返回
return nil
}
if err != gorm.ErrRecordNotFound {
return err
}
// 创建点赞记录
if err := tx.Create(&model.PostLike{
PostID: postID,
UserID: userID,
}).Error; err != nil {
return err
}
// 增加帖子点赞数并同步热度分
return tx.Model(&model.Post{}).Where("id = ?", postID).
Updates(map[string]interface{}{
"likes_count": gorm.Expr("likes_count + 1"),
"hot_score": gorm.Expr("(likes_count + 1) * 2 + comments_count * 3 + views_count * 0.1"),
}).Error
})
}
// Unlike 取消点赞
func (r *PostRepository) Unlike(postID, userID string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
result := tx.Where("post_id = ? AND user_id = ?", postID, userID).Delete(&model.PostLike{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected > 0 {
// 减少帖子点赞数并同步热度分
return tx.Model(&model.Post{}).Where("id = ?", postID).
Updates(map[string]interface{}{
"likes_count": gorm.Expr("likes_count - 1"),
"hot_score": gorm.Expr("(likes_count - 1) * 2 + comments_count * 3 + views_count * 0.1"),
}).Error
}
return nil
})
}
// IsLiked 检查是否点赞
func (r *PostRepository) IsLiked(postID, userID string) bool {
var count int64
r.db.Model(&model.PostLike{}).Where("post_id = ? AND user_id = ?", postID, userID).Count(&count)
return count > 0
}
// Favorite 收藏帖子
func (r *PostRepository) Favorite(postID, userID string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 检查是否已经收藏
var existing model.Favorite
err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&existing).Error
if err == nil {
// 已经收藏,直接返回
return nil
}
if err != gorm.ErrRecordNotFound {
return err
}
// 创建收藏记录
if err := tx.Create(&model.Favorite{
PostID: postID,
UserID: userID,
}).Error; err != nil {
return err
}
// 增加帖子收藏数
return tx.Model(&model.Post{}).Where("id = ?", postID).
UpdateColumn("favorites_count", gorm.Expr("favorites_count + 1")).Error
})
}
// Unfavorite 取消收藏
func (r *PostRepository) Unfavorite(postID, userID string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
result := tx.Where("post_id = ? AND user_id = ?", postID, userID).Delete(&model.Favorite{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected > 0 {
// 减少帖子收藏数
return tx.Model(&model.Post{}).Where("id = ?", postID).
UpdateColumn("favorites_count", gorm.Expr("favorites_count - 1")).Error
}
return nil
})
}
// IsFavorited 检查是否收藏
func (r *PostRepository) IsFavorited(postID, userID string) bool {
var count int64
r.db.Model(&model.Favorite{}).Where("post_id = ? AND user_id = ?", postID, userID).Count(&count)
return count > 0
}
// IncrementViews 增加帖子观看量
func (r *PostRepository) IncrementViews(postID string) error {
return r.db.Model(&model.Post{}).Where("id = ?", postID).
Updates(map[string]interface{}{
"views_count": gorm.Expr("views_count + 1"),
"hot_score": gorm.Expr("likes_count * 2 + comments_count * 3 + (views_count + 1) * 0.1"),
}).Error
}
// Search 搜索帖子
func (r *PostRepository) Search(keyword string, page, pageSize int) ([]*model.Post, int64, error) {
var posts []*model.Post
var total int64
query := r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished)
// 搜索标题和内容
if keyword != "" {
if r.db.Dialector.Name() == "postgres" {
// PostgreSQL 使用全文检索表达式,为 pg_trgm/GIN 索引升级预留路径
query = query.Where(
"to_tsvector('simple', COALESCE(title, '') || ' ' || COALESCE(content, '')) @@ plainto_tsquery('simple', ?)",
keyword,
)
} else {
searchPattern := "%" + keyword + "%"
query = query.Where("title LIKE ? OR content LIKE ?", searchPattern, searchPattern)
}
}
query.Count(&total)
offset := (page - 1) * pageSize
err := query.Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
return posts, total, err
}
// GetFollowingPosts 获取关注用户的帖子
func (r *PostRepository) GetFollowingPosts(userID string, page, pageSize int) ([]*model.Post, int64, error) {
var posts []*model.Post
var total int64
// 子查询获取当前用户关注的所有用户ID
subQuery := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Select("following_id")
// 统计总数
r.db.Model(&model.Post{}).Where("user_id IN (?) AND status = ?", subQuery, model.PostStatusPublished).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("user_id IN (?) AND status = ?", subQuery, model.PostStatusPublished).
Preload("User").Preload("Images").
Offset(offset).Limit(pageSize).
Order("created_at DESC").
Find(&posts).Error
return posts, total, err
}
// GetHotPosts 获取热门帖子(按点赞数和评论数排序)
func (r *PostRepository) GetHotPosts(page, pageSize int) ([]*model.Post, int64, error) {
var posts []*model.Post
var total int64
r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished).Count(&total)
offset := (page - 1) * pageSize
// 热门排序使用预计算热度分,避免每次请求进行表达式排序计算
err := r.db.Where("status = ?", model.PostStatusPublished).Preload("User").Preload("Images").
Offset(offset).Limit(pageSize).
Order("hot_score DESC, created_at DESC").
Find(&posts).Error
return posts, total, err
}
// GetByIDs 根据ID列表获取帖子保持传入顺序
func (r *PostRepository) GetByIDs(ids []string) ([]*model.Post, error) {
if len(ids) == 0 {
return []*model.Post{}, nil
}
var posts []*model.Post
err := r.db.Preload("User").Preload("Images").
Where("id IN ? AND status = ?", ids, model.PostStatusPublished).
Find(&posts).Error
if err != nil {
return nil, err
}
// 按传入ID顺序排序
postMap := make(map[string]*model.Post)
for _, post := range posts {
postMap[post.ID] = post
}
ordered := make([]*model.Post, 0, len(ids))
for _, id := range ids {
if post, ok := postMap[id]; ok {
ordered = append(ordered, post)
}
}
return ordered, nil
}

View File

@@ -0,0 +1,172 @@
package repository
import (
"time"
"carrot_bbs/internal/model"
"gorm.io/gorm"
)
// PushRecordRepository 推送记录仓储
type PushRecordRepository struct {
db *gorm.DB
}
// NewPushRecordRepository 创建推送记录仓储
func NewPushRecordRepository(db *gorm.DB) *PushRecordRepository {
return &PushRecordRepository{db: db}
}
// Create 创建推送记录
func (r *PushRecordRepository) Create(record *model.PushRecord) error {
return r.db.Create(record).Error
}
// GetByID 根据ID获取推送记录
func (r *PushRecordRepository) GetByID(id int64) (*model.PushRecord, error) {
var record model.PushRecord
err := r.db.First(&record, "id = ?", id).Error
if err != nil {
return nil, err
}
return &record, nil
}
// Update 更新推送记录
func (r *PushRecordRepository) Update(record *model.PushRecord) error {
return r.db.Save(record).Error
}
// GetPendingPushes 获取待推送记录
func (r *PushRecordRepository) GetPendingPushes(limit int) ([]*model.PushRecord, error) {
var records []*model.PushRecord
err := r.db.Where("push_status = ?", model.PushStatusPending).
Where("expired_at IS NULL OR expired_at > ?", time.Now()).
Order("created_at ASC").
Limit(limit).
Find(&records).Error
return records, err
}
// GetByUserID 根据用户ID获取推送记录
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *PushRecordRepository) GetByUserID(userID string, limit, offset int) ([]*model.PushRecord, error) {
var records []*model.PushRecord
err := r.db.Where("user_id = ?", userID).
Order("created_at DESC").
Offset(offset).
Limit(limit).
Find(&records).Error
return records, err
}
// GetByMessageID 根据消息ID获取推送记录
func (r *PushRecordRepository) GetByMessageID(messageID int64) ([]*model.PushRecord, error) {
var records []*model.PushRecord
err := r.db.Where("message_id = ?", messageID).
Order("created_at DESC").
Find(&records).Error
return records, err
}
// GetFailedPushesForRetry 获取失败待重试的推送
func (r *PushRecordRepository) GetFailedPushesForRetry(limit int) ([]*model.PushRecord, error) {
var records []*model.PushRecord
err := r.db.Where("push_status = ?", model.PushStatusFailed).
Where("retry_count < max_retry").
Where("expired_at IS NULL OR expired_at > ?", time.Now()).
Order("created_at ASC").
Limit(limit).
Find(&records).Error
return records, err
}
// BatchCreate 批量创建推送记录
func (r *PushRecordRepository) BatchCreate(records []*model.PushRecord) error {
if len(records) == 0 {
return nil
}
return r.db.Create(&records).Error
}
// BatchUpdateStatus 批量更新推送状态
func (r *PushRecordRepository) BatchUpdateStatus(ids []int64, status model.PushStatus) error {
if len(ids) == 0 {
return nil
}
updates := map[string]interface{}{
"push_status": status,
}
if status == model.PushStatusPushed {
updates["pushed_at"] = time.Now()
}
return r.db.Model(&model.PushRecord{}).
Where("id IN ?", ids).
Updates(updates).Error
}
// UpdateStatus 更新单条记录状态
func (r *PushRecordRepository) UpdateStatus(id int64, status model.PushStatus) error {
updates := map[string]interface{}{
"push_status": status,
}
if status == model.PushStatusPushed {
updates["pushed_at"] = time.Now()
}
return r.db.Model(&model.PushRecord{}).
Where("id = ?", id).
Updates(updates).Error
}
// MarkAsFailed 标记为失败
func (r *PushRecordRepository) MarkAsFailed(id int64, errMsg string) error {
return r.db.Model(&model.PushRecord{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"push_status": model.PushStatusFailed,
"error_message": errMsg,
"retry_count": gorm.Expr("retry_count + 1"),
}).Error
}
// MarkAsDelivered 标记为已送达
func (r *PushRecordRepository) MarkAsDelivered(id int64) error {
return r.db.Model(&model.PushRecord{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"push_status": model.PushStatusDelivered,
"delivered_at": time.Now(),
}).Error
}
// DeleteExpiredRecords 删除过期的推送记录(软删除)
func (r *PushRecordRepository) DeleteExpiredRecords() error {
return r.db.Where("expired_at IS NOT NULL AND expired_at < ?", time.Now()).
Delete(&model.PushRecord{}).Error
}
// GetStatsByUserID 获取用户推送统计
func (r *PushRecordRepository) GetStatsByUserID(userID int64) (map[model.PushStatus]int64, error) {
type statusCount struct {
Status model.PushStatus
Count int64
}
var results []statusCount
err := r.db.Model(&model.PushRecord{}).
Select("push_status as status, count(*) as count").
Where("user_id = ?", userID).
Group("push_status").
Scan(&results).Error
if err != nil {
return nil, err
}
stats := make(map[model.PushStatus]int64)
for _, r := range results {
stats[r.Status] = r.Count
}
return stats, nil
}

View File

@@ -0,0 +1,112 @@
package repository
import (
"carrot_bbs/internal/model"
"gorm.io/gorm"
)
// StickerRepository 自定义表情仓库接口
type StickerRepository interface {
// 获取用户的所有表情
GetByUserID(userID string) ([]model.UserSticker, error)
// 根据ID获取表情
GetByID(id string) (*model.UserSticker, error)
// 创建表情
Create(sticker *model.UserSticker) error
// 删除表情
Delete(id string) error
// 删除用户的所有表情
DeleteByUserID(userID string) error
// 检查表情是否存在
Exists(userID string, url string) (bool, error)
// 更新排序
UpdateSortOrder(id string, sortOrder int) error
// 批量更新排序
BatchUpdateSortOrder(userID string, orders map[string]int) error
// 获取用户表情数量
CountByUserID(userID string) (int64, error)
}
// stickerRepository 自定义表情仓库实现
type stickerRepository struct {
db *gorm.DB
}
// NewStickerRepository 创建自定义表情仓库
func NewStickerRepository(db *gorm.DB) StickerRepository {
return &stickerRepository{db: db}
}
// GetByUserID 获取用户的所有表情
func (r *stickerRepository) GetByUserID(userID string) ([]model.UserSticker, error) {
var stickers []model.UserSticker
err := r.db.Where("user_id = ?", userID).
Order("sort_order ASC, created_at DESC").
Find(&stickers).Error
return stickers, err
}
// GetByID 根据ID获取表情
func (r *stickerRepository) GetByID(id string) (*model.UserSticker, error) {
var sticker model.UserSticker
err := r.db.Where("id = ?", id).First(&sticker).Error
if err != nil {
return nil, err
}
return &sticker, nil
}
// Create 创建表情
func (r *stickerRepository) Create(sticker *model.UserSticker) error {
return r.db.Create(sticker).Error
}
// Delete 删除表情
func (r *stickerRepository) Delete(id string) error {
return r.db.Where("id = ?", id).Delete(&model.UserSticker{}).Error
}
// DeleteByUserID 删除用户的所有表情
func (r *stickerRepository) DeleteByUserID(userID string) error {
return r.db.Where("user_id = ?", userID).Delete(&model.UserSticker{}).Error
}
// Exists 检查表情是否存在
func (r *stickerRepository) Exists(userID string, url string) (bool, error) {
var count int64
err := r.db.Model(&model.UserSticker{}).
Where("user_id = ? AND url = ?", userID, url).
Count(&count).Error
return count > 0, err
}
// UpdateSortOrder 更新排序
func (r *stickerRepository) UpdateSortOrder(id string, sortOrder int) error {
return r.db.Model(&model.UserSticker{}).
Where("id = ?", id).
Update("sort_order", sortOrder).Error
}
// BatchUpdateSortOrder 批量更新排序
func (r *stickerRepository) BatchUpdateSortOrder(userID string, orders map[string]int) error {
return r.db.Transaction(func(tx *gorm.DB) error {
for id, sortOrder := range orders {
if err := tx.Model(&model.UserSticker{}).
Where("id = ? AND user_id = ?", id, userID).
Update("sort_order", sortOrder).Error; err != nil {
return err
}
}
return nil
})
}
// CountByUserID 获取用户表情数量
func (r *stickerRepository) CountByUserID(userID string) (int64, error) {
var count int64
err := r.db.Model(&model.UserSticker{}).
Where("user_id = ?", userID).
Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,114 @@
package repository
import (
"carrot_bbs/internal/model"
"gorm.io/gorm"
)
// SystemNotificationRepository 系统通知仓储
type SystemNotificationRepository struct {
db *gorm.DB
}
// NewSystemNotificationRepository 创建系统通知仓储
func NewSystemNotificationRepository(db *gorm.DB) *SystemNotificationRepository {
return &SystemNotificationRepository{db: db}
}
// Create 创建系统通知
func (r *SystemNotificationRepository) Create(notification *model.SystemNotification) error {
return r.db.Create(notification).Error
}
// GetByID 根据ID获取通知
func (r *SystemNotificationRepository) GetByID(id int64) (*model.SystemNotification, error) {
var notification model.SystemNotification
err := r.db.First(&notification, "id = ?", id).Error
if err != nil {
return nil, err
}
return &notification, nil
}
// GetByReceiverID 获取用户的通知列表
func (r *SystemNotificationRepository) GetByReceiverID(receiverID string, page, pageSize int) ([]*model.SystemNotification, int64, error) {
var notifications []*model.SystemNotification
var total int64
query := r.db.Model(&model.SystemNotification{}).Where("receiver_id = ?", receiverID)
query.Count(&total)
offset := (page - 1) * pageSize
err := query.Offset(offset).
Limit(pageSize).
Order("created_at DESC").
Find(&notifications).Error
return notifications, total, err
}
// GetUnreadByReceiverID 获取用户的未读通知列表
func (r *SystemNotificationRepository) GetUnreadByReceiverID(receiverID string, limit int) ([]*model.SystemNotification, error) {
var notifications []*model.SystemNotification
err := r.db.Where("receiver_id = ? AND is_read = ?", receiverID, false).
Order("created_at DESC").
Limit(limit).
Find(&notifications).Error
return notifications, err
}
// GetUnreadCount 获取用户未读通知数量
func (r *SystemNotificationRepository) GetUnreadCount(receiverID string) (int64, error) {
var count int64
err := r.db.Model(&model.SystemNotification{}).
Where("receiver_id = ? AND is_read = ?", receiverID, false).
Count(&count).Error
return count, err
}
// MarkAsRead 标记单条通知为已读
func (r *SystemNotificationRepository) MarkAsRead(id int64, receiverID string) error {
now := model.SystemNotification{}.UpdatedAt
return r.db.Model(&model.SystemNotification{}).
Where("id = ? AND receiver_id = ?", id, receiverID).
Updates(map[string]interface{}{
"is_read": true,
"read_at": now,
}).Error
}
// MarkAllAsRead 标记用户所有通知为已读
func (r *SystemNotificationRepository) MarkAllAsRead(receiverID string) error {
now := model.SystemNotification{}.UpdatedAt
return r.db.Model(&model.SystemNotification{}).
Where("receiver_id = ? AND is_read = ?", receiverID, false).
Updates(map[string]interface{}{
"is_read": true,
"read_at": now,
}).Error
}
// Delete 删除通知(软删除)
func (r *SystemNotificationRepository) Delete(id int64, receiverID string) error {
return r.db.Where("id = ? AND receiver_id = ?", id, receiverID).
Delete(&model.SystemNotification{}).Error
}
// GetByType 获取用户指定类型的通知
func (r *SystemNotificationRepository) GetByType(receiverID string, notifyType model.SystemNotificationType, page, pageSize int) ([]*model.SystemNotification, int64, error) {
var notifications []*model.SystemNotification
var total int64
query := r.db.Model(&model.SystemNotification{}).
Where("receiver_id = ? AND type = ?", receiverID, notifyType)
query.Count(&total)
offset := (page - 1) * pageSize
err := query.Offset(offset).
Limit(pageSize).
Order("created_at DESC").
Find(&notifications).Error
return notifications, total, err
}

View File

@@ -0,0 +1,404 @@
package repository
import (
"carrot_bbs/internal/model"
"fmt"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// UserRepository 用户仓储
type UserRepository struct {
db *gorm.DB
}
// NewUserRepository 创建用户仓储
func NewUserRepository(db *gorm.DB) *UserRepository {
return &UserRepository{db: db}
}
// Create 创建用户
func (r *UserRepository) Create(user *model.User) error {
return r.db.Create(user).Error
}
// GetByID 根据ID获取用户
func (r *UserRepository) GetByID(id string) (*model.User, error) {
var user model.User
err := r.db.First(&user, "id = ?", id).Error
if err != nil {
return nil, err
}
return &user, nil
}
// GetByUsername 根据用户名获取用户
func (r *UserRepository) GetByUsername(username string) (*model.User, error) {
var user model.User
err := r.db.First(&user, "username = ?", username).Error
if err != nil {
return nil, err
}
return &user, nil
}
// GetByEmail 根据邮箱获取用户
func (r *UserRepository) GetByEmail(email string) (*model.User, error) {
var user model.User
err := r.db.First(&user, "email = ?", email).Error
if err != nil {
return nil, err
}
return &user, nil
}
// GetByPhone 根据手机号获取用户
func (r *UserRepository) GetByPhone(phone string) (*model.User, error) {
var user model.User
err := r.db.First(&user, "phone = ?", phone).Error
if err != nil {
return nil, err
}
return &user, nil
}
// Update 更新用户
func (r *UserRepository) Update(user *model.User) error {
return r.db.Save(user).Error
}
// Delete 删除用户
func (r *UserRepository) Delete(id string) error {
return r.db.Delete(&model.User{}, "id = ?", id).Error
}
// List 分页获取用户列表
func (r *UserRepository) List(page, pageSize int) ([]*model.User, int64, error) {
var users []*model.User
var total int64
r.db.Model(&model.User{}).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Order("created_at DESC, id DESC").Offset(offset).Limit(pageSize).Find(&users).Error
return users, total, err
}
// GetFollowers 获取用户粉丝
func (r *UserRepository) GetFollowers(userID string, page, pageSize int) ([]*model.User, int64, error) {
var users []*model.User
var total int64
subQuery := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Select("follower_id")
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("id IN (?)", subQuery).
Order("created_at DESC, id DESC").
Offset(offset).Limit(pageSize).
Find(&users).Error
return users, total, err
}
// GetFollowing 获取用户关注
func (r *UserRepository) GetFollowing(userID string, page, pageSize int) ([]*model.User, int64, error) {
var users []*model.User
var total int64
subQuery := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Select("following_id")
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("id IN (?)", subQuery).
Order("created_at DESC, id DESC").
Offset(offset).Limit(pageSize).
Find(&users).Error
return users, total, err
}
// CreateFollow 创建关注关系
func (r *UserRepository) CreateFollow(follow *model.Follow) error {
return r.db.Create(follow).Error
}
// DeleteFollow 删除关注关系
func (r *UserRepository) DeleteFollow(followerID, followingID string) error {
return r.db.Where("follower_id = ? AND following_id = ?", followerID, followingID).Delete(&model.Follow{}).Error
}
// IsFollowing 检查是否关注了某用户
func (r *UserRepository) IsFollowing(followerID, followingID string) (bool, error) {
var count int64
err := r.db.Model(&model.Follow{}).Where("follower_id = ? AND following_id = ?", followerID, followingID).Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// IncrementFollowersCount 增加用户粉丝数
func (r *UserRepository) IncrementFollowersCount(userID string) error {
return r.db.Model(&model.User{}).Where("id = ?", userID).
UpdateColumn("followers_count", gorm.Expr("followers_count + 1")).Error
}
// DecrementFollowersCount 减少用户粉丝数
func (r *UserRepository) DecrementFollowersCount(userID string) error {
return r.db.Model(&model.User{}).Where("id = ? AND followers_count > 0", userID).
UpdateColumn("followers_count", gorm.Expr("followers_count - 1")).Error
}
// IncrementFollowingCount 增加用户关注数
func (r *UserRepository) IncrementFollowingCount(userID string) error {
return r.db.Model(&model.User{}).Where("id = ?", userID).
UpdateColumn("following_count", gorm.Expr("following_count + 1")).Error
}
// DecrementFollowingCount 减少用户关注数
func (r *UserRepository) DecrementFollowingCount(userID string) error {
return r.db.Model(&model.User{}).Where("id = ? AND following_count > 0", userID).
UpdateColumn("following_count", gorm.Expr("following_count - 1")).Error
}
// RefreshFollowersCount 刷新用户粉丝数(通过实际计数)
func (r *UserRepository) RefreshFollowersCount(userID string) error {
var count int64
err := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Count(&count).Error
if err != nil {
return err
}
return r.db.Model(&model.User{}).Where("id = ?", userID).
UpdateColumn("followers_count", count).Error
}
// GetPostsCount 获取用户帖子数(实时计算)
func (r *UserRepository) GetPostsCount(userID string) (int64, error) {
var count int64
err := r.db.Model(&model.Post{}).Where("user_id = ?", userID).Count(&count).Error
return count, err
}
// GetPostsCountBatch 批量获取用户帖子数(实时计算)
// 返回 map[userID]postsCount
func (r *UserRepository) GetPostsCountBatch(userIDs []string) (map[string]int64, error) {
result := make(map[string]int64)
if len(userIDs) == 0 {
return result, nil
}
// 初始化所有用户ID的计数为0
for _, userID := range userIDs {
result[userID] = 0
}
// 使用 GROUP BY 一次性查询所有用户的帖子数
type CountResult struct {
UserID string
Count int64
}
var counts []CountResult
err := r.db.Model(&model.Post{}).
Select("user_id, count(*) as count").
Where("user_id IN ?", userIDs).
Group("user_id").
Scan(&counts).Error
if err != nil {
return nil, err
}
// 更新查询结果
for _, c := range counts {
result[c.UserID] = c.Count
}
return result, nil
}
// RefreshFollowingCount 刷新用户关注数(通过实际计数)
func (r *UserRepository) RefreshFollowingCount(userID string) error {
var count int64
err := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Count(&count).Error
if err != nil {
return err
}
return r.db.Model(&model.User{}).Where("id = ?", userID).
UpdateColumn("following_count", count).Error
}
// IsBlocked 检查拉黑关系是否存在blocker -> blocked
func (r *UserRepository) IsBlocked(blockerID, blockedID string) (bool, error) {
var count int64
err := r.db.Model(&model.UserBlock{}).
Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID).
Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// IsBlockedEitherDirection 检查是否任一方向存在拉黑
func (r *UserRepository) IsBlockedEitherDirection(userA, userB string) (bool, error) {
var count int64
err := r.db.Model(&model.UserBlock{}).
Where("(blocker_id = ? AND blocked_id = ?) OR (blocker_id = ? AND blocked_id = ?)",
userA, userB, userB, userA).
Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// BlockUserAndCleanupRelations 拉黑用户并清理双向关注关系(事务)
func (r *UserRepository) BlockUserAndCleanupRelations(blockerID, blockedID string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
block := &model.UserBlock{
BlockerID: blockerID,
BlockedID: blockedID,
}
if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "blocker_id"}, {Name: "blocked_id"}},
DoNothing: true,
}).Create(block).Error; err != nil {
return err
}
if err := tx.Where("follower_id = ? AND following_id = ?", blockerID, blockedID).
Delete(&model.Follow{}).Error; err != nil {
return err
}
if err := tx.Where("follower_id = ? AND following_id = ?", blockedID, blockerID).
Delete(&model.Follow{}).Error; err != nil {
return err
}
for _, uid := range []string{blockerID, blockedID} {
var followersCount int64
if err := tx.Model(&model.Follow{}).Where("following_id = ?", uid).Count(&followersCount).Error; err != nil {
return err
}
if err := tx.Model(&model.User{}).Where("id = ?", uid).
UpdateColumn("followers_count", followersCount).Error; err != nil {
return err
}
var followingCount int64
if err := tx.Model(&model.Follow{}).Where("follower_id = ?", uid).Count(&followingCount).Error; err != nil {
return err
}
if err := tx.Model(&model.User{}).Where("id = ?", uid).
UpdateColumn("following_count", followingCount).Error; err != nil {
return err
}
}
return nil
})
}
// UnblockUser 取消拉黑
func (r *UserRepository) UnblockUser(blockerID, blockedID string) error {
return r.db.Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID).
Delete(&model.UserBlock{}).Error
}
// GetBlockedUsers 获取用户黑名单列表
func (r *UserRepository) GetBlockedUsers(blockerID string, page, pageSize int) ([]*model.User, int64, error) {
var users []*model.User
var total int64
subQuery := r.db.Model(&model.UserBlock{}).Where("blocker_id = ?", blockerID).Select("blocked_id")
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("id IN (?)", subQuery).
Order("created_at DESC, id DESC").
Offset(offset).
Limit(pageSize).
Find(&users).Error
return users, total, err
}
// Search 搜索用户
func (r *UserRepository) Search(keyword string, page, pageSize int) ([]*model.User, int64, error) {
var users []*model.User
var total int64
query := r.db.Model(&model.User{})
// 搜索用户名、昵称、简介
if keyword != "" {
if r.db.Dialector.Name() == "postgres" {
query = query.Where(
"to_tsvector('simple', COALESCE(username, '') || ' ' || COALESCE(nickname, '') || ' ' || COALESCE(bio, '')) @@ plainto_tsquery('simple', ?)",
keyword,
)
} else {
searchPattern := "%" + keyword + "%"
query = query.Where("username LIKE ? OR nickname LIKE ? OR bio LIKE ?", searchPattern, searchPattern, searchPattern)
}
}
query.Count(&total)
offset := (page - 1) * pageSize
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&users).Error
return users, total, err
}
// GetMutualFollowStatus 批量获取双向关注状态
// 返回 map[userID][isFollowing, isFollowingMe]
func (r *UserRepository) GetMutualFollowStatus(currentUserID string, targetUserIDs []string) (map[string][2]bool, error) {
result := make(map[string][2]bool)
if len(targetUserIDs) == 0 {
return result, nil
}
fmt.Printf("[DEBUG] GetMutualFollowStatus: currentUserID=%s, targetUserIDs=%v\n", currentUserID, targetUserIDs)
// 初始化所有目标用户为未关注状态
for _, userID := range targetUserIDs {
result[userID] = [2]bool{false, false}
}
// 查询当前用户关注了哪些目标用户 (isFollowing)
var followingIDs []string
err := r.db.Model(&model.Follow{}).
Where("follower_id = ? AND following_id IN ?", currentUserID, targetUserIDs).
Pluck("following_id", &followingIDs).Error
if err != nil {
return nil, err
}
fmt.Printf("[DEBUG] GetMutualFollowStatus: currentUser follows these targets: %v\n", followingIDs)
for _, id := range followingIDs {
status := result[id]
status[0] = true
result[id] = status
}
// 查询哪些目标用户关注了当前用户 (isFollowingMe)
var followerIDs []string
err = r.db.Model(&model.Follow{}).
Where("follower_id IN ? AND following_id = ?", targetUserIDs, currentUserID).
Pluck("follower_id", &followerIDs).Error
if err != nil {
return nil, err
}
fmt.Printf("[DEBUG] GetMutualFollowStatus: these targets follow currentUser: %v\n", followerIDs)
for _, id := range followerIDs {
status := result[id]
status[1] = true
result[id] = status
}
fmt.Printf("[DEBUG] GetMutualFollowStatus: final result=%v\n", result)
return result, nil
}

View File

@@ -0,0 +1,141 @@
package repository
import (
"carrot_bbs/internal/model"
"errors"
"gorm.io/gorm"
)
// VoteRepository 投票仓储
type VoteRepository struct {
db *gorm.DB
}
// NewVoteRepository 创建投票仓储
func NewVoteRepository(db *gorm.DB) *VoteRepository {
return &VoteRepository{db: db}
}
// CreateOptions 批量创建投票选项
func (r *VoteRepository) CreateOptions(postID string, options []string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
for i, content := range options {
option := &model.VoteOption{
PostID: postID,
Content: content,
SortOrder: i,
}
if err := tx.Create(option).Error; err != nil {
return err
}
}
return nil
})
}
// GetOptionsByPostID 获取帖子的所有投票选项
func (r *VoteRepository) GetOptionsByPostID(postID string) ([]model.VoteOption, error) {
var options []model.VoteOption
err := r.db.Where("post_id = ?", postID).Order("sort_order ASC").Find(&options).Error
return options, err
}
// Vote 用户投票
func (r *VoteRepository) Vote(postID, userID, optionID string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 检查用户是否已投票
var existing model.UserVote
err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&existing).Error
if err == nil {
// 已经投票,返回错误
return errors.New("user already voted")
}
if err != gorm.ErrRecordNotFound {
return err
}
// 验证选项是否属于该帖子
var option model.VoteOption
if err := tx.Where("id = ? AND post_id = ?", optionID, postID).First(&option).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return errors.New("invalid option")
}
return err
}
// 创建投票记录
if err := tx.Create(&model.UserVote{
PostID: postID,
UserID: userID,
OptionID: optionID,
}).Error; err != nil {
return err
}
// 原子增加选项投票数
return tx.Model(&model.VoteOption{}).Where("id = ?", optionID).
UpdateColumn("votes_count", gorm.Expr("votes_count + 1")).Error
})
}
// Unvote 取消投票
func (r *VoteRepository) Unvote(postID, userID string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 获取用户的投票记录
var userVote model.UserVote
err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&userVote).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil // 没有投票记录,直接返回
}
return err
}
// 删除投票记录
result := tx.Where("post_id = ? AND user_id = ?", postID, userID).Delete(&model.UserVote{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected > 0 {
// 原子减少选项投票数
return tx.Model(&model.VoteOption{}).Where("id = ?", userVote.OptionID).
UpdateColumn("votes_count", gorm.Expr("votes_count - 1")).Error
}
return nil
})
}
// GetUserVote 获取用户在指定帖子的投票
func (r *VoteRepository) GetUserVote(postID, userID string) (*model.UserVote, error) {
var userVote model.UserVote
err := r.db.Where("post_id = ? AND user_id = ?", postID, userID).First(&userVote).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, err
}
return &userVote, nil
}
// UpdateOption 更新选项内容
func (r *VoteRepository) UpdateOption(optionID, content string) error {
return r.db.Model(&model.VoteOption{}).Where("id = ?", optionID).
Update("content", content).Error
}
// DeleteOptionsByPostID 删除帖子的所有投票选项
func (r *VoteRepository) DeleteOptionsByPostID(postID string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 删除关联的用户投票记录
if err := tx.Where("post_id = ?", postID).Delete(&model.UserVote{}).Error; err != nil {
return err
}
// 删除投票选项
return tx.Where("post_id = ?", postID).Delete(&model.VoteOption{}).Error
})
}

334
internal/router/router.go Normal file
View File

@@ -0,0 +1,334 @@
package router
import (
"github.com/gin-gonic/gin"
"carrot_bbs/internal/handler"
"carrot_bbs/internal/middleware"
"carrot_bbs/internal/service"
)
// Router 路由配置
type Router struct {
engine *gin.Engine
userHandler *handler.UserHandler
postHandler *handler.PostHandler
commentHandler *handler.CommentHandler
messageHandler *handler.MessageHandler
notificationHandler *handler.NotificationHandler
uploadHandler *handler.UploadHandler
wsHandler *handler.WebSocketHandler
pushHandler *handler.PushHandler
systemMessageHandler *handler.SystemMessageHandler
groupHandler *handler.GroupHandler
stickerHandler *handler.StickerHandler
gorseHandler *handler.GorseHandler
voteHandler *handler.VoteHandler
jwtService *service.JWTService
}
// New 创建路由
func New(
userHandler *handler.UserHandler,
postHandler *handler.PostHandler,
commentHandler *handler.CommentHandler,
messageHandler *handler.MessageHandler,
notificationHandler *handler.NotificationHandler,
uploadHandler *handler.UploadHandler,
jwtService *service.JWTService,
wsHandler *handler.WebSocketHandler,
pushHandler *handler.PushHandler,
systemMessageHandler *handler.SystemMessageHandler,
groupHandler *handler.GroupHandler,
stickerHandler *handler.StickerHandler,
gorseHandler *handler.GorseHandler,
voteHandler *handler.VoteHandler,
) *Router {
// 设置JWT服务
userHandler.SetJWTService(jwtService)
r := &Router{
engine: gin.Default(),
userHandler: userHandler,
postHandler: postHandler,
commentHandler: commentHandler,
messageHandler: messageHandler,
notificationHandler: notificationHandler,
uploadHandler: uploadHandler,
wsHandler: wsHandler,
pushHandler: pushHandler,
systemMessageHandler: systemMessageHandler,
groupHandler: groupHandler,
stickerHandler: stickerHandler,
gorseHandler: gorseHandler,
voteHandler: voteHandler,
jwtService: jwtService,
}
r.setupRoutes()
return r
}
// setupRoutes 设置路由
func (r *Router) setupRoutes() {
// 中间件
r.engine.Use(middleware.CORS())
// 健康检查
r.engine.GET("/health", func(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok"})
})
// WebSocket 路由
if r.wsHandler != nil {
r.engine.GET("/ws", r.wsHandler.HandleWebSocket)
}
// API v1
v1 := r.engine.Group("/api/v1")
{
// 认证路由(公开)
auth := v1.Group("/auth")
{
auth.POST("/register", r.userHandler.Register)
auth.POST("/register/send-code", r.userHandler.SendRegisterCode)
auth.POST("/login", r.userHandler.Login)
auth.POST("/password/send-code", r.userHandler.SendPasswordResetCode)
auth.POST("/password/reset", r.userHandler.ResetPassword)
auth.POST("/refresh", r.userHandler.RefreshToken)
}
// 需要认证的路由
authMiddleware := middleware.Auth(r.jwtService)
// 用户路由
users := v1.Group("/users")
{
// 当前用户
users.GET("/me", authMiddleware, r.userHandler.GetCurrentUser)
users.PUT("/me", authMiddleware, r.userHandler.UpdateUser)
users.POST("/me/email/send-code", authMiddleware, r.userHandler.SendEmailVerifyCode)
users.POST("/me/email/verify", authMiddleware, r.userHandler.VerifyEmail)
users.POST("/me/avatar", authMiddleware, r.uploadHandler.UploadAvatar)
users.POST("/me/cover", authMiddleware, r.uploadHandler.UploadCover)
users.POST("/change-password/send-code", authMiddleware, r.userHandler.SendChangePasswordCode)
users.POST("/change-password", authMiddleware, r.userHandler.ChangePassword)
// 搜索用户
users.GET("/search", middleware.OptionalAuth(r.jwtService), r.userHandler.Search)
// 其他用户
users.GET("/:id", middleware.OptionalAuth(r.jwtService), r.userHandler.GetUserByID)
users.POST("/:id/follow", authMiddleware, r.userHandler.FollowUser)
users.DELETE("/:id/follow", authMiddleware, r.userHandler.UnfollowUser)
users.POST("/:id/block", authMiddleware, r.userHandler.BlockUser)
users.DELETE("/:id/block", authMiddleware, r.userHandler.UnblockUser)
users.GET("/:id/block-status", authMiddleware, r.userHandler.GetBlockStatus)
users.GET("/blocks", authMiddleware, r.userHandler.GetBlockedUsers)
users.GET("/:id/following", authMiddleware, r.userHandler.GetFollowingList)
users.GET("/:id/followers", authMiddleware, r.userHandler.GetFollowersList)
// 用户帖子 - 使用 OptionalAuth 获取当前用户点赞/收藏状态
users.GET("/:id/posts", middleware.OptionalAuth(r.jwtService), r.postHandler.GetUserPosts)
users.GET("/:id/favorites", middleware.OptionalAuth(r.jwtService), r.postHandler.GetFavorites)
}
// 认证路由(公开)
authPublic := v1.Group("/auth")
{
authPublic.GET("/check-username", r.userHandler.CheckUsername)
}
// 帖子路由
posts := v1.Group("/posts")
{
// 使用 OptionalAuth 中间件来获取用户登录状态
posts.GET("", middleware.OptionalAuth(r.jwtService), r.postHandler.List)
posts.GET("/search", middleware.OptionalAuth(r.jwtService), r.postHandler.Search)
posts.GET("/:id", middleware.OptionalAuth(r.jwtService), r.postHandler.GetByID)
posts.POST("", authMiddleware, r.postHandler.Create)
posts.PUT("/:id", authMiddleware, r.postHandler.Update)
posts.DELETE("/:id", authMiddleware, r.postHandler.Delete)
// 浏览量记录(可选认证,允许游客浏览)
posts.POST("/:id/view", middleware.OptionalAuth(r.jwtService), r.postHandler.RecordView)
// 点赞
posts.POST("/:id/like", authMiddleware, r.postHandler.Like)
posts.DELETE("/:id/like", authMiddleware, r.postHandler.Unlike)
// 收藏
posts.POST("/:id/favorite", authMiddleware, r.postHandler.Favorite)
posts.DELETE("/:id/favorite", authMiddleware, r.postHandler.Unfavorite)
// 投票相关路由
posts.POST("/vote", authMiddleware, r.voteHandler.CreateVotePost) // 创建投票帖子
posts.GET("/:id/vote", middleware.OptionalAuth(r.jwtService), r.voteHandler.GetVoteResult) // 获取投票结果
posts.POST("/:id/vote", authMiddleware, r.voteHandler.Vote) // 投票
posts.DELETE("/:id/vote", authMiddleware, r.voteHandler.Unvote) // 取消投票
}
// 投票选项路由
voteOptions := v1.Group("/vote-options")
voteOptions.Use(authMiddleware)
{
voteOptions.PUT("/:id", r.voteHandler.UpdateVoteOption) // 更新选项
}
// 评论路由
comments := v1.Group("/comments")
{
comments.GET("/post/:id", middleware.OptionalAuth(r.jwtService), r.commentHandler.GetByPostID)
comments.GET("/:id", middleware.OptionalAuth(r.jwtService), r.commentHandler.GetByID)
comments.POST("", authMiddleware, r.commentHandler.Create)
comments.PUT("/:id", authMiddleware, r.commentHandler.Update)
comments.DELETE("/:id", authMiddleware, r.commentHandler.Delete)
comments.GET("/:id/replies", middleware.OptionalAuth(r.jwtService), r.commentHandler.GetReplies)
comments.GET("/:id/replies/flat", middleware.OptionalAuth(r.jwtService), r.commentHandler.GetRepliesByRootID) // 扁平化分页获取回复
// 评论点赞
comments.POST("/:id/like", authMiddleware, r.commentHandler.Like)
comments.DELETE("/:id/like", authMiddleware, r.commentHandler.Unlike)
}
// 会话路由(新版 RESTful action 风格)
conversations := v1.Group("/conversations")
conversations.Use(authMiddleware)
{
// 获取会话列表
conversations.GET("/list", r.messageHandler.HandleGetConversationList)
// 创建会话
conversations.POST("/create", r.messageHandler.HandleCreateConversation)
// 获取会话详情
conversations.GET("/get", r.messageHandler.HandleGetConversation)
// 获取会话消息
conversations.GET("/get_messages", r.messageHandler.HandleGetMessages)
// 发送消息
conversations.POST("/send_message", r.messageHandler.HandleSendMessage)
// 标记已读
conversations.POST("/mark_read", r.messageHandler.HandleMarkRead)
// 会话置顶
conversations.POST("/set_pinned", r.messageHandler.HandleSetConversationPinned)
// 获取未读消息总数
conversations.GET("/unread/count", r.messageHandler.GetUnreadCount)
// 仅自己删除会话
conversations.DELETE("/:id/self", r.messageHandler.HandleDeleteConversationForSelf)
}
// 消息操作路由
messages := v1.Group("/messages")
messages.Use(authMiddleware)
{
// 撤回/删除消息(统一接口)
messages.POST("/delete_msg", r.messageHandler.HandleDeleteMsg)
}
// 通知路由
notifications := v1.Group("/notifications")
{
notifications.GET("", authMiddleware, r.notificationHandler.GetNotifications)
notifications.POST("/:id/read", authMiddleware, r.notificationHandler.MarkAsRead)
notifications.POST("/read-all", authMiddleware, r.notificationHandler.MarkAllAsRead)
notifications.GET("/unread-count", authMiddleware, r.notificationHandler.GetUnreadCount)
notifications.DELETE("/:id", authMiddleware, r.notificationHandler.DeleteNotification)
notifications.DELETE("", authMiddleware, r.notificationHandler.ClearAllNotifications)
}
// 上传路由
uploads := v1.Group("/uploads")
{
uploads.POST("/images", authMiddleware, r.uploadHandler.UploadImage)
}
// 推送相关路由
if r.pushHandler != nil {
pushGroup := v1.Group("/push")
pushGroup.Use(authMiddleware)
{
pushGroup.POST("/devices", r.pushHandler.RegisterDevice)
pushGroup.GET("/devices", r.pushHandler.GetMyDevices)
pushGroup.DELETE("/devices/:device_id", r.pushHandler.UnregisterDevice)
pushGroup.PUT("/devices/:device_id/token", r.pushHandler.UpdateDeviceToken)
pushGroup.GET("/records", r.pushHandler.GetPushRecords)
}
}
// 系统消息路由
if r.systemMessageHandler != nil {
msgGroup := v1.Group("/messages")
msgGroup.Use(authMiddleware)
{
msgGroup.GET("/system", r.systemMessageHandler.GetSystemMessages)
msgGroup.GET("/system/unread-count", r.systemMessageHandler.GetUnreadCount)
msgGroup.PUT("/system/:id/read", r.systemMessageHandler.MarkAsRead)
msgGroup.PUT("/system/read-all", r.systemMessageHandler.MarkAllAsRead)
}
}
// 群组路由(新版 RESTful action 风格)
if r.groupHandler != nil {
groups := v1.Group("/groups")
groups.Use(authMiddleware)
{
// 群组管理
groups.POST("/create", r.groupHandler.HandleCreateGroup)
groups.GET("/list", r.groupHandler.HandleGetUserGroups)
groups.GET("/get", r.groupHandler.HandleGetGroupInfo)
groups.GET("/get_my_info", r.groupHandler.HandleGetMyMemberInfo)
groups.POST("/dissolve", r.groupHandler.HandleDissolveGroup)
groups.POST("/transfer", r.groupHandler.HandleTransferOwner)
// 成员管理
groups.POST("/invite_members", r.groupHandler.HandleInviteMembers)
groups.POST("/join", r.groupHandler.HandleJoinGroup)
groups.POST("/respond_invite", r.groupHandler.HandleRespondInvite)
groups.POST("/set_group_leave", r.groupHandler.HandleSetGroupLeave)
groups.GET("/get_members", r.groupHandler.HandleGetGroupMemberList)
groups.POST("/set_group_kick", r.groupHandler.HandleSetGroupKick)
groups.POST("/set_group_admin", r.groupHandler.HandleSetGroupAdmin)
groups.POST("/set_nickname", r.groupHandler.HandleSetNickname)
groups.POST("/set_group_ban", r.groupHandler.HandleSetGroupBan)
// 群设置
groups.POST("/set_group_whole_ban", r.groupHandler.HandleSetGroupWholeBan)
groups.POST("/set_join_type", r.groupHandler.HandleSetJoinType)
groups.POST("/set_group_name", r.groupHandler.HandleSetGroupName)
groups.POST("/set_group_avatar", r.groupHandler.HandleSetGroupAvatar)
// 群公告
groups.POST("/create_announcement", r.groupHandler.HandleCreateAnnouncement)
groups.GET("/get_announcements", r.groupHandler.HandleGetAnnouncements)
groups.POST("/delete_announcement", r.groupHandler.HandleDeleteAnnouncement)
// 加群请求处理(预留)
groups.POST("/set_group_add_request", r.groupHandler.HandleSetGroupAddRequest)
}
}
// 自定义表情路由
if r.stickerHandler != nil {
stickers := v1.Group("/stickers")
stickers.Use(authMiddleware)
{
stickers.GET("", r.stickerHandler.GetStickers)
stickers.POST("", r.stickerHandler.AddSticker)
stickers.DELETE("", r.stickerHandler.DeleteSticker)
stickers.POST("/reorder", r.stickerHandler.ReorderStickers)
stickers.GET("/check", r.stickerHandler.CheckStickerExists)
}
}
// Gorse 管理路由
if r.gorseHandler != nil {
gorseGroup := v1.Group("/gorse")
{
gorseGroup.GET("/status", r.gorseHandler.GetStatus)
gorseGroup.POST("/import", r.gorseHandler.ImportData)
}
}
}
}
// Engine 获取引擎
func (r *Router) Engine() *gin.Engine {
return r.engine
}

View File

@@ -0,0 +1,759 @@
package service
import (
"context"
"encoding/json"
"fmt"
"log"
"strings"
"sync"
"time"
"carrot_bbs/internal/model"
"gorm.io/gorm"
)
// ==================== 内容审核服务接口和实现 ====================
// AuditServiceProvider 内容审核服务提供商接口
type AuditServiceProvider interface {
// AuditText 审核文本
AuditText(ctx context.Context, text string, scene string) (*AuditResult, error)
// AuditImage 审核图片
AuditImage(ctx context.Context, imageURL string) (*AuditResult, error)
// GetName 获取提供商名称
GetName() string
}
// AuditResult 审核结果
type AuditResult struct {
Pass bool `json:"pass"` // 是否通过
Risk string `json:"risk"` // 风险等级: low, medium, high
Labels []string `json:"labels"` // 标签列表
Suggest string `json:"suggest"` // 建议: pass, review, block
Detail string `json:"detail"` // 详细说明
Provider string `json:"provider"` // 服务提供商
}
// AuditService 内容审核服务接口
type AuditService interface {
// AuditText 审核文本
AuditText(ctx context.Context, text string, auditType string) (*AuditResult, error)
// AuditImage 审核图片
AuditImage(ctx context.Context, imageURL string) (*AuditResult, error)
// GetAuditResult 获取审核结果
GetAuditResult(ctx context.Context, auditID string) (*AuditResult, error)
// SetProvider 设置审核服务提供商
SetProvider(provider AuditServiceProvider)
// GetProvider 获取当前审核服务提供商
GetProvider() AuditServiceProvider
}
// auditServiceImpl 内容审核服务实现
type auditServiceImpl struct {
db *gorm.DB
provider AuditServiceProvider
config *AuditConfig
mu sync.RWMutex
}
// AuditConfig 内容审核服务配置
type AuditConfig struct {
Enabled bool `mapstructure:"enabled" yaml:"enabled"`
// 审核服务提供商: local, aliyun, tencent, baidu
Provider string `mapstructure:"provider" yaml:"provider"`
// 阿里云配置
AliyunAccessKey string `mapstructure:"aliyun_access_key" yaml:"aliyun_access_key"`
AliyunSecretKey string `mapstructure:"aliyun_secret_key" yaml:"aliyun_secret_key"`
AliyunRegion string `mapstructure:"aliyun_region" yaml:"aliyun_region"`
// 腾讯云配置
TencentSecretID string `mapstructure:"tencent_secret_id" yaml:"tencent_secret_id"`
TencentSecretKey string `mapstructure:"tencent_secret_key" yaml:"tencent_secret_key"`
// 百度云配置
BaiduAPIKey string `mapstructure:"baidu_api_key" yaml:"baidu_api_key"`
BaiduSecretKey string `mapstructure:"baidu_secret_key" yaml:"baidu_secret_key"`
// 是否自动审核
AutoAudit bool `mapstructure:"auto_audit" yaml:"auto_audit"`
// 审核超时时间(秒)
Timeout int `mapstructure:"timeout" yaml:"timeout"`
}
// NewAuditService 创建内容审核服务
func NewAuditService(db *gorm.DB, config *AuditConfig) AuditService {
s := &auditServiceImpl{
db: db,
config: config,
}
// 根据配置初始化提供商
if config.Enabled {
provider := s.initProvider(config.Provider)
s.provider = provider
}
return s
}
// initProvider 根据配置初始化审核服务提供商
func (s *auditServiceImpl) initProvider(providerType string) AuditServiceProvider {
switch strings.ToLower(providerType) {
case "aliyun":
return NewAliyunAuditProvider(s.config.AliyunAccessKey, s.config.AliyunSecretKey, s.config.AliyunRegion)
case "tencent":
return NewTencentAuditProvider(s.config.TencentSecretID, s.config.TencentSecretKey)
case "baidu":
return NewBaiduAuditProvider(s.config.BaiduAPIKey, s.config.BaiduSecretKey)
case "local":
fallthrough
default:
// 默认使用本地审核服务
return NewLocalAuditProvider()
}
}
// AuditText 审核文本
func (s *auditServiceImpl) AuditText(ctx context.Context, text string, auditType string) (*AuditResult, error) {
if !s.config.Enabled {
// 如果审核服务未启用,直接返回通过
return &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Detail: "Audit service disabled",
}, nil
}
if text == "" {
return &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Detail: "Empty text",
}, nil
}
var result *AuditResult
var err error
// 使用提供商审核
if s.provider != nil {
result, err = s.provider.AuditText(ctx, text, auditType)
} else {
// 如果没有设置提供商,使用本地审核
localProvider := NewLocalAuditProvider()
result, err = localProvider.AuditText(ctx, text, auditType)
}
if err != nil {
log.Printf("Audit text error: %v", err)
return &AuditResult{
Pass: false,
Risk: "high",
Suggest: "review",
Detail: fmt.Sprintf("Audit error: %v", err),
}, err
}
// 记录审核日志
go s.saveAuditLog(ctx, "text", "", text, auditType, result)
return result, nil
}
// AuditImage 审核图片
func (s *auditServiceImpl) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
if !s.config.Enabled {
return &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Detail: "Audit service disabled",
}, nil
}
if imageURL == "" {
return &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Detail: "Empty image URL",
}, nil
}
var result *AuditResult
var err error
// 使用提供商审核
if s.provider != nil {
result, err = s.provider.AuditImage(ctx, imageURL)
} else {
// 如果没有设置提供商,使用本地审核
localProvider := NewLocalAuditProvider()
result, err = localProvider.AuditImage(ctx, imageURL)
}
if err != nil {
log.Printf("Audit image error: %v", err)
return &AuditResult{
Pass: false,
Risk: "high",
Suggest: "review",
Detail: fmt.Sprintf("Audit error: %v", err),
}, err
}
// 记录审核日志
go s.saveAuditLog(ctx, "image", "", "", "image", result)
return result, nil
}
// GetAuditResult 获取审核结果
func (s *auditServiceImpl) GetAuditResult(ctx context.Context, auditID string) (*AuditResult, error) {
if s.db == nil || auditID == "" {
return nil, fmt.Errorf("invalid audit ID")
}
var auditLog model.AuditLog
if err := s.db.Where("id = ?", auditID).First(&auditLog).Error; err != nil {
return nil, err
}
result := &AuditResult{
Pass: auditLog.Result == model.AuditResultPass,
Risk: string(auditLog.RiskLevel),
Suggest: auditLog.Suggestion,
Detail: auditLog.Detail,
}
// 解析标签
if auditLog.Labels != "" {
json.Unmarshal([]byte(auditLog.Labels), &result.Labels)
}
return result, nil
}
// SetProvider 设置审核服务提供商
func (s *auditServiceImpl) SetProvider(provider AuditServiceProvider) {
s.mu.Lock()
defer s.mu.Unlock()
s.provider = provider
}
// GetProvider 获取当前审核服务提供商
func (s *auditServiceImpl) GetProvider() AuditServiceProvider {
s.mu.RLock()
defer s.mu.RUnlock()
return s.provider
}
// saveAuditLog 保存审核日志
func (s *auditServiceImpl) saveAuditLog(ctx context.Context, contentType, content, imageURL, auditType string, result *AuditResult) {
if s.db == nil {
return
}
auditLog := model.AuditLog{
ContentType: contentType,
Content: content,
ContentURL: imageURL,
AuditType: auditType,
Labels: strings.Join(result.Labels, ","),
Suggestion: result.Suggest,
Detail: result.Detail,
Source: model.AuditSourceAuto,
Status: "completed",
}
if result.Pass {
auditLog.Result = model.AuditResultPass
} else if result.Suggest == "review" {
auditLog.Result = model.AuditResultReview
} else {
auditLog.Result = model.AuditResultBlock
}
switch result.Risk {
case "low":
auditLog.RiskLevel = model.AuditRiskLevelLow
case "medium":
auditLog.RiskLevel = model.AuditRiskLevelMedium
case "high":
auditLog.RiskLevel = model.AuditRiskLevelHigh
default:
auditLog.RiskLevel = model.AuditRiskLevelLow
}
if err := s.db.Create(&auditLog).Error; err != nil {
log.Printf("Failed to save audit log: %v", err)
}
}
// ==================== 本地审核服务提供商 ====================
// localAuditProvider 本地审核服务提供商
type localAuditProvider struct {
// 可以注入敏感词服务进行本地审核
sensitiveService SensitiveService
}
// NewLocalAuditProvider 创建本地审核服务提供商
func NewLocalAuditProvider() AuditServiceProvider {
return &localAuditProvider{
sensitiveService: nil,
}
}
// GetName 获取提供商名称
func (p *localAuditProvider) GetName() string {
return "local"
}
// AuditText 审核文本
func (p *localAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
// 本地审核逻辑
// 1. 敏感词检查
// 2. 规则匹配
// 3. 简单的关键词检测
result := &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Labels: []string{},
Provider: "local",
}
// 如果有敏感词服务,使用它进行检测
if p.sensitiveService != nil {
hasSensitive, words := p.sensitiveService.Check(ctx, text)
if hasSensitive {
result.Pass = false
result.Risk = "high"
result.Suggest = "block"
result.Detail = fmt.Sprintf("包含敏感词: %s", strings.Join(words, ","))
result.Labels = append(result.Labels, "sensitive")
}
}
// 简单的关键词检测规则
// 实际项目中应该从数据库加载
suspiciousPatterns := []string{
"诈骗",
"钓鱼",
"木马",
"病毒",
}
for _, pattern := range suspiciousPatterns {
if strings.Contains(text, pattern) {
result.Pass = false
result.Risk = "high"
result.Suggest = "block"
result.Labels = append(result.Labels, "suspicious")
if result.Detail == "" {
result.Detail = fmt.Sprintf("包含可疑内容: %s", pattern)
} else {
result.Detail += fmt.Sprintf(", %s", pattern)
}
}
}
return result, nil
}
// AuditImage 审核图片
func (p *localAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
// 本地图片审核逻辑
// 1. 图片URL合法性检查
// 2. 图片格式检查
// 3. 可以扩展接入本地图片识别服务
result := &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Labels: []string{},
Provider: "local",
}
// 检查URL是否为空
if imageURL == "" {
result.Detail = "Empty image URL"
return result, nil
}
// 检查是否为支持的图片URL格式
validPrefixes := []string{"http://", "https://", "s3://", "oss://", "cos://"}
isValid := false
for _, prefix := range validPrefixes {
if strings.HasPrefix(strings.ToLower(imageURL), prefix) {
isValid = true
break
}
}
if !isValid {
result.Pass = false
result.Risk = "medium"
result.Suggest = "review"
result.Detail = "Invalid image URL format"
result.Labels = append(result.Labels, "invalid_url")
}
return result, nil
}
// SetSensitiveService 设置敏感词服务
func (p *localAuditProvider) SetSensitiveService(ss SensitiveService) {
p.sensitiveService = ss
}
// ==================== 阿里云审核服务提供商 ====================
// aliyunAuditProvider 阿里云审核服务提供商
type aliyunAuditProvider struct {
accessKey string
secretKey string
region string
}
// NewAliyunAuditProvider 创建阿里云审核服务提供商
func NewAliyunAuditProvider(accessKey, secretKey, region string) AuditServiceProvider {
return &aliyunAuditProvider{
accessKey: accessKey,
secretKey: secretKey,
region: region,
}
}
// GetName 获取提供商名称
func (p *aliyunAuditProvider) GetName() string {
return "aliyun"
}
// AuditText 审核文本
func (p *aliyunAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
// 阿里云内容安全API调用
// 实际项目中需要实现阿里云SDK调用
// 这里预留接口
result := &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Labels: []string{},
Provider: "aliyun",
Detail: "Aliyun audit not implemented, using pass",
}
// TODO: 实现阿里云内容安全API调用
// 具体参考: https://help.aliyun.com/document_detail/28417.html
return result, nil
}
// AuditImage 审核图片
func (p *aliyunAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
result := &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Labels: []string{},
Provider: "aliyun",
Detail: "Aliyun image audit not implemented, using pass",
}
// TODO: 实现阿里云图片审核API调用
return result, nil
}
// ==================== 腾讯云审核服务提供商 ====================
// tencentAuditProvider 腾讯云审核服务提供商
type tencentAuditProvider struct {
secretID string
secretKey string
}
// NewTencentAuditProvider 创建腾讯云审核服务提供商
func NewTencentAuditProvider(secretID, secretKey string) AuditServiceProvider {
return &tencentAuditProvider{
secretID: secretID,
secretKey: secretKey,
}
}
// GetName 获取提供商名称
func (p *tencentAuditProvider) GetName() string {
return "tencent"
}
// AuditText 审核文本
func (p *tencentAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
result := &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Labels: []string{},
Provider: "tencent",
Detail: "Tencent audit not implemented, using pass",
}
// TODO: 实现腾讯云内容审核API调用
// 具体参考: https://cloud.tencent.com/document/product/1124/64508
return result, nil
}
// AuditImage 审核图片
func (p *tencentAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
result := &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Labels: []string{},
Provider: "tencent",
Detail: "Tencent image audit not implemented, using pass",
}
// TODO: 实现腾讯云图片审核API调用
return result, nil
}
// ==================== 百度云审核服务提供商 ====================
// baiduAuditProvider 百度云审核服务提供商
type baiduAuditProvider struct {
apiKey string
secretKey string
}
// NewBaiduAuditProvider 创建百度云审核服务提供商
func NewBaiduAuditProvider(apiKey, secretKey string) AuditServiceProvider {
return &baiduAuditProvider{
apiKey: apiKey,
secretKey: secretKey,
}
}
// GetName 获取提供商名称
func (p *baiduAuditProvider) GetName() string {
return "baidu"
}
// AuditText 审核文本
func (p *baiduAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) {
result := &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Labels: []string{},
Provider: "baidu",
Detail: "Baidu audit not implemented, using pass",
}
// TODO: 实现百度云内容审核API调用
// 具体参考: https://cloud.baidu.com/doc/ANTISPAM/s/Jjw0r1iF6
return result, nil
}
// AuditImage 审核图片
func (p *baiduAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) {
result := &AuditResult{
Pass: true,
Risk: "low",
Suggest: "pass",
Labels: []string{},
Provider: "baidu",
Detail: "Baidu image audit not implemented, using pass",
}
// TODO: 实现百度云图片审核API调用
return result, nil
}
// ==================== 审核结果回调处理 ====================
// AuditCallback 审核回调处理
type AuditCallback struct {
service AuditService
}
// NewAuditCallback 创建审核回调处理
func NewAuditCallback(service AuditService) *AuditCallback {
return &AuditCallback{
service: service,
}
}
// HandleTextCallback 处理文本审核回调
func (c *AuditCallback) HandleTextCallback(ctx context.Context, auditID string, result *AuditResult) error {
if c.service == nil || auditID == "" || result == nil {
return fmt.Errorf("invalid parameters")
}
log.Printf("Processing text audit callback: auditID=%s, result=%+v", auditID, result)
// 根据审核结果执行相应操作
// 例如: 更新帖子状态、发送通知等
return nil
}
// HandleImageCallback 处理图片审核回调
func (c *AuditCallback) HandleImageCallback(ctx context.Context, auditID string, result *AuditResult) error {
if c.service == nil || auditID == "" || result == nil {
return fmt.Errorf("invalid parameters")
}
log.Printf("Processing image audit callback: auditID=%s, result=%+v", auditID, result)
// 根据审核结果执行相应操作
// 例如: 更新图片状态、删除违规图片等
return nil
}
// ==================== 辅助函数 ====================
// IsContentSafe 判断内容是否安全
func IsContentSafe(result *AuditResult) bool {
if result == nil {
return true
}
return result.Pass && result.Suggest != "block"
}
// NeedReview 判断内容是否需要人工复审
func NeedReview(result *AuditResult) bool {
if result == nil {
return false
}
return result.Suggest == "review"
}
// GetRiskLevel 获取风险等级
func GetRiskLevel(result *AuditResult) string {
if result == nil {
return "low"
}
return result.Risk
}
// FormatAuditResult 格式化审核结果为字符串
func FormatAuditResult(result *AuditResult) string {
if result == nil {
return "{}"
}
data, _ := json.Marshal(result)
return string(data)
}
// ParseAuditResult 从字符串解析审核结果
func ParseAuditResult(data string) (*AuditResult, error) {
if data == "" {
return nil, fmt.Errorf("empty data")
}
var result AuditResult
if err := json.Unmarshal([]byte(data), &result); err != nil {
return nil, err
}
return &result, nil
}
// ==================== 审核日志查询 ====================
// GetAuditLogs 获取审核日志列表
func GetAuditLogs(db *gorm.DB, targetType string, targetID string, result string, page, pageSize int) ([]model.AuditLog, int64, error) {
query := db.Model(&model.AuditLog{})
if targetType != "" {
query = query.Where("target_type = ?", targetType)
}
if targetID != "" {
query = query.Where("target_id = ?", targetID)
}
if result != "" {
query = query.Where("result = ?", result)
}
var total int64
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
var logs []model.AuditLog
offset := (page - 1) * pageSize
if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&logs).Error; err != nil {
return nil, 0, err
}
return logs, total, nil
}
// ==================== 定时任务 ====================
// AuditScheduler 审核调度器
type AuditScheduler struct {
db *gorm.DB
service AuditService
interval time.Duration
stopCh chan bool
}
// NewAuditScheduler 创建审核调度器
func NewAuditScheduler(db *gorm.DB, service AuditService, interval time.Duration) *AuditScheduler {
return &AuditScheduler{
db: db,
service: service,
interval: interval,
stopCh: make(chan bool),
}
}
// Start 启动调度器
func (s *AuditScheduler) Start() {
go func() {
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.processPendingAudits()
case <-s.stopCh:
return
}
}
}()
}
// Stop 停止调度器
func (s *AuditScheduler) Stop() {
s.stopCh <- true
}
// processPendingAudits 处理待审核内容
func (s *AuditScheduler) processPendingAudits() {
// 查询待审核的内容
// 1. 查询审核状态为 pending 的记录
// 2. 调用审核服务
// 3. 更新审核状态
// 示例逻辑,实际需要根据业务需求实现
log.Println("Processing pending audits...")
}
// CleanupOldLogs 清理旧的审核日志
func CleanupOldLogs(db *gorm.DB, days int) error {
// 清理指定天数之前的审核日志
cutoffTime := time.Now().AddDate(0, 0, -days)
return db.Where("created_at < ? AND result = ?", cutoffTime, model.AuditResultPass).Delete(&model.AuditLog{}).Error
}

View File

@@ -0,0 +1,622 @@
package service
import (
"context"
"errors"
"fmt"
"log"
"time"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/websocket"
"carrot_bbs/internal/repository"
"gorm.io/gorm"
)
// 撤回消息的时间限制2分钟
const RecallMessageTimeout = 2 * time.Minute
// ChatService 聊天服务接口
type ChatService interface {
// 会话管理
GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error)
GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error)
GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error)
DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error
SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error
// 消息操作
SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error)
GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error)
GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error)
// 已读管理
MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error
GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error)
GetAllUnreadCount(ctx context.Context, userID string) (int64, error)
// 消息扩展功能
RecallMessage(ctx context.Context, messageID string, userID string) error
DeleteMessage(ctx context.Context, messageID string, userID string) error
// WebSocket相关
SendTyping(ctx context.Context, senderID string, conversationID string)
BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string)
// 系统消息推送
IsUserOnline(userID string) bool
PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error
PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error
PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error
// 仅保存消息到数据库,不发送 WebSocket 推送(供群聊等自行推送的场景使用)
SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error)
}
// chatServiceImpl 聊天服务实现
type chatServiceImpl struct {
db *gorm.DB
repo *repository.MessageRepository
userRepo *repository.UserRepository
sensitive SensitiveService
wsManager *websocket.WebSocketManager
}
// NewChatService 创建聊天服务
func NewChatService(
db *gorm.DB,
repo *repository.MessageRepository,
userRepo *repository.UserRepository,
sensitive SensitiveService,
wsManager *websocket.WebSocketManager,
) ChatService {
return &chatServiceImpl{
db: db,
repo: repo,
userRepo: userRepo,
sensitive: sensitive,
wsManager: wsManager,
}
}
// GetOrCreateConversation 获取或创建私聊会话
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
}
// GetConversationList 获取用户的会话列表
func (s *chatServiceImpl) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
return s.repo.GetConversations(userID, page, pageSize)
}
// GetConversationByID 获取会话详情
func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) {
// 验证用户是否是会话参与者
participant, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
// 获取会话信息
conv, err := s.repo.GetConversation(conversationID)
if err != nil {
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 填充用户的已读位置信息
_ = participant // 可以用于返回已读位置等信息
return conv, nil
}
// DeleteConversationForSelf 仅自己删除会话
func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error {
participant, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
if participant.ConversationID == "" {
return errors.New("conversation not found or no permission")
}
if err := s.repo.HideConversationForUser(conversationID, userID); err != nil {
return fmt.Errorf("failed to hide conversation: %w", err)
}
return nil
}
// SetConversationPinned 设置会话置顶(用户维度)
func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error {
participant, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
if participant.ConversationID == "" {
return errors.New("conversation not found or no permission")
}
if err := s.repo.UpdatePinned(conversationID, userID, isPinned); err != nil {
return fmt.Errorf("failed to update pinned status: %w", err)
}
return nil
}
// SendMessage 发送消息(使用 segments
func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
// 首先验证会话是否存在
conv, err := s.repo.GetConversation(conversationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("会话不存在,请重新创建会话")
}
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 拉黑限制:仅拦截“被拉黑方 -> 拉黑人”方向
if conv.Type == model.ConversationTypePrivate && s.userRepo != nil {
participants, pErr := s.repo.GetConversationParticipants(conversationID)
if pErr != nil {
return nil, fmt.Errorf("failed to get participants: %w", pErr)
}
var sentCount *int64
for _, p := range participants {
if p.UserID == senderID {
continue
}
blocked, bErr := s.userRepo.IsBlocked(p.UserID, senderID)
if bErr != nil {
return nil, fmt.Errorf("failed to check block status: %w", bErr)
}
if blocked {
return nil, ErrUserBlocked
}
// 陌生人限制:对方未回关前,只允许发送一条文本消息,且禁止发送图片
isFollowedBack, fErr := s.userRepo.IsFollowing(p.UserID, senderID)
if fErr != nil {
return nil, fmt.Errorf("failed to check follow status: %w", fErr)
}
if !isFollowedBack {
if containsImageSegment(segments) {
return nil, errors.New("对方未关注你,暂不支持发送图片")
}
if sentCount == nil {
c, cErr := s.repo.CountMessagesBySenderInConversation(conversationID, senderID)
if cErr != nil {
return nil, fmt.Errorf("failed to count sender messages: %w", cErr)
}
sentCount = &c
}
if *sentCount >= 1 {
return nil, errors.New("对方未关注你前,仅允许发送一条消息")
}
}
}
}
// 验证用户是否是会话参与者
participant, err := s.repo.GetParticipant(conversationID, senderID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("您不是该会话的参与者")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
// 创建消息
message := &model.Message{
ConversationID: conversationID,
SenderID: senderID, // 直接使用string类型的UUID
Segments: segments,
ReplyToID: replyToID,
Status: model.MessageStatusNormal,
}
// 使用事务创建消息并更新seq
if err := s.repo.CreateMessageWithSeq(message); err != nil {
return nil, fmt.Errorf("failed to save message: %w", err)
}
// 发送消息给接收者
log.Printf("[DEBUG SendMessage] 私聊消息 segments 类型: %T, 值: %+v", message.Segments, message.Segments)
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, websocket.ChatMessage{
ID: message.ID,
ConversationID: message.ConversationID,
SenderID: senderID,
Segments: message.Segments,
Seq: message.Seq,
CreatedAt: message.CreatedAt.UnixMilli(),
})
// 获取会话中的其他参与者
participants, err := s.repo.GetConversationParticipants(conversationID)
if err == nil {
for _, p := range participants {
// 不发给自己
if p.UserID == senderID {
continue
}
// 如果接收者在线,发送实时消息
if s.wsManager != nil {
isOnline := s.wsManager.IsUserOnline(p.UserID)
log.Printf("[DEBUG SendMessage] 接收者 UserID=%s, 在线状态=%v", p.UserID, isOnline)
if isOnline {
log.Printf("[DEBUG SendMessage] 发送WebSocket消息给 UserID=%s, 消息类型=%s", p.UserID, wsMsg.Type)
s.wsManager.SendToUser(p.UserID, wsMsg)
}
}
}
} else {
log.Printf("[DEBUG SendMessage] 获取参与者失败: %v", err)
}
_ = participant // 避免未使用变量警告
return message, nil
}
func containsImageSegment(segments model.MessageSegments) bool {
for _, seg := range segments {
if seg.Type == string(model.ContentTypeImage) || seg.Type == "image" {
return true
}
}
return false
}
// GetMessages 获取消息历史(分页)
func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, 0, errors.New("conversation not found or no permission")
}
return nil, 0, fmt.Errorf("failed to get participant: %w", err)
}
return s.repo.GetMessages(conversationID, page, pageSize)
}
// GetMessagesAfterSeq 获取指定seq之后的消息用于增量同步
func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
if limit <= 0 {
limit = 100
}
return s.repo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
}
// GetMessagesBeforeSeq 获取指定seq之前的历史消息用于下拉加载更多
func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
if limit <= 0 {
limit = 20
}
return s.repo.GetMessagesBeforeSeq(conversationID, beforeSeq, limit)
}
// MarkAsRead 标记已读
func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
}
return fmt.Errorf("failed to get participant: %w", err)
}
// 更新参与者的已读位置
err = s.repo.UpdateLastReadSeq(conversationID, userID, seq)
if err != nil {
return fmt.Errorf("failed to update last read seq: %w", err)
}
// 发送已读回执(作为 meta 事件)
if s.wsManager != nil {
wsMsg := websocket.CreateWSMessage("meta", map[string]interface{}{
"detail_type": websocket.MetaDetailTypeRead,
"conversation_id": conversationID,
"seq": seq,
"user_id": userID,
})
// 获取会话中的所有参与者
participants, err := s.repo.GetConversationParticipants(conversationID)
if err == nil {
// 推送给会话中的所有参与者(包括自己)
for _, p := range participants {
if s.wsManager.IsUserOnline(p.UserID) {
s.wsManager.SendToUser(p.UserID, wsMsg)
}
}
}
}
return nil
}
// GetUnreadCount 获取指定会话的未读消息数
func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return 0, errors.New("conversation not found or no permission")
}
return 0, fmt.Errorf("failed to get participant: %w", err)
}
return s.repo.GetUnreadCount(conversationID, userID)
}
// GetAllUnreadCount 获取所有会话的未读消息总数
func (s *chatServiceImpl) GetAllUnreadCount(ctx context.Context, userID string) (int64, error) {
return s.repo.GetAllUnreadCount(userID)
}
// RecallMessage 撤回消息2分钟内
func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, userID string) error {
// 获取消息
var message model.Message
err := s.db.First(&message, "id = ?", messageID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("message not found")
}
return fmt.Errorf("failed to get message: %w", err)
}
// 验证是否是消息发送者
if message.SenderIDStr() != userID {
return errors.New("can only recall your own messages")
}
// 验证消息是否已被撤回
if message.Status == model.MessageStatusRecalled {
return errors.New("message already recalled")
}
// 验证是否在2分钟内
if time.Since(message.CreatedAt) > RecallMessageTimeout {
return errors.New("message recall timeout (2 minutes)")
}
// 更新消息状态为已撤回
err = s.db.Model(&message).Update("status", model.MessageStatusRecalled).Error
if err != nil {
return fmt.Errorf("failed to recall message: %w", err)
}
// 发送撤回通知
if s.wsManager != nil {
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeRecall, map[string]interface{}{
"messageId": messageID,
"conversationId": message.ConversationID,
"senderId": userID,
})
// 通知会话中的所有参与者
participants, err := s.repo.GetConversationParticipants(message.ConversationID)
if err == nil {
for _, p := range participants {
if s.wsManager.IsUserOnline(p.UserID) {
s.wsManager.SendToUser(p.UserID, wsMsg)
}
}
}
}
return nil
}
// DeleteMessage 删除消息(仅对自己可见)
func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, userID string) error {
// 获取消息
var message model.Message
err := s.db.First(&message, "id = ?", messageID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("message not found")
}
return fmt.Errorf("failed to get message: %w", err)
}
// 验证用户是否是会话参与者
_, err = s.repo.GetParticipant(message.ConversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("no permission to delete this message")
}
return fmt.Errorf("failed to get participant: %w", err)
}
// 对于删除消息,我们使用软删除,但需要确保只对当前用户隐藏
// 这里简化处理:只有发送者可以删除自己的消息
if message.SenderIDStr() != userID {
return errors.New("can only delete your own messages")
}
// 更新消息状态为已删除
err = s.db.Model(&message).Update("status", model.MessageStatusDeleted).Error
if err != nil {
return fmt.Errorf("failed to delete message: %w", err)
}
return nil
}
// SendTyping 发送正在输入状态
func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) {
if s.wsManager == nil {
return
}
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, senderID)
if err != nil {
return
}
// 获取会话中的其他参与者
participants, err := s.repo.GetConversationParticipants(conversationID)
if err != nil {
return
}
for _, p := range participants {
if p.UserID == senderID {
continue
}
// 发送正在输入状态
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeTyping, map[string]string{
"conversationId": conversationID,
"senderId": senderID,
})
if s.wsManager.IsUserOnline(p.UserID) {
s.wsManager.SendToUser(p.UserID, wsMsg)
}
}
}
// BroadcastMessage 广播消息给用户
func (s *chatServiceImpl) BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string) {
if s.wsManager != nil {
s.wsManager.SendToUser(targetUser, msg)
}
}
// IsUserOnline 检查用户是否在线
func (s *chatServiceImpl) IsUserOnline(userID string) bool {
if s.wsManager == nil {
return false
}
return s.wsManager.IsUserOnline(userID)
}
// PushSystemMessage 推送系统消息给指定用户
func (s *chatServiceImpl) PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error {
if s.wsManager == nil {
return errors.New("websocket manager not available")
}
if !s.wsManager.IsUserOnline(userID) {
return errors.New("user is offline")
}
sysMsg := &websocket.SystemMessage{
ID: "", // 由调用方生成
Type: msgType,
Title: title,
Content: content,
Data: data,
CreatedAt: time.Now().UnixMilli(),
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
s.wsManager.SendToUser(userID, wsMsg)
return nil
}
// PushNotificationMessage 推送通知消息给指定用户
func (s *chatServiceImpl) PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error {
if s.wsManager == nil {
return errors.New("websocket manager not available")
}
if !s.wsManager.IsUserOnline(userID) {
return errors.New("user is offline")
}
// 确保时间戳已设置
if notification.CreatedAt == 0 {
notification.CreatedAt = time.Now().UnixMilli()
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
s.wsManager.SendToUser(userID, wsMsg)
return nil
}
// PushAnnouncementMessage 广播公告消息给所有在线用户
func (s *chatServiceImpl) PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error {
if s.wsManager == nil {
return errors.New("websocket manager not available")
}
// 确保时间戳已设置
if announcement.CreatedAt == 0 {
announcement.CreatedAt = time.Now().UnixMilli()
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
s.wsManager.Broadcast(wsMsg)
return nil
}
// SaveMessage 仅保存消息到数据库,不发送 WebSocket 推送
// 适用于群聊等由调用方自行负责推送的场景
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
// 验证会话是否存在
_, err := s.repo.GetConversation(conversationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("会话不存在,请重新创建会话")
}
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 验证用户是否是会话参与者
_, err = s.repo.GetParticipant(conversationID, senderID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("您不是该会话的参与者")
}
return nil, fmt.Errorf("failed to get participant: %w", err)
}
message := &model.Message{
ConversationID: conversationID,
SenderID: senderID,
Segments: segments,
ReplyToID: replyToID,
Status: model.MessageStatusNormal,
}
if err := s.repo.CreateMessageWithSeq(message); err != nil {
return nil, fmt.Errorf("failed to save message: %w", err)
}
return message, nil
}

View File

@@ -0,0 +1,273 @@
package service
import (
"context"
"errors"
"fmt"
"log"
"strings"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/gorse"
"carrot_bbs/internal/repository"
)
// CommentService 评论服务
type CommentService struct {
commentRepo *repository.CommentRepository
postRepo *repository.PostRepository
systemMessageService SystemMessageService
gorseClient gorse.Client
postAIService *PostAIService
}
// NewCommentService 创建评论服务
func NewCommentService(commentRepo *repository.CommentRepository, postRepo *repository.PostRepository, systemMessageService SystemMessageService, gorseClient gorse.Client, postAIService *PostAIService) *CommentService {
return &CommentService{
commentRepo: commentRepo,
postRepo: postRepo,
systemMessageService: systemMessageService,
gorseClient: gorseClient,
postAIService: postAIService,
}
}
// Create 创建评论
func (s *CommentService) Create(ctx context.Context, postID, userID, content string, parentID *string, images string, imageURLs []string) (*model.Comment, error) {
if s.postAIService != nil {
// 采用异步审核,前端先立即返回
}
// 获取帖子信息用于发送通知
post, err := s.postRepo.GetByID(postID)
if err != nil {
return nil, err
}
comment := &model.Comment{
PostID: postID,
UserID: userID,
Content: content,
ParentID: parentID,
Images: images,
Status: model.CommentStatusPending,
}
// 如果有父评论设置根评论ID
var parentUserID string
if parentID != nil {
parent, err := s.commentRepo.GetByID(*parentID)
if err == nil && parent != nil {
if parent.RootID != nil {
comment.RootID = parent.RootID
} else {
comment.RootID = parentID
}
parentUserID = parent.UserID
}
}
err = s.commentRepo.Create(comment)
if err != nil {
return nil, err
}
// 重新查询以获取关联的 User
comment, err = s.commentRepo.GetByID(comment.ID)
if err != nil {
return nil, err
}
go s.reviewCommentAsync(comment.ID, userID, postID, content, imageURLs, parentID, parentUserID, post.UserID)
return comment, nil
}
func (s *CommentService) reviewCommentAsync(
commentID, userID, postID, content string,
imageURLs []string,
parentID *string,
parentUserID string,
postOwnerID string,
) {
// 未启用AI时直接通过审核并发送后续通知
if s.postAIService == nil || !s.postAIService.IsEnabled() {
if err := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); err != nil {
log.Printf("[WARN] Failed to publish comment without AI moderation: %v", err)
return
}
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
return
}
err := s.postAIService.ModerateComment(context.Background(), content, imageURLs)
if err != nil {
var rejectedErr *CommentModerationRejectedError
if errors.As(err, &rejectedErr) {
if delErr := s.commentRepo.Delete(commentID); delErr != nil {
log.Printf("[WARN] Failed to delete rejected comment %s: %v", commentID, delErr)
}
s.notifyCommentModerationRejected(userID, rejectedErr.Reason)
return
}
// 审核服务异常时降级放行避免评论长期pending
if updateErr := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); updateErr != nil {
log.Printf("[WARN] Failed to publish comment %s after moderation error: %v", commentID, updateErr)
return
}
log.Printf("[WARN] Comment moderation failed, fallback publish comment=%s err=%v", commentID, err)
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
return
}
if updateErr := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); updateErr != nil {
log.Printf("[WARN] Failed to publish comment %s: %v", commentID, updateErr)
return
}
s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID)
}
func (s *CommentService) afterCommentPublished(userID, postID, commentID string, parentID *string, parentUserID, postOwnerID string) {
// 发送系统消息通知
if s.systemMessageService != nil {
go func() {
if parentID != nil && parentUserID != "" {
// 回复评论,通知被回复的人
if parentUserID != userID {
notifyErr := s.systemMessageService.SendReplyNotification(context.Background(), parentUserID, userID, postID, *parentID, commentID)
if notifyErr != nil {
fmt.Printf("[DEBUG] Error sending reply notification: %v\n", notifyErr)
}
}
} else {
// 评论帖子,通知帖子作者
if postOwnerID != userID {
notifyErr := s.systemMessageService.SendCommentNotification(context.Background(), postOwnerID, userID, postID, commentID)
if notifyErr != nil {
fmt.Printf("[DEBUG] Error sending comment notification: %v\n", notifyErr)
}
}
}
}()
}
// 推送评论行为到Gorse异步
go func() {
if s.gorseClient.IsEnabled() {
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeComment, userID, postID); err != nil {
log.Printf("[WARN] Failed to insert comment feedback to Gorse: %v", err)
}
}
}()
}
func (s *CommentService) notifyCommentModerationRejected(userID, reason string) {
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
return
}
content := "您发布的评论未通过AI审核请修改后重试。"
if strings.TrimSpace(reason) != "" {
content = fmt.Sprintf("您发布的评论未通过AI审核原因%s。请修改后重试。", reason)
}
go func() {
if err := s.systemMessageService.SendSystemAnnouncement(
context.Background(),
[]string{userID},
"评论审核未通过",
content,
); err != nil {
log.Printf("[WARN] Failed to send comment moderation reject notification: %v", err)
}
}()
}
// GetByID 根据ID获取评论
func (s *CommentService) GetByID(ctx context.Context, id string) (*model.Comment, error) {
return s.commentRepo.GetByID(id)
}
// GetByPostID 获取帖子评论
func (s *CommentService) GetByPostID(ctx context.Context, postID string, page, pageSize int) ([]*model.Comment, int64, error) {
// 使用带回复的查询默认加载前3条回复
return s.commentRepo.GetByPostIDWithReplies(postID, page, pageSize, 3)
}
// GetRepliesByRootID 根据根评论ID分页获取回复
func (s *CommentService) GetRepliesByRootID(ctx context.Context, rootID string, page, pageSize int) ([]*model.Comment, int64, error) {
return s.commentRepo.GetRepliesByRootID(rootID, page, pageSize)
}
// GetReplies 获取回复
func (s *CommentService) GetReplies(ctx context.Context, parentID string) ([]*model.Comment, error) {
return s.commentRepo.GetReplies(parentID)
}
// Update 更新评论
func (s *CommentService) Update(ctx context.Context, comment *model.Comment) error {
return s.commentRepo.Update(comment)
}
// Delete 删除评论
func (s *CommentService) Delete(ctx context.Context, id string) error {
return s.commentRepo.Delete(id)
}
// Like 点赞评论
func (s *CommentService) Like(ctx context.Context, commentID, userID string) error {
// 获取评论信息用于发送通知
comment, err := s.commentRepo.GetByID(commentID)
if err != nil {
return err
}
err = s.commentRepo.Like(commentID, userID)
if err != nil {
return err
}
// 发送评论/回复点赞通知(只有不是给自己点赞时才发送)
if s.systemMessageService != nil && comment.UserID != userID {
go func() {
var notifyErr error
if comment.ParentID != nil {
notifyErr = s.systemMessageService.SendLikeReplyNotification(
context.Background(),
comment.UserID,
userID,
comment.PostID,
commentID,
comment.Content,
)
} else {
notifyErr = s.systemMessageService.SendLikeCommentNotification(
context.Background(),
comment.UserID,
userID,
comment.PostID,
commentID,
comment.Content,
)
}
if notifyErr != nil {
fmt.Printf("[DEBUG] Error sending like notification: %v\n", notifyErr)
} else {
fmt.Printf("[DEBUG] Like notification sent successfully\n")
}
}()
}
return nil
}
// Unlike 取消点赞评论
func (s *CommentService) Unlike(ctx context.Context, commentID, userID string) error {
return s.commentRepo.Unlike(commentID, userID)
}
// IsLiked 检查是否已点赞
func (s *CommentService) IsLiked(ctx context.Context, commentID, userID string) bool {
return s.commentRepo.IsLiked(commentID, userID)
}

View File

@@ -0,0 +1,234 @@
package service
import (
"context"
"crypto/rand"
"encoding/json"
"fmt"
"math/big"
"strings"
"time"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/pkg/utils"
)
const (
verifyCodeTTL = 10 * time.Minute
verifyCodeRateLimitTTL = 60 * time.Second
)
const (
CodePurposeRegister = "register"
CodePurposePasswordReset = "password_reset"
CodePurposeEmailVerify = "email_verify"
CodePurposeChangePassword = "change_password"
)
type verificationCodePayload struct {
Code string `json:"code"`
Purpose string `json:"purpose"`
Email string `json:"email"`
ExpiresAt int64 `json:"expires_at"`
}
type EmailCodeService interface {
SendCode(ctx context.Context, purpose, email string) error
VerifyCode(purpose, email, code string) error
}
type emailCodeServiceImpl struct {
emailService EmailService
cache cache.Cache
}
func NewEmailCodeService(emailService EmailService, cacheBackend cache.Cache) EmailCodeService {
if cacheBackend == nil {
cacheBackend = cache.GetCache()
}
return &emailCodeServiceImpl{
emailService: emailService,
cache: cacheBackend,
}
}
func verificationCodeCacheKey(purpose, email string) string {
return fmt.Sprintf("auth:verify_code:%s:%s", purpose, strings.ToLower(strings.TrimSpace(email)))
}
func verificationCodeRateLimitKey(purpose, email string) string {
return fmt.Sprintf("auth:verify_code_rate_limit:%s:%s", purpose, strings.ToLower(strings.TrimSpace(email)))
}
func generateNumericCode(length int) (string, error) {
if length <= 0 {
return "", fmt.Errorf("invalid code length")
}
max := big.NewInt(10)
result := make([]byte, length)
for i := 0; i < length; i++ {
n, err := rand.Int(rand.Reader, max)
if err != nil {
return "", err
}
result[i] = byte('0' + n.Int64())
}
return string(result), nil
}
func (s *emailCodeServiceImpl) SendCode(ctx context.Context, purpose, email string) error {
if strings.TrimSpace(email) == "" || !utils.ValidateEmail(email) {
return ErrInvalidEmail
}
if s.emailService == nil || !s.emailService.IsEnabled() {
return ErrEmailServiceUnavailable
}
if s.cache == nil {
return ErrVerificationCodeUnavailable
}
rateLimitKey := verificationCodeRateLimitKey(purpose, email)
if s.cache.Exists(rateLimitKey) {
return ErrVerificationCodeTooFrequent
}
code, err := generateNumericCode(6)
if err != nil {
return fmt.Errorf("generate verification code failed: %w", err)
}
payload := verificationCodePayload{
Code: code,
Purpose: purpose,
Email: strings.ToLower(strings.TrimSpace(email)),
ExpiresAt: time.Now().Add(verifyCodeTTL).Unix(),
}
cacheKey := verificationCodeCacheKey(purpose, email)
s.cache.Set(cacheKey, payload, verifyCodeTTL)
s.cache.Set(rateLimitKey, "1", verifyCodeRateLimitTTL)
subject, sceneText := verificationEmailMeta(purpose)
textBody := fmt.Sprintf("【%s】验证码%s\n有效期10分钟\n请勿将验证码泄露给他人。", sceneText, code)
htmlBody := buildVerificationEmailHTML(sceneText, code)
if err := s.emailService.Send(ctx, SendEmailRequest{
To: []string{email},
Subject: subject,
TextBody: textBody,
HTMLBody: htmlBody,
}); err != nil {
s.cache.Delete(cacheKey)
return fmt.Errorf("send verification email failed: %w", err)
}
return nil
}
func (s *emailCodeServiceImpl) VerifyCode(purpose, email, code string) error {
if strings.TrimSpace(email) == "" || strings.TrimSpace(code) == "" {
return ErrVerificationCodeInvalid
}
if s.cache == nil {
return ErrVerificationCodeUnavailable
}
cacheKey := verificationCodeCacheKey(purpose, email)
raw, ok := s.cache.Get(cacheKey)
if !ok {
return ErrVerificationCodeExpired
}
var payload verificationCodePayload
switch v := raw.(type) {
case string:
if err := json.Unmarshal([]byte(v), &payload); err != nil {
return ErrVerificationCodeInvalid
}
case []byte:
if err := json.Unmarshal(v, &payload); err != nil {
return ErrVerificationCodeInvalid
}
case verificationCodePayload:
payload = v
default:
data, err := json.Marshal(v)
if err != nil {
return ErrVerificationCodeInvalid
}
if err := json.Unmarshal(data, &payload); err != nil {
return ErrVerificationCodeInvalid
}
}
if payload.Purpose != purpose || payload.Email != strings.ToLower(strings.TrimSpace(email)) {
return ErrVerificationCodeInvalid
}
if payload.ExpiresAt > 0 && time.Now().Unix() > payload.ExpiresAt {
s.cache.Delete(cacheKey)
return ErrVerificationCodeExpired
}
if payload.Code != strings.TrimSpace(code) {
return ErrVerificationCodeInvalid
}
s.cache.Delete(cacheKey)
return nil
}
func verificationEmailMeta(purpose string) (subject string, sceneText string) {
switch purpose {
case CodePurposeRegister:
return "Carrot BBS 注册验证码", "注册账号"
case CodePurposePasswordReset:
return "Carrot BBS 找回密码验证码", "找回密码"
case CodePurposeEmailVerify:
return "Carrot BBS 邮箱验证验证码", "验证邮箱"
case CodePurposeChangePassword:
return "Carrot BBS 修改密码验证码", "修改密码"
default:
return "Carrot BBS 验证码", "身份验证"
}
}
func buildVerificationEmailHTML(sceneText, code string) string {
return fmt.Sprintf(`<!doctype html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Carrot BBS 验证码</title>
</head>
<body style="margin:0;padding:0;background:#f4f6fb;font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,'PingFang SC','Microsoft YaHei',sans-serif;color:#1f2937;">
<table role="presentation" width="100%%" cellspacing="0" cellpadding="0" style="background:#f4f6fb;padding:24px 12px;">
<tr>
<td align="center">
<table role="presentation" width="100%%" cellspacing="0" cellpadding="0" style="max-width:560px;background:#ffffff;border-radius:14px;overflow:hidden;box-shadow:0 8px 30px rgba(15,23,42,0.08);">
<tr>
<td style="background:linear-gradient(135deg,#ff6b35,#ff8f66);padding:24px 28px;color:#ffffff;">
<div style="font-size:22px;font-weight:700;line-height:1.2;">Carrot BBS</div>
<div style="margin-top:6px;font-size:14px;opacity:0.95;">%s 验证</div>
</td>
</tr>
<tr>
<td style="padding:28px;">
<p style="margin:0 0 14px;font-size:15px;line-height:1.75;">你好,</p>
<p style="margin:0 0 20px;font-size:15px;line-height:1.75;">你正在进行 <strong>%s</strong> 操作,请使用下方验证码完成验证:</p>
<div style="margin:0 auto 18px;max-width:320px;border:1px dashed #ff8f66;background:#fff8f4;border-radius:12px;padding:14px 12px;text-align:center;">
<div style="font-size:13px;color:#9a3412;letter-spacing:0.5px;">验证码10分钟内有效</div>
<div style="margin-top:8px;font-size:34px;line-height:1;font-weight:800;letter-spacing:8px;color:#ea580c;">%s</div>
</div>
<p style="margin:0 0 8px;font-size:13px;color:#6b7280;line-height:1.7;">如果不是你本人操作,请忽略此邮件,并及时检查账号安全。</p>
<p style="margin:0;font-size:13px;color:#6b7280;line-height:1.7;">请勿向任何人透露验证码,平台不会以任何理由索取验证码。</p>
</td>
</tr>
<tr>
<td style="padding:14px 28px;background:#f8fafc;border-top:1px solid #e5e7eb;color:#94a3b8;font-size:12px;line-height:1.7;">
此邮件由系统自动发送,请勿直接回复。<br/>
© Carrot BBS
</td>
</tr>
</table>
</td>
</tr>
</table>
</body>
</html>`, sceneText, sceneText, code)
}

View File

@@ -0,0 +1,82 @@
package service
import (
"context"
"fmt"
"strings"
emailpkg "carrot_bbs/internal/pkg/email"
)
// SendEmailRequest 发信请求
type SendEmailRequest struct {
To []string
Cc []string
Bcc []string
ReplyTo []string
Subject string
TextBody string
HTMLBody string
Attachments []string
}
type EmailService interface {
IsEnabled() bool
Send(ctx context.Context, req SendEmailRequest) error
SendText(ctx context.Context, to []string, subject, body string) error
SendHTML(ctx context.Context, to []string, subject, html string) error
}
type emailServiceImpl struct {
client emailpkg.Client
}
func NewEmailService(client emailpkg.Client) EmailService {
return &emailServiceImpl{client: client}
}
func (s *emailServiceImpl) IsEnabled() bool {
return s.client != nil && s.client.IsEnabled()
}
func (s *emailServiceImpl) Send(ctx context.Context, req SendEmailRequest) error {
if s.client == nil {
return fmt.Errorf("email client is nil")
}
if !s.client.IsEnabled() {
return fmt.Errorf("email service is disabled")
}
if len(req.To) == 0 {
return fmt.Errorf("email recipient is empty")
}
if strings.TrimSpace(req.Subject) == "" {
return fmt.Errorf("email subject is empty")
}
return s.client.Send(ctx, emailpkg.Message{
To: req.To,
Cc: req.Cc,
Bcc: req.Bcc,
ReplyTo: req.ReplyTo,
Subject: req.Subject,
TextBody: req.TextBody,
HTMLBody: req.HTMLBody,
Attachments: req.Attachments,
})
}
func (s *emailServiceImpl) SendText(ctx context.Context, to []string, subject, body string) error {
return s.Send(ctx, SendEmailRequest{
To: to,
Subject: subject,
TextBody: body,
})
}
func (s *emailServiceImpl) SendHTML(ctx context.Context, to []string, subject, html string) error {
return s.Send(ctx, SendEmailRequest{
To: to,
Subject: subject,
HTMLBody: html,
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,38 @@
package service
import (
"carrot_bbs/internal/pkg/jwt"
"time"
)
// JWTService JWT服务
type JWTService struct {
jwt *jwt.JWT
}
// NewJWTService 创建JWT服务
func NewJWTService(secret string, accessExpire, refreshExpire int64) *JWTService {
return &JWTService{
jwt: jwt.New(secret, time.Duration(accessExpire)*time.Second, time.Duration(refreshExpire)*time.Second),
}
}
// GenerateAccessToken 生成访问令牌
func (s *JWTService) GenerateAccessToken(userID, username string) (string, error) {
return s.jwt.GenerateAccessToken(userID, username)
}
// GenerateRefreshToken 生成刷新令牌
func (s *JWTService) GenerateRefreshToken(userID, username string) (string, error) {
return s.jwt.GenerateRefreshToken(userID, username)
}
// ParseToken 解析令牌
func (s *JWTService) ParseToken(tokenString string) (*jwt.Claims, error) {
return s.jwt.ParseToken(tokenString)
}
// ValidateToken 验证令牌
func (s *JWTService) ValidateToken(tokenString string) error {
return s.jwt.ValidateToken(tokenString)
}

View File

@@ -0,0 +1,215 @@
package service
import (
"context"
"time"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/model"
"carrot_bbs/internal/repository"
)
// 缓存TTL常量
const (
ConversationListTTL = 60 * time.Second // 会话列表缓存60秒
ConversationDetailTTL = 60 * time.Second // 会话详情缓存60秒
UnreadCountTTL = 30 * time.Second // 未读数缓存30秒
ConversationNullTTL = 5 * time.Second
UnreadNullTTL = 5 * time.Second
CacheJitterRatio = 0.1
)
// MessageService 消息服务
type MessageService struct {
messageRepo *repository.MessageRepository
cache cache.Cache
}
// NewMessageService 创建消息服务
func NewMessageService(messageRepo *repository.MessageRepository) *MessageService {
return &MessageService{
messageRepo: messageRepo,
cache: cache.GetCache(),
}
}
// ConversationListResult 会话列表缓存结果
type ConversationListResult struct {
Conversations []*model.Conversation
Total int64
}
// SendMessage 发送消息(使用 segments
// senderID 和 receiverID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) SendMessage(ctx context.Context, senderID, receiverID string, segments model.MessageSegments) (*model.Message, error) {
// 获取或创建会话
conv, err := s.messageRepo.GetOrCreatePrivateConversation(senderID, receiverID)
if err != nil {
return nil, err
}
msg := &model.Message{
ConversationID: conv.ID,
SenderID: senderID,
Segments: segments,
Status: model.MessageStatusNormal,
}
// 使用事务创建消息并更新seq
err = s.messageRepo.CreateMessageWithSeq(msg)
if err != nil {
return nil, err
}
// 失效会话列表缓存(发送者和接收者)
cache.InvalidateConversationList(s.cache, senderID)
cache.InvalidateConversationList(s.cache, receiverID)
// 失效未读数缓存
cache.InvalidateUnreadConversation(s.cache, receiverID)
cache.InvalidateUnreadDetail(s.cache, receiverID, conv.ID)
return msg, nil
}
// GetConversations 获取会话列表(带缓存)
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) GetConversations(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
cacheSettings := cache.GetSettings()
conversationTTL := cacheSettings.ConversationTTL
if conversationTTL <= 0 {
conversationTTL = ConversationListTTL
}
nullTTL := cacheSettings.NullTTL
if nullTTL <= 0 {
nullTTL = ConversationNullTTL
}
jitter := cacheSettings.JitterRatio
if jitter <= 0 {
jitter = CacheJitterRatio
}
// 生成缓存键
cacheKey := cache.ConversationListKey(userID, page, pageSize)
result, err := cache.GetOrLoadTyped[*ConversationListResult](
s.cache,
cacheKey,
conversationTTL,
jitter,
nullTTL,
func() (*ConversationListResult, error) {
conversations, total, err := s.messageRepo.GetConversations(userID, page, pageSize)
if err != nil {
return nil, err
}
return &ConversationListResult{
Conversations: conversations,
Total: total,
}, nil
},
)
if err != nil {
return nil, 0, err
}
if result == nil {
return []*model.Conversation{}, 0, nil
}
return result.Conversations, result.Total, nil
}
// GetMessages 获取消息列表
func (s *MessageService) GetMessages(ctx context.Context, conversationID string, page, pageSize int) ([]*model.Message, int64, error) {
return s.messageRepo.GetMessages(conversationID, page, pageSize)
}
// GetMessagesAfterSeq 获取指定seq之后的消息增量同步
func (s *MessageService) GetMessagesAfterSeq(ctx context.Context, conversationID string, afterSeq int64, limit int) ([]*model.Message, error) {
return s.messageRepo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
}
// MarkAsRead 标记为已读
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) MarkAsRead(ctx context.Context, conversationID string, userID string, lastReadSeq int64) error {
err := s.messageRepo.UpdateLastReadSeq(conversationID, userID, lastReadSeq)
if err != nil {
return err
}
// 失效未读数缓存
cache.InvalidateUnreadConversation(s.cache, userID)
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
// 失效会话列表缓存
cache.InvalidateConversationList(s.cache, userID)
return nil
}
// GetUnreadCount 获取未读消息数(带缓存)
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
cacheSettings := cache.GetSettings()
unreadTTL := cacheSettings.UnreadCountTTL
if unreadTTL <= 0 {
unreadTTL = UnreadCountTTL
}
nullTTL := cacheSettings.NullTTL
if nullTTL <= 0 {
nullTTL = UnreadNullTTL
}
jitter := cacheSettings.JitterRatio
if jitter <= 0 {
jitter = CacheJitterRatio
}
// 生成缓存键
cacheKey := cache.UnreadDetailKey(userID, conversationID)
return cache.GetOrLoadTyped[int64](
s.cache,
cacheKey,
unreadTTL,
jitter,
nullTTL,
func() (int64, error) {
return s.messageRepo.GetUnreadCount(conversationID, userID)
},
)
}
// GetOrCreateConversation 获取或创建私聊会话
// user1ID 和 user2ID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
conv, err := s.messageRepo.GetOrCreatePrivateConversation(user1ID, user2ID)
if err != nil {
return nil, err
}
// 失效会话列表缓存
cache.InvalidateConversationList(s.cache, user1ID)
cache.InvalidateConversationList(s.cache, user2ID)
return conv, nil
}
// GetConversationParticipants 获取会话参与者列表
func (s *MessageService) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) {
return s.messageRepo.GetConversationParticipants(conversationID)
}
// ParseConversationID 辅助函数直接返回字符串ID已经是string类型
func ParseConversationID(idStr string) (string, error) {
return idStr, nil
}
// InvalidateUserConversationCache 失效用户会话相关缓存(供外部调用)
func (s *MessageService) InvalidateUserConversationCache(userID string) {
cache.InvalidateConversationList(s.cache, userID)
cache.InvalidateUnreadConversation(s.cache, userID)
}
// InvalidateUserUnreadCache 失效用户未读数缓存(供外部调用)
func (s *MessageService) InvalidateUserUnreadCache(userID, conversationID string) {
cache.InvalidateUnreadConversation(s.cache, userID)
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
}

View File

@@ -0,0 +1,169 @@
package service
import (
"context"
"time"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/model"
"carrot_bbs/internal/repository"
)
// 缓存TTL常量
const (
NotificationUnreadCountTTL = 30 * time.Second // 通知未读数缓存30秒
NotificationNullTTL = 5 * time.Second
NotificationCacheJitter = 0.1
)
// NotificationService 通知服务
type NotificationService struct {
notificationRepo *repository.NotificationRepository
cache cache.Cache
}
// NewNotificationService 创建通知服务
func NewNotificationService(notificationRepo *repository.NotificationRepository) *NotificationService {
return &NotificationService{
notificationRepo: notificationRepo,
cache: cache.GetCache(),
}
}
// Create 创建通知
func (s *NotificationService) Create(ctx context.Context, userID string, notificationType model.NotificationType, title, content string) (*model.Notification, error) {
notification := &model.Notification{
UserID: userID,
Type: notificationType,
Title: title,
Content: content,
IsRead: false,
}
err := s.notificationRepo.Create(notification)
if err != nil {
return nil, err
}
// 失效未读数缓存
cache.InvalidateUnreadSystem(s.cache, userID)
return notification, nil
}
// GetByUserID 获取用户通知
func (s *NotificationService) GetByUserID(ctx context.Context, userID string, page, pageSize int, unreadOnly bool) ([]*model.Notification, int64, error) {
return s.notificationRepo.GetByUserID(userID, page, pageSize, unreadOnly)
}
// MarkAsRead 标记为已读
func (s *NotificationService) MarkAsRead(ctx context.Context, id string) error {
err := s.notificationRepo.MarkAsRead(id)
if err != nil {
return err
}
// 注意这里无法获取userID所以不在缓存中失效
// 调用方应该使用MarkAsReadWithUserID方法
return nil
}
// MarkAsReadWithUserID 标记为已读带用户ID用于缓存失效
func (s *NotificationService) MarkAsReadWithUserID(ctx context.Context, id, userID string) error {
err := s.notificationRepo.MarkAsRead(id)
if err != nil {
return err
}
// 失效未读数缓存
cache.InvalidateUnreadSystem(s.cache, userID)
return nil
}
// MarkAllAsRead 标记所有为已读
func (s *NotificationService) MarkAllAsRead(ctx context.Context, userID string) error {
err := s.notificationRepo.MarkAllAsRead(userID)
if err != nil {
return err
}
// 失效未读数缓存
cache.InvalidateUnreadSystem(s.cache, userID)
return nil
}
// Delete 删除通知
func (s *NotificationService) Delete(ctx context.Context, id string) error {
return s.notificationRepo.Delete(id)
}
// GetUnreadCount 获取未读数量(带缓存)
func (s *NotificationService) GetUnreadCount(ctx context.Context, userID string) (int64, error) {
cacheSettings := cache.GetSettings()
unreadTTL := cacheSettings.UnreadCountTTL
if unreadTTL <= 0 {
unreadTTL = NotificationUnreadCountTTL
}
nullTTL := cacheSettings.NullTTL
if nullTTL <= 0 {
nullTTL = NotificationNullTTL
}
jitter := cacheSettings.JitterRatio
if jitter <= 0 {
jitter = NotificationCacheJitter
}
// 生成缓存键
cacheKey := cache.UnreadSystemKey(userID)
return cache.GetOrLoadTyped[int64](
s.cache,
cacheKey,
unreadTTL,
jitter,
nullTTL,
func() (int64, error) {
return s.notificationRepo.GetUnreadCount(userID)
},
)
}
// DeleteNotification 删除通知(带用户验证)
func (s *NotificationService) DeleteNotification(ctx context.Context, id, userID string) error {
// 先检查通知是否属于该用户
notification, err := s.notificationRepo.GetByID(id)
if err != nil {
return err
}
if notification.UserID != userID {
return ErrUnauthorizedNotification
}
err = s.notificationRepo.Delete(id)
if err != nil {
return err
}
// 失效未读数缓存
cache.InvalidateUnreadSystem(s.cache, userID)
return nil
}
// ClearAllNotifications 清空所有通知
func (s *NotificationService) ClearAllNotifications(ctx context.Context, userID string) error {
err := s.notificationRepo.DeleteAllByUserID(userID)
if err != nil {
return err
}
// 失效未读数缓存
cache.InvalidateUnreadSystem(s.cache, userID)
return nil
}
// 错误定义
var ErrUnauthorizedNotification = &ServiceError{Code: 403, Message: "unauthorized to delete this notification"}

View File

@@ -0,0 +1,103 @@
package service
import (
"context"
"log"
"strings"
"carrot_bbs/internal/pkg/openai"
)
// PostModerationRejectedError 帖子审核拒绝错误
type PostModerationRejectedError struct {
Reason string
}
func (e *PostModerationRejectedError) Error() string {
if strings.TrimSpace(e.Reason) == "" {
return "post rejected by moderation"
}
return "post rejected by moderation: " + e.Reason
}
// UserMessage 返回给前端的用户可读文案
func (e *PostModerationRejectedError) UserMessage() string {
if strings.TrimSpace(e.Reason) == "" {
return "内容未通过审核,请修改后重试"
}
return strings.TrimSpace(e.Reason)
}
// CommentModerationRejectedError 评论审核拒绝错误
type CommentModerationRejectedError struct {
Reason string
}
func (e *CommentModerationRejectedError) Error() string {
if strings.TrimSpace(e.Reason) == "" {
return "comment rejected by moderation"
}
return "comment rejected by moderation: " + e.Reason
}
// UserMessage 返回给前端的用户可读文案
func (e *CommentModerationRejectedError) UserMessage() string {
if strings.TrimSpace(e.Reason) == "" {
return "评论未通过审核,请修改后重试"
}
return strings.TrimSpace(e.Reason)
}
type PostAIService struct {
openAIClient openai.Client
}
func NewPostAIService(openAIClient openai.Client) *PostAIService {
return &PostAIService{
openAIClient: openAIClient,
}
}
func (s *PostAIService) IsEnabled() bool {
return s != nil && s.openAIClient != nil && s.openAIClient.IsEnabled()
}
// ModeratePost 审核帖子内容,返回 nil 表示通过
func (s *PostAIService) ModeratePost(ctx context.Context, title, content string, images []string) error {
if !s.IsEnabled() {
return nil
}
approved, reason, err := s.openAIClient.ModeratePost(ctx, title, content, images)
if err != nil {
if s.openAIClient.Config().StrictModeration {
return err
}
log.Printf("[WARN] AI moderation failed, fallback allow: %v", err)
return nil
}
if !approved {
return &PostModerationRejectedError{Reason: reason}
}
return nil
}
// ModerateComment 审核评论内容,返回 nil 表示通过
func (s *PostAIService) ModerateComment(ctx context.Context, content string, images []string) error {
if !s.IsEnabled() {
return nil
}
approved, reason, err := s.openAIClient.ModerateComment(ctx, content, images)
if err != nil {
if s.openAIClient.Config().StrictModeration {
return err
}
log.Printf("[WARN] AI comment moderation failed, fallback allow: %v", err)
return nil
}
if !approved {
return &CommentModerationRejectedError{Reason: reason}
}
return nil
}

View File

@@ -0,0 +1,593 @@
package service
import (
"context"
"errors"
"fmt"
"log"
"strings"
"time"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/gorse"
"carrot_bbs/internal/repository"
)
// 缓存TTL常量
const (
PostListTTL = 30 * time.Second // 帖子列表缓存30秒
PostListNullTTL = 5 * time.Second
PostListJitterRatio = 0.15
anonymousViewUserID = "_anon_view"
)
// PostService 帖子服务
type PostService struct {
postRepo *repository.PostRepository
systemMessageService SystemMessageService
cache cache.Cache
gorseClient gorse.Client
postAIService *PostAIService
}
// NewPostService 创建帖子服务
func NewPostService(postRepo *repository.PostRepository, systemMessageService SystemMessageService, gorseClient gorse.Client, postAIService *PostAIService) *PostService {
return &PostService{
postRepo: postRepo,
systemMessageService: systemMessageService,
cache: cache.GetCache(),
gorseClient: gorseClient,
postAIService: postAIService,
}
}
// PostListResult 帖子列表缓存结果
type PostListResult struct {
Posts []*model.Post
Total int64
}
// Create 创建帖子
func (s *PostService) Create(ctx context.Context, userID, title, content string, images []string) (*model.Post, error) {
post := &model.Post{
UserID: userID,
Title: title,
Content: content,
Status: model.PostStatusPending,
}
err := s.postRepo.Create(post, images)
if err != nil {
return nil, err
}
// 失效帖子列表缓存
cache.InvalidatePostList(s.cache)
// 同步到Gorse推荐系统异步
go s.reviewPostAsync(post.ID, userID, title, content, images)
// 重新查询以获取关联的 User 和 Images
return s.postRepo.GetByID(post.ID)
}
func (s *PostService) reviewPostAsync(postID, userID, title, content string, images []string) {
// 未启用AI时直接发布
if s.postAIService == nil || !s.postAIService.IsEnabled() {
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil {
log.Printf("[WARN] Failed to publish post without AI moderation: %v", err)
}
return
}
err := s.postAIService.ModeratePost(context.Background(), title, content, images)
if err != nil {
var rejectedErr *PostModerationRejectedError
if errors.As(err, &rejectedErr) {
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
log.Printf("[WARN] Failed to reject post %s: %v", postID, updateErr)
}
s.notifyModerationRejected(userID, rejectedErr.Reason)
return
}
// 规则审核不可用时降级为发布避免长时间pending
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
log.Printf("[WARN] Failed to publish post %s after moderation error: %v", postID, updateErr)
}
log.Printf("[WARN] Post moderation failed, fallback publish post=%s err=%v", postID, err)
return
}
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "ai"); err != nil {
log.Printf("[WARN] Failed to publish post %s: %v", postID, err)
return
}
if s.gorseClient.IsEnabled() {
post, getErr := s.postRepo.GetByID(postID)
if getErr != nil {
log.Printf("[WARN] Failed to load published post for gorse sync: %v", getErr)
return
}
categories := s.buildPostCategories(post)
comment := post.Title
textToEmbed := post.Title + " " + post.Content
if upsertErr := s.gorseClient.UpsertItemWithEmbedding(context.Background(), post.ID, categories, comment, textToEmbed); upsertErr != nil {
log.Printf("[WARN] Failed to upsert item to Gorse: %v", upsertErr)
}
}
}
func (s *PostService) notifyModerationRejected(userID, reason string) {
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
return
}
content := "您发布的帖子未通过AI审核请修改后重试。"
if strings.TrimSpace(reason) != "" {
content = fmt.Sprintf("您发布的帖子未通过AI审核原因%s。请修改后重试。", reason)
}
go func() {
if err := s.systemMessageService.SendSystemAnnouncement(
context.Background(),
[]string{userID},
"帖子审核未通过",
content,
); err != nil {
log.Printf("[WARN] Failed to send moderation reject notification: %v", err)
}
}()
}
// GetByID 根据ID获取帖子
func (s *PostService) GetByID(ctx context.Context, id string) (*model.Post, error) {
return s.postRepo.GetByID(id)
}
// Update 更新帖子
func (s *PostService) Update(ctx context.Context, post *model.Post) error {
err := s.postRepo.Update(post)
if err != nil {
return err
}
// 失效帖子详情缓存和列表缓存
cache.InvalidatePostDetail(s.cache, post.ID)
cache.InvalidatePostList(s.cache)
return nil
}
// Delete 删除帖子
func (s *PostService) Delete(ctx context.Context, id string) error {
err := s.postRepo.Delete(id)
if err != nil {
return err
}
// 失效帖子详情缓存和列表缓存
cache.InvalidatePostDetail(s.cache, id)
cache.InvalidatePostList(s.cache)
// 从Gorse中删除帖子异步
go func() {
if s.gorseClient.IsEnabled() {
if err := s.gorseClient.DeleteItem(context.Background(), id); err != nil {
log.Printf("[WARN] Failed to delete item from Gorse: %v", err)
}
}
}()
return nil
}
// List 获取帖子列表(带缓存)
func (s *PostService) List(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
cacheSettings := cache.GetSettings()
postListTTL := cacheSettings.PostListTTL
if postListTTL <= 0 {
postListTTL = PostListTTL
}
nullTTL := cacheSettings.NullTTL
if nullTTL <= 0 {
nullTTL = PostListNullTTL
}
jitter := cacheSettings.JitterRatio
if jitter <= 0 {
jitter = PostListJitterRatio
}
// 生成缓存键(包含 userID 维度,避免过滤查询与全量查询互相污染)
cacheKey := cache.PostListKey("latest", userID, page, pageSize)
result, err := cache.GetOrLoadTyped[*PostListResult](
s.cache,
cacheKey,
postListTTL,
jitter,
nullTTL,
func() (*PostListResult, error) {
posts, total, err := s.postRepo.List(page, pageSize, userID)
if err != nil {
return nil, err
}
return &PostListResult{Posts: posts, Total: total}, nil
},
)
if err != nil {
return nil, 0, err
}
if result == nil {
return []*model.Post{}, 0, nil
}
// 兼容历史脏缓存:旧缓存序列化会丢失 Post.User导致前端显示“匿名用户”
// 这里检测并回源重建一次缓存,避免在 TTL 内持续返回缺失作者的数据。
missingAuthor := false
for _, post := range result.Posts {
if post != nil && post.UserID != "" && post.User == nil {
missingAuthor = true
break
}
}
if missingAuthor {
posts, total, loadErr := s.postRepo.List(page, pageSize, userID)
if loadErr != nil {
return nil, 0, loadErr
}
result = &PostListResult{Posts: posts, Total: total}
cache.SetWithJitter(s.cache, cacheKey, result, postListTTL, jitter)
}
return result.Posts, result.Total, nil
}
// GetLatestPosts 获取最新帖子(语义化别名)
func (s *PostService) GetLatestPosts(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) {
return s.List(ctx, page, pageSize, userID)
}
// GetUserPosts 获取用户帖子
func (s *PostService) GetUserPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
return s.postRepo.GetUserPosts(userID, page, pageSize)
}
// Like 点赞
func (s *PostService) Like(ctx context.Context, postID, userID string) error {
// 获取帖子信息用于发送通知
post, err := s.postRepo.GetByID(postID)
if err != nil {
return err
}
err = s.postRepo.Like(postID, userID)
if err != nil {
return err
}
// 失效帖子详情缓存
cache.InvalidatePostDetail(s.cache, postID)
// 发送点赞通知(不给自己发通知)
if s.systemMessageService != nil && post.UserID != userID {
go func() {
notifyErr := s.systemMessageService.SendLikeNotification(context.Background(), post.UserID, userID, postID)
if notifyErr != nil {
fmt.Printf("[DEBUG] Error sending like notification: %v\n", notifyErr)
} else {
fmt.Printf("[DEBUG] Like notification sent successfully\n")
}
}()
}
// 推送点赞行为到Gorse异步
go func() {
if s.gorseClient.IsEnabled() {
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeLike, userID, postID); err != nil {
log.Printf("[WARN] Failed to insert like feedback to Gorse: %v", err)
}
}
}()
return nil
}
// Unlike 取消点赞
func (s *PostService) Unlike(ctx context.Context, postID, userID string) error {
err := s.postRepo.Unlike(postID, userID)
if err != nil {
return err
}
// 失效帖子详情缓存
cache.InvalidatePostDetail(s.cache, postID)
// 删除Gorse中的点赞反馈异步
go func() {
if s.gorseClient.IsEnabled() {
if err := s.gorseClient.DeleteFeedback(context.Background(), gorse.FeedbackTypeLike, userID, postID); err != nil {
log.Printf("[WARN] Failed to delete like feedback from Gorse: %v", err)
}
}
}()
return nil
}
// IsLiked 检查是否点赞
func (s *PostService) IsLiked(ctx context.Context, postID, userID string) bool {
return s.postRepo.IsLiked(postID, userID)
}
// Favorite 收藏
func (s *PostService) Favorite(ctx context.Context, postID, userID string) error {
// 获取帖子信息用于发送通知
post, err := s.postRepo.GetByID(postID)
if err != nil {
return err
}
err = s.postRepo.Favorite(postID, userID)
if err != nil {
return err
}
// 失效帖子详情缓存
cache.InvalidatePostDetail(s.cache, postID)
// 发送收藏通知(不给自己发通知)
if s.systemMessageService != nil && post.UserID != userID {
go func() {
notifyErr := s.systemMessageService.SendFavoriteNotification(context.Background(), post.UserID, userID, postID)
if notifyErr != nil {
fmt.Printf("[DEBUG] Error sending favorite notification: %v\n", notifyErr)
} else {
fmt.Printf("[DEBUG] Favorite notification sent successfully\n")
}
}()
}
// 推送收藏行为到Gorse异步
go func() {
if s.gorseClient.IsEnabled() {
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeStar, userID, postID); err != nil {
log.Printf("[WARN] Failed to insert favorite feedback to Gorse: %v", err)
}
}
}()
return nil
}
// Unfavorite 取消收藏
func (s *PostService) Unfavorite(ctx context.Context, postID, userID string) error {
err := s.postRepo.Unfavorite(postID, userID)
if err != nil {
return err
}
// 失效帖子详情缓存
cache.InvalidatePostDetail(s.cache, postID)
// 删除Gorse中的收藏反馈异步
go func() {
if s.gorseClient.IsEnabled() {
if err := s.gorseClient.DeleteFeedback(context.Background(), gorse.FeedbackTypeStar, userID, postID); err != nil {
log.Printf("[WARN] Failed to delete favorite feedback from Gorse: %v", err)
}
}
}()
return nil
}
// IsFavorited 检查是否收藏
func (s *PostService) IsFavorited(ctx context.Context, postID, userID string) bool {
return s.postRepo.IsFavorited(postID, userID)
}
// IncrementViews 增加帖子观看量并同步到Gorse
func (s *PostService) IncrementViews(ctx context.Context, postID, userID string) error {
if err := s.postRepo.IncrementViews(postID); err != nil {
return err
}
// 同步浏览行为到Gorse异步
go func() {
if !s.gorseClient.IsEnabled() {
return
}
feedbackUserID := userID
if feedbackUserID == "" {
feedbackUserID = anonymousViewUserID
}
if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeRead, feedbackUserID, postID); err != nil {
log.Printf("[WARN] Failed to insert read feedback to Gorse: %v", err)
}
}()
return nil
}
// GetFavorites 获取收藏列表
func (s *PostService) GetFavorites(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
return s.postRepo.GetFavorites(userID, page, pageSize)
}
// Search 搜索帖子
func (s *PostService) Search(ctx context.Context, keyword string, page, pageSize int) ([]*model.Post, int64, error) {
return s.postRepo.Search(keyword, page, pageSize)
}
// GetFollowingPosts 获取关注用户的帖子(带缓存)
func (s *PostService) GetFollowingPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
cacheSettings := cache.GetSettings()
postListTTL := cacheSettings.PostListTTL
if postListTTL <= 0 {
postListTTL = PostListTTL
}
nullTTL := cacheSettings.NullTTL
if nullTTL <= 0 {
nullTTL = PostListNullTTL
}
jitter := cacheSettings.JitterRatio
if jitter <= 0 {
jitter = PostListJitterRatio
}
// 生成缓存键
cacheKey := cache.PostListKey("follow", userID, page, pageSize)
result, err := cache.GetOrLoadTyped[*PostListResult](
s.cache,
cacheKey,
postListTTL,
jitter,
nullTTL,
func() (*PostListResult, error) {
posts, total, err := s.postRepo.GetFollowingPosts(userID, page, pageSize)
if err != nil {
return nil, err
}
return &PostListResult{Posts: posts, Total: total}, nil
},
)
if err != nil {
return nil, 0, err
}
if result == nil {
return []*model.Post{}, 0, nil
}
return result.Posts, result.Total, nil
}
// GetHotPosts 获取热门帖子使用Gorse非个性化推荐
func (s *PostService) GetHotPosts(ctx context.Context, page, pageSize int) ([]*model.Post, int64, error) {
// 如果Gorse启用使用自定义的非个性化推荐器
if s.gorseClient.IsEnabled() {
offset := (page - 1) * pageSize
// 使用 most_liked_weekly 推荐器获取周热门
// 多取1条用于判断是否还有下一页
itemIDs, err := s.gorseClient.GetNonPersonalized(ctx, "most_liked_weekly", pageSize+1, offset, "")
if err != nil {
log.Printf("[WARN] Gorse GetNonPersonalized failed: %v, fallback to database", err)
return s.getHotPostsFromDB(ctx, page, pageSize)
}
if len(itemIDs) > 0 {
hasNext := len(itemIDs) > pageSize
if hasNext {
itemIDs = itemIDs[:pageSize]
}
posts, err := s.postRepo.GetByIDs(itemIDs)
if err != nil {
return nil, 0, err
}
// 近似 total当 hasNext 为 true 时,按分页窗口估算,避免因脏数据/缺失数据导致总页数被低估
estimatedTotal := int64(offset + len(posts))
if hasNext {
estimatedTotal = int64(offset + pageSize + 1)
}
return posts, estimatedTotal, nil
}
}
// 降级:从数据库获取
return s.getHotPostsFromDB(ctx, page, pageSize)
}
// getHotPostsFromDB 从数据库获取热门帖子(降级路径)
func (s *PostService) getHotPostsFromDB(ctx context.Context, page, pageSize int) ([]*model.Post, int64, error) {
// 直接查询数据库不再使用本地缓存Gorse失败降级时使用
posts, total, err := s.postRepo.GetHotPosts(page, pageSize)
if err != nil {
return nil, 0, err
}
return posts, total, nil
}
// GetRecommendedPosts 获取推荐帖子
func (s *PostService) GetRecommendedPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) {
// 如果Gorse未启用或用户未登录降级为热门帖子
if !s.gorseClient.IsEnabled() || userID == "" {
return s.GetHotPosts(ctx, page, pageSize)
}
// 计算偏移量
offset := (page - 1) * pageSize
// 从Gorse获取推荐列表
// 多取1条用于判断是否还有下一页
itemIDs, err := s.gorseClient.GetRecommend(ctx, userID, pageSize+1, offset)
if err != nil {
log.Printf("[WARN] Gorse recommendation failed: %v, fallback to hot posts", err)
return s.GetHotPosts(ctx, page, pageSize)
}
// 如果没有推荐结果,降级为热门帖子
if len(itemIDs) == 0 {
return s.GetHotPosts(ctx, page, pageSize)
}
hasNext := len(itemIDs) > pageSize
if hasNext {
itemIDs = itemIDs[:pageSize]
}
// 根据ID列表查询帖子详情
posts, err := s.postRepo.GetByIDs(itemIDs)
if err != nil {
return nil, 0, err
}
// 近似 total当 hasNext 为 true 时,按分页窗口估算,避免因脏数据/缺失数据导致总页数被低估
estimatedTotal := int64(offset + len(posts))
if hasNext {
estimatedTotal = int64(offset + pageSize + 1)
}
return posts, estimatedTotal, nil
}
// buildPostCategories 构建帖子的类别标签
func (s *PostService) buildPostCategories(post *model.Post) []string {
var categories []string
// 热度标签
if post.ViewsCount > 1000 {
categories = append(categories, "hot_high")
} else if post.ViewsCount > 100 {
categories = append(categories, "hot_medium")
}
// 点赞标签
if post.LikesCount > 100 {
categories = append(categories, "likes_100+")
} else if post.LikesCount > 50 {
categories = append(categories, "likes_50+")
} else if post.LikesCount > 10 {
categories = append(categories, "likes_10+")
}
// 评论标签
if post.CommentsCount > 50 {
categories = append(categories, "comments_50+")
} else if post.CommentsCount > 10 {
categories = append(categories, "comments_10+")
}
// 时间标签
age := time.Since(post.CreatedAt)
if age < 24*time.Hour {
categories = append(categories, "today")
} else if age < 7*24*time.Hour {
categories = append(categories, "this_week")
} else if age < 30*24*time.Hour {
categories = append(categories, "this_month")
}
return categories
}

View File

@@ -0,0 +1,575 @@
package service
import (
"context"
"errors"
"fmt"
"time"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/websocket"
"carrot_bbs/internal/repository"
)
// 推送相关常量
const (
// DefaultPushTimeout 默认推送超时时间
DefaultPushTimeout = 30 * time.Second
// MaxRetryCount 最大重试次数
MaxRetryCount = 3
// DefaultExpiredTime 默认消息过期时间24小时
DefaultExpiredTime = 24 * time.Hour
// PushQueueSize 推送队列大小
PushQueueSize = 1000
)
// PushPriority 推送优先级
type PushPriority int
const (
PriorityLow PushPriority = 1 // 低优先级(营销消息等)
PriorityNormal PushPriority = 5 // 普通优先级(系统通知)
PriorityHigh PushPriority = 8 // 高优先级(聊天消息)
PriorityCritical PushPriority = 10 // 最高优先级(重要系统通知)
)
// PushService 推送服务接口
type PushService interface {
// 推送核心方法
PushMessage(ctx context.Context, userID string, message *model.Message) error
PushToUser(ctx context.Context, userID string, message *model.Message, priority int) error
// 系统消息推送
PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error
PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error
PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error
// 系统通知推送(新接口,使用独立的 SystemNotification 模型)
PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error
// 设备管理
RegisterDevice(ctx context.Context, userID string, deviceID string, deviceType model.DeviceType, pushToken string) error
UnregisterDevice(ctx context.Context, deviceID string) error
UpdateDeviceToken(ctx context.Context, deviceID string, newPushToken string) error
// 推送记录管理
CreatePushRecord(ctx context.Context, userID string, messageID string, channel model.PushChannel) (*model.PushRecord, error)
GetPendingPushes(ctx context.Context, userID string) ([]*model.PushRecord, error)
// 后台任务
StartPushWorker(ctx context.Context)
StopPushWorker()
}
// pushServiceImpl 推送服务实现
type pushServiceImpl struct {
pushRepo *repository.PushRecordRepository
deviceRepo *repository.DeviceTokenRepository
messageRepo *repository.MessageRepository
wsManager *websocket.WebSocketManager
// 推送队列
pushQueue chan *pushTask
stopChan chan struct{}
}
// pushTask 推送任务
type pushTask struct {
userID string
message *model.Message
priority int
}
// NewPushService 创建推送服务
func NewPushService(
pushRepo *repository.PushRecordRepository,
deviceRepo *repository.DeviceTokenRepository,
messageRepo *repository.MessageRepository,
wsManager *websocket.WebSocketManager,
) PushService {
return &pushServiceImpl{
pushRepo: pushRepo,
deviceRepo: deviceRepo,
messageRepo: messageRepo,
wsManager: wsManager,
pushQueue: make(chan *pushTask, PushQueueSize),
stopChan: make(chan struct{}),
}
}
// PushMessage 推送消息给用户
func (s *pushServiceImpl) PushMessage(ctx context.Context, userID string, message *model.Message) error {
return s.PushToUser(ctx, userID, message, int(PriorityNormal))
}
// PushToUser 带优先级的推送
func (s *pushServiceImpl) PushToUser(ctx context.Context, userID string, message *model.Message, priority int) error {
// 首先尝试WebSocket推送实时推送
if s.pushViaWebSocket(ctx, userID, message) {
// WebSocket推送成功记录推送状态
record, err := s.CreatePushRecord(ctx, userID, message.ID, model.PushChannelWebSocket)
if err != nil {
return fmt.Errorf("failed to create push record: %w", err)
}
record.MarkPushed()
if err := s.pushRepo.Update(record); err != nil {
return fmt.Errorf("failed to update push record: %w", err)
}
return nil
}
// WebSocket推送失败加入推送队列等待移动端推送
select {
case s.pushQueue <- &pushTask{
userID: userID,
message: message,
priority: priority,
}:
return nil
default:
// 队列已满,直接创建待推送记录
_, err := s.CreatePushRecord(ctx, userID, message.ID, model.PushChannelFCM)
if err != nil {
return fmt.Errorf("failed to create pending push record: %w", err)
}
return errors.New("push queue is full, message queued for later delivery")
}
}
// pushViaWebSocket 通过WebSocket推送消息
// 返回true表示推送成功false表示用户不在线
func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, message *model.Message) bool {
if s.wsManager == nil {
return false
}
if !s.wsManager.IsUserOnline(userID) {
return false
}
// 判断是否为系统消息/通知消息
if message.IsSystemMessage() || message.Category == model.CategoryNotification {
// 使用 NotificationMessage 格式推送系统通知
// 从 segments 中提取文本内容
content := dto.ExtractTextContentFromModel(message.Segments)
notification := &websocket.NotificationMessage{
ID: fmt.Sprintf("%s", message.ID),
Type: string(message.SystemType),
Content: content,
Extra: make(map[string]interface{}),
CreatedAt: message.CreatedAt.UnixMilli(),
}
// 填充额外数据
if message.ExtraData != nil {
notification.Extra["actor_id"] = message.ExtraData.ActorID
notification.Extra["actor_name"] = message.ExtraData.ActorName
notification.Extra["avatar_url"] = message.ExtraData.AvatarURL
notification.Extra["target_id"] = message.ExtraData.TargetID
notification.Extra["target_type"] = message.ExtraData.TargetType
notification.Extra["action_url"] = message.ExtraData.ActionURL
notification.Extra["action_time"] = message.ExtraData.ActionTime
// 设置触发用户信息
if message.ExtraData.ActorID > 0 {
notification.TriggerUser = &websocket.NotificationUser{
ID: fmt.Sprintf("%d", message.ExtraData.ActorID),
Username: message.ExtraData.ActorName,
Avatar: message.ExtraData.AvatarURL,
}
}
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
s.wsManager.SendToUser(userID, wsMsg)
return true
}
// 构建普通聊天消息的 WebSocket 消息 - 使用新的 WSEventResponse 格式
// 获取会话类型 (private/group)
detailType := "private"
if message.ConversationID != "" {
// 从会话中获取类型,需要查询数据库或从消息中判断
// 这里暂时默认为 privategroup 类型需要额外逻辑
}
// 直接使用 message.Segments
segments := message.Segments
event := &dto.WSEventResponse{
ID: fmt.Sprintf("%s", message.ID),
Time: message.CreatedAt.UnixMilli(),
Type: "message",
DetailType: detailType,
Seq: fmt.Sprintf("%d", message.Seq),
Segments: segments,
SenderID: message.SenderID,
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, event)
s.wsManager.SendToUser(userID, wsMsg)
return true
}
// pushViaFCM 通过FCM推送预留接口
func (s *pushServiceImpl) pushViaFCM(ctx context.Context, deviceToken *model.DeviceToken, message *model.Message) error {
// TODO: 实现FCM推送
// 1. 构建FCM消息
// 2. 调用Firebase Admin SDK发送消息
// 3. 处理发送结果
return errors.New("FCM push not implemented")
}
// pushViaAPNs 通过APNs推送预留接口
func (s *pushServiceImpl) pushViaAPNs(ctx context.Context, deviceToken *model.DeviceToken, message *model.Message) error {
// TODO: 实现APNs推送
// 1. 构建APNs消息
// 2. 调用APNs SDK发送消息
// 3. 处理发送结果
return errors.New("APNs push not implemented")
}
// RegisterDevice 注册设备
func (s *pushServiceImpl) RegisterDevice(ctx context.Context, userID string, deviceID string, deviceType model.DeviceType, pushToken string) error {
deviceToken := &model.DeviceToken{
UserID: userID,
DeviceID: deviceID,
DeviceType: deviceType,
PushToken: pushToken,
IsActive: true,
}
deviceToken.UpdateLastUsed()
return s.deviceRepo.Upsert(deviceToken)
}
// UnregisterDevice 注销设备
func (s *pushServiceImpl) UnregisterDevice(ctx context.Context, deviceID string) error {
return s.deviceRepo.Deactivate(deviceID)
}
// UpdateDeviceToken 更新设备Token
func (s *pushServiceImpl) UpdateDeviceToken(ctx context.Context, deviceID string, newPushToken string) error {
deviceToken, err := s.deviceRepo.GetByDeviceID(deviceID)
if err != nil {
return fmt.Errorf("device not found: %w", err)
}
deviceToken.PushToken = newPushToken
deviceToken.Activate()
return s.deviceRepo.Update(deviceToken)
}
// CreatePushRecord 创建推送记录
func (s *pushServiceImpl) CreatePushRecord(ctx context.Context, userID string, messageID string, channel model.PushChannel) (*model.PushRecord, error) {
expiredAt := time.Now().Add(DefaultExpiredTime)
record := &model.PushRecord{
UserID: userID,
MessageID: messageID,
PushChannel: channel,
PushStatus: model.PushStatusPending,
MaxRetry: MaxRetryCount,
ExpiredAt: &expiredAt,
}
if err := s.pushRepo.Create(record); err != nil {
return nil, fmt.Errorf("failed to create push record: %w", err)
}
return record, nil
}
// GetPendingPushes 获取待推送记录
func (s *pushServiceImpl) GetPendingPushes(ctx context.Context, userID string) ([]*model.PushRecord, error) {
return s.pushRepo.GetByUserID(userID, 100, 0)
}
// StartPushWorker 启动推送工作协程
func (s *pushServiceImpl) StartPushWorker(ctx context.Context) {
go s.processPushQueue()
go s.retryFailedPushes()
}
// StopPushWorker 停止推送工作协程
func (s *pushServiceImpl) StopPushWorker() {
close(s.stopChan)
}
// processPushQueue 处理推送队列
func (s *pushServiceImpl) processPushQueue() {
for {
select {
case <-s.stopChan:
return
case task := <-s.pushQueue:
s.processPushTask(task)
}
}
}
// processPushTask 处理单个推送任务
func (s *pushServiceImpl) processPushTask(task *pushTask) {
ctx, cancel := context.WithTimeout(context.Background(), DefaultPushTimeout)
defer cancel()
// 获取用户活跃设备
devices, err := s.deviceRepo.GetActiveByUserID(task.userID)
if err != nil || len(devices) == 0 {
// 没有可用设备,创建待推送记录
s.CreatePushRecord(ctx, task.userID, task.message.ID, model.PushChannelFCM)
return
}
// 对每个设备创建推送记录并尝试推送
for _, device := range devices {
record, err := s.CreatePushRecord(ctx, task.userID, task.message.ID, s.getChannelForDevice(device))
if err != nil {
continue
}
var pushErr error
switch {
case device.IsIOS():
pushErr = s.pushViaAPNs(ctx, device, task.message)
case device.IsAndroid():
pushErr = s.pushViaFCM(ctx, device, task.message)
default:
// Web设备只支持WebSocket
continue
}
if pushErr != nil {
record.MarkFailed(pushErr.Error())
} else {
record.MarkPushed()
}
s.pushRepo.Update(record)
}
}
// getChannelForDevice 根据设备类型获取推送通道
func (s *pushServiceImpl) getChannelForDevice(device *model.DeviceToken) model.PushChannel {
switch device.DeviceType {
case model.DeviceTypeIOS:
return model.PushChannelAPNs
case model.DeviceTypeAndroid:
return model.PushChannelFCM
default:
return model.PushChannelWebSocket
}
}
// retryFailedPushes 重试失败的推送
func (s *pushServiceImpl) retryFailedPushes() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.stopChan:
return
case <-ticker.C:
s.doRetry()
}
}
}
// doRetry 执行重试
func (s *pushServiceImpl) doRetry() {
ctx := context.Background()
// 获取失败待重试的推送
records, err := s.pushRepo.GetFailedPushesForRetry(100)
if err != nil {
return
}
for _, record := range records {
// 检查是否过期
if record.IsExpired() {
record.MarkExpired()
s.pushRepo.Update(record)
continue
}
// 获取消息
message, err := s.messageRepo.GetMessageByID(record.MessageID)
if err != nil {
record.MarkFailed("message not found")
s.pushRepo.Update(record)
continue
}
// 尝试WebSocket推送
if s.pushViaWebSocket(ctx, record.UserID, message) {
record.MarkDelivered()
s.pushRepo.Update(record)
continue
}
// 获取设备并尝试移动端推送
if record.DeviceToken != "" {
device, err := s.deviceRepo.GetByPushToken(record.DeviceToken)
if err != nil {
record.MarkFailed("device not found")
s.pushRepo.Update(record)
continue
}
var pushErr error
switch {
case device.IsIOS():
pushErr = s.pushViaAPNs(ctx, device, message)
case device.IsAndroid():
pushErr = s.pushViaFCM(ctx, device, message)
}
if pushErr != nil {
record.MarkFailed(pushErr.Error())
} else {
record.MarkPushed()
}
s.pushRepo.Update(record)
}
}
}
// PushSystemMessage 推送系统消息
func (s *pushServiceImpl) PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error {
// 首先尝试WebSocket推送
if s.pushSystemViaWebSocket(ctx, userID, msgType, title, content, data) {
return nil
}
// 用户不在线,创建待推送记录(移动端上线后可通过其他方式获取)
// 系统消息通常不需要离线推送,客户端上线后会主动拉取
return errors.New("user is offline, system message will be available on next sync")
}
// pushSystemViaWebSocket 通过WebSocket推送系统消息
func (s *pushServiceImpl) pushSystemViaWebSocket(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) bool {
if s.wsManager == nil {
return false
}
if !s.wsManager.IsUserOnline(userID) {
return false
}
sysMsg := &websocket.SystemMessage{
Type: msgType,
Title: title,
Content: content,
Data: data,
CreatedAt: time.Now().UnixMilli(),
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg)
s.wsManager.SendToUser(userID, wsMsg)
return true
}
// PushNotification 推送通知消息
func (s *pushServiceImpl) PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error {
// 首先尝试WebSocket推送
if s.pushNotificationViaWebSocket(ctx, userID, notification) {
return nil
}
// 用户不在线,创建待推送记录
// 通知消息可以等用户上线后拉取
return errors.New("user is offline, notification will be available on next sync")
}
// pushNotificationViaWebSocket 通过WebSocket推送通知消息
func (s *pushServiceImpl) pushNotificationViaWebSocket(ctx context.Context, userID string, notification *websocket.NotificationMessage) bool {
if s.wsManager == nil {
return false
}
if !s.wsManager.IsUserOnline(userID) {
return false
}
if notification.CreatedAt == 0 {
notification.CreatedAt = time.Now().UnixMilli()
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification)
s.wsManager.SendToUser(userID, wsMsg)
return true
}
// PushAnnouncement 广播公告消息
func (s *pushServiceImpl) PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error {
if s.wsManager == nil {
return errors.New("websocket manager not available")
}
if announcement.CreatedAt == 0 {
announcement.CreatedAt = time.Now().UnixMilli()
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement)
s.wsManager.Broadcast(wsMsg)
return nil
}
// PushSystemNotification 推送系统通知(使用独立的 SystemNotification 模型)
func (s *pushServiceImpl) PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error {
// 首先尝试WebSocket推送
if s.pushSystemNotificationViaWebSocket(ctx, userID, notification) {
return nil
}
// 用户不在线,系统通知已存储在数据库中,用户上线后会主动拉取
return nil
}
// pushSystemNotificationViaWebSocket 通过WebSocket推送系统通知
func (s *pushServiceImpl) pushSystemNotificationViaWebSocket(ctx context.Context, userID string, notification *model.SystemNotification) bool {
if s.wsManager == nil {
return false
}
if !s.wsManager.IsUserOnline(userID) {
return false
}
// 构建 WebSocket 通知消息
wsNotification := &websocket.NotificationMessage{
ID: fmt.Sprintf("%d", notification.ID),
Type: string(notification.Type),
Title: notification.Title,
Content: notification.Content,
Extra: make(map[string]interface{}),
CreatedAt: notification.CreatedAt.UnixMilli(),
}
// 填充额外数据
if notification.ExtraData != nil {
wsNotification.Extra["actor_id_str"] = notification.ExtraData.ActorIDStr
wsNotification.Extra["actor_name"] = notification.ExtraData.ActorName
wsNotification.Extra["avatar_url"] = notification.ExtraData.AvatarURL
wsNotification.Extra["target_id"] = notification.ExtraData.TargetID
wsNotification.Extra["target_type"] = notification.ExtraData.TargetType
wsNotification.Extra["action_url"] = notification.ExtraData.ActionURL
wsNotification.Extra["action_time"] = notification.ExtraData.ActionTime
// 设置触发用户信息
if notification.ExtraData.ActorIDStr != "" {
wsNotification.TriggerUser = &websocket.NotificationUser{
ID: notification.ExtraData.ActorIDStr,
Username: notification.ExtraData.ActorName,
Avatar: notification.ExtraData.AvatarURL,
}
}
}
wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, wsNotification)
s.wsManager.SendToUser(userID, wsMsg)
return true
}

View File

@@ -0,0 +1,559 @@
package service
import (
"context"
"encoding/json"
"fmt"
"log"
"regexp"
"strings"
"sync"
"time"
"unicode/utf8"
"carrot_bbs/internal/model"
redisclient "carrot_bbs/internal/pkg/redis"
"gorm.io/gorm"
)
// ==================== DFA 敏感词过滤实现 ====================
// SensitiveNode 敏感词树节点
type SensitiveNode struct {
// 子节点映射
Children map[rune]*SensitiveNode
// 是否为敏感词结尾
IsEnd bool
// 敏感词信息(仅在 IsEnd 为 true 时有效)
Word string
Level model.SensitiveWordLevel
Category model.SensitiveWordCategory
}
// NewSensitiveNode 创建新的敏感词节点
func NewSensitiveNode() *SensitiveNode {
return &SensitiveNode{
Children: make(map[rune]*SensitiveNode),
IsEnd: false,
}
}
// SensitiveWordTree 敏感词树
type SensitiveWordTree struct {
root *SensitiveNode
wordCount int
mu sync.RWMutex
lastReload time.Time
}
// NewSensitiveWordTree 创建新的敏感词树
func NewSensitiveWordTree() *SensitiveWordTree {
return &SensitiveWordTree{
root: NewSensitiveNode(),
wordCount: 0,
lastReload: time.Now(),
}
}
// AddWord 添加敏感词到树中
func (t *SensitiveWordTree) AddWord(word string, level model.SensitiveWordLevel, category model.SensitiveWordCategory) {
if word == "" {
return
}
t.mu.Lock()
defer t.mu.Unlock()
node := t.root
// 转换为小写进行匹配(不区分大小写)
lowerWord := strings.ToLower(word)
runes := []rune(lowerWord)
for _, r := range runes {
child, exists := node.Children[r]
if !exists {
child = NewSensitiveNode()
node.Children[r] = child
}
node = child
}
// 如果不是已存在的敏感词,则计数+1
if !node.IsEnd {
t.wordCount++
}
node.IsEnd = true
node.Word = word
node.Level = level
node.Category = category
}
// RemoveWord 从树中移除敏感词
func (t *SensitiveWordTree) RemoveWord(word string) {
if word == "" {
return
}
t.mu.Lock()
defer t.mu.Unlock()
lowerWord := strings.ToLower(word)
runes := []rune(lowerWord)
// 查找节点
node := t.root
for _, r := range runes {
child, exists := node.Children[r]
if !exists {
return // 敏感词不存在
}
node = child
}
if node.IsEnd {
node.IsEnd = false
node.Word = ""
t.wordCount--
}
}
// Check 检查文本是否包含敏感词,返回是否包含及敏感词列表
func (t *SensitiveWordTree) Check(text string) (bool, []string) {
if text == "" {
return false, nil
}
t.mu.RLock()
defer t.mu.RUnlock()
var foundWords []string
runes := []rune(strings.ToLower(text))
length := len(runes)
// 用于标记已找到的敏感词位置,避免重复计算
marked := make([]bool, length)
for i := 0; i < length; i++ {
// 从当前位置开始搜索
node := t.root
matchEnd := -1
matchWord := ""
for j := i; j < length; j++ {
child, exists := node.Children[runes[j]]
if !exists {
break
}
node = child
if node.IsEnd {
matchEnd = j
matchWord = node.Word
}
}
// 标记找到的敏感词位置
if matchEnd >= 0 && !marked[i] {
for k := i; k <= matchEnd; k++ {
marked[k] = true
}
foundWords = append(foundWords, matchWord)
}
}
return len(foundWords) > 0, foundWords
}
// Replace 替换文本中的敏感词
func (t *SensitiveWordTree) Replace(text string, repl string) string {
if text == "" {
return text
}
t.mu.RLock()
defer t.mu.RUnlock()
runes := []rune(text)
length := len(runes)
result := make([]rune, 0, length)
// 用于标记已替换的位置
marked := make([]bool, length)
for i := 0; i < length; i++ {
if marked[i] {
continue
}
// 从当前位置开始搜索
node := t.root
matchEnd := -1
for j := i; j < length; j++ {
child, exists := node.Children[runes[j]]
if !exists {
break
}
node = child
if node.IsEnd {
matchEnd = j
}
}
if matchEnd >= 0 {
// 标记已替换的位置
for k := i; k <= matchEnd; k++ {
marked[k] = true
}
// 追加替换符
replRunes := []rune(repl)
result = append(result, replRunes...)
// 跳过已匹配的字符
i = matchEnd
} else {
// 追加原字符
result = append(result, runes[i])
}
}
return string(result)
}
// WordCount 获取敏感词数量
func (t *SensitiveWordTree) WordCount() int {
t.mu.RLock()
defer t.mu.RUnlock()
return t.wordCount
}
// ==================== 敏感词服务实现 ====================
// SensitiveService 敏感词服务接口
type SensitiveService interface {
// Check 检查文本是否包含敏感词
Check(ctx context.Context, text string) (bool, []string)
// Replace 替换敏感词
Replace(ctx context.Context, text string, repl string) string
// AddWord 添加敏感词
AddWord(ctx context.Context, word string, category string, level int) error
// RemoveWord 移除敏感词
RemoveWord(ctx context.Context, word string) error
// Reload 重新加载敏感词库
Reload(ctx context.Context) error
// GetWordCount 获取敏感词数量
GetWordCount(ctx context.Context) int
}
// sensitiveServiceImpl 敏感词服务实现
type sensitiveServiceImpl struct {
tree *SensitiveWordTree
db *gorm.DB
redis *redisclient.Client
config *SensitiveConfig
mu sync.RWMutex
replaceStr string
}
// SensitiveConfig 敏感词服务配置
type SensitiveConfig struct {
Enabled bool `mapstructure:"enabled" yaml:"enabled"`
ReplaceStr string `mapstructure:"replace_str" yaml:"replace_str"`
// 最小匹配长度
MinMatchLen int `mapstructure:"min_match_len" yaml:"min_match_len"`
// 是否从数据库加载
LoadFromDB bool `mapstructure:"load_from_db" yaml:"load_from_db"`
// 是否从Redis加载
LoadFromRedis bool `mapstructure:"load_from_redis" yaml:"load_from_redis"`
// Redis键前缀
RedisKeyPrefix string `mapstructure:"redis_key_prefix" yaml:"redis_key_prefix"`
}
// NewSensitiveService 创建敏感词服务
func NewSensitiveService(db *gorm.DB, redisClient *redisclient.Client, config *SensitiveConfig) SensitiveService {
s := &sensitiveServiceImpl{
tree: NewSensitiveWordTree(),
db: db,
redis: redisClient,
config: config,
replaceStr: config.ReplaceStr,
}
// 如果未设置替换符,默认使用 ***
if s.replaceStr == "" {
s.replaceStr = "***"
}
// 初始化加载敏感词
if config.LoadFromDB {
if err := s.loadFromDB(context.Background()); err != nil {
log.Printf("Failed to load sensitive words from database: %v", err)
}
}
if config.LoadFromRedis && redisClient != nil {
if err := s.loadFromRedis(context.Background()); err != nil {
log.Printf("Failed to load sensitive words from redis: %v", err)
}
}
return s
}
// Check 检查文本是否包含敏感词
func (s *sensitiveServiceImpl) Check(ctx context.Context, text string) (bool, []string) {
if !s.config.Enabled {
return false, nil
}
if text == "" {
return false, nil
}
return s.tree.Check(text)
}
// Replace 替换敏感词
func (s *sensitiveServiceImpl) Replace(ctx context.Context, text string, repl string) string {
if !s.config.Enabled {
return text
}
if text == "" {
return text
}
// 如果未指定替换符,使用默认替换符
if repl == "" {
repl = s.replaceStr
}
return s.tree.Replace(text, repl)
}
// AddWord 添加敏感词
func (s *sensitiveServiceImpl) AddWord(ctx context.Context, word string, category string, level int) error {
if word == "" {
return fmt.Errorf("word cannot be empty")
}
// 转换为敏感词级别
wordLevel := model.SensitiveWordLevel(level)
if wordLevel < 1 || wordLevel > 3 {
wordLevel = model.SensitiveWordLevelLow
}
// 转换为敏感词分类
wordCategory := model.SensitiveWordCategory(category)
if wordCategory == "" {
wordCategory = model.SensitiveWordCategoryOther
}
// 添加到树
s.tree.AddWord(word, wordLevel, wordCategory)
// 持久化到数据库
if s.db != nil {
sensitiveWord := model.SensitiveWord{
Word: word,
Category: wordCategory,
Level: wordLevel,
IsActive: true,
}
// 使用 upsert 逻辑
var existing model.SensitiveWord
result := s.db.Where("word = ?", word).First(&existing)
if result.Error == gorm.ErrRecordNotFound {
if err := s.db.Create(&sensitiveWord).Error; err != nil {
log.Printf("Failed to save sensitive word to database: %v", err)
}
} else if result.Error == nil {
// 更新已存在的记录
existing.Category = wordCategory
existing.Level = wordLevel
existing.IsActive = true
if err := s.db.Save(&existing).Error; err != nil {
log.Printf("Failed to update sensitive word in database: %v", err)
}
}
}
// 同步到 Redis
if s.redis != nil && s.config.RedisKeyPrefix != "" {
key := fmt.Sprintf("%s:%s", s.config.RedisKeyPrefix, word)
data := map[string]interface{}{
"word": word,
"category": category,
"level": level,
}
jsonData, _ := json.Marshal(data)
s.redis.Set(ctx, key, jsonData, 0)
}
return nil
}
// RemoveWord 移除敏感词
func (s *sensitiveServiceImpl) RemoveWord(ctx context.Context, word string) error {
if word == "" {
return fmt.Errorf("word cannot be empty")
}
// 从树中移除
s.tree.RemoveWord(word)
// 从数据库中标记为不活跃
if s.db != nil {
result := s.db.Model(&model.SensitiveWord{}).Where("word = ?", word).Update("is_active", false)
if result.Error != nil {
log.Printf("Failed to deactivate sensitive word in database: %v", result.Error)
}
}
// 从 Redis 中删除
if s.redis != nil && s.config.RedisKeyPrefix != "" {
key := fmt.Sprintf("%s:%s", s.config.RedisKeyPrefix, word)
s.redis.Del(ctx, key)
}
return nil
}
// Reload 重新加载敏感词库
func (s *sensitiveServiceImpl) Reload(ctx context.Context) error {
// 清空现有树
s.tree = NewSensitiveWordTree()
// 从数据库加载
if s.config.LoadFromDB {
if err := s.loadFromDB(ctx); err != nil {
return fmt.Errorf("failed to load from database: %w", err)
}
}
// 从 Redis 加载
if s.config.LoadFromRedis && s.redis != nil {
if err := s.loadFromRedis(ctx); err != nil {
return fmt.Errorf("failed to load from redis: %w", err)
}
}
return nil
}
// GetWordCount 获取敏感词数量
func (s *sensitiveServiceImpl) GetWordCount(ctx context.Context) int {
return s.tree.WordCount()
}
// loadFromDB 从数据库加载敏感词
func (s *sensitiveServiceImpl) loadFromDB(ctx context.Context) error {
if s.db == nil {
return nil
}
var words []model.SensitiveWord
if err := s.db.Where("is_active = ?", true).Find(&words).Error; err != nil {
return err
}
for _, word := range words {
s.tree.AddWord(word.Word, word.Level, word.Category)
}
log.Printf("Loaded %d sensitive words from database", len(words))
return nil
}
// loadFromRedis 从 Redis 加载敏感词
func (s *sensitiveServiceImpl) loadFromRedis(ctx context.Context) error {
if s.redis == nil || s.config.RedisKeyPrefix == "" {
return nil
}
// 使用 SCAN 命令代替 KEYS避免阻塞
pattern := fmt.Sprintf("%s:*", s.config.RedisKeyPrefix)
var cursor uint64
for {
keys, nextCursor, err := s.redis.GetClient().Scan(ctx, cursor, pattern, 100).Result()
if err != nil {
return err
}
for _, key := range keys {
data, err := s.redis.Get(ctx, key)
if err != nil {
continue
}
var wordData map[string]interface{}
if err := json.Unmarshal([]byte(data), &wordData); err != nil {
continue
}
word, _ := wordData["word"].(string)
category, _ := wordData["category"].(string)
level, _ := wordData["level"].(float64)
if word != "" {
s.tree.AddWord(word, model.SensitiveWordLevel(int(level)), model.SensitiveWordCategory(category))
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return nil
}
// ==================== 辅助函数 ====================
// ContainsSensitiveWord 快速检查文本是否包含敏感词
func ContainsSensitiveWord(text string, tree *SensitiveWordTree) bool {
if tree == nil || text == "" {
return false
}
hasSensitive, _ := tree.Check(text)
return hasSensitive
}
// FilterSensitiveWords 过滤敏感词并返回替换后的文本
func FilterSensitiveWords(text string, tree *SensitiveWordTree, repl string) string {
if tree == nil || text == "" {
return text
}
if repl == "" {
repl = "***"
}
return tree.Replace(text, repl)
}
// ValidateTextLength 验证文本长度是否合法
func ValidateTextLength(text string, minLen, maxLen int) bool {
length := utf8.RuneCountInString(text)
return length >= minLen && length <= maxLen
}
// SanitizeText 清理文本,移除多余空白字符
func SanitizeText(text string) string {
// 替换多个连续空白字符为单个空格
spaceReg := regexp.MustCompile(`\s+`)
text = spaceReg.ReplaceAllString(text, " ")
// 去除首尾空白
return strings.TrimSpace(text)
}
// ==================== 默认敏感词列表 ====================
// DefaultSensitiveWords 返回默认敏感词列表(示例)
func DefaultSensitiveWords() map[string]struct{} {
return map[string]struct{}{
// 示例敏感词,实际需要从数据库或配置加载
"测试敏感词1": {},
"测试敏感词2": {},
"测试敏感词3": {},
}
}

View File

@@ -0,0 +1,139 @@
package service
import (
"carrot_bbs/internal/model"
"carrot_bbs/internal/repository"
"errors"
"net/url"
"strings"
)
var (
ErrStickerAlreadyExists = errors.New("sticker already exists")
ErrInvalidStickerURL = errors.New("invalid sticker url")
)
// StickerService 自定义表情服务接口
type StickerService interface {
// 获取用户的所有表情
GetUserStickers(userID string) ([]model.UserSticker, error)
// 添加表情
AddSticker(userID string, url string, width, height int) (*model.UserSticker, error)
// 删除表情
DeleteSticker(userID string, stickerID string) error
// 检查表情是否已存在
CheckExists(userID string, url string) (bool, error)
// 重新排序
ReorderStickers(userID string, orders map[string]int) error
// 获取用户表情数量
GetStickerCount(userID string) (int64, error)
}
// stickerService 自定义表情服务实现
type stickerService struct {
stickerRepo repository.StickerRepository
}
// NewStickerService 创建自定义表情服务
func NewStickerService(stickerRepo repository.StickerRepository) StickerService {
return &stickerService{
stickerRepo: stickerRepo,
}
}
// GetUserStickers 获取用户的所有表情
func (s *stickerService) GetUserStickers(userID string) ([]model.UserSticker, error) {
stickers, err := s.stickerRepo.GetByUserID(userID)
if err != nil {
return nil, err
}
// 兼容历史脏数据:过滤本地文件 URI避免客户端加载 file:// 报错
filtered := make([]model.UserSticker, 0, len(stickers))
for _, sticker := range stickers {
if isValidStickerURL(sticker.URL) {
filtered = append(filtered, sticker)
}
}
return filtered, nil
}
// AddSticker 添加表情
func (s *stickerService) AddSticker(userID string, url string, width, height int) (*model.UserSticker, error) {
if !isValidStickerURL(url) {
return nil, ErrInvalidStickerURL
}
// 检查是否已存在
exists, err := s.stickerRepo.Exists(userID, url)
if err != nil {
return nil, err
}
if exists {
return nil, ErrStickerAlreadyExists
}
// 获取当前数量用于设置排序
count, err := s.stickerRepo.CountByUserID(userID)
if err != nil {
return nil, err
}
sticker := &model.UserSticker{
UserID: userID,
URL: url,
Width: width,
Height: height,
SortOrder: int(count), // 新表情添加到末尾
}
if err := s.stickerRepo.Create(sticker); err != nil {
return nil, err
}
return sticker, nil
}
func isValidStickerURL(raw string) bool {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return false
}
parsed, err := url.Parse(trimmed)
if err != nil {
return false
}
scheme := strings.ToLower(parsed.Scheme)
return scheme == "http" || scheme == "https"
}
// DeleteSticker 删除表情
func (s *stickerService) DeleteSticker(userID string, stickerID string) error {
// 先检查表情是否属于该用户
sticker, err := s.stickerRepo.GetByID(stickerID)
if err != nil {
return err
}
if sticker.UserID != userID {
return errors.New("sticker not found")
}
return s.stickerRepo.Delete(stickerID)
}
// CheckExists 检查表情是否已存在
func (s *stickerService) CheckExists(userID string, url string) (bool, error) {
return s.stickerRepo.Exists(userID, url)
}
// ReorderStickers 重新排序
func (s *stickerService) ReorderStickers(userID string, orders map[string]int) error {
return s.stickerRepo.BatchUpdateSortOrder(userID, orders)
}
// GetStickerCount 获取用户表情数量
func (s *stickerService) GetStickerCount(userID string) (int64, error) {
return s.stickerRepo.CountByUserID(userID)
}

View File

@@ -0,0 +1,462 @@
package service
import (
"context"
"fmt"
"time"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/utils"
"carrot_bbs/internal/repository"
)
// SystemMessageService 系统消息服务接口
type SystemMessageService interface {
// 发送互动通知
SendLikeNotification(ctx context.Context, userID string, operatorID string, postID string) error
SendCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string) error
SendReplyNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, replyID string) error
SendFollowNotification(ctx context.Context, userID string, operatorID string) error
SendMentionNotification(ctx context.Context, userID string, operatorID string, postID string) error
SendFavoriteNotification(ctx context.Context, userID string, operatorID string, postID string) error
SendLikeCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, commentContent string) error
SendLikeReplyNotification(ctx context.Context, userID string, operatorID string, postID string, replyID string, replyContent string) error
// 发送系统公告
SendSystemAnnouncement(ctx context.Context, userIDs []string, title string, content string) error
SendBroadcastAnnouncement(ctx context.Context, title string, content string) error
}
type systemMessageServiceImpl struct {
notifyRepo *repository.SystemNotificationRepository
pushService PushService
userRepo *repository.UserRepository
postRepo *repository.PostRepository
cache cache.Cache
}
// NewSystemMessageService 创建系统消息服务
func NewSystemMessageService(
notifyRepo *repository.SystemNotificationRepository,
pushService PushService,
userRepo *repository.UserRepository,
postRepo *repository.PostRepository,
) SystemMessageService {
return &systemMessageServiceImpl{
notifyRepo: notifyRepo,
pushService: pushService,
userRepo: userRepo,
postRepo: postRepo,
cache: cache.GetCache(),
}
}
// SendLikeNotification 发送点赞通知
func (s *systemMessageServiceImpl) SendLikeNotification(ctx context.Context, userID string, operatorID string, postID string) error {
// 获取操作者信息
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
if err != nil {
return err
}
// 获取帖子标题
postTitle, err := s.getPostTitle(postID)
if err != nil {
postTitle = "您的帖子"
}
extraData := &model.SystemNotificationExtra{
ActorIDStr: operatorID,
ActorName: actorName,
AvatarURL: avatarURL,
TargetID: postID,
TargetTitle: postTitle,
TargetType: "post",
ActionURL: fmt.Sprintf("/posts/%s", postID),
ActionTime: time.Now().Format(time.RFC3339),
}
content := fmt.Sprintf("%s 赞了「%s」", actorName, postTitle)
// 创建通知
notification, err := s.createNotification(ctx, userID, model.SysNotifyLikePost, content, extraData)
if err != nil {
return fmt.Errorf("failed to create like notification: %w", err)
}
// 推送通知
return s.pushService.PushSystemNotification(ctx, userID, notification)
}
// SendCommentNotification 发送评论通知
func (s *systemMessageServiceImpl) SendCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string) error {
// 获取操作者信息
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
if err != nil {
return err
}
// 获取帖子标题
postTitle, err := s.getPostTitle(postID)
if err != nil {
postTitle = "您的帖子"
}
extraData := &model.SystemNotificationExtra{
ActorIDStr: operatorID,
ActorName: actorName,
AvatarURL: avatarURL,
TargetID: postID,
TargetTitle: postTitle,
TargetType: "comment",
ActionURL: fmt.Sprintf("/posts/%s?comment=%s", postID, commentID),
ActionTime: time.Now().Format(time.RFC3339),
}
content := fmt.Sprintf("%s 评论了「%s」", actorName, postTitle)
// 创建通知
notification, err := s.createNotification(ctx, userID, model.SysNotifyComment, content, extraData)
if err != nil {
return fmt.Errorf("failed to create comment notification: %w", err)
}
// 推送通知
return s.pushService.PushSystemNotification(ctx, userID, notification)
}
// SendReplyNotification 发送回复通知
func (s *systemMessageServiceImpl) SendReplyNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, replyID string) error {
// 获取操作者信息
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
if err != nil {
return err
}
// 获取帖子标题
postTitle, err := s.getPostTitle(postID)
if err != nil {
postTitle = "您的帖子"
}
extraData := &model.SystemNotificationExtra{
ActorIDStr: operatorID,
ActorName: actorName,
AvatarURL: avatarURL,
TargetID: replyID,
TargetTitle: postTitle,
TargetType: "reply",
ActionURL: fmt.Sprintf("/posts/%s?comment=%s&reply=%s", postID, commentID, replyID),
ActionTime: time.Now().Format(time.RFC3339),
}
content := fmt.Sprintf("%s 回复了您在「%s」的评论", actorName, postTitle)
// 创建通知
notification, err := s.createNotification(ctx, userID, model.SysNotifyReply, content, extraData)
if err != nil {
return fmt.Errorf("failed to create reply notification: %w", err)
}
// 推送通知
return s.pushService.PushSystemNotification(ctx, userID, notification)
}
// SendFollowNotification 发送关注通知
func (s *systemMessageServiceImpl) SendFollowNotification(ctx context.Context, userID string, operatorID string) error {
fmt.Printf("[DEBUG] SendFollowNotification: userID=%s, operatorID=%s\n", userID, operatorID)
// 获取操作者信息
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
if err != nil {
fmt.Printf("[DEBUG] SendFollowNotification: getActorInfo error: %v\n", err)
return err
}
fmt.Printf("[DEBUG] SendFollowNotification: actorName=%s, avatarURL=%s\n", actorName, avatarURL)
extraData := &model.SystemNotificationExtra{
ActorIDStr: operatorID,
ActorName: actorName,
AvatarURL: avatarURL,
TargetID: "",
TargetType: "user",
ActionURL: fmt.Sprintf("/users/%s", operatorID),
ActionTime: time.Now().Format(time.RFC3339),
}
content := fmt.Sprintf("%s 关注了你", actorName)
// 创建通知
notification, err := s.createNotification(ctx, userID, model.SysNotifyFollow, content, extraData)
if err != nil {
return fmt.Errorf("failed to create follow notification: %w", err)
}
fmt.Printf("[DEBUG] SendFollowNotification: notification created, ID=%d, Content=%s\n", notification.ID, notification.Content)
// 推送通知
pushErr := s.pushService.PushSystemNotification(ctx, userID, notification)
if pushErr != nil {
fmt.Printf("[DEBUG] SendFollowNotification: PushSystemNotification error: %v\n", pushErr)
} else {
fmt.Printf("[DEBUG] SendFollowNotification: PushSystemNotification success\n")
}
return pushErr
}
// SendFavoriteNotification 发送收藏通知
func (s *systemMessageServiceImpl) SendFavoriteNotification(ctx context.Context, userID string, operatorID string, postID string) error {
// 获取操作者信息
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
if err != nil {
return err
}
// 获取帖子标题
postTitle, err := s.getPostTitle(postID)
if err != nil {
postTitle = "您的帖子"
}
extraData := &model.SystemNotificationExtra{
ActorIDStr: operatorID,
ActorName: actorName,
AvatarURL: avatarURL,
TargetID: postID,
TargetTitle: postTitle,
TargetType: "post",
ActionURL: fmt.Sprintf("/posts/%s", postID),
ActionTime: time.Now().Format(time.RFC3339),
}
content := fmt.Sprintf("%s 收藏了「%s」", actorName, postTitle)
// 创建通知
notification, err := s.createNotification(ctx, userID, model.SysNotifyFavoritePost, content, extraData)
if err != nil {
return fmt.Errorf("failed to create favorite notification: %w", err)
}
// 推送通知
return s.pushService.PushSystemNotification(ctx, userID, notification)
}
// SendLikeCommentNotification 发送评论点赞通知
func (s *systemMessageServiceImpl) SendLikeCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, commentContent string) error {
// 获取操作者信息
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
if err != nil {
return err
}
// 截取评论内容预览最多50字
preview := commentContent
runes := []rune(preview)
if len(runes) > 50 {
preview = string(runes[:50]) + "..."
}
extraData := &model.SystemNotificationExtra{
ActorIDStr: operatorID,
ActorName: actorName,
AvatarURL: avatarURL,
TargetID: postID,
TargetTitle: preview,
TargetType: "comment",
ActionURL: fmt.Sprintf("/posts/%s?comment=%s", postID, commentID),
ActionTime: time.Now().Format(time.RFC3339),
}
content := fmt.Sprintf("%s 赞了您的评论", actorName)
// 创建通知
notification, err := s.createNotification(ctx, userID, model.SysNotifyLikeComment, content, extraData)
if err != nil {
return fmt.Errorf("failed to create like comment notification: %w", err)
}
// 推送通知
return s.pushService.PushSystemNotification(ctx, userID, notification)
}
// SendLikeReplyNotification 发送回复点赞通知
func (s *systemMessageServiceImpl) SendLikeReplyNotification(ctx context.Context, userID string, operatorID string, postID string, replyID string, replyContent string) error {
// 获取操作者信息
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
if err != nil {
return err
}
// 截取回复内容预览最多50字
preview := replyContent
runes := []rune(preview)
if len(runes) > 50 {
preview = string(runes[:50]) + "..."
}
extraData := &model.SystemNotificationExtra{
ActorIDStr: operatorID,
ActorName: actorName,
AvatarURL: avatarURL,
TargetID: postID,
TargetTitle: preview,
TargetType: "reply",
ActionURL: fmt.Sprintf("/posts/%s?reply=%s", postID, replyID),
ActionTime: time.Now().Format(time.RFC3339),
}
content := fmt.Sprintf("%s 赞了您的回复", actorName)
// 创建通知
notification, err := s.createNotification(ctx, userID, model.SysNotifyLikeReply, content, extraData)
if err != nil {
return fmt.Errorf("failed to create like reply notification: %w", err)
}
// 推送通知
return s.pushService.PushSystemNotification(ctx, userID, notification)
}
// SendMentionNotification 发送@提及通知
func (s *systemMessageServiceImpl) SendMentionNotification(ctx context.Context, userID string, operatorID string, postID string) error {
// 获取操作者信息
actorName, avatarURL, err := s.getActorInfo(ctx, operatorID)
if err != nil {
return err
}
// 获取帖子标题
postTitle, err := s.getPostTitle(postID)
if err != nil {
postTitle = "您的帖子"
}
extraData := &model.SystemNotificationExtra{
ActorIDStr: operatorID,
ActorName: actorName,
AvatarURL: avatarURL,
TargetID: postID,
TargetTitle: postTitle,
TargetType: "post",
ActionURL: fmt.Sprintf("/posts/%s", postID),
ActionTime: time.Now().Format(time.RFC3339),
}
content := fmt.Sprintf("%s 在「%s」中提到了你", actorName, postTitle)
// 创建通知
notification, err := s.createNotification(ctx, userID, model.SysNotifyMention, content, extraData)
if err != nil {
return fmt.Errorf("failed to create mention notification: %w", err)
}
// 推送通知
return s.pushService.PushSystemNotification(ctx, userID, notification)
}
// SendSystemAnnouncement 发送系统公告给指定用户
func (s *systemMessageServiceImpl) SendSystemAnnouncement(ctx context.Context, userIDs []string, title string, content string) error {
for _, userID := range userIDs {
extraData := &model.SystemNotificationExtra{
TargetType: "announcement",
ActionTime: time.Now().Format(time.RFC3339),
}
notification, err := s.createNotification(ctx, userID, model.SysNotifyAnnounce, fmt.Sprintf("【%s】%s", title, content), extraData)
if err != nil {
continue // 单个失败不影响其他用户
}
// 推送通知(使用高优先级)
if err := s.pushService.PushSystemNotification(ctx, userID, notification); err != nil {
continue
}
}
return nil
}
// SendBroadcastAnnouncement 发送广播公告给所有在线用户
func (s *systemMessageServiceImpl) SendBroadcastAnnouncement(ctx context.Context, title string, content string) error {
// TODO: 实现广播公告
// 1. 获取所有在线用户
// 2. 批量发送公告
// 3. 对于离线用户,存储为待推送记录
return fmt.Errorf("broadcast announcement not implemented")
}
// createNotification 创建系统通知(存储到独立表)
func (s *systemMessageServiceImpl) createNotification(ctx context.Context, userID string, notifyType model.SystemNotificationType, content string, extraData *model.SystemNotificationExtra) (*model.SystemNotification, error) {
fmt.Printf("[DEBUG] createNotification: userID=%s, notifyType=%s\n", userID, notifyType)
// 生成雪花算法ID
id, err := utils.GetSnowflake().GenerateID()
if err != nil {
fmt.Printf("[DEBUG] createNotification: failed to generate ID: %v\n", err)
return nil, fmt.Errorf("failed to generate notification ID: %w", err)
}
notification := &model.SystemNotification{
ID: id,
ReceiverID: userID,
Type: notifyType,
Content: content,
ExtraData: extraData,
IsRead: false,
}
fmt.Printf("[DEBUG] createNotification: notification created with ID=%d, ReceiverID=%s\n", id, userID)
// 保存通知到数据库
if err := s.notifyRepo.Create(notification); err != nil {
fmt.Printf("[DEBUG] createNotification: failed to save notification: %v\n", err)
return nil, fmt.Errorf("failed to save notification: %w", err)
}
// 失效系统消息未读数缓存
cache.InvalidateUnreadSystem(s.cache, userID)
fmt.Printf("[DEBUG] createNotification: notification saved successfully, ID=%d\n", notification.ID)
return notification, nil
}
// getActorInfo 获取操作者信息
func (s *systemMessageServiceImpl) getActorInfo(ctx context.Context, operatorID string) (string, string, error) {
// 从用户仓储获取用户信息
if s.userRepo != nil {
user, err := s.userRepo.GetByID(operatorID)
if err != nil {
fmt.Printf("[DEBUG] getActorInfo: failed to get user %s: %v\n", operatorID, err)
return "用户", utils.GenerateDefaultAvatarURL("用户"), nil // 返回默认值,不阻断流程
}
avatar := utils.GetAvatarOrDefault(user.Username, user.Nickname, user.Avatar)
return user.Nickname, avatar, nil
}
// 如果没有用户仓储,返回默认值
return "用户", utils.GenerateDefaultAvatarURL("用户"), nil
}
// getPostTitle 获取帖子标题
func (s *systemMessageServiceImpl) getPostTitle(postID string) (string, error) {
if s.postRepo == nil {
if len(postID) >= 8 {
return fmt.Sprintf("帖子#%s", postID[:8]), nil
}
return fmt.Sprintf("帖子#%s", postID), nil
}
post, err := s.postRepo.GetByID(postID)
if err != nil {
if len(postID) >= 8 {
return fmt.Sprintf("帖子#%s", postID[:8]), nil
}
return fmt.Sprintf("帖子#%s", postID), nil
}
if post.Title != "" {
return post.Title, nil
}
// 如果没有标题返回内容前20个字符
if len(post.Content) > 20 {
return post.Content[:20] + "...", nil
}
return post.Content, nil
}

View File

@@ -0,0 +1,273 @@
package service
import (
"bytes"
"context"
"crypto/sha256"
"fmt"
"image"
"image/jpeg"
"image/png"
"io"
"mime"
"mime/multipart"
"net/http"
"path/filepath"
"strings"
"carrot_bbs/internal/pkg/s3"
_ "golang.org/x/image/bmp"
_ "golang.org/x/image/tiff"
)
// UploadService 上传服务
type UploadService struct {
s3Client *s3.Client
userService *UserService
}
// NewUploadService 创建上传服务
func NewUploadService(s3Client *s3.Client, userService *UserService) *UploadService {
return &UploadService{
s3Client: s3Client,
userService: userService,
}
}
// UploadImage 上传图片
func (s *UploadService) UploadImage(ctx context.Context, file *multipart.FileHeader) (string, error) {
processedData, contentType, ext, err := prepareImageForUpload(file)
if err != nil {
return "", err
}
// 压缩后再计算哈希,确保同一压缩结果映射同一对象名
hash := sha256.Sum256(processedData)
hashStr := fmt.Sprintf("%x", hash)
objectName := fmt.Sprintf("images/%s%s", hashStr, ext)
url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType)
if err != nil {
return "", fmt.Errorf("failed to upload to S3: %w", err)
}
return url, nil
}
// getExtFromContentType 根据Content-Type获取文件扩展名
func getExtFromContentType(contentType string) string {
baseType, _, err := mime.ParseMediaType(contentType)
if err == nil && baseType != "" {
contentType = baseType
}
switch contentType {
case "image/jpg", "image/jpeg":
return ".jpg"
case "image/png":
return ".png"
case "image/gif":
return ".gif"
case "image/webp":
return ".webp"
case "image/bmp", "image/x-ms-bmp":
return ".bmp"
case "image/tiff":
return ".tiff"
default:
return ""
}
}
// UploadAvatar 上传头像
func (s *UploadService) UploadAvatar(ctx context.Context, userID string, file *multipart.FileHeader) (string, error) {
processedData, contentType, ext, err := prepareImageForUpload(file)
if err != nil {
return "", err
}
// 压缩后再计算哈希
hash := sha256.Sum256(processedData)
hashStr := fmt.Sprintf("%x", hash)
objectName := fmt.Sprintf("avatars/%s%s", hashStr, ext)
url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType)
if err != nil {
return "", fmt.Errorf("failed to upload to S3: %w", err)
}
// 更新用户头像
if s.userService != nil {
user, err := s.userService.GetUserByID(ctx, userID)
if err == nil && user != nil {
user.Avatar = url
err = s.userService.UpdateUser(ctx, user)
if err != nil {
// 更新失败不影响上传结果,只记录日志
fmt.Printf("[UploadAvatar] failed to update user avatar: %v\n", err)
}
}
}
return url, nil
}
// UploadCover 上传头图(个人主页封面)
func (s *UploadService) UploadCover(ctx context.Context, userID string, file *multipart.FileHeader) (string, error) {
processedData, contentType, ext, err := prepareImageForUpload(file)
if err != nil {
return "", err
}
// 压缩后再计算哈希
hash := sha256.Sum256(processedData)
hashStr := fmt.Sprintf("%x", hash)
objectName := fmt.Sprintf("covers/%s%s", hashStr, ext)
url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType)
if err != nil {
return "", fmt.Errorf("failed to upload to S3: %w", err)
}
// 更新用户头图
if s.userService != nil {
user, err := s.userService.GetUserByID(ctx, userID)
if err == nil && user != nil {
user.CoverURL = url
err = s.userService.UpdateUser(ctx, user)
if err != nil {
// 更新失败不影响上传结果,只记录日志
fmt.Printf("[UploadCover] failed to update user cover: %v\n", err)
}
}
}
return url, nil
}
// GetURL 获取文件URL
func (s *UploadService) GetURL(ctx context.Context, objectName string) (string, error) {
return s.s3Client.GetURL(ctx, objectName)
}
// Delete 删除文件
func (s *UploadService) Delete(ctx context.Context, objectName string) error {
return s.s3Client.Delete(ctx, objectName)
}
func prepareImageForUpload(file *multipart.FileHeader) ([]byte, string, string, error) {
f, err := file.Open()
if err != nil {
return nil, "", "", fmt.Errorf("failed to open file: %w", err)
}
defer f.Close()
originalData, err := io.ReadAll(f)
if err != nil {
return nil, "", "", fmt.Errorf("failed to read file: %w", err)
}
// 优先从文件字节探测真实类型,避免前端压缩/转码后 header 与实际格式不一致
detectedType := normalizeImageContentType(http.DetectContentType(originalData))
headerType := normalizeImageContentType(file.Header.Get("Content-Type"))
contentType := detectedType
if contentType == "" || contentType == "application/octet-stream" {
contentType = headerType
}
compressedData, compressedType, err := compressImageData(originalData, contentType)
if err != nil {
// 压缩失败时回退到原图,保证上传可用性
compressedData = originalData
compressedType = contentType
}
if compressedType == "" {
compressedType = contentType
}
if compressedType == "" {
compressedType = http.DetectContentType(compressedData)
}
ext := getExtFromContentType(compressedType)
if ext == "" {
ext = strings.ToLower(filepath.Ext(file.Filename))
}
if ext == "" {
// 最终兜底,避免对象名无扩展名导致 URL 语义不明确
ext = ".jpg"
}
return compressedData, compressedType, ext, nil
}
func compressImageData(data []byte, contentType string) ([]byte, string, error) {
contentType = normalizeImageContentType(contentType)
// GIF/WebP 等格式先保留原图,避免动画和透明通道丢失
if contentType == "image/gif" || contentType == "image/webp" {
return data, contentType, nil
}
if contentType != "image/jpeg" &&
contentType != "image/png" &&
contentType != "image/bmp" &&
contentType != "image/x-ms-bmp" &&
contentType != "image/tiff" {
return data, contentType, nil
}
img, _, err := image.Decode(bytes.NewReader(data))
if err != nil {
return nil, "", fmt.Errorf("failed to decode image: %w", err)
}
var buf bytes.Buffer
switch contentType {
case "image/png":
encoder := png.Encoder{CompressionLevel: png.BestCompression}
if err := encoder.Encode(&buf, img); err != nil {
return nil, "", fmt.Errorf("failed to encode png: %w", err)
}
return buf.Bytes(), "image/png", nil
default:
// BMP/TIFF 等无损大图统一压缩为 JPEG控制体积
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 82}); err != nil {
return nil, "", fmt.Errorf("failed to encode jpeg: %w", err)
}
return buf.Bytes(), "image/jpeg", nil
}
}
func normalizeImageContentType(contentType string) string {
if contentType == "" {
return ""
}
baseType, _, err := mime.ParseMediaType(contentType)
if err == nil && baseType != "" {
contentType = baseType
}
switch strings.ToLower(contentType) {
case "image/jpg":
return "image/jpeg"
case "image/jpeg":
return "image/jpeg"
case "image/png":
return "image/png"
case "image/gif":
return "image/gif"
case "image/webp":
return "image/webp"
case "image/bmp", "image/x-ms-bmp":
return "image/bmp"
case "image/tiff":
return "image/tiff"
default:
return contentType
}
}

View File

@@ -0,0 +1,592 @@
package service
import (
"context"
"fmt"
"strings"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/utils"
"carrot_bbs/internal/repository"
)
// UserService 用户服务
type UserService struct {
userRepo *repository.UserRepository
systemMessageService SystemMessageService
emailCodeService EmailCodeService
}
// NewUserService 创建用户服务
func NewUserService(
userRepo *repository.UserRepository,
systemMessageService SystemMessageService,
emailService EmailService,
cacheBackend cache.Cache,
) *UserService {
return &UserService{
userRepo: userRepo,
systemMessageService: systemMessageService,
emailCodeService: NewEmailCodeService(emailService, cacheBackend),
}
}
// SendRegisterCode 发送注册验证码
func (s *UserService) SendRegisterCode(ctx context.Context, email string) error {
user, err := s.userRepo.GetByEmail(email)
if err == nil && user != nil {
return ErrEmailExists
}
return s.emailCodeService.SendCode(ctx, CodePurposeRegister, email)
}
// SendPasswordResetCode 发送找回密码验证码
func (s *UserService) SendPasswordResetCode(ctx context.Context, email string) error {
user, err := s.userRepo.GetByEmail(email)
if err != nil || user == nil {
return ErrUserNotFound
}
return s.emailCodeService.SendCode(ctx, CodePurposePasswordReset, email)
}
// SendCurrentUserEmailVerifyCode 发送当前用户邮箱验证验证码
func (s *UserService) SendCurrentUserEmailVerifyCode(ctx context.Context, userID, email string) error {
user, err := s.userRepo.GetByID(userID)
if err != nil || user == nil {
return ErrUserNotFound
}
targetEmail := strings.TrimSpace(email)
if targetEmail == "" && user.Email != nil {
targetEmail = strings.TrimSpace(*user.Email)
}
if targetEmail == "" || !utils.ValidateEmail(targetEmail) {
return ErrInvalidEmail
}
if user.EmailVerified && user.Email != nil && strings.EqualFold(strings.TrimSpace(*user.Email), targetEmail) {
return ErrEmailAlreadyVerified
}
existingUser, queryErr := s.userRepo.GetByEmail(targetEmail)
if queryErr == nil && existingUser != nil && existingUser.ID != userID {
return ErrEmailExists
}
return s.emailCodeService.SendCode(ctx, CodePurposeEmailVerify, targetEmail)
}
// VerifyCurrentUserEmail 验证当前用户邮箱
func (s *UserService) VerifyCurrentUserEmail(ctx context.Context, userID, email, verificationCode string) error {
user, err := s.userRepo.GetByID(userID)
if err != nil || user == nil {
return ErrUserNotFound
}
targetEmail := strings.TrimSpace(email)
if targetEmail == "" && user.Email != nil {
targetEmail = strings.TrimSpace(*user.Email)
}
if targetEmail == "" || !utils.ValidateEmail(targetEmail) {
return ErrInvalidEmail
}
if err := s.emailCodeService.VerifyCode(CodePurposeEmailVerify, targetEmail, verificationCode); err != nil {
return err
}
existingUser, queryErr := s.userRepo.GetByEmail(targetEmail)
if queryErr == nil && existingUser != nil && existingUser.ID != userID {
return ErrEmailExists
}
user.Email = &targetEmail
user.EmailVerified = true
return s.userRepo.Update(user)
}
// SendChangePasswordCode 发送修改密码验证码
func (s *UserService) SendChangePasswordCode(ctx context.Context, userID string) error {
user, err := s.userRepo.GetByID(userID)
if err != nil || user == nil {
return ErrUserNotFound
}
if user.Email == nil || strings.TrimSpace(*user.Email) == "" {
return ErrEmailNotBound
}
return s.emailCodeService.SendCode(ctx, CodePurposeChangePassword, *user.Email)
}
// Register 用户注册
func (s *UserService) Register(ctx context.Context, username, email, password, nickname, phone, verificationCode string) (*model.User, error) {
// 验证用户名
if !utils.ValidateUsername(username) {
return nil, ErrInvalidUsername
}
// 注册必须提供邮箱并完成验证码校验
if email == "" || !utils.ValidateEmail(email) {
return nil, ErrInvalidEmail
}
if err := s.emailCodeService.VerifyCode(CodePurposeRegister, email, verificationCode); err != nil {
return nil, err
}
// 验证密码
if !utils.ValidatePassword(password) {
return nil, ErrWeakPassword
}
// 验证手机号(如果提供)
if phone != "" && !utils.ValidatePhone(phone) {
return nil, ErrInvalidPhone
}
// 检查用户名是否已存在
existingUser, err := s.userRepo.GetByUsername(username)
if err == nil && existingUser != nil {
return nil, ErrUsernameExists
}
// 检查邮箱是否已存在(如果提供)
if email != "" {
existingUser, err = s.userRepo.GetByEmail(email)
if err == nil && existingUser != nil {
return nil, ErrEmailExists
}
}
// 检查手机号是否已存在(如果提供)
if phone != "" {
existingUser, err = s.userRepo.GetByPhone(phone)
if err == nil && existingUser != nil {
return nil, ErrPhoneExists
}
}
// 密码哈希
hashedPassword, err := utils.HashPassword(password)
if err != nil {
return nil, err
}
// 创建用户
user := &model.User{
Username: username,
Nickname: nickname,
EmailVerified: true,
PasswordHash: hashedPassword,
Status: model.UserStatusActive,
}
// 如果提供了邮箱,设置指针值
if email != "" {
user.Email = &email
}
// 如果提供了手机号,设置指针值
if phone != "" {
user.Phone = &phone
}
err = s.userRepo.Create(user)
if err != nil {
return nil, err
}
return user, nil
}
// Login 用户登录
func (s *UserService) Login(ctx context.Context, account, password string) (*model.User, error) {
account = strings.TrimSpace(account)
var (
user *model.User
err error
)
if utils.ValidateEmail(account) {
user, err = s.userRepo.GetByEmail(account)
} else if utils.ValidatePhone(account) {
user, err = s.userRepo.GetByPhone(account)
} else {
user, err = s.userRepo.GetByUsername(account)
}
if err != nil || user == nil {
return nil, ErrInvalidCredentials
}
if !utils.CheckPasswordHash(password, user.PasswordHash) {
return nil, ErrInvalidCredentials
}
if user.Status != model.UserStatusActive {
return nil, ErrUserBanned
}
return user, nil
}
// GetUserByID 根据ID获取用户
func (s *UserService) GetUserByID(ctx context.Context, id string) (*model.User, error) {
return s.userRepo.GetByID(id)
}
// GetUserPostCount 获取用户帖子数(实时计算)
func (s *UserService) GetUserPostCount(ctx context.Context, userID string) (int64, error) {
return s.userRepo.GetPostsCount(userID)
}
// GetUserPostCountBatch 批量获取用户帖子数(实时计算)
func (s *UserService) GetUserPostCountBatch(ctx context.Context, userIDs []string) (map[string]int64, error) {
return s.userRepo.GetPostsCountBatch(userIDs)
}
// GetUserByIDWithFollowingStatus 根据ID获取用户包含当前用户是否关注的状态
func (s *UserService) GetUserByIDWithFollowingStatus(ctx context.Context, userID, currentUserID string) (*model.User, bool, error) {
user, err := s.userRepo.GetByID(userID)
if err != nil {
return nil, false, err
}
// 如果查询的是当前用户自己,不需要检查关注状态
if userID == currentUserID {
return user, false, nil
}
isFollowing, err := s.userRepo.IsFollowing(currentUserID, userID)
if err != nil {
return user, false, err
}
return user, isFollowing, nil
}
// GetUserByIDWithMutualFollowStatus 根据ID获取用户包含双向关注状态
func (s *UserService) GetUserByIDWithMutualFollowStatus(ctx context.Context, userID, currentUserID string) (*model.User, bool, bool, error) {
user, err := s.userRepo.GetByID(userID)
if err != nil {
return nil, false, false, err
}
// 如果查询的是当前用户自己,不需要检查关注状态
if userID == currentUserID {
return user, false, false, nil
}
// 当前用户是否关注了该用户
isFollowing, err := s.userRepo.IsFollowing(currentUserID, userID)
if err != nil {
return user, false, false, err
}
// 该用户是否关注了当前用户
isFollowingMe, err := s.userRepo.IsFollowing(userID, currentUserID)
if err != nil {
return user, isFollowing, false, err
}
return user, isFollowing, isFollowingMe, nil
}
// UpdateUser 更新用户
func (s *UserService) UpdateUser(ctx context.Context, user *model.User) error {
return s.userRepo.Update(user)
}
// GetFollowers 获取粉丝
func (s *UserService) GetFollowers(ctx context.Context, userID string, page, pageSize int) ([]*model.User, int64, error) {
return s.userRepo.GetFollowers(userID, page, pageSize)
}
// GetFollowing 获取关注
func (s *UserService) GetFollowing(ctx context.Context, userID string, page, pageSize int) ([]*model.User, int64, error) {
return s.userRepo.GetFollowing(userID, page, pageSize)
}
// FollowUser 关注用户
func (s *UserService) FollowUser(ctx context.Context, followerID, followeeID string) error {
fmt.Printf("[DEBUG] FollowUser called: followerID=%s, followeeID=%s\n", followerID, followeeID)
blocked, err := s.userRepo.IsBlockedEitherDirection(followerID, followeeID)
if err != nil {
return err
}
if blocked {
return ErrUserBlocked
}
// 检查是否已经关注
isFollowing, err := s.userRepo.IsFollowing(followerID, followeeID)
if err != nil {
fmt.Printf("[DEBUG] Error checking existing follow: %v\n", err)
return err
}
if isFollowing {
fmt.Printf("[DEBUG] Already following, skip creation\n")
return nil // 已经关注,直接返回成功
}
// 创建关注关系
follow := &model.Follow{
FollowerID: followerID,
FollowingID: followeeID,
}
err = s.userRepo.CreateFollow(follow)
if err != nil {
fmt.Printf("[DEBUG] CreateFollow error: %v\n", err)
return err
}
fmt.Printf("[DEBUG] Follow record created successfully\n")
// 刷新关注者的关注数(通过实际计数,更可靠)
err = s.userRepo.RefreshFollowingCount(followerID)
if err != nil {
fmt.Printf("[DEBUG] Error refreshing following count: %v\n", err)
// 不回滚,计数可以通过其他方式修复
}
// 刷新被关注者的粉丝数(通过实际计数,更可靠)
err = s.userRepo.RefreshFollowersCount(followeeID)
if err != nil {
fmt.Printf("[DEBUG] Error refreshing followers count: %v\n", err)
// 不回滚,计数可以通过其他方式修复
}
// 发送关注通知给被关注者
if s.systemMessageService != nil {
// 异步发送通知,不阻塞主流程
go func() {
notifyErr := s.systemMessageService.SendFollowNotification(context.Background(), followeeID, followerID)
if notifyErr != nil {
fmt.Printf("[DEBUG] Error sending follow notification: %v\n", notifyErr)
} else {
fmt.Printf("[DEBUG] Follow notification sent successfully to %s\n", followeeID)
}
}()
}
fmt.Printf("[DEBUG] FollowUser completed: followerID=%s, followeeID=%s\n", followerID, followeeID)
return nil
}
// UnfollowUser 取消关注用户
func (s *UserService) UnfollowUser(ctx context.Context, followerID, followeeID string) error {
fmt.Printf("[DEBUG] UnfollowUser called: followerID=%s, followeeID=%s\n", followerID, followeeID)
// 检查是否已经关注
isFollowing, err := s.userRepo.IsFollowing(followerID, followeeID)
if err != nil {
fmt.Printf("[DEBUG] Error checking existing follow: %v\n", err)
return err
}
if !isFollowing {
fmt.Printf("[DEBUG] Not following, skip deletion\n")
return nil // 没有关注,直接返回成功
}
// 删除关注关系
err = s.userRepo.DeleteFollow(followerID, followeeID)
if err != nil {
fmt.Printf("[DEBUG] DeleteFollow error: %v\n", err)
return err
}
fmt.Printf("[DEBUG] Follow record deleted successfully\n")
// 刷新关注者的关注数(通过实际计数,更可靠)
err = s.userRepo.RefreshFollowingCount(followerID)
if err != nil {
fmt.Printf("[DEBUG] Error refreshing following count: %v\n", err)
}
// 刷新被关注者的粉丝数(通过实际计数,更可靠)
err = s.userRepo.RefreshFollowersCount(followeeID)
if err != nil {
fmt.Printf("[DEBUG] Error refreshing followers count: %v\n", err)
}
fmt.Printf("[DEBUG] UnfollowUser completed: followerID=%s, followeeID=%s\n", followerID, followeeID)
return nil
}
// BlockUser 拉黑用户,并自动清理双向关注/粉丝关系
func (s *UserService) BlockUser(ctx context.Context, blockerID, blockedID string) error {
if blockerID == blockedID {
return ErrInvalidOperation
}
return s.userRepo.BlockUserAndCleanupRelations(blockerID, blockedID)
}
// UnblockUser 取消拉黑
func (s *UserService) UnblockUser(ctx context.Context, blockerID, blockedID string) error {
if blockerID == blockedID {
return ErrInvalidOperation
}
return s.userRepo.UnblockUser(blockerID, blockedID)
}
// GetBlockedUsers 获取黑名单列表
func (s *UserService) GetBlockedUsers(ctx context.Context, blockerID string, page, pageSize int) ([]*model.User, int64, error) {
return s.userRepo.GetBlockedUsers(blockerID, page, pageSize)
}
// IsBlocked 检查当前用户是否已拉黑目标用户
func (s *UserService) IsBlocked(ctx context.Context, blockerID, blockedID string) (bool, error) {
return s.userRepo.IsBlocked(blockerID, blockedID)
}
// GetFollowingList 获取关注列表(字符串参数版本)
func (s *UserService) GetFollowingList(ctx context.Context, userID, page, pageSize string) ([]*model.User, error) {
// 转换字符串参数为整数
pageInt := 1
pageSizeInt := 20
if page != "" {
_, err := fmt.Sscanf(page, "%d", &pageInt)
if err != nil {
pageInt = 1
}
}
if pageSize != "" {
_, err := fmt.Sscanf(pageSize, "%d", &pageSizeInt)
if err != nil {
pageSizeInt = 20
}
}
users, _, err := s.userRepo.GetFollowing(userID, pageInt, pageSizeInt)
return users, err
}
// GetFollowersList 获取粉丝列表(字符串参数版本)
func (s *UserService) GetFollowersList(ctx context.Context, userID, page, pageSize string) ([]*model.User, error) {
// 转换字符串参数为整数
pageInt := 1
pageSizeInt := 20
if page != "" {
_, err := fmt.Sscanf(page, "%d", &pageInt)
if err != nil {
pageInt = 1
}
}
if pageSize != "" {
_, err := fmt.Sscanf(pageSize, "%d", &pageSizeInt)
if err != nil {
pageSizeInt = 20
}
}
users, _, err := s.userRepo.GetFollowers(userID, pageInt, pageSizeInt)
return users, err
}
// GetMutualFollowStatus 批量获取双向关注状态
func (s *UserService) GetMutualFollowStatus(ctx context.Context, currentUserID string, targetUserIDs []string) (map[string][2]bool, error) {
return s.userRepo.GetMutualFollowStatus(currentUserID, targetUserIDs)
}
// CheckUsernameAvailable 检查用户名是否可用
func (s *UserService) CheckUsernameAvailable(ctx context.Context, username string) (bool, error) {
user, err := s.userRepo.GetByUsername(username)
if err != nil {
return true, nil // 用户不存在,可用
}
return user == nil, nil
}
// ChangePassword 修改密码
func (s *UserService) ChangePassword(ctx context.Context, userID, oldPassword, newPassword, verificationCode string) error {
// 获取用户
user, err := s.userRepo.GetByID(userID)
if err != nil {
return ErrUserNotFound
}
if user.Email == nil || strings.TrimSpace(*user.Email) == "" {
return ErrEmailNotBound
}
if err := s.emailCodeService.VerifyCode(CodePurposeChangePassword, *user.Email, verificationCode); err != nil {
return err
}
// 验证旧密码
if !utils.CheckPasswordHash(oldPassword, user.PasswordHash) {
return ErrInvalidCredentials
}
// 哈希新密码
hashedPassword, err := utils.HashPassword(newPassword)
if err != nil {
return err
}
// 更新密码
user.PasswordHash = hashedPassword
return s.userRepo.Update(user)
}
// ResetPasswordByEmail 通过邮箱重置密码
func (s *UserService) ResetPasswordByEmail(ctx context.Context, email, verificationCode, newPassword string) error {
email = strings.TrimSpace(email)
if !utils.ValidateEmail(email) {
return ErrInvalidEmail
}
if !utils.ValidatePassword(newPassword) {
return ErrWeakPassword
}
if err := s.emailCodeService.VerifyCode(CodePurposePasswordReset, email, verificationCode); err != nil {
return err
}
user, err := s.userRepo.GetByEmail(email)
if err != nil || user == nil {
return ErrUserNotFound
}
hashedPassword, err := utils.HashPassword(newPassword)
if err != nil {
return err
}
user.PasswordHash = hashedPassword
return s.userRepo.Update(user)
}
// Search 搜索用户
func (s *UserService) Search(ctx context.Context, keyword string, page, pageSize int) ([]*model.User, int64, error) {
return s.userRepo.Search(keyword, page, pageSize)
}
// 错误定义
var (
ErrInvalidUsername = &ServiceError{Code: 400, Message: "invalid username"}
ErrInvalidEmail = &ServiceError{Code: 400, Message: "invalid email"}
ErrInvalidPhone = &ServiceError{Code: 400, Message: "invalid phone number"}
ErrWeakPassword = &ServiceError{Code: 400, Message: "password too weak"}
ErrUsernameExists = &ServiceError{Code: 400, Message: "username already exists"}
ErrEmailExists = &ServiceError{Code: 400, Message: "email already exists"}
ErrPhoneExists = &ServiceError{Code: 400, Message: "phone number already exists"}
ErrUserNotFound = &ServiceError{Code: 404, Message: "user not found"}
ErrUserBanned = &ServiceError{Code: 403, Message: "user is banned"}
ErrUserBlocked = &ServiceError{Code: 403, Message: "blocked relationship exists"}
ErrInvalidOperation = &ServiceError{Code: 400, Message: "invalid operation"}
ErrEmailServiceUnavailable = &ServiceError{Code: 503, Message: "email service unavailable"}
ErrVerificationCodeTooFrequent = &ServiceError{Code: 429, Message: "verification code sent too frequently"}
ErrVerificationCodeInvalid = &ServiceError{Code: 400, Message: "invalid verification code"}
ErrVerificationCodeExpired = &ServiceError{Code: 400, Message: "verification code expired"}
ErrVerificationCodeUnavailable = &ServiceError{Code: 500, Message: "verification code storage unavailable"}
ErrEmailAlreadyVerified = &ServiceError{Code: 400, Message: "email already verified"}
ErrEmailNotBound = &ServiceError{Code: 400, Message: "email not bound"}
)
// ServiceError 服务错误
type ServiceError struct {
Code int
Message string
}
func (e *ServiceError) Error() string {
return e.Message
}
var ErrInvalidCredentials = &ServiceError{Code: 401, Message: "invalid username or password"}

Some files were not shown because too many files have changed in this diff Show More