commit 4d8f2ec9974df8a45c5bc20c79d9ad577203398c Author: lan Date: Mon Mar 9 21:28:58 2026 +0800 Initial backend repository commit. Set up project files and add .gitignore to exclude local build/runtime artifacts. Made-with: Cursor diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..13a73bb --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +# Build artifacts +server +*.tar + +# Runtime files +data/ +logs/ +*.log + +# Local docker metadata +.last_docker_tag +.last_docker_tar + +# Local environment files +.env +.env.* diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..c802f7b --- /dev/null +++ b/Dockerfile @@ -0,0 +1,28 @@ +# 使用Debian bookworm-slim作为生产镜像 +FROM debian:bookworm-slim + +# 安装必要的运行时依赖 +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates \ + tzdata \ + libsqlite3-0 \ + && rm -rf /var/lib/apt/lists/* \ + && apt-get clean + +# 设置工作目录 +WORKDIR /app + +# 从宿主机复制编译好的二进制文件 +COPY server ./carrot_bbs + +# 复制配置文件 +COPY configs ./configs + +# 给可执行文件添加执行权限 +RUN chmod +x ./carrot_bbs + +# 暴露端口 +EXPOSE 8080 + +# 默认命令 +CMD ["./carrot_bbs"] diff --git a/configs/config.yaml b/configs/config.yaml new file mode 100644 index 0000000..d686eec --- /dev/null +++ b/configs/config.yaml @@ -0,0 +1,172 @@ +# 服务器配置 +# 环境变量: APP_SERVER_HOST, APP_SERVER_PORT, APP_SERVER_MODE +server: + host: 0.0.0.0 + port: 8080 + mode: debug + +# 数据库配置 +# 环境变量: +# SQLite: APP_DATABASE_SQLITE_PATH +# Postgres: APP_DATABASE_POSTGRES_HOST, APP_DATABASE_POSTGRES_PORT, APP_DATABASE_POSTGRES_USER, +# APP_DATABASE_POSTGRES_PASSWORD, APP_DATABASE_POSTGRES_DBNAME +database: + type: sqlite # sqlite 或 postgres + sqlite: + path: ./data/carrot_bbs.db + postgres: + host: localhost + port: 5432 + user: postgres + password: postgres + dbname: carrot_bbs + sslmode: disable + max_idle_conns: 10 + max_open_conns: 100 + log_level: warn + slow_threshold_ms: 200 + ignore_record_not_found: true + parameterized_queries: true + +# Redis配置 +# 环境变量: +# 类型: APP_REDIS_TYPE (miniredis/redis) +# Redis: APP_REDIS_REDIS_HOST, APP_REDIS_REDIS_PORT, APP_REDIS_REDIS_PASSWORD, APP_REDIS_REDIS_DB +# Miniredis: APP_REDIS_MINIREDIS_HOST, APP_REDIS_MINIREDIS_PORT +redis: + type: miniredis # miniredis 或 redis + redis: + host: localhost + port: 6379 + password: "" + db: 0 + miniredis: + host: localhost + port: 6379 + pool_size: 10 + +# 缓存配置 +# 环境变量: +# APP_CACHE_ENABLED, APP_CACHE_KEY_PREFIX, APP_CACHE_DEFAULT_TTL, APP_CACHE_NULL_TTL +# APP_CACHE_JITTER_RATIO, APP_CACHE_DISABLE_FLUSHDB +# APP_CACHE_MODULES_POST_LIST_TTL, APP_CACHE_MODULES_CONVERSATION_TTL +# APP_CACHE_MODULES_UNREAD_COUNT_TTL, APP_CACHE_MODULES_GROUP_MEMBERS_TTL +cache: + enabled: true + key_prefix: "" + default_ttl: 30 + null_ttl: 5 + jitter_ratio: 0.1 + disable_flushdb: true + modules: + post_list_ttl: 30 + conversation_ttl: 60 + unread_count_ttl: 30 + group_members_ttl: 120 + +# S3对象存储配置 +# 环境变量: APP_S3_ENDPOINT, APP_S3_ACCESS_KEY, APP_S3_SECRET_KEY, APP_S3_BUCKET, APP_S3_DOMAIN +s3: + endpoint: "" + access_key: "" + secret_key: "" + bucket: "" + use_ssl: true + region: us-east-1 + domain: "" +# JWT配置 +# 环境变量: APP_JWT_SECRET +jwt: + secret: your-jwt-secret-key-change-in-production + access_token_expire: 86400 # 24 hours in seconds + refresh_token_expire: 604800 # 7 days in seconds + +log: + level: info + encoding: json + output_paths: + - stdout + - ./logs/app.log + +rate_limit: + enabled: true + requests_per_minute: 60 + +upload: + max_file_size: 10485760 # 10MB + allowed_types: + - image/jpeg + - image/png + - image/gif + - image/webp + +# 敏感词过滤配置 +sensitive: + enabled: true + replace_str: "***" + min_match_len: 1 + load_from_db: true + load_from_redis: false + redis_key_prefix: "sensitive_words" + +# 内容审核服务配置 +audit: + enabled: false # 暂时关闭第三方审核 + # 审核服务提供商: local, aliyun, tencent, baidu + provider: "local" + auto_audit: true + timeout: 30 + # 阿里云配置 + aliyun_access_key: "" + aliyun_secret_key: "" + aliyun_region: "cn-shanghai" + # 腾讯云配置 + tencent_secret_id: "" + tencent_secret_key: "" + # 百度云配置 + baidu_api_key: "" + baidu_secret_key: "" + +# Gorse推荐系统配置 +# 环境变量: APP_GORSE_ADDRESS, APP_GORSE_API_KEY, APP_GORSE_DASHBOARD, APP_GORSE_IMPORT_PASSWORD +gorse: + enabled: false + address: "" # Gorse server地址 + api_key: "" # API密钥 + dashboard: "" # Gorse dashboard地址 + import_password: "" # 导入数据密码 + embedding_api_key: "" + embedding_url: "https://api.littlelan.cn/v1/embeddings" + embedding_model: "BAAI/bge-m3" + +# OpenAI兼容接口配置(用于帖子审核,支持图文) +# 环境变量: +# APP_OPENAI_ENABLED, APP_OPENAI_BASE_URL, APP_OPENAI_API_KEY +# APP_OPENAI_MODERATION_MODEL, APP_OPENAI_MODERATION_MAX_IMAGES_PER_REQUEST +# APP_OPENAI_REQUEST_TIMEOUT, APP_OPENAI_STRICT_MODERATION +openai: + enabled: true + base_url: "https://api.littlelan.cn/" + api_key: "" + moderation_model: "qwen3.5-122b" + moderation_max_images_per_request: 1 + request_timeout: 30 + strict_moderation: false + +# SMTP发信配置(gomail.v2) +# 环境变量: +# APP_EMAIL_ENABLED, APP_EMAIL_HOST, APP_EMAIL_PORT +# APP_EMAIL_USERNAME, APP_EMAIL_PASSWORD +# APP_EMAIL_FROM_ADDRESS, APP_EMAIL_FROM_NAME +# APP_EMAIL_USE_TLS, APP_EMAIL_INSECURE_SKIP_VERIFY, APP_EMAIL_TIMEOUT +email: + enabled: false + host: "" + port: 587 + username: "" + password: "" + from_address: "" + from_name: "Carrot BBS" + use_tls: true + insecure_skip_verify: false + timeout: 15 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..6f4aec4 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,78 @@ +version: '3.8' + +services: + # 开发环境使用 SQLite 和 miniredis,不需要外部数据库服务 + # 如需使用 PostgreSQL 和 Redis,切换 config.yaml 中的配置并取消注释以下服务 + + # PostgreSQL (生产环境使用) + # postgres: + # image: postgres:15-alpine + # container_name: carrot_bbs_postgres + # environment: + # POSTGRES_USER: postgres + # POSTGRES_PASSWORD: postgres + # POSTGRES_DB: carrot_bbs + # ports: + # - "5432:5432" + # volumes: + # - postgres_data:/var/lib/postgresql/data + # healthcheck: + # test: ["CMD-SHELL", "pg_isready -U postgres"] + # interval: 10s + # timeout: 5s + # retries: 5 + + # Redis (生产环境使用) + # redis: + # image: redis:7-alpine + # container_name: carrot_bbs_redis + # ports: + # - "6379:6379" + # volumes: + # - redis_data:/data + # healthcheck: + # test: ["CMD", "redis-cli", "ping"] + # interval: 10s + # timeout: 5s + # retries: 5 + + # MinIO (对象存储,可选) + minio: + image: minio/minio:latest + container_name: carrot_bbs_minio + environment: + MINIO_ROOT_USER: minioadmin + MINIO_ROOT_PASSWORD: minioadmin + ports: + - "9000:9000" + - "9001:9001" + volumes: + - minio_data:/data + command: server /data --console-address ":9001" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + interval: 30s + timeout: 20s + retries: 3 + + app: + build: + context: . + dockerfile: Dockerfile + container_name: carrot_bbs_app + ports: + - "8080:8080" + depends_on: + minio: + condition: service_healthy + environment: + - CONFIG_PATH=/app/configs/config.yaml + volumes: + - ./configs:/app/configs + - ./logs:/app/logs + - ./data:/app/data + +volumes: + # postgres_data: + # redis_data: + minio_data: diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..0466572 --- /dev/null +++ b/go.mod @@ -0,0 +1,86 @@ +module carrot_bbs + +go 1.25 + +require ( + github.com/alicebob/miniredis/v2 v2.31.0 + github.com/gin-gonic/gin v1.9.1 + github.com/golang-jwt/jwt/v5 v5.2.0 + github.com/google/uuid v1.5.0 + github.com/gorilla/websocket v1.5.3 + github.com/gorse-io/gorse-go v0.5.0-alpha.3 + github.com/minio/minio-go/v7 v7.0.66 + github.com/redis/go-redis/v9 v9.3.1 + github.com/spf13/viper v1.18.2 + go.uber.org/zap v1.26.0 + golang.org/x/crypto v0.17.0 + golang.org/x/image v0.24.0 + gorm.io/driver/postgres v1.5.4 + gorm.io/driver/sqlite v1.5.4 + gorm.io/gorm v1.25.5 +) + +require ( + github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect + github.com/bytedance/sonic v1.9.1 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/felixge/httpsnoop v1.0.3 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-logr/logr v1.2.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.4.3 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/compress v1.17.4 // indirect + github.com/klauspost/cpuid/v2 v2.2.6 // indirect + github.com/leodido/go-urn v1.2.4 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mattn/go-sqlite3 v1.14.17 // indirect + github.com/minio/md5-simd v1.1.2 // indirect + github.com/minio/sha256-simd v1.0.1 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.1.0 // indirect + github.com/rs/xid v1.5.0 // indirect + github.com/sagikazarmark/locafero v0.4.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.6.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.11 // indirect + github.com/yuin/gopher-lua v1.1.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4 // indirect + go.opentelemetry.io/otel v1.11.1 // indirect + go.opentelemetry.io/otel/metric v0.33.0 // indirect + go.opentelemetry.io/otel/trace v1.11.1 // indirect + go.uber.org/multierr v1.10.0 // indirect + golang.org/x/arch v0.3.0 // indirect + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect + golang.org/x/net v0.19.0 // indirect + golang.org/x/sys v0.15.0 // indirect + golang.org/x/text v0.22.0 // indirect + google.golang.org/protobuf v1.31.0 // indirect + gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect + gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..573d2b2 --- /dev/null +++ b/go.sum @@ -0,0 +1,216 @@ +github.com/DmitriyVTitov/size v1.5.0/go.mod h1:le6rNI4CoLQV1b9gzp1+3d7hMAD/uu2QcJ+aYbNgiU0= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis/v2 v2.31.0 h1:ObEFUNlJwoIiyjxdrYF0QIDE7qXcLc7D3WpSH4c22PU= +github.com/alicebob/miniredis/v2 v2.31.0/go.mod h1:UB/T2Uztp7MlFSDakaX1sTXUv5CASoprx0wulRT6HBg= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= +github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk= +github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= +github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= +github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= +github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= +github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= +github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU= +github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorse-io/gorse-go v0.5.0-alpha.3 h1:GR/OWzq016VyFyTTxgQWeayGahRVzB1cGFIW/AaShC4= +github.com/gorse-io/gorse-go v0.5.0-alpha.3/go.mod h1:ZxmVHzZPKm5pmEIlqaRDwK0rkfTRHlfziO033XZ+RW0= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= +github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= +github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= +github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= +github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= +github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw= +github.com/minio/minio-go/v7 v7.0.66/go.mod h1:DHAgmyQEGdW3Cif0UooKOyrT3Vxs82zNdV6tkKhRtbs= +github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM= +github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= +github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.3.1 h1:KqdY8U+3X6z+iACvumCNxnoluToB+9Me+TvyFa21Mds= +github.com/redis/go-redis/v9 v9.3.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= +github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= +github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= +github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= +github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/yuin/gopher-lua v1.1.0 h1:BojcDhfyDWgU2f2TOzYK/g5p2gxMrku8oupLDqlnSqE= +github.com/yuin/gopher-lua v1.1.0/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4 h1:aUEBEdCa6iamGzg6fuYxDA8ThxvOG240mAvWDU+XLio= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4/go.mod h1:l2MdsbKTocpPS5nQZscqTR9jd8u96VYZdcpF8Sye7mA= +go.opentelemetry.io/otel v1.11.1 h1:4WLLAmcfkmDk2ukNXJyq3/kiz/3UzCaYq6PskJsaou4= +go.opentelemetry.io/otel v1.11.1/go.mod h1:1nNhXBbWSD0nsL38H6btgnFN2k4i0sNLHNNMZMSbUGE= +go.opentelemetry.io/otel/metric v0.33.0 h1:xQAyl7uGEYvrLAiV/09iTJlp1pZnQ9Wl793qbVvED1E= +go.opentelemetry.io/otel/metric v0.33.0/go.mod h1:QlTYc+EnYNq/M2mNk1qDDMRLpqCOj2f/r5c7Fd5FYaI= +go.opentelemetry.io/otel/trace v1.11.1 h1:ofxdnzsNrGBYXbP7t7zpUK281+go5rF7dvdIZXF8gdQ= +go.opentelemetry.io/otel/trace v1.11.1/go.mod h1:f/Q9G7vzk5u91PhbmKbg1Qn0rzH1LJ4vbPHFGkTPtOk= +go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk= +go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= +go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= +golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= +golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/image v0.24.0 h1:AN7zRgVsbvmTfNyqIbbOraYL8mSwcKncEj8ofjgzcMQ= +golang.org/x/image v0.24.0/go.mod h1:4b/ITuLfqYq1hqZcjofwctIhi7sZh2WaCjvsBNjjya8= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= +gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE= +gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo= +gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0= +gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0= +gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= +gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= +gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..cbb225d --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,604 @@ +package cache + +import ( + "context" + "encoding/json" + "fmt" + "log" + "math/rand" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9" + + redisPkg "carrot_bbs/internal/pkg/redis" +) + +// Cache 缓存接口 +type Cache interface { + // Set 设置缓存值,支持TTL + Set(key string, value interface{}, ttl time.Duration) + // Get 获取缓存值 + Get(key string) (interface{}, bool) + // Delete 删除缓存 + Delete(key string) + // DeleteByPrefix 根据前缀删除缓存 + DeleteByPrefix(prefix string) + // Clear 清空所有缓存 + Clear() + // Exists 检查键是否存在 + Exists(key string) bool + // Increment 增加计数器的值 + Increment(key string) int64 + // IncrementBy 增加指定值 + IncrementBy(key string, value int64) int64 +} + +// cacheItem 缓存项(用于内存缓存降级) +type cacheItem struct { + value interface{} + expiration int64 // 过期时间戳(纳秒) +} + +const nullMarkerValue = "__carrot_cache_null__" + +type cacheMetrics struct { + hit atomic.Int64 + miss atomic.Int64 + decodeError atomic.Int64 + setError atomic.Int64 + invalidate atomic.Int64 +} + +var metrics cacheMetrics +var loadLocks sync.Map + +type MetricsSnapshot struct { + Hit int64 + Miss int64 + DecodeError int64 + SetError int64 + Invalidate int64 +} + +type Settings struct { + Enabled bool + KeyPrefix string + DefaultTTL time.Duration + NullTTL time.Duration + JitterRatio float64 + PostListTTL time.Duration + ConversationTTL time.Duration + UnreadCountTTL time.Duration + GroupMembersTTL time.Duration + DisableFlushDB bool +} + +var settings = Settings{ + Enabled: true, + DefaultTTL: 30 * time.Second, + NullTTL: 5 * time.Second, + JitterRatio: 0.1, + PostListTTL: 30 * time.Second, + ConversationTTL: 60 * time.Second, + UnreadCountTTL: 30 * time.Second, + GroupMembersTTL: 120 * time.Second, + DisableFlushDB: true, +} + +func Configure(s Settings) { + settings.Enabled = s.Enabled + if s.KeyPrefix != "" { + settings.KeyPrefix = s.KeyPrefix + } + if s.DefaultTTL > 0 { + settings.DefaultTTL = s.DefaultTTL + } + if s.NullTTL > 0 { + settings.NullTTL = s.NullTTL + } + if s.JitterRatio > 0 { + settings.JitterRatio = s.JitterRatio + } + if s.PostListTTL > 0 { + settings.PostListTTL = s.PostListTTL + } + if s.ConversationTTL > 0 { + settings.ConversationTTL = s.ConversationTTL + } + if s.UnreadCountTTL > 0 { + settings.UnreadCountTTL = s.UnreadCountTTL + } + if s.GroupMembersTTL > 0 { + settings.GroupMembersTTL = s.GroupMembersTTL + } + settings.DisableFlushDB = s.DisableFlushDB +} + +func GetSettings() Settings { + return settings +} + +func normalizeKey(key string) string { + if settings.KeyPrefix == "" { + return key + } + return settings.KeyPrefix + ":" + key +} + +func normalizePrefix(prefix string) string { + if settings.KeyPrefix == "" { + return prefix + } + return settings.KeyPrefix + ":" + prefix +} + +func GetMetricsSnapshot() MetricsSnapshot { + return MetricsSnapshot{ + Hit: metrics.hit.Load(), + Miss: metrics.miss.Load(), + DecodeError: metrics.decodeError.Load(), + SetError: metrics.setError.Load(), + Invalidate: metrics.invalidate.Load(), + } +} + +// isExpired 检查是否过期 +func (item *cacheItem) isExpired() bool { + if item.expiration == 0 { + return false + } + return time.Now().UnixNano() > item.expiration +} + +// MemoryCache 内存缓存实现(降级使用) +type MemoryCache struct { + items sync.Map + // cleanupInterval 清理过期缓存的间隔 + cleanupInterval time.Duration + // stopCleanup 停止清理协程的通道 + stopCleanup chan struct{} +} + +// NewMemoryCache 创建内存缓存 +func NewMemoryCache() *MemoryCache { + c := &MemoryCache{ + cleanupInterval: 1 * time.Minute, + stopCleanup: make(chan struct{}), + } + // 启动后台清理协程 + go c.cleanup() + return c +} + +// Set 设置缓存值 +func (c *MemoryCache) Set(key string, value interface{}, ttl time.Duration) { + key = normalizeKey(key) + var expiration int64 + if ttl > 0 { + expiration = time.Now().Add(ttl).UnixNano() + } + c.items.Store(key, &cacheItem{ + value: value, + expiration: expiration, + }) +} + +// Get 获取缓存值 +func (c *MemoryCache) Get(key string) (interface{}, bool) { + key = normalizeKey(key) + val, ok := c.items.Load(key) + if !ok { + return nil, false + } + + item := val.(*cacheItem) + if item.isExpired() { + c.items.Delete(key) + return nil, false + } + + return item.value, true +} + +// Delete 删除缓存 +func (c *MemoryCache) Delete(key string) { + key = normalizeKey(key) + metrics.invalidate.Add(1) + c.items.Delete(key) +} + +// DeleteByPrefix 根据前缀删除缓存 +func (c *MemoryCache) DeleteByPrefix(prefix string) { + prefix = normalizePrefix(prefix) + c.items.Range(func(key, value interface{}) bool { + if keyStr, ok := key.(string); ok { + if strings.HasPrefix(keyStr, prefix) { + metrics.invalidate.Add(1) + c.items.Delete(key) + } + } + return true + }) +} + +// Clear 清空所有缓存 +func (c *MemoryCache) Clear() { + c.items.Range(func(key, value interface{}) bool { + metrics.invalidate.Add(1) + c.items.Delete(key) + return true + }) +} + +// Exists 检查键是否存在 +func (c *MemoryCache) Exists(key string) bool { + _, ok := c.Get(key) + return ok +} + +// Increment 增加计数器的值 +func (c *MemoryCache) Increment(key string) int64 { + return c.IncrementBy(key, 1) +} + +// IncrementBy 增加指定值 +func (c *MemoryCache) IncrementBy(key string, value int64) int64 { + key = normalizeKey(key) + for { + val, ok := c.items.Load(key) + if !ok { + // 键不存在,创建新值 + c.items.Store(key, &cacheItem{ + value: value, + expiration: 0, + }) + return value + } + + item := val.(*cacheItem) + if item.isExpired() { + // 已过期,创建新值 + c.items.Store(key, &cacheItem{ + value: value, + expiration: 0, + }) + return value + } + + // 尝试更新 + currentValue, ok := item.value.(int64) + if !ok { + // 类型不匹配,覆盖为新值 + c.items.Store(key, &cacheItem{ + value: value, + expiration: item.expiration, + }) + return value + } + + newValue := currentValue + value + // 使用 CAS 操作确保并发安全 + if c.items.CompareAndSwap(key, val, &cacheItem{ + value: newValue, + expiration: item.expiration, + }) { + return newValue + } + // CAS 失败,重试 + } +} + +// cleanup 定期清理过期缓存 +func (c *MemoryCache) cleanup() { + ticker := time.NewTicker(c.cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + c.cleanExpired() + case <-c.stopCleanup: + return + } + } +} + +// cleanExpired 清理过期缓存 +func (c *MemoryCache) cleanExpired() { + count := 0 + c.items.Range(func(key, value interface{}) bool { + item := value.(*cacheItem) + if item.isExpired() { + c.items.Delete(key) + count++ + } + return true + }) + if count > 0 { + log.Printf("[Cache] Cleaned %d expired items", count) + } +} + +// Stop 停止缓存清理协程 +func (c *MemoryCache) Stop() { + close(c.stopCleanup) +} + +// RedisCache Redis缓存实现 +type RedisCache struct { + client *redisPkg.Client + ctx context.Context +} + +// NewRedisCache 创建Redis缓存 +func NewRedisCache(client *redisPkg.Client) *RedisCache { + return &RedisCache{ + client: client, + ctx: context.Background(), + } +} + +// Set 设置缓存值 +func (c *RedisCache) Set(key string, value interface{}, ttl time.Duration) { + key = normalizeKey(key) + // 将值序列化为JSON + data, err := json.Marshal(value) + if err != nil { + metrics.setError.Add(1) + log.Printf("[RedisCache] Failed to marshal value for key %s: %v", key, err) + return + } + + if err := c.client.Set(c.ctx, key, data, ttl); err != nil { + metrics.setError.Add(1) + log.Printf("[RedisCache] Failed to set key %s: %v", key, err) + } +} + +// Get 获取缓存值 +func (c *RedisCache) Get(key string) (interface{}, bool) { + key = normalizeKey(key) + data, err := c.client.Get(c.ctx, key) + if err != nil { + if err == redis.Nil { + return nil, false + } + log.Printf("[RedisCache] Failed to get key %s: %v", key, err) + return nil, false + } + + // 返回原始字符串,由调用侧决定如何解码为目标类型 + return data, true +} + +// Delete 删除缓存 +func (c *RedisCache) Delete(key string) { + key = normalizeKey(key) + metrics.invalidate.Add(1) + if err := c.client.Del(c.ctx, key); err != nil { + log.Printf("[RedisCache] Failed to delete key %s: %v", key, err) + } +} + +// DeleteByPrefix 根据前缀删除缓存 +func (c *RedisCache) DeleteByPrefix(prefix string) { + prefix = normalizePrefix(prefix) + // 使用原生客户端执行SCAN命令 + rdb := c.client.GetClient() + var cursor uint64 + for { + keys, nextCursor, err := rdb.Scan(c.ctx, cursor, prefix+"*", 100).Result() + if err != nil { + log.Printf("[RedisCache] Failed to scan keys with prefix %s: %v", prefix, err) + return + } + + if len(keys) > 0 { + metrics.invalidate.Add(int64(len(keys))) + if err := c.client.Del(c.ctx, keys...); err != nil { + log.Printf("[RedisCache] Failed to delete keys with prefix %s: %v", prefix, err) + } + } + + cursor = nextCursor + if cursor == 0 { + break + } + } +} + +// Clear 清空所有缓存 +func (c *RedisCache) Clear() { + if settings.DisableFlushDB { + log.Printf("[RedisCache] Skip FlushDB because cache.disable_flushdb=true") + return + } + metrics.invalidate.Add(1) + rdb := c.client.GetClient() + if err := rdb.FlushDB(c.ctx).Err(); err != nil { + log.Printf("[RedisCache] Failed to clear cache: %v", err) + } +} + +// Exists 检查键是否存在 +func (c *RedisCache) Exists(key string) bool { + key = normalizeKey(key) + n, err := c.client.Exists(c.ctx, key) + if err != nil { + log.Printf("[RedisCache] Failed to check existence of key %s: %v", key, err) + return false + } + return n > 0 +} + +// Increment 增加计数器的值 +func (c *RedisCache) Increment(key string) int64 { + return c.IncrementBy(key, 1) +} + +// IncrementBy 增加指定值 +func (c *RedisCache) IncrementBy(key string, value int64) int64 { + key = normalizeKey(key) + rdb := c.client.GetClient() + result, err := rdb.IncrBy(c.ctx, key, value).Result() + if err != nil { + log.Printf("[RedisCache] Failed to increment key %s: %v", key, err) + return 0 + } + return result +} + +// 全局缓存实例 +var globalCache Cache +var once sync.Once + +// InitCache 初始化全局缓存实例(使用Redis) +func InitCache(redisClient *redisPkg.Client) { + once.Do(func() { + if redisClient != nil { + globalCache = NewRedisCache(redisClient) + log.Println("[Cache] Initialized Redis cache") + } else { + globalCache = NewMemoryCache() + log.Println("[Cache] Initialized Memory cache (Redis not available)") + } + }) +} + +// GetCache 获取全局缓存实例 +func GetCache() Cache { + if globalCache == nil { + // 如果未初始化,返回内存缓存作为降级 + log.Println("[Cache] Warning: Cache not initialized, using Memory cache") + return NewMemoryCache() + } + return globalCache +} + +// GetRedisClient 从缓存中获取Redis客户端(仅在Redis模式下有效) +func GetRedisClient() (*redisPkg.Client, error) { + if redisCache, ok := globalCache.(*RedisCache); ok { + return redisCache.client, nil + } + return nil, fmt.Errorf("cache is not using Redis backend") +} + +func SetWithJitter(c Cache, key string, value interface{}, ttl time.Duration, jitterRatio float64) { + if !settings.Enabled { + return + } + c.Set(key, value, ApplyTTLJitter(ttl, jitterRatio)) +} + +func SetNull(c Cache, key string, ttl time.Duration) { + if !settings.Enabled { + return + } + c.Set(key, nullMarkerValue, ttl) +} + +func ApplyTTLJitter(ttl time.Duration, jitterRatio float64) time.Duration { + if ttl <= 0 || jitterRatio <= 0 { + return ttl + } + if jitterRatio > 1 { + jitterRatio = 1 + } + maxJitter := int64(float64(ttl) * jitterRatio) + if maxJitter <= 0 { + return ttl + } + delta := rand.Int63n(maxJitter + 1) + return ttl + time.Duration(delta) +} + +func GetTyped[T any](c Cache, key string) (T, bool) { + var zero T + if !settings.Enabled { + return zero, false + } + raw, ok := c.Get(key) + if !ok { + metrics.miss.Add(1) + return zero, false + } + if str, ok := raw.(string); ok && str == nullMarkerValue { + metrics.hit.Add(1) + return zero, false + } + + if typed, ok := raw.(T); ok { + metrics.hit.Add(1) + return typed, true + } + + var out T + switch v := raw.(type) { + case string: + if err := json.Unmarshal([]byte(v), &out); err != nil { + metrics.decodeError.Add(1) + return zero, false + } + metrics.hit.Add(1) + return out, true + case []byte: + if err := json.Unmarshal(v, &out); err != nil { + metrics.decodeError.Add(1) + return zero, false + } + metrics.hit.Add(1) + return out, true + default: + data, err := json.Marshal(v) + if err != nil { + metrics.decodeError.Add(1) + return zero, false + } + if err := json.Unmarshal(data, &out); err != nil { + metrics.decodeError.Add(1) + return zero, false + } + metrics.hit.Add(1) + return out, true + } +} + +func GetOrLoadTyped[T any]( + c Cache, + key string, + ttl time.Duration, + jitterRatio float64, + nullTTL time.Duration, + loader func() (T, error), +) (T, error) { + if cached, ok := GetTyped[T](c, key); ok { + return cached, nil + } + + lockValue, _ := loadLocks.LoadOrStore(key, &sync.Mutex{}) + lock := lockValue.(*sync.Mutex) + lock.Lock() + defer lock.Unlock() + + if cached, ok := GetTyped[T](c, key); ok { + return cached, nil + } + + loaded, err := loader() + if err != nil { + var zero T + return zero, err + } + + encoded, marshalErr := json.Marshal(loaded) + if marshalErr == nil && string(encoded) == "null" && nullTTL > 0 { + SetNull(c, key, nullTTL) + return loaded, nil + } + + SetWithJitter(c, key, loaded, ttl, jitterRatio) + return loaded, nil +} diff --git a/internal/cache/keys.go b/internal/cache/keys.go new file mode 100644 index 0000000..0cfd8a0 --- /dev/null +++ b/internal/cache/keys.go @@ -0,0 +1,147 @@ +package cache + +import ( + "fmt" +) + +// 缓存键前缀常量 +const ( + // 帖子相关 + PrefixPostList = "posts:list" + PrefixPost = "posts:detail" + + // 会话相关 + PrefixConversationList = "conversations:list" + PrefixConversationDetail = "conversations:detail" + + // 群组相关 + PrefixGroupMembers = "groups:members" + PrefixGroupInfo = "groups:info" + + // 未读数相关 + PrefixUnreadSystem = "unread:system" + PrefixUnreadConversation = "unread:conversation" + PrefixUnreadDetail = "unread:detail" + + // 用户相关 + PrefixUserInfo = "users:info" + PrefixUserMe = "users:me" +) + +// PostListKey 生成帖子列表缓存键 +// postType: 帖子类型 (recommend, hot, follow, latest) +// page: 页码 +// pageSize: 每页数量 +// userID: 用户维度(仅在个性化列表如 follow 场景使用) +func PostListKey(postType string, userID string, page, pageSize int) string { + if userID == "" { + return fmt.Sprintf("%s:%s:%d:%d", PrefixPostList, postType, page, pageSize) + } + return fmt.Sprintf("%s:%s:%s:%d:%d", PrefixPostList, postType, userID, page, pageSize) +} + +// PostDetailKey 生成帖子详情缓存键 +func PostDetailKey(postID string) string { + return fmt.Sprintf("%s:%s", PrefixPost, postID) +} + +// ConversationListKey 生成会话列表缓存键 +func ConversationListKey(userID string, page, pageSize int) string { + return fmt.Sprintf("%s:%s:%d:%d", PrefixConversationList, userID, page, pageSize) +} + +// ConversationDetailKey 生成会话详情缓存键 +func ConversationDetailKey(conversationID, userID string) string { + return fmt.Sprintf("%s:%s:%s", PrefixConversationDetail, conversationID, userID) +} + +// GroupMembersKey 生成群组成员缓存键 +func GroupMembersKey(groupID string, page, pageSize int) string { + return fmt.Sprintf("%s:%s:page:%d:size:%d", PrefixGroupMembers, groupID, page, pageSize) +} + +// GroupMembersAllKey 生成群组全量成员ID列表缓存键 +func GroupMembersAllKey(groupID string) string { + return fmt.Sprintf("%s:all:%s", PrefixGroupMembers, groupID) +} + +// GroupInfoKey 生成群组信息缓存键 +func GroupInfoKey(groupID string) string { + return fmt.Sprintf("%s:%s", PrefixGroupInfo, groupID) +} + +// UnreadSystemKey 生成系统消息未读数缓存键 +func UnreadSystemKey(userID string) string { + return fmt.Sprintf("%s:%s", PrefixUnreadSystem, userID) +} + +// UnreadConversationKey 生成会话未读总数缓存键 +func UnreadConversationKey(userID string) string { + return fmt.Sprintf("%s:%s", PrefixUnreadConversation, userID) +} + +// UnreadDetailKey 生成单个会话未读数缓存键 +func UnreadDetailKey(userID, conversationID string) string { + return fmt.Sprintf("%s:%s:%s", PrefixUnreadDetail, userID, conversationID) +} + +// UserInfoKey 生成用户信息缓存键 +func UserInfoKey(userID string) string { + return fmt.Sprintf("%s:%s", PrefixUserInfo, userID) +} + +// UserMeKey 生成当前用户信息缓存键 +func UserMeKey(userID string) string { + return fmt.Sprintf("%s:%s", PrefixUserMe, userID) +} + +// InvalidatePostList 失效帖子列表缓存 +func InvalidatePostList(cache Cache) { + cache.DeleteByPrefix(PrefixPostList) +} + +// InvalidatePostDetail 失效帖子详情缓存 +func InvalidatePostDetail(cache Cache, postID string) { + cache.Delete(PostDetailKey(postID)) +} + +// InvalidateConversationList 失效会话列表缓存 +func InvalidateConversationList(cache Cache, userID string) { + cache.DeleteByPrefix(PrefixConversationList + ":" + userID + ":") +} + +// InvalidateConversationDetail 失效会话详情缓存 +func InvalidateConversationDetail(cache Cache, conversationID, userID string) { + cache.Delete(ConversationDetailKey(conversationID, userID)) +} + +// InvalidateGroupMembers 失效群组成员缓存 +func InvalidateGroupMembers(cache Cache, groupID string) { + cache.DeleteByPrefix(PrefixGroupMembers + ":" + groupID) +} + +// InvalidateGroupInfo 失效群组信息缓存 +func InvalidateGroupInfo(cache Cache, groupID string) { + cache.Delete(GroupInfoKey(groupID)) +} + +// InvalidateUnreadSystem 失效系统消息未读数缓存 +func InvalidateUnreadSystem(cache Cache, userID string) { + cache.Delete(UnreadSystemKey(userID)) +} + +// InvalidateUnreadConversation 失效会话未读数缓存 +func InvalidateUnreadConversation(cache Cache, userID string) { + cache.Delete(UnreadConversationKey(userID)) +} + +// InvalidateUnreadDetail 失效单个会话未读数缓存 +func InvalidateUnreadDetail(cache Cache, userID, conversationID string) { + cache.Delete(UnreadDetailKey(userID, conversationID)) +} + +// InvalidateUserInfo 失效用户信息缓存 +func InvalidateUserInfo(cache Cache, userID string) { + cache.Delete(UserInfoKey(userID)) + cache.Delete(UserMeKey(userID)) +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..056991b --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,393 @@ +package config + +import ( + "context" + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" + "github.com/redis/go-redis/v9" + "github.com/spf13/viper" +) + +type Config struct { + Server ServerConfig `mapstructure:"server"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + Cache CacheConfig `mapstructure:"cache"` + S3 S3Config `mapstructure:"s3"` + JWT JWTConfig `mapstructure:"jwt"` + Log LogConfig `mapstructure:"log"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Upload UploadConfig `mapstructure:"upload"` + Gorse GorseConfig `mapstructure:"gorse"` + OpenAI OpenAIConfig `mapstructure:"openai"` + Email EmailConfig `mapstructure:"email"` +} + +type ServerConfig struct { + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + Mode string `mapstructure:"mode"` +} + +type DatabaseConfig struct { + Type string `mapstructure:"type"` + SQLite SQLiteConfig `mapstructure:"sqlite"` + Postgres PostgresConfig `mapstructure:"postgres"` + MaxIdleConns int `mapstructure:"max_idle_conns"` + MaxOpenConns int `mapstructure:"max_open_conns"` + LogLevel string `mapstructure:"log_level"` + SlowThresholdMs int `mapstructure:"slow_threshold_ms"` + IgnoreRecordNotFound bool `mapstructure:"ignore_record_not_found"` + ParameterizedQueries bool `mapstructure:"parameterized_queries"` +} + +type SQLiteConfig struct { + Path string `mapstructure:"path"` +} + +type PostgresConfig struct { + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + User string `mapstructure:"user"` + Password string `mapstructure:"password"` + DBName string `mapstructure:"dbname"` + SSLMode string `mapstructure:"sslmode"` +} + +func (d PostgresConfig) DSN() string { + return fmt.Sprintf( + "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", + d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode, + ) +} + +type RedisConfig struct { + Type string `mapstructure:"type"` + Redis RedisServerConfig `mapstructure:"redis"` + Miniredis MiniredisConfig `mapstructure:"miniredis"` + PoolSize int `mapstructure:"pool_size"` +} + +type CacheConfig struct { + Enabled bool `mapstructure:"enabled"` + KeyPrefix string `mapstructure:"key_prefix"` + DefaultTTL int `mapstructure:"default_ttl"` + NullTTL int `mapstructure:"null_ttl"` + JitterRatio float64 `mapstructure:"jitter_ratio"` + DisableFlushDB bool `mapstructure:"disable_flushdb"` + Modules CacheModuleTTL `mapstructure:"modules"` +} + +type CacheModuleTTL struct { + PostList int `mapstructure:"post_list_ttl"` + Conversation int `mapstructure:"conversation_ttl"` + UnreadCount int `mapstructure:"unread_count_ttl"` + GroupMembers int `mapstructure:"group_members_ttl"` +} + +type RedisServerConfig struct { + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + Password string `mapstructure:"password"` + DB int `mapstructure:"db"` +} + +func (r RedisServerConfig) Addr() string { + return fmt.Sprintf("%s:%d", r.Host, r.Port) +} + +type MiniredisConfig struct { + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` +} + +type S3Config struct { + Endpoint string `mapstructure:"endpoint"` + AccessKey string `mapstructure:"access_key"` + SecretKey string `mapstructure:"secret_key"` + Bucket string `mapstructure:"bucket"` + UseSSL bool `mapstructure:"use_ssl"` + Region string `mapstructure:"region"` + Domain string `mapstructure:"domain"` // 自定义域名,如 s3.carrot.skin +} + +type JWTConfig struct { + Secret string `mapstructure:"secret"` + AccessTokenExpire time.Duration `mapstructure:"access_token_expire"` + RefreshTokenExpire time.Duration `mapstructure:"refresh_token_expire"` +} + +type LogConfig struct { + Level string `mapstructure:"level"` + Encoding string `mapstructure:"encoding"` + OutputPaths []string `mapstructure:"output_paths"` +} + +type RateLimitConfig struct { + Enabled bool `mapstructure:"enabled"` + RequestsPerMinute int `mapstructure:"requests_per_minute"` +} + +type UploadConfig struct { + MaxFileSize int64 `mapstructure:"max_file_size"` + AllowedTypes []string `mapstructure:"allowed_types"` +} + +type GorseConfig struct { + Address string `mapstructure:"address"` + APIKey string `mapstructure:"api_key"` + Enabled bool `mapstructure:"enabled"` + Dashboard string `mapstructure:"dashboard"` + ImportPassword string `mapstructure:"import_password"` + EmbeddingAPIKey string `mapstructure:"embedding_api_key"` + EmbeddingURL string `mapstructure:"embedding_url"` + EmbeddingModel string `mapstructure:"embedding_model"` +} + +type OpenAIConfig struct { + Enabled bool `mapstructure:"enabled"` + BaseURL string `mapstructure:"base_url"` + APIKey string `mapstructure:"api_key"` + ModerationModel string `mapstructure:"moderation_model"` + ModerationMaxImagesPerRequest int `mapstructure:"moderation_max_images_per_request"` + RequestTimeout int `mapstructure:"request_timeout"` + StrictModeration bool `mapstructure:"strict_moderation"` +} + +type EmailConfig struct { + Enabled bool `mapstructure:"enabled"` + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + FromAddress string `mapstructure:"from_address"` + FromName string `mapstructure:"from_name"` + UseTLS bool `mapstructure:"use_tls"` + InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` + Timeout int `mapstructure:"timeout"` +} + +func Load(configPath string) (*Config, error) { + viper.SetConfigFile(configPath) + viper.SetConfigType("yaml") + + // 启用环境变量支持 + viper.SetEnvPrefix("APP") + viper.AutomaticEnv() + // 允许环境变量使用下划线或连字符 + viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_", "-", "_")) + + // Set default values + viper.SetDefault("server.port", 8080) + viper.SetDefault("server.mode", "debug") + viper.SetDefault("server.host", "0.0.0.0") + viper.SetDefault("database.type", "sqlite") + viper.SetDefault("database.sqlite.path", "./data/carrot_bbs.db") + viper.SetDefault("database.max_idle_conns", 10) + viper.SetDefault("database.max_open_conns", 100) + viper.SetDefault("database.log_level", "warn") + viper.SetDefault("database.slow_threshold_ms", 200) + viper.SetDefault("database.ignore_record_not_found", true) + viper.SetDefault("database.parameterized_queries", true) + viper.SetDefault("redis.type", "miniredis") + viper.SetDefault("redis.redis.host", "localhost") + viper.SetDefault("redis.redis.port", 6379) + viper.SetDefault("redis.redis.password", "") + viper.SetDefault("redis.redis.db", 0) + viper.SetDefault("redis.miniredis.host", "localhost") + viper.SetDefault("redis.miniredis.port", 6379) + viper.SetDefault("redis.pool_size", 10) + viper.SetDefault("cache.enabled", true) + viper.SetDefault("cache.key_prefix", "") + viper.SetDefault("cache.default_ttl", 30) + viper.SetDefault("cache.null_ttl", 5) + viper.SetDefault("cache.jitter_ratio", 0.1) + viper.SetDefault("cache.disable_flushdb", true) + viper.SetDefault("cache.modules.post_list_ttl", 30) + viper.SetDefault("cache.modules.conversation_ttl", 60) + viper.SetDefault("cache.modules.unread_count_ttl", 30) + viper.SetDefault("cache.modules.group_members_ttl", 120) + viper.SetDefault("jwt.secret", "your-jwt-secret-key-change-in-production") + viper.SetDefault("jwt.access_token_expire", 86400) + viper.SetDefault("jwt.refresh_token_expire", 604800) + viper.SetDefault("log.level", "info") + viper.SetDefault("log.encoding", "json") + viper.SetDefault("log.output_paths", []string{"stdout", "./logs/app.log"}) + viper.SetDefault("rate_limit.enabled", true) + viper.SetDefault("rate_limit.requests_per_minute", 60) + viper.SetDefault("upload.max_file_size", 10485760) + viper.SetDefault("upload.allowed_types", []string{"image/jpeg", "image/png", "image/gif", "image/webp"}) + viper.SetDefault("s3.endpoint", "") + viper.SetDefault("s3.access_key", "") + viper.SetDefault("s3.secret_key", "") + viper.SetDefault("s3.bucket", "") + viper.SetDefault("s3.use_ssl", true) + viper.SetDefault("s3.region", "us-east-1") + viper.SetDefault("s3.domain", "") + viper.SetDefault("sensitive.enabled", true) + viper.SetDefault("sensitive.replace_str", "***") + viper.SetDefault("audit.enabled", false) + viper.SetDefault("audit.provider", "local") + viper.SetDefault("gorse.enabled", false) + viper.SetDefault("gorse.address", "http://localhost:8087") + viper.SetDefault("gorse.api_key", "") + viper.SetDefault("gorse.dashboard", "http://localhost:8088") + viper.SetDefault("gorse.import_password", "") + viper.SetDefault("gorse.embedding_api_key", "") + viper.SetDefault("gorse.embedding_url", "https://api.littlelan.cn/v1/embeddings") + viper.SetDefault("gorse.embedding_model", "BAAI/bge-m3") + viper.SetDefault("openai.enabled", true) + viper.SetDefault("openai.base_url", "https://api.littlelan.cn/") + viper.SetDefault("openai.api_key", "") + viper.SetDefault("openai.moderation_model", "qwen3.5-122b") + viper.SetDefault("openai.moderation_max_images_per_request", 1) + viper.SetDefault("openai.request_timeout", 30) + viper.SetDefault("openai.strict_moderation", false) + viper.SetDefault("email.enabled", false) + viper.SetDefault("email.host", "") + viper.SetDefault("email.port", 587) + viper.SetDefault("email.username", "") + viper.SetDefault("email.password", "") + viper.SetDefault("email.from_address", "") + viper.SetDefault("email.from_name", "Carrot BBS") + viper.SetDefault("email.use_tls", true) + viper.SetDefault("email.insecure_skip_verify", false) + viper.SetDefault("email.timeout", 15) + + if err := viper.ReadInConfig(); err != nil { + return nil, fmt.Errorf("failed to read config: %w", err) + } + + var cfg Config + if err := viper.Unmarshal(&cfg); err != nil { + return nil, fmt.Errorf("failed to unmarshal config: %w", err) + } + + // Convert seconds to duration + cfg.JWT.AccessTokenExpire = time.Duration(viper.GetInt("jwt.access_token_expire")) * time.Second + cfg.JWT.RefreshTokenExpire = time.Duration(viper.GetInt("jwt.refresh_token_expire")) * time.Second + + // 环境变量覆盖(显式处理敏感配置) + cfg.JWT.Secret = getEnvOrDefault("APP_JWT_SECRET", cfg.JWT.Secret) + cfg.Database.SQLite.Path = getEnvOrDefault("APP_DATABASE_SQLITE_PATH", cfg.Database.SQLite.Path) + cfg.Database.Postgres.Host = getEnvOrDefault("APP_DATABASE_POSTGRES_HOST", cfg.Database.Postgres.Host) + cfg.Database.Postgres.Port, _ = strconv.Atoi(getEnvOrDefault("APP_DATABASE_POSTGRES_PORT", fmt.Sprintf("%d", cfg.Database.Postgres.Port))) + cfg.Database.Postgres.User = getEnvOrDefault("APP_DATABASE_POSTGRES_USER", cfg.Database.Postgres.User) + cfg.Database.Postgres.Password = getEnvOrDefault("APP_DATABASE_POSTGRES_PASSWORD", cfg.Database.Postgres.Password) + cfg.Database.Postgres.DBName = getEnvOrDefault("APP_DATABASE_POSTGRES_DBNAME", cfg.Database.Postgres.DBName) + cfg.Database.LogLevel = getEnvOrDefault("APP_DATABASE_LOG_LEVEL", cfg.Database.LogLevel) + cfg.Database.SlowThresholdMs, _ = strconv.Atoi(getEnvOrDefault("APP_DATABASE_SLOW_THRESHOLD_MS", fmt.Sprintf("%d", cfg.Database.SlowThresholdMs))) + cfg.Database.IgnoreRecordNotFound, _ = strconv.ParseBool(getEnvOrDefault("APP_DATABASE_IGNORE_RECORD_NOT_FOUND", fmt.Sprintf("%t", cfg.Database.IgnoreRecordNotFound))) + cfg.Database.ParameterizedQueries, _ = strconv.ParseBool(getEnvOrDefault("APP_DATABASE_PARAMETERIZED_QUERIES", fmt.Sprintf("%t", cfg.Database.ParameterizedQueries))) + cfg.Redis.Redis.Host = getEnvOrDefault("APP_REDIS_REDIS_HOST", cfg.Redis.Redis.Host) + cfg.Redis.Redis.Port, _ = strconv.Atoi(getEnvOrDefault("APP_REDIS_REDIS_PORT", fmt.Sprintf("%d", cfg.Redis.Redis.Port))) + cfg.Redis.Redis.Password = getEnvOrDefault("APP_REDIS_REDIS_PASSWORD", cfg.Redis.Redis.Password) + cfg.Redis.Redis.DB, _ = strconv.Atoi(getEnvOrDefault("APP_REDIS_REDIS_DB", fmt.Sprintf("%d", cfg.Redis.Redis.DB))) + cfg.Redis.Miniredis.Host = getEnvOrDefault("APP_REDIS_MINIREDIS_HOST", cfg.Redis.Miniredis.Host) + cfg.Redis.Miniredis.Port, _ = strconv.Atoi(getEnvOrDefault("APP_REDIS_MINIREDIS_PORT", fmt.Sprintf("%d", cfg.Redis.Miniredis.Port))) + cfg.Redis.Type = getEnvOrDefault("APP_REDIS_TYPE", cfg.Redis.Type) + cfg.Cache.KeyPrefix = getEnvOrDefault("APP_CACHE_KEY_PREFIX", cfg.Cache.KeyPrefix) + cfg.Cache.Enabled, _ = strconv.ParseBool(getEnvOrDefault("APP_CACHE_ENABLED", fmt.Sprintf("%t", cfg.Cache.Enabled))) + cfg.Cache.DisableFlushDB, _ = strconv.ParseBool(getEnvOrDefault("APP_CACHE_DISABLE_FLUSHDB", fmt.Sprintf("%t", cfg.Cache.DisableFlushDB))) + cfg.Cache.DefaultTTL, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_DEFAULT_TTL", fmt.Sprintf("%d", cfg.Cache.DefaultTTL))) + cfg.Cache.NullTTL, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_NULL_TTL", fmt.Sprintf("%d", cfg.Cache.NullTTL))) + cfg.Cache.JitterRatio, _ = strconv.ParseFloat(getEnvOrDefault("APP_CACHE_JITTER_RATIO", fmt.Sprintf("%.2f", cfg.Cache.JitterRatio)), 64) + cfg.Cache.Modules.PostList, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_MODULES_POST_LIST_TTL", fmt.Sprintf("%d", cfg.Cache.Modules.PostList))) + cfg.Cache.Modules.Conversation, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_MODULES_CONVERSATION_TTL", fmt.Sprintf("%d", cfg.Cache.Modules.Conversation))) + cfg.Cache.Modules.UnreadCount, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_MODULES_UNREAD_COUNT_TTL", fmt.Sprintf("%d", cfg.Cache.Modules.UnreadCount))) + cfg.Cache.Modules.GroupMembers, _ = strconv.Atoi(getEnvOrDefault("APP_CACHE_MODULES_GROUP_MEMBERS_TTL", fmt.Sprintf("%d", cfg.Cache.Modules.GroupMembers))) + cfg.S3.Endpoint = getEnvOrDefault("APP_S3_ENDPOINT", cfg.S3.Endpoint) + cfg.S3.AccessKey = getEnvOrDefault("APP_S3_ACCESS_KEY", cfg.S3.AccessKey) + cfg.S3.SecretKey = getEnvOrDefault("APP_S3_SECRET_KEY", cfg.S3.SecretKey) + cfg.S3.Bucket = getEnvOrDefault("APP_S3_BUCKET", cfg.S3.Bucket) + cfg.S3.Domain = getEnvOrDefault("APP_S3_DOMAIN", cfg.S3.Domain) + cfg.Server.Host = getEnvOrDefault("APP_SERVER_HOST", cfg.Server.Host) + cfg.Server.Port, _ = strconv.Atoi(getEnvOrDefault("APP_SERVER_PORT", fmt.Sprintf("%d", cfg.Server.Port))) + cfg.Server.Mode = getEnvOrDefault("APP_SERVER_MODE", cfg.Server.Mode) + cfg.Gorse.Address = getEnvOrDefault("APP_GORSE_ADDRESS", cfg.Gorse.Address) + cfg.Gorse.APIKey = getEnvOrDefault("APP_GORSE_API_KEY", cfg.Gorse.APIKey) + cfg.Gorse.Dashboard = getEnvOrDefault("APP_GORSE_DASHBOARD", cfg.Gorse.Dashboard) + cfg.Gorse.ImportPassword = getEnvOrDefault("APP_GORSE_IMPORT_PASSWORD", cfg.Gorse.ImportPassword) + cfg.Gorse.EmbeddingAPIKey = getEnvOrDefault("APP_GORSE_EMBEDDING_API_KEY", cfg.Gorse.EmbeddingAPIKey) + cfg.Gorse.EmbeddingURL = getEnvOrDefault("APP_GORSE_EMBEDDING_URL", cfg.Gorse.EmbeddingURL) + cfg.Gorse.EmbeddingModel = getEnvOrDefault("APP_GORSE_EMBEDDING_MODEL", cfg.Gorse.EmbeddingModel) + cfg.OpenAI.BaseURL = getEnvOrDefault("APP_OPENAI_BASE_URL", cfg.OpenAI.BaseURL) + cfg.OpenAI.APIKey = getEnvOrDefault("APP_OPENAI_API_KEY", cfg.OpenAI.APIKey) + cfg.OpenAI.ModerationModel = getEnvOrDefault("APP_OPENAI_MODERATION_MODEL", cfg.OpenAI.ModerationModel) + cfg.OpenAI.ModerationMaxImagesPerRequest, _ = strconv.Atoi(getEnvOrDefault("APP_OPENAI_MODERATION_MAX_IMAGES_PER_REQUEST", fmt.Sprintf("%d", cfg.OpenAI.ModerationMaxImagesPerRequest))) + cfg.OpenAI.RequestTimeout, _ = strconv.Atoi(getEnvOrDefault("APP_OPENAI_REQUEST_TIMEOUT", fmt.Sprintf("%d", cfg.OpenAI.RequestTimeout))) + cfg.OpenAI.Enabled, _ = strconv.ParseBool(getEnvOrDefault("APP_OPENAI_ENABLED", fmt.Sprintf("%t", cfg.OpenAI.Enabled))) + cfg.OpenAI.StrictModeration, _ = strconv.ParseBool(getEnvOrDefault("APP_OPENAI_STRICT_MODERATION", fmt.Sprintf("%t", cfg.OpenAI.StrictModeration))) + cfg.Email.Enabled, _ = strconv.ParseBool(getEnvOrDefault("APP_EMAIL_ENABLED", fmt.Sprintf("%t", cfg.Email.Enabled))) + cfg.Email.Host = getEnvOrDefault("APP_EMAIL_HOST", cfg.Email.Host) + cfg.Email.Port, _ = strconv.Atoi(getEnvOrDefault("APP_EMAIL_PORT", fmt.Sprintf("%d", cfg.Email.Port))) + cfg.Email.Username = getEnvOrDefault("APP_EMAIL_USERNAME", cfg.Email.Username) + cfg.Email.Password = getEnvOrDefault("APP_EMAIL_PASSWORD", cfg.Email.Password) + cfg.Email.FromAddress = getEnvOrDefault("APP_EMAIL_FROM_ADDRESS", cfg.Email.FromAddress) + cfg.Email.FromName = getEnvOrDefault("APP_EMAIL_FROM_NAME", cfg.Email.FromName) + cfg.Email.UseTLS, _ = strconv.ParseBool(getEnvOrDefault("APP_EMAIL_USE_TLS", fmt.Sprintf("%t", cfg.Email.UseTLS))) + cfg.Email.InsecureSkipVerify, _ = strconv.ParseBool(getEnvOrDefault("APP_EMAIL_INSECURE_SKIP_VERIFY", fmt.Sprintf("%t", cfg.Email.InsecureSkipVerify))) + cfg.Email.Timeout, _ = strconv.Atoi(getEnvOrDefault("APP_EMAIL_TIMEOUT", fmt.Sprintf("%d", cfg.Email.Timeout))) + + return &cfg, nil +} + +// getEnvOrDefault 获取环境变量值,如果未设置则返回默认值 +func getEnvOrDefault(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} + +// NewRedis 创建Redis客户端(真实Redis) +func NewRedis(cfg *RedisConfig) (*redis.Client, error) { + client := redis.NewClient(&redis.Options{ + Addr: cfg.Redis.Addr(), + Password: cfg.Redis.Password, + DB: cfg.Redis.DB, + PoolSize: cfg.PoolSize, + }) + + ctx := context.Background() + if err := client.Ping(ctx).Err(); err != nil { + return nil, fmt.Errorf("failed to connect to redis: %w", err) + } + + return client, nil +} + +// NewS3 创建S3客户端 +func NewS3(cfg *S3Config) (*minio.Client, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + client, err := minio.New(cfg.Endpoint, &minio.Options{ + Creds: credentials.NewStaticV4(cfg.AccessKey, cfg.SecretKey, ""), + Secure: cfg.UseSSL, + }) + if err != nil { + return nil, fmt.Errorf("failed to create S3 client: %w", err) + } + + exists, err := client.BucketExists(ctx, cfg.Bucket) + if err != nil { + return nil, fmt.Errorf("failed to check bucket: %w", err) + } + + if !exists { + if err := client.MakeBucket(ctx, cfg.Bucket, minio.MakeBucketOptions{ + Region: cfg.Region, + }); err != nil { + return nil, fmt.Errorf("failed to create bucket: %w", err) + } + } + + return client, nil +} diff --git a/internal/dto/converter.go b/internal/dto/converter.go new file mode 100644 index 0000000..4b82c0d --- /dev/null +++ b/internal/dto/converter.go @@ -0,0 +1,885 @@ +package dto + +import ( + "carrot_bbs/internal/model" + "carrot_bbs/internal/pkg/utils" + "context" + "encoding/json" + "strconv" +) + +// ==================== User 转换 ==================== + +// getAvatarOrDefault 获取头像URL,如果为空则返回在线头像生成服务的URL +func getAvatarOrDefault(user *model.User) string { + return utils.GetAvatarOrDefault(user.Username, user.Nickname, user.Avatar) +} + +// ConvertUserToResponse 将User转换为UserResponse +func ConvertUserToResponse(user *model.User) *UserResponse { + if user == nil { + return nil + } + return &UserResponse{ + ID: user.ID, + Username: user.Username, + Nickname: user.Nickname, + Email: user.Email, + Phone: user.Phone, + EmailVerified: user.EmailVerified, + Avatar: getAvatarOrDefault(user), + CoverURL: user.CoverURL, + Bio: user.Bio, + Website: user.Website, + Location: user.Location, + PostsCount: user.PostsCount, + FollowersCount: user.FollowersCount, + FollowingCount: user.FollowingCount, + CreatedAt: FormatTime(user.CreatedAt), + } +} + +// ConvertUserToResponseWithFollowing 将User转换为UserResponse(包含关注状态) +func ConvertUserToResponseWithFollowing(user *model.User, isFollowing bool) *UserResponse { + if user == nil { + return nil + } + return &UserResponse{ + ID: user.ID, + Username: user.Username, + Nickname: user.Nickname, + Email: user.Email, + Phone: user.Phone, + EmailVerified: user.EmailVerified, + Avatar: getAvatarOrDefault(user), + CoverURL: user.CoverURL, + Bio: user.Bio, + Website: user.Website, + Location: user.Location, + PostsCount: user.PostsCount, + FollowersCount: user.FollowersCount, + FollowingCount: user.FollowingCount, + IsFollowing: isFollowing, + IsFollowingMe: false, // 默认false,需要单独计算 + CreatedAt: FormatTime(user.CreatedAt), + } +} + +// ConvertUserToResponseWithPostsCount 将User转换为UserResponse(使用实时计算的帖子数量) +func ConvertUserToResponseWithPostsCount(user *model.User, postsCount int) *UserResponse { + if user == nil { + return nil + } + return &UserResponse{ + ID: user.ID, + Username: user.Username, + Nickname: user.Nickname, + Email: user.Email, + Phone: user.Phone, + EmailVerified: user.EmailVerified, + Avatar: getAvatarOrDefault(user), + CoverURL: user.CoverURL, + Bio: user.Bio, + Website: user.Website, + Location: user.Location, + PostsCount: postsCount, + FollowersCount: user.FollowersCount, + FollowingCount: user.FollowingCount, + CreatedAt: FormatTime(user.CreatedAt), + } +} + +// ConvertUserToResponseWithMutualFollow 将User转换为UserResponse(包含双向关注状态) +func ConvertUserToResponseWithMutualFollow(user *model.User, isFollowing, isFollowingMe bool) *UserResponse { + if user == nil { + return nil + } + return &UserResponse{ + ID: user.ID, + Username: user.Username, + Nickname: user.Nickname, + Email: user.Email, + Phone: user.Phone, + EmailVerified: user.EmailVerified, + Avatar: getAvatarOrDefault(user), + CoverURL: user.CoverURL, + Bio: user.Bio, + Website: user.Website, + Location: user.Location, + PostsCount: user.PostsCount, + FollowersCount: user.FollowersCount, + FollowingCount: user.FollowingCount, + IsFollowing: isFollowing, + IsFollowingMe: isFollowingMe, + CreatedAt: FormatTime(user.CreatedAt), + } +} + +// ConvertUserToResponseWithMutualFollowAndPostsCount 将User转换为UserResponse(包含双向关注状态和实时计算的帖子数量) +func ConvertUserToResponseWithMutualFollowAndPostsCount(user *model.User, isFollowing, isFollowingMe bool, postsCount int) *UserResponse { + if user == nil { + return nil + } + return &UserResponse{ + ID: user.ID, + Username: user.Username, + Nickname: user.Nickname, + Email: user.Email, + Phone: user.Phone, + EmailVerified: user.EmailVerified, + Avatar: getAvatarOrDefault(user), + CoverURL: user.CoverURL, + Bio: user.Bio, + Website: user.Website, + Location: user.Location, + PostsCount: postsCount, + FollowersCount: user.FollowersCount, + FollowingCount: user.FollowingCount, + IsFollowing: isFollowing, + IsFollowingMe: isFollowingMe, + CreatedAt: FormatTime(user.CreatedAt), + } +} + +// ConvertUserToDetailResponse 将User转换为UserDetailResponse +func ConvertUserToDetailResponse(user *model.User) *UserDetailResponse { + if user == nil { + return nil + } + return &UserDetailResponse{ + ID: user.ID, + Username: user.Username, + Nickname: user.Nickname, + Email: user.Email, + EmailVerified: user.EmailVerified, + Avatar: getAvatarOrDefault(user), + CoverURL: user.CoverURL, + Bio: user.Bio, + Website: user.Website, + Location: user.Location, + PostsCount: user.PostsCount, + FollowersCount: user.FollowersCount, + FollowingCount: user.FollowingCount, + IsVerified: user.IsVerified, + CreatedAt: FormatTime(user.CreatedAt), + } +} + +// ConvertUserToDetailResponseWithPostsCount 将User转换为UserDetailResponse(使用实时计算的帖子数量) +func ConvertUserToDetailResponseWithPostsCount(user *model.User, postsCount int) *UserDetailResponse { + if user == nil { + return nil + } + return &UserDetailResponse{ + ID: user.ID, + Username: user.Username, + Nickname: user.Nickname, + Email: user.Email, + EmailVerified: user.EmailVerified, + Phone: user.Phone, // 仅当前用户自己可见 + Avatar: getAvatarOrDefault(user), + CoverURL: user.CoverURL, + Bio: user.Bio, + Website: user.Website, + Location: user.Location, + PostsCount: postsCount, + FollowersCount: user.FollowersCount, + FollowingCount: user.FollowingCount, + IsVerified: user.IsVerified, + CreatedAt: FormatTime(user.CreatedAt), + } +} + +// ConvertUsersToResponse 将User列表转换为响应列表 +func ConvertUsersToResponse(users []*model.User) []*UserResponse { + result := make([]*UserResponse, 0, len(users)) + for _, user := range users { + result = append(result, ConvertUserToResponse(user)) + } + return result +} + +// ConvertUsersToResponseWithMutualFollow 将User列表转换为响应列表(包含双向关注状态) +// followingStatusMap: key是用户ID,value是[isFollowing, isFollowingMe] +func ConvertUsersToResponseWithMutualFollow(users []*model.User, followingStatusMap map[string][2]bool) []*UserResponse { + result := make([]*UserResponse, 0, len(users)) + for _, user := range users { + status, ok := followingStatusMap[user.ID] + if ok { + result = append(result, ConvertUserToResponseWithMutualFollow(user, status[0], status[1])) + } else { + result = append(result, ConvertUserToResponse(user)) + } + } + return result +} + +// ConvertUsersToResponseWithMutualFollowAndPostsCount 将User列表转换为响应列表(包含双向关注状态和实时计算的帖子数量) +// followingStatusMap: key是用户ID,value是[isFollowing, isFollowingMe] +// postsCountMap: key是用户ID,value是帖子数量 +func ConvertUsersToResponseWithMutualFollowAndPostsCount(users []*model.User, followingStatusMap map[string][2]bool, postsCountMap map[string]int64) []*UserResponse { + result := make([]*UserResponse, 0, len(users)) + for _, user := range users { + status, hasStatus := followingStatusMap[user.ID] + postsCount, hasPostsCount := postsCountMap[user.ID] + + // 如果没有帖子数量,使用数据库中的值 + if !hasPostsCount { + postsCount = int64(user.PostsCount) + } + + if hasStatus { + result = append(result, ConvertUserToResponseWithMutualFollowAndPostsCount(user, status[0], status[1], int(postsCount))) + } else { + result = append(result, ConvertUserToResponseWithPostsCount(user, int(postsCount))) + } + } + return result +} + +// ==================== Post 转换 ==================== + +// ConvertPostImageToResponse 将PostImage转换为PostImageResponse +func ConvertPostImageToResponse(img *model.PostImage) PostImageResponse { + if img == nil { + return PostImageResponse{} + } + return PostImageResponse{ + ID: img.ID, + URL: img.URL, + ThumbnailURL: img.ThumbnailURL, + Width: img.Width, + Height: img.Height, + } +} + +// ConvertPostImagesToResponse 将PostImage列表转换为响应列表 +func ConvertPostImagesToResponse(images []model.PostImage) []PostImageResponse { + result := make([]PostImageResponse, 0, len(images)) + for i := range images { + result = append(result, ConvertPostImageToResponse(&images[i])) + } + return result +} + +// ConvertPostToResponse 将Post转换为PostResponse(列表用) +func ConvertPostToResponse(post *model.Post, isLiked, isFavorited bool) *PostResponse { + if post == nil { + return nil + } + + images := make([]PostImageResponse, 0) + for _, img := range post.Images { + images = append(images, ConvertPostImageToResponse(&img)) + } + + var author *UserResponse + if post.User != nil { + author = ConvertUserToResponse(post.User) + } + + return &PostResponse{ + ID: post.ID, + UserID: post.UserID, + Title: post.Title, + Content: post.Content, + Images: images, + LikesCount: post.LikesCount, + CommentsCount: post.CommentsCount, + FavoritesCount: post.FavoritesCount, + SharesCount: post.SharesCount, + ViewsCount: post.ViewsCount, + IsPinned: post.IsPinned, + IsLocked: post.IsLocked, + IsVote: post.IsVote, + CreatedAt: FormatTime(post.CreatedAt), + Author: author, + IsLiked: isLiked, + IsFavorited: isFavorited, + } +} + +// ConvertPostToDetailResponse 将Post转换为PostDetailResponse +func ConvertPostToDetailResponse(post *model.Post, isLiked, isFavorited bool) *PostDetailResponse { + if post == nil { + return nil + } + + images := make([]PostImageResponse, 0) + for _, img := range post.Images { + images = append(images, ConvertPostImageToResponse(&img)) + } + + var author *UserResponse + if post.User != nil { + author = ConvertUserToResponse(post.User) + } + + return &PostDetailResponse{ + ID: post.ID, + UserID: post.UserID, + Title: post.Title, + Content: post.Content, + Images: images, + Status: string(post.Status), + LikesCount: post.LikesCount, + CommentsCount: post.CommentsCount, + FavoritesCount: post.FavoritesCount, + SharesCount: post.SharesCount, + ViewsCount: post.ViewsCount, + IsPinned: post.IsPinned, + IsLocked: post.IsLocked, + IsVote: post.IsVote, + CreatedAt: FormatTime(post.CreatedAt), + UpdatedAt: FormatTime(post.UpdatedAt), + Author: author, + IsLiked: isLiked, + IsFavorited: isFavorited, + } +} + +// ConvertPostsToResponse 将Post列表转换为响应列表(每个帖子独立检查点赞/收藏状态) +func ConvertPostsToResponse(posts []*model.Post, isLikedMap, isFavoritedMap map[string]bool) []*PostResponse { + result := make([]*PostResponse, 0, len(posts)) + for _, post := range posts { + isLiked := false + isFavorited := false + if isLikedMap != nil { + isLiked = isLikedMap[post.ID] + } + if isFavoritedMap != nil { + isFavorited = isFavoritedMap[post.ID] + } + result = append(result, ConvertPostToResponse(post, isLiked, isFavorited)) + } + return result +} + +// ==================== Comment 转换 ==================== + +// ConvertCommentToResponse 将Comment转换为CommentResponse +func ConvertCommentToResponse(comment *model.Comment, isLiked bool) *CommentResponse { + if comment == nil { + return nil + } + + var author *UserResponse + if comment.User != nil { + author = ConvertUserToResponse(comment.User) + } + + // 转换子回复(扁平化结构) + var replies []*CommentResponse + if len(comment.Replies) > 0 { + replies = make([]*CommentResponse, 0, len(comment.Replies)) + for _, reply := range comment.Replies { + replies = append(replies, ConvertCommentToResponse(reply, false)) + } + } + + // TargetID 就是 ParentID,前端根据这个 ID 找到被回复用户的昵称 + var targetID *string + if comment.ParentID != nil && *comment.ParentID != "" { + targetID = comment.ParentID + } + + // 解析图片JSON + var images []CommentImageResponse + if comment.Images != "" { + var urlList []string + if err := json.Unmarshal([]byte(comment.Images), &urlList); err == nil { + images = make([]CommentImageResponse, 0, len(urlList)) + for _, url := range urlList { + images = append(images, CommentImageResponse{URL: url}) + } + } + } + + return &CommentResponse{ + ID: comment.ID, + PostID: comment.PostID, + UserID: comment.UserID, + ParentID: comment.ParentID, + RootID: comment.RootID, + Content: comment.Content, + Images: images, + LikesCount: comment.LikesCount, + RepliesCount: comment.RepliesCount, + CreatedAt: FormatTime(comment.CreatedAt), + Author: author, + IsLiked: isLiked, + TargetID: targetID, + Replies: replies, + } +} + +// ConvertCommentsToResponse 将Comment列表转换为响应列表 +func ConvertCommentsToResponse(comments []*model.Comment, isLiked bool) []*CommentResponse { + result := make([]*CommentResponse, 0, len(comments)) + for _, comment := range comments { + result = append(result, ConvertCommentToResponse(comment, isLiked)) + } + return result +} + +// IsLikedChecker 点赞状态检查器接口 +type IsLikedChecker interface { + IsLiked(ctx context.Context, commentID, userID string) bool +} + +// ConvertCommentToResponseWithUser 将Comment转换为CommentResponse(根据用户ID检查点赞状态) +func ConvertCommentToResponseWithUser(comment *model.Comment, userID string, checker IsLikedChecker) *CommentResponse { + if comment == nil { + return nil + } + + // 检查当前用户是否点赞了该评论 + isLiked := false + if userID != "" && checker != nil { + isLiked = checker.IsLiked(context.Background(), comment.ID, userID) + } + + var author *UserResponse + if comment.User != nil { + author = ConvertUserToResponse(comment.User) + } + + // 转换子回复(扁平化结构),递归检查点赞状态 + var replies []*CommentResponse + if len(comment.Replies) > 0 { + replies = make([]*CommentResponse, 0, len(comment.Replies)) + for _, reply := range comment.Replies { + replies = append(replies, ConvertCommentToResponseWithUser(reply, userID, checker)) + } + } + + // TargetID 就是 ParentID,前端根据这个 ID 找到被回复用户的昵称 + var targetID *string + if comment.ParentID != nil && *comment.ParentID != "" { + targetID = comment.ParentID + } + + // 解析图片JSON + var images []CommentImageResponse + if comment.Images != "" { + var urlList []string + if err := json.Unmarshal([]byte(comment.Images), &urlList); err == nil { + images = make([]CommentImageResponse, 0, len(urlList)) + for _, url := range urlList { + images = append(images, CommentImageResponse{URL: url}) + } + } + } + + return &CommentResponse{ + ID: comment.ID, + PostID: comment.PostID, + UserID: comment.UserID, + ParentID: comment.ParentID, + RootID: comment.RootID, + Content: comment.Content, + Images: images, + LikesCount: comment.LikesCount, + RepliesCount: comment.RepliesCount, + CreatedAt: FormatTime(comment.CreatedAt), + Author: author, + IsLiked: isLiked, + TargetID: targetID, + Replies: replies, + } +} + +// ConvertCommentsToResponseWithUser 将Comment列表转换为响应列表(根据用户ID检查点赞状态) +func ConvertCommentsToResponseWithUser(comments []*model.Comment, userID string, checker IsLikedChecker) []*CommentResponse { + result := make([]*CommentResponse, 0, len(comments)) + for _, comment := range comments { + result = append(result, ConvertCommentToResponseWithUser(comment, userID, checker)) + } + return result +} + +// ==================== Notification 转换 ==================== + +// ConvertNotificationToResponse 将Notification转换为NotificationResponse +func ConvertNotificationToResponse(notification *model.Notification) *NotificationResponse { + if notification == nil { + return nil + } + return &NotificationResponse{ + ID: notification.ID, + UserID: notification.UserID, + Type: string(notification.Type), + Title: notification.Title, + Content: notification.Content, + Data: notification.Data, + IsRead: notification.IsRead, + CreatedAt: FormatTime(notification.CreatedAt), + } +} + +// ConvertNotificationsToResponse 将Notification列表转换为响应列表 +func ConvertNotificationsToResponse(notifications []*model.Notification) []*NotificationResponse { + result := make([]*NotificationResponse, 0, len(notifications)) + for _, n := range notifications { + result = append(result, ConvertNotificationToResponse(n)) + } + return result +} + +// ==================== Message 转换 ==================== + +// ConvertMessageToResponse 将Message转换为MessageResponse +func ConvertMessageToResponse(message *model.Message) *MessageResponse { + if message == nil { + return nil + } + + // 直接使用 segments,不需要解析 + segments := make(model.MessageSegments, len(message.Segments)) + for i, seg := range message.Segments { + segments[i] = model.MessageSegment{ + Type: seg.Type, + Data: seg.Data, + } + } + + return &MessageResponse{ + ID: message.ID, + ConversationID: message.ConversationID, + SenderID: message.SenderID, + Seq: message.Seq, + Segments: segments, + ReplyToID: message.ReplyToID, + Status: string(message.Status), + Category: string(message.Category), + CreatedAt: FormatTime(message.CreatedAt), + } +} + +// ConvertConversationToResponse 将Conversation转换为ConversationResponse +// participants: 会话参与者列表(用户信息) +// unreadCount: 当前用户的未读消息数 +// lastMessage: 最后一条消息 +func ConvertConversationToResponse(conv *model.Conversation, participants []*model.User, unreadCount int, lastMessage *model.Message, isPinned bool) *ConversationResponse { + if conv == nil { + return nil + } + + var participantResponses []*UserResponse + for _, p := range participants { + participantResponses = append(participantResponses, ConvertUserToResponse(p)) + } + + // 转换群组信息 + var groupResponse *GroupResponse + if conv.Group != nil { + groupResponse = GroupToResponse(conv.Group) + } + + return &ConversationResponse{ + ID: conv.ID, + Type: string(conv.Type), + IsPinned: isPinned, + Group: groupResponse, + LastSeq: conv.LastSeq, + LastMessage: ConvertMessageToResponse(lastMessage), + LastMessageAt: FormatTimePointer(conv.LastMsgTime), + UnreadCount: unreadCount, + Participants: participantResponses, + CreatedAt: FormatTime(conv.CreatedAt), + UpdatedAt: FormatTime(conv.UpdatedAt), + } +} + +// ConvertConversationToDetailResponse 将Conversation转换为ConversationDetailResponse +func ConvertConversationToDetailResponse(conv *model.Conversation, participants []*model.User, unreadCount int64, lastMessage *model.Message, myLastReadSeq int64, otherLastReadSeq int64, isPinned bool) *ConversationDetailResponse { + if conv == nil { + return nil + } + + var participantResponses []*UserResponse + for _, p := range participants { + participantResponses = append(participantResponses, ConvertUserToResponse(p)) + } + + return &ConversationDetailResponse{ + ID: conv.ID, + Type: string(conv.Type), + IsPinned: isPinned, + LastSeq: conv.LastSeq, + LastMessage: ConvertMessageToResponse(lastMessage), + LastMessageAt: FormatTimePointer(conv.LastMsgTime), + UnreadCount: unreadCount, + Participants: participantResponses, + MyLastReadSeq: myLastReadSeq, + OtherLastReadSeq: otherLastReadSeq, + CreatedAt: FormatTime(conv.CreatedAt), + UpdatedAt: FormatTime(conv.UpdatedAt), + } +} + +// ConvertMessagesToResponse 将Message列表转换为响应列表 +func ConvertMessagesToResponse(messages []*model.Message) []*MessageResponse { + result := make([]*MessageResponse, 0, len(messages)) + for _, msg := range messages { + result = append(result, ConvertMessageToResponse(msg)) + } + return result +} + +// ConvertConversationsToResponse 将Conversation列表转换为响应列表 +func ConvertConversationsToResponse(convs []*model.Conversation) []*ConversationResponse { + result := make([]*ConversationResponse, 0, len(convs)) + for _, conv := range convs { + result = append(result, ConvertConversationToResponse(conv, nil, 0, nil, false)) + } + return result +} + +// ==================== PushRecord 转换 ==================== + +// PushRecordToResponse 将PushRecord转换为PushRecordResponse +func PushRecordToResponse(record *model.PushRecord) *PushRecordResponse { + if record == nil { + return nil + } + resp := &PushRecordResponse{ + ID: record.ID, + MessageID: record.MessageID, + PushChannel: string(record.PushChannel), + PushStatus: string(record.PushStatus), + RetryCount: record.RetryCount, + CreatedAt: record.CreatedAt, + } + if record.PushedAt != nil { + resp.PushedAt = *record.PushedAt + } + if record.DeliveredAt != nil { + resp.DeliveredAt = *record.DeliveredAt + } + return resp +} + +// PushRecordsToResponse 将PushRecord列表转换为响应列表 +func PushRecordsToResponse(records []*model.PushRecord) []*PushRecordResponse { + result := make([]*PushRecordResponse, 0, len(records)) + for _, record := range records { + result = append(result, PushRecordToResponse(record)) + } + return result +} + +// ==================== DeviceToken 转换 ==================== + +// DeviceTokenToResponse 将DeviceToken转换为DeviceTokenResponse +func DeviceTokenToResponse(token *model.DeviceToken) *DeviceTokenResponse { + if token == nil { + return nil + } + resp := &DeviceTokenResponse{ + ID: token.ID, + DeviceID: token.DeviceID, + DeviceType: string(token.DeviceType), + IsActive: token.IsActive, + DeviceName: token.DeviceName, + CreatedAt: token.CreatedAt, + } + if token.LastUsedAt != nil { + resp.LastUsedAt = *token.LastUsedAt + } + return resp +} + +// DeviceTokensToResponse 将DeviceToken列表转换为响应列表 +func DeviceTokensToResponse(tokens []*model.DeviceToken) []*DeviceTokenResponse { + result := make([]*DeviceTokenResponse, 0, len(tokens)) + for _, token := range tokens { + result = append(result, DeviceTokenToResponse(token)) + } + return result +} + +// ==================== SystemMessage 转换 ==================== + +// SystemMessageToResponse 将Message转换为SystemMessageResponse +func SystemMessageToResponse(msg *model.Message) *SystemMessageResponse { + if msg == nil { + return nil + } + + // 从 segments 中提取文本内容 + content := ExtractTextContentFromModel(msg.Segments) + + resp := &SystemMessageResponse{ + ID: msg.ID, + SenderID: msg.SenderID, + ReceiverID: "", // 系统消息的接收者需要从上下文获取 + Content: content, + Category: string(msg.Category), + SystemType: string(msg.SystemType), + CreatedAt: msg.CreatedAt, + } + if msg.ExtraData != nil { + resp.ExtraData = map[string]interface{}{ + "actor_id": msg.ExtraData.ActorID, + "actor_name": msg.ExtraData.ActorName, + "avatar_url": msg.ExtraData.AvatarURL, + "target_id": msg.ExtraData.TargetID, + "target_type": msg.ExtraData.TargetType, + "action_url": msg.ExtraData.ActionURL, + "action_time": msg.ExtraData.ActionTime, + } + } + return resp +} + +// SystemMessagesToResponse 将Message列表转换为SystemMessageResponse列表 +func SystemMessagesToResponse(messages []*model.Message) []*SystemMessageResponse { + result := make([]*SystemMessageResponse, 0, len(messages)) + for _, msg := range messages { + result = append(result, SystemMessageToResponse(msg)) + } + return result +} + +// SystemNotificationToResponse 将SystemNotification转换为SystemMessageResponse +func SystemNotificationToResponse(n *model.SystemNotification) *SystemMessageResponse { + if n == nil { + return nil + } + resp := &SystemMessageResponse{ + ID: strconv.FormatInt(n.ID, 10), + SenderID: model.SystemSenderIDStr, // 系统发送者 + ReceiverID: n.ReceiverID, + Content: n.Content, + Category: "notification", + SystemType: string(n.Type), + IsRead: n.IsRead, + CreatedAt: n.CreatedAt, + } + if n.ExtraData != nil { + resp.ExtraData = map[string]interface{}{ + "actor_id": n.ExtraData.ActorID, + "actor_id_str": n.ExtraData.ActorIDStr, + "actor_name": n.ExtraData.ActorName, + "avatar_url": n.ExtraData.AvatarURL, + "target_id": n.ExtraData.TargetID, + "target_title": n.ExtraData.TargetTitle, + "target_type": n.ExtraData.TargetType, + "action_url": n.ExtraData.ActionURL, + "action_time": n.ExtraData.ActionTime, + "group_id": n.ExtraData.GroupID, + "group_name": n.ExtraData.GroupName, + "group_avatar": n.ExtraData.GroupAvatar, + "group_description": n.ExtraData.GroupDescription, + "flag": n.ExtraData.Flag, + "request_type": n.ExtraData.RequestType, + "request_status": n.ExtraData.RequestStatus, + "reason": n.ExtraData.Reason, + "target_user_id": n.ExtraData.TargetUserID, + "target_user_name": n.ExtraData.TargetUserName, + "target_user_avatar": n.ExtraData.TargetUserAvatar, + } + } + return resp +} + +// SystemNotificationsToResponse 将SystemNotification列表转换为SystemMessageResponse列表 +func SystemNotificationsToResponse(notifications []*model.SystemNotification) []*SystemMessageResponse { + result := make([]*SystemMessageResponse, 0, len(notifications)) + for _, n := range notifications { + result = append(result, SystemNotificationToResponse(n)) + } + return result +} + +// ==================== Group 转换 ==================== + +// GroupToResponse 将Group转换为GroupResponse +func GroupToResponse(group *model.Group) *GroupResponse { + if group == nil { + return nil + } + return &GroupResponse{ + ID: group.ID, + Name: group.Name, + Avatar: group.Avatar, + Description: group.Description, + OwnerID: group.OwnerID, + MemberCount: group.MemberCount, + MaxMembers: group.MaxMembers, + JoinType: int(group.JoinType), + MuteAll: group.MuteAll, + CreatedAt: FormatTime(group.CreatedAt), + } +} + +// GroupsToResponse 将Group列表转换为GroupResponse列表 +func GroupsToResponse(groups []model.Group) []*GroupResponse { + result := make([]*GroupResponse, 0, len(groups)) + for i := range groups { + result = append(result, GroupToResponse(&groups[i])) + } + return result +} + +// GroupMemberToResponse 将GroupMember转换为GroupMemberResponse +func GroupMemberToResponse(member *model.GroupMember) *GroupMemberResponse { + if member == nil { + return nil + } + return &GroupMemberResponse{ + ID: member.ID, + GroupID: member.GroupID, + UserID: member.UserID, + Role: member.Role, + Nickname: member.Nickname, + Muted: member.Muted, + JoinTime: FormatTime(member.JoinTime), + } +} + +// GroupMemberToResponseWithUser 将GroupMember转换为GroupMemberResponse(包含用户信息) +func GroupMemberToResponseWithUser(member *model.GroupMember, user *model.User) *GroupMemberResponse { + if member == nil { + return nil + } + resp := GroupMemberToResponse(member) + if user != nil { + resp.User = ConvertUserToResponse(user) + } + return resp +} + +// GroupMembersToResponse 将GroupMember列表转换为GroupMemberResponse列表 +func GroupMembersToResponse(members []model.GroupMember) []*GroupMemberResponse { + result := make([]*GroupMemberResponse, 0, len(members)) + for i := range members { + result = append(result, GroupMemberToResponse(&members[i])) + } + return result +} + +// GroupAnnouncementToResponse 将GroupAnnouncement转换为GroupAnnouncementResponse +func GroupAnnouncementToResponse(announcement *model.GroupAnnouncement) *GroupAnnouncementResponse { + if announcement == nil { + return nil + } + return &GroupAnnouncementResponse{ + ID: announcement.ID, + GroupID: announcement.GroupID, + Content: announcement.Content, + AuthorID: announcement.AuthorID, + IsPinned: announcement.IsPinned, + CreatedAt: FormatTime(announcement.CreatedAt), + } +} + +// GroupAnnouncementsToResponse 将GroupAnnouncement列表转换为GroupAnnouncementResponse列表 +func GroupAnnouncementsToResponse(announcements []model.GroupAnnouncement) []*GroupAnnouncementResponse { + result := make([]*GroupAnnouncementResponse, 0, len(announcements)) + for i := range announcements { + result = append(result, GroupAnnouncementToResponse(&announcements[i])) + } + return result +} diff --git a/internal/dto/dto.go b/internal/dto/dto.go new file mode 100644 index 0000000..d0d13a3 --- /dev/null +++ b/internal/dto/dto.go @@ -0,0 +1,819 @@ +package dto + +import ( + "carrot_bbs/internal/model" + "time" +) + +// ==================== User DTOs ==================== + +// UserResponse 用户信息响应 +type UserResponse struct { + ID string `json:"id"` + Username string `json:"username"` + Nickname string `json:"nickname"` + Email *string `json:"email,omitempty"` + Phone *string `json:"phone,omitempty"` + EmailVerified bool `json:"email_verified"` + Avatar string `json:"avatar"` + CoverURL string `json:"cover_url"` // 头图URL + Bio string `json:"bio"` + Website string `json:"website"` + Location string `json:"location"` + PostsCount int `json:"posts_count"` + FollowersCount int `json:"followers_count"` + FollowingCount int `json:"following_count"` + IsFollowing bool `json:"is_following"` // 当前用户是否关注了该用户 + IsFollowingMe bool `json:"is_following_me"` // 该用户是否关注了当前用户 + CreatedAt string `json:"created_at"` +} + +// UserDetailResponse 用户详情响应 +type UserDetailResponse struct { + ID string `json:"id"` + Username string `json:"username"` + Nickname string `json:"nickname"` + Email *string `json:"email"` + EmailVerified bool `json:"email_verified"` + Phone *string `json:"phone,omitempty"` // 仅当前用户自己可见 + Avatar string `json:"avatar"` + CoverURL string `json:"cover_url"` // 头图URL + Bio string `json:"bio"` + Website string `json:"website"` + Location string `json:"location"` + PostsCount int `json:"posts_count"` + FollowersCount int `json:"followers_count"` + FollowingCount int `json:"following_count"` + IsVerified bool `json:"is_verified"` + IsFollowing bool `json:"is_following"` // 当前用户是否关注了该用户 + IsFollowingMe bool `json:"is_following_me"` // 该用户是否关注了当前用户 + CreatedAt string `json:"created_at"` +} + +// ==================== Post DTOs ==================== + +// PostImageResponse 帖子图片响应 +type PostImageResponse struct { + ID string `json:"id"` + URL string `json:"url"` + ThumbnailURL string `json:"thumbnail_url"` + Width int `json:"width"` + Height int `json:"height"` +} + +// PostResponse 帖子响应(列表用) +type PostResponse struct { + ID string `json:"id"` + UserID string `json:"user_id"` + Title string `json:"title"` + Content string `json:"content"` + Images []PostImageResponse `json:"images"` + LikesCount int `json:"likes_count"` + CommentsCount int `json:"comments_count"` + FavoritesCount int `json:"favorites_count"` + SharesCount int `json:"shares_count"` + ViewsCount int `json:"views_count"` + IsPinned bool `json:"is_pinned"` + IsLocked bool `json:"is_locked"` + IsVote bool `json:"is_vote"` + CreatedAt string `json:"created_at"` + Author *UserResponse `json:"author"` + IsLiked bool `json:"is_liked"` + IsFavorited bool `json:"is_favorited"` +} + +// PostDetailResponse 帖子详情响应 +type PostDetailResponse struct { + ID string `json:"id"` + UserID string `json:"user_id"` + Title string `json:"title"` + Content string `json:"content"` + Images []PostImageResponse `json:"images"` + Status string `json:"status"` + LikesCount int `json:"likes_count"` + CommentsCount int `json:"comments_count"` + FavoritesCount int `json:"favorites_count"` + SharesCount int `json:"shares_count"` + ViewsCount int `json:"views_count"` + IsPinned bool `json:"is_pinned"` + IsLocked bool `json:"is_locked"` + IsVote bool `json:"is_vote"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + Author *UserResponse `json:"author"` + IsLiked bool `json:"is_liked"` + IsFavorited bool `json:"is_favorited"` +} + +// ==================== Comment DTOs ==================== + +// CommentImageResponse 评论图片响应 +type CommentImageResponse struct { + URL string `json:"url"` +} + +// CommentResponse 评论响应(扁平化结构,类似B站/抖音) +// 第一层级正常展示,第二三四五层级在第一层级的评论区扁平展示 +type CommentResponse struct { + ID string `json:"id"` + PostID string `json:"post_id"` + UserID string `json:"user_id"` + ParentID *string `json:"parent_id"` + RootID *string `json:"root_id"` + Content string `json:"content"` + Images []CommentImageResponse `json:"images"` + LikesCount int `json:"likes_count"` + RepliesCount int `json:"replies_count"` + CreatedAt string `json:"created_at"` + Author *UserResponse `json:"author"` + IsLiked bool `json:"is_liked"` + TargetID *string `json:"target_id,omitempty"` // 被回复的评论ID,前端根据此ID找到被回复用户的昵称 + Replies []*CommentResponse `json:"replies,omitempty"` // 子回复列表(扁平化,所有层级都在这里) +} + +// ==================== Notification DTOs ==================== + +// NotificationResponse 通知响应 +type NotificationResponse struct { + ID string `json:"id"` + UserID string `json:"user_id"` + Type string `json:"type"` + Title string `json:"title"` + Content string `json:"content"` + Data string `json:"data"` + IsRead bool `json:"is_read"` + CreatedAt string `json:"created_at"` +} + +// ==================== Message Segment DTOs ==================== + +// SegmentType Segment类型 +type SegmentType string + +const ( + SegmentTypeText SegmentType = "text" + SegmentTypeImage SegmentType = "image" + SegmentTypeVoice SegmentType = "voice" + SegmentTypeVideo SegmentType = "video" + SegmentTypeFile SegmentType = "file" + SegmentTypeAt SegmentType = "at" + SegmentTypeReply SegmentType = "reply" + SegmentTypeFace SegmentType = "face" + SegmentTypeLink SegmentType = "link" +) + +// TextSegmentData 文本数据 +type TextSegmentData struct { + Text string `json:"text"` +} + +// ImageSegmentData 图片数据 +type ImageSegmentData struct { + URL string `json:"url"` + Width int `json:"width,omitempty"` + Height int `json:"height,omitempty"` + ThumbnailURL string `json:"thumbnail_url,omitempty"` + FileSize int64 `json:"file_size,omitempty"` +} + +// VoiceSegmentData 语音数据 +type VoiceSegmentData struct { + URL string `json:"url"` + Duration int `json:"duration,omitempty"` // 秒 + FileSize int64 `json:"file_size,omitempty"` +} + +// VideoSegmentData 视频数据 +type VideoSegmentData struct { + URL string `json:"url"` + Width int `json:"width,omitempty"` + Height int `json:"height,omitempty"` + Duration int `json:"duration,omitempty"` // 秒 + ThumbnailURL string `json:"thumbnail_url,omitempty"` + FileSize int64 `json:"file_size,omitempty"` +} + +// FileSegmentData 文件数据 +type FileSegmentData struct { + URL string `json:"url"` + Name string `json:"name"` + Size int64 `json:"size,omitempty"` + MimeType string `json:"mime_type,omitempty"` +} + +// AtSegmentData @数据 +type AtSegmentData struct { + UserID string `json:"user_id"` // "all" 表示@所有人 + Nickname string `json:"nickname,omitempty"` +} + +// ReplySegmentData 回复数据 +type ReplySegmentData struct { + ID string `json:"id"` // 被回复消息的ID +} + +// FaceSegmentData 表情数据 +type FaceSegmentData struct { + ID int `json:"id"` + Name string `json:"name,omitempty"` + URL string `json:"url,omitempty"` +} + +// LinkSegmentData 链接数据 +type LinkSegmentData struct { + URL string `json:"url"` + Title string `json:"title,omitempty"` + Description string `json:"description,omitempty"` + Image string `json:"image,omitempty"` +} + +// ==================== Message DTOs ==================== + +// MessageResponse 消息响应 +type MessageResponse struct { + ID string `json:"id"` + ConversationID string `json:"conversation_id"` + SenderID string `json:"sender_id"` + Seq int64 `json:"seq"` + Segments model.MessageSegments `json:"segments"` // 消息链(必须) + ReplyToID *string `json:"reply_to_id,omitempty"` // 被回复消息的ID(用于关联查找) + Status string `json:"status"` + Category string `json:"category,omitempty"` // 消息类别:chat, notification, announcement + CreatedAt string `json:"created_at"` + Sender *UserResponse `json:"sender"` +} + +// ConversationResponse 会话响应 +type ConversationResponse struct { + ID string `json:"id"` + Type string `json:"type"` + IsPinned bool `json:"is_pinned"` + Group *GroupResponse `json:"group,omitempty"` + LastSeq int64 `json:"last_seq"` + LastMessage *MessageResponse `json:"last_message"` + LastMessageAt string `json:"last_message_at"` + UnreadCount int `json:"unread_count"` + Participants []*UserResponse `json:"participants,omitempty"` // 私聊时使用 + MemberCount int `json:"member_count,omitempty"` // 群聊时使用 + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +// ConversationParticipantResponse 会话参与者响应 +type ConversationParticipantResponse struct { + UserID string `json:"user_id"` + LastReadSeq int64 `json:"last_read_seq"` + Muted bool `json:"muted"` + IsPinned bool `json:"is_pinned"` +} + +// ==================== Auth DTOs ==================== + +// LoginResponse 登录响应 +type LoginResponse struct { + User *UserResponse `json:"user"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` +} + +// RegisterResponse 注册响应 +type RegisterResponse struct { + User *UserResponse `json:"user"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` +} + +// RefreshTokenResponse 刷新Token响应 +type RefreshTokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` +} + +// ==================== Common DTOs ==================== + +// SuccessResponse 通用成功响应 +type SuccessResponse struct { + Success bool `json:"success"` + Message string `json:"message"` +} + +// AvailableResponse 可用性检查响应 +type AvailableResponse struct { + Available bool `json:"available"` +} + +// CountResponse 数量响应 +type CountResponse struct { + Count int `json:"count"` +} + +// URLResponse URL响应 +type URLResponse struct { + URL string `json:"url"` +} + +// ==================== Chat Request DTOs ==================== + +// CreateConversationRequest 创建会话请求 +type CreateConversationRequest struct { + UserID string `json:"user_id" binding:"required"` // 目标用户ID (UUID格式) +} + +// SendMessageRequest 发送消息请求 +type SendMessageRequest struct { + Segments model.MessageSegments `json:"segments" binding:"required"` // 消息链(必须) + ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID (string类型) +} + +// MarkReadRequest 标记已读请求 +type MarkReadRequest struct { + LastReadSeq int64 `json:"last_read_seq" binding:"required"` // 已读到的seq位置 +} + +// SetConversationPinnedRequest 设置会话置顶请求 +type SetConversationPinnedRequest struct { + ConversationID string `json:"conversation_id" binding:"required"` + IsPinned bool `json:"is_pinned"` +} + +// ==================== Chat Response DTOs ==================== + +// ConversationListResponse 会话列表响应 +type ConversationListResponse struct { + Conversations []*ConversationResponse `json:"list"` + Total int64 `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` +} + +// ConversationDetailResponse 会话详情响应 +type ConversationDetailResponse struct { + ID string `json:"id"` + Type string `json:"type"` + IsPinned bool `json:"is_pinned"` + LastSeq int64 `json:"last_seq"` + LastMessage *MessageResponse `json:"last_message"` + LastMessageAt string `json:"last_message_at"` + UnreadCount int64 `json:"unread_count"` + Participants []*UserResponse `json:"participants"` + MyLastReadSeq int64 `json:"my_last_read_seq"` // 当前用户的已读位置 + OtherLastReadSeq int64 `json:"other_last_read_seq"` // 对方用户的已读位置 + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +// UnreadCountResponse 未读数响应 +type UnreadCountResponse struct { + TotalUnreadCount int64 `json:"total_unread_count"` // 所有会话的未读总数 +} + +// ConversationUnreadCountResponse 单个会话未读数响应 +type ConversationUnreadCountResponse struct { + ConversationID string `json:"conversation_id"` + UnreadCount int64 `json:"unread_count"` +} + +// MessageListResponse 消息列表响应 +type MessageListResponse struct { + Messages []*MessageResponse `json:"messages"` + Total int64 `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` +} + +// MessageSyncResponse 消息同步响应(增量同步) +type MessageSyncResponse struct { + Messages []*MessageResponse `json:"messages"` + HasMore bool `json:"has_more"` +} + +// ==================== 设备Token DTOs ==================== + +// RegisterDeviceRequest 注册设备请求 +type RegisterDeviceRequest struct { + DeviceID string `json:"device_id" binding:"required"` + DeviceType string `json:"device_type" binding:"required,oneof=ios android web"` + PushToken string `json:"push_token"` + DeviceName string `json:"device_name"` +} + +// DeviceTokenResponse 设备Token响应 +type DeviceTokenResponse struct { + ID int64 `json:"id"` + DeviceID string `json:"device_id"` + DeviceType string `json:"device_type"` + IsActive bool `json:"is_active"` + DeviceName string `json:"device_name"` + LastUsedAt time.Time `json:"last_used_at,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// ==================== 推送记录 DTOs ==================== + +// PushRecordResponse 推送记录响应 +type PushRecordResponse struct { + ID int64 `json:"id"` + MessageID string `json:"message_id"` + PushChannel string `json:"push_channel"` + PushStatus string `json:"push_status"` + RetryCount int `json:"retry_count"` + PushedAt time.Time `json:"pushed_at,omitempty"` + DeliveredAt time.Time `json:"delivered_at,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// PushRecordListResponse 推送记录列表响应 +type PushRecordListResponse struct { + Records []*PushRecordResponse `json:"records"` + Total int64 `json:"total"` +} + +// ==================== 系统消息 DTOs ==================== + +// SystemMessageResponse 系统消息响应 +type SystemMessageResponse struct { + ID string `json:"id"` + SenderID string `json:"sender_id"` + ReceiverID string `json:"receiver_id"` + Content string `json:"content"` + Category string `json:"category"` + SystemType string `json:"system_type"` + ExtraData map[string]interface{} `json:"extra_data,omitempty"` + IsRead bool `json:"is_read"` + CreatedAt time.Time `json:"created_at"` +} + +// SystemMessageListResponse 系统消息列表响应 +type SystemMessageListResponse struct { + Messages []*SystemMessageResponse `json:"messages"` + Total int64 `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` +} + +// SystemUnreadCountResponse 系统消息未读数响应 +type SystemUnreadCountResponse struct { + UnreadCount int64 `json:"unread_count"` +} + +// ==================== 时间格式化 ==================== + +// FormatTime 格式化时间 +func FormatTime(t time.Time) string { + if t.IsZero() { + return "" + } + return t.Format("2006-01-02T15:04:05Z07:00") +} + +// FormatTimePointer 格式化时间指针 +func FormatTimePointer(t *time.Time) string { + if t == nil { + return "" + } + return FormatTime(*t) +} + +// ==================== Group DTOs ==================== + +// CreateGroupRequest 创建群组请求 +type CreateGroupRequest struct { + Name string `json:"name" binding:"required,max=50"` + Description string `json:"description" binding:"max=500"` + MemberIDs []string `json:"member_ids"` +} + +// UpdateGroupRequest 更新群组请求 +type UpdateGroupRequest struct { + Name string `json:"name" binding:"omitempty,max=50"` + Description string `json:"description" binding:"omitempty,max=500"` + Avatar string `json:"avatar" binding:"omitempty,url"` +} + +// InviteMembersRequest 邀请成员请求 +type InviteMembersRequest struct { + MemberIDs []string `json:"member_ids" binding:"required,min=1"` +} + +// TransferOwnerRequest 转让群主请求 +type TransferOwnerRequest struct { + NewOwnerID string `json:"new_owner_id" binding:"required"` +} + +// SetRoleRequest 设置角色请求 +type SetRoleRequest struct { + Role string `json:"role" binding:"required,oneof=admin member"` +} + +// SetNicknameRequest 设置昵称请求 +type SetNicknameRequest struct { + Nickname string `json:"nickname" binding:"max=50"` +} + +// MuteMemberRequest 禁言成员请求 +type MuteMemberRequest struct { + Muted bool `json:"muted"` +} + +// SetMuteAllRequest 设置全员禁言请求 +type SetMuteAllRequest struct { + MuteAll bool `json:"mute_all"` +} + +// SetJoinTypeRequest 设置加群方式请求 +type SetJoinTypeRequest struct { + JoinType int `json:"join_type" binding:"min=0,max=2"` +} + +// CreateAnnouncementRequest 创建群公告请求 +type CreateAnnouncementRequest struct { + Content string `json:"content" binding:"required,max=2000"` +} + +// GroupResponse 群组响应 +type GroupResponse struct { + ID string `json:"id"` + Name string `json:"name"` + Avatar string `json:"avatar"` + Description string `json:"description"` + OwnerID string `json:"owner_id"` + MemberCount int `json:"member_count"` + MaxMembers int `json:"max_members"` + JoinType int `json:"join_type"` + MuteAll bool `json:"mute_all"` + CreatedAt string `json:"created_at"` +} + +// GroupMemberResponse 群成员响应 +type GroupMemberResponse struct { + ID string `json:"id"` + GroupID string `json:"group_id"` + UserID string `json:"user_id"` + Role string `json:"role"` + Nickname string `json:"nickname"` + Muted bool `json:"muted"` + JoinTime string `json:"join_time"` + User *UserResponse `json:"user,omitempty"` +} + +// GroupAnnouncementResponse 群公告响应 +type GroupAnnouncementResponse struct { + ID string `json:"id"` + GroupID string `json:"group_id"` + Content string `json:"content"` + AuthorID string `json:"author_id"` + IsPinned bool `json:"is_pinned"` + CreatedAt string `json:"created_at"` +} + +// GroupListResponse 群组列表响应 +type GroupListResponse struct { + List []*GroupResponse `json:"list"` + Total int64 `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` +} + +// GroupMemberListResponse 群成员列表响应 +type GroupMemberListResponse struct { + List []*GroupMemberResponse `json:"list"` + Total int64 `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` +} + +// GroupAnnouncementListResponse 群公告列表响应 +type GroupAnnouncementListResponse struct { + List []*GroupAnnouncementResponse `json:"list"` + Total int64 `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` +} + +// ==================== WebSocket Event DTOs ==================== + +// WSEventResponse WebSocket事件响应结构体 +// 用于后端推送消息给前端的标准格式 +type WSEventResponse struct { + ID string `json:"id"` // 事件唯一ID (UUID) + Time int64 `json:"time"` // 时间戳 (毫秒) + Type string `json:"type"` // 事件类型 (message, notification, system等) + DetailType string `json:"detail_type"` // 详细类型 (private, group, like, comment等) + Seq string `json:"seq"` // 消息序列号 + Segments model.MessageSegments `json:"segments"` // 消息段数组 + SenderID string `json:"sender_id"` // 发送者用户ID +} + +// ==================== WebSocket Request DTOs ==================== + +// SendMessageParams 发送消息参数(用于 REST API) +type SendMessageParams struct { + DetailType string `json:"detail_type"` // 消息类型: private, group + ConversationID string `json:"conversation_id"` // 会话ID + Segments model.MessageSegments `json:"segments"` // 消息内容(消息段数组) + ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID +} + +// DeleteMsgParams 撤回消息参数 +type DeleteMsgParams struct { + MessageID string `json:"message_id"` // 消息ID +} + +// ==================== Group Action Params ==================== + +// SetGroupKickParams 群组踢人参数 +type SetGroupKickParams struct { + GroupID string `json:"group_id"` // 群组ID + UserID string `json:"user_id"` // 被踢用户ID + RejectAddRequest bool `json:"reject_add_request"` // 是否拒绝再次加群 +} + +// SetGroupBanParams 群组单人禁言参数 +type SetGroupBanParams struct { + GroupID string `json:"group_id"` // 群组ID + UserID string `json:"user_id"` // 被禁言用户ID + Duration int64 `json:"duration"` // 禁言时长(秒),0表示解除禁言 +} + +// SetGroupWholeBanParams 群组全员禁言参数 +type SetGroupWholeBanParams struct { + GroupID string `json:"group_id"` // 群组ID + Enable bool `json:"enable"` // 是否开启全员禁言 +} + +// SetGroupAdminParams 群组设置管理员参数 +type SetGroupAdminParams struct { + GroupID string `json:"group_id"` // 群组ID + UserID string `json:"user_id"` // 被设置的用户ID + Enable bool `json:"enable"` // 是否设置为管理员 +} + +// SetGroupNameParams 设置群名参数 +type SetGroupNameParams struct { + GroupID string `json:"group_id"` // 群组ID + GroupName string `json:"group_name"` // 新群名 +} + +// SetGroupAvatarParams 设置群头像参数 +type SetGroupAvatarParams struct { + GroupID string `json:"group_id"` // 群组ID + Avatar string `json:"avatar"` // 头像URL +} + +// SetGroupLeaveParams 退出群组参数 +type SetGroupLeaveParams struct { + GroupID string `json:"group_id"` // 群组ID +} + +// SetGroupAddRequestParams 处理加群请求参数 +type SetGroupAddRequestParams struct { + Flag string `json:"flag"` // 加群请求的flag标识 + Approve bool `json:"approve"` // 是否同意 + Reason string `json:"reason"` // 拒绝理由(当approve为false时) +} + +// GetConversationListParams 获取会话列表参数 +type GetConversationListParams struct { + Page int `json:"page"` // 页码 + PageSize int `json:"page_size"` // 每页数量 +} + +// GetGroupInfoParams 获取群信息参数 +type GetGroupInfoParams struct { + GroupID string `json:"group_id"` // 群组ID +} + +// GetGroupMemberListParams 获取群成员列表参数 +type GetGroupMemberListParams struct { + GroupID string `json:"group_id"` // 群组ID + Page int `json:"page"` // 页码 + PageSize int `json:"page_size"` // 每页数量 +} + +// ==================== Conversation Action Params ==================== + +// CreateConversationParams 创建会话参数 +type CreateConversationParams struct { + UserID string `json:"user_id"` // 目标用户ID(私聊) +} + +// MarkReadParams 标记已读参数 +type MarkReadParams struct { + ConversationID string `json:"conversation_id"` // 会话ID + LastReadSeq int64 `json:"last_read_seq"` // 最后已读消息序号 +} + +// SetConversationPinnedParams 设置会话置顶参数 +type SetConversationPinnedParams struct { + ConversationID string `json:"conversation_id"` // 会话ID + IsPinned bool `json:"is_pinned"` // 是否置顶 +} + +// ==================== Group Action Params (Additional) ==================== + +// CreateGroupParams 创建群组参数 +type CreateGroupParams struct { + Name string `json:"name"` // 群名 + Description string `json:"description,omitempty"` // 群描述 + MemberIDs []string `json:"member_ids,omitempty"` // 初始成员ID列表 +} + +// GetUserGroupsParams 获取用户群组列表参数 +type GetUserGroupsParams struct { + Page int `json:"page"` // 页码 + PageSize int `json:"page_size"` // 每页数量 +} + +// TransferOwnerParams 转让群主参数 +type TransferOwnerParams struct { + GroupID string `json:"group_id"` // 群组ID + NewOwnerID string `json:"new_owner_id"` // 新群主ID +} + +// InviteMembersParams 邀请成员参数 +type InviteMembersParams struct { + GroupID string `json:"group_id"` // 群组ID + MemberIDs []string `json:"member_ids"` // 被邀请的用户ID列表 +} + +// JoinGroupParams 加入群组参数 +type JoinGroupParams struct { + GroupID string `json:"group_id"` // 群组ID +} + +// SetNicknameParams 设置群内昵称参数 +type SetNicknameParams struct { + GroupID string `json:"group_id"` // 群组ID + Nickname string `json:"nickname"` // 群内昵称 +} + +// SetJoinTypeParams 设置加群方式参数 +type SetJoinTypeParams struct { + GroupID string `json:"group_id"` // 群组ID + JoinType int `json:"join_type"` // 加群方式:0-允许任何人加入,1-需要审批,2-不允许加入 +} + +// CreateAnnouncementParams 创建群公告参数 +type CreateAnnouncementParams struct { + GroupID string `json:"group_id"` // 群组ID + Content string `json:"content"` // 公告内容 +} + +// GetAnnouncementsParams 获取群公告列表参数 +type GetAnnouncementsParams struct { + GroupID string `json:"group_id"` // 群组ID + Page int `json:"page"` // 页码 + PageSize int `json:"page_size"` // 每页数量 +} + +// DeleteAnnouncementParams 删除群公告参数 +type DeleteAnnouncementParams struct { + GroupID string `json:"group_id"` // 群组ID + AnnouncementID string `json:"announcement_id"` // 公告ID +} + +// DissolveGroupParams 解散群组参数 +type DissolveGroupParams struct { + GroupID string `json:"group_id"` // 群组ID +} + +// GetMyMemberInfoParams 获取我在群组中的成员信息参数 +type GetMyMemberInfoParams struct { + GroupID string `json:"group_id"` // 群组ID +} + +// ==================== Vote DTOs ==================== + +// CreateVotePostRequest 创建投票帖子请求 +type CreateVotePostRequest struct { + Title string `json:"title" binding:"required,max=200"` + Content string `json:"content" binding:"max=2000"` + CommunityID string `json:"community_id"` + Images []string `json:"images"` + VoteOptions []string `json:"vote_options" binding:"required,min=2,max=10"` // 投票选项,至少2个最多10个 +} + +// VoteOptionDTO 投票选项DTO +type VoteOptionDTO struct { + ID string `json:"id"` + Content string `json:"content"` + VotesCount int `json:"votes_count"` +} + +// VoteResultDTO 投票结果DTO +type VoteResultDTO struct { + Options []VoteOptionDTO `json:"options"` + TotalVotes int `json:"total_votes"` + HasVoted bool `json:"has_voted"` + VotedOptionID string `json:"voted_option_id,omitempty"` +} + +// ==================== WebSocket Response DTOs ==================== + +// WSResponse WebSocket响应结构体 +type WSResponse struct { + Success bool `json:"success"` // 是否成功 + Action string `json:"action"` // 响应原action + Data interface{} `json:"data,omitempty"` // 响应数据 + Error string `json:"error,omitempty"` // 错误信息 +} diff --git a/internal/dto/segment.go b/internal/dto/segment.go new file mode 100644 index 0000000..40d8131 --- /dev/null +++ b/internal/dto/segment.go @@ -0,0 +1,362 @@ +package dto + +import ( + "encoding/json" + "fmt" + + "carrot_bbs/internal/model" +) + +// ParseSegmentData 解析Segment数据到目标结构体 +func ParseSegmentData(segment model.MessageSegment, target interface{}) error { + dataBytes, err := json.Marshal(segment.Data) + if err != nil { + return err + } + return json.Unmarshal(dataBytes, target) +} + +// NewTextSegment 创建文本Segment +func NewTextSegment(content string) model.MessageSegment { + return model.MessageSegment{ + Type: string(SegmentTypeText), + Data: map[string]interface{}{"text": content}, + } +} + +// NewImageSegment 创建图片Segment +func NewImageSegment(url string, width, height int, thumbnailURL string) model.MessageSegment { + return model.MessageSegment{ + Type: string(SegmentTypeImage), + Data: map[string]interface{}{ + "url": url, + "width": width, + "height": height, + "thumbnail_url": thumbnailURL, + }, + } +} + +// NewImageSegmentWithSize 创建带文件大小的图片Segment +func NewImageSegmentWithSize(url string, width, height int, thumbnailURL string, fileSize int64) model.MessageSegment { + return model.MessageSegment{ + Type: string(SegmentTypeImage), + Data: map[string]interface{}{ + "url": url, + "width": width, + "height": height, + "thumbnail_url": thumbnailURL, + "file_size": fileSize, + }, + } +} + +// NewVoiceSegment 创建语音Segment +func NewVoiceSegment(url string, duration int, fileSize int64) model.MessageSegment { + return model.MessageSegment{ + Type: string(SegmentTypeVoice), + Data: map[string]interface{}{ + "url": url, + "duration": duration, + "file_size": fileSize, + }, + } +} + +// NewVideoSegment 创建视频Segment +func NewVideoSegment(url string, width, height, duration int, thumbnailURL string, fileSize int64) model.MessageSegment { + return model.MessageSegment{ + Type: string(SegmentTypeVideo), + Data: map[string]interface{}{ + "url": url, + "width": width, + "height": height, + "duration": duration, + "thumbnail_url": thumbnailURL, + "file_size": fileSize, + }, + } +} + +// NewFileSegment 创建文件Segment +func NewFileSegment(url, name string, size int64, mimeType string) model.MessageSegment { + return model.MessageSegment{ + Type: string(SegmentTypeFile), + Data: map[string]interface{}{ + "url": url, + "name": name, + "size": size, + "mime_type": mimeType, + }, + } +} + +// NewAtSegment 创建@Segment,只存储 user_id,昵称由前端根据群成员列表实时解析 +func NewAtSegment(userID string) model.MessageSegment { + return model.MessageSegment{ + Type: string(SegmentTypeAt), + Data: map[string]interface{}{ + "user_id": userID, + }, + } +} + +// NewAtAllSegment 创建@所有人Segment +func NewAtAllSegment() model.MessageSegment { + return model.MessageSegment{ + Type: string(SegmentTypeAt), + Data: map[string]interface{}{ + "user_id": "all", + }, + } +} + +// NewReplySegment 创建回复Segment +func NewReplySegment(messageID string) model.MessageSegment { + return model.MessageSegment{ + Type: string(SegmentTypeReply), + Data: map[string]interface{}{"id": messageID}, + } +} + +// NewFaceSegment 创建表情Segment +func NewFaceSegment(id int, name, url string) model.MessageSegment { + return model.MessageSegment{ + Type: string(SegmentTypeFace), + Data: map[string]interface{}{ + "id": id, + "name": name, + "url": url, + }, + } +} + +// NewLinkSegment 创建链接Segment +func NewLinkSegment(url, title, description, image string) model.MessageSegment { + return model.MessageSegment{ + Type: string(SegmentTypeLink), + Data: map[string]interface{}{ + "url": url, + "title": title, + "description": description, + "image": image, + }, + } +} + +// ExtractTextContentFromJSON 从JSON格式的segments中提取纯文本内容 +// 用于从数据库读取的 []byte 格式的 segments +// 已废弃:现在数据库直接存储 model.MessageSegments 类型 +func ExtractTextContentFromJSON(segmentsJSON []byte) string { + if len(segmentsJSON) == 0 { + return "" + } + + var segments model.MessageSegments + if err := json.Unmarshal(segmentsJSON, &segments); err != nil { + return "" + } + + return ExtractTextContentFromModel(segments) +} + +// ExtractTextContentFromModel 从 model.MessageSegments 中提取纯文本内容 +func ExtractTextContentFromModel(segments model.MessageSegments) string { + var result string + for _, segment := range segments { + switch segment.Type { + case "text": + if text, ok := segment.Data["text"].(string); ok { + result += text + } + case "at": + userID, _ := segment.Data["user_id"].(string) + if userID == "all" { + result += "@所有人 " + } else if userID != "" { + // 昵称由前端实时解析,后端文本提取仅用于推送通知兜底 + result += "@某人 " + } + case "image": + result += "[图片]" + case "voice": + result += "[语音]" + case "video": + result += "[视频]" + case "file": + if name, ok := segment.Data["name"].(string); ok && name != "" { + result += fmt.Sprintf("[文件:%s]", name) + } else { + result += "[文件]" + } + case "face": + if name, ok := segment.Data["name"].(string); ok && name != "" { + result += fmt.Sprintf("[%s]", name) + } else { + result += "[表情]" + } + case "link": + if title, ok := segment.Data["title"].(string); ok && title != "" { + result += fmt.Sprintf("[链接:%s]", title) + } else { + result += "[链接]" + } + } + } + return result +} + +// ExtractTextContent 从消息链中提取纯文本内容 +// 用于搜索、通知展示等场景 +func ExtractTextContent(segments model.MessageSegments) string { + return ExtractTextContentFromModel(segments) +} + +// ExtractMentionedUsers 从消息链中提取被@的用户ID列表 +// 不包括 "all"(@所有人) +func ExtractMentionedUsers(segments model.MessageSegments) []string { + var userIDs []string + seen := make(map[string]bool) + + for _, segment := range segments { + if segment.Type == string(SegmentTypeAt) { + userID, _ := segment.Data["user_id"].(string) + if userID != "all" && userID != "" && !seen[userID] { + userIDs = append(userIDs, userID) + seen[userID] = true + } + } + } + return userIDs +} + +// IsAtAll 检查消息是否@了所有人 +func IsAtAll(segments model.MessageSegments) bool { + for _, segment := range segments { + if segment.Type == string(SegmentTypeAt) { + if userID, ok := segment.Data["user_id"].(string); ok && userID == "all" { + return true + } + } + } + return false +} + +// GetReplyMessageID 从消息链中获取被回复的消息ID +// 如果没有回复segment,返回空字符串 +func GetReplyMessageID(segments model.MessageSegments) string { + for _, segment := range segments { + if segment.Type == string(SegmentTypeReply) { + if id, ok := segment.Data["id"].(string); ok { + return id + } + } + } + return "" +} + +// BuildSegmentsFromContent 从旧版content构建segments +// 用于兼容旧版本消息 +func BuildSegmentsFromContent(contentType, content string, mediaURL *string) model.MessageSegments { + var segments model.MessageSegments + + switch contentType { + case "text": + segments = append(segments, NewTextSegment(content)) + case "image": + if mediaURL != nil { + segments = append(segments, NewImageSegment(*mediaURL, 0, 0, "")) + } + case "voice": + if mediaURL != nil { + segments = append(segments, NewVoiceSegment(*mediaURL, 0, 0)) + } + case "video": + if mediaURL != nil { + segments = append(segments, NewVideoSegment(*mediaURL, 0, 0, 0, "", 0)) + } + case "file": + if mediaURL != nil { + segments = append(segments, NewFileSegment(*mediaURL, content, 0, "")) + } + default: + // 默认当作文本处理 + if content != "" { + segments = append(segments, NewTextSegment(content)) + } + } + + return segments +} + +// HasSegmentType 检查消息链中是否包含指定类型的segment +func HasSegmentType(segments model.MessageSegments, segmentType SegmentType) bool { + for _, segment := range segments { + if segment.Type == string(segmentType) { + return true + } + } + return false +} + +// GetSegmentsByType 获取消息链中所有指定类型的segment +func GetSegmentsByType(segments model.MessageSegments, segmentType SegmentType) []model.MessageSegment { + var result []model.MessageSegment + for _, segment := range segments { + if segment.Type == string(segmentType) { + result = append(result, segment) + } + } + return result +} + +// GetFirstImageURL 获取消息链中第一张图片的URL +// 如果没有图片,返回空字符串 +func GetFirstImageURL(segments model.MessageSegments) string { + for _, segment := range segments { + if segment.Type == string(SegmentTypeImage) { + if url, ok := segment.Data["url"].(string); ok { + return url + } + } + } + return "" +} + +// GetFirstMediaURL 获取消息链中第一个媒体文件的URL(图片/视频/语音/文件) +// 用于兼容旧版本API +func GetFirstMediaURL(segments model.MessageSegments) string { + for _, segment := range segments { + switch segment.Type { + case string(SegmentTypeImage), string(SegmentTypeVideo), string(SegmentTypeVoice), string(SegmentTypeFile): + if url, ok := segment.Data["url"].(string); ok { + return url + } + } + } + return "" +} + +// DetermineContentType 从消息链推断消息类型(用于兼容旧版本) +func DetermineContentType(segments model.MessageSegments) string { + if len(segments) == 0 { + return "text" + } + + // 优先检查媒体类型 + for _, segment := range segments { + switch segment.Type { + case string(SegmentTypeImage): + return "image" + case string(SegmentTypeVideo): + return "video" + case string(SegmentTypeVoice): + return "voice" + case string(SegmentTypeFile): + return "file" + } + } + + // 默认返回text + return "text" +} diff --git a/internal/handler/comment_handler.go b/internal/handler/comment_handler.go new file mode 100644 index 0000000..05ec327 --- /dev/null +++ b/internal/handler/comment_handler.go @@ -0,0 +1,253 @@ +package handler + +import ( + "encoding/json" + "errors" + "strconv" + + "github.com/gin-gonic/gin" + + "carrot_bbs/internal/dto" + "carrot_bbs/internal/pkg/response" + "carrot_bbs/internal/service" +) + +// CommentHandler 评论处理器 +type CommentHandler struct { + commentService *service.CommentService +} + +// NewCommentHandler 创建评论处理器 +func NewCommentHandler(commentService *service.CommentService) *CommentHandler { + return &CommentHandler{ + commentService: commentService, + } +} + +// Create 创建评论 +func (h *CommentHandler) Create(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + type CreateRequest struct { + PostID string `json:"post_id" binding:"required"` + Content string `json:"content"` // 内容可选,允许只发图片 + ParentID *string `json:"parent_id"` + Images []string `json:"images"` // 图片URL列表 + } + + var req CreateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 验证:评论必须有内容或图片 + if req.Content == "" && len(req.Images) == 0 { + response.BadRequest(c, "评论内容或图片不能同时为空") + return + } + + // 将图片列表转换为JSON字符串 + var imagesJSON string + if len(req.Images) > 0 { + imagesBytes, _ := json.Marshal(req.Images) + imagesJSON = string(imagesBytes) + } + + comment, err := h.commentService.Create(c.Request.Context(), req.PostID, userID, req.Content, req.ParentID, imagesJSON, req.Images) + if err != nil { + var moderationErr *service.CommentModerationRejectedError + if errors.As(err, &moderationErr) { + response.BadRequest(c, moderationErr.UserMessage()) + return + } + response.InternalServerError(c, "failed to create comment") + return + } + + response.Success(c, dto.ConvertCommentToResponse(comment, false)) +} + +// GetByID 获取单条评论详情 +func (h *CommentHandler) GetByID(c *gin.Context) { + userID := c.GetString("user_id") + id := c.Param("id") + + comment, err := h.commentService.GetByID(c.Request.Context(), id) + if err != nil { + response.NotFound(c, "comment not found") + return + } + + resp := dto.ConvertCommentToResponse(comment, h.commentService.IsLiked(c.Request.Context(), id, userID)) + response.Success(c, resp) +} + +// GetByPostID 获取帖子评论 +func (h *CommentHandler) GetByPostID(c *gin.Context) { + userID := c.GetString("user_id") + postID := c.Param("id") + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + comments, total, err := h.commentService.GetByPostID(c.Request.Context(), postID, page, pageSize) + if err != nil { + response.InternalServerError(c, "failed to get comments") + return + } + + // 转换为响应结构,检查每个评论的点赞状态 + commentResponses := dto.ConvertCommentsToResponseWithUser(comments, userID, h.commentService) + + response.Paginated(c, commentResponses, total, page, pageSize) +} + +// GetReplies 获取回复 +func (h *CommentHandler) GetReplies(c *gin.Context) { + userID := c.GetString("user_id") + parentID := c.Param("id") + + comments, err := h.commentService.GetReplies(c.Request.Context(), parentID) + if err != nil { + response.InternalServerError(c, "failed to get replies") + return + } + + // 转换为响应结构,检查每个回复的点赞状态 + commentResponses := dto.ConvertCommentsToResponseWithUser(comments, userID, h.commentService) + + response.Success(c, commentResponses) +} + +// GetRepliesByRootID 根据根评论ID分页获取回复(扁平化) +func (h *CommentHandler) GetRepliesByRootID(c *gin.Context) { + userID := c.GetString("user_id") + rootID := c.Param("id") + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10")) + + replies, total, err := h.commentService.GetRepliesByRootID(c.Request.Context(), rootID, page, pageSize) + if err != nil { + response.InternalServerError(c, "failed to get replies") + return + } + + // 转换为响应结构,检查每个回复的点赞状态 + replyResponses := dto.ConvertCommentsToResponseWithUser(replies, userID, h.commentService) + + response.Paginated(c, replyResponses, total, page, pageSize) +} + +// Update 更新评论 +func (h *CommentHandler) Update(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + id := c.Param("id") + + comment, err := h.commentService.GetByID(c.Request.Context(), id) + if err != nil { + response.NotFound(c, "comment not found") + return + } + + if comment.UserID != userID { + response.Forbidden(c, "cannot update others' comment") + return + } + + type UpdateRequest struct { + Content string `json:"content" binding:"required"` + } + + var req UpdateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + comment.Content = req.Content + + err = h.commentService.Update(c.Request.Context(), comment) + if err != nil { + response.InternalServerError(c, "failed to update comment") + return + } + + response.Success(c, dto.ConvertCommentToResponse(comment, false)) +} + +// Delete 删除评论 +func (h *CommentHandler) Delete(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + id := c.Param("id") + + comment, err := h.commentService.GetByID(c.Request.Context(), id) + if err != nil { + response.NotFound(c, "comment not found") + return + } + + if comment.UserID != userID { + response.Forbidden(c, "cannot delete others' comment") + return + } + + err = h.commentService.Delete(c.Request.Context(), id) + if err != nil { + response.InternalServerError(c, "failed to delete comment") + return + } + + response.SuccessWithMessage(c, "comment deleted", nil) +} + +// Like 点赞评论 +func (h *CommentHandler) Like(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + id := c.Param("id") + + err := h.commentService.Like(c.Request.Context(), id, userID) + if err != nil { + response.InternalServerError(c, "failed to like comment") + return + } + + response.SuccessWithMessage(c, "liked", nil) +} + +// Unlike 取消点赞评论 +func (h *CommentHandler) Unlike(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + id := c.Param("id") + + err := h.commentService.Unlike(c.Request.Context(), id, userID) + if err != nil { + response.InternalServerError(c, "failed to unlike comment") + return + } + + response.SuccessWithMessage(c, "unliked", nil) +} diff --git a/internal/handler/gorse_handler.go b/internal/handler/gorse_handler.go new file mode 100644 index 0000000..ec89118 --- /dev/null +++ b/internal/handler/gorse_handler.go @@ -0,0 +1,234 @@ +package handler + +import ( + "context" + "log" + "strings" + "time" + + "carrot_bbs/internal/config" + "carrot_bbs/internal/model" + "carrot_bbs/internal/pkg/gorse" + "carrot_bbs/internal/pkg/response" + + gorseio "github.com/gorse-io/gorse-go" + "github.com/gin-gonic/gin" +) + +// GorseHandler Gorse推荐处理器 +type GorseHandler struct { + importPassword string + gorseConfig config.GorseConfig +} + +// NewGorseHandler 创建Gorse处理器 +func NewGorseHandler(cfg config.GorseConfig) *GorseHandler { + return &GorseHandler{ + importPassword: cfg.ImportPassword, + gorseConfig: cfg, + } +} + +// ImportRequest 导入请求 +type ImportRequest struct { + Password string `json:"password"` +} + +// ImportData 导入数据到Gorse +// POST /api/v1/gorse/import +func (h *GorseHandler) ImportData(c *gin.Context) { + // 验证密码 + if h.importPassword == "" { + response.BadRequest(c, "Gorse import is disabled") + return + } + + var req ImportRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "invalid request body") + return + } + + if req.Password != h.importPassword { + response.Unauthorized(c, "invalid password") + return + } + + ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Minute) + defer cancel() + + stats, err := h.importAllData(ctx) + if err != nil { + log.Printf("[ERROR] gorse import failed: %v", err) + response.InternalServerError(c, "gorse import failed: "+err.Error()) + return + } + + response.Success(c, gin.H{ + "message": "import completed", + "status": "done", + "stats": stats, + }) +} + +// GetStatus 获取Gorse状态 +// GET /api/v1/gorse/status +func (h *GorseHandler) GetStatus(c *gin.Context) { + // 返回Gorse连接状态和配置信息 + hasPassword := h.importPassword != "" + response.Success(c, gin.H{ + "enabled": h.gorseConfig.Enabled, + "has_password": hasPassword, + "address": h.gorseConfig.Address, + "api_key": strings.Repeat("*", 8), // 不返回实际APIKey + }) +} + +func (h *GorseHandler) importAllData(ctx context.Context) (gin.H, error) { + gorseClient := gorseio.NewGorseClient(h.gorseConfig.Address, h.gorseConfig.APIKey) + gorse.InitEmbeddingWithConfig(h.gorseConfig.EmbeddingAPIKey, h.gorseConfig.EmbeddingURL, h.gorseConfig.EmbeddingModel) + + stats := gin.H{ + "items": 0, + "users": 0, + "likes": 0, + "favorites": 0, + "comments": 0, + } + + // 导入帖子 + var posts []model.Post + if err := model.DB.Find(&posts).Error; err != nil { + return nil, err + } + for _, post := range posts { + embedding, err := gorse.GetEmbedding(strings.TrimSpace(post.Title + " " + post.Content)) + if err != nil { + log.Printf("[WARN] get embedding failed for post %s: %v", post.ID, err) + embedding = make([]float64, 1024) + } + _, err = gorseClient.InsertItem(ctx, gorseio.Item{ + ItemId: post.ID, + IsHidden: post.DeletedAt.Valid, + Categories: buildPostCategories(&post), + Comment: post.Title, + Timestamp: post.CreatedAt.UTC().Truncate(time.Second), + Labels: map[string]any{ + "embedding": embedding, + }, + }) + if err != nil { + log.Printf("[WARN] import item failed (%s): %v", post.ID, err) + continue + } + stats["items"] = stats["items"].(int) + 1 + } + + // 导入用户 + var users []model.User + if err := model.DB.Find(&users).Error; err != nil { + return nil, err + } + for _, user := range users { + _, err := gorseClient.InsertUser(ctx, gorseio.User{ + UserId: user.ID, + Labels: map[string]any{ + "posts_count": user.PostsCount, + "followers_count": user.FollowersCount, + "following_count": user.FollowingCount, + }, + Comment: user.Nickname, + }) + if err != nil { + log.Printf("[WARN] import user failed (%s): %v", user.ID, err) + continue + } + stats["users"] = stats["users"].(int) + 1 + } + + // 导入点赞 + var likes []model.PostLike + if err := model.DB.Find(&likes).Error; err != nil { + return nil, err + } + for _, like := range likes { + _, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{ + FeedbackType: string(gorse.FeedbackTypeLike), + UserId: like.UserID, + ItemId: like.PostID, + Timestamp: like.CreatedAt.UTC().Truncate(time.Second), + }}) + if err != nil { + log.Printf("[WARN] import like failed (%s/%s): %v", like.UserID, like.PostID, err) + continue + } + stats["likes"] = stats["likes"].(int) + 1 + } + + // 导入收藏 + var favorites []model.Favorite + if err := model.DB.Find(&favorites).Error; err != nil { + return nil, err + } + for _, fav := range favorites { + _, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{ + FeedbackType: string(gorse.FeedbackTypeStar), + UserId: fav.UserID, + ItemId: fav.PostID, + Timestamp: fav.CreatedAt.UTC().Truncate(time.Second), + }}) + if err != nil { + log.Printf("[WARN] import favorite failed (%s/%s): %v", fav.UserID, fav.PostID, err) + continue + } + stats["favorites"] = stats["favorites"].(int) + 1 + } + + // 导入评论(按用户-帖子去重) + var comments []model.Comment + if err := model.DB.Where("status = ?", model.CommentStatusPublished).Find(&comments).Error; err != nil { + return nil, err + } + seen := make(map[string]struct{}) + for _, cm := range comments { + key := cm.UserID + ":" + cm.PostID + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + _, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{ + FeedbackType: string(gorse.FeedbackTypeComment), + UserId: cm.UserID, + ItemId: cm.PostID, + Timestamp: cm.CreatedAt.UTC().Truncate(time.Second), + }}) + if err != nil { + log.Printf("[WARN] import comment failed (%s/%s): %v", cm.UserID, cm.PostID, err) + continue + } + stats["comments"] = stats["comments"].(int) + 1 + } + + return stats, nil +} + +func buildPostCategories(post *model.Post) []string { + var categories []string + if post.ViewsCount > 1000 { + categories = append(categories, "hot_high") + } else if post.ViewsCount > 100 { + categories = append(categories, "hot_medium") + } + if post.LikesCount > 100 { + categories = append(categories, "likes_100+") + } else if post.LikesCount > 10 { + categories = append(categories, "likes_10+") + } + age := time.Since(post.CreatedAt) + if age < 24*time.Hour { + categories = append(categories, "today") + } else if age < 7*24*time.Hour { + categories = append(categories, "this_week") + } + return categories +} \ No newline at end of file diff --git a/internal/handler/group_handler.go b/internal/handler/group_handler.go new file mode 100644 index 0000000..b487ddc --- /dev/null +++ b/internal/handler/group_handler.go @@ -0,0 +1,1801 @@ +package handler + +import ( + "log" + "strconv" + + "github.com/gin-gonic/gin" + + "carrot_bbs/internal/dto" + "carrot_bbs/internal/model" + "carrot_bbs/internal/pkg/response" + "carrot_bbs/internal/service" +) + +// GroupHandler 群组处理器 +type GroupHandler struct { + groupService service.GroupService + userService *service.UserService +} + +// NewGroupHandler 创建群组处理器 +func NewGroupHandler(groupService service.GroupService, userService *service.UserService) *GroupHandler { + return &GroupHandler{ + groupService: groupService, + userService: userService, + } +} + +// parseUserID 从上下文获取用户ID(UUID格式) +func parseUserID(c *gin.Context) string { + return c.GetString("user_id") +} + +// parseGroupID 从路径参数获取群组ID +func parseGroupID(c *gin.Context) string { + return c.Param("id") +} + +// parseUserIDFromPath 从路径参数获取用户ID(UUID格式) +func parseUserIDFromPath(c *gin.Context) string { + return c.Param("userId") +} + +// parseAnnouncementID 从路径参数获取公告ID +func parseAnnouncementID(c *gin.Context) string { + return c.Param("announcementId") +} + +// ==================== 群组管理 ==================== + +// CreateGroup 创建群组 +// POST /api/groups +func (h *GroupHandler) CreateGroup(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var req dto.CreateGroupRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + group, err := h.groupService.CreateGroup(userID, req.Name, req.Description, req.MemberIDs) + if err != nil { + response.InternalServerError(c, err.Error()) + return + } + + response.Success(c, dto.GroupToResponse(group)) +} + +// GetGroup 获取群组详情 +// GET /api/groups/:id +func (h *GroupHandler) GetGroup(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + group, err := h.groupService.GetGroupByID(groupID) + if err != nil { + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + // 实时计算群成员数量 + memberCount, _ := h.groupService.GetMemberCount(groupID) + + // 创建响应并设置实时计算的member_count + resp := dto.GroupToResponse(group) + resp.MemberCount = memberCount + + response.Success(c, resp) +} + +// UpdateGroup 更新群组信息 +// PUT /api/groups/:id +func (h *GroupHandler) UpdateGroup(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + var req dto.UpdateGroupRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + updates := make(map[string]interface{}) + if req.Name != "" { + updates["name"] = req.Name + } + if req.Description != "" { + updates["description"] = req.Description + } + if req.Avatar != "" { + updates["avatar"] = req.Avatar + } + + if err := h.groupService.UpdateGroup(userID, groupID, updates); err != nil { + if err == service.ErrNotGroupAdmin { + response.Forbidden(c, "没有权限修改群组信息") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + // 获取更新后的群组信息 + group, _ := h.groupService.GetGroupByID(groupID) + response.Success(c, dto.GroupToResponse(group)) +} + +// DissolveGroup 解散群组 +// DELETE /api/groups/:id +func (h *GroupHandler) DissolveGroup(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + if err := h.groupService.DissolveGroup(userID, groupID); err != nil { + if err == service.ErrNotGroupOwner { + response.Forbidden(c, "只有群主可以解散群组") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "群组已解散", nil) +} + +// TransferOwner 转让群主 +// POST /api/groups/:id/transfer +func (h *GroupHandler) TransferOwner(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + var req dto.TransferOwnerRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if err := h.groupService.TransferOwner(userID, groupID, req.NewOwnerID); err != nil { + if err == service.ErrNotGroupOwner { + response.Forbidden(c, "只有群主可以转让群主") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + if err == service.ErrNotGroupMember { + response.BadRequest(c, "新群主必须是群成员") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "群主已转让", nil) +} + +// GetUserGroups 获取用户的群组列表 +// GET /api/groups +func (h *GroupHandler) GetUserGroups(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + groups, total, err := h.groupService.GetUserGroups(userID, page, pageSize) + if err != nil { + response.InternalServerError(c, err.Error()) + return + } + + response.Paginated(c, dto.GroupsToResponse(groups), total, page, pageSize) +} + +// ==================== 成员管理 ==================== + +// InviteMembers 邀请成员加入群组 +// POST /api/groups/:id/members/invite +func (h *GroupHandler) InviteMembers(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + var req dto.InviteMembersRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if err := h.groupService.InviteMembers(userID, groupID, req.MemberIDs); err != nil { + if err == service.ErrNotGroupMember { + response.Forbidden(c, "只有群成员可以邀请他人") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + if err == service.ErrGroupFull { + response.BadRequest(c, "群已满") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "邀请成功", nil) +} + +// JoinGroup 加入群组 +// POST /api/groups/:id/join +func (h *GroupHandler) JoinGroup(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + if err := h.groupService.JoinGroup(userID, groupID); err != nil { + if err == service.ErrCannotJoin { + response.Forbidden(c, "该群不允许加入") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + if err == service.ErrAlreadyMember { + response.BadRequest(c, "已经是群成员") + return + } + if err == service.ErrGroupFull { + response.BadRequest(c, "群已满") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "加入成功", nil) +} + +// LeaveGroup 退出群组 +// POST /api/groups/:id/leave +func (h *GroupHandler) LeaveGroup(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + if err := h.groupService.LeaveGroup(userID, groupID); err != nil { + if err == service.ErrNotGroupMember { + response.BadRequest(c, "不是群成员") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "已退出群组", nil) +} + +// RemoveMember 移除群成员 +// DELETE /api/groups/:id/members/:userId +func (h *GroupHandler) RemoveMember(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + targetUserID := parseUserIDFromPath(c) + if targetUserID == "" { + response.BadRequest(c, "invalid user id") + return + } + + if err := h.groupService.RemoveMember(userID, groupID, targetUserID); err != nil { + if err == service.ErrNotGroupAdmin { + response.Forbidden(c, "只有群主或管理员可以移除成员") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + if err == service.ErrNotGroupMember { + response.BadRequest(c, "该用户不是群成员") + return + } + if err == service.ErrCannotRemoveOwner { + response.Forbidden(c, "不能移除群主") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "已移除成员", nil) +} + +// GetMembers 获取群成员列表 +// GET /api/groups/:id/members +func (h *GroupHandler) GetMembers(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "50")) + + members, total, err := h.groupService.GetMembers(groupID, page, pageSize) + if err != nil { + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + // 转换为响应格式,并预加载用户信息 + result := make([]*dto.GroupMemberResponse, 0, len(members)) + for _, member := range members { + memberResp := dto.GroupMemberToResponse(&member) + // 预加载用户信息 + user, _ := h.userService.GetUserByID(c.Request.Context(), member.UserID) + if user != nil { + memberResp.User = dto.ConvertUserToResponse(user) + } + result = append(result, memberResp) + } + + response.Paginated(c, result, total, page, pageSize) +} + +// ==================== RESTful Action 端点 ==================== + +// HandleCreateGroup 创建群组 +// POST /api/v1/groups/create +func (h *GroupHandler) HandleCreateGroup(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.CreateGroupParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + group, err := h.groupService.CreateGroup(userID, params.Name, params.Description, params.MemberIDs) + if err != nil { + response.InternalServerError(c, err.Error()) + return + } + + response.Success(c, dto.GroupToResponse(group)) +} + +// HandleGetUserGroups 获取用户群组列表 +// GET /api/v1/groups/list +func (h *GroupHandler) HandleGetUserGroups(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + groups, total, err := h.groupService.GetUserGroups(userID, page, pageSize) + if err != nil { + response.InternalServerError(c, err.Error()) + return + } + + response.Paginated(c, dto.GroupsToResponse(groups), total, page, pageSize) +} + +// HandleGetMyMemberInfo 获取我在群组中的成员信息 +// GET /api/v1/groups/get_my_info?group_id=xxx +func (h *GroupHandler) HandleGetMyMemberInfo(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := c.Query("group_id") + if groupID == "" { + response.BadRequest(c, "group_id is required") + return + } + + // 获取群组信息 + group, err := h.groupService.GetGroupByID(groupID) + if err != nil { + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + // 获取当前用户的成员信息 + member, err := h.groupService.GetMember(groupID, userID) + if err != nil { + response.NotFound(c, "不是群成员") + return + } + + // 构建响应 + memberResp := dto.GroupMemberToResponse(member) + + // 预加载用户信息 + user, _ := h.userService.GetUserByID(c.Request.Context(), userID) + if user != nil { + memberResp.User = dto.ConvertUserToResponse(user) + } + + // 添加群组禁言状态信息 + response.Success(c, map[string]interface{}{ + "member": memberResp, + "mute_all": group.MuteAll, + "is_muted": member.Muted || group.MuteAll, + "can_speak": !member.Muted && !group.MuteAll, + }) +} + +// HandleDissolveGroup 解散群组 +// POST /api/v1/groups/dissolve +func (h *GroupHandler) HandleDissolveGroup(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.DissolveGroupParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.GroupID == "" { + response.BadRequest(c, "group_id is required") + return + } + + if err := h.groupService.DissolveGroup(userID, params.GroupID); err != nil { + if err == service.ErrNotGroupOwner { + response.Forbidden(c, "只有群主可以解散群组") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "群组已解散", nil) +} + +// HandleTransferOwner 转让群主 +// POST /api/v1/groups/transfer +func (h *GroupHandler) HandleTransferOwner(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.TransferOwnerParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.GroupID == "" { + response.BadRequest(c, "group_id is required") + return + } + if params.NewOwnerID == "" { + response.BadRequest(c, "new_owner_id is required") + return + } + + if err := h.groupService.TransferOwner(userID, params.GroupID, params.NewOwnerID); err != nil { + if err == service.ErrNotGroupOwner { + response.Forbidden(c, "只有群主可以转让群主") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + if err == service.ErrNotGroupMember { + response.BadRequest(c, "新群主必须是群成员") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "群主已转让", nil) +} + +// HandleInviteMembers 邀请成员加入群组 +// POST /api/v1/groups/invite_members +func (h *GroupHandler) HandleInviteMembers(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.InviteMembersParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.GroupID == "" { + response.BadRequest(c, "group_id is required") + return + } + + if err := h.groupService.InviteMembers(userID, params.GroupID, params.MemberIDs); err != nil { + if err == service.ErrNotGroupMember { + response.Forbidden(c, "只有群成员可以邀请他人") + return + } + if err == service.ErrNotGroupAdmin { + response.Forbidden(c, "只有群主或管理员可以邀请他人") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + if err == service.ErrNoEligibleInvitee { + response.BadRequest(c, "暂无可邀请对象(需互相关注且未在群内)") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "邀请请求已处理", nil) +} + +// HandleJoinGroup 加入群组 +// POST /api/v1/groups/join +func (h *GroupHandler) HandleJoinGroup(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.JoinGroupParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.GroupID == "" { + response.BadRequest(c, "group_id is required") + return + } + + if err := h.groupService.JoinGroup(userID, params.GroupID); err != nil { + if err == service.ErrJoinRequestPending { + response.SuccessWithMessage(c, "申请已提交,等待群主/管理员审批", nil) + return + } + if err == service.ErrCannotJoin { + response.Forbidden(c, "该群不允许加入") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + if err == service.ErrAlreadyMember { + response.BadRequest(c, "已经是群成员") + return + } + if err == service.ErrGroupFull { + response.BadRequest(c, "群已满") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "加入成功", nil) +} + +// HandleSetNickname 设置群内昵称 +// POST /api/v1/groups/set_nickname +func (h *GroupHandler) HandleSetNickname(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.SetNicknameParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.GroupID == "" { + response.BadRequest(c, "group_id is required") + return + } + + if err := h.groupService.SetMemberNickname(userID, params.GroupID, params.Nickname); err != nil { + if err == service.ErrNotGroupMember { + response.BadRequest(c, "不是群成员") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "昵称已更新", nil) +} + +// HandleSetJoinType 设置加群方式 +// POST /api/v1/groups/set_join_type +func (h *GroupHandler) HandleSetJoinType(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.SetJoinTypeParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.GroupID == "" { + response.BadRequest(c, "group_id is required") + return + } + + if err := h.groupService.SetJoinType(userID, params.GroupID, params.JoinType); err != nil { + if err == service.ErrNotGroupOwner { + response.Forbidden(c, "只有群主可以设置加群方式") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + joinTypeStr := "允许任何人加入" + switch params.JoinType { + case int(model.JoinTypeApproval): + joinTypeStr = "需要审批" + case int(model.JoinTypeForbidden): + joinTypeStr = "不允许加入" + } + + response.SuccessWithMessage(c, "加群方式已更新为: "+joinTypeStr, nil) +} + +// HandleCreateAnnouncement 创建群公告 +// POST /api/v1/groups/create_announcement +func (h *GroupHandler) HandleCreateAnnouncement(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.CreateAnnouncementParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.GroupID == "" { + response.BadRequest(c, "group_id is required") + return + } + + announcement, err := h.groupService.CreateAnnouncement(userID, params.GroupID, params.Content) + if err != nil { + if err == service.ErrNotGroupAdmin { + response.Forbidden(c, "只有群主或管理员可以发布公告") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.Success(c, dto.GroupAnnouncementToResponse(announcement)) +} + +// HandleGetAnnouncements 获取群公告列表 +// GET /api/v1/groups/get_announcements?group_id=xxx +func (h *GroupHandler) HandleGetAnnouncements(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := c.Query("group_id") + if groupID == "" { + response.BadRequest(c, "group_id is required") + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + announcements, total, err := h.groupService.GetAnnouncements(groupID, page, pageSize) + if err != nil { + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.Paginated(c, dto.GroupAnnouncementsToResponse(announcements), total, page, pageSize) +} + +// HandleDeleteAnnouncement 删除群公告 +// POST /api/v1/groups/delete_announcement +func (h *GroupHandler) HandleDeleteAnnouncement(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.DeleteAnnouncementParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.GroupID == "" { + response.BadRequest(c, "group_id is required") + return + } + if params.AnnouncementID == "" { + response.BadRequest(c, "announcement_id is required") + return + } + + if err := h.groupService.DeleteAnnouncement(userID, params.AnnouncementID); err != nil { + if err == service.ErrNotGroupAdmin { + response.Forbidden(c, "只有群主或管理员可以删除公告") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "公告已删除", nil) +} + +// GetMyMemberInfo 获取当前用户在群组中的成员信息 +// GET /api/groups/:id/me +func (h *GroupHandler) GetMyMemberInfo(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + // 获取群组信息 + group, err := h.groupService.GetGroupByID(groupID) + if err != nil { + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + // 获取当前用户的成员信息 + member, err := h.groupService.GetMember(groupID, userID) + if err != nil { + response.NotFound(c, "不是群成员") + return + } + + // 构建响应 + memberResp := dto.GroupMemberToResponse(member) + + // 预加载用户信息 + user, _ := h.userService.GetUserByID(c.Request.Context(), userID) + if user != nil { + memberResp.User = dto.ConvertUserToResponse(user) + } + + // 添加群组禁言状态信息 + response.Success(c, map[string]interface{}{ + "member": memberResp, + "mute_all": group.MuteAll, + "is_muted": member.Muted || group.MuteAll, + "can_speak": !member.Muted && !group.MuteAll, + }) +} + +// SetMemberRole 设置成员角色 +// PUT /api/groups/:id/members/:userId/role +func (h *GroupHandler) SetMemberRole(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + targetUserID := parseUserIDFromPath(c) + if targetUserID == "" { + response.BadRequest(c, "invalid user id") + return + } + + var req dto.SetRoleRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if err := h.groupService.SetMemberRole(userID, groupID, targetUserID, req.Role); err != nil { + if err == service.ErrNotGroupOwner { + response.Forbidden(c, "只有群主可以设置成员角色") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + if err == service.ErrNotGroupMember { + response.BadRequest(c, "该用户不是群成员") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "角色已更新", nil) +} + +// SetNickname 设置群内昵称 +// PUT /api/groups/:id/nickname +func (h *GroupHandler) SetNickname(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + var req dto.SetNicknameRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if err := h.groupService.SetMemberNickname(userID, groupID, req.Nickname); err != nil { + if err == service.ErrNotGroupMember { + response.BadRequest(c, "不是群成员") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "昵称已更新", nil) +} + +// MuteMember 禁言/解禁成员 +// PUT /api/groups/:id/members/:userId/mute +func (h *GroupHandler) MuteMember(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + targetUserID := parseUserIDFromPath(c) + if targetUserID == "" { + response.BadRequest(c, "invalid user id") + return + } + + var req dto.MuteMemberRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if err := h.groupService.MuteMember(userID, groupID, targetUserID, req.Muted); err != nil { + if err == service.ErrNotGroupAdmin { + response.Forbidden(c, "只有群主或管理员可以禁言成员") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + if err == service.ErrNotGroupMember { + response.BadRequest(c, "该用户不是群成员") + return + } + if err == service.ErrCannotMuteOwner { + response.Forbidden(c, "不能禁言群主") + return + } + response.InternalServerError(c, err.Error()) + return + } + + if req.Muted { + response.SuccessWithMessage(c, "已禁言该成员", nil) + } else { + response.SuccessWithMessage(c, "已解除禁言", nil) + } +} + +// ==================== 群设置 ==================== + +// SetMuteAll 设置全员禁言 +// PUT /api/groups/:id/mute-all +func (h *GroupHandler) SetMuteAll(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + var req dto.SetMuteAllRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if err := h.groupService.SetMuteAll(userID, groupID, req.MuteAll); err != nil { + if err == service.ErrNotGroupOwner { + response.Forbidden(c, "只有群主可以设置全员禁言") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + if req.MuteAll { + response.SuccessWithMessage(c, "已开启全员禁言", nil) + } else { + response.SuccessWithMessage(c, "已关闭全员禁言", nil) + } +} + +// SetJoinType 设置加群方式 +// PUT /api/groups/:id/join-type +func (h *GroupHandler) SetJoinType(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + var req dto.SetJoinTypeRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if err := h.groupService.SetJoinType(userID, groupID, req.JoinType); err != nil { + if err == service.ErrNotGroupOwner { + response.Forbidden(c, "只有群主可以设置加群方式") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + joinTypeStr := "允许任何人加入" + switch req.JoinType { + case int(model.JoinTypeApproval): + joinTypeStr = "需要审批" + case int(model.JoinTypeForbidden): + joinTypeStr = "不允许加入" + } + + response.SuccessWithMessage(c, "加群方式已更新为: "+joinTypeStr, nil) +} + +// ==================== 群公告 ==================== + +// CreateAnnouncement 创建群公告 +// POST /api/groups/:id/announcements +func (h *GroupHandler) CreateAnnouncement(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + var req dto.CreateAnnouncementRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + announcement, err := h.groupService.CreateAnnouncement(userID, groupID, req.Content) + if err != nil { + if err == service.ErrNotGroupAdmin { + response.Forbidden(c, "只有群主或管理员可以发布公告") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.Success(c, dto.GroupAnnouncementToResponse(announcement)) +} + +// GetAnnouncements 获取群公告列表 +// GET /api/groups/:id/announcements +func (h *GroupHandler) GetAnnouncements(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + announcements, total, err := h.groupService.GetAnnouncements(groupID, page, pageSize) + if err != nil { + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.Paginated(c, dto.GroupAnnouncementsToResponse(announcements), total, page, pageSize) +} + +// DeleteAnnouncement 删除群公告 +// DELETE /api/groups/:id/announcements/:announcementId +func (h *GroupHandler) DeleteAnnouncement(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "invalid group id") + return + } + + announcementID := parseAnnouncementID(c) + if announcementID == "" { + response.BadRequest(c, "invalid announcement id") + return + } + + if err := h.groupService.DeleteAnnouncement(userID, announcementID); err != nil { + if err == service.ErrNotGroupAdmin { + response.Forbidden(c, "只有群主或管理员可以删除公告") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "公告已删除", nil) +} + +// ==================== RESTful Action 端点 ==================== + +// HandleSetGroupKick 群组踢人 +// POST /api/v1/groups/set_group_kick +func (h *GroupHandler) HandleSetGroupKick(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.SetGroupKickParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.GroupID == "" { + response.BadRequest(c, "group_id is required") + return + } + if params.UserID == "" { + response.BadRequest(c, "user_id is required") + return + } + + // 使用 RemoveMember 方法 + err := h.groupService.RemoveMember(userID, params.GroupID, params.UserID) + if err != nil { + if err == service.ErrNotGroupAdmin { + response.Forbidden(c, "只有群主或管理员可以移除成员") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + if err == service.ErrNotGroupMember { + response.BadRequest(c, "该用户不是群成员") + return + } + if err == service.ErrCannotRemoveOwner { + response.Forbidden(c, "不能移除群主") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "已移除成员", nil) +} + +// HandleSetGroupBan 群组单人禁言 +// POST /api/v1/groups/set_group_ban +func (h *GroupHandler) HandleSetGroupBan(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.SetGroupBanParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.GroupID == "" { + response.BadRequest(c, "group_id is required") + return + } + if params.UserID == "" { + response.BadRequest(c, "user_id is required") + return + } + + // duration > 0 或 duration = -1 表示禁言,duration = 0 表示解除禁言 + muted := params.Duration != 0 + log.Printf("[HandleSetGroupBan] 开始禁言操作: userID=%s, groupID=%s, targetUserID=%s, duration=%d, muted=%v", userID, params.GroupID, params.UserID, params.Duration, muted) + err := h.groupService.MuteMember(userID, params.GroupID, params.UserID, muted) + if err != nil { + log.Printf("[HandleSetGroupBan] 禁言操作失败: %v", err) + } else { + log.Printf("[HandleSetGroupBan] 禁言操作成功") + } + if err != nil { + if err == service.ErrNotGroupAdmin { + response.Forbidden(c, "只有群主或管理员可以禁言成员") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + if err == service.ErrNotGroupMember { + response.BadRequest(c, "该用户不是群成员") + return + } + if err == service.ErrCannotMuteOwner { + response.Forbidden(c, "不能禁言群主") + return + } + response.InternalServerError(c, err.Error()) + return + } + + if muted { + response.SuccessWithMessage(c, "已禁言该成员", nil) + } else { + response.SuccessWithMessage(c, "已解除禁言", nil) + } +} + +// HandleSetGroupWholeBan 群组全员禁言 +// POST /api/v1/groups/set_group_whole_ban +func (h *GroupHandler) HandleSetGroupWholeBan(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.SetGroupWholeBanParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.GroupID == "" { + response.BadRequest(c, "group_id is required") + return + } + + err := h.groupService.SetMuteAll(userID, params.GroupID, params.Enable) + if err != nil { + if err == service.ErrNotGroupOwner { + response.Forbidden(c, "只有群主可以设置全员禁言") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + if params.Enable { + response.SuccessWithMessage(c, "已开启全员禁言", nil) + } else { + response.SuccessWithMessage(c, "已关闭全员禁言", nil) + } +} + +// HandleSetGroupAdmin 群组设置管理员 +// POST /api/v1/groups/set_group_admin +func (h *GroupHandler) HandleSetGroupAdmin(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.SetGroupAdminParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.GroupID == "" { + response.BadRequest(c, "group_id is required") + return + } + if params.UserID == "" { + response.BadRequest(c, "user_id is required") + return + } + + // 根据 enable 参数设置角色 + role := model.GroupRoleMember + if params.Enable { + role = model.GroupRoleAdmin + } + + err := h.groupService.SetMemberRole(userID, params.GroupID, params.UserID, role) + if err != nil { + if err == service.ErrNotGroupOwner { + response.Forbidden(c, "只有群主可以设置管理员") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + if err == service.ErrNotGroupMember { + response.BadRequest(c, "该用户不是群成员") + return + } + response.InternalServerError(c, err.Error()) + return + } + + if params.Enable { + response.SuccessWithMessage(c, "已设置为管理员", nil) + } else { + response.SuccessWithMessage(c, "已取消管理员", nil) + } +} + +// HandleSetGroupName 设置群名 +// POST /api/v1/groups/set_group_name +func (h *GroupHandler) HandleSetGroupName(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.SetGroupNameParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.GroupID == "" { + response.BadRequest(c, "group_id is required") + return + } + if params.GroupName == "" { + response.BadRequest(c, "group_name is required") + return + } + + updates := map[string]interface{}{ + "name": params.GroupName, + } + + err := h.groupService.UpdateGroup(userID, params.GroupID, updates) + if err != nil { + if err == service.ErrNotGroupAdmin { + response.Forbidden(c, "没有权限修改群组信息") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + // 获取更新后的群组信息 + group, _ := h.groupService.GetGroupByID(params.GroupID) + response.Success(c, dto.GroupToResponse(group)) +} + +// HandleSetGroupAvatar 设置群头像 +// POST /api/v1/groups/set_group_avatar +func (h *GroupHandler) HandleSetGroupAvatar(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.SetGroupAvatarParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.GroupID == "" { + response.BadRequest(c, "group_id is required") + return + } + if params.Avatar == "" { + response.BadRequest(c, "avatar is required") + return + } + + updates := map[string]interface{}{ + "avatar": params.Avatar, + } + + err := h.groupService.UpdateGroup(userID, params.GroupID, updates) + if err != nil { + if err == service.ErrNotGroupAdmin { + response.Forbidden(c, "没有权限修改群组信息") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + // 获取更新后的群组信息 + group, _ := h.groupService.GetGroupByID(params.GroupID) + response.Success(c, dto.GroupToResponse(group)) +} + +// HandleSetGroupLeave 退出群组 +// POST /api/v1/groups/set_group_leave +func (h *GroupHandler) HandleSetGroupLeave(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.SetGroupLeaveParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.GroupID == "" { + response.BadRequest(c, "group_id is required") + return + } + + err := h.groupService.LeaveGroup(userID, params.GroupID) + if err != nil { + if err == service.ErrNotGroupMember { + response.BadRequest(c, "不是群成员") + return + } + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "已退出群组", nil) +} + +// HandleSetGroupAddRequest 处理加群审批 +// POST /api/v1/groups/set_group_add_request +func (h *GroupHandler) HandleSetGroupAddRequest(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.SetGroupAddRequestParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.Flag == "" { + response.BadRequest(c, "flag is required") + return + } + + if err := h.groupService.SetGroupAddRequest(userID, params.Flag, params.Approve, params.Reason); err != nil { + if err == service.ErrGroupRequestNotFound { + response.NotFound(c, "加群申请不存在") + return + } + if err == service.ErrGroupRequestHandled { + response.BadRequest(c, "该加群申请已处理") + return + } + if err == service.ErrNotGroupAdmin { + response.Forbidden(c, "仅群主或管理员可审批") + return + } + if err == service.ErrGroupFull { + response.BadRequest(c, "群已满") + return + } + response.InternalServerError(c, err.Error()) + return + } + + if params.Approve { + response.SuccessWithMessage(c, "已同意加群申请", nil) + return + } + response.SuccessWithMessage(c, "已拒绝加群申请", nil) +} + +// HandleRespondInvite 处理群邀请响应 +// POST /api/v1/groups/respond_invite +func (h *GroupHandler) HandleRespondInvite(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.SetGroupAddRequestParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + if params.Flag == "" { + response.BadRequest(c, "flag is required") + return + } + + if err := h.groupService.RespondInvite(userID, params.Flag, params.Approve, params.Reason); err != nil { + if err == service.ErrGroupRequestNotFound { + response.NotFound(c, "邀请不存在") + return + } + if err == service.ErrGroupRequestHandled { + response.BadRequest(c, "邀请已处理") + return + } + if err == service.ErrNotRequestTarget { + response.Forbidden(c, "无权处理该邀请") + return + } + if err == service.ErrGroupFull { + response.BadRequest(c, "群已满") + return + } + response.InternalServerError(c, err.Error()) + return + } + + if params.Approve { + response.SuccessWithMessage(c, "已接受邀请", nil) + return + } + response.SuccessWithMessage(c, "已拒绝邀请", nil) +} + +// HandleGetGroupInfo 获取群信息 +// GET /api/v1/groups/get?group_id=xxx +func (h *GroupHandler) HandleGetGroupInfo(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := c.Query("group_id") + if groupID == "" { + response.BadRequest(c, "group_id is required") + return + } + + group, err := h.groupService.GetGroupByID(groupID) + if err != nil { + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + // 实时计算群成员数量 + memberCount, _ := h.groupService.GetMemberCount(groupID) + + // 创建响应并设置实时计算的member_count + resp := dto.GroupToResponse(group) + resp.MemberCount = memberCount + + response.Success(c, resp) +} + +// HandleGetGroupMemberList 获取群成员列表 +// GET /api/v1/groups/get_members?group_id=xxx +func (h *GroupHandler) HandleGetGroupMemberList(c *gin.Context) { + userID := parseUserID(c) + if userID == "" { + response.Unauthorized(c, "") + return + } + + groupID := c.Query("group_id") + if groupID == "" { + response.BadRequest(c, "group_id is required") + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "50")) + + members, total, err := h.groupService.GetMembers(groupID, page, pageSize) + if err != nil { + if err == service.ErrGroupNotFound { + response.NotFound(c, "群组不存在") + return + } + response.InternalServerError(c, err.Error()) + return + } + + // 转换为响应格式,并预加载用户信息 + result := make([]*dto.GroupMemberResponse, 0, len(members)) + for _, member := range members { + memberResp := dto.GroupMemberToResponse(&member) + // 预加载用户信息 + user, _ := h.userService.GetUserByID(c.Request.Context(), member.UserID) + if user != nil { + memberResp.User = dto.ConvertUserToResponse(user) + } + result = append(result, memberResp) + } + + response.Paginated(c, result, total, page, pageSize) +} diff --git a/internal/handler/message_handler.go b/internal/handler/message_handler.go new file mode 100644 index 0000000..3cab38d --- /dev/null +++ b/internal/handler/message_handler.go @@ -0,0 +1,879 @@ +package handler + +import ( + "context" + "fmt" + "strconv" + + "github.com/gin-gonic/gin" + + "carrot_bbs/internal/dto" + "carrot_bbs/internal/model" + "carrot_bbs/internal/pkg/response" + "carrot_bbs/internal/service" +) + +// MessageHandler 消息处理器 +type MessageHandler struct { + chatService service.ChatService + messageService *service.MessageService + userService *service.UserService + groupService service.GroupService +} + +// NewMessageHandler 创建消息处理器 +func NewMessageHandler(chatService service.ChatService, messageService *service.MessageService, userService *service.UserService, groupService service.GroupService) *MessageHandler { + return &MessageHandler{ + chatService: chatService, + messageService: messageService, + userService: userService, + groupService: groupService, + } +} + +// GetConversations 获取会话列表 +// GET /api/conversations +func (h *MessageHandler) GetConversations(c *gin.Context) { + userID := c.GetString("user_id") + // 添加调试日志 + if userID == "" { + response.Unauthorized(c, "") + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + convs, _, err := h.chatService.GetConversationList(c.Request.Context(), userID, page, pageSize) + if err != nil { + response.InternalServerError(c, "failed to get conversations") + return + } + + // 过滤掉系统会话(系统通知现在使用独立的表) + filteredConvs := make([]*model.Conversation, 0) + for _, conv := range convs { + if conv.ID != model.SystemConversationID { + filteredConvs = append(filteredConvs, conv) + } + } + + // 转换为响应格式 + result := make([]*dto.ConversationResponse, len(filteredConvs)) + for i, conv := range filteredConvs { + // 获取未读数 + unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conv.ID, userID) + + // 获取最后一条消息 + var lastMessage *model.Message + messages, _, _ := h.chatService.GetMessages(c.Request.Context(), conv.ID, userID, 1, 1) + if len(messages) > 0 { + lastMessage = messages[0] + } + + // 群聊时返回member_count,私聊时返回participants + var resp *dto.ConversationResponse + myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID) + isPinned := myParticipant != nil && myParticipant.IsPinned + if conv.Type == model.ConversationTypeGroup && conv.GroupID != nil && *conv.GroupID != "" { + // 群聊:实时计算群成员数量 + memberCount, _ := h.groupService.GetMemberCount(*conv.GroupID) + // 创建响应并设置member_count + resp = dto.ConvertConversationToResponse(conv, nil, int(unreadCount), lastMessage, isPinned) + resp.MemberCount = memberCount + } else { + // 私聊:获取参与者信息 + participants, _ := h.getConversationParticipants(c.Request.Context(), conv.ID, userID) + resp = dto.ConvertConversationToResponse(conv, participants, int(unreadCount), lastMessage, isPinned) + } + result[i] = resp + } + + // 更新 total 为过滤后的数量 + response.Paginated(c, result, int64(len(filteredConvs)), page, pageSize) +} + +// CreateConversation 创建私聊会话 +// POST /api/conversations +func (h *MessageHandler) CreateConversation(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + var req dto.CreateConversationRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 验证目标用户是否存在 + targetUser, err := h.userService.GetUserByID(c.Request.Context(), req.UserID) + if err != nil { + response.BadRequest(c, "target user not found") + return + } + + // 不能和自己创建会话 + if userID == req.UserID { + response.BadRequest(c, "cannot create conversation with yourself") + return + } + + conv, err := h.chatService.GetOrCreateConversation(c.Request.Context(), userID, req.UserID) + if err != nil { + response.InternalServerError(c, "failed to create conversation") + return + } + + // 获取参与者信息 + participants := []*model.User{targetUser} + myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID) + isPinned := myParticipant != nil && myParticipant.IsPinned + + response.Success(c, dto.ConvertConversationToResponse(conv, participants, 0, nil, isPinned)) +} + +// GetConversationByID 获取会话详情 +// GET /api/conversations/:id +func (h *MessageHandler) GetConversationByID(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + conversationIDStr := c.Param("id") + fmt.Printf("[DEBUG] GetConversationByID: conversationIDStr = %s\n", conversationIDStr) + conversationID, err := service.ParseConversationID(conversationIDStr) + if err != nil { + fmt.Printf("[DEBUG] GetConversationByID: failed to parse conversation ID: %v\n", err) + response.BadRequest(c, "invalid conversation id") + return + } + fmt.Printf("[DEBUG] GetConversationByID: conversationID = %s\n", conversationID) + + conv, err := h.chatService.GetConversationByID(c.Request.Context(), conversationID, userID) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 获取未读数 + unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID) + + // 获取参与者信息 + participants, _ := h.getConversationParticipants(c.Request.Context(), conversationID, userID) + + // 获取当前用户的已读位置 + myLastReadSeq := int64(0) + isPinned := false + allParticipants, _ := h.messageService.GetConversationParticipants(conversationID) + for _, p := range allParticipants { + if p.UserID == userID { + myLastReadSeq = p.LastReadSeq + isPinned = p.IsPinned + break + } + } + + // 获取对方用户的已读位置 + otherLastReadSeq := int64(0) + response.Success(c, dto.ConvertConversationToDetailResponse(conv, participants, unreadCount, nil, myLastReadSeq, otherLastReadSeq, isPinned)) +} + +// GetMessages 获取消息列表 +// GET /api/conversations/:id/messages +func (h *MessageHandler) GetMessages(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + conversationIDStr := c.Param("id") + conversationID, err := service.ParseConversationID(conversationIDStr) + if err != nil { + response.BadRequest(c, "invalid conversation id") + return + } + + // 检查是否使用增量同步(after_seq参数) + afterSeqStr := c.Query("after_seq") + if afterSeqStr != "" { + // 增量同步模式 + afterSeq, err := strconv.ParseInt(afterSeqStr, 10, 64) + if err != nil { + response.BadRequest(c, "invalid after_seq") + return + } + + limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20")) + + messages, err := h.chatService.GetMessagesAfterSeq(c.Request.Context(), conversationID, userID, afterSeq, limit) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 转换为响应格式 + result := dto.ConvertMessagesToResponse(messages) + + response.Success(c, &dto.MessageSyncResponse{ + Messages: result, + HasMore: len(messages) == limit, + }) + return + } + + // 检查是否使用历史消息加载(before_seq参数) + beforeSeqStr := c.Query("before_seq") + if beforeSeqStr != "" { + // 加载更早的消息(下拉加载更多) + beforeSeq, err := strconv.ParseInt(beforeSeqStr, 10, 64) + if err != nil { + response.BadRequest(c, "invalid before_seq") + return + } + + limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20")) + + messages, err := h.chatService.GetMessagesBeforeSeq(c.Request.Context(), conversationID, userID, beforeSeq, limit) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 转换为响应格式 + result := dto.ConvertMessagesToResponse(messages) + + response.Success(c, &dto.MessageSyncResponse{ + Messages: result, + HasMore: len(messages) == limit, + }) + return + } + + // 分页模式 + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + messages, total, err := h.chatService.GetMessages(c.Request.Context(), conversationID, userID, page, pageSize) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 转换为响应格式 + result := dto.ConvertMessagesToResponse(messages) + + response.Paginated(c, result, total, page, pageSize) +} + +// SendMessage 发送消息 +// POST /api/conversations/:id/messages +func (h *MessageHandler) SendMessage(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + conversationIDStr := c.Param("id") + fmt.Printf("[DEBUG] SendMessage: conversationIDStr = %s\n", conversationIDStr) + conversationID, err := service.ParseConversationID(conversationIDStr) + if err != nil { + fmt.Printf("[DEBUG] SendMessage: failed to parse conversation ID: %v\n", err) + response.BadRequest(c, "invalid conversation id") + return + } + fmt.Printf("[DEBUG] SendMessage: conversationID = %s, userID = %s\n", conversationID, userID) + + var req dto.SendMessageRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 直接使用 segments + msg, err := h.chatService.SendMessage(c.Request.Context(), userID, conversationID, req.Segments, req.ReplyToID) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + response.Success(c, dto.ConvertMessageToResponse(msg)) +} + +// HandleSendMessage RESTful 风格的发送消息端点 +// POST /api/v1/conversations/send_message +// 请求体格式: {"detail_type": "private", "conversation_id": "123445667", "segments": [{"type": "text", "data": {"text": "嗨~"}}]} +func (h *MessageHandler) HandleSendMessage(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.SendMessageParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 验证参数 + if params.ConversationID == "" { + response.BadRequest(c, "conversation_id is required") + return + } + if params.DetailType == "" { + response.BadRequest(c, "detail_type is required") + return + } + if params.Segments == nil || len(params.Segments) == 0 { + response.BadRequest(c, "segments is required") + return + } + + // 发送消息 + msg, err := h.chatService.SendMessage(c.Request.Context(), userID, params.ConversationID, params.Segments, params.ReplyToID) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 构建 WSEventResponse 格式响应 + wsResponse := dto.WSEventResponse{ + ID: msg.ID, + Time: msg.CreatedAt.UnixMilli(), + Type: "message", + DetailType: params.DetailType, + Seq: strconv.FormatInt(msg.Seq, 10), + Segments: params.Segments, + SenderID: userID, + } + + response.Success(c, wsResponse) +} + +// HandleDeleteMsg 撤回消息 +// POST /api/v1/messages/delete_msg +// 请求体格式: {"message_id": "xxx"} +func (h *MessageHandler) HandleDeleteMsg(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.DeleteMsgParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 验证参数 + if params.MessageID == "" { + response.BadRequest(c, "message_id is required") + return + } + + // 撤回消息 + err := h.chatService.RecallMessage(c.Request.Context(), params.MessageID, userID) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "消息已撤回", nil) +} + +// HandleGetConversationList 获取会话列表 +// GET /api/v1/conversations/list +func (h *MessageHandler) HandleGetConversationList(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + convs, _, err := h.chatService.GetConversationList(c.Request.Context(), userID, page, pageSize) + if err != nil { + response.InternalServerError(c, "failed to get conversations") + return + } + + // 过滤掉系统会话(系统通知现在使用独立的表) + filteredConvs := make([]*model.Conversation, 0) + for _, conv := range convs { + if conv.ID != model.SystemConversationID { + filteredConvs = append(filteredConvs, conv) + } + } + + // 转换为响应格式 + result := make([]*dto.ConversationResponse, len(filteredConvs)) + for i, conv := range filteredConvs { + // 获取未读数 + unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conv.ID, userID) + + // 获取最后一条消息 + var lastMessage *model.Message + messages, _, _ := h.chatService.GetMessages(c.Request.Context(), conv.ID, userID, 1, 1) + if len(messages) > 0 { + lastMessage = messages[0] + } + + // 群聊时返回member_count,私聊时返回participants + var resp *dto.ConversationResponse + myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID) + isPinned := myParticipant != nil && myParticipant.IsPinned + if conv.Type == model.ConversationTypeGroup && conv.GroupID != nil && *conv.GroupID != "" { + // 群聊:实时计算群成员数量 + memberCount, _ := h.groupService.GetMemberCount(*conv.GroupID) + // 创建响应并设置member_count + resp = dto.ConvertConversationToResponse(conv, nil, int(unreadCount), lastMessage, isPinned) + resp.MemberCount = memberCount + } else { + // 私聊:获取参与者信息 + participants, _ := h.getConversationParticipants(c.Request.Context(), conv.ID, userID) + resp = dto.ConvertConversationToResponse(conv, participants, int(unreadCount), lastMessage, isPinned) + } + result[i] = resp + } + + response.Paginated(c, result, int64(len(filteredConvs)), page, pageSize) +} + +// HandleDeleteConversationForSelf 仅自己删除会话 +// DELETE /api/v1/conversations/:id/self +func (h *MessageHandler) HandleDeleteConversationForSelf(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + conversationID := c.Param("id") + if conversationID == "" { + response.BadRequest(c, "conversation id is required") + return + } + + if err := h.chatService.DeleteConversationForSelf(c.Request.Context(), conversationID, userID); err != nil { + response.BadRequest(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "conversation deleted for self", nil) +} + +// MarkAsRead 标记为已读 +// POST /api/conversations/:id/read +func (h *MessageHandler) MarkAsRead(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + conversationIDStr := c.Param("id") + conversationID, err := service.ParseConversationID(conversationIDStr) + if err != nil { + response.BadRequest(c, "invalid conversation id") + return + } + + var req dto.MarkReadRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "last_read_seq is required") + return + } + + err = h.chatService.MarkAsRead(c.Request.Context(), conversationID, userID, req.LastReadSeq) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "marked as read", nil) +} + +// GetUnreadCount 获取未读消息总数 +// GET /api/conversations/unread/count +func (h *MessageHandler) GetUnreadCount(c *gin.Context) { + userID := c.GetString("user_id") + // 添加调试日志 + fmt.Printf("[DEBUG] GetUnreadCount: user_id from context = %q\n", userID) + + if userID == "" { + fmt.Printf("[DEBUG] GetUnreadCount: user_id is empty, returning 401\n") + response.Unauthorized(c, "") + return + } + + count, err := h.chatService.GetAllUnreadCount(c.Request.Context(), userID) + if err != nil { + response.InternalServerError(c, "failed to get unread count") + return + } + + response.Success(c, &dto.UnreadCountResponse{ + TotalUnreadCount: count, + }) +} + +// GetConversationUnreadCount 获取单个会话的未读数 +// GET /api/conversations/:id/unread/count +func (h *MessageHandler) GetConversationUnreadCount(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + conversationIDStr := c.Param("id") + conversationID, err := service.ParseConversationID(conversationIDStr) + if err != nil { + response.BadRequest(c, "invalid conversation id") + return + } + + count, err := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + response.Success(c, &dto.ConversationUnreadCountResponse{ + ConversationID: conversationID, + UnreadCount: count, + }) +} + +// RecallMessage 撤回消息 +// POST /api/messages/:id/recall +func (h *MessageHandler) RecallMessage(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + messageIDStr := c.Param("id") + // 直接使用 string 类型的 messageID + err := h.chatService.RecallMessage(c.Request.Context(), messageIDStr, userID) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "message recalled", nil) +} + +// DeleteMessage 删除消息 +// DELETE /api/messages/:id +func (h *MessageHandler) DeleteMessage(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + messageIDStr := c.Param("id") + // 直接使用 string 类型的 messageID + err := h.chatService.DeleteMessage(c.Request.Context(), messageIDStr, userID) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "message deleted", nil) +} + +// 辅助函数:验证内容类型 +func isValidContentType(contentType model.ContentType) bool { + switch contentType { + case model.ContentTypeText, model.ContentTypeImage, model.ContentTypeVideo, model.ContentTypeAudio, model.ContentTypeFile: + return true + default: + return false + } +} + +// 辅助函数:获取会话参与者信息 +func (h *MessageHandler) getConversationParticipants(ctx context.Context, conversationID string, currentUserID string) ([]*model.User, error) { + // 从repository获取参与者列表 + participants, err := h.messageService.GetConversationParticipants(conversationID) + if err != nil { + return nil, err + } + + // 获取参与者用户信息 + var users []*model.User + for _, p := range participants { + // 跳过当前用户 + if p.UserID == currentUserID { + continue + } + user, err := h.userService.GetUserByID(ctx, p.UserID) + if err != nil { + continue + } + users = append(users, user) + } + return users, nil +} + +// 获取当前用户在会话中的参与者信息 +func (h *MessageHandler) getMyConversationParticipant(conversationID string, userID string) (*model.ConversationParticipant, error) { + participants, err := h.messageService.GetConversationParticipants(conversationID) + if err != nil { + return nil, err + } + for _, p := range participants { + if p.UserID == userID { + return p, nil + } + } + return nil, nil +} + +// ==================== RESTful Action 端点 ==================== + +// HandleCreateConversation 创建会话 +// POST /api/v1/conversations/create +func (h *MessageHandler) HandleCreateConversation(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.CreateConversationParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 验证目标用户是否存在 + targetUser, err := h.userService.GetUserByID(c.Request.Context(), params.UserID) + if err != nil { + response.BadRequest(c, "target user not found") + return + } + + // 不能和自己创建会话 + if userID == params.UserID { + response.BadRequest(c, "cannot create conversation with yourself") + return + } + + conv, err := h.chatService.GetOrCreateConversation(c.Request.Context(), userID, params.UserID) + if err != nil { + response.InternalServerError(c, "failed to create conversation") + return + } + + // 获取参与者信息 + participants := []*model.User{targetUser} + myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID) + isPinned := myParticipant != nil && myParticipant.IsPinned + + response.Success(c, dto.ConvertConversationToResponse(conv, participants, 0, nil, isPinned)) +} + +// HandleGetConversation 获取会话详情 +// GET /api/v1/conversations/get?conversation_id=xxx +func (h *MessageHandler) HandleGetConversation(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + conversationID := c.Query("conversation_id") + if conversationID == "" { + response.BadRequest(c, "conversation_id is required") + return + } + + conv, err := h.chatService.GetConversationByID(c.Request.Context(), conversationID, userID) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 获取未读数 + unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID) + + // 获取参与者信息 + participants, _ := h.getConversationParticipants(c.Request.Context(), conversationID, userID) + + // 获取当前用户的已读位置 + myLastReadSeq := int64(0) + isPinned := false + allParticipants, _ := h.messageService.GetConversationParticipants(conversationID) + for _, p := range allParticipants { + if p.UserID == userID { + myLastReadSeq = p.LastReadSeq + isPinned = p.IsPinned + break + } + } + + // 获取对方用户的已读位置 + otherLastReadSeq := int64(0) + response.Success(c, dto.ConvertConversationToDetailResponse(conv, participants, unreadCount, nil, myLastReadSeq, otherLastReadSeq, isPinned)) +} + +// HandleGetMessages 获取会话消息 +// GET /api/v1/conversations/get_messages?conversation_id=xxx +func (h *MessageHandler) HandleGetMessages(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + conversationID := c.Query("conversation_id") + if conversationID == "" { + response.BadRequest(c, "conversation_id is required") + return + } + + // 检查是否使用增量同步(after_seq参数) + afterSeqStr := c.Query("after_seq") + if afterSeqStr != "" { + // 增量同步模式 + afterSeq, err := strconv.ParseInt(afterSeqStr, 10, 64) + if err != nil { + response.BadRequest(c, "invalid after_seq") + return + } + + limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100")) + + messages, err := h.chatService.GetMessagesAfterSeq(c.Request.Context(), conversationID, userID, afterSeq, limit) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 转换为响应格式 + result := dto.ConvertMessagesToResponse(messages) + + response.Success(c, &dto.MessageSyncResponse{ + Messages: result, + HasMore: len(messages) == limit, + }) + return + } + + // 检查是否使用历史消息加载(before_seq参数) + beforeSeqStr := c.Query("before_seq") + if beforeSeqStr != "" { + // 加载更早的消息(下拉加载更多) + beforeSeq, err := strconv.ParseInt(beforeSeqStr, 10, 64) + if err != nil { + response.BadRequest(c, "invalid before_seq") + return + } + + limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20")) + + messages, err := h.chatService.GetMessagesBeforeSeq(c.Request.Context(), conversationID, userID, beforeSeq, limit) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 转换为响应格式 + result := dto.ConvertMessagesToResponse(messages) + + response.Success(c, &dto.MessageSyncResponse{ + Messages: result, + HasMore: len(messages) == limit, + }) + return + } + + // 分页模式 + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + messages, total, err := h.chatService.GetMessages(c.Request.Context(), conversationID, userID, page, pageSize) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 转换为响应格式 + result := dto.ConvertMessagesToResponse(messages) + + response.Paginated(c, result, total, page, pageSize) +} + +// HandleMarkRead 标记已读 +// POST /api/v1/conversations/mark_read +func (h *MessageHandler) HandleMarkRead(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.MarkReadParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.ConversationID == "" { + response.BadRequest(c, "conversation_id is required") + return + } + + err := h.chatService.MarkAsRead(c.Request.Context(), params.ConversationID, userID, params.LastReadSeq) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "marked as read", nil) +} + +// HandleSetConversationPinned 设置会话置顶 +// POST /api/v1/conversations/set_pinned +func (h *MessageHandler) HandleSetConversationPinned(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + var params dto.SetConversationPinnedParams + if err := c.ShouldBindJSON(¶ms); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if params.ConversationID == "" { + response.BadRequest(c, "conversation_id is required") + return + } + + if err := h.chatService.SetConversationPinned(c.Request.Context(), params.ConversationID, userID, params.IsPinned); err != nil { + response.BadRequest(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "conversation pinned status updated", gin.H{ + "conversation_id": params.ConversationID, + "is_pinned": params.IsPinned, + }) +} diff --git a/internal/handler/notification_handler.go b/internal/handler/notification_handler.go new file mode 100644 index 0000000..9c24722 --- /dev/null +++ b/internal/handler/notification_handler.go @@ -0,0 +1,132 @@ +package handler + +import ( + "strconv" + + "github.com/gin-gonic/gin" + + "carrot_bbs/internal/pkg/response" + "carrot_bbs/internal/service" +) + +// NotificationHandler 通知处理器 +type NotificationHandler struct { + notificationService *service.NotificationService +} + +// NewNotificationHandler 创建通知处理器 +func NewNotificationHandler(notificationService *service.NotificationService) *NotificationHandler { + return &NotificationHandler{ + notificationService: notificationService, + } +} + +// GetNotifications 获取通知列表 +func (h *NotificationHandler) GetNotifications(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + unreadOnly := c.Query("unread_only") == "true" + + notifications, total, err := h.notificationService.GetByUserID(c.Request.Context(), userID, page, pageSize, unreadOnly) + if err != nil { + response.InternalServerError(c, "failed to get notifications") + return + } + + response.Paginated(c, notifications, total, page, pageSize) +} + +// MarkAsRead 标记为已读 +func (h *NotificationHandler) MarkAsRead(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + id := c.Param("id") + + err := h.notificationService.MarkAsReadWithUserID(c.Request.Context(), id, userID) + if err != nil { + response.InternalServerError(c, "failed to mark as read") + return + } + + response.SuccessWithMessage(c, "marked as read", nil) +} + +// MarkAllAsRead 标记所有为已读 +func (h *NotificationHandler) MarkAllAsRead(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + err := h.notificationService.MarkAllAsRead(c.Request.Context(), userID) + if err != nil { + response.InternalServerError(c, "failed to mark all as read") + return + } + + response.SuccessWithMessage(c, "all marked as read", nil) +} + +// GetUnreadCount 获取未读数量 +func (h *NotificationHandler) GetUnreadCount(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + count, err := h.notificationService.GetUnreadCount(c.Request.Context(), userID) + if err != nil { + response.InternalServerError(c, "failed to get unread count") + return + } + + response.Success(c, gin.H{"count": count}) +} + +// DeleteNotification 删除通知 +func (h *NotificationHandler) DeleteNotification(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + id := c.Param("id") + + err := h.notificationService.DeleteNotification(c.Request.Context(), id, userID) + if err != nil { + response.InternalServerError(c, "failed to delete notification") + return + } + + response.Success(c, gin.H{"success": true}) +} + +// ClearAllNotifications 清空所有通知 +func (h *NotificationHandler) ClearAllNotifications(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + err := h.notificationService.ClearAllNotifications(c.Request.Context(), userID) + if err != nil { + response.InternalServerError(c, "failed to clear notifications") + return + } + + response.Success(c, gin.H{"success": true}) +} diff --git a/internal/handler/post_handler.go b/internal/handler/post_handler.go new file mode 100644 index 0000000..7ecbb0e --- /dev/null +++ b/internal/handler/post_handler.go @@ -0,0 +1,511 @@ +package handler + +import ( + "errors" + "fmt" + "strconv" + + "github.com/gin-gonic/gin" + + "carrot_bbs/internal/dto" + "carrot_bbs/internal/model" + "carrot_bbs/internal/pkg/response" + "carrot_bbs/internal/service" +) + +// PostHandler 帖子处理器 +type PostHandler struct { + postService *service.PostService + userService *service.UserService +} + +// NewPostHandler 创建帖子处理器 +func NewPostHandler(postService *service.PostService, userService *service.UserService) *PostHandler { + return &PostHandler{ + postService: postService, + userService: userService, + } +} + +// Create 创建帖子 +func (h *PostHandler) Create(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + type CreateRequest struct { + Title string `json:"title" binding:"required"` + Content string `json:"content" binding:"required"` + Images []string `json:"images"` + } + + var req CreateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + post, err := h.postService.Create(c.Request.Context(), userID, req.Title, req.Content, req.Images) + if err != nil { + var moderationErr *service.PostModerationRejectedError + if errors.As(err, &moderationErr) { + response.BadRequest(c, moderationErr.UserMessage()) + return + } + response.InternalServerError(c, "failed to create post") + return + } + + response.Success(c, dto.ConvertPostToResponse(post, false, false)) +} + +// GetByID 获取帖子(不增加浏览量) +func (h *PostHandler) GetByID(c *gin.Context) { + id := c.Param("id") + + post, err := h.postService.GetByID(c.Request.Context(), id) + if err != nil { + response.NotFound(c, "post not found") + return + } + + // 非作者不可查看未发布内容 + currentUserID := c.GetString("user_id") + if post.Status != model.PostStatusPublished && post.UserID != currentUserID { + response.NotFound(c, "post not found") + return + } + + // 注意:不再自动增加浏览量,浏览量通过 RecordView 端点单独记录 + + // 获取当前用户ID用于判断点赞和收藏状态 + fmt.Printf("[DEBUG] GetByID - postID: %s, currentUserID: %s\n", id, currentUserID) + + var isLiked, isFavorited bool + if currentUserID != "" { + isLiked = h.postService.IsLiked(c.Request.Context(), id, currentUserID) + isFavorited = h.postService.IsFavorited(c.Request.Context(), id, currentUserID) + fmt.Printf("[DEBUG] GetByID - isLiked: %v, isFavorited: %v\n", isLiked, isFavorited) + } else { + fmt.Printf("[DEBUG] GetByID - user not logged in, isLiked: false, isFavorited: false\n") + } + + // 如果有当前用户,检查与帖子作者的相互关注状态 + var authorWithFollowStatus *dto.UserResponse + if currentUserID != "" && post.User != nil { + _, isFollowing, isFollowingMe, err := h.userService.GetUserByIDWithMutualFollowStatus(c.Request.Context(), post.UserID, currentUserID) + if err == nil { + authorWithFollowStatus = dto.ConvertUserToResponseWithMutualFollow(post.User, isFollowing, isFollowingMe) + } else { + // 如果出错,使用默认的author + authorWithFollowStatus = dto.ConvertUserToResponse(post.User) + } + } + + // 构建响应 + responseData := &dto.PostResponse{ + ID: post.ID, + UserID: post.UserID, + Title: post.Title, + Content: post.Content, + Images: dto.ConvertPostImagesToResponse(post.Images), + LikesCount: post.LikesCount, + CommentsCount: post.CommentsCount, + FavoritesCount: post.FavoritesCount, + SharesCount: post.SharesCount, + ViewsCount: post.ViewsCount, + IsPinned: post.IsPinned, + IsLocked: post.IsLocked, + IsVote: post.IsVote, + CreatedAt: dto.FormatTime(post.CreatedAt), + Author: authorWithFollowStatus, + IsLiked: isLiked, + IsFavorited: isFavorited, + } + + response.Success(c, responseData) +} + +// RecordView 记录帖子浏览(增加浏览量) +func (h *PostHandler) RecordView(c *gin.Context) { + id := c.Param("id") + userID := c.GetString("user_id") + + // 验证帖子存在 + _, err := h.postService.GetByID(c.Request.Context(), id) + if err != nil { + response.NotFound(c, "post not found") + return + } + + // 增加浏览量 + if err := h.postService.IncrementViews(c.Request.Context(), id, userID); err != nil { + fmt.Printf("[DEBUG] Failed to increment views for post %s: %v\n", id, err) + response.InternalServerError(c, "failed to record view") + return + } + + response.Success(c, gin.H{"success": true}) +} + +// List 获取帖子列表 +func (h *PostHandler) List(c *gin.Context) { + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + userID := c.Query("user_id") + tab := c.Query("tab") // recommend, follow, hot, latest + + // 获取当前用户ID + currentUserID := c.GetString("user_id") + + var posts []*model.Post + var total int64 + var err error + + // 根据 tab 参数选择不同的获取方式 + switch tab { + case "follow": + // 获取关注用户的帖子,需要登录 + if currentUserID == "" { + response.Unauthorized(c, "请先登录") + return + } + posts, total, err = h.postService.GetFollowingPosts(c.Request.Context(), currentUserID, page, pageSize) + case "hot": + // 获取热门帖子 + posts, total, err = h.postService.GetHotPosts(c.Request.Context(), page, pageSize) + case "recommend": + // 推荐帖子(从Gorse获取个性化推荐) + posts, total, err = h.postService.GetRecommendedPosts(c.Request.Context(), currentUserID, page, pageSize) + case "latest": + // 最新帖子 + posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID) + default: + // 默认获取最新帖子 + posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID) + } + + if err != nil { + response.InternalServerError(c, "failed to get posts") + return + } + + fmt.Printf("[DEBUG] List - tab: %s, currentUserID: %s, posts count: %d\n", tab, currentUserID, len(posts)) + + isLikedMap := make(map[string]bool) + isFavoritedMap := make(map[string]bool) + if currentUserID != "" { + for _, post := range posts { + isLiked := h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID) + isFavorited := h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID) + isLikedMap[post.ID] = isLiked + isFavoritedMap[post.ID] = isFavorited + fmt.Printf("[DEBUG] List - postID: %s, isLiked: %v, isFavorited: %v\n", post.ID, isLiked, isFavorited) + } + } else { + fmt.Printf("[DEBUG] List - user not logged in\n") + } + + // 转换为响应结构 + postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap) + + response.Paginated(c, postResponses, total, page, pageSize) +} + +// Update 更新帖子 +func (h *PostHandler) Update(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + id := c.Param("id") + + post, err := h.postService.GetByID(c.Request.Context(), id) + if err != nil { + response.NotFound(c, "post not found") + return + } + + if post.UserID != userID { + response.Forbidden(c, "cannot update others' post") + return + } + + type UpdateRequest struct { + Title string `json:"title"` + Content string `json:"content"` + } + + var req UpdateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if req.Title != "" { + post.Title = req.Title + } + if req.Content != "" { + post.Content = req.Content + } + + err = h.postService.Update(c.Request.Context(), post) + if err != nil { + response.InternalServerError(c, "failed to update post") + return + } + + currentUserID := c.GetString("user_id") + var isLiked, isFavorited bool + if currentUserID != "" { + isLiked = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID) + isFavorited = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID) + } + + response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited)) +} + +// Delete 删除帖子 +func (h *PostHandler) Delete(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + id := c.Param("id") + + post, err := h.postService.GetByID(c.Request.Context(), id) + if err != nil { + response.NotFound(c, "post not found") + return + } + + if post.UserID != userID { + response.Forbidden(c, "cannot delete others' post") + return + } + + err = h.postService.Delete(c.Request.Context(), id) + if err != nil { + response.InternalServerError(c, "failed to delete post") + return + } + + response.SuccessWithMessage(c, "post deleted", nil) +} + +// Like 点赞帖子 +func (h *PostHandler) Like(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + id := c.Param("id") + fmt.Printf("[DEBUG] Like - postID: %s, userID: %s\n", id, userID) + + err := h.postService.Like(c.Request.Context(), id, userID) + if err != nil { + response.InternalServerError(c, "failed to like post") + return + } + + // 获取更新后的帖子状态 + post, err := h.postService.GetByID(c.Request.Context(), id) + if err != nil { + response.InternalServerError(c, "failed to get post") + return + } + + isLiked := h.postService.IsLiked(c.Request.Context(), id, userID) + isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID) + fmt.Printf("[DEBUG] Like - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited) + + response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited)) +} + +// Unlike 取消点赞 +func (h *PostHandler) Unlike(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + id := c.Param("id") + fmt.Printf("[DEBUG] Unlike - postID: %s, userID: %s\n", id, userID) + + err := h.postService.Unlike(c.Request.Context(), id, userID) + if err != nil { + response.InternalServerError(c, "failed to unlike post") + return + } + + // 获取更新后的帖子状态 + post, err := h.postService.GetByID(c.Request.Context(), id) + if err != nil { + response.InternalServerError(c, "failed to get post") + return + } + + isLiked := h.postService.IsLiked(c.Request.Context(), id, userID) + isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID) + fmt.Printf("[DEBUG] Unlike - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited) + + response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited)) +} + +// Favorite 收藏帖子 +func (h *PostHandler) Favorite(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + id := c.Param("id") + fmt.Printf("[DEBUG] Favorite - postID: %s, userID: %s\n", id, userID) + + err := h.postService.Favorite(c.Request.Context(), id, userID) + if err != nil { + response.InternalServerError(c, "failed to favorite post") + return + } + + // 获取更新后的帖子状态 + post, err := h.postService.GetByID(c.Request.Context(), id) + if err != nil { + response.InternalServerError(c, "failed to get post") + return + } + + isLiked := h.postService.IsLiked(c.Request.Context(), id, userID) + isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID) + fmt.Printf("[DEBUG] Favorite - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited) + + response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited)) +} + +// Unfavorite 取消收藏 +func (h *PostHandler) Unfavorite(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + id := c.Param("id") + fmt.Printf("[DEBUG] Unfavorite - postID: %s, userID: %s\n", id, userID) + + err := h.postService.Unfavorite(c.Request.Context(), id, userID) + if err != nil { + response.InternalServerError(c, "failed to unfavorite post") + return + } + + // 获取更新后的帖子状态 + post, err := h.postService.GetByID(c.Request.Context(), id) + if err != nil { + response.InternalServerError(c, "failed to get post") + return + } + + isLiked := h.postService.IsLiked(c.Request.Context(), id, userID) + isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID) + fmt.Printf("[DEBUG] Unfavorite - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited) + + response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited)) +} + +// GetUserPosts 获取用户帖子列表 +func (h *PostHandler) GetUserPosts(c *gin.Context) { + userID := c.Param("id") + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + posts, total, err := h.postService.GetUserPosts(c.Request.Context(), userID, page, pageSize) + if err != nil { + response.InternalServerError(c, "failed to get user posts") + return + } + + // 获取当前用户ID用于判断点赞和收藏状态 + currentUserID := c.GetString("user_id") + isLikedMap := make(map[string]bool) + isFavoritedMap := make(map[string]bool) + if currentUserID != "" { + for _, post := range posts { + isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID) + isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID) + } + } + + // 转换为响应结构 + postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap) + + response.Paginated(c, postResponses, total, page, pageSize) +} + +// GetFavorites 获取收藏列表 +func (h *PostHandler) GetFavorites(c *gin.Context) { + userID := c.Param("id") + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + posts, total, err := h.postService.GetFavorites(c.Request.Context(), userID, page, pageSize) + if err != nil { + response.InternalServerError(c, "failed to get favorites") + return + } + + // 获取当前用户ID用于判断点赞和收藏状态 + currentUserID := c.GetString("user_id") + isLikedMap := make(map[string]bool) + isFavoritedMap := make(map[string]bool) + if currentUserID != "" { + for _, post := range posts { + isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID) + isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID) + } + } + + // 转换为响应结构 + postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap) + + response.Paginated(c, postResponses, total, page, pageSize) +} + +// Search 搜索帖子 +func (h *PostHandler) Search(c *gin.Context) { + keyword := c.Query("keyword") + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + posts, total, err := h.postService.Search(c.Request.Context(), keyword, page, pageSize) + if err != nil { + response.InternalServerError(c, "failed to search posts") + return + } + + // 获取当前用户ID用于判断点赞和收藏状态 + currentUserID := c.GetString("user_id") + isLikedMap := make(map[string]bool) + isFavoritedMap := make(map[string]bool) + if currentUserID != "" { + for _, post := range posts { + isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID) + isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID) + } + } + + // 转换为响应结构 + postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap) + + response.Paginated(c, postResponses, total, page, pageSize) +} diff --git a/internal/handler/push_handler.go b/internal/handler/push_handler.go new file mode 100644 index 0000000..878f35a --- /dev/null +++ b/internal/handler/push_handler.go @@ -0,0 +1,157 @@ +package handler + +import ( + "carrot_bbs/internal/dto" + "carrot_bbs/internal/model" + "carrot_bbs/internal/pkg/response" + "carrot_bbs/internal/service" + + "github.com/gin-gonic/gin" +) + +// PushHandler 推送处理器 +type PushHandler struct { + pushService service.PushService +} + +// NewPushHandler 创建推送处理器 +func NewPushHandler(pushService service.PushService) *PushHandler { + return &PushHandler{ + pushService: pushService, + } +} + +// RegisterDevice 注册设备 +// POST /api/v1/push/devices +func (h *PushHandler) RegisterDevice(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + var req dto.RegisterDeviceRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 验证设备类型 + deviceType := model.DeviceType(req.DeviceType) + if !isValidDeviceType(deviceType) { + response.BadRequest(c, "invalid device type") + return + } + + err := h.pushService.RegisterDevice(c.Request.Context(), userID, req.DeviceID, deviceType, req.PushToken) + if err != nil { + response.InternalServerError(c, "failed to register device") + return + } + + response.SuccessWithMessage(c, "device registered successfully", nil) +} + +// UnregisterDevice 注销设备 +// DELETE /api/v1/push/devices/:device_id +func (h *PushHandler) UnregisterDevice(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + deviceID := c.Param("device_id") + if deviceID == "" { + response.BadRequest(c, "device_id is required") + return + } + + err := h.pushService.UnregisterDevice(c.Request.Context(), deviceID) + if err != nil { + response.InternalServerError(c, "failed to unregister device") + return + } + + response.SuccessWithMessage(c, "device unregistered successfully", nil) +} + +// GetMyDevices 获取当前用户的设备列表 +// GET /api/v1/push/devices +func (h *PushHandler) GetMyDevices(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + // 这里需要从DeviceTokenRepository获取设备列表 + // 由于PushService接口没有提供获取设备列表的方法,我们暂时返回空列表 + // TODO: 在PushService接口中添加GetUserDevices方法 + _ = userID // 避免未使用变量警告 + + response.Success(c, []*dto.DeviceTokenResponse{}) +} + +// GetPushRecords 获取推送记录 +// GET /api/v1/push/records +func (h *PushHandler) GetPushRecords(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + records, err := h.pushService.GetPendingPushes(c.Request.Context(), userID) + if err != nil { + response.InternalServerError(c, "failed to get push records") + return + } + + response.Success(c, &dto.PushRecordListResponse{ + Records: dto.PushRecordsToResponse(records), + Total: int64(len(records)), + }) +} + +// 辅助函数:验证设备类型 +func isValidDeviceType(deviceType model.DeviceType) bool { + switch deviceType { + case model.DeviceTypeIOS, model.DeviceTypeAndroid, model.DeviceTypeWeb: + return true + default: + return false + } +} + +// UpdateDeviceToken 更新设备推送Token +// PUT /api/v1/push/devices/:device_id/token +func (h *PushHandler) UpdateDeviceToken(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + deviceID := c.Param("device_id") + if deviceID == "" { + response.BadRequest(c, "device_id is required") + return + } + + var req struct { + PushToken string `json:"push_token" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + err := h.pushService.UpdateDeviceToken(c.Request.Context(), deviceID, req.PushToken) + if err != nil { + response.InternalServerError(c, "failed to update device token") + return + } + + response.SuccessWithMessage(c, "device token updated successfully", nil) +} diff --git a/internal/handler/sticker_handler.go b/internal/handler/sticker_handler.go new file mode 100644 index 0000000..bb9084a --- /dev/null +++ b/internal/handler/sticker_handler.go @@ -0,0 +1,164 @@ +package handler + +import ( + "net/http" + + "github.com/gin-gonic/gin" + + "carrot_bbs/internal/pkg/response" + "carrot_bbs/internal/service" +) + +// StickerHandler 自定义表情处理器 +type StickerHandler struct { + stickerService service.StickerService +} + +// NewStickerHandler 创建自定义表情处理器 +func NewStickerHandler(stickerService service.StickerService) *StickerHandler { + return &StickerHandler{ + stickerService: stickerService, + } +} + +// GetStickersRequest 获取表情列表请求 +type GetStickersRequest struct { + UserID string `form:"user_id" binding:"required"` +} + +// AddStickerRequest 添加表情请求 +type AddStickerRequest struct { + URL string `json:"url" binding:"required"` + Width int `json:"width"` + Height int `json:"height"` +} + +// DeleteStickerRequest 删除表情请求 +type DeleteStickerRequest struct { + StickerID string `json:"sticker_id" binding:"required"` +} + +// ReorderStickersRequest 重新排序请求 +type ReorderStickersRequest struct { + Orders map[string]int `json:"orders" binding:"required"` +} + +// CheckStickerRequest 检查表情是否存在请求 +type CheckStickerRequest struct { + URL string `form:"url" binding:"required"` +} + +// GetStickers 获取用户的表情列表 +func (h *StickerHandler) GetStickers(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + stickers, err := h.stickerService.GetUserStickers(userID) + if err != nil { + response.InternalServerError(c, "failed to get stickers") + return + } + + response.Success(c, gin.H{"stickers": stickers}) +} + +// AddSticker 添加表情 +func (h *StickerHandler) AddSticker(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + var req AddStickerRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + sticker, err := h.stickerService.AddSticker(userID, req.URL, req.Width, req.Height) + if err != nil { + if err == service.ErrStickerAlreadyExists { + response.Error(c, http.StatusConflict, "sticker already exists") + return + } + if err == service.ErrInvalidStickerURL { + response.BadRequest(c, "invalid sticker url, only http/https is allowed") + return + } + response.InternalServerError(c, err.Error()) + return + } + + response.Success(c, gin.H{"sticker": sticker}) +} + +// DeleteSticker 删除表情 +func (h *StickerHandler) DeleteSticker(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + var req DeleteStickerRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if err := h.stickerService.DeleteSticker(userID, req.StickerID); err != nil { + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "sticker deleted successfully", nil) +} + +// ReorderStickers 重新排序表情 +func (h *StickerHandler) ReorderStickers(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + var req ReorderStickersRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if err := h.stickerService.ReorderStickers(userID, req.Orders); err != nil { + response.InternalServerError(c, err.Error()) + return + } + + response.SuccessWithMessage(c, "stickers reordered successfully", nil) +} + +// CheckStickerExists 检查表情是否存在 +func (h *StickerHandler) CheckStickerExists(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + url := c.Query("url") + if url == "" { + response.BadRequest(c, "url is required") + return + } + + exists, err := h.stickerService.CheckExists(userID, url) + if err != nil { + response.InternalServerError(c, err.Error()) + return + } + + response.Success(c, gin.H{"exists": exists}) +} diff --git a/internal/handler/system_message_handler.go b/internal/handler/system_message_handler.go new file mode 100644 index 0000000..7c9b82c --- /dev/null +++ b/internal/handler/system_message_handler.go @@ -0,0 +1,154 @@ +package handler + +import ( + "strconv" + + "carrot_bbs/internal/cache" + "github.com/gin-gonic/gin" + + "carrot_bbs/internal/dto" + "carrot_bbs/internal/pkg/response" + "carrot_bbs/internal/repository" + "carrot_bbs/internal/service" +) + +// SystemMessageHandler 系统消息处理器 +type SystemMessageHandler struct { + systemMsgService service.SystemMessageService + notifyRepo *repository.SystemNotificationRepository +} + +// NewSystemMessageHandler 创建系统消息处理器 +func NewSystemMessageHandler( + systemMsgService service.SystemMessageService, + notifyRepo *repository.SystemNotificationRepository, +) *SystemMessageHandler { + return &SystemMessageHandler{ + systemMsgService: systemMsgService, + notifyRepo: notifyRepo, + } +} + +// GetSystemMessages 获取系统消息列表 +// GET /api/v1/messages/system +func (h *SystemMessageHandler) GetSystemMessages(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + // 获取当前用户的系统通知(从独立表中获取) + notifications, total, err := h.notifyRepo.GetByReceiverID(userID, page, pageSize) + if err != nil { + response.InternalServerError(c, "failed to get system messages") + return + } + + // 转换为响应格式 + result := make([]*dto.SystemMessageResponse, 0) + for _, n := range notifications { + resp := dto.SystemNotificationToResponse(n) + result = append(result, resp) + } + + response.Paginated(c, result, total, page, pageSize) +} + +// GetUnreadCount 获取系统消息未读数 +// GET /api/v1/messages/system/unread-count +func (h *SystemMessageHandler) GetUnreadCount(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + // 获取当前用户的未读通知数 + unreadCount, err := h.notifyRepo.GetUnreadCount(userID) + if err != nil { + response.InternalServerError(c, "failed to get unread count") + return + } + + response.Success(c, &dto.SystemUnreadCountResponse{ + UnreadCount: unreadCount, + }) +} + +// MarkAsRead 标记系统消息为已读 +// PUT /api/v1/messages/system/:id/read +func (h *SystemMessageHandler) MarkAsRead(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + notificationIDStr := c.Param("id") + notificationID, err := strconv.ParseInt(notificationIDStr, 10, 64) + if err != nil { + response.BadRequest(c, "invalid notification id") + return + } + + // 标记为已读 + err = h.notifyRepo.MarkAsRead(notificationID, userID) + if err != nil { + response.InternalServerError(c, "failed to mark as read") + return + } + cache.InvalidateUnreadSystem(cache.GetCache(), userID) + + response.SuccessWithMessage(c, "marked as read", nil) +} + +// MarkAllAsRead 标记所有系统消息为已读 +// PUT /api/v1/messages/system/read-all +func (h *SystemMessageHandler) MarkAllAsRead(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + // 标记当前用户所有通知为已读 + err := h.notifyRepo.MarkAllAsRead(userID) + if err != nil { + response.InternalServerError(c, "failed to mark all as read") + return + } + cache.InvalidateUnreadSystem(cache.GetCache(), userID) + + response.SuccessWithMessage(c, "all messages marked as read", nil) +} + +// DeleteSystemMessage 删除系统消息 +// DELETE /api/v1/messages/system/:id +func (h *SystemMessageHandler) DeleteSystemMessage(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + notificationIDStr := c.Param("id") + notificationID, err := strconv.ParseInt(notificationIDStr, 10, 64) + if err != nil { + response.BadRequest(c, "invalid notification id") + return + } + + // 删除通知 + err = h.notifyRepo.Delete(notificationID, userID) + if err != nil { + response.InternalServerError(c, "failed to delete notification") + return + } + cache.InvalidateUnreadSystem(cache.GetCache(), userID) + + response.SuccessWithMessage(c, "notification deleted", nil) +} diff --git a/internal/handler/upload_handler.go b/internal/handler/upload_handler.go new file mode 100644 index 0000000..b57409f --- /dev/null +++ b/internal/handler/upload_handler.go @@ -0,0 +1,90 @@ +package handler + +import ( + "github.com/gin-gonic/gin" + + "carrot_bbs/internal/pkg/response" + "carrot_bbs/internal/service" +) + +// UploadHandler 上传处理器 +type UploadHandler struct { + uploadService *service.UploadService +} + +// NewUploadHandler 创建上传处理器 +func NewUploadHandler(uploadService *service.UploadService) *UploadHandler { + return &UploadHandler{ + uploadService: uploadService, + } +} + +// UploadImage 上传图片 +func (h *UploadHandler) UploadImage(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + file, err := c.FormFile("image") + if err != nil { + response.BadRequest(c, "image file is required") + return + } + + url, err := h.uploadService.UploadImage(c.Request.Context(), file) + if err != nil { + response.InternalServerError(c, "failed to upload image") + return + } + + response.Success(c, gin.H{"url": url}) +} + +// UploadAvatar 上传头像 +func (h *UploadHandler) UploadAvatar(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + file, err := c.FormFile("image") + + if err != nil { + response.BadRequest(c, "avatar file is required") + return + } + + url, err := h.uploadService.UploadAvatar(c.Request.Context(), userID, file) + if err != nil { + response.InternalServerError(c, "failed to upload avatar") + return + } + + response.Success(c, gin.H{"url": url}) +} + +// UploadCover 上传头图(个人主页封面) +func (h *UploadHandler) UploadCover(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + file, err := c.FormFile("image") + if err != nil { + response.BadRequest(c, "image file is required") + return + } + + url, err := h.uploadService.UploadCover(c.Request.Context(), userID, file) + if err != nil { + response.InternalServerError(c, "failed to upload cover") + return + } + + response.Success(c, gin.H{"url": url}) +} diff --git a/internal/handler/user_handler.go b/internal/handler/user_handler.go new file mode 100644 index 0000000..3faf4fb --- /dev/null +++ b/internal/handler/user_handler.go @@ -0,0 +1,705 @@ +package handler + +import ( + "fmt" + "strconv" + + "github.com/gin-gonic/gin" + + "carrot_bbs/internal/dto" + "carrot_bbs/internal/pkg/response" + "carrot_bbs/internal/service" +) + +// UserHandler 用户处理器 +type UserHandler struct { + userService *service.UserService + jwtService *service.JWTService +} + +// NewUserHandler 创建用户处理器 +func NewUserHandler(userService *service.UserService) *UserHandler { + return &UserHandler{ + userService: userService, + } +} + +// Register 用户注册 +func (h *UserHandler) Register(c *gin.Context) { + type RegisterRequest struct { + Username string `json:"username" binding:"required"` + Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required,min=6"` + Nickname string `json:"nickname" binding:"required"` + Phone string `json:"phone"` + VerificationCode string `json:"verification_code" binding:"required"` + } + + var req RegisterRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + user, err := h.userService.Register(c.Request.Context(), req.Username, req.Email, req.Password, req.Nickname, req.Phone, req.VerificationCode) + if err != nil { + if se, ok := err.(*service.ServiceError); ok { + response.Error(c, se.Code, se.Message) + return + } + response.InternalServerError(c, "failed to register") + return + } + + // 生成Token + accessToken, _ := h.jwtService.GenerateAccessToken(user.ID, user.Username) + refreshToken, _ := h.jwtService.GenerateRefreshToken(user.ID, user.Username) + + response.Success(c, gin.H{ + "user": dto.ConvertUserToResponse(user), + "token": accessToken, + "refresh_token": refreshToken, + }) +} + +// Login 用户登录 +func (h *UserHandler) Login(c *gin.Context) { + type LoginRequest struct { + Username string `json:"username"` + Account string `json:"account"` + Password string `json:"password" binding:"required"` + } + + var req LoginRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + account := req.Account + if account == "" { + account = req.Username + } + if account == "" { + response.BadRequest(c, "username or account is required") + return + } + + user, err := h.userService.Login(c.Request.Context(), account, req.Password) + if err != nil { + if se, ok := err.(*service.ServiceError); ok { + response.Error(c, se.Code, se.Message) + return + } + response.InternalServerError(c, "failed to login") + return + } + + // 生成Token + accessToken, _ := h.jwtService.GenerateAccessToken(user.ID, user.Username) + refreshToken, _ := h.jwtService.GenerateRefreshToken(user.ID, user.Username) + + response.Success(c, gin.H{ + "user": dto.ConvertUserToResponse(user), + "token": accessToken, + "refresh_token": refreshToken, + }) +} + +// SendRegisterCode 发送注册验证码 +func (h *UserHandler) SendRegisterCode(c *gin.Context) { + type SendCodeRequest struct { + Email string `json:"email" binding:"required,email"` + } + + var req SendCodeRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if err := h.userService.SendRegisterCode(c.Request.Context(), req.Email); err != nil { + if se, ok := err.(*service.ServiceError); ok { + response.Error(c, se.Code, se.Message) + return + } + response.InternalServerError(c, "failed to send register verification code") + return + } + + response.Success(c, gin.H{"success": true}) +} + +// SendPasswordResetCode 发送找回密码验证码 +func (h *UserHandler) SendPasswordResetCode(c *gin.Context) { + type SendCodeRequest struct { + Email string `json:"email" binding:"required,email"` + } + + var req SendCodeRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if err := h.userService.SendPasswordResetCode(c.Request.Context(), req.Email); err != nil { + if se, ok := err.(*service.ServiceError); ok { + response.Error(c, se.Code, se.Message) + return + } + response.InternalServerError(c, "failed to send reset verification code") + return + } + + response.Success(c, gin.H{"success": true}) +} + +// ResetPassword 找回密码并重置 +func (h *UserHandler) ResetPassword(c *gin.Context) { + type ResetPasswordRequest struct { + Email string `json:"email" binding:"required,email"` + VerificationCode string `json:"verification_code" binding:"required"` + NewPassword string `json:"new_password" binding:"required,min=6"` + } + + var req ResetPasswordRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if err := h.userService.ResetPasswordByEmail(c.Request.Context(), req.Email, req.VerificationCode, req.NewPassword); err != nil { + if se, ok := err.(*service.ServiceError); ok { + response.Error(c, se.Code, se.Message) + return + } + response.InternalServerError(c, "failed to reset password") + return + } + + response.Success(c, gin.H{"success": true}) +} + +// GetCurrentUser 获取当前用户 +func (h *UserHandler) GetCurrentUser(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + user, err := h.userService.GetUserByID(c.Request.Context(), userID) + if err != nil { + response.NotFound(c, "user not found") + return + } + + // 实时计算帖子数量 + postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), userID) + if err != nil { + // 如果获取失败,使用数据库中的值 + postsCount = int64(user.PostsCount) + } + + response.Success(c, dto.ConvertUserToDetailResponseWithPostsCount(user, int(postsCount))) +} + +// GetUserByID 获取指定用户 +func (h *UserHandler) GetUserByID(c *gin.Context) { + id := c.Param("id") + currentUserID := c.GetString("user_id") + + // 获取用户信息,包含双向关注状态 + user, isFollowing, isFollowingMe, err := h.userService.GetUserByIDWithMutualFollowStatus(c.Request.Context(), id, currentUserID) + if err != nil { + response.NotFound(c, "user not found") + return + } + + // 实时计算帖子数量 + postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), id) + if err != nil { + // 如果获取失败,使用数据库中的值 + postsCount = int64(user.PostsCount) + } + + // 转换为响应格式,包含双向关注状态和实时计算的帖子数量 + userResponse := dto.ConvertUserToResponseWithMutualFollowAndPostsCount(user, isFollowing, isFollowingMe, int(postsCount)) + + response.Success(c, userResponse) +} + +// UpdateUser 更新用户 +func (h *UserHandler) UpdateUser(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + type UpdateRequest struct { + Nickname string `json:"nickname"` + Bio string `json:"bio"` + Website string `json:"website"` + Location string `json:"location"` + Avatar string `json:"avatar"` + Phone *string `json:"phone"` + Email *string `json:"email"` + } + + var req UpdateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + user, err := h.userService.GetUserByID(c.Request.Context(), userID) + if err != nil { + response.NotFound(c, "user not found") + return + } + + if req.Nickname != "" { + user.Nickname = req.Nickname + } + if req.Bio != "" { + user.Bio = req.Bio + } + if req.Website != "" { + user.Website = req.Website + } + if req.Location != "" { + user.Location = req.Location + } + if req.Avatar != "" { + user.Avatar = req.Avatar + } + if req.Phone != nil { + user.Phone = req.Phone + } + if req.Email != nil { + if user.Email == nil || *user.Email != *req.Email { + user.EmailVerified = false + } + user.Email = req.Email + } + + err = h.userService.UpdateUser(c.Request.Context(), user) + if err != nil { + response.InternalServerError(c, "failed to update user") + return + } + + // 实时计算帖子数量 + postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), userID) + if err != nil { + // 如果获取失败,使用数据库中的值 + postsCount = int64(user.PostsCount) + } + + response.Success(c, dto.ConvertUserToDetailResponseWithPostsCount(user, int(postsCount))) +} + +// SendEmailVerifyCode 发送当前用户邮箱验证码 +func (h *UserHandler) SendEmailVerifyCode(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + type SendCodeRequest struct { + Email string `json:"email" binding:"required,email"` + } + var req SendCodeRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if err := h.userService.SendCurrentUserEmailVerifyCode(c.Request.Context(), userID, req.Email); err != nil { + if se, ok := err.(*service.ServiceError); ok { + response.Error(c, se.Code, se.Message) + return + } + response.InternalServerError(c, "failed to send email verify code") + return + } + + response.Success(c, gin.H{"success": true}) +} + +// VerifyEmail 验证当前用户邮箱 +func (h *UserHandler) VerifyEmail(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + type VerifyEmailRequest struct { + Email string `json:"email" binding:"required,email"` + VerificationCode string `json:"verification_code" binding:"required"` + } + var req VerifyEmailRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if err := h.userService.VerifyCurrentUserEmail(c.Request.Context(), userID, req.Email, req.VerificationCode); err != nil { + if se, ok := err.(*service.ServiceError); ok { + response.Error(c, se.Code, se.Message) + return + } + response.InternalServerError(c, "failed to verify email") + return + } + + response.Success(c, gin.H{"success": true}) +} + +// RefreshToken 刷新Token +func (h *UserHandler) RefreshToken(c *gin.Context) { + type RefreshRequest struct { + RefreshToken string `json:"refresh_token" binding:"required"` + } + + var req RefreshRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 解析 refresh token + claims, err := h.jwtService.ParseToken(req.RefreshToken) + if err != nil { + response.Unauthorized(c, "invalid refresh token") + return + } + + // 生成新 token + accessToken, _ := h.jwtService.GenerateAccessToken(claims.UserID, claims.Username) + refreshToken, _ := h.jwtService.GenerateRefreshToken(claims.UserID, claims.Username) + + response.Success(c, gin.H{ + "token": accessToken, + "refresh_token": refreshToken, + }) +} + +// SetJWTService 设置JWT服务 +func (h *UserHandler) SetJWTService(jwtSvc *service.JWTService) { + h.jwtService = jwtSvc +} + +// FollowUser 关注用户 +func (h *UserHandler) FollowUser(c *gin.Context) { + userID := c.Param("id") + currentUserID := c.GetString("user_id") + + if userID == currentUserID { + response.BadRequest(c, "cannot follow yourself") + return + } + + err := h.userService.FollowUser(c.Request.Context(), currentUserID, userID) + if err != nil { + response.InternalServerError(c, "failed to follow user") + return + } + + response.Success(c, gin.H{"success": true}) +} + +// UnfollowUser 取消关注用户 +func (h *UserHandler) UnfollowUser(c *gin.Context) { + userID := c.Param("id") + currentUserID := c.GetString("user_id") + + err := h.userService.UnfollowUser(c.Request.Context(), currentUserID, userID) + if err != nil { + response.InternalServerError(c, "failed to unfollow user") + return + } + + response.Success(c, gin.H{"success": true}) +} + +// BlockUser 拉黑用户 +func (h *UserHandler) BlockUser(c *gin.Context) { + targetUserID := c.Param("id") + currentUserID := c.GetString("user_id") + + if targetUserID == currentUserID { + response.BadRequest(c, "cannot block yourself") + return + } + + err := h.userService.BlockUser(c.Request.Context(), currentUserID, targetUserID) + if err != nil { + if se, ok := err.(*service.ServiceError); ok { + response.Error(c, se.Code, se.Message) + return + } + response.InternalServerError(c, "failed to block user") + return + } + + response.Success(c, gin.H{"success": true}) +} + +// UnblockUser 取消拉黑 +func (h *UserHandler) UnblockUser(c *gin.Context) { + targetUserID := c.Param("id") + currentUserID := c.GetString("user_id") + + if targetUserID == currentUserID { + response.BadRequest(c, "cannot unblock yourself") + return + } + + err := h.userService.UnblockUser(c.Request.Context(), currentUserID, targetUserID) + if err != nil { + if se, ok := err.(*service.ServiceError); ok { + response.Error(c, se.Code, se.Message) + return + } + response.InternalServerError(c, "failed to unblock user") + return + } + + response.Success(c, gin.H{"success": true}) +} + +// GetBlockedUsers 获取黑名单列表 +func (h *UserHandler) GetBlockedUsers(c *gin.Context) { + currentUserID := c.GetString("user_id") + if currentUserID == "" { + response.Unauthorized(c, "") + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + if page <= 0 { + page = 1 + } + if pageSize <= 0 { + pageSize = 20 + } + + users, total, err := h.userService.GetBlockedUsers(c.Request.Context(), currentUserID, page, pageSize) + if err != nil { + response.InternalServerError(c, "failed to get blocked users") + return + } + + userIDs := make([]string, len(users)) + for i, u := range users { + userIDs[i] = u.ID + } + postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs) + userResponses := dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap) + response.Paginated(c, userResponses, total, page, pageSize) +} + +// GetBlockStatus 获取拉黑状态 +func (h *UserHandler) GetBlockStatus(c *gin.Context) { + targetUserID := c.Param("id") + currentUserID := c.GetString("user_id") + if currentUserID == "" { + response.Unauthorized(c, "") + return + } + if targetUserID == "" { + response.BadRequest(c, "target user id is required") + return + } + + isBlocked, err := h.userService.IsBlocked(c.Request.Context(), currentUserID, targetUserID) + if err != nil { + response.InternalServerError(c, "failed to get block status") + return + } + + response.Success(c, gin.H{"is_blocked": isBlocked}) +} + +// GetFollowingList 获取关注列表 +func (h *UserHandler) GetFollowingList(c *gin.Context) { + userID := c.Param("id") + currentUserID := c.GetString("user_id") + page := c.DefaultQuery("page", "1") + pageSize := c.DefaultQuery("page_size", "20") + + users, err := h.userService.GetFollowingList(c.Request.Context(), userID, page, pageSize) + if err != nil { + response.InternalServerError(c, "failed to get following list") + return + } + + // 如果已登录,获取双向关注状态和实时计算的帖子数量 + var userResponses []*dto.UserResponse + if currentUserID != "" && len(users) > 0 { + userIDs := make([]string, len(users)) + for i, u := range users { + userIDs[i] = u.ID + } + statusMap, _ := h.userService.GetMutualFollowStatus(c.Request.Context(), currentUserID, userIDs) + postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs) + userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, statusMap, postsCountMap) + } else if len(users) > 0 { + userIDs := make([]string, len(users)) + for i, u := range users { + userIDs[i] = u.ID + } + postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs) + userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap) + } else { + userResponses = dto.ConvertUsersToResponse(users) + } + + response.Success(c, gin.H{ + "list": userResponses, + }) +} + +// GetFollowersList 获取粉丝列表 +func (h *UserHandler) GetFollowersList(c *gin.Context) { + userID := c.Param("id") + currentUserID := c.GetString("user_id") + page := c.DefaultQuery("page", "1") + pageSize := c.DefaultQuery("page_size", "20") + + fmt.Printf("[DEBUG] GetFollowersList: userID=%s, currentUserID=%s\n", userID, currentUserID) + + users, err := h.userService.GetFollowersList(c.Request.Context(), userID, page, pageSize) + if err != nil { + response.InternalServerError(c, "failed to get followers list") + return + } + + fmt.Printf("[DEBUG] GetFollowersList: found %d users\n", len(users)) + + // 如果已登录,获取双向关注状态和实时计算的帖子数量 + var userResponses []*dto.UserResponse + if currentUserID != "" && len(users) > 0 { + userIDs := make([]string, len(users)) + for i, u := range users { + userIDs[i] = u.ID + } + fmt.Printf("[DEBUG] GetFollowersList: checking mutual follow status for userIDs=%v\n", userIDs) + statusMap, _ := h.userService.GetMutualFollowStatus(c.Request.Context(), currentUserID, userIDs) + postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs) + userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, statusMap, postsCountMap) + } else if len(users) > 0 { + userIDs := make([]string, len(users)) + for i, u := range users { + userIDs[i] = u.ID + } + postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs) + userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap) + } else { + userResponses = dto.ConvertUsersToResponse(users) + } + + response.Success(c, gin.H{ + "list": userResponses, + }) +} + +// CheckUsername 检查用户名是否可用 +func (h *UserHandler) CheckUsername(c *gin.Context) { + username := c.Query("username") + if username == "" { + response.BadRequest(c, "username is required") + return + } + + available, err := h.userService.CheckUsernameAvailable(c.Request.Context(), username) + if err != nil { + response.InternalServerError(c, "failed to check username") + return + } + + response.Success(c, gin.H{"available": available}) +} + +// ChangePassword 修改密码 +func (h *UserHandler) ChangePassword(c *gin.Context) { + currentUserID := c.GetString("user_id") + + type ChangePasswordRequest struct { + OldPassword string `json:"old_password" binding:"required"` + NewPassword string `json:"new_password" binding:"required,min=6"` + VerificationCode string `json:"verification_code" binding:"required"` + } + + var req ChangePasswordRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + err := h.userService.ChangePassword(c.Request.Context(), currentUserID, req.OldPassword, req.NewPassword, req.VerificationCode) + if err != nil { + if se, ok := err.(*service.ServiceError); ok { + response.Error(c, se.Code, se.Message) + return + } + response.InternalServerError(c, "failed to change password") + return + } + + response.Success(c, gin.H{"success": true}) +} + +// SendChangePasswordCode 发送修改密码验证码 +func (h *UserHandler) SendChangePasswordCode(c *gin.Context) { + currentUserID := c.GetString("user_id") + if currentUserID == "" { + response.Unauthorized(c, "") + return + } + + err := h.userService.SendChangePasswordCode(c.Request.Context(), currentUserID) + if err != nil { + if se, ok := err.(*service.ServiceError); ok { + response.Error(c, se.Code, se.Message) + return + } + response.InternalServerError(c, "failed to send change password code") + return + } + + response.Success(c, gin.H{"success": true}) +} + +// Search 搜索用户 +func (h *UserHandler) Search(c *gin.Context) { + keyword := c.Query("keyword") + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + users, total, err := h.userService.Search(c.Request.Context(), keyword, page, pageSize) + if err != nil { + response.InternalServerError(c, "failed to search users") + return + } + + // 获取实时计算的帖子数量 + var userResponses []*dto.UserResponse + if len(users) > 0 { + userIDs := make([]string, len(users)) + for i, u := range users { + userIDs[i] = u.ID + } + postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs) + userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap) + } else { + userResponses = dto.ConvertUsersToResponse(users) + } + + response.Paginated(c, userResponses, total, page, pageSize) +} diff --git a/internal/handler/vote_handler.go b/internal/handler/vote_handler.go new file mode 100644 index 0000000..5a19bb3 --- /dev/null +++ b/internal/handler/vote_handler.go @@ -0,0 +1,216 @@ +package handler + +import ( + "errors" + "net/http" + + "github.com/gin-gonic/gin" + + "carrot_bbs/internal/dto" + "carrot_bbs/internal/pkg/response" + "carrot_bbs/internal/service" +) + +// VoteHandler 投票处理器 +type VoteHandler struct { + voteService *service.VoteService + postService *service.PostService +} + +// NewVoteHandler 创建投票处理器 +func NewVoteHandler(voteService *service.VoteService, postService *service.PostService) *VoteHandler { + return &VoteHandler{ + voteService: voteService, + postService: postService, + } +} + +// CreateVotePost 创建投票帖子 +// POST /api/v1/posts/vote +func (h *VoteHandler) CreateVotePost(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "请先登录") + return + } + + var req dto.CreateVotePostRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + post, err := h.voteService.CreateVotePost(c.Request.Context(), userID, &req) + if err != nil { + var moderationErr *service.PostModerationRejectedError + if errors.As(err, &moderationErr) { + response.BadRequest(c, moderationErr.UserMessage()) + return + } + response.Error(c, http.StatusBadRequest, err.Error()) + return + } + + response.Success(c, post) +} + +// GetVoteResult 获取投票结果 +// GET /api/v1/posts/:id/vote +func (h *VoteHandler) GetVoteResult(c *gin.Context) { + postID := c.Param("id") + if postID == "" { + response.BadRequest(c, "帖子ID不能为空") + return + } + + // 验证帖子存在 + _, err := h.postService.GetByID(c.Request.Context(), postID) + if err != nil { + response.NotFound(c, "帖子不存在") + return + } + + // 获取当前用户ID(可选登录) + userID := c.GetString("user_id") + + // 如果用户未登录,返回不带has_voted的结果 + if userID == "" { + options, err := h.voteService.GetVoteOptions(postID) + if err != nil { + response.InternalServerError(c, "获取投票选项失败") + return + } + + // 计算总票数 + totalVotes := 0 + for _, option := range options { + totalVotes += option.VotesCount + } + + result := &dto.VoteResultDTO{ + Options: options, + TotalVotes: totalVotes, + HasVoted: false, + } + + response.Success(c, result) + return + } + + // 用户已登录,获取完整的投票结果 + result, err := h.voteService.GetVoteResult(postID, userID) + if err != nil { + response.InternalServerError(c, "获取投票结果失败") + return + } + + response.Success(c, result) +} + +// Vote 投票 +// POST /api/v1/posts/:id/vote +func (h *VoteHandler) Vote(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "请先登录") + return + } + + postID := c.Param("id") + if postID == "" { + response.BadRequest(c, "帖子ID不能为空") + return + } + + // 验证帖子存在 + _, err := h.postService.GetByID(c.Request.Context(), postID) + if err != nil { + response.NotFound(c, "帖子不存在") + return + } + + // 解析请求体 + var req struct { + OptionID string `json:"option_id" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + if err := h.voteService.Vote(c.Request.Context(), postID, userID, req.OptionID); err != nil { + response.Error(c, http.StatusBadRequest, err.Error()) + return + } + + response.Success(c, gin.H{"success": true}) +} + +// Unvote 取消投票 +// DELETE /api/v1/posts/:id/vote +func (h *VoteHandler) Unvote(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "请先登录") + return + } + + postID := c.Param("id") + if postID == "" { + response.BadRequest(c, "帖子ID不能为空") + return + } + + // 验证帖子存在 + _, err := h.postService.GetByID(c.Request.Context(), postID) + if err != nil { + response.NotFound(c, "帖子不存在") + return + } + + if err := h.voteService.Unvote(c.Request.Context(), postID, userID); err != nil { + response.Error(c, http.StatusBadRequest, err.Error()) + return + } + + response.Success(c, gin.H{"success": true}) +} + +// UpdateVoteOption 更新投票选项(仅作者) +// PUT /api/v1/vote-options/:id +func (h *VoteHandler) UpdateVoteOption(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "请先登录") + return + } + + optionID := c.Param("id") + if optionID == "" { + response.BadRequest(c, "选项ID不能为空") + return + } + + // 解析请求体 + var req struct { + Content string `json:"content" binding:"required"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + // 获取帖子ID(从查询参数或请求体中获取) + postID := c.Query("post_id") + if postID == "" { + response.BadRequest(c, "帖子ID不能为空") + return + } + + if err := h.voteService.UpdateVoteOption(c.Request.Context(), postID, optionID, userID, req.Content); err != nil { + response.Error(c, http.StatusForbidden, err.Error()) + return + } + + response.Success(c, gin.H{"success": true}) +} diff --git a/internal/handler/websocket_handler.go b/internal/handler/websocket_handler.go new file mode 100644 index 0000000..7c847d4 --- /dev/null +++ b/internal/handler/websocket_handler.go @@ -0,0 +1,866 @@ +package handler + +import ( + "context" + "encoding/json" + "log" + "net/http" + "strconv" + "strings" + "time" + + "carrot_bbs/internal/dto" + "carrot_bbs/internal/model" + ws "carrot_bbs/internal/pkg/websocket" + "carrot_bbs/internal/repository" + "carrot_bbs/internal/service" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true // 允许所有来源,生产环境应该限制 + }, +} + +// WebSocketHandler WebSocket处理器 +type WebSocketHandler struct { + jwtService *service.JWTService + chatService service.ChatService + groupService service.GroupService + groupRepo repository.GroupRepository + userRepo *repository.UserRepository + wsManager *ws.WebSocketManager +} + +// NewWebSocketHandler 创建WebSocket处理器 +func NewWebSocketHandler( + jwtService *service.JWTService, + chatService service.ChatService, + groupService service.GroupService, + groupRepo repository.GroupRepository, + userRepo *repository.UserRepository, + wsManager *ws.WebSocketManager, +) *WebSocketHandler { + return &WebSocketHandler{ + jwtService: jwtService, + chatService: chatService, + groupService: groupService, + groupRepo: groupRepo, + userRepo: userRepo, + wsManager: wsManager, + } +} + +// HandleWebSocket 处理WebSocket连接 +func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) { + // 调试:打印请求头信息 + log.Printf("[WebSocket] 收到请求: Method=%s, Path=%s", c.Request.Method, c.Request.URL.Path) + log.Printf("[WebSocket] 请求头: Connection=%s, Upgrade=%s", + c.GetHeader("Connection"), + c.GetHeader("Upgrade")) + log.Printf("[WebSocket] Sec-WebSocket-Key=%s, Sec-WebSocket-Version=%s", + c.GetHeader("Sec-WebSocket-Key"), + c.GetHeader("Sec-WebSocket-Version")) + + // 从query参数获取token + token := c.Query("token") + if token == "" { + // 尝试从header获取 + authHeader := c.GetHeader("Authorization") + if strings.HasPrefix(authHeader, "Bearer ") { + token = strings.TrimPrefix(authHeader, "Bearer ") + } + } + + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "missing token"}) + return + } + + // 验证token + claims, err := h.jwtService.ParseToken(token) + if err != nil { + log.Printf("Invalid token: %v", err) + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"}) + return + } + + userID := claims.UserID + if userID == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token claims"}) + return + } + + // 升级HTTP连接为WebSocket连接 + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + log.Printf("Failed to upgrade connection: %v", err) + log.Printf("[WebSocket] 请求详情 - User-Agent: %s, Content-Type: %s", + c.GetHeader("User-Agent"), + c.GetHeader("Content-Type")) + return + } + + // 如果用户已在线,先注销旧连接 + if h.wsManager.IsUserOnline(userID) { + log.Printf("[DEBUG] 用户 %s 已有在线连接,复用该连接", userID) + } else { + log.Printf("[DEBUG] 用户 %s 当前不在线,创建新连接", userID) + } + + // 创建客户端 + client := &ws.Client{ + ID: userID, + UserID: userID, + Conn: conn, + Send: make(chan []byte, 256), + Manager: h.wsManager, + } + + // 注册客户端 + h.wsManager.Register(client) + + // 启动读写协程 + go client.WritePump() + go h.handleMessages(client) + + log.Printf("[DEBUG] WebSocket连接建立: userID=%s, 当前在线=%v", userID, h.wsManager.IsUserOnline(userID)) +} + +// handleMessages 处理客户端消息 +// 针对移动端优化:增加超时时间到 120 秒,配合客户端 55 秒心跳 +func (h *WebSocketHandler) handleMessages(client *ws.Client) { + defer func() { + h.wsManager.Unregister(client) + client.Conn.Close() + }() + + client.Conn.SetReadLimit(512 * 1024) // 512KB + client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒 + client.Conn.SetPongHandler(func(string) error { + client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒 + return nil + }) + + // 心跳定时器 - 服务端主动 ping 间隔保持 30 秒 + pingTicker := time.NewTicker(30 * time.Second) + defer pingTicker.Stop() + + for { + select { + case <-pingTicker.C: + // 发送心跳 + if err := client.SendPing(); err != nil { + log.Printf("Failed to send ping: %v", err) + return + } + default: + _, message, err := client.Conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("WebSocket error: %v", err) + } + return + } + + var wsMsg ws.WSMessage + if err := json.Unmarshal(message, &wsMsg); err != nil { + log.Printf("Failed to unmarshal message: %v", err) + continue + } + + h.processMessage(client, &wsMsg) + } + } +} + +// processMessage 处理消息 +func (h *WebSocketHandler) processMessage(client *ws.Client, msg *ws.WSMessage) { + switch msg.Type { + case ws.MessageTypePing: + // 响应心跳 + if err := client.SendPong(); err != nil { + log.Printf("Failed to send pong: %v", err) + } + + case ws.MessageTypePong: + // 客户端响应心跳 + + case ws.MessageTypeMessage: + // 处理聊天消息 + h.handleChatMessage(client, msg) + + case ws.MessageTypeTyping: + // 处理正在输入状态 + h.handleTyping(client, msg) + + case ws.MessageTypeRead: + // 处理已读回执 + h.handleReadReceipt(client, msg) + + // 群组消息处理 + case ws.MessageTypeGroupMessage: + // 处理群消息 + h.handleGroupMessage(client, msg) + + case ws.MessageTypeGroupTyping: + // 处理群输入状态 + h.handleGroupTyping(client, msg) + + case ws.MessageTypeGroupRead: + // 处理群消息已读 + h.handleGroupReadReceipt(client, msg) + + case ws.MessageTypeGroupRecall: + // 处理群消息撤回 + h.handleGroupRecall(client, msg) + + default: + log.Printf("Unknown message type: %s", msg.Type) + } +} + +// handleChatMessage 处理聊天消息 +func (h *WebSocketHandler) handleChatMessage(client *ws.Client, msg *ws.WSMessage) { + data, ok := msg.Data.(map[string]interface{}) + if !ok { + log.Printf("Invalid message data format") + return + } + + log.Printf("[DEBUG handleChatMessage] 完整data: %+v", data) + + conversationIDStr, _ := data["conversationId"].(string) + + if conversationIDStr == "" { + log.Printf("Missing conversationId") + return + } + + // 解析会话ID + conversationID, err := service.ParseConversationID(conversationIDStr) + if err != nil { + log.Printf("Invalid conversation ID: %v", err) + return + } + + // 解析 segments + var segments model.MessageSegments + if data["segments"] != nil { + segmentsBytes, err := json.Marshal(data["segments"]) + if err == nil { + json.Unmarshal(segmentsBytes, &segments) + } + } + + // 从 segments 中提取回复消息ID + replyToID := dto.GetReplyMessageID(segments) + var replyToIDPtr *string + if replyToID != "" { + replyToIDPtr = &replyToID + } + + // 发送消息 - 使用 segments + message, err := h.chatService.SendMessage(context.Background(), client.UserID, conversationID, segments, replyToIDPtr) + if err != nil { + log.Printf("Failed to send message: %v", err) + // 发送错误消息 + errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{ + "error": "Failed to send message", + }) + if client.Send != nil { + msgBytes, _ := json.Marshal(errorMsg) + client.Send <- msgBytes + } + return + } + + // 发送确认消息(使用 meta 事件格式,包含完整的消息内容) + metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{ + "detail_type": ws.MetaDetailTypeAck, + "conversation_id": conversationID, + "id": message.ID, + "user_id": client.UserID, + "sender_id": client.UserID, + "seq": message.Seq, + "segments": message.Segments, + "created_at": message.CreatedAt.UnixMilli(), + }) + if client.Send != nil { + msgBytes, _ := json.Marshal(metaAckMsg) + log.Printf("[DEBUG handleChatMessage] 私聊 ack 消息: %s", string(msgBytes)) + log.Printf("[DEBUG handleChatMessage] message.Segments 类型: %T, 值: %+v", message.Segments, message.Segments) + client.Send <- msgBytes + } +} + +// handleTyping 处理正在输入状态 +func (h *WebSocketHandler) handleTyping(client *ws.Client, msg *ws.WSMessage) { + data, ok := msg.Data.(map[string]interface{}) + if !ok { + return + } + + conversationIDStr, _ := data["conversationId"].(string) + if conversationIDStr == "" { + return + } + + conversationID, err := service.ParseConversationID(conversationIDStr) + if err != nil { + return + } + + // 直接使用 string 类型的 userID + h.chatService.SendTyping(context.Background(), client.UserID, conversationID) +} + +// handleReadReceipt 处理已读回执 +func (h *WebSocketHandler) handleReadReceipt(client *ws.Client, msg *ws.WSMessage) { + data, ok := msg.Data.(map[string]interface{}) + if !ok { + return + } + + conversationIDStr, _ := data["conversationId"].(string) + if conversationIDStr == "" { + return + } + + conversationID, err := service.ParseConversationID(conversationIDStr) + if err != nil { + return + } + + // 获取lastReadSeq + lastReadSeq, _ := data["lastReadSeq"].(float64) + if lastReadSeq == 0 { + return + } + + // 直接使用 string 类型的 userID 和 conversationID + if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil { + log.Printf("Failed to mark as read: %v", err) + } +} + +// ==================== 群组消息处理 ==================== + +// handleGroupMessage 处理群消息 +func (h *WebSocketHandler) handleGroupMessage(client *ws.Client, msg *ws.WSMessage) { + // 打印接收到的消息类型和数据,用于调试 + log.Printf("[handleGroupMessage] Received message type: %s", msg.Type) + log.Printf("[handleGroupMessage] Message data: %+v", msg.Data) + + data, ok := msg.Data.(map[string]interface{}) + if !ok { + log.Printf("Invalid group message data format: data is not map[string]interface{}") + return + } + + // 解析群组ID(支持 camelCase 和 snake_case) + var groupIDFloat float64 + groupID := "" // 使用 groupID 作为最终变量名 + if val, ok := data["groupId"].(float64); ok { + groupIDFloat = val + groupID = strconv.FormatFloat(groupIDFloat, 'f', 0, 64) + } else if val, ok := data["group_id"].(string); ok { + groupID = val + } + + if groupID == "" { + log.Printf("Missing groupId in group message") + return + } + + // 解析会话ID(支持 camelCase 和 snake_case) + var conversationID string + if val, ok := data["conversationId"].(string); ok { + conversationID = val + } else if val, ok := data["conversation_id"].(string); ok { + conversationID = val + } + if conversationID == "" { + log.Printf("Missing conversationId in group message") + return + } + + // 解析 segments + var segments model.MessageSegments + if data["segments"] != nil { + segmentsBytes, err := json.Marshal(data["segments"]) + if err == nil { + json.Unmarshal(segmentsBytes, &segments) + } + } + + // 解析@用户列表(支持 camelCase 和 snake_case) + var mentionUsers []uint64 + var mentionUsersInterface []interface{} + if val, ok := data["mentionUsers"].([]interface{}); ok { + mentionUsersInterface = val + } else if val, ok := data["mention_users"].([]interface{}); ok { + mentionUsersInterface = val + } + if len(mentionUsersInterface) > 0 { + for _, uid := range mentionUsersInterface { + if uidFloat, ok := uid.(float64); ok { + mentionUsers = append(mentionUsers, uint64(uidFloat)) + } else if uidStr, ok := uid.(string); ok { + // 处理字符串格式的用户ID + if uidInt, err := strconv.ParseUint(uidStr, 10, 64); err == nil { + mentionUsers = append(mentionUsers, uidInt) + } + } + } + } + + // 解析@所有人(支持 camelCase 和 snake_case) + var mentionAll bool + if val, ok := data["mentionAll"].(bool); ok { + mentionAll = val + } else if val, ok := data["mention_all"].(bool); ok { + mentionAll = val + } + + // 检查用户是否可以发送群消息(验证成员身份和禁言状态) + // client.UserID 已经是 string 格式的 UUID + if err := h.groupService.CanSendGroupMessage(client.UserID, groupID); err != nil { + log.Printf("User cannot send group message: %v", err) + // 发送错误消息 + errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{ + "error": "Cannot send group message", + "reason": err.Error(), + "type": "group_message_error", + "groupId": groupID, + }) + if client.Send != nil { + msgBytes, _ := json.Marshal(errorMsg) + client.Send <- msgBytes + } + return + } + + // 检查@所有人权限(只有群主和管理员可以@所有人) + if mentionAll { + if !h.groupService.IsGroupAdmin(client.UserID, groupID) { + log.Printf("User %s has no permission to mention all in group %s", client.UserID, groupID) + mentionAll = false // 取消@所有人标记 + } + } + + // 创建消息 + message := &model.Message{ + ConversationID: conversationID, + SenderID: client.UserID, + Segments: segments, + Status: model.MessageStatusNormal, + MentionAll: mentionAll, + } + + // 序列化mentionUsers为JSON + if len(mentionUsers) > 0 { + mentionUsersJSON, _ := json.Marshal(mentionUsers) + message.MentionUsers = string(mentionUsersJSON) + } + + // 保存消息到数据库(只存库,不发私聊 WebSocket 帧,群消息通过 BroadcastGroupMessage 单独广播) + savedMessage, err := h.chatService.SaveMessage(context.Background(), client.UserID, conversationID, segments, nil) + if err != nil { + log.Printf("Failed to save group message: %v", err) + errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{ + "error": "Failed to save group message", + }) + if client.Send != nil { + msgBytes, _ := json.Marshal(errorMsg) + client.Send <- msgBytes + } + return + } + + // 更新消息的mention信息 + if len(mentionUsers) > 0 || mentionAll { + message.ID = savedMessage.ID + message.Seq = savedMessage.Seq + } + + // 构造群消息响应 + groupMsg := &ws.GroupMessage{ + ID: savedMessage.ID, + ConversationID: conversationID, + GroupID: groupID, + SenderID: client.UserID, + Seq: savedMessage.Seq, + Segments: segments, + MentionUsers: mentionUsers, + MentionAll: mentionAll, + CreatedAt: savedMessage.CreatedAt.UnixMilli(), + } + + // 广播消息给群组所有成员(排除发送者) + h.BroadcastGroupMessage(groupID, groupMsg, client.UserID) + + // 发送确认消息给发送者(作为meta事件) + // 使用 meta 事件格式发送 ack + log.Printf("[DEBUG HandleGroupMessageSend] 准备发送 ack 消息, userID=%s, messageID=%s, seq=%d", + client.UserID, savedMessage.ID, savedMessage.Seq) + + metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{ + "detail_type": ws.MetaDetailTypeAck, + "conversation_id": conversationID, + "group_id": groupID, + "id": savedMessage.ID, + "user_id": client.UserID, + "sender_id": client.UserID, + "seq": savedMessage.Seq, + "segments": segments, + "created_at": savedMessage.CreatedAt.UnixMilli(), + }) + if client.Send != nil { + msgBytes, _ := json.Marshal(metaAckMsg) + log.Printf("[DEBUG HandleGroupMessageSend] 发送 ack 消息到 channel, userID=%s, msg=%s", + client.UserID, string(msgBytes)) + client.Send <- msgBytes + } else { + log.Printf("[ERROR HandleGroupMessageSend] client.Send 为 nil, userID=%s", client.UserID) + } + + // 处理@提及通知 + if len(mentionUsers) > 0 || mentionAll { + // 提取文本正文(不含 @ 部分) + textContent := dto.ExtractTextContentFromModel(segments) + // 在通知内容前拼接被@的真实昵称,通过群成员列表查找 + mentionContent := h.buildMentionContent(groupID, mentionUsers, mentionAll, textContent) + h.handleGroupMention(groupID, savedMessage.ID, client.UserID, mentionContent, mentionUsers, mentionAll) + } +} + +// handleGroupTyping 处理群输入状态 +func (h *WebSocketHandler) handleGroupTyping(client *ws.Client, msg *ws.WSMessage) { + data, ok := msg.Data.(map[string]interface{}) + if !ok { + return + } + + groupIDFloat, _ := data["groupId"].(float64) + if groupIDFloat == 0 { + return + } + groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64) + + isTyping, _ := data["isTyping"].(bool) + + // 验证用户是否是群成员 + // client.UserID 已经是 string 格式的 UUID + isMember, err := h.groupRepo.IsMember(groupID, client.UserID) + if err != nil || !isMember { + return + } + + // 获取用户信息 + user, err := h.userRepo.GetByID(client.UserID) + if err != nil { + return + } + + // 构造输入状态消息 + typingMsg := &ws.GroupTypingMessage{ + GroupID: groupID, + UserID: client.UserID, + Username: user.Username, + IsTyping: isTyping, + } + + // 广播给群组其他成员 + wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg) + h.BroadcastGroupNoticeExclude(groupID, wsMsg, client.UserID) +} + +// handleGroupReadReceipt 处理群消息已读回执 +func (h *WebSocketHandler) handleGroupReadReceipt(client *ws.Client, msg *ws.WSMessage) { + data, ok := msg.Data.(map[string]interface{}) + if !ok { + return + } + + conversationID, _ := data["conversationId"].(string) + if conversationID == "" { + return + } + + lastReadSeq, _ := data["lastReadSeq"].(float64) + if lastReadSeq == 0 { + return + } + + // 标记已读 + if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil { + log.Printf("Failed to mark group message as read: %v", err) + } +} + +// handleGroupRecall 处理群消息撤回 +func (h *WebSocketHandler) handleGroupRecall(client *ws.Client, msg *ws.WSMessage) { + data, ok := msg.Data.(map[string]interface{}) + if !ok { + return + } + + messageID, _ := data["messageId"].(string) + if messageID == "" { + return + } + + groupIDFloat, _ := data["groupId"].(float64) + if groupIDFloat == 0 { + return + } + groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64) + + // 撤回消息 + if err := h.chatService.RecallMessage(context.Background(), messageID, client.UserID); err != nil { + log.Printf("Failed to recall group message: %v", err) + errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{ + "error": "Failed to recall message", + }) + if client.Send != nil { + msgBytes, _ := json.Marshal(errorMsg) + client.Send <- msgBytes + } + return + } + + // 广播撤回通知给群组所有成员 + recallNotice := ws.CreateWSMessage(ws.MessageTypeGroupRecall, map[string]interface{}{ + "messageId": messageID, + "groupId": groupID, + "userId": client.UserID, + "timestamp": time.Now().UnixMilli(), + }) + h.BroadcastGroupNotice(groupID, recallNotice) +} + +// handleGroupMention 处理群消息@提及通知 +func (h *WebSocketHandler) handleGroupMention(groupID string, messageID, senderID, content string, mentionUsers []uint64, mentionAll bool) { + // 如果@所有人,获取所有群成员 + if mentionAll { + members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) + if err != nil { + log.Printf("Failed to get group members for mention all: %v", err) + return + } + + for _, member := range members { + // 不通知发送者自己 + memberIDStr := member.UserID + if memberIDStr == senderID { + continue + } + + // 发送@提及通知 + mentionMsg := &ws.GroupMentionMessage{ + GroupID: groupID, + MessageID: messageID, + FromUserID: senderID, + Content: truncateContent(content, 50), + MentionAll: true, + CreatedAt: time.Now().UnixMilli(), + } + wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg) + h.wsManager.SendToUser(memberIDStr, wsMsg) + } + return + } + + // 处理特定用户的@提及 + for _, userID := range mentionUsers { + // userID 是 uint64,转换为 string + userIDStr := strconv.FormatUint(userID, 10) + if userIDStr == senderID { + continue // 不通知发送者自己 + } + + mentionMsg := &ws.GroupMentionMessage{ + GroupID: groupID, + MessageID: messageID, + FromUserID: senderID, + Content: truncateContent(content, 50), + MentionAll: false, + CreatedAt: time.Now().UnixMilli(), + } + wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg) + h.wsManager.SendToUser(userIDStr, wsMsg) + } +} + +// buildMentionContent 构建@提及通知的内容,通过群成员列表查找被@用户的真实昵称 +func (h *WebSocketHandler) buildMentionContent(groupID string, mentionUsers []uint64, mentionAll bool, textBody string) string { + var prefix string + if mentionAll { + prefix = "@所有人 " + } else if len(mentionUsers) > 0 { + // 查询群成员列表,找到被@用户的昵称 + members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) + if err == nil { + memberNickMap := make(map[string]string, len(members)) + for _, m := range members { + displayName := m.Nickname + if displayName == "" { + displayName = m.UserID + } + memberNickMap[m.UserID] = displayName + } + for _, uid := range mentionUsers { + uidStr := strconv.FormatUint(uid, 10) + if name, ok := memberNickMap[uidStr]; ok { + prefix += "@" + name + " " + } else { + prefix += "@某人 " + } + } + } else { + for range mentionUsers { + prefix += "@某人 " + } + } + } + return prefix + textBody +} + +// BroadcastGroupMessage 向群组所有成员广播消息 +func (h *WebSocketHandler) BroadcastGroupMessage(groupID string, message *ws.GroupMessage, excludeUserID string) { + // 获取群组所有成员 + members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) + if err != nil { + log.Printf("Failed to get group members for broadcast: %v", err) + return + } + + // 创建WebSocket消息 + wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMessage, message) + + // 遍历成员,如果在线则发送消息 + for _, member := range members { + memberIDStr := member.UserID + + // 排除发送者 + if memberIDStr == excludeUserID { + continue + } + + // 发送消息 + h.wsManager.SendToUser(memberIDStr, wsMsg) + } +} + +// BroadcastGroupNotice 广播群组通知给所有成员 +func (h *WebSocketHandler) BroadcastGroupNotice(groupID string, notice *ws.WSMessage) { + // 获取群组所有成员 + members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) + if err != nil { + log.Printf("Failed to get group members for notice broadcast: %v", err) + return + } + + // 遍历成员,如果在线则发送通知 + for _, member := range members { + memberIDStr := member.UserID + h.wsManager.SendToUser(memberIDStr, notice) + } +} + +// BroadcastGroupNoticeExclude 广播群组通知给所有成员(排除指定用户) +func (h *WebSocketHandler) BroadcastGroupNoticeExclude(groupID string, notice *ws.WSMessage, excludeUserID string) { + // 获取群组所有成员 + members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) + if err != nil { + log.Printf("Failed to get group members for notice broadcast: %v", err) + return + } + + // 遍历成员,如果在线则发送通知 + for _, member := range members { + memberIDStr := member.UserID + if memberIDStr == excludeUserID { + continue + } + h.wsManager.SendToUser(memberIDStr, notice) + } +} + +// SendGroupMemberNotice 发送群成员变动通知 +func (h *WebSocketHandler) SendGroupMemberNotice(noticeType string, groupID string, data *ws.GroupNoticeData) { + notice := &ws.GroupNoticeMessage{ + NoticeType: noticeType, + GroupID: groupID, + Data: data, + Timestamp: time.Now().UnixMilli(), + } + wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupNotice, notice) + h.BroadcastGroupNotice(groupID, wsMsg) +} + +// truncateContent 截断内容 +func truncateContent(content string, maxLen int) string { + if len(content) <= maxLen { + return content + } + return content[:maxLen] + "..." +} + +// BroadcastGroupTyping 向群组所有成员广播输入状态 +func (h *WebSocketHandler) BroadcastGroupTyping(groupID string, typingMsg *ws.GroupTypingMessage, excludeUserID string) { + // 获取群组所有成员 + members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) + if err != nil { + log.Printf("Failed to get group members for typing broadcast: %v", err) + return + } + + // 创建WebSocket消息 + wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg) + + // 遍历成员,如果在线则发送消息 + for _, member := range members { + memberIDStr := member.UserID + + // 排除指定用户 + if memberIDStr == excludeUserID { + continue + } + + // 发送消息 + h.wsManager.SendToUser(memberIDStr, wsMsg) + } +} + +// BroadcastGroupRead 向群组所有成员广播已读状态 +func (h *WebSocketHandler) BroadcastGroupRead(groupID string, readMsg map[string]interface{}, excludeUserID string) { + // 获取群组所有成员 + members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) + if err != nil { + log.Printf("Failed to get group members for read broadcast: %v", err) + return + } + + // 创建WebSocket消息 + wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupRead, readMsg) + + // 遍历成员,如果在线则发送消息 + for _, member := range members { + memberIDStr := member.UserID + + // 排除指定用户 + if memberIDStr == excludeUserID { + continue + } + + // 发送消息 + h.wsManager.SendToUser(memberIDStr, wsMsg) + } +} diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go new file mode 100644 index 0000000..f1913fc --- /dev/null +++ b/internal/middleware/auth.go @@ -0,0 +1,95 @@ +package middleware + +import ( + "fmt" + "strings" + + "github.com/gin-gonic/gin" + + "carrot_bbs/internal/pkg/response" + "carrot_bbs/internal/service" +) + +// Auth 认证中间件 +func Auth(jwtService *service.JWTService) gin.HandlerFunc { + return func(c *gin.Context) { + authHeader := c.GetHeader("Authorization") + fmt.Printf("[DEBUG] Auth middleware: Authorization header = %q\n", authHeader) + + if authHeader == "" { + fmt.Printf("[DEBUG] Auth middleware: no Authorization header, returning 401\n") + response.Unauthorized(c, "authorization header is required") + c.Abort() + return + } + + // 提取Token + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || parts[0] != "Bearer" { + fmt.Printf("[DEBUG] Auth middleware: invalid Authorization header format\n") + response.Unauthorized(c, "invalid authorization header format") + c.Abort() + return + } + + token := parts[1] + fmt.Printf("[DEBUG] Auth middleware: token = %q\n", token[:min(20, len(token))]+"...") + + // 验证Token + claims, err := jwtService.ParseToken(token) + if err != nil { + fmt.Printf("[DEBUG] Auth middleware: failed to parse token: %v\n", err) + response.Unauthorized(c, "invalid token") + c.Abort() + return + } + + fmt.Printf("[DEBUG] Auth middleware: parsed claims, user_id = %q, username = %q\n", claims.UserID, claims.Username) + + // 将用户信息存入上下文 + c.Set("user_id", claims.UserID) + c.Set("username", claims.Username) + + c.Next() + } +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +// OptionalAuth 可选认证中间件 +func OptionalAuth(jwtService *service.JWTService) gin.HandlerFunc { + return func(c *gin.Context) { + authHeader := c.GetHeader("Authorization") + if authHeader == "" { + c.Next() + return + } + + // 提取Token + parts := strings.SplitN(authHeader, " ", 2) + if len(parts) != 2 || parts[0] != "Bearer" { + c.Next() + return + } + + token := parts[1] + + // 验证Token + claims, err := jwtService.ParseToken(token) + if err != nil { + c.Next() + return + } + + // 将用户信息存入上下文 + c.Set("user_id", claims.UserID) + c.Set("username", claims.Username) + + c.Next() + } +} diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go new file mode 100644 index 0000000..d6039bb --- /dev/null +++ b/internal/middleware/cors.go @@ -0,0 +1,46 @@ +package middleware + +import ( + "log" + "strings" + + "github.com/gin-gonic/gin" +) + +// CORS CORS中间件 +func CORS() gin.HandlerFunc { + return func(c *gin.Context) { + // 获取请求路径 + path := c.Request.URL.Path + + c.Header("Access-Control-Allow-Origin", "*") + c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") + // 添加 WebSocket 升级所需的头 + c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization, Connection, Upgrade, Sec-WebSocket-Key, Sec-WebSocket-Version, Sec-WebSocket-Protocol, Sec-WebSocket-Extensions") + c.Header("Access-Control-Expose-Headers", "Content-Length, Connection, Upgrade") + c.Header("Access-Control-Allow-Credentials", "true") + + // 处理 WebSocket 升级请求的预检 + if c.Request.Method == "OPTIONS" { + log.Printf("[CORS] OPTIONS 预检请求: %s", path) + c.AbortWithStatus(204) + return + } + + // 针对 WebSocket 路径的特殊处理 + if path == "/ws" { + connection := c.GetHeader("Connection") + upgrade := c.GetHeader("Upgrade") + log.Printf("[CORS] WebSocket 请求: Connection=%s, Upgrade=%s", connection, upgrade) + + // 检查是否是有效的 WebSocket 升级请求 + if strings.Contains(strings.ToLower(connection), "upgrade") && strings.ToLower(upgrade) == "websocket" { + log.Printf("[CORS] 有效的 WebSocket 升级请求") + } else { + log.Printf("[CORS] 警告: 不是有效的 WebSocket 升级请求!") + } + } + + c.Next() + } +} diff --git a/internal/middleware/logger.go b/internal/middleware/logger.go new file mode 100644 index 0000000..a256b3f --- /dev/null +++ b/internal/middleware/logger.go @@ -0,0 +1,49 @@ +package middleware + +import ( + "time" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// Logger 日志中间件 +func Logger(logger *zap.Logger) gin.HandlerFunc { + return func(c *gin.Context) { + start := time.Now() + path := c.Request.URL.Path + + c.Next() + + latency := time.Since(start) + statusCode := c.Writer.Status() + + logger.Info("request", + zap.String("method", c.Request.Method), + zap.String("path", path), + zap.Int("status", statusCode), + zap.Duration("latency", latency), + zap.String("ip", c.ClientIP()), + zap.String("user-agent", c.Request.UserAgent()), + ) + } +} + +// Recovery 恢复中间件 +func Recovery(logger *zap.Logger) gin.HandlerFunc { + return func(c *gin.Context) { + defer func() { + if err := recover(); err != nil { + logger.Error("panic recovered", + zap.Any("error", err), + ) + c.JSON(500, gin.H{ + "code": 500, + "message": "internal server error", + }) + c.Abort() + } + }() + c.Next() + } +} diff --git a/internal/middleware/ratelimit.go b/internal/middleware/ratelimit.go new file mode 100644 index 0000000..8483651 --- /dev/null +++ b/internal/middleware/ratelimit.go @@ -0,0 +1,102 @@ +package middleware + +import ( + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +// RateLimiter 限流器 +type RateLimiter struct { + requests map[string][]time.Time + mu sync.Mutex + limit int + window time.Duration +} + +// NewRateLimiter 创建限流器 +func NewRateLimiter(limit int, window time.Duration) *RateLimiter { + rl := &RateLimiter{ + requests: make(map[string][]time.Time), + limit: limit, + window: window, + } + + // 定期清理过期的记录 + go func() { + for { + time.Sleep(window) + rl.cleanup() + } + }() + + return rl +} + +// cleanup 清理过期的记录 +func (rl *RateLimiter) cleanup() { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + for key, times := range rl.requests { + var valid []time.Time + for _, t := range times { + if now.Sub(t) < rl.window { + valid = append(valid, t) + } + } + if len(valid) == 0 { + delete(rl.requests, key) + } else { + rl.requests[key] = valid + } + } +} + +// isAllowed 检查是否允许请求 +func (rl *RateLimiter) isAllowed(key string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + times := rl.requests[key] + + // 过滤掉过期的 + var valid []time.Time + for _, t := range times { + if now.Sub(t) < rl.window { + valid = append(valid, t) + } + } + + if len(valid) >= rl.limit { + rl.requests[key] = valid + return false + } + + rl.requests[key] = append(valid, now) + return true +} + +// RateLimit 限流中间件 +func RateLimit(requestsPerMinute int) gin.HandlerFunc { + limiter := NewRateLimiter(requestsPerMinute, time.Minute) + + return func(c *gin.Context) { + ip := c.ClientIP() + + if !limiter.isAllowed(ip) { + c.JSON(http.StatusTooManyRequests, gin.H{ + "code": 429, + "message": "too many requests", + }) + c.Abort() + return + } + + c.Next() + } +} diff --git a/internal/model/audit_log.go b/internal/model/audit_log.go new file mode 100644 index 0000000..9428371 --- /dev/null +++ b/internal/model/audit_log.go @@ -0,0 +1,118 @@ +package model + +import ( + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +// AuditTargetType 审核对象类型 +type AuditTargetType string + +const ( + AuditTargetTypePost AuditTargetType = "post" // 帖子 + AuditTargetTypeComment AuditTargetType = "comment" // 评论 + AuditTargetTypeMessage AuditTargetType = "message" // 私信 + AuditTargetTypeUser AuditTargetType = "user" // 用户资料 + AuditTargetTypeImage AuditTargetType = "image" // 图片 +) + +// AuditResult 审核结果 +type AuditResult string + +const ( + AuditResultPass AuditResult = "pass" // 通过 + AuditResultReview AuditResult = "review" // 需人工复审 + AuditResultBlock AuditResult = "block" // 违规拦截 + AuditResultUnknown AuditResult = "unknown" // 未知 +) + +// AuditRiskLevel 风险等级 +type AuditRiskLevel string + +const ( + AuditRiskLevelLow AuditRiskLevel = "low" // 低风险 + AuditRiskLevelMedium AuditRiskLevel = "medium" // 中风险 + AuditRiskLevelHigh AuditRiskLevel = "high" // 高风险 +) + +// AuditSource 审核来源 +type AuditSource string + +const ( + AuditSourceAuto AuditSource = "auto" // 自动审核 + AuditSourceManual AuditSource = "manual" // 人工审核 + AuditSourceCallback AuditSource = "callback" // 回调审核 +) + +// AuditLog 审核日志实体 +type AuditLog struct { + ID string `json:"id" gorm:"type:varchar(36);primaryKey"` + TargetType AuditTargetType `json:"target_type" gorm:"type:varchar(50);index"` + TargetID string `json:"target_id" gorm:"type:varchar(255);index"` + Content string `json:"content" gorm:"type:text"` // 待审核内容 + ContentType string `json:"content_type" gorm:"type:varchar(50)"` // 内容类型: text, image + ContentURL string `json:"content_url" gorm:"type:text"` // 图片/文件URL + AuditType string `json:"audit_type" gorm:"type:varchar(50)"` // 审核类型: porn, violence, ad, political, fraud, gamble + Result AuditResult `json:"result" gorm:"type:varchar(50);index"` + RiskLevel AuditRiskLevel `json:"risk_level" gorm:"type:varchar(20)"` + Labels string `json:"labels" gorm:"type:text"` // JSON数组,标签列表 + Suggestion string `json:"suggestion" gorm:"type:varchar(50)"` // pass, review, block + Detail string `json:"detail" gorm:"type:text"` // 详细说明 + ThirdPartyID string `json:"third_party_id" gorm:"type:varchar(255)"` // 第三方审核服务返回的ID + Source AuditSource `json:"source" gorm:"type:varchar(20);default:auto"` + ReviewerID string `json:"reviewer_id" gorm:"type:varchar(255)"` // 审核人ID(人工审核时使用) + ReviewerName string `json:"reviewer_name" gorm:"type:varchar(100)"` // 审核人名称 + ReviewTime *time.Time `json:"review_time" gorm:"index"` // 审核时间 + UserID string `json:"user_id" gorm:"type:varchar(255);index"` // 内容发布者ID + UserIP string `json:"user_ip" gorm:"type:varchar(45)"` // 用户IP + Status string `json:"status" gorm:"type:varchar(20);default:pending"` // pending, completed, failed + RejectReason string `json:"reject_reason" gorm:"type:text"` // 拒绝原因(人工审核时使用) + ExtraData string `json:"extra_data" gorm:"type:text"` // 额外数据,JSON格式 + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` +} + +// BeforeCreate 创建前生成UUID +func (al *AuditLog) BeforeCreate(tx *gorm.DB) error { + if al.ID == "" { + al.ID = uuid.New().String() + } + return nil +} + +func (AuditLog) TableName() string { + return "audit_logs" +} + +// AuditLogRequest 创建审核日志请求 +type AuditLogRequest struct { + TargetType AuditTargetType `json:"target_type" validate:"required"` + TargetID string `json:"target_id" validate:"required"` + Content string `json:"content"` + ContentType string `json:"content_type"` + ContentURL string `json:"content_url"` + AuditType string `json:"audit_type"` + UserID string `json:"user_id"` + UserIP string `json:"user_ip"` +} + +// AuditLogListItem 审核日志列表项 +type AuditLogListItem struct { + ID string `json:"id"` + TargetType AuditTargetType `json:"target_type"` + TargetID string `json:"target_id"` + Content string `json:"content"` + ContentType string `json:"content_type"` + Result AuditResult `json:"result"` + RiskLevel AuditRiskLevel `json:"risk_level"` + Suggestion string `json:"suggestion"` + Source AuditSource `json:"source"` + ReviewerID string `json:"reviewer_id"` + ReviewTime *time.Time `json:"review_time"` + UserID string `json:"user_id"` + Status string `json:"status"` + CreatedAt time.Time `json:"created_at"` +} diff --git a/internal/model/comment.go b/internal/model/comment.go new file mode 100644 index 0000000..f8c5946 --- /dev/null +++ b/internal/model/comment.go @@ -0,0 +1,80 @@ +package model + +import ( + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +// CommentStatus 评论状态 +type CommentStatus string + +const ( + CommentStatusDraft CommentStatus = "draft" + CommentStatusPending CommentStatus = "pending" + CommentStatusPublished CommentStatus = "published" + CommentStatusRejected CommentStatus = "rejected" + CommentStatusDeleted CommentStatus = "deleted" +) + +// Comment 评论实体 +type Comment struct { + ID string `json:"id" gorm:"type:varchar(36);primaryKey"` + PostID string `json:"post_id" gorm:"type:varchar(36);not null;index:idx_comments_post_parent_status_created,priority:1"` + UserID string `json:"user_id" gorm:"type:varchar(36);index;not null"` + ParentID *string `json:"parent_id" gorm:"type:varchar(36);index:idx_comments_post_parent_status_created,priority:2"` // 父评论 ID(支持嵌套) + RootID *string `json:"root_id" gorm:"type:varchar(36);index:idx_comments_root_status_created,priority:1"` // 根评论 ID(用于高效查询) + Content string `json:"content" gorm:"type:text;not null"` + Images string `json:"images" gorm:"type:text"` // 图片URL列表,JSON数组格式 + + // 关联 + User *User `json:"-" gorm:"foreignKey:UserID"` + Replies []*Comment `json:"-" gorm:"-"` // 子回复(手动加载,非 GORM 关联) + + // 审核状态 + Status CommentStatus `json:"status" gorm:"type:varchar(20);default:published;index:idx_comments_post_parent_status_created,priority:3;index:idx_comments_root_status_created,priority:2"` + + // 统计 + LikesCount int `json:"likes_count" gorm:"default:0"` + RepliesCount int `json:"replies_count" gorm:"default:0"` + + // 软删除 + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` + + // 时间戳 + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_comments_post_parent_status_created,priority:4,sort:asc;index:idx_comments_root_status_created,priority:3,sort:asc"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` +} + +// BeforeCreate 创建前生成UUID +func (c *Comment) BeforeCreate(tx *gorm.DB) error { + if c.ID == "" { + c.ID = uuid.New().String() + } + return nil +} + +func (Comment) TableName() string { + return "comments" +} + +// CommentLike 评论点赞 +type CommentLike struct { + ID string `json:"id" gorm:"type:varchar(36);primaryKey"` + CommentID string `json:"comment_id" gorm:"type:varchar(36);index;not null"` + UserID string `json:"user_id" gorm:"type:varchar(36);index;not null"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` +} + +// BeforeCreate 创建前生成UUID +func (cl *CommentLike) BeforeCreate(tx *gorm.DB) error { + if cl.ID == "" { + cl.ID = uuid.New().String() + } + return nil +} + +func (CommentLike) TableName() string { + return "comment_likes" +} diff --git a/internal/model/conversation.go b/internal/model/conversation.go new file mode 100644 index 0000000..c8aae11 --- /dev/null +++ b/internal/model/conversation.go @@ -0,0 +1,68 @@ +package model + +import ( + "strconv" + "time" + + "gorm.io/gorm" + + "carrot_bbs/internal/pkg/utils" +) + +// ConversationType 会话类型 +type ConversationType string + +const ( + ConversationTypePrivate ConversationType = "private" // 私聊 + ConversationTypeGroup ConversationType = "group" // 群聊 + ConversationTypeSystem ConversationType = "system" // 系统通知会话 +) + +// Conversation 会话实体 +// 使用雪花算法ID(作为string存储)和seq机制实现消息排序和增量同步 +type Conversation struct { + ID string `gorm:"primaryKey;size:20" json:"id"` // 雪花算法ID(转为string避免精度丢失) + Type ConversationType `gorm:"type:varchar(20);default:'private'" json:"type"` // 会话类型 + GroupID *string `gorm:"index" json:"group_id,omitempty"` // 关联的群组ID(群聊时使用,string类型避免JS精度丢失);使用指针支持NULL值 + LastSeq int64 `gorm:"default:0" json:"last_seq"` // 最后一条消息的seq + LastMsgTime *time.Time `json:"last_msg_time,omitempty"` // 最后消息时间 + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + + // 关联 - 使用 polymorphic 模式避免外键约束问题 + Participants []ConversationParticipant `gorm:"foreignKey:ConversationID" json:"participants,omitempty"` + Group *Group `gorm:"foreignKey:GroupID;references:ID" json:"group,omitempty"` +} + +// BeforeCreate 创建前生成雪花算法ID +func (c *Conversation) BeforeCreate(tx *gorm.DB) error { + if c.ID == "" { + id, err := utils.GetSnowflake().GenerateID() + if err != nil { + return err + } + c.ID = strconv.FormatInt(id, 10) + } + return nil +} + +func (Conversation) TableName() string { + return "conversations" +} + +// ConversationParticipant 会话参与者 +type ConversationParticipant struct { + ID uint `gorm:"primaryKey" json:"id"` + ConversationID string `gorm:"not null;size:20;uniqueIndex:idx_conversation_user,priority:1;index:idx_cp_conversation_user,priority:1" json:"conversation_id"` // 雪花算法ID(string类型) + UserID string `gorm:"column:user_id;type:varchar(50);not null;uniqueIndex:idx_conversation_user,priority:2;index:idx_cp_conversation_user,priority:2;index:idx_cp_user_hidden_pinned_updated,priority:1" json:"user_id"` // UUID格式,与JWT中user_id保持一致 + LastReadSeq int64 `gorm:"default:0" json:"last_read_seq"` // 已读到的seq位置 + Muted bool `gorm:"default:false" json:"muted"` // 是否免打扰 + IsPinned bool `gorm:"default:false;index:idx_cp_user_hidden_pinned_updated,priority:3" json:"is_pinned"` // 是否置顶会话(用户维度) + HiddenAt *time.Time `gorm:"index:idx_cp_user_hidden_pinned_updated,priority:2" json:"hidden_at,omitempty"` // 仅自己删除会话时使用,收到新消息后自动恢复 + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime;index:idx_cp_user_hidden_pinned_updated,priority:4,sort:desc"` +} + +func (ConversationParticipant) TableName() string { + return "conversation_participants" +} diff --git a/internal/model/device_token.go b/internal/model/device_token.go new file mode 100644 index 0000000..86bc363 --- /dev/null +++ b/internal/model/device_token.go @@ -0,0 +1,94 @@ +package model + +import ( + "time" + + "gorm.io/gorm" + + "carrot_bbs/internal/pkg/utils" +) + +// DeviceType 设备类型 +type DeviceType string + +const ( + DeviceTypeIOS DeviceType = "ios" // iOS设备 + DeviceTypeAndroid DeviceType = "android" // Android设备 + DeviceTypeWeb DeviceType = "web" // Web端 +) + +// DeviceToken 设备Token实体 +// 用于管理用户的多设备推送Token +type DeviceToken struct { + ID int64 `gorm:"primaryKey;autoIncrement:false" json:"id"` // 雪花算法ID + UserID string `gorm:"column:user_id;type:varchar(50);index;not null" json:"user_id"` // 用户ID (UUID格式) + DeviceID string `gorm:"type:varchar(100);not null" json:"device_id"` // 设备唯一标识 + DeviceType DeviceType `gorm:"type:varchar(20);not null" json:"device_type"` // 设备类型 + PushToken string `gorm:"type:varchar(256);not null" json:"push_token"` // 推送Token(FCM/APNs等) + IsActive bool `gorm:"default:true" json:"is_active"` // 是否活跃 + DeviceName string `gorm:"type:varchar(100)" json:"device_name,omitempty"` // 设备名称(可选) + + // 时间戳 + LastUsedAt *time.Time `json:"last_used_at,omitempty"` // 最后使用时间 + + // 软删除 + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` + + // 时间戳 + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` +} + +// BeforeCreate 创建前生成雪花算法ID +func (d *DeviceToken) BeforeCreate(tx *gorm.DB) error { + if d.ID == 0 { + id, err := utils.GetSnowflake().GenerateID() + if err != nil { + return err + } + d.ID = id + } + return nil +} + +func (DeviceToken) TableName() string { + return "device_tokens" +} + +// UpdateLastUsed 更新最后使用时间 +func (d *DeviceToken) UpdateLastUsed() { + now := time.Now() + d.LastUsedAt = &now +} + +// Deactivate 停用设备 +func (d *DeviceToken) Deactivate() { + d.IsActive = false +} + +// Activate 激活设备 +func (d *DeviceToken) Activate() { + d.IsActive = true + now := time.Now() + d.LastUsedAt = &now +} + +// IsIOS 判断是否为iOS设备 +func (d *DeviceToken) IsIOS() bool { + return d.DeviceType == DeviceTypeIOS +} + +// IsAndroid 判断是否为Android设备 +func (d *DeviceToken) IsAndroid() bool { + return d.DeviceType == DeviceTypeAndroid +} + +// IsWeb 判断是否为Web端 +func (d *DeviceToken) IsWeb() bool { + return d.DeviceType == DeviceTypeWeb +} + +// SupportsMobilePush 判断是否支持手机推送 +func (d *DeviceToken) SupportsMobilePush() bool { + return d.DeviceType == DeviceTypeIOS || d.DeviceType == DeviceTypeAndroid +} diff --git a/internal/model/favorite.go b/internal/model/favorite.go new file mode 100644 index 0000000..521511c --- /dev/null +++ b/internal/model/favorite.go @@ -0,0 +1,28 @@ +package model + +import ( + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +// Favorite 收藏 +type Favorite struct { + ID string `json:"id" gorm:"type:varchar(36);primaryKey"` + PostID string `json:"post_id" gorm:"type:varchar(36);not null;index;uniqueIndex:idx_favorite_post_user,priority:1"` + UserID string `json:"user_id" gorm:"type:varchar(36);not null;index;uniqueIndex:idx_favorite_post_user,priority:2"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` +} + +// BeforeCreate 创建前生成UUID +func (f *Favorite) BeforeCreate(tx *gorm.DB) error { + if f.ID == "" { + f.ID = uuid.New().String() + } + return nil +} + +func (Favorite) TableName() string { + return "favorites" +} diff --git a/internal/model/follow.go b/internal/model/follow.go new file mode 100644 index 0000000..c77a822 --- /dev/null +++ b/internal/model/follow.go @@ -0,0 +1,28 @@ +package model + +import ( + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +// Follow 关注关系 +type Follow struct { + ID string `json:"id" gorm:"type:varchar(36);primaryKey"` + FollowerID string `json:"follower_id" gorm:"type:varchar(36);index;not null;uniqueIndex:idx_follower_following"` // 关注者 + FollowingID string `json:"following_id" gorm:"type:varchar(36);index;not null;uniqueIndex:idx_follower_following"` // 被关注者 + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` +} + +// BeforeCreate 创建前生成UUID +func (f *Follow) BeforeCreate(tx *gorm.DB) error { + if f.ID == "" { + f.ID = uuid.New().String() + } + return nil +} + +func (Follow) TableName() string { + return "follows" +} diff --git a/internal/model/group.go b/internal/model/group.go new file mode 100644 index 0000000..05d1b36 --- /dev/null +++ b/internal/model/group.go @@ -0,0 +1,57 @@ +package model + +import ( + "strconv" + "time" + + "carrot_bbs/internal/pkg/utils" + + "gorm.io/gorm" +) + +// JoinType 群组加入类型 +type JoinType int + +const ( + JoinTypeAnyone JoinType = 0 // 允许任何人加入 + JoinTypeApproval JoinType = 1 // 需要审批 + JoinTypeForbidden JoinType = 2 // 不允许加入 +) + +// Group 群组模型 +type Group struct { + ID string `gorm:"primaryKey;size:20" json:"id"` + Name string `gorm:"size:50;not null" json:"name"` + Avatar string `gorm:"size:512" json:"avatar"` + Description string `gorm:"size:500" json:"description"` + OwnerID string `gorm:"type:varchar(36);not null;index" json:"owner_id"` + MemberCount int `gorm:"default:0" json:"member_count"` + MaxMembers int `gorm:"default:500" json:"max_members"` + JoinType JoinType `gorm:"default:0" json:"join_type"` // 0:允许任何人加入 1:需要审批 2:不允许加入 + MuteAll bool `gorm:"default:false" json:"mute_all"` // 全员禁言 + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// BeforeCreate 创建前生成雪花算法ID +func (g *Group) BeforeCreate(tx *gorm.DB) error { + if g.ID == "" { + id, err := utils.GetSnowflake().GenerateID() + if err != nil { + return err + } + g.ID = strconv.FormatInt(id, 10) + } + return nil +} + +// GetIDInt 获取数字类型的ID(用于比较) +func (g *Group) GetIDInt() uint64 { + id, _ := strconv.ParseUint(g.ID, 10, 64) + return id +} + +// TableName 指定表名 +func (Group) TableName() string { + return "groups" +} diff --git a/internal/model/group_announcement.go b/internal/model/group_announcement.go new file mode 100644 index 0000000..c0f9703 --- /dev/null +++ b/internal/model/group_announcement.go @@ -0,0 +1,38 @@ +package model + +import ( + "strconv" + "time" + + "carrot_bbs/internal/pkg/utils" + + "gorm.io/gorm" +) + +// GroupAnnouncement 群公告模型 +type GroupAnnouncement struct { + ID string `gorm:"primaryKey;size:20" json:"id"` + GroupID string `gorm:"not null;index" json:"group_id"` + Content string `gorm:"type:text;not null" json:"content"` + AuthorID string `gorm:"type:varchar(36);not null" json:"author_id"` + IsPinned bool `gorm:"default:false" json:"is_pinned"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// BeforeCreate 创建前生成雪花算法ID +func (ga *GroupAnnouncement) BeforeCreate(tx *gorm.DB) error { + if ga.ID == "" { + id, err := utils.GetSnowflake().GenerateID() + if err != nil { + return err + } + ga.ID = strconv.FormatInt(id, 10) + } + return nil +} + +// TableName 指定表名 +func (GroupAnnouncement) TableName() string { + return "group_announcements" +} diff --git a/internal/model/group_join_request.go b/internal/model/group_join_request.go new file mode 100644 index 0000000..27ec21a --- /dev/null +++ b/internal/model/group_join_request.go @@ -0,0 +1,59 @@ +package model + +import ( + "strconv" + "time" + + "carrot_bbs/internal/pkg/utils" + + "gorm.io/gorm" +) + +type GroupJoinRequestType string + +const ( + GroupJoinRequestTypeInvite GroupJoinRequestType = "invite" + GroupJoinRequestTypeJoinApply GroupJoinRequestType = "join_apply" +) + +type GroupJoinRequestStatus string + +const ( + GroupJoinRequestStatusPending GroupJoinRequestStatus = "pending" + GroupJoinRequestStatusAccepted GroupJoinRequestStatus = "accepted" + GroupJoinRequestStatusRejected GroupJoinRequestStatus = "rejected" + GroupJoinRequestStatusCancelled GroupJoinRequestStatus = "cancelled" + GroupJoinRequestStatusExpired GroupJoinRequestStatus = "expired" +) + +// GroupJoinRequest 统一保存邀请入群和主动加群申请 +type GroupJoinRequest struct { + ID string `gorm:"primaryKey;size:20" json:"id"` + Flag string `gorm:"size:64;uniqueIndex;not null" json:"flag"` + GroupID string `gorm:"not null;index;index:idx_gjr_group_target_type_status_created,priority:1" json:"group_id"` + InitiatorID string `gorm:"type:varchar(36);not null;index" json:"initiator_id"` + TargetUserID string `gorm:"type:varchar(36);not null;index;index:idx_gjr_group_target_type_status_created,priority:2" json:"target_user_id"` + RequestType GroupJoinRequestType `gorm:"size:20;not null;index;index:idx_gjr_group_target_type_status_created,priority:3" json:"request_type"` + Status GroupJoinRequestStatus `gorm:"size:20;not null;index;index:idx_gjr_group_target_type_status_created,priority:4" json:"status"` + Reason string `gorm:"size:500" json:"reason"` + ReviewerID string `gorm:"type:varchar(36);index" json:"reviewer_id"` + ReviewedAt *time.Time `json:"reviewed_at"` + ExpireAt *time.Time `json:"expire_at"` + CreatedAt time.Time `gorm:"index:idx_gjr_group_target_type_status_created,priority:5,sort:desc" json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (r *GroupJoinRequest) BeforeCreate(tx *gorm.DB) error { + if r.ID == "" { + id, err := utils.GetSnowflake().GenerateID() + if err != nil { + return err + } + r.ID = strconv.FormatInt(id, 10) + } + return nil +} + +func (GroupJoinRequest) TableName() string { + return "group_join_requests" +} diff --git a/internal/model/group_member.go b/internal/model/group_member.go new file mode 100644 index 0000000..6be9639 --- /dev/null +++ b/internal/model/group_member.go @@ -0,0 +1,47 @@ +package model + +import ( + "strconv" + "time" + + "carrot_bbs/internal/pkg/utils" + + "gorm.io/gorm" +) + +// GroupMember 群成员模型 +type GroupMember struct { + ID string `gorm:"primaryKey;size:20" json:"id"` + GroupID string `gorm:"not null;uniqueIndex:idx_group_user" json:"group_id"` + UserID string `gorm:"type:varchar(36);not null;uniqueIndex:idx_group_user" json:"user_id"` + Role string `gorm:"size:20;default:'member'" json:"role"` // owner, admin, member + Nickname string `gorm:"size:50" json:"nickname"` // 群内昵称 + Muted bool `gorm:"default:false" json:"muted"` // 是否被禁言 + JoinTime time.Time `json:"join_time"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// BeforeCreate 创建前生成雪花算法ID +func (gm *GroupMember) BeforeCreate(tx *gorm.DB) error { + if gm.ID == "" { + id, err := utils.GetSnowflake().GenerateID() + if err != nil { + return err + } + gm.ID = strconv.FormatInt(id, 10) + } + return nil +} + +// TableName 指定表名 +func (GroupMember) TableName() string { + return "group_members" +} + +// Role 常量 +const ( + GroupRoleOwner = "owner" + GroupRoleAdmin = "admin" + GroupRoleMember = "member" +) diff --git a/internal/model/init.go b/internal/model/init.go new file mode 100644 index 0000000..78cd78e --- /dev/null +++ b/internal/model/init.go @@ -0,0 +1,159 @@ +package model + +import ( + "fmt" + "log" + "os" + "time" + + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "carrot_bbs/internal/config" +) + +// DB 全局数据库连接 +var DB *gorm.DB + +// InitDB 初始化数据库连接 +func InitDB(cfg *config.DatabaseConfig) error { + var err error + var db *gorm.DB + gormLogger := logger.New( + log.New(os.Stdout, "\r\n", log.LstdFlags), + logger.Config{ + SlowThreshold: time.Duration(cfg.SlowThresholdMs) * time.Millisecond, + LogLevel: parseGormLogLevel(cfg.LogLevel), + IgnoreRecordNotFoundError: cfg.IgnoreRecordNotFound, + ParameterizedQueries: cfg.ParameterizedQueries, + Colorful: false, + }, + ) + + // 根据数据库类型选择驱动 + switch cfg.Type { + case "sqlite": + db, err = gorm.Open(sqlite.Open(cfg.SQLite.Path), &gorm.Config{ + Logger: gormLogger, + }) + case "postgres", "postgresql": + dsn := cfg.Postgres.DSN() + db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{ + Logger: gormLogger, + }) + default: + // 默认使用PostgreSQL + dsn := cfg.Postgres.DSN() + db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{ + Logger: gormLogger, + }) + } + + if err != nil { + return fmt.Errorf("failed to connect to database: %w", err) + } + + DB = db + + // 配置连接池(SQLite不支持连接池配置,跳过) + if cfg.Type != "sqlite" { + sqlDB, err := DB.DB() + if err != nil { + return fmt.Errorf("failed to get database instance: %w", err) + } + sqlDB.SetMaxIdleConns(cfg.MaxIdleConns) + sqlDB.SetMaxOpenConns(cfg.MaxOpenConns) + } + + // 自动迁移 + if err := autoMigrate(DB); err != nil { + return fmt.Errorf("failed to auto migrate: %w", err) + } + + log.Printf("Database connected (%s) and migrated successfully", cfg.Type) + return nil +} + +func parseGormLogLevel(level string) logger.LogLevel { + switch level { + case "silent": + return logger.Silent + case "error": + return logger.Error + case "warn": + return logger.Warn + case "info": + return logger.Info + default: + return logger.Warn + } +} + +// autoMigrate 自动迁移数据库表 +func autoMigrate(db *gorm.DB) error { + err := db.AutoMigrate( + // 用户相关 + &User{}, + + // 帖子相关 + &Post{}, + &PostImage{}, + + // 评论相关 + &Comment{}, + &CommentLike{}, + + // 消息相关(使用雪花算法ID和seq机制) + // 已读位置存储在 ConversationParticipant.LastReadSeq 中 + &Conversation{}, + &ConversationParticipant{}, + &Message{}, + + // 系统通知(独立表,每个用户只能看到自己的通知) + &SystemNotification{}, + + // 通知 + &Notification{}, + + // 推送中心相关 + &PushRecord{}, // 推送记录 + &DeviceToken{}, // 设备Token + + // 社交 + &Follow{}, + &UserBlock{}, + &PostLike{}, + &Favorite{}, + + // 投票 + &VoteOption{}, + &UserVote{}, + + // 敏感词和审核 + &SensitiveWord{}, + &AuditLog{}, + + // 群组相关 + &Group{}, + &GroupMember{}, + &GroupAnnouncement{}, + &GroupJoinRequest{}, + + // 自定义表情 + &UserSticker{}, + ) + if err != nil { + return err + } + + return nil +} + +// CloseDB 关闭数据库连接 +func CloseDB() { + if sqlDB, err := DB.DB(); err == nil { + sqlDB.Close() + } +} diff --git a/internal/model/like.go b/internal/model/like.go new file mode 100644 index 0000000..04fb9be --- /dev/null +++ b/internal/model/like.go @@ -0,0 +1,28 @@ +package model + +import ( + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +// PostLike 帖子点赞 +type PostLike struct { + ID string `json:"id" gorm:"type:varchar(36);primaryKey"` + PostID string `json:"post_id" gorm:"type:varchar(36);not null;index;uniqueIndex:idx_post_like_user,priority:1"` + UserID string `json:"user_id" gorm:"type:varchar(36);not null;index;uniqueIndex:idx_post_like_user,priority:2"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` +} + +// BeforeCreate 创建前生成UUID +func (pl *PostLike) BeforeCreate(tx *gorm.DB) error { + if pl.ID == "" { + pl.ID = uuid.New().String() + } + return nil +} + +func (PostLike) TableName() string { + return "post_likes" +} diff --git a/internal/model/message.go b/internal/model/message.go new file mode 100644 index 0000000..fac1eea --- /dev/null +++ b/internal/model/message.go @@ -0,0 +1,205 @@ +package model + +import ( + "database/sql/driver" + "encoding/json" + "errors" + "strconv" + "time" + + "gorm.io/gorm" + + "carrot_bbs/internal/pkg/utils" +) + +// 系统消息相关常量 +const ( + // SystemSenderID 系统消息发送者ID (类似QQ 10000) + SystemSenderID int64 = 10000 + + // SystemSenderIDStr 系统消息发送者ID字符串版本 + SystemSenderIDStr string = "10000" + + // SystemConversationID 系统通知会话ID (string类型) + SystemConversationID string = "9999999999" +) + +// ContentType 消息内容类型 +type ContentType string + +const ( + ContentTypeText ContentType = "text" + ContentTypeImage ContentType = "image" + ContentTypeVideo ContentType = "video" + ContentTypeAudio ContentType = "audio" + ContentTypeFile ContentType = "file" +) + +// MessageStatus 消息状态 +type MessageStatus string + +const ( + MessageStatusNormal MessageStatus = "normal" // 正常 + MessageStatusRecalled MessageStatus = "recalled" // 已撤回 + MessageStatusDeleted MessageStatus = "deleted" // 已删除 +) + +// MessageCategory 消息类别 +type MessageCategory string + +const ( + CategoryChat MessageCategory = "chat" // 普通聊天 + CategoryNotification MessageCategory = "notification" // 通知类消息 + CategoryAnnouncement MessageCategory = "announcement" // 系统公告 + CategoryMarketing MessageCategory = "marketing" // 营销消息 +) + +// SystemMessageType 系统消息类型 (对应原NotificationType) +type SystemMessageType string + +const ( + // 互动通知 + SystemTypeLikePost SystemMessageType = "like_post" // 点赞帖子 + SystemTypeLikeComment SystemMessageType = "like_comment" // 点赞评论 + SystemTypeComment SystemMessageType = "comment" // 评论 + SystemTypeReply SystemMessageType = "reply" // 回复 + SystemTypeFollow SystemMessageType = "follow" // 关注 + SystemTypeMention SystemMessageType = "mention" // @提及 + + // 系统消息 + SystemTypeSystem SystemMessageType = "system" // 系统通知 + SystemTypeAnnounce SystemMessageType = "announce" // 系统公告 +) + +// ExtraData 消息额外数据,用于存储系统消息的相关信息 +type ExtraData struct { + // 操作者信息 + ActorID int64 `json:"actor_id,omitempty"` // 操作者ID (数字格式,兼容旧数据) + ActorIDStr string `json:"actor_id_str,omitempty"` // 操作者ID (UUID字符串格式) + ActorName string `json:"actor_name,omitempty"` // 操作者名称 + AvatarURL string `json:"avatar_url,omitempty"` // 操作者头像 + + // 目标信息 + TargetID int64 `json:"target_id,omitempty"` // 目标ID(帖子ID、评论ID等) + TargetTitle string `json:"target_title,omitempty"` // 目标标题 + TargetType string `json:"target_type,omitempty"` // 目标类型(post/comment等) + + // 其他信息 + ActionURL string `json:"action_url,omitempty"` // 跳转链接 + ActionTime string `json:"action_time,omitempty"` // 操作时间 +} + +// Value 实现driver.Valuer接口,用于数据库存储 +func (e ExtraData) Value() (driver.Value, error) { + return json.Marshal(e) +} + +// Scan 实现sql.Scanner接口,用于数据库读取 +func (e *ExtraData) Scan(value interface{}) error { + if value == nil { + return nil + } + bytes, ok := value.([]byte) + if !ok { + return errors.New("type assertion to []byte failed") + } + return json.Unmarshal(bytes, e) +} + +// MessageSegmentData 单个消息段的数据 +type MessageSegmentData map[string]interface{} + +// MessageSegment 消息段 +type MessageSegment struct { + Type string `json:"type"` + Data MessageSegmentData `json:"data"` +} + +// MessageSegments 消息链类型 +type MessageSegments []MessageSegment + +// Value 实现driver.Valuer接口,用于数据库存储 +func (s MessageSegments) Value() (driver.Value, error) { + return json.Marshal(s) +} + +// Scan 实现sql.Scanner接口,用于数据库读取 +func (s *MessageSegments) Scan(value interface{}) error { + if value == nil { + *s = nil + return nil + } + bytes, ok := value.([]byte) + if !ok { + return errors.New("type assertion to []byte failed") + } + return json.Unmarshal(bytes, s) +} + +// Message 消息实体 +// 使用雪花算法ID(string类型)和seq机制实现消息排序和增量同步 +type Message struct { + ID string `gorm:"primaryKey;size:20" json:"id"` // 雪花算法ID(string类型) + ConversationID string `gorm:"not null;size:20;index:idx_msg_conversation_seq,priority:1" json:"conversation_id"` // 会话ID(string类型) + SenderID string `gorm:"column:sender_id;type:varchar(50);index;not null" json:"sender_id"` // 发送者ID (UUID格式) + Seq int64 `gorm:"not null;index:idx_msg_conversation_seq,priority:2" json:"seq"` // 会话内序号,用于排序和增量同步 + Segments MessageSegments `gorm:"type:json" json:"segments"` // 消息链(结构体数组) + ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID(string类型) + Status MessageStatus `gorm:"type:varchar(20);default:'normal'" json:"status"` // 消息状态 + + // 新增字段:消息分类和系统消息类型 + Category MessageCategory `gorm:"type:varchar(20);default:'chat'" json:"category"` // 消息分类 + SystemType SystemMessageType `gorm:"type:varchar(30)" json:"system_type,omitempty"` // 系统消息类型 + ExtraData *ExtraData `gorm:"type:json" json:"extra_data,omitempty"` // 额外数据(JSON格式) + + // @相关字段 + MentionUsers string `gorm:"type:text" json:"mention_users"` // @的用户ID列表,JSON数组 + MentionAll bool `gorm:"default:false" json:"mention_all"` // 是否@所有人 + + // 软删除 + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` + + // 时间戳 + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` +} + +// SenderIDStr 返回发送者ID字符串(保持兼容性) +func (m *Message) SenderIDStr() string { + return m.SenderID +} + +// BeforeCreate 创建前生成雪花算法ID +func (m *Message) BeforeCreate(tx *gorm.DB) error { + if m.ID == "" { + id, err := utils.GetSnowflake().GenerateID() + if err != nil { + return err + } + m.ID = strconv.FormatInt(id, 10) + } + return nil +} + +func (Message) TableName() string { + return "messages" +} + +// IsSystemMessage 判断是否为系统消息 +func (m *Message) IsSystemMessage() bool { + return m.SenderID == SystemSenderIDStr || m.Category == CategoryNotification || m.Category == CategoryAnnouncement +} + +// IsInteractionNotification 判断是否为互动通知 +func (m *Message) IsInteractionNotification() bool { + if m.Category != CategoryNotification { + return false + } + switch m.SystemType { + case SystemTypeLikePost, SystemTypeLikeComment, SystemTypeComment, + SystemTypeReply, SystemTypeFollow, SystemTypeMention: + return true + default: + return false + } +} diff --git a/internal/model/message_read.go b/internal/model/message_read.go new file mode 100644 index 0000000..00b3e24 --- /dev/null +++ b/internal/model/message_read.go @@ -0,0 +1,19 @@ +package model + +import ( + "time" +) + +// MessageRead 消息已读状态 +// 记录每个用户在每个会话中的已读位置 +type MessageRead struct { + ID uint `gorm:"primaryKey" json:"id"` + ConversationID int64 `gorm:"uniqueIndex:idx_conversation_user;not null" json:"conversation_id"` + UserID uint `gorm:"uniqueIndex:idx_conversation_user;not null" json:"user_id"` + LastReadSeq int64 `gorm:"not null" json:"last_read_seq"` // 已读到的seq位置 + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` +} + +func (MessageRead) TableName() string { + return "message_reads" +} diff --git a/internal/model/notification.go b/internal/model/notification.go new file mode 100644 index 0000000..3248e74 --- /dev/null +++ b/internal/model/notification.go @@ -0,0 +1,53 @@ +package model + +import ( + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +// NotificationType 通知类型 +type NotificationType string + +const ( + NotificationTypeLikePost NotificationType = "like_post" + NotificationTypeLikeComment NotificationType = "like_comment" + NotificationTypeComment NotificationType = "comment" + NotificationTypeReply NotificationType = "reply" + NotificationTypeFollow NotificationType = "follow" + NotificationTypeMention NotificationType = "mention" + NotificationTypeSystem NotificationType = "system" +) + +// Notification 通知实体 +type Notification struct { + ID string `json:"id" gorm:"type:varchar(36);primaryKey"` + UserID string `json:"user_id" gorm:"type:varchar(36);not null;index:idx_notifications_user_read_created,priority:1"` // 接收者 + Type NotificationType `json:"type" gorm:"type:varchar(30);not null"` + Title string `json:"title" gorm:"type:varchar(200);not null"` + Content string `json:"content" gorm:"type:text"` + Data string `json:"data" gorm:"type:jsonb"` // 相关数据(JSON) + + // 已读状态 + IsRead bool `json:"is_read" gorm:"default:false;index:idx_notifications_user_read_created,priority:2"` + ReadAt *time.Time `json:"read_at" gorm:"type:timestamp"` + + // 软删除 + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` + + // 时间戳 + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_notifications_user_read_created,priority:3,sort:desc"` +} + +// BeforeCreate 创建前生成UUID +func (n *Notification) BeforeCreate(tx *gorm.DB) error { + if n.ID == "" { + n.ID = uuid.New().String() + } + return nil +} + +func (Notification) TableName() string { + return "notifications" +} diff --git a/internal/model/post.go b/internal/model/post.go new file mode 100644 index 0000000..1c21e7d --- /dev/null +++ b/internal/model/post.go @@ -0,0 +1,100 @@ +package model + +import ( + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +// PostStatus 帖子状态 +type PostStatus string + +const ( + PostStatusDraft PostStatus = "draft" + PostStatusPending PostStatus = "pending" // 待审核 + PostStatusPublished PostStatus = "published" + PostStatusRejected PostStatus = "rejected" + PostStatusDeleted PostStatus = "deleted" +) + +// Post 帖子实体 +type Post struct { + ID string `json:"id" gorm:"type:varchar(36);primaryKey"` + UserID string `json:"user_id" gorm:"type:varchar(36);index;index:idx_posts_user_status_created,priority:1;not null"` + CommunityID string `json:"community_id" gorm:"type:varchar(36);index"` + Title string `json:"title" gorm:"type:varchar(200);not null"` + Content string `json:"content" gorm:"type:text;not null"` + + // 关联 + // User 需要参与缓存序列化;否则列表命中缓存后会丢失作者信息,前端退化为“匿名用户” + User *User `json:"user,omitempty" gorm:"foreignKey:UserID"` + Images []PostImage `json:"images" gorm:"foreignKey:PostID"` + + // 审核状态 + Status PostStatus `json:"status" gorm:"type:varchar(20);default:published;index:idx_posts_status_created,priority:1;index:idx_posts_user_status_created,priority:2"` + ReviewedAt *time.Time `json:"reviewed_at" gorm:"type:timestamp"` + ReviewedBy string `json:"reviewed_by" gorm:"type:varchar(50)"` + RejectReason string `json:"reject_reason" gorm:"type:varchar(500)"` + + // 统计 + LikesCount int `json:"likes_count" gorm:"column:likes_count;default:0"` + CommentsCount int `json:"comments_count" gorm:"column:comments_count;default:0"` + FavoritesCount int `json:"favorites_count" gorm:"column:favorites_count;default:0"` + SharesCount int `json:"shares_count" gorm:"column:shares_count;default:0"` + ViewsCount int `json:"views_count" gorm:"column:views_count;default:0"` + HotScore float64 `json:"hot_score" gorm:"column:hot_score;default:0;index:idx_posts_hot_score_created,priority:1"` + + // 置顶/锁定 + IsPinned bool `json:"is_pinned" gorm:"default:false"` + IsLocked bool `json:"is_locked" gorm:"default:false"` + IsDeleted bool `json:"-" gorm:"default:false"` + + // 投票 + IsVote bool `json:"is_vote" gorm:"column:is_vote;default:false"` + + // 软删除 + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` + + // 时间戳 + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_posts_status_created,priority:2,sort:desc;index:idx_posts_user_status_created,priority:3,sort:desc;index:idx_posts_hot_score_created,priority:2,sort:desc"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` +} + +// BeforeCreate 创建前生成UUID +func (p *Post) BeforeCreate(tx *gorm.DB) error { + if p.ID == "" { + p.ID = uuid.New().String() + } + return nil +} + +func (Post) TableName() string { + return "posts" +} + +// PostImage 帖子图片 +type PostImage struct { + ID string `json:"id" gorm:"type:varchar(36);primaryKey"` + PostID string `json:"post_id" gorm:"type:varchar(36);index;not null"` + URL string `json:"url" gorm:"type:text;not null"` + ThumbnailURL string `json:"thumbnail_url" gorm:"type:text"` + Width int `json:"width" gorm:"default:0"` + Height int `json:"height" gorm:"default:0"` + Size int64 `json:"size" gorm:"default:0"` // 文件大小(字节) + MimeType string `json:"mime_type" gorm:"type:varchar(50)"` + SortOrder int `json:"sort_order" gorm:"default:0"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` +} + +// BeforeCreate 创建前生成UUID +func (pi *PostImage) BeforeCreate(tx *gorm.DB) error { + if pi.ID == "" { + pi.ID = uuid.New().String() + } + return nil +} + +func (PostImage) TableName() string { + return "post_images" +} diff --git a/internal/model/push_record.go b/internal/model/push_record.go new file mode 100644 index 0000000..f29288e --- /dev/null +++ b/internal/model/push_record.go @@ -0,0 +1,129 @@ +package model + +import ( + "time" + + "gorm.io/gorm" + + "carrot_bbs/internal/pkg/utils" +) + +// PushChannel 推送通道类型 +type PushChannel string + +const ( + PushChannelWebSocket PushChannel = "websocket" // WebSocket推送 + PushChannelFCM PushChannel = "fcm" // Firebase Cloud Messaging + PushChannelAPNs PushChannel = "apns" // Apple Push Notification service + PushChannelHuawei PushChannel = "huawei" // 华为推送 +) + +// PushStatus 推送状态 +type PushStatus string + +const ( + PushStatusPending PushStatus = "pending" // 待推送 + PushStatusPushing PushStatus = "pushing" // 推送中 + PushStatusPushed PushStatus = "pushed" // 已推送(成功发送到推送服务) + PushStatusDelivered PushStatus = "delivered" // 已送达(客户端确认) + PushStatusFailed PushStatus = "failed" // 推送失败 + PushStatusExpired PushStatus = "expired" // 消息过期 +) + +// PushRecord 推送记录实体 +// 用于跟踪消息的推送状态,支持多设备推送和重试机制 +type PushRecord struct { + ID int64 `gorm:"primaryKey;autoIncrement:false" json:"id"` // 雪花算法ID + UserID string `gorm:"column:user_id;type:varchar(50);index;not null" json:"user_id"` // 目标用户ID (UUID格式) + MessageID string `gorm:"index;not null;size:20" json:"message_id"` // 关联的消息ID (string类型) + PushChannel PushChannel `gorm:"type:varchar(20);not null" json:"push_channel"` // 推送通道 + PushStatus PushStatus `gorm:"type:varchar(20);not null;default:'pending'" json:"push_status"` // 推送状态 + + // 设备信息 + DeviceToken string `gorm:"type:varchar(256)" json:"device_token,omitempty"` // 设备Token(用于手机推送) + DeviceType string `gorm:"type:varchar(20)" json:"device_type,omitempty"` // 设备类型 (ios/android/web) + + // 重试机制 + RetryCount int `gorm:"default:0" json:"retry_count"` // 重试次数 + MaxRetry int `gorm:"default:3" json:"max_retry"` // 最大重试次数 + + // 时间戳 + PushedAt *time.Time `json:"pushed_at,omitempty"` // 推送时间 + DeliveredAt *time.Time `json:"delivered_at,omitempty"` // 送达时间 + ExpiredAt *time.Time `gorm:"index" json:"expired_at,omitempty"` // 过期时间 + + // 错误信息 + ErrorMessage string `gorm:"type:varchar(500)" json:"error_message,omitempty"` // 错误信息 + + // 软删除 + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` + + // 时间戳 + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` +} + +// BeforeCreate 创建前生成雪花算法ID +func (r *PushRecord) BeforeCreate(tx *gorm.DB) error { + if r.ID == 0 { + id, err := utils.GetSnowflake().GenerateID() + if err != nil { + return err + } + r.ID = id + } + return nil +} + +func (PushRecord) TableName() string { + return "push_records" +} + +// CanRetry 判断是否可以重试 +func (r *PushRecord) CanRetry() bool { + return r.RetryCount < r.MaxRetry && r.PushStatus != PushStatusDelivered && r.PushStatus != PushStatusExpired +} + +// IsExpired 判断是否已过期 +func (r *PushRecord) IsExpired() bool { + if r.ExpiredAt == nil { + return false + } + return time.Now().After(*r.ExpiredAt) +} + +// MarkPushing 标记为推送中 +func (r *PushRecord) MarkPushing() { + r.PushStatus = PushStatusPushing +} + +// MarkPushed 标记为已推送 +func (r *PushRecord) MarkPushed() { + now := time.Now() + r.PushStatus = PushStatusPushed + r.PushedAt = &now +} + +// MarkDelivered 标记为已送达 +func (r *PushRecord) MarkDelivered() { + now := time.Now() + r.PushStatus = PushStatusDelivered + r.DeliveredAt = &now +} + +// MarkFailed 标记为推送失败 +func (r *PushRecord) MarkFailed(errMsg string) { + r.PushStatus = PushStatusFailed + r.ErrorMessage = errMsg + r.RetryCount++ +} + +// MarkExpired 标记为已过期 +func (r *PushRecord) MarkExpired() { + r.PushStatus = PushStatusExpired +} + +// IncrementRetry 增加重试次数 +func (r *PushRecord) IncrementRetry() { + r.RetryCount++ +} diff --git a/internal/model/sensitive_word.go b/internal/model/sensitive_word.go new file mode 100644 index 0000000..e533beb --- /dev/null +++ b/internal/model/sensitive_word.go @@ -0,0 +1,77 @@ +package model + +import ( + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +// SensitiveWordLevel 敏感词级别 +type SensitiveWordLevel int + +const ( + SensitiveWordLevelLow SensitiveWordLevel = 1 // 低危 + SensitiveWordLevelMedium SensitiveWordLevel = 2 // 中危 + SensitiveWordLevelHigh SensitiveWordLevel = 3 // 高危 +) + +// SensitiveWordCategory 敏感词分类 +type SensitiveWordCategory string + +const ( + SensitiveWordCategoryPolitical SensitiveWordCategory = "political" // 政治 + SensitiveWordCategoryPorn SensitiveWordCategory = "porn" // 色情 + SensitiveWordCategoryViolence SensitiveWordCategory = "violence" // 暴力 + SensitiveWordCategoryAd SensitiveWordCategory = "ad" // 广告 + SensitiveWordCategoryGamble SensitiveWordCategory = "gamble" // 赌博 + SensitiveWordCategoryFraud SensitiveWordCategory = "fraud" // 诈骗 + SensitiveWordCategoryOther SensitiveWordCategory = "other" // 其他 +) + +// SensitiveWord 敏感词实体 +type SensitiveWord struct { + ID string `gorm:"type:varchar(36);primaryKey"` + Word string `gorm:"type:varchar(255);uniqueIndex;not null"` + Category SensitiveWordCategory `gorm:"type:varchar(50);index"` + Level SensitiveWordLevel `gorm:"type:int;default:1"` + IsActive bool `gorm:"default:true"` + CreatedBy string `gorm:"type:varchar(255)"` + UpdatedBy string `gorm:"type:varchar(255)"` + Remark string `gorm:"type:text"` + CreatedAt time.Time `gorm:"autoCreateTime"` + UpdatedAt time.Time `gorm:"autoUpdateTime"` + DeletedAt gorm.DeletedAt `gorm:"index"` +} + +// BeforeCreate 创建前生成UUID +func (sw *SensitiveWord) BeforeCreate(tx *gorm.DB) error { + if sw.ID == "" { + sw.ID = uuid.New().String() + } + return nil +} + +func (SensitiveWord) TableName() string { + return "sensitive_words" +} + +// SensitiveWordRequest 创建/更新敏感词请求 +type SensitiveWordRequest struct { + Word string `json:"word" validate:"required,min=1,max=255"` + Category SensitiveWordCategory `json:"category"` + Level SensitiveWordLevel `json:"level"` + Remark string `json:"remark"` + CreatedBy string `json:"-"` +} + +// SensitiveWordListItem 敏感词列表项(用于列表展示) +type SensitiveWordListItem struct { + ID string `json:"id"` + Word string `json:"word"` + Category SensitiveWordCategory `json:"category"` + Level SensitiveWordLevel `json:"level"` + IsActive bool `json:"is_active"` + CreatedAt time.Time `json:"created_at"` + Remark string `json:"remark"` +} diff --git a/internal/model/sticker.go b/internal/model/sticker.go new file mode 100644 index 0000000..f2963d3 --- /dev/null +++ b/internal/model/sticker.go @@ -0,0 +1,33 @@ +package model + +import ( + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +// UserSticker 用户自定义表情 +type UserSticker struct { + ID string `json:"id" gorm:"type:varchar(36);primaryKey"` + UserID string `json:"user_id" gorm:"type:varchar(36);not null;index:idx_user_stickers"` + URL string `json:"url" gorm:"type:text;not null"` + Width int `json:"width" gorm:"default:0"` + Height int `json:"height" gorm:"default:0"` + SortOrder int `json:"sort_order" gorm:"default:0;index:idx_user_stickers_sort"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` +} + +// TableName 表名 +func (UserSticker) TableName() string { + return "user_stickers" +} + +// BeforeCreate 创建前生成UUID +func (s *UserSticker) BeforeCreate(tx *gorm.DB) error { + if s.ID == "" { + s.ID = uuid.New().String() + } + return nil +} diff --git a/internal/model/system_notification.go b/internal/model/system_notification.go new file mode 100644 index 0000000..ef55db8 --- /dev/null +++ b/internal/model/system_notification.go @@ -0,0 +1,127 @@ +package model + +import ( + "database/sql/driver" + "encoding/json" + "errors" + "time" + + "gorm.io/gorm" + + "carrot_bbs/internal/pkg/utils" +) + +// SystemNotificationType 系统通知类型 +type SystemNotificationType string + +const ( + // 互动通知 + SysNotifyLikePost SystemNotificationType = "like_post" // 点赞帖子 + SysNotifyLikeComment SystemNotificationType = "like_comment" // 点赞评论 + SysNotifyComment SystemNotificationType = "comment" // 评论 + SysNotifyReply SystemNotificationType = "reply" // 回复 + SysNotifyFollow SystemNotificationType = "follow" // 关注 + SysNotifyMention SystemNotificationType = "mention" // @提及 + SysNotifyFavoritePost SystemNotificationType = "favorite_post" // 收藏帖子 + SysNotifyLikeReply SystemNotificationType = "like_reply" // 点赞回复 + + // 系统消息 + SysNotifySystem SystemNotificationType = "system" // 系统通知 + SysNotifyAnnounce SystemNotificationType = "announce" // 系统公告 + SysNotifyGroupInvite SystemNotificationType = "group_invite" // 群邀请 + SysNotifyGroupJoinApply SystemNotificationType = "group_join_apply" // 加群申请待审批 + SysNotifyGroupJoinApproved SystemNotificationType = "group_join_approved" // 加群申请通过 + SysNotifyGroupJoinRejected SystemNotificationType = "group_join_rejected" // 加群申请拒绝 +) + +// SystemNotificationExtra 额外数据 +type SystemNotificationExtra struct { + // 操作者信息 + ActorID int64 `json:"actor_id,omitempty"` + ActorIDStr string `json:"actor_id_str,omitempty"` + ActorName string `json:"actor_name,omitempty"` + AvatarURL string `json:"avatar_url,omitempty"` + + // 目标信息 + TargetID string `json:"target_id,omitempty"` // 改为string类型以支持UUID + TargetTitle string `json:"target_title,omitempty"` + TargetType string `json:"target_type,omitempty"` + + // 其他信息 + ActionURL string `json:"action_url,omitempty"` + ActionTime string `json:"action_time,omitempty"` + + // 群邀请/加群申请扩展字段 + GroupID string `json:"group_id,omitempty"` + GroupName string `json:"group_name,omitempty"` + GroupAvatar string `json:"group_avatar,omitempty"` + GroupDescription string `json:"group_description,omitempty"` + Flag string `json:"flag,omitempty"` + RequestType string `json:"request_type,omitempty"` + RequestStatus string `json:"request_status,omitempty"` + Reason string `json:"reason,omitempty"` + TargetUserID string `json:"target_user_id,omitempty"` + TargetUserName string `json:"target_user_name,omitempty"` + TargetUserAvatar string `json:"target_user_avatar,omitempty"` +} + +// Value 实现driver.Valuer接口 +func (e SystemNotificationExtra) Value() (driver.Value, error) { + return json.Marshal(e) +} + +// Scan 实现sql.Scanner接口 +func (e *SystemNotificationExtra) Scan(value interface{}) error { + if value == nil { + return nil + } + bytes, ok := value.([]byte) + if !ok { + return errors.New("type assertion to []byte failed") + } + return json.Unmarshal(bytes, e) +} + +// SystemNotification 系统通知(独立表,与消息完全分离) +// 每个用户只能看到自己的系统通知 +type SystemNotification struct { + ID int64 `gorm:"primaryKey;autoIncrement:false" json:"id"` + ReceiverID string `gorm:"column:receiver_id;type:varchar(50);not null;index:idx_sys_notifications_receiver_read_created,priority:1" json:"receiver_id"` // 接收者ID (UUID) + Type SystemNotificationType `gorm:"type:varchar(30);not null" json:"type"` // 通知类型 + Title string `gorm:"type:varchar(200)" json:"title,omitempty"` // 标题 + Content string `gorm:"type:text;not null" json:"content"` // 内容 + ExtraData *SystemNotificationExtra `gorm:"type:json" json:"extra_data,omitempty"` // 额外数据 + IsRead bool `gorm:"default:false;index:idx_sys_notifications_receiver_read_created,priority:2" json:"is_read"` // 是否已读 + ReadAt *time.Time `json:"read_at,omitempty"` // 阅读时间 + + // 软删除 + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` + + // 时间戳 + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime;index:idx_sys_notifications_receiver_read_created,priority:3,sort:desc"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` +} + +// BeforeCreate 创建前生成雪花算法ID +func (n *SystemNotification) BeforeCreate(tx *gorm.DB) error { + if n.ID == 0 { + id, err := utils.GetSnowflake().GenerateID() + if err != nil { + return err + } + n.ID = id + } + return nil +} + +// TableName 指定表名 +func (SystemNotification) TableName() string { + return "system_notifications" +} + +// MarkAsRead 标记为已读 +func (n *SystemNotification) MarkAsRead() { + now := time.Now() + n.IsRead = true + n.ReadAt = &now +} diff --git a/internal/model/user.go b/internal/model/user.go new file mode 100644 index 0000000..a557103 --- /dev/null +++ b/internal/model/user.go @@ -0,0 +1,66 @@ +package model + +import ( + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +// UserStatus 用户状态 +type UserStatus string + +const ( + UserStatusActive UserStatus = "active" + UserStatusBanned UserStatus = "banned" + UserStatusInactive UserStatus = "inactive" +) + +// User 用户实体 +type User struct { + ID string `json:"id" gorm:"type:varchar(36);primaryKey"` + Username string `json:"username" gorm:"type:varchar(50);uniqueIndex;not null"` + Nickname string `json:"nickname" gorm:"type:varchar(100);not null"` + Email *string `json:"email" gorm:"type:varchar(255);uniqueIndex"` + Phone *string `json:"phone" gorm:"type:varchar(20);uniqueIndex"` + EmailVerified bool `json:"email_verified" gorm:"default:false"` + PasswordHash string `json:"-" gorm:"type:varchar(255);not null"` + Avatar string `json:"avatar" gorm:"type:text"` + CoverURL string `json:"cover_url" gorm:"type:text"` // 头图URL + Bio string `json:"bio" gorm:"type:text"` + Website string `json:"website" gorm:"type:varchar(255)"` + Location string `json:"location" gorm:"type:varchar(100)"` + + // 实名认证信息(可选) + RealName string `json:"real_name" gorm:"type:varchar(100)"` // 真实姓名 + IDCard string `json:"-" gorm:"type:varchar(18)"` // 身份证号(加密存储) + IsVerified bool `json:"is_verified" gorm:"default:false"` // 是否实名认证 + VerifiedAt *time.Time `json:"verified_at" gorm:"type:timestamp"` + + // 统计计数 + PostsCount int `json:"posts_count" gorm:"default:0"` + FollowersCount int `json:"followers_count" gorm:"default:0"` + FollowingCount int `json:"following_count" gorm:"default:0"` + + // 状态 + Status UserStatus `json:"status" gorm:"type:varchar(20);default:active"` + LastLoginAt *time.Time `json:"last_login_at" gorm:"type:timestamp"` + LastLoginIP string `json:"last_login_ip" gorm:"type:varchar(45)"` + + // 时间戳 + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` +} + +// BeforeCreate 创建前生成UUID +func (u *User) BeforeCreate(tx *gorm.DB) error { + if u.ID == "" { + u.ID = uuid.New().String() + } + return nil +} + +func (User) TableName() string { + return "users" +} diff --git a/internal/model/user_block.go b/internal/model/user_block.go new file mode 100644 index 0000000..c4c9e62 --- /dev/null +++ b/internal/model/user_block.go @@ -0,0 +1,27 @@ +package model + +import ( + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +// UserBlock 用户拉黑关系 +type UserBlock struct { + ID string `json:"id" gorm:"type:varchar(36);primaryKey"` + BlockerID string `json:"blocker_id" gorm:"type:varchar(36);index;not null;uniqueIndex:idx_blocker_blocked"` // 拉黑人 + BlockedID string `json:"blocked_id" gorm:"type:varchar(36);index;not null;uniqueIndex:idx_blocker_blocked"` // 被拉黑人 + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` +} + +func (b *UserBlock) BeforeCreate(tx *gorm.DB) error { + if b.ID == "" { + b.ID = uuid.New().String() + } + return nil +} + +func (UserBlock) TableName() string { + return "user_blocks" +} diff --git a/internal/model/vote.go b/internal/model/vote.go new file mode 100644 index 0000000..aff0983 --- /dev/null +++ b/internal/model/vote.go @@ -0,0 +1,52 @@ +package model + +import ( + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +// VoteOption 投票选项 +type VoteOption struct { + ID string `json:"id" gorm:"type:varchar(36);primaryKey"` + PostID string `json:"post_id" gorm:"type:varchar(36);index:idx_vote_option_post_sort,priority:1;not null"` + Content string `json:"content" gorm:"type:varchar(200);not null"` + SortOrder int `json:"sort_order" gorm:"default:0;index:idx_vote_option_post_sort,priority:2"` + VotesCount int `json:"votes_count" gorm:"default:0"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` + UpdatedAt time.Time `json:"updated_at" gorm:"autoUpdateTime"` +} + +// BeforeCreate 创建前生成UUID +func (vo *VoteOption) BeforeCreate(tx *gorm.DB) error { + if vo.ID == "" { + vo.ID = uuid.New().String() + } + return nil +} + +func (VoteOption) TableName() string { + return "vote_options" +} + +// UserVote 用户投票记录 +type UserVote struct { + ID string `json:"id" gorm:"type:varchar(36);primaryKey"` + PostID string `json:"post_id" gorm:"type:varchar(36);index;uniqueIndex:idx_user_vote_post_user,priority:1;not null"` + UserID string `json:"user_id" gorm:"type:varchar(36);index;uniqueIndex:idx_user_vote_post_user,priority:2;not null"` + OptionID string `json:"option_id" gorm:"type:varchar(36);index;not null"` + CreatedAt time.Time `json:"created_at" gorm:"autoCreateTime"` +} + +// BeforeCreate 创建前生成UUID +func (uv *UserVote) BeforeCreate(tx *gorm.DB) error { + if uv.ID == "" { + uv.ID = uuid.New().String() + } + return nil +} + +func (UserVote) TableName() string { + return "user_votes" +} diff --git a/internal/pkg/avatar/avatar.go b/internal/pkg/avatar/avatar.go new file mode 100644 index 0000000..6e0eba5 --- /dev/null +++ b/internal/pkg/avatar/avatar.go @@ -0,0 +1,115 @@ +package avatar + +import ( + "encoding/base64" + "fmt" + "unicode/utf8" +) + +// 预定义一组好看的颜色 +var colors = []string{ + "#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", + "#FFEAA7", "#DDA0DD", "#98D8C8", "#F7DC6F", + "#BB8FCE", "#85C1E9", "#F8B500", "#00CED1", + "#E74C3C", "#3498DB", "#2ECC71", "#9B59B6", + "#1ABC9C", "#F39C12", "#E67E22", "#16A085", +} + +// SVG模板 +const svgTemplate = ` + + %s +` + +// GenerateSVGAvatar 根据用户名生成SVG头像 +// username: 用户名 +// size: 头像尺寸(像素) +func GenerateSVGAvatar(username string, size int) string { + initials := getInitials(username) + color := stringToColor(username) + return fmt.Sprintf(svgTemplate, size, size, color, initials) +} + +// GenerateAvatarDataURI 生成Data URI格式的头像 +// 可以直接在HTML img标签或CSS background-image中使用 +func GenerateAvatarDataURI(username string, size int) string { + svg := GenerateSVGAvatar(username, size) + encoded := base64.StdEncoding.EncodeToString([]byte(svg)) + return fmt.Sprintf("data:image/svg+xml;base64,%s", encoded) +} + +// getInitials 获取用户名首字母 +// 中文取第一个字,英文取首字母(最多2个) +func getInitials(username string) string { + if username == "" { + return "?" + } + + // 检查是否是中文字符 + firstRune, _ := utf8.DecodeRuneInString(username) + if isChinese(firstRune) { + // 中文直接返回第一个字符 + return string(firstRune) + } + + // 英文处理:取前两个单词的首字母 + // 例如: "John Doe" -> "JD", "john" -> "J" + result := []rune{} + for i, r := range username { + if i == 0 { + result = append(result, toUpper(r)) + } else if r == ' ' || r == '_' || r == '-' { + // 找到下一个字符作为第二个首字母 + nextIdx := i + 1 + if nextIdx < len(username) { + nextRune, _ := utf8.DecodeRuneInString(username[nextIdx:]) + if nextRune != utf8.RuneError && nextRune != ' ' { + result = append(result, toUpper(nextRune)) + break + } + } + } + } + + if len(result) == 0 { + return "?" + } + + // 最多返回2个字符 + if len(result) > 2 { + result = result[:2] + } + + return string(result) +} + +// isChinese 判断是否是中文字符 +func isChinese(r rune) bool { + return r >= 0x4E00 && r <= 0x9FFF +} + +// toUpper 将字母转换为大写 +func toUpper(r rune) rune { + if r >= 'a' && r <= 'z' { + return r - 32 + } + return r +} + +// stringToColor 根据字符串生成颜色 +// 使用简单的哈希算法确保同一用户名每次生成的颜色一致 +func stringToColor(s string) string { + if s == "" { + return colors[0] + } + + hash := 0 + for _, r := range s { + hash = (hash*31 + int(r)) % len(colors) + } + if hash < 0 { + hash = -hash + } + + return colors[hash%len(colors)] +} diff --git a/internal/pkg/avatar/avatar_test.go b/internal/pkg/avatar/avatar_test.go new file mode 100644 index 0000000..c38d934 --- /dev/null +++ b/internal/pkg/avatar/avatar_test.go @@ -0,0 +1,118 @@ +package avatar + +import ( + "strings" + "testing" +) + +func TestGetInitials(t *testing.T) { + tests := []struct { + name string + username string + want string + }{ + {"中文用户名", "张三", "张"}, + {"英文用户名", "John", "J"}, + {"英文全名", "John Doe", "JD"}, + {"带下划线", "john_doe", "JD"}, + {"带连字符", "john-doe", "JD"}, + {"空字符串", "", "?"}, + {"小写英文", "alice", "A"}, + {"中文复合", "李小龙", "李"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := getInitials(tt.username) + if got != tt.want { + t.Errorf("getInitials(%q) = %q, want %q", tt.username, got, tt.want) + } + }) + } +} + +func TestStringToColor(t *testing.T) { + // 测试同一用户名生成的颜色一致 + color1 := stringToColor("张三") + color2 := stringToColor("张三") + if color1 != color2 { + t.Errorf("stringToColor should return consistent colors for the same input") + } + + // 测试不同用户名生成不同颜色(大概率) + color3 := stringToColor("李四") + if color1 == color3 { + t.Logf("Warning: different usernames generated the same color (possible but unlikely)") + } + + // 测试空字符串 + color4 := stringToColor("") + if color4 == "" { + t.Errorf("stringToColor should return a color for empty string") + } + + // 验证颜色格式 + if !strings.HasPrefix(color4, "#") { + t.Errorf("stringToColor should return hex color format starting with #") + } +} + +func TestGenerateSVGAvatar(t *testing.T) { + svg := GenerateSVGAvatar("张三", 100) + + // 验证SVG结构 + if !strings.Contains(svg, "") { + t.Errorf("SVG should contain tag") + } + if !strings.Contains(svg, "width=\"100\"") { + t.Errorf("SVG should have width=100") + } + if !strings.Contains(svg, "height=\"100\"") { + t.Errorf("SVG should have height=100") + } + if !strings.Contains(svg, "张") { + t.Errorf("SVG should contain the initial character") + } +} + +func TestGenerateAvatarDataURI(t *testing.T) { + dataURI := GenerateAvatarDataURI("张三", 100) + + // 验证Data URI格式 + if !strings.HasPrefix(dataURI, "data:image/svg+xml;base64,") { + t.Errorf("Data URI should start with data:image/svg+xml;base64,") + } + + // 验证base64部分不为空 + parts := strings.Split(dataURI, ",") + if len(parts) != 2 { + t.Errorf("Data URI should have two parts separated by comma") + } + if parts[1] == "" { + t.Errorf("Base64 part should not be empty") + } +} + +func TestIsChinese(t *testing.T) { + tests := []struct { + r rune + want bool + }{ + {'中', true}, + {'文', true}, + {'a', false}, + {'Z', false}, + {'0', false}, + {'_', false}, + } + + for _, tt := range tests { + got := isChinese(tt.r) + if got != tt.want { + t.Errorf("isChinese(%q) = %v, want %v", tt.r, got, tt.want) + } + } +} diff --git a/internal/pkg/email/client.go b/internal/pkg/email/client.go new file mode 100644 index 0000000..1b2b63b --- /dev/null +++ b/internal/pkg/email/client.go @@ -0,0 +1,131 @@ +package email + +import ( + "context" + "crypto/tls" + "fmt" + "strings" + "time" + + gomail "gopkg.in/gomail.v2" +) + +// Message 发信参数 +type Message struct { + To []string + Cc []string + Bcc []string + ReplyTo []string + Subject string + TextBody string + HTMLBody string + Attachments []string +} + +type Client interface { + IsEnabled() bool + Config() Config + Send(ctx context.Context, msg Message) error +} + +type clientImpl struct { + cfg Config +} + +func NewClient(cfg Config) Client { + return &clientImpl{cfg: cfg} +} + +func (c *clientImpl) IsEnabled() bool { + return c.cfg.Enabled && + strings.TrimSpace(c.cfg.Host) != "" && + c.cfg.Port > 0 && + strings.TrimSpace(c.cfg.FromAddress) != "" +} + +func (c *clientImpl) Config() Config { + return c.cfg +} + +func (c *clientImpl) Send(ctx context.Context, msg Message) error { + if !c.IsEnabled() { + return fmt.Errorf("email client is disabled or misconfigured") + } + if len(msg.To) == 0 { + return fmt.Errorf("email recipient is empty") + } + if strings.TrimSpace(msg.Subject) == "" { + return fmt.Errorf("email subject is empty") + } + if strings.TrimSpace(msg.TextBody) == "" && strings.TrimSpace(msg.HTMLBody) == "" { + return fmt.Errorf("email body is empty") + } + + m := gomail.NewMessage() + m.SetAddressHeader("From", c.cfg.FromAddress, c.cfg.FromName) + m.SetHeader("To", msg.To...) + if len(msg.Cc) > 0 { + m.SetHeader("Cc", msg.Cc...) + } + if len(msg.Bcc) > 0 { + m.SetHeader("Bcc", msg.Bcc...) + } + if len(msg.ReplyTo) > 0 { + m.SetHeader("Reply-To", msg.ReplyTo...) + } + m.SetHeader("Subject", msg.Subject) + + if strings.TrimSpace(msg.TextBody) != "" && strings.TrimSpace(msg.HTMLBody) != "" { + m.SetBody("text/plain", msg.TextBody) + m.AddAlternative("text/html", msg.HTMLBody) + } else if strings.TrimSpace(msg.HTMLBody) != "" { + m.SetBody("text/html", msg.HTMLBody) + } else { + m.SetBody("text/plain", msg.TextBody) + } + + for _, attachment := range msg.Attachments { + if strings.TrimSpace(attachment) == "" { + continue + } + m.Attach(attachment) + } + + timeout := c.cfg.TimeoutSeconds + if timeout <= 0 { + timeout = 15 + } + dialer := gomail.NewDialer(c.cfg.Host, c.cfg.Port, c.cfg.Username, c.cfg.Password) + if c.cfg.UseTLS { + dialer.TLSConfig = &tls.Config{ + ServerName: c.cfg.Host, + InsecureSkipVerify: c.cfg.InsecureSkipVerify, + } + // 465 端口通常要求直接 TLS(Implicit TLS)。 + if c.cfg.Port == 465 { + dialer.SSL = true + } + } + + sendCtx := ctx + cancel := func() {} + if timeout > 0 { + sendCtx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + } + defer cancel() + + done := make(chan error, 1) + go func() { + done <- dialer.DialAndSend(m) + }() + + select { + case <-sendCtx.Done(): + return fmt.Errorf("send email canceled: %w", sendCtx.Err()) + case err := <-done: + if err != nil { + return fmt.Errorf("send email failed: %w", err) + } + return nil + } +} diff --git a/internal/pkg/email/config.go b/internal/pkg/email/config.go new file mode 100644 index 0000000..4083e15 --- /dev/null +++ b/internal/pkg/email/config.go @@ -0,0 +1,33 @@ +package email + +import "carrot_bbs/internal/config" + +// Config SMTP 邮件配置(由应用配置转换) +type Config struct { + Enabled bool + Host string + Port int + Username string + Password string + FromAddress string + FromName string + UseTLS bool + InsecureSkipVerify bool + TimeoutSeconds int +} + +// ConfigFromAppConfig 从应用配置转换 +func ConfigFromAppConfig(cfg *config.EmailConfig) Config { + return Config{ + Enabled: cfg.Enabled, + Host: cfg.Host, + Port: cfg.Port, + Username: cfg.Username, + Password: cfg.Password, + FromAddress: cfg.FromAddress, + FromName: cfg.FromName, + UseTLS: cfg.UseTLS, + InsecureSkipVerify: cfg.InsecureSkipVerify, + TimeoutSeconds: cfg.Timeout, + } +} diff --git a/internal/pkg/gorse/client.go b/internal/pkg/gorse/client.go new file mode 100644 index 0000000..9c9e3fc --- /dev/null +++ b/internal/pkg/gorse/client.go @@ -0,0 +1,286 @@ +package gorse + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "time" + + gorseio "github.com/gorse-io/gorse-go" +) + +// FeedbackType 反馈类型 +type FeedbackType string + +const ( + FeedbackTypeLike FeedbackType = "like" // 点赞 + FeedbackTypeStar FeedbackType = "star" // 收藏 + FeedbackTypeComment FeedbackType = "comment" // 评论 + FeedbackTypeRead FeedbackType = "read" // 浏览 +) + +// Score 非个性化推荐返回的评分项 +type Score struct { + Id string `json:"Id"` + Score float64 `json:"Score"` +} + +// Client Gorse客户端接口 +type Client interface { + // InsertFeedback 插入用户反馈 + InsertFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error + // DeleteFeedback 删除用户反馈 + DeleteFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error + // GetRecommend 获取个性化推荐列表 + GetRecommend(ctx context.Context, userID string, n int, offset int) ([]string, error) + // GetNonPersonalized 获取非个性化推荐(通过名称) + GetNonPersonalized(ctx context.Context, name string, n int, offset int, userID string) ([]string, error) + // UpsertItem 插入或更新物品(无embedding) + UpsertItem(ctx context.Context, itemID string, categories []string, comment string) error + // UpsertItemWithEmbedding 插入或更新物品(带embedding) + UpsertItemWithEmbedding(ctx context.Context, itemID string, categories []string, comment string, textToEmbed string) error + // DeleteItem 删除物品 + DeleteItem(ctx context.Context, itemID string) error + // UpsertUser 插入或更新用户 + UpsertUser(ctx context.Context, userID string, labels map[string]any) error + // IsEnabled 检查是否启用 + IsEnabled() bool +} + +// client Gorse客户端实现 +type client struct { + config Config + gorse *gorseio.GorseClient + httpClient *http.Client +} + +// NewClient 创建新的Gorse客户端 +func NewClient(cfg Config) Client { + if !cfg.Enabled { + return &noopClient{} + } + + gorse := gorseio.NewGorseClient(cfg.Address, cfg.APIKey) + return &client{ + config: cfg, + gorse: gorse, + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, + } +} + +// IsEnabled 检查是否启用 +func (c *client) IsEnabled() bool { + return c.config.Enabled +} + +// InsertFeedback 插入用户反馈 +func (c *client) InsertFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error { + if !c.config.Enabled { + return nil + } + + _, err := c.gorse.InsertFeedback(ctx, []gorseio.Feedback{ + { + FeedbackType: string(feedbackType), + UserId: userID, + ItemId: itemID, + Timestamp: time.Now().UTC().Truncate(time.Second), + }, + }) + return err +} + +// DeleteFeedback 删除用户反馈 +func (c *client) DeleteFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error { + if !c.config.Enabled { + return nil + } + + _, err := c.gorse.DeleteFeedback(ctx, string(feedbackType), userID, itemID) + return err +} + +// GetRecommend 获取个性化推荐列表 +func (c *client) GetRecommend(ctx context.Context, userID string, n int, offset int) ([]string, error) { + if !c.config.Enabled { + return nil, nil + } + + result, err := c.gorse.GetRecommend(ctx, userID, "", n, offset) + if err != nil { + return nil, err + } + + return result, nil +} + +// GetNonPersonalized 获取非个性化推荐 +// name: 推荐器名称,如 "most_liked_weekly" +// n: 返回数量 +// offset: 偏移量 +// userID: 可选,用于排除用户已读物品 +func (c *client) GetNonPersonalized(ctx context.Context, name string, n int, offset int, userID string) ([]string, error) { + if !c.config.Enabled { + return nil, nil + } + + // 构建URL + url := fmt.Sprintf("%s/api/non-personalized/%s?n=%d&offset=%d", c.config.Address, name, n, offset) + if userID != "" { + url += fmt.Sprintf("&user-id=%s", userID) + } + + // 创建请求 + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // 设置API Key + if c.config.APIKey != "" { + req.Header.Set("X-API-Key", c.config.APIKey) + } + + // 发送请求 + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + // 读取响应 + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("gorse api error: status=%d, body=%s", resp.StatusCode, string(body)) + } + + // 解析响应 + var scores []Score + if err := json.Unmarshal(body, &scores); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // 提取ID + ids := make([]string, len(scores)) + for i, score := range scores { + ids[i] = score.Id + } + + return ids, nil +} + +// UpsertItem 插入或更新物品 +func (c *client) UpsertItem(ctx context.Context, itemID string, categories []string, comment string) error { + if !c.config.Enabled { + return nil + } + + _, err := c.gorse.InsertItem(ctx, gorseio.Item{ + ItemId: itemID, + IsHidden: false, + Categories: categories, + Comment: comment, + Timestamp: time.Now().UTC().Truncate(time.Second), + }) + return err +} + +// UpsertItemWithEmbedding 插入或更新物品(带embedding) +func (c *client) UpsertItemWithEmbedding(ctx context.Context, itemID string, categories []string, comment string, textToEmbed string) error { + if !c.config.Enabled { + return nil + } + + // 生成embedding + var embedding []float64 + if textToEmbed != "" { + var err error + embedding, err = GetEmbedding(textToEmbed) + if err != nil { + log.Printf("[WARN] Failed to get embedding for item %s: %v, using zero vector", itemID, err) + embedding = make([]float64, 1024) + } + } else { + embedding = make([]float64, 1024) + } + + _, err := c.gorse.InsertItem(ctx, gorseio.Item{ + ItemId: itemID, + IsHidden: false, + Categories: categories, + Comment: comment, + Timestamp: time.Now().UTC().Truncate(time.Second), + Labels: map[string]any{ + "embedding": embedding, + }, + }) + return err +} + +// DeleteItem 删除物品 +func (c *client) DeleteItem(ctx context.Context, itemID string) error { + if !c.config.Enabled { + return nil + } + + _, err := c.gorse.DeleteItem(ctx, itemID) + return err +} + +// UpsertUser 插入或更新用户 +func (c *client) UpsertUser(ctx context.Context, userID string, labels map[string]any) error { + if !c.config.Enabled { + return nil + } + + _, err := c.gorse.InsertUser(ctx, gorseio.User{ + UserId: userID, + Labels: labels, + }) + return err +} + +// noopClient 空操作客户端(用于未启用推荐功能时) +type noopClient struct{} + +func (c *noopClient) IsEnabled() bool { return false } +func (c *noopClient) InsertFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error { + return nil +} +func (c *noopClient) DeleteFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error { + return nil +} +func (c *noopClient) GetRecommend(ctx context.Context, userID string, n int, offset int) ([]string, error) { + return nil, nil +} +func (c *noopClient) GetNonPersonalized(ctx context.Context, name string, n int, offset int, userID string) ([]string, error) { + return nil, nil +} +func (c *noopClient) UpsertItem(ctx context.Context, itemID string, categories []string, comment string) error { + return nil +} +func (c *noopClient) UpsertItemWithEmbedding(ctx context.Context, itemID string, categories []string, comment string, textToEmbed string) error { + return nil +} +func (c *noopClient) DeleteItem(ctx context.Context, itemID string) error { return nil } +func (c *noopClient) UpsertUser(ctx context.Context, userID string, labels map[string]any) error { + return nil +} + +// 确保实现了接口 +var _ Client = (*client)(nil) +var _ Client = (*noopClient)(nil) + +// log 用于内部日志 +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) +} diff --git a/internal/pkg/gorse/config.go b/internal/pkg/gorse/config.go new file mode 100644 index 0000000..9d3329a --- /dev/null +++ b/internal/pkg/gorse/config.go @@ -0,0 +1,23 @@ +package gorse + +import ( + "carrot_bbs/internal/config" +) + +// Config Gorse客户端配置(从config.GorseConfig转换) +type Config struct { + Address string + APIKey string + Enabled bool + Dashboard string +} + +// ConfigFromAppConfig 从应用配置创建Gorse配置 +func ConfigFromAppConfig(cfg *config.GorseConfig) Config { + return Config{ + Address: cfg.Address, + APIKey: cfg.APIKey, + Enabled: cfg.Enabled, + Dashboard: cfg.Dashboard, + } +} diff --git a/internal/pkg/gorse/embedding.go b/internal/pkg/gorse/embedding.go new file mode 100644 index 0000000..30057c6 --- /dev/null +++ b/internal/pkg/gorse/embedding.go @@ -0,0 +1,106 @@ +package gorse + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "time" +) + +// EmbeddingConfig embedding服务配置 +type EmbeddingConfig struct { + APIKey string + URL string + Model string +} + +var defaultEmbeddingConfig = EmbeddingConfig{ + APIKey: "sk-ZPN5NMPSqEaOGCPfD2LqndZ5Wwmw3DC4CQgzgKhM35fI3RpD", + URL: "https://api.littlelan.cn/v1/embeddings", + Model: "BAAI/bge-m3", +} + +// SetEmbeddingConfig 设置embedding配置 +func SetEmbeddingConfig(apiKey, url, model string) { + if apiKey != "" { + defaultEmbeddingConfig.APIKey = apiKey + } + if url != "" { + defaultEmbeddingConfig.URL = url + } + if model != "" { + defaultEmbeddingConfig.Model = model + } +} + +// GetEmbedding 获取文本的embedding +func GetEmbedding(text string) ([]float64, error) { + type embeddingRequest struct { + Input string `json:"input"` + Model string `json:"model"` + } + + type embeddingResponse struct { + Data []struct { + Embedding []float64 `json:"embedding"` + } `json:"data"` + } + + reqBody := embeddingRequest{ + Input: text, + Model: defaultEmbeddingConfig.Model, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", defaultEmbeddingConfig.URL, bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+defaultEmbeddingConfig.APIKey) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("embedding API error: status=%d, body=%s", resp.StatusCode, string(body)) + } + + var result embeddingResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + if len(result.Data) == 0 { + return nil, fmt.Errorf("no embedding returned") + } + + return result.Data[0].Embedding, nil +} + +// InitEmbeddingWithConfig 从应用配置初始化embedding +func InitEmbeddingWithConfig(apiKey, url, model string) { + if apiKey == "" { + log.Println("[WARN] Gorse embedding API key not set, using default") + } + defaultEmbeddingConfig.APIKey = apiKey + if url != "" { + defaultEmbeddingConfig.URL = url + } + if model != "" { + defaultEmbeddingConfig.Model = model + } +} \ No newline at end of file diff --git a/internal/pkg/jwt/jwt.go b/internal/pkg/jwt/jwt.go new file mode 100644 index 0000000..e433ba2 --- /dev/null +++ b/internal/pkg/jwt/jwt.go @@ -0,0 +1,105 @@ +package jwt + +import ( + "errors" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +var ( + ErrInvalidToken = errors.New("invalid token") + ErrExpiredToken = errors.New("token has expired") +) + +// Claims JWT 声明 +type Claims struct { + UserID string `json:"user_id"` + Username string `json:"username"` + jwt.RegisteredClaims +} + +// JWT JWT工具 +type JWT struct { + secretKey string + accessTokenExpire time.Duration + refreshTokenExpire time.Duration +} + +// New 创建JWT实例 +func New(secret string, accessExpire, refreshExpire time.Duration) *JWT { + return &JWT{ + secretKey: secret, + accessTokenExpire: accessExpire, + refreshTokenExpire: refreshExpire, + } +} + +// GenerateAccessToken 生成访问令牌 +func (j *JWT) GenerateAccessToken(userID, username string) (string, error) { + now := time.Now() + claims := Claims{ + UserID: userID, + Username: username, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(j.accessTokenExpire)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + Issuer: "carrot_bbs", + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(j.secretKey)) +} + +// GenerateRefreshToken 生成刷新令牌 +func (j *JWT) GenerateRefreshToken(userID, username string) (string, error) { + now := time.Now() + claims := Claims{ + UserID: userID, + Username: username, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(j.refreshTokenExpire)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + Issuer: "carrot_bbs", + ID: "refresh", + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(j.secretKey)) +} + +// ParseToken 解析令牌 +func (j *JWT) ParseToken(tokenString string) (*Claims, error) { + token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { + return []byte(j.secretKey), nil + }) + + if err != nil { + return nil, err + } + + if claims, ok := token.Claims.(*Claims); ok && token.Valid { + return claims, nil + } + + return nil, ErrInvalidToken +} + +// ValidateToken 验证令牌 +func (j *JWT) ValidateToken(tokenString string) error { + claims, err := j.ParseToken(tokenString) + if err != nil { + return err + } + + // 检查是否是刷新令牌 + if claims.ID == "refresh" { + return errors.New("cannot use refresh token as access token") + } + + return nil +} diff --git a/internal/pkg/openai/client.go b/internal/pkg/openai/client.go new file mode 100644 index 0000000..06ac12e --- /dev/null +++ b/internal/pkg/openai/client.go @@ -0,0 +1,438 @@ +package openai + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "image" + _ "image/gif" + "image/jpeg" + _ "image/png" + "io" + "net/http" + "net/url" + "strings" + "time" + + xdraw "golang.org/x/image/draw" +) + +const moderationSystemPrompt = "你是中文社区的内容审核助手,负责对“帖子标题、正文、配图”做联合审核。目标是平衡社区安全与正常交流:必须拦截高风险违规内容,但不要误伤正常玩梗、二创、吐槽和轻度调侃。请只输出指定JSON。\n\n审核流程:\n1) 先判断是否命中硬性违规;\n2) 再判断语境(玩笑/自嘲/朋友间互动/作品讨论);\n3) 做文图交叉判断(文本+图片合并理解);\n4) 给出 approved 与简短 reason。\n\n硬性违规(命中任一项必须 approved=false):\nA. 宣传对立与煽动撕裂:\n- 明确煽动群体对立、地域对立、性别对立、民族宗教对立,鼓动仇恨、排斥、报复。\nB. 严重人身攻击与网暴引导:\n- 持续性侮辱贬损、羞辱人格、号召围攻/骚扰/挂人/线下冲突。\nC. 开盒/人肉/隐私暴露:\n- 故意公开、拼接、索取他人可识别隐私信息(姓名+联系方式、身份证号、住址、学校单位、车牌、定位轨迹等);\n- 图片/截图中出现可识别隐私信息并伴随曝光意图,也按违规处理。\nD. 其他高危违规:\n- 违法犯罪、暴力威胁、极端仇恨、色情低俗、诈骗引流、恶意广告等。\n\n放行规则(以下通常 approved=true):\n- 正常玩梗、表情包、谐音梗、二次创作、无恶意的吐槽;\n- 非定向、轻度口语化吐槽(无明确攻击对象、无网暴号召、无隐私暴露);\n- 对社会事件/作品的理性讨论、观点争论(即使语气尖锐,但未煽动对立或人身攻击)。\n\n边界判定:\n- 若只是“梗文化表达”且不指向现实伤害,优先通过;\n- 若存在明确伤害意图(煽动、围攻、曝光隐私),必须拒绝;\n- 对模糊内容不因个别粗口直接拒绝,需结合对象、意图、号召性和可执行性综合判断。\n\nreason 要求:\n- approved=false 时:中文10-30字,说明核心违规点;\n- approved=true 时:reason 为空字符串。\n\n输出格式(严格):\n仅输出一行JSON对象,不要Markdown,不要额外解释:\n{\"approved\": true/false, \"reason\": \"...\"}" + +const ( + defaultMaxImagesPerModerationRequest = 1 + maxModerationResultRetries = 3 + maxChatCompletionRetries = 3 + initialRetryBackoff = 500 * time.Millisecond + maxDownloadImageBytes = 10 * 1024 * 1024 + maxModerationImageSide = 1280 + compressedJPEGQuality = 72 + maxCompressedPayloadBytes = 1536 * 1024 +) + +type Client interface { + IsEnabled() bool + Config() Config + ModeratePost(ctx context.Context, title, content string, images []string) (bool, string, error) + ModerateComment(ctx context.Context, content string, images []string) (bool, string, error) +} + +type clientImpl struct { + cfg Config + httpClient *http.Client +} + +func NewClient(cfg Config) Client { + timeout := cfg.RequestTimeoutSeconds + if timeout <= 0 { + timeout = 30 + } + return &clientImpl{ + cfg: cfg, + httpClient: &http.Client{ + Timeout: time.Duration(timeout) * time.Second, + }, + } +} + +func (c *clientImpl) IsEnabled() bool { + return c.cfg.Enabled && c.cfg.APIKey != "" && c.cfg.BaseURL != "" +} + +func (c *clientImpl) Config() Config { + return c.cfg +} + +func (c *clientImpl) ModeratePost(ctx context.Context, title, content string, images []string) (bool, string, error) { + if !c.IsEnabled() { + return true, "", nil + } + return c.moderateContentInBatches(ctx, fmt.Sprintf("帖子标题:%s\n帖子内容:%s", title, content), images) +} + +func (c *clientImpl) ModerateComment(ctx context.Context, content string, images []string) (bool, string, error) { + if !c.IsEnabled() { + return true, "", nil + } + return c.moderateContentInBatches(ctx, fmt.Sprintf("评论内容:%s", content), images) +} + +func (c *clientImpl) moderateContentInBatches(ctx context.Context, contentPrompt string, images []string) (bool, string, error) { + cleanImages := normalizeImageURLs(images) + optimizedImages := c.optimizeImagesForModeration(ctx, cleanImages) + maxImagesPerRequest := c.maxImagesPerModerationRequest() + totalBatches := 1 + if len(optimizedImages) > 0 { + totalBatches = (len(optimizedImages) + maxImagesPerRequest - 1) / maxImagesPerRequest + } + + // 图片超过单批上限时分批审核,任一批次拒绝即整体拒绝 + for i := 0; i < totalBatches; i++ { + start := i * maxImagesPerRequest + end := start + maxImagesPerRequest + if end > len(optimizedImages) { + end = len(optimizedImages) + } + + batchImages := []string{} + if len(optimizedImages) > 0 { + batchImages = optimizedImages[start:end] + } + + approved, reason, err := c.moderateSingleBatch(ctx, contentPrompt, batchImages, i+1, totalBatches) + if err != nil { + return false, "", err + } + if !approved { + if strings.TrimSpace(reason) != "" && totalBatches > 1 { + reason = fmt.Sprintf("第%d/%d批图片未通过:%s", i+1, totalBatches, reason) + } + return false, reason, nil + } + } + + return true, "", nil +} + +func (c *clientImpl) moderateSingleBatch( + ctx context.Context, + contentPrompt string, + images []string, + batchNo, totalBatches int, +) (bool, string, error) { + userPrompt := fmt.Sprintf( + "%s\n图片批次:%d/%d(本次仅提供当前批次图片)", + contentPrompt, + batchNo, + totalBatches, + ) + + var lastErr error + for attempt := 0; attempt < maxModerationResultRetries; attempt++ { + replyText, err := c.chatCompletion(ctx, c.cfg.ModerationModel, moderationSystemPrompt, userPrompt, images, 0.1, 220) + if err != nil { + lastErr = err + } else { + parsed := struct { + Approved bool `json:"approved"` + Reason string `json:"reason"` + }{} + if err := json.Unmarshal([]byte(extractJSONObject(replyText)), &parsed); err != nil { + lastErr = fmt.Errorf("failed to parse moderation result: %w", err) + } else { + return parsed.Approved, parsed.Reason, nil + } + } + + if attempt == maxModerationResultRetries-1 { + break + } + if sleepErr := sleepWithBackoff(ctx, attempt); sleepErr != nil { + return false, "", sleepErr + } + } + + return false, "", fmt.Errorf( + "moderation batch %d/%d failed after %d attempts: %w", + batchNo, + totalBatches, + maxModerationResultRetries, + lastErr, + ) +} + +type chatCompletionsRequest struct { + Model string `json:"model"` + Messages []chatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` +} + +type chatMessage struct { + Role string `json:"role"` + Content interface{} `json:"content"` +} + +type contentPart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ImageURL *imageURLPart `json:"image_url,omitempty"` +} + +type imageURLPart struct { + URL string `json:"url"` +} + +type chatCompletionsResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` +} + +func (c *clientImpl) chatCompletion( + ctx context.Context, + model string, + systemPrompt string, + userPrompt string, + images []string, + temperature float64, + maxTokens int, +) (string, error) { + if model == "" { + return "", fmt.Errorf("model is empty") + } + + cleanImages := normalizeImageURLs(images) + + userParts := []contentPart{ + {Type: "text", Text: userPrompt}, + } + for _, image := range cleanImages { + userParts = append(userParts, contentPart{ + Type: "image_url", + ImageURL: &imageURLPart{URL: image}, + }) + } + + reqBody := chatCompletionsRequest{ + Model: model, + Messages: []chatMessage{ + {Role: "system", Content: systemPrompt}, + {Role: "user", Content: userParts}, + }, + Temperature: temperature, + MaxTokens: maxTokens, + } + + data, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("marshal request: %w", err) + } + + baseURL := strings.TrimRight(c.cfg.BaseURL, "/") + endpoint := baseURL + "/v1/chat/completions" + if strings.HasSuffix(baseURL, "/v1") { + endpoint = baseURL + "/chat/completions" + } + + var lastErr error + for attempt := 0; attempt < maxChatCompletionRetries; attempt++ { + body, statusCode, err := c.doChatCompletionRequest(ctx, endpoint, data) + if err != nil { + lastErr = err + } else if statusCode >= 400 { + lastErr = fmt.Errorf("openai error status=%d body=%s", statusCode, string(body)) + if !isRetryableStatusCode(statusCode) { + return "", lastErr + } + } else { + var parsed chatCompletionsResponse + if err := json.Unmarshal(body, &parsed); err != nil { + return "", fmt.Errorf("decode response: %w", err) + } + if len(parsed.Choices) == 0 { + return "", fmt.Errorf("empty response choices") + } + return parsed.Choices[0].Message.Content, nil + } + + if attempt == maxChatCompletionRetries-1 { + break + } + if sleepErr := sleepWithBackoff(ctx, attempt); sleepErr != nil { + return "", sleepErr + } + } + + return "", lastErr +} + +func (c *clientImpl) doChatCompletionRequest(ctx context.Context, endpoint string, data []byte) ([]byte, int, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data)) + if err != nil { + return nil, 0, fmt.Errorf("create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.cfg.APIKey) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, 0, fmt.Errorf("request openai: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, 0, fmt.Errorf("read response: %w", err) + } + return body, resp.StatusCode, nil +} + +func isRetryableStatusCode(statusCode int) bool { + if statusCode == http.StatusTooManyRequests { + return true + } + return statusCode >= 500 && statusCode <= 599 +} + +func sleepWithBackoff(ctx context.Context, attempt int) error { + delay := initialRetryBackoff * time.Duration(1<= 0 && end > start { + return text[start : end+1] + } + return text +} + +func (c *clientImpl) maxImagesPerModerationRequest() int { + // 审核固定单图请求,降低单次payload体积,减少超时风险。 + if c.cfg.ModerationMaxImagesPerRequest <= 0 { + return defaultMaxImagesPerModerationRequest + } + if c.cfg.ModerationMaxImagesPerRequest > 1 { + return 1 + } + return c.cfg.ModerationMaxImagesPerRequest +} + +func (c *clientImpl) optimizeImagesForModeration(ctx context.Context, images []string) []string { + if len(images) == 0 { + return images + } + + optimized := make([]string, 0, len(images)) + for _, imageURL := range images { + optimized = append(optimized, c.tryCompressImageForModeration(ctx, imageURL)) + } + return optimized +} + +func (c *clientImpl) tryCompressImageForModeration(ctx context.Context, imageURL string) string { + parsed, err := url.Parse(imageURL) + if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") { + return imageURL + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, imageURL, nil) + if err != nil { + return imageURL + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return imageURL + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + return imageURL + } + if !strings.HasPrefix(strings.ToLower(resp.Header.Get("Content-Type")), "image/") { + return imageURL + } + + originBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxDownloadImageBytes)) + if err != nil || len(originBytes) == 0 { + return imageURL + } + + srcImg, _, err := image.Decode(bytes.NewReader(originBytes)) + if err != nil { + return imageURL + } + + dstImg := resizeIfNeeded(srcImg, maxModerationImageSide) + var buf bytes.Buffer + if err := jpeg.Encode(&buf, dstImg, &jpeg.Options{Quality: compressedJPEGQuality}); err != nil { + return imageURL + } + + compressed := buf.Bytes() + if len(compressed) == 0 || len(compressed) > maxCompressedPayloadBytes { + return imageURL + } + // 压缩效果不明显时直接使用原图URL,避免增大请求体。 + if len(compressed) >= int(float64(len(originBytes))*0.95) { + return imageURL + } + + return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(compressed) +} + +func resizeIfNeeded(src image.Image, maxSide int) image.Image { + bounds := src.Bounds() + w := bounds.Dx() + h := bounds.Dy() + if w <= 0 || h <= 0 || max(w, h) <= maxSide { + return src + } + + newW, newH := w, h + if w >= h { + newW = maxSide + newH = int(float64(h) * (float64(maxSide) / float64(w))) + } else { + newH = maxSide + newW = int(float64(w) * (float64(maxSide) / float64(h))) + } + if newW < 1 { + newW = 1 + } + if newH < 1 { + newH = 1 + } + + dst := image.NewRGBA(image.Rect(0, 0, newW, newH)) + xdraw.CatmullRom.Scale(dst, dst.Bounds(), src, bounds, xdraw.Over, nil) + return dst +} diff --git a/internal/pkg/openai/config.go b/internal/pkg/openai/config.go new file mode 100644 index 0000000..e71efb1 --- /dev/null +++ b/internal/pkg/openai/config.go @@ -0,0 +1,27 @@ +package openai + +import "carrot_bbs/internal/config" + +// Config OpenAI 兼容接口配置 +type Config struct { + Enabled bool + BaseURL string + APIKey string + ModerationModel string + ModerationMaxImagesPerRequest int + RequestTimeoutSeconds int + StrictModeration bool +} + +// ConfigFromAppConfig 从应用配置转换 +func ConfigFromAppConfig(cfg *config.OpenAIConfig) Config { + return Config{ + Enabled: cfg.Enabled, + BaseURL: cfg.BaseURL, + APIKey: cfg.APIKey, + ModerationModel: cfg.ModerationModel, + ModerationMaxImagesPerRequest: cfg.ModerationMaxImagesPerRequest, + RequestTimeoutSeconds: cfg.RequestTimeout, + StrictModeration: cfg.StrictModeration, + } +} diff --git a/internal/pkg/redis/redis.go b/internal/pkg/redis/redis.go new file mode 100644 index 0000000..8d3e926 --- /dev/null +++ b/internal/pkg/redis/redis.go @@ -0,0 +1,119 @@ +package redis + +import ( + "context" + "fmt" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + + "carrot_bbs/internal/config" +) + +// Client Redis客户端 +type Client struct { + rdb *redis.Client + isMiniRedis bool + mr *miniredis.Miniredis +} + +// New 创建Redis客户端 +func New(cfg *config.RedisConfig) (*Client, error) { + switch cfg.Type { + case "miniredis": + // 启动内嵌Redis模拟 + mr, err := miniredis.Run() + if err != nil { + return nil, fmt.Errorf("failed to start miniredis: %w", err) + } + rdb := redis.NewClient(&redis.Options{ + Addr: mr.Addr(), + Password: "", + DB: 0, + }) + return &Client{ + rdb: rdb, + isMiniRedis: true, + mr: mr, + }, nil + case "redis": + // 使用真实Redis + rdb := redis.NewClient(&redis.Options{ + Addr: cfg.Redis.Addr(), + Password: cfg.Redis.Password, + DB: cfg.Redis.DB, + PoolSize: cfg.PoolSize, + }) + ctx := context.Background() + if err := rdb.Ping(ctx).Err(); err != nil { + return nil, fmt.Errorf("failed to connect to redis: %w", err) + } + return &Client{rdb: rdb, isMiniRedis: false}, nil + default: + // 默认使用miniredis + mr, err := miniredis.Run() + if err != nil { + return nil, fmt.Errorf("failed to start miniredis: %w", err) + } + rdb := redis.NewClient(&redis.Options{ + Addr: mr.Addr(), + }) + return &Client{ + rdb: rdb, + isMiniRedis: true, + mr: mr, + }, nil + } +} + +// Get 获取值 +func (c *Client) Get(ctx context.Context, key string) (string, error) { + return c.rdb.Get(ctx, key).Result() +} + +// Set 设置值 +func (c *Client) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error { + return c.rdb.Set(ctx, key, value, expiration).Err() +} + +// Del 删除键 +func (c *Client) Del(ctx context.Context, keys ...string) error { + return c.rdb.Del(ctx, keys...).Err() +} + +// Exists 检查键是否存在 +func (c *Client) Exists(ctx context.Context, keys ...string) (int64, error) { + return c.rdb.Exists(ctx, keys...).Result() +} + +// Incr 递增 +func (c *Client) Incr(ctx context.Context, key string) (int64, error) { + return c.rdb.Incr(ctx, key).Result() +} + +// Expire 设置过期时间 +func (c *Client) Expire(ctx context.Context, key string, expiration time.Duration) (bool, error) { + return c.rdb.Expire(ctx, key, expiration).Result() +} + +// GetClient 获取原生客户端 +func (c *Client) GetClient() *redis.Client { + return c.rdb +} + +// Close 关闭连接 +func (c *Client) Close() error { + if err := c.rdb.Close(); err != nil { + return err + } + if c.mr != nil { + c.mr.Close() + } + return nil +} + +// IsMiniRedis 返回是否是miniredis +func (c *Client) IsMiniRedis() bool { + return c.isMiniRedis +} diff --git a/internal/pkg/response/response.go b/internal/pkg/response/response.go new file mode 100644 index 0000000..ebd5f3e --- /dev/null +++ b/internal/pkg/response/response.go @@ -0,0 +1,117 @@ +package response + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +// Response 统一响应结构 +type Response struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// ResponseSnakeCase 统一响应结构(snake_case) +type ResponseSnakeCase struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// Success 成功响应 +func Success(c *gin.Context, data interface{}) { + c.JSON(http.StatusOK, Response{ + Code: 0, + Message: "success", + Data: data, + }) +} + +// SuccessWithMessage 成功响应带消息 +func SuccessWithMessage(c *gin.Context, message string, data interface{}) { + c.JSON(http.StatusOK, Response{ + Code: 0, + Message: message, + Data: data, + }) +} + +// Error 错误响应 +func Error(c *gin.Context, code int, message string) { + c.JSON(http.StatusBadRequest, Response{ + Code: code, + Message: message, + }) +} + +// ErrorWithStatusCode 带状态码的错误响应 +func ErrorWithStatusCode(c *gin.Context, statusCode int, code int, message string) { + c.JSON(statusCode, Response{ + Code: code, + Message: message, + }) +} + +// BadRequest 参数错误 +func BadRequest(c *gin.Context, message string) { + ErrorWithStatusCode(c, http.StatusBadRequest, 400, message) +} + +// Unauthorized 未授权 +func Unauthorized(c *gin.Context, message string) { + if message == "" { + message = "unauthorized" + } + ErrorWithStatusCode(c, http.StatusUnauthorized, 401, message) +} + +// Forbidden 禁止访问 +func Forbidden(c *gin.Context, message string) { + if message == "" { + message = "forbidden" + } + ErrorWithStatusCode(c, http.StatusForbidden, 403, message) +} + +// NotFound 资源不存在 +func NotFound(c *gin.Context, message string) { + if message == "" { + message = "resource not found" + } + ErrorWithStatusCode(c, http.StatusNotFound, 404, message) +} + +// InternalServerError 服务器内部错误 +func InternalServerError(c *gin.Context, message string) { + if message == "" { + message = "internal server error" + } + ErrorWithStatusCode(c, http.StatusInternalServerError, 500, message) +} + +// PaginatedResponse 分页响应 +type PaginatedResponse struct { + List interface{} `json:"list"` + Total int64 `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` + TotalPages int `json:"total_pages"` +} + +// Paginated 分页成功响应 +func Paginated(c *gin.Context, list interface{}, total int64, page, pageSize int) { + totalPages := int(total) / pageSize + if int(total)%pageSize > 0 { + totalPages++ + } + + Success(c, PaginatedResponse{ + List: list, + Total: total, + Page: page, + PageSize: pageSize, + TotalPages: totalPages, + }) +} diff --git a/internal/pkg/s3/s3.go b/internal/pkg/s3/s3.go new file mode 100644 index 0000000..47d0b45 --- /dev/null +++ b/internal/pkg/s3/s3.go @@ -0,0 +1,119 @@ +package s3 + +import ( + "bytes" + "context" + "fmt" + "time" + + "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" + + "carrot_bbs/internal/config" +) + +// Client S3客户端 +type Client struct { + client *minio.Client + bucket string + domain string +} + +// New 创建S3客户端 +func New(cfg *config.S3Config) (*Client, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + client, err := minio.New(cfg.Endpoint, &minio.Options{ + Creds: credentials.NewStaticV4(cfg.AccessKey, cfg.SecretKey, ""), + Secure: cfg.UseSSL, + }) + if err != nil { + return nil, fmt.Errorf("failed to create S3 client: %w", err) + } + + // 检查bucket是否存在 + exists, err := client.BucketExists(ctx, cfg.Bucket) + if err != nil { + return nil, fmt.Errorf("failed to check bucket: %w", err) + } + + if !exists { + if err := client.MakeBucket(ctx, cfg.Bucket, minio.MakeBucketOptions{ + Region: cfg.Region, + }); err != nil { + return nil, fmt.Errorf("failed to create bucket: %w", err) + } + } + + // 如果没有配置domain,则使用默认的endpoint + domain := cfg.Domain + if domain == "" { + domain = cfg.Endpoint + } + + return &Client{ + client: client, + bucket: cfg.Bucket, + domain: domain, + }, nil +} + +// Upload 上传文件 +func (c *Client) Upload(ctx context.Context, objectName string, filePath string, contentType string) (string, error) { + _, err := c.client.FPutObject(ctx, c.bucket, objectName, filePath, minio.PutObjectOptions{ + ContentType: contentType, + }) + if err != nil { + return "", fmt.Errorf("failed to upload file: %w", err) + } + + return fmt.Sprintf("%s/%s", c.bucket, objectName), nil +} + +// UploadData 上传数据 +func (c *Client) UploadData(ctx context.Context, objectName string, data []byte, contentType string) (string, error) { + _, err := c.client.PutObject(ctx, c.bucket, objectName, bytes.NewReader(data), int64(len(data)), minio.PutObjectOptions{ + ContentType: contentType, + }) + if err != nil { + return "", fmt.Errorf("failed to upload data: %w", err) + } + + // 返回完整URL,包含bucket名称 + scheme := "https" + if c.domain == c.bucket || c.domain == "" { + scheme = "http" + } + return fmt.Sprintf("%s://%s/%s/%s", scheme, c.domain, c.bucket, objectName), nil +} + +// GetURL 获取文件URL - 使用自定义域名 +func (c *Client) GetURL(ctx context.Context, objectName string) (string, error) { + // 使用自定义域名构建URL,包含bucket名称 + scheme := "https" + if c.domain == c.bucket || c.domain == "" { + scheme = "http" + } + return fmt.Sprintf("%s://%s/%s/%s", scheme, c.domain, c.bucket, objectName), nil +} + +// GetPresignedURL 获取预签名URL(用于私有桶) +func (c *Client) GetPresignedURL(ctx context.Context, objectName string) (string, error) { + url, err := c.client.PresignedGetObject(ctx, c.bucket, objectName, time.Hour*24, nil) + if err != nil { + return "", fmt.Errorf("failed to get presigned URL: %w", err) + } + + return url.String(), nil +} + +// Delete 删除文件 +func (c *Client) Delete(ctx context.Context, objectName string) error { + return c.client.RemoveObject(ctx, c.bucket, objectName, minio.RemoveObjectOptions{}) +} + +// GetClient 获取原生客户端 +func (c *Client) GetClient() *minio.Client { + return c.client +} diff --git a/internal/pkg/utils/avatar.go b/internal/pkg/utils/avatar.go new file mode 100644 index 0000000..3282ba7 --- /dev/null +++ b/internal/pkg/utils/avatar.go @@ -0,0 +1,52 @@ +package utils + +import ( + "net/url" +) + +// AvatarServiceBaseURL 默认头像服务基础URL (使用 UI Avatars API) +const AvatarServiceBaseURL = "https://ui-avatars.com/api" + +// DefaultAvatarSize 默认头像尺寸 +const DefaultAvatarSize = 100 + +// AvatarInfo 头像信息 +type AvatarInfo struct { + Username string + Nickname string + Avatar string +} + +// GetAvatarOrDefault 获取头像URL,如果为空则返回在线头像生成服务的URL +// 优先使用已有的头像,否则使用昵称或用户名生成默认头像 +func GetAvatarOrDefault(username, nickname, avatar string) string { + if avatar != "" { + return avatar + } + // 使用用户名生成默认头像URL(优先使用昵称) + displayName := nickname + if displayName == "" { + displayName = username + } + return GenerateDefaultAvatarURL(displayName) +} + +// GetAvatarOrDefaultFromInfo 从 AvatarInfo 获取头像URL +func GetAvatarOrDefaultFromInfo(info AvatarInfo) string { + return GetAvatarOrDefault(info.Username, info.Nickname, info.Avatar) +} + +// GenerateDefaultAvatarURL 生成默认头像URL +// 使用 UI Avatars API 生成基于用户名首字母的头像 +func GenerateDefaultAvatarURL(name string) string { + if name == "" { + name = "?" + } + // 使用 UI Avatars API 生成头像 + params := url.Values{} + params.Set("name", url.QueryEscape(name)) + params.Set("size", "100") + params.Set("background", "0D8ABC") // 默认蓝色背景 + params.Set("color", "ffffff") // 白色文字 + return AvatarServiceBaseURL + "?" + params.Encode() +} diff --git a/internal/pkg/utils/hash.go b/internal/pkg/utils/hash.go new file mode 100644 index 0000000..0ecf1c2 --- /dev/null +++ b/internal/pkg/utils/hash.go @@ -0,0 +1,17 @@ +package utils + +import ( + "golang.org/x/crypto/bcrypt" +) + +// HashPassword 密码哈希 +func HashPassword(password string) (string, error) { + bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + return string(bytes), err +} + +// CheckPasswordHash 验证密码 +func CheckPasswordHash(password, hash string) bool { + err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) + return err == nil +} diff --git a/internal/pkg/utils/snowflake.go b/internal/pkg/utils/snowflake.go new file mode 100644 index 0000000..d01da71 --- /dev/null +++ b/internal/pkg/utils/snowflake.go @@ -0,0 +1,261 @@ +package utils + +import ( + "errors" + "os" + "sync" + "time" +) + +// 雪花算法常量定义 +const ( + // 64位ID结构:1位符号位 + 41位时间戳 + 10位机器ID + 12位序列号 + + // 机器ID占用的位数 + nodeIDBits uint64 = 10 + // 序列号占用的位数 + sequenceBits uint64 = 12 + + // 机器ID的最大值 (0-1023) + maxNodeID int64 = -1 ^ (-1 << nodeIDBits) + // 序列号的最大值 (0-4095) + maxSequence int64 = -1 ^ (-1 << sequenceBits) + + // 机器ID左移位数 + nodeIDShift uint64 = sequenceBits + // 时间戳左移位数 + timestampShift uint64 = sequenceBits + nodeIDBits + + // 自定义纪元时间:2024-01-01 00:00:00 UTC + // 使用自定义纪元可以延长ID有效期约24年(从2024年开始) + customEpoch int64 = 1704067200000 // 2024-01-01 00:00:00 UTC 的毫秒时间戳 +) + +// 错误定义 +var ( + // ErrInvalidNodeID 机器ID无效 + ErrInvalidNodeID = errors.New("node ID must be between 0 and 1023") + // ErrClockBackwards 时钟回拨 + ErrClockBackwards = errors.New("clock moved backwards, refusing to generate ID") +) + +// IDInfo 解析后的ID信息 +type IDInfo struct { + Timestamp int64 // 生成ID时的时间戳(毫秒) + NodeID int64 // 机器ID + Sequence int64 // 序列号 +} + +// Snowflake 雪花算法ID生成器 +type Snowflake struct { + mu sync.Mutex // 互斥锁,保证线程安全 + nodeID int64 // 机器ID (0-1023) + sequence int64 // 当前序列号 (0-4095) + lastTimestamp int64 // 上次生成ID的时间戳 +} + +// 全局雪花算法实例 +var ( + globalSnowflake *Snowflake + globalSnowflakeOnce sync.Once + globalSnowflakeErr error +) + +// InitSnowflake 初始化全局雪花算法实例 +// nodeID: 机器ID,范围0-1023,可以通过环境变量 NODE_ID 配置 +func InitSnowflake(nodeID int64) error { + globalSnowflake, globalSnowflakeErr = NewSnowflake(nodeID) + return globalSnowflakeErr +} + +// GetSnowflake 获取全局雪花算法实例 +// 如果未初始化,会自动使用默认配置初始化 +func GetSnowflake() *Snowflake { + globalSnowflakeOnce.Do(func() { + if globalSnowflake == nil { + globalSnowflake, globalSnowflakeErr = NewSnowflake(-1) + } + }) + return globalSnowflake +} + +// NewSnowflake 创建雪花算法ID生成器实例 +// nodeID: 机器ID,范围0-1023,可以通过环境变量 NODE_ID 配置 +// 如果nodeID为-1,则尝试从环境变量 NODE_ID 读取 +func NewSnowflake(nodeID int64) (*Snowflake, error) { + // 如果传入-1,尝试从环境变量读取 + if nodeID == -1 { + nodeIDStr := os.Getenv("NODE_ID") + if nodeIDStr != "" { + // 解析环境变量 + parsedID, err := parseInt(nodeIDStr) + if err != nil { + return nil, ErrInvalidNodeID + } + nodeID = parsedID + } else { + // 默认使用0 + nodeID = 0 + } + } + + // 验证机器ID范围 + if nodeID < 0 || nodeID > maxNodeID { + return nil, ErrInvalidNodeID + } + + return &Snowflake{ + nodeID: nodeID, + sequence: 0, + lastTimestamp: 0, + }, nil +} + +// parseInt 辅助函数:解析整数 +func parseInt(s string) (int64, error) { + var result int64 + var negative bool + + if len(s) == 0 { + return 0, errors.New("empty string") + } + + i := 0 + if s[0] == '-' { + negative = true + i = 1 + } + + for ; i < len(s); i++ { + if s[i] < '0' || s[i] > '9' { + return 0, errors.New("invalid character") + } + result = result*10 + int64(s[i]-'0') + } + + if negative { + result = -result + } + + return result, nil +} + +// GenerateID 生成唯一的雪花算法ID +// 返回值:生成的ID,以及可能的错误(如时钟回拨) +// 线程安全:使用互斥锁保证并发安全 +func (s *Snowflake) GenerateID() (int64, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // 获取当前时间戳(毫秒) + now := currentTimestamp() + + // 处理时钟回拨 + if now < s.lastTimestamp { + return 0, ErrClockBackwards + } + + // 同一毫秒内 + if now == s.lastTimestamp { + // 序列号递增 + s.sequence = (s.sequence + 1) & maxSequence + // 序列号溢出,等待下一毫秒 + if s.sequence == 0 { + now = s.waitNextMillis(now) + } + } else { + // 不同毫秒,序列号重置为0 + s.sequence = 0 + } + + // 更新上次生成时间 + s.lastTimestamp = now + + // 组装ID + // ID结构:时间戳部分 | 机器ID部分 | 序列号部分 + id := ((now - customEpoch) << timestampShift) | + (s.nodeID << nodeIDShift) | + s.sequence + + return id, nil +} + +// waitNextMillis 等待到下一毫秒 +// 参数:当前时间戳 +// 返回值:下一毫秒的时间戳 +func (s *Snowflake) waitNextMillis(timestamp int64) int64 { + now := currentTimestamp() + for now <= timestamp { + now = currentTimestamp() + } + return now +} + +// ParseID 解析雪花算法ID,提取其中的信息 +// id: 要解析的雪花算法ID +// 返回值:包含时间戳、机器ID、序列号的结构体 +func ParseID(id int64) *IDInfo { + // 提取序列号(低12位) + sequence := id & maxSequence + + // 提取机器ID(中间10位) + nodeID := (id >> nodeIDShift) & maxNodeID + + // 提取时间戳(高41位) + timestamp := (id >> timestampShift) + customEpoch + + return &IDInfo{ + Timestamp: timestamp, + NodeID: nodeID, + Sequence: sequence, + } +} + +// currentTimestamp 获取当前时间戳(毫秒) +func currentTimestamp() int64 { + return time.Now().UnixNano() / 1e6 +} + +// GetNodeID 获取当前机器ID +func (s *Snowflake) GetNodeID() int64 { + return s.nodeID +} + +// GetCustomEpoch 获取自定义纪元时间 +func GetCustomEpoch() int64 { + return customEpoch +} + +// IDToTime 将雪花算法ID转换为生成时间 +// 这是一个便捷方法,等价于 ParseID(id).Timestamp +func IDToTime(id int64) time.Time { + info := ParseID(id) + return time.Unix(0, info.Timestamp*1e6) // 毫秒转纳秒 +} + +// ValidateID 验证ID是否为有效的雪花算法ID +// 检查时间戳是否在合理范围内 +func ValidateID(id int64) bool { + if id <= 0 { + return false + } + + info := ParseID(id) + + // 检查时间戳是否在合理范围内 + // 不能早于纪元时间,不能晚于当前时间太多(允许1分钟的时钟偏差) + now := currentTimestamp() + if info.Timestamp < customEpoch || info.Timestamp > now+60000 { + return false + } + + // 检查机器ID和序列号是否在有效范围内 + if info.NodeID < 0 || info.NodeID > maxNodeID { + return false + } + if info.Sequence < 0 || info.Sequence > maxSequence { + return false + } + + return true +} diff --git a/internal/pkg/utils/validator.go b/internal/pkg/utils/validator.go new file mode 100644 index 0000000..81782a3 --- /dev/null +++ b/internal/pkg/utils/validator.go @@ -0,0 +1,46 @@ +package utils + +import ( + "regexp" + "strings" +) + +// ValidateEmail 验证邮箱 +func ValidateEmail(email string) bool { + pattern := `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$` + matched, _ := regexp.MatchString(pattern, email) + return matched +} + +// ValidateUsername 验证用户名 +func ValidateUsername(username string) bool { + if len(username) < 3 || len(username) > 30 { + return false + } + pattern := `^[a-zA-Z0-9_]+$` + matched, _ := regexp.MatchString(pattern, username) + return matched +} + +// ValidatePassword 验证密码强度 +func ValidatePassword(password string) bool { + if len(password) < 6 || len(password) > 50 { + return false + } + return true +} + +// ValidatePhone 验证手机号 +func ValidatePhone(phone string) bool { + pattern := `^1[3-9]\d{9}$` + matched, _ := regexp.MatchString(pattern, phone) + return matched +} + +// SanitizeHTML 清理HTML +func SanitizeHTML(input string) string { + // 简单实现,实际使用建议用专门库 + input = strings.ReplaceAll(input, "<", "<") + input = strings.ReplaceAll(input, ">", ">") + return input +} diff --git a/internal/pkg/websocket/websocket.go b/internal/pkg/websocket/websocket.go new file mode 100644 index 0000000..d093830 --- /dev/null +++ b/internal/pkg/websocket/websocket.go @@ -0,0 +1,440 @@ +package websocket + +import ( + "carrot_bbs/internal/model" + "encoding/json" + "log" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +// WebSocket消息类型常量 +const ( + MessageTypePing = "ping" + MessageTypePong = "pong" + MessageTypeMessage = "message" + MessageTypeTyping = "typing" + MessageTypeRead = "read" + MessageTypeAck = "ack" + MessageTypeError = "error" + MessageTypeRecall = "recall" // 撤回消息 + MessageTypeSystem = "system" // 系统消息 + MessageTypeNotification = "notification" // 通知消息 + MessageTypeAnnouncement = "announcement" // 公告消息 + + // 群组相关消息类型 + MessageTypeGroupMessage = "group_message" // 群消息 + MessageTypeGroupTyping = "group_typing" // 群输入状态 + MessageTypeGroupNotice = "group_notice" // 群组通知(成员变动等) + MessageTypeGroupMention = "group_mention" // @提及通知 + MessageTypeGroupRead = "group_read" // 群消息已读 + MessageTypeGroupRecall = "group_recall" // 群消息撤回 + + // Meta事件详细类型 + MetaDetailTypeHeartbeat = "heartbeat" + MetaDetailTypeTyping = "typing" + MetaDetailTypeAck = "ack" // 消息发送确认 + MetaDetailTypeRead = "read" // 已读回执 +) + +// WSMessage WebSocket消息结构 +type WSMessage struct { + Type string `json:"type"` + Data interface{} `json:"data"` + Timestamp int64 `json:"timestamp"` +} + +// ChatMessage 聊天消息结构 +type ChatMessage struct { + ID string `json:"id"` + ConversationID string `json:"conversation_id"` + SenderID string `json:"sender_id"` + Seq int64 `json:"seq"` + Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组) + ReplyToID *string `json:"reply_to_id,omitempty"` + CreatedAt int64 `json:"created_at"` +} + +// SystemMessage 系统消息结构 +type SystemMessage struct { + ID string `json:"id"` // 消息ID + Type string `json:"type"` // 消息子类型(如:account_banned, post_approved等) + Title string `json:"title"` // 消息标题 + Content string `json:"content"` // 消息内容 + Data map[string]interface{} `json:"data"` // 额外数据 + CreatedAt int64 `json:"created_at"` // 创建时间戳 +} + +// NotificationMessage 通知消息结构 +type NotificationMessage struct { + ID string `json:"id"` // 通知ID + Type string `json:"type"` // 通知类型(like, comment, follow, mention等) + Title string `json:"title"` // 通知标题 + Content string `json:"content"` // 通知内容 + TriggerUser *NotificationUser `json:"trigger_user"` // 触发用户 + ResourceType string `json:"resource_type"` // 资源类型(post, comment等) + ResourceID string `json:"resource_id"` // 资源ID + Extra map[string]interface{} `json:"extra"` // 额外数据 + CreatedAt int64 `json:"created_at"` // 创建时间戳 +} + +// NotificationUser 通知中的用户信息 +type NotificationUser struct { + ID string `json:"id"` + Username string `json:"username"` + Avatar string `json:"avatar"` +} + +// AnnouncementMessage 公告消息结构 +type AnnouncementMessage struct { + ID string `json:"id"` // 公告ID + Title string `json:"title"` // 公告标题 + Content string `json:"content"` // 公告内容 + Priority int `json:"priority"` // 优先级(1-10) + CreatedAt int64 `json:"created_at"` // 创建时间戳 +} + +// GroupMessage 群消息结构 +type GroupMessage struct { + ID string `json:"id"` // 消息ID + ConversationID string `json:"conversation_id"` // 会话ID(群聊会话) + GroupID string `json:"group_id"` // 群组ID + SenderID string `json:"sender_id"` // 发送者ID + Seq int64 `json:"seq"` // 消息序号 + Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组) + ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID + MentionUsers []uint64 `json:"mention_users,omitempty"` // @的用户ID列表 + MentionAll bool `json:"mention_all"` // 是否@所有人 + CreatedAt int64 `json:"created_at"` // 创建时间戳 +} + +// GroupTypingMessage 群输入状态消息 +type GroupTypingMessage struct { + GroupID string `json:"group_id"` // 群组ID + UserID string `json:"user_id"` // 用户ID + Username string `json:"username"` // 用户名 + IsTyping bool `json:"is_typing"` // 是否正在输入 +} + +// GroupNoticeMessage 群组通知消息 +type GroupNoticeMessage struct { + NoticeType string `json:"notice_type"` // 通知类型:member_join, member_leave, member_removed, role_changed, muted, unmuted, group_dissolved + GroupID string `json:"group_id"` // 群组ID + Data interface{} `json:"data"` // 通知数据 + Timestamp int64 `json:"timestamp"` // 时间戳 + MessageID string `json:"message_id,omitempty"` // 消息ID(如果通知保存为消息) + Seq int64 `json:"seq,omitempty"` // 消息序号(如果通知保存为消息) +} + +// GroupNoticeData 通知数据结构 +type GroupNoticeData struct { + // 成员变动 + UserID string `json:"user_id,omitempty"` // 变动的用户ID + Username string `json:"username,omitempty"` // 用户名 + OperatorID string `json:"operator_id,omitempty"` // 操作者ID + OpName string `json:"op_name,omitempty"` // 操作者名称 + NewRole string `json:"new_role,omitempty"` // 新角色 + OldRole string `json:"old_role,omitempty"` // 旧角色 + MemberCount int `json:"member_count,omitempty"` // 当前成员数 + + // 群设置变更 + MuteAll bool `json:"mute_all,omitempty"` // 全员禁言状态 +} + +// GroupMentionMessage @提及通知消息 +type GroupMentionMessage struct { + GroupID string `json:"group_id"` // 群组ID + MessageID string `json:"message_id"` // 消息ID + FromUserID string `json:"from_user_id"` // 发送者ID + FromName string `json:"from_name"` // 发送者名称 + Content string `json:"content"` // 消息内容预览 + MentionAll bool `json:"mention_all"` // 是否@所有人 + CreatedAt int64 `json:"created_at"` // 创建时间戳 +} + +// AckMessage 消息发送确认结构 +type AckMessage struct { + ConversationID string `json:"conversation_id"` // 会话ID + GroupID string `json:"group_id,omitempty"` // 群组ID(群聊时) + ID string `json:"id"` // 消息ID + SenderID string `json:"sender_id"` // 发送者ID + Seq int64 `json:"seq"` // 消息序号 + Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组) + CreatedAt int64 `json:"created_at"` // 创建时间戳 +} + +// Client WebSocket客户端 +type Client struct { + ID string + UserID string + Conn *websocket.Conn + Send chan []byte + Manager *WebSocketManager + IsClosed bool + Mu sync.Mutex + closeOnce sync.Once // 确保 Send channel 只关闭一次 +} + +// WebSocketManager WebSocket连接管理器 +type WebSocketManager struct { + clients map[string]*Client // userID -> Client + register chan *Client + unregister chan *Client + broadcast chan *BroadcastMessage + mutex sync.RWMutex +} + +// BroadcastMessage 广播消息 +type BroadcastMessage struct { + Message *WSMessage + ExcludeUser string // 排除的用户ID,为空表示不排除 + TargetUser string // 目标用户ID,为空表示广播给所有用户 +} + +// NewWebSocketManager 创建WebSocket管理器 +func NewWebSocketManager() *WebSocketManager { + return &WebSocketManager{ + clients: make(map[string]*Client), + register: make(chan *Client, 100), + unregister: make(chan *Client, 100), + broadcast: make(chan *BroadcastMessage, 100), + } +} + +// Start 启动管理器 +func (m *WebSocketManager) Start() { + go func() { + for { + select { + case client := <-m.register: + m.mutex.Lock() + m.clients[client.UserID] = client + m.mutex.Unlock() + log.Printf("WebSocket client connected: userID=%s, 当前在线用户数=%d", client.UserID, len(m.clients)) + + case client := <-m.unregister: + m.mutex.Lock() + if _, ok := m.clients[client.UserID]; ok { + delete(m.clients, client.UserID) + // 使用 closeOnce 确保 channel 只关闭一次,避免 panic + client.closeOnce.Do(func() { + close(client.Send) + }) + log.Printf("WebSocket client disconnected: userID=%s", client.UserID) + } + m.mutex.Unlock() + + case broadcast := <-m.broadcast: + m.sendMessage(broadcast) + } + } + }() +} + +// Register 注册客户端 +func (m *WebSocketManager) Register(client *Client) { + m.register <- client +} + +// Unregister 注销客户端 +func (m *WebSocketManager) Unregister(client *Client) { + m.unregister <- client +} + +// Broadcast 广播消息给所有用户 +func (m *WebSocketManager) Broadcast(msg *WSMessage) { + m.broadcast <- &BroadcastMessage{ + Message: msg, + TargetUser: "", + } +} + +// SendToUser 发送消息给指定用户 +func (m *WebSocketManager) SendToUser(userID string, msg *WSMessage) { + m.broadcast <- &BroadcastMessage{ + Message: msg, + TargetUser: userID, + } +} + +// SendToUsers 发送消息给指定用户列表 +func (m *WebSocketManager) SendToUsers(userIDs []string, msg *WSMessage) { + for _, userID := range userIDs { + m.SendToUser(userID, msg) + } +} + +// GetClient 获取客户端 +func (m *WebSocketManager) GetClient(userID string) (*Client, bool) { + m.mutex.RLock() + defer m.mutex.RUnlock() + client, ok := m.clients[userID] + return client, ok +} + +// GetAllClients 获取所有客户端 +func (m *WebSocketManager) GetAllClients() map[string]*Client { + m.mutex.RLock() + defer m.mutex.RUnlock() + return m.clients +} + +// GetClientCount 获取在线用户数量 +func (m *WebSocketManager) GetClientCount() int { + m.mutex.RLock() + defer m.mutex.RUnlock() + return len(m.clients) +} + +// IsUserOnline 检查用户是否在线 +func (m *WebSocketManager) IsUserOnline(userID string) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + _, ok := m.clients[userID] + log.Printf("[DEBUG IsUserOnline] 检查用户 %s, 结果=%v, 当前在线用户=%v", userID, ok, m.clients) + return ok +} + +// sendMessage 发送消息 +func (m *WebSocketManager) sendMessage(broadcast *BroadcastMessage) { + msgBytes, err := json.Marshal(broadcast.Message) + if err != nil { + log.Printf("Failed to marshal message: %v", err) + return + } + + log.Printf("[DEBUG WebSocket] sendMessage: 目标用户=%s, 当前在线用户数=%d, 消息类型=%s", + broadcast.TargetUser, len(m.clients), broadcast.Message.Type) + + m.mutex.RLock() + defer m.mutex.RUnlock() + + for userID, client := range m.clients { + // 如果指定了目标用户,只发送给目标用户 + if broadcast.TargetUser != "" && userID != broadcast.TargetUser { + continue + } + + // 如果指定了排除用户,跳过 + if broadcast.ExcludeUser != "" && userID == broadcast.ExcludeUser { + continue + } + + select { + case client.Send <- msgBytes: + log.Printf("[DEBUG WebSocket] 成功发送消息到用户 %s, 消息类型=%s", userID, broadcast.Message.Type) + default: + log.Printf("Failed to send message to user %s: channel full", userID) + } + } +} + +// SendPing 发送心跳 +func (c *Client) SendPing() error { + c.Mu.Lock() + defer c.Mu.Unlock() + if c.IsClosed { + return nil + } + msg := WSMessage{ + Type: MessageTypePing, + Data: nil, + Timestamp: time.Now().UnixMilli(), + } + msgBytes, _ := json.Marshal(msg) + return c.Conn.WriteMessage(websocket.TextMessage, msgBytes) +} + +// SendPong 发送Pong响应 +func (c *Client) SendPong() error { + c.Mu.Lock() + defer c.Mu.Unlock() + if c.IsClosed { + return nil + } + msg := WSMessage{ + Type: MessageTypePong, + Data: nil, + Timestamp: time.Now().UnixMilli(), + } + msgBytes, _ := json.Marshal(msg) + return c.Conn.WriteMessage(websocket.TextMessage, msgBytes) +} + +// WritePump 写入泵,将消息从Manager发送到客户端 +func (c *Client) WritePump() { + defer func() { + c.Conn.Close() + c.Mu.Lock() + c.IsClosed = true + c.Mu.Unlock() + }() + + for { + message, ok := <-c.Send + if !ok { + c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + c.Mu.Lock() + if c.IsClosed { + c.Mu.Unlock() + return + } + err := c.Conn.WriteMessage(websocket.TextMessage, message) + c.Mu.Unlock() + + if err != nil { + log.Printf("Write error: %v", err) + return + } + } +} + +// ReadPump 读取泵,从客户端读取消息 +func (c *Client) ReadPump(handler func(msg *WSMessage)) { + defer func() { + c.Manager.Unregister(c) + c.Conn.Close() + c.Mu.Lock() + c.IsClosed = true + c.Mu.Unlock() + }() + + c.Conn.SetReadLimit(512 * 1024) // 512KB + c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + c.Conn.SetPongHandler(func(string) error { + c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + + for { + _, message, err := c.Conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("WebSocket error: %v", err) + } + break + } + + var wsMsg WSMessage + if err := json.Unmarshal(message, &wsMsg); err != nil { + log.Printf("Failed to unmarshal message: %v", err) + continue + } + + handler(&wsMsg) + } +} + +// CreateWSMessage 创建WebSocket消息 +func CreateWSMessage(msgType string, data interface{}) *WSMessage { + return &WSMessage{ + Type: msgType, + Data: data, + Timestamp: time.Now().UnixMilli(), + } +} diff --git a/internal/repository/comment_repo.go b/internal/repository/comment_repo.go new file mode 100644 index 0000000..9836c0c --- /dev/null +++ b/internal/repository/comment_repo.go @@ -0,0 +1,296 @@ +package repository + +import ( + "carrot_bbs/internal/model" + + "gorm.io/gorm" +) + +// CommentRepository 评论仓储 +type CommentRepository struct { + db *gorm.DB +} + +// NewCommentRepository 创建评论仓储 +func NewCommentRepository(db *gorm.DB) *CommentRepository { + return &CommentRepository{db: db} +} + +// Create 创建评论 +func (r *CommentRepository) Create(comment *model.Comment) error { + return r.db.Transaction(func(tx *gorm.DB) error { + // 创建评论 + err := tx.Create(comment).Error + if err != nil { + return err + } + + // 增加帖子的评论数并同步热度分 + if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID). + Updates(map[string]interface{}{ + "comments_count": gorm.Expr("comments_count + 1"), + "hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"), + }).Error; err != nil { + return err + } + + // 如果是回复,增加父评论的回复数 + if comment.ParentID != nil && *comment.ParentID != "" { + if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID). + UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil { + return err + } + } + + return nil + }) +} + +// GetByID 根据ID获取评论 +func (r *CommentRepository) GetByID(id string) (*model.Comment, error) { + var comment model.Comment + err := r.db.Preload("User").First(&comment, "id = ?", id).Error + if err != nil { + return nil, err + } + return &comment, nil +} + +// Update 更新评论 +func (r *CommentRepository) Update(comment *model.Comment) error { + return r.db.Save(comment).Error +} + +// UpdateModerationStatus 更新评论审核状态 +func (r *CommentRepository) UpdateModerationStatus(commentID string, status model.CommentStatus) error { + return r.db.Model(&model.Comment{}). + Where("id = ?", commentID). + Update("status", status).Error +} + +// Delete 删除评论(软删除,同时清理关联数据) +func (r *CommentRepository) Delete(id string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + // 先查询评论获取post_id和parent_id + var comment model.Comment + if err := tx.First(&comment, "id = ?", id).Error; err != nil { + return err + } + + // 删除评论点赞记录 + if err := tx.Where("comment_id = ?", id).Delete(&model.CommentLike{}).Error; err != nil { + return err + } + + // 删除评论(软删除) + if err := tx.Delete(&model.Comment{}, "id = ?", id).Error; err != nil { + return err + } + + // 减少帖子的评论数并同步热度分 + if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID). + Updates(map[string]interface{}{ + "comments_count": gorm.Expr("comments_count - 1"), + "hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"), + }).Error; err != nil { + return err + } + + // 如果是回复,减少父评论的回复数 + if comment.ParentID != nil && *comment.ParentID != "" { + if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID). + UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil { + return err + } + } + + return nil + }) +} + +// GetByPostID 获取帖子评论 +func (r *CommentRepository) GetByPostID(postID string, page, pageSize int) ([]*model.Comment, int64, error) { + var comments []*model.Comment + var total int64 + + r.db.Model(&model.Comment{}).Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).Count(&total) + + offset := (page - 1) * pageSize + err := r.db.Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished). + Preload("User"). + Offset(offset).Limit(pageSize). + Order("created_at ASC"). + Find(&comments).Error + + return comments, total, err +} + +// GetByPostIDWithReplies 获取帖子评论(包含回复,扁平化结构) +// 所有层级的回复都扁平化展示在顶级评论的 replies 中 +func (r *CommentRepository) GetByPostIDWithReplies(postID string, page, pageSize, replyLimit int) ([]*model.Comment, int64, error) { + var comments []*model.Comment + var total int64 + + r.db.Model(&model.Comment{}).Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).Count(&total) + + offset := (page - 1) * pageSize + err := r.db.Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished). + Preload("User"). + Offset(offset).Limit(pageSize). + Order("created_at ASC"). + Find(&comments).Error + + if err != nil { + return nil, 0, err + } + + if len(comments) == 0 { + return comments, total, nil + } + + rootIDs := make([]string, 0, len(comments)) + commentsByID := make(map[string]*model.Comment, len(comments)) + for _, comment := range comments { + rootIDs = append(rootIDs, comment.ID) + commentsByID[comment.ID] = comment + } + + // 批量加载所有回复,内存中按 root_id 分组并裁剪每个根评论的返回条数 + var allReplies []*model.Comment + if err := r.db.Where("root_id IN ? AND status = ?", rootIDs, model.CommentStatusPublished). + Preload("User"). + Order("created_at ASC"). + Find(&allReplies).Error; err != nil { + return nil, 0, err + } + + repliesByRoot := make(map[string][]*model.Comment, len(rootIDs)) + for _, reply := range allReplies { + if reply.RootID == nil { + continue + } + rootID := *reply.RootID + if replyLimit <= 0 || len(repliesByRoot[rootID]) < replyLimit { + repliesByRoot[rootID] = append(repliesByRoot[rootID], reply) + } + } + + type replyCountRow struct { + RootID string + Total int64 + } + var replyCountRows []replyCountRow + if err := r.db.Model(&model.Comment{}). + Select("root_id, COUNT(*) AS total"). + Where("root_id IN ? AND status = ?", rootIDs, model.CommentStatusPublished). + Group("root_id"). + Scan(&replyCountRows).Error; err != nil { + return nil, 0, err + } + + replyCountMap := make(map[string]int64, len(replyCountRows)) + for _, row := range replyCountRows { + replyCountMap[row.RootID] = row.Total + } + + for _, rootID := range rootIDs { + comment := commentsByID[rootID] + comment.Replies = repliesByRoot[rootID] + comment.RepliesCount = int(replyCountMap[rootID]) + } + + return comments, total, nil +} + +// loadFlatReplies 加载评论的所有回复(扁平化,所有层级都在同一层) +func (r *CommentRepository) loadFlatReplies(rootComment *model.Comment, limit int) { + var allReplies []*model.Comment + + // 查询所有以该评论为根评论的回复(不包括顶级评论本身) + r.db.Where("root_id = ? AND status = ?", rootComment.ID, model.CommentStatusPublished). + Preload("User"). + Order("created_at ASC"). + Limit(limit). + Find(&allReplies) + + rootComment.Replies = allReplies +} + +// GetRepliesByRootID 根据根评论ID分页获取回复(扁平化) +func (r *CommentRepository) GetRepliesByRootID(rootID string, page, pageSize int) ([]*model.Comment, int64, error) { + var replies []*model.Comment + var total int64 + + // 统计总数 + r.db.Model(&model.Comment{}).Where("root_id = ? AND status = ?", rootID, model.CommentStatusPublished).Count(&total) + + // 分页查询 + offset := (page - 1) * pageSize + err := r.db.Where("root_id = ? AND status = ?", rootID, model.CommentStatusPublished). + Preload("User"). + Order("created_at ASC"). + Offset(offset). + Limit(pageSize). + Find(&replies).Error + + return replies, total, err +} + +// GetReplies 获取回复 +func (r *CommentRepository) GetReplies(parentID string) ([]*model.Comment, error) { + var comments []*model.Comment + err := r.db.Where("parent_id = ? AND status = ?", parentID, model.CommentStatusPublished). + Preload("User"). + Order("created_at ASC"). + Find(&comments).Error + return comments, err +} + +// Like 点赞评论 +func (r *CommentRepository) Like(commentID, userID string) error { + // 检查是否已经点赞 + var existing model.CommentLike + err := r.db.Where("comment_id = ? AND user_id = ?", commentID, userID).First(&existing).Error + if err == nil { + // 已经点赞 + return nil + } + if err != gorm.ErrRecordNotFound { + return err + } + + // 创建点赞记录 + like := &model.CommentLike{ + CommentID: commentID, + UserID: userID, + } + err = r.db.Create(like).Error + if err != nil { + return err + } + + // 增加评论点赞数 + return r.db.Model(&model.Comment{}).Where("id = ?", commentID). + UpdateColumn("likes_count", gorm.Expr("likes_count + 1")).Error +} + +// Unlike 取消点赞评论 +func (r *CommentRepository) Unlike(commentID, userID string) error { + result := r.db.Where("comment_id = ? AND user_id = ?", commentID, userID).Delete(&model.CommentLike{}) + if result.Error != nil { + return result.Error + } + if result.RowsAffected > 0 { + // 减少评论点赞数 + return r.db.Model(&model.Comment{}).Where("id = ?", commentID). + UpdateColumn("likes_count", gorm.Expr("likes_count - 1")).Error + } + return nil +} + +// IsLiked 检查是否已点赞 +func (r *CommentRepository) IsLiked(commentID, userID string) bool { + var count int64 + r.db.Model(&model.CommentLike{}).Where("comment_id = ? AND user_id = ?", commentID, userID).Count(&count) + return count > 0 +} diff --git a/internal/repository/device_token_repo.go b/internal/repository/device_token_repo.go new file mode 100644 index 0000000..49cb5a6 --- /dev/null +++ b/internal/repository/device_token_repo.go @@ -0,0 +1,166 @@ +package repository + +import ( + "time" + + "carrot_bbs/internal/model" + + "gorm.io/gorm" +) + +// DeviceTokenRepository 设备Token仓储 +type DeviceTokenRepository struct { + db *gorm.DB +} + +// NewDeviceTokenRepository 创建设备Token仓储 +func NewDeviceTokenRepository(db *gorm.DB) *DeviceTokenRepository { + return &DeviceTokenRepository{db: db} +} + +// Create 创建设备Token +func (r *DeviceTokenRepository) Create(token *model.DeviceToken) error { + return r.db.Create(token).Error +} + +// GetByID 根据ID获取设备Token +func (r *DeviceTokenRepository) GetByID(id int64) (*model.DeviceToken, error) { + var token model.DeviceToken + err := r.db.First(&token, "id = ?", id).Error + if err != nil { + return nil, err + } + return &token, nil +} + +// Update 更新设备Token +func (r *DeviceTokenRepository) Update(token *model.DeviceToken) error { + return r.db.Save(token).Error +} + +// Delete 删除设备Token(软删除) +func (r *DeviceTokenRepository) Delete(id int64) error { + return r.db.Delete(&model.DeviceToken{}, id).Error +} + +// GetByUserID 获取用户所有设备 +// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 +func (r *DeviceTokenRepository) GetByUserID(userID string) ([]*model.DeviceToken, error) { + var tokens []*model.DeviceToken + err := r.db.Where("user_id = ?", userID). + Order("created_at DESC"). + Find(&tokens).Error + return tokens, err +} + +// GetActiveByUserID 获取用户活跃设备 +// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 +func (r *DeviceTokenRepository) GetActiveByUserID(userID string) ([]*model.DeviceToken, error) { + var tokens []*model.DeviceToken + err := r.db.Where("user_id = ? AND is_active = ?", userID, true). + Order("last_used_at DESC"). + Find(&tokens).Error + return tokens, err +} + +// GetByDeviceID 根据设备ID获取设备Token +func (r *DeviceTokenRepository) GetByDeviceID(deviceID string) (*model.DeviceToken, error) { + var token model.DeviceToken + err := r.db.Where("device_id = ?", deviceID).First(&token).Error + if err != nil { + return nil, err + } + return &token, nil +} + +// GetByPushToken 根据推送Token获取设备信息 +func (r *DeviceTokenRepository) GetByPushToken(pushToken string) (*model.DeviceToken, error) { + var token model.DeviceToken + err := r.db.Where("push_token = ?", pushToken).First(&token).Error + if err != nil { + return nil, err + } + return &token, nil +} + +// DeactivateAllExcept 登出其他设备(停用除指定设备外的所有设备) +func (r *DeviceTokenRepository) DeactivateAllExcept(userID int64, deviceID string) error { + return r.db.Model(&model.DeviceToken{}). + Where("user_id = ? AND device_id != ?", userID, deviceID). + Update("is_active", false).Error +} + +// Upsert 创建或更新设备Token +// 如果设备ID已存在,则更新Token和激活状态;否则创建新记录 +func (r *DeviceTokenRepository) Upsert(token *model.DeviceToken) error { + var existing model.DeviceToken + err := r.db.Where("device_id = ?", token.DeviceID).First(&existing).Error + + if err == gorm.ErrRecordNotFound { + // 创建新记录 + return r.db.Create(token).Error + } else if err != nil { + return err + } + + // 更新现有记录 + return r.db.Model(&existing).Updates(map[string]interface{}{ + "push_token": token.PushToken, + "is_active": true, + "device_name": token.DeviceName, + "last_used_at": time.Now(), + }).Error +} + +// UpdateLastUsed 更新最后使用时间 +func (r *DeviceTokenRepository) UpdateLastUsed(deviceID string) error { + return r.db.Model(&model.DeviceToken{}). + Where("device_id = ?", deviceID). + Update("last_used_at", time.Now()).Error +} + +// Deactivate 停用设备 +func (r *DeviceTokenRepository) Deactivate(deviceID string) error { + return r.db.Model(&model.DeviceToken{}). + Where("device_id = ?", deviceID). + Update("is_active", false).Error +} + +// Activate 激活设备 +func (r *DeviceTokenRepository) Activate(deviceID string) error { + return r.db.Model(&model.DeviceToken{}). + Where("device_id = ?", deviceID). + Updates(map[string]interface{}{ + "is_active": true, + "last_used_at": time.Now(), + }).Error +} + +// DeleteByUserID 删除用户所有设备Token +func (r *DeviceTokenRepository) DeleteByUserID(userID int64) error { + return r.db.Where("user_id = ?", userID).Delete(&model.DeviceToken{}).Error +} + +// GetDeviceCountByUserID 获取用户设备数量 +func (r *DeviceTokenRepository) GetDeviceCountByUserID(userID int64) (int64, error) { + var count int64 + err := r.db.Model(&model.DeviceToken{}). + Where("user_id = ?", userID). + Count(&count).Error + return count, err +} + +// GetActiveDeviceCountByUserID 获取用户活跃设备数量 +func (r *DeviceTokenRepository) GetActiveDeviceCountByUserID(userID int64) (int64, error) { + var count int64 + err := r.db.Model(&model.DeviceToken{}). + Where("user_id = ? AND is_active = ?", userID, true). + Count(&count).Error + return count, err +} + +// DeleteInactiveDevices 删除长时间未使用的设备 +func (r *DeviceTokenRepository) DeleteInactiveDevices(before time.Time) error { + return r.db.Where("is_active = ? AND last_used_at < ?", false, before). + Delete(&model.DeviceToken{}).Error +} diff --git a/internal/repository/group_join_request_repo.go b/internal/repository/group_join_request_repo.go new file mode 100644 index 0000000..81966e4 --- /dev/null +++ b/internal/repository/group_join_request_repo.go @@ -0,0 +1,50 @@ +package repository + +import ( + "carrot_bbs/internal/model" + + "gorm.io/gorm" +) + +type GroupJoinRequestRepository interface { + Create(req *model.GroupJoinRequest) error + GetByFlag(flag string) (*model.GroupJoinRequest, error) + Update(req *model.GroupJoinRequest) error + GetPendingByGroupAndTarget(groupID, targetUserID string, reqType model.GroupJoinRequestType) (*model.GroupJoinRequest, error) +} + +type groupJoinRequestRepository struct { + db *gorm.DB +} + +func NewGroupJoinRequestRepository(db *gorm.DB) GroupJoinRequestRepository { + return &groupJoinRequestRepository{db: db} +} + +func (r *groupJoinRequestRepository) Create(req *model.GroupJoinRequest) error { + return r.db.Create(req).Error +} + +func (r *groupJoinRequestRepository) GetByFlag(flag string) (*model.GroupJoinRequest, error) { + var req model.GroupJoinRequest + if err := r.db.Where("flag = ?", flag).First(&req).Error; err != nil { + return nil, err + } + return &req, nil +} + +func (r *groupJoinRequestRepository) Update(req *model.GroupJoinRequest) error { + return r.db.Save(req).Error +} + +func (r *groupJoinRequestRepository) GetPendingByGroupAndTarget(groupID, targetUserID string, reqType model.GroupJoinRequestType) (*model.GroupJoinRequest, error) { + var req model.GroupJoinRequest + err := r.db.Where("group_id = ? AND target_user_id = ? AND request_type = ? AND status = ?", + groupID, targetUserID, reqType, model.GroupJoinRequestStatusPending). + Order("created_at DESC"). + First(&req).Error + if err != nil { + return nil, err + } + return &req, nil +} diff --git a/internal/repository/group_repo.go b/internal/repository/group_repo.go new file mode 100644 index 0000000..cfbefe6 --- /dev/null +++ b/internal/repository/group_repo.go @@ -0,0 +1,242 @@ +package repository + +import ( + "carrot_bbs/internal/model" + + "gorm.io/gorm" +) + +// GroupRepository 群组仓库接口 +type GroupRepository interface { + // 群组操作 + Create(group *model.Group) error + GetByID(id string) (*model.Group, error) + Update(group *model.Group) error + Delete(id string) error + GetByOwnerID(ownerID string, page, pageSize int) ([]model.Group, int64, error) + + // 群成员操作 + AddMember(member *model.GroupMember) error + GetMember(groupID string, userID string) (*model.GroupMember, error) + GetMembers(groupID string, page, pageSize int) ([]model.GroupMember, int64, error) + UpdateMember(member *model.GroupMember) error + RemoveMember(groupID string, userID string) error + GetMemberCount(groupID string) (int64, error) + IsMember(groupID string, userID string) (bool, error) + GetUserGroups(userID string, page, pageSize int) ([]model.Group, int64, error) + + // 角色相关 + GetMemberRole(groupID string, userID string) (string, error) + SetMemberRole(groupID string, userID string, role string) error + GetAdmins(groupID string) ([]model.GroupMember, error) + + // 群公告操作 + CreateAnnouncement(announcement *model.GroupAnnouncement) error + GetAnnouncements(groupID string, page, pageSize int) ([]model.GroupAnnouncement, int64, error) + GetAnnouncementByID(id string) (*model.GroupAnnouncement, error) + DeleteAnnouncement(id string) error +} + +// groupRepository 群组仓库实现 +type groupRepository struct { + db *gorm.DB +} + +// NewGroupRepository 创建群组仓库 +func NewGroupRepository(db *gorm.DB) GroupRepository { + return &groupRepository{db: db} +} + +// Create 创建群组 +func (r *groupRepository) Create(group *model.Group) error { + return r.db.Create(group).Error +} + +// GetByID 根据ID获取群组 +func (r *groupRepository) GetByID(id string) (*model.Group, error) { + var group model.Group + err := r.db.First(&group, "id = ?", id).Error + if err != nil { + return nil, err + } + return &group, nil +} + +// Update 更新群组 +func (r *groupRepository) Update(group *model.Group) error { + return r.db.Save(group).Error +} + +// Delete 删除群组 +func (r *groupRepository) Delete(id string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + // 删除群成员 + if err := tx.Where("group_id = ?", id).Delete(&model.GroupMember{}).Error; err != nil { + return err + } + // 删除群公告 + if err := tx.Where("group_id = ?", id).Delete(&model.GroupAnnouncement{}).Error; err != nil { + return err + } + // 删除群组 + if err := tx.Delete(&model.Group{}, "id = ?", id).Error; err != nil { + return err + } + return nil + }) +} + +// GetByOwnerID 根据群主ID获取群组列表 +func (r *groupRepository) GetByOwnerID(ownerID string, page, pageSize int) ([]model.Group, int64, error) { + var groups []model.Group + var total int64 + + query := r.db.Model(&model.Group{}).Where("owner_id = ?", ownerID) + query.Count(&total) + + offset := (page - 1) * pageSize + err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&groups).Error + return groups, total, err +} + +// AddMember 添加群成员 +func (r *groupRepository) AddMember(member *model.GroupMember) error { + return r.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Create(member).Error; err != nil { + return err + } + // 更新群组成员数量 + return tx.Model(&model.Group{}).Where("id = ?", member.GroupID). + Update("member_count", gorm.Expr("member_count + ?", 1)).Error + }) +} + +// GetMember 获取群成员 +func (r *groupRepository) GetMember(groupID string, userID string) (*model.GroupMember, error) { + var member model.GroupMember + err := r.db.First(&member, "group_id = ? AND user_id = ?", groupID, userID).Error + if err != nil { + return nil, err + } + return &member, nil +} + +// GetMembers 获取群成员列表 +func (r *groupRepository) GetMembers(groupID string, page, pageSize int) ([]model.GroupMember, int64, error) { + var members []model.GroupMember + var total int64 + + query := r.db.Model(&model.GroupMember{}).Where("group_id = ?", groupID) + query.Count(&total) + + offset := (page - 1) * pageSize + err := query.Offset(offset).Limit(pageSize).Order("created_at ASC").Find(&members).Error + return members, total, err +} + +// UpdateMember 更新群成员 +func (r *groupRepository) UpdateMember(member *model.GroupMember) error { + return r.db.Save(member).Error +} + +// RemoveMember 移除群成员 +func (r *groupRepository) RemoveMember(groupID string, userID string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + // 删除成员 + if err := tx.Where("group_id = ? AND user_id = ?", groupID, userID).Delete(&model.GroupMember{}).Error; err != nil { + return err + } + // 更新群组成员数量 + return tx.Model(&model.Group{}).Where("id = ?", groupID). + Update("member_count", gorm.Expr("member_count - ?", 1)).Error + }) +} + +// GetMemberCount 获取群成员数量 +func (r *groupRepository) GetMemberCount(groupID string) (int64, error) { + var count int64 + err := r.db.Model(&model.GroupMember{}).Where("group_id = ?", groupID).Count(&count).Error + return count, err +} + +// IsMember 检查是否是群成员 +func (r *groupRepository) IsMember(groupID string, userID string) (bool, error) { + var count int64 + err := r.db.Model(&model.GroupMember{}).Where("group_id = ? AND user_id = ?", groupID, userID).Count(&count).Error + return count > 0, err +} + +// GetUserGroups 获取用户加入的群组列表 +func (r *groupRepository) GetUserGroups(userID string, page, pageSize int) ([]model.Group, int64, error) { + var groups []model.Group + var total int64 + + // 通过群成员表查询用户加入的群组 + subQuery := r.db.Model(&model.GroupMember{}). + Select("group_id"). + Where("user_id = ?", userID) + + query := r.db.Model(&model.Group{}).Where("id IN (?)", subQuery) + query.Count(&total) + + offset := (page - 1) * pageSize + err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&groups).Error + return groups, total, err +} + +// GetMemberRole 获取成员角色 +func (r *groupRepository) GetMemberRole(groupID string, userID string) (string, error) { + member, err := r.GetMember(groupID, userID) + if err != nil { + return "", err + } + return member.Role, nil +} + +// SetMemberRole 设置成员角色 +func (r *groupRepository) SetMemberRole(groupID string, userID string, role string) error { + return r.db.Model(&model.GroupMember{}). + Where("group_id = ? AND user_id = ?", groupID, userID). + Update("role", role).Error +} + +// GetAdmins 获取群管理员列表 +func (r *groupRepository) GetAdmins(groupID string) ([]model.GroupMember, error) { + var admins []model.GroupMember + err := r.db.Where("group_id = ? AND role = ?", groupID, model.GroupRoleAdmin).Find(&admins).Error + return admins, err +} + +// CreateAnnouncement 创建群公告 +func (r *groupRepository) CreateAnnouncement(announcement *model.GroupAnnouncement) error { + return r.db.Create(announcement).Error +} + +// GetAnnouncements 获取群公告列表 +func (r *groupRepository) GetAnnouncements(groupID string, page, pageSize int) ([]model.GroupAnnouncement, int64, error) { + var announcements []model.GroupAnnouncement + var total int64 + + query := r.db.Model(&model.GroupAnnouncement{}).Where("group_id = ?", groupID) + query.Count(&total) + + offset := (page - 1) * pageSize + // 置顶的排在前面,然后按时间倒序 + err := query.Offset(offset).Limit(pageSize).Order("is_pinned DESC, created_at DESC").Find(&announcements).Error + return announcements, total, err +} + +// GetAnnouncementByID 根据ID获取群公告 +func (r *groupRepository) GetAnnouncementByID(id string) (*model.GroupAnnouncement, error) { + var announcement model.GroupAnnouncement + err := r.db.First(&announcement, "id = ?", id).Error + if err != nil { + return nil, err + } + return &announcement, nil +} + +// DeleteAnnouncement 删除群公告 +func (r *groupRepository) DeleteAnnouncement(id string) error { + return r.db.Delete(&model.GroupAnnouncement{}, "id = ?", id).Error +} diff --git a/internal/repository/message_repo.go b/internal/repository/message_repo.go new file mode 100644 index 0000000..5bb2515 --- /dev/null +++ b/internal/repository/message_repo.go @@ -0,0 +1,543 @@ +package repository + +import ( + "carrot_bbs/internal/model" + "fmt" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +// MessageRepository 消息仓储 +type MessageRepository struct { + db *gorm.DB +} + +// NewMessageRepository 创建消息仓储 +func NewMessageRepository(db *gorm.DB) *MessageRepository { + return &MessageRepository{db: db} +} + +// CreateMessage 创建消息 +func (r *MessageRepository) CreateMessage(msg *model.Message) error { + return r.db.Create(msg).Error +} + +// GetConversation 获取会话 +func (r *MessageRepository) GetConversation(id string) (*model.Conversation, error) { + var conv model.Conversation + err := r.db.Preload("Group").First(&conv, "id = ?", id).Error + if err != nil { + return nil, err + } + return &conv, nil +} + +// GetOrCreatePrivateConversation 获取或创建私聊会话 +// 使用参与者关系表来管理会话 +// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 +func (r *MessageRepository) GetOrCreatePrivateConversation(user1ID, user2ID string) (*model.Conversation, error) { + var conv model.Conversation + + fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: user1ID=%s, user2ID=%s\n", user1ID, user2ID) + + // 查找两个用户共同参与的私聊会话 + err := r.db.Table("conversations c"). + Joins("INNER JOIN conversation_participants cp1 ON c.id = cp1.conversation_id AND cp1.user_id = ?", user1ID). + Joins("INNER JOIN conversation_participants cp2 ON c.id = cp2.conversation_id AND cp2.user_id = ?", user2ID). + Where("c.type = ?", model.ConversationTypePrivate). + First(&conv).Error + + if err == nil { + _ = r.db.Model(&model.ConversationParticipant{}). + Where("conversation_id = ? AND user_id IN ?", conv.ID, []string{user1ID, user2ID}). + Update("hidden_at", nil).Error + fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: found existing conversation, ID=%s\n", conv.ID) + return &conv, nil + } + + if err != gorm.ErrRecordNotFound { + return nil, err + } + + // 没找到会话,创建新会话 + fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: no existing conversation found, creating new one\n") + conv = model.Conversation{ + Type: model.ConversationTypePrivate, + } + + // 使用事务创建会话和参与者 + err = r.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Create(&conv).Error; err != nil { + return err + } + + // 创建参与者记录 - UserID 存储为 string (UUID) + participants := []model.ConversationParticipant{ + {ConversationID: conv.ID, UserID: user1ID}, + {ConversationID: conv.ID, UserID: user2ID}, + } + if err := tx.Create(&participants).Error; err != nil { + return err + } + + return nil + }) + + if err == nil { + fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: created new conversation, ID=%s\n", conv.ID) + } + + return &conv, err +} + +// GetConversations 获取用户会话列表 +// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 +func (r *MessageRepository) GetConversations(userID string, page, pageSize int) ([]*model.Conversation, int64, error) { + var convs []*model.Conversation + var total int64 + + // 获取总数 + r.db.Model(&model.ConversationParticipant{}). + Where("user_id = ? AND hidden_at IS NULL", userID). + Count(&total) + + if total == 0 { + return convs, total, nil + } + + offset := (page - 1) * pageSize + // 查询会话列表并预加载关联数据: + // 当前用户维度先按置顶排序,再按更新时间排序 + err := r.db.Model(&model.Conversation{}). + Joins("INNER JOIN conversation_participants cp ON conversations.id = cp.conversation_id"). + Where("cp.user_id = ? AND cp.hidden_at IS NULL", userID). + Preload("Group"). + Offset(offset). + Limit(pageSize). + Order("cp.is_pinned DESC"). + Order("conversations.updated_at DESC"). + Find(&convs).Error + + return convs, total, err +} + +// GetMessages 获取会话消息 +func (r *MessageRepository) GetMessages(conversationID string, page, pageSize int) ([]*model.Message, int64, error) { + var messages []*model.Message + var total int64 + + r.db.Model(&model.Message{}).Where("conversation_id = ?", conversationID).Count(&total) + + offset := (page - 1) * pageSize + err := r.db.Where("conversation_id = ?", conversationID). + Offset(offset). + Limit(pageSize). + Order("seq DESC"). + Find(&messages).Error + + return messages, total, err +} + +// GetMessagesAfterSeq 获取指定seq之后的消息(用于增量同步) +func (r *MessageRepository) GetMessagesAfterSeq(conversationID string, afterSeq int64, limit int) ([]*model.Message, error) { + var messages []*model.Message + err := r.db.Where("conversation_id = ? AND seq > ?", conversationID, afterSeq). + Order("seq ASC"). + Limit(limit). + Find(&messages).Error + return messages, err +} + +// GetMessagesBeforeSeq 获取指定seq之前的历史消息(用于下拉加载更多) +func (r *MessageRepository) GetMessagesBeforeSeq(conversationID string, beforeSeq int64, limit int) ([]*model.Message, error) { + var messages []*model.Message + fmt.Printf("[DEBUG] GetMessagesBeforeSeq: conversationID=%s, beforeSeq=%d, limit=%d\n", conversationID, beforeSeq, limit) + err := r.db.Where("conversation_id = ? AND seq < ?", conversationID, beforeSeq). + Order("seq DESC"). // 降序获取最新消息在前 + Limit(limit). + Find(&messages).Error + fmt.Printf("[DEBUG] GetMessagesBeforeSeq: found %d messages, seq range: ", len(messages)) + for i, m := range messages { + if i < 5 || i >= len(messages)-2 { + fmt.Printf("%d ", m.Seq) + } else if i == 5 { + fmt.Printf("... ") + } + } + fmt.Println() + // 反转回正序 + for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 { + messages[i], messages[j] = messages[j], messages[i] + } + return messages, err +} + +// GetConversationParticipants 获取会话参与者 +func (r *MessageRepository) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) { + var participants []*model.ConversationParticipant + err := r.db.Where("conversation_id = ?", conversationID).Find(&participants).Error + return participants, err +} + +// GetParticipant 获取用户在会话中的参与者信息 +// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 +func (r *MessageRepository) GetParticipant(conversationID string, userID string) (*model.ConversationParticipant, error) { + var participant model.ConversationParticipant + err := r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).First(&participant).Error + if err != nil { + // 如果找不到参与者,尝试添加(修复没有参与者记录的问题) + if err == gorm.ErrRecordNotFound { + // 检查会话是否存在 + var conv model.Conversation + if err := r.db.First(&conv, conversationID).Error; err == nil { + // 会话存在,添加参与者 + participant = model.ConversationParticipant{ + ConversationID: conversationID, + UserID: userID, + } + if err := r.db.Create(&participant).Error; err != nil { + return nil, err + } + return &participant, nil + } + } + return nil, err + } + return &participant, nil +} + +// UpdateLastReadSeq 更新已读位置 +// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 +func (r *MessageRepository) UpdateLastReadSeq(conversationID string, userID string, lastReadSeq int64) error { + result := r.db.Model(&model.ConversationParticipant{}). + Where("conversation_id = ? AND user_id = ?", conversationID, userID). + Update("last_read_seq", lastReadSeq) + + if result.Error != nil { + return result.Error + } + + // 如果没有更新任何记录,说明参与者记录不存在,需要插入 + if result.RowsAffected == 0 { + // 尝试插入新记录(跨数据库 upsert) + err := r.db.Clauses(clause.OnConflict{ + Columns: []clause.Column{ + {Name: "conversation_id"}, + {Name: "user_id"}, + }, + DoUpdates: clause.Assignments(map[string]interface{}{ + "last_read_seq": lastReadSeq, + "updated_at": gorm.Expr("CURRENT_TIMESTAMP"), + }), + }).Create(&model.ConversationParticipant{ + ConversationID: conversationID, + UserID: userID, + LastReadSeq: lastReadSeq, + }).Error + if err != nil { + return err + } + } + + return nil +} + +// UpdatePinned 更新会话置顶状态(用户维度) +func (r *MessageRepository) UpdatePinned(conversationID string, userID string, isPinned bool) error { + result := r.db.Model(&model.ConversationParticipant{}). + Where("conversation_id = ? AND user_id = ?", conversationID, userID). + Update("is_pinned", isPinned) + + if result.Error != nil { + return result.Error + } + + if result.RowsAffected == 0 { + return r.db.Clauses(clause.OnConflict{ + Columns: []clause.Column{ + {Name: "conversation_id"}, + {Name: "user_id"}, + }, + DoUpdates: clause.Assignments(map[string]interface{}{ + "is_pinned": isPinned, + "updated_at": gorm.Expr("CURRENT_TIMESTAMP"), + }), + }).Create(&model.ConversationParticipant{ + ConversationID: conversationID, + UserID: userID, + IsPinned: isPinned, + }).Error + } + + return nil +} + +// GetUnreadCount 获取未读消息数 +// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 +func (r *MessageRepository) GetUnreadCount(conversationID string, userID string) (int64, error) { + var participant model.ConversationParticipant + err := r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).First(&participant).Error + if err != nil { + return 0, err + } + + var count int64 + err = r.db.Model(&model.Message{}). + Where("conversation_id = ? AND sender_id != ? AND seq > ?", conversationID, userID, participant.LastReadSeq). + Count(&count).Error + return count, err +} + +// UpdateConversationLastSeq 更新会话的最后消息seq和时间 +func (r *MessageRepository) UpdateConversationLastSeq(conversationID string, seq int64) error { + return r.db.Model(&model.Conversation{}). + Where("id = ?", conversationID). + Updates(map[string]interface{}{ + "last_seq": seq, + "last_msg_time": gorm.Expr("CURRENT_TIMESTAMP"), + }).Error +} + +// GetNextSeq 获取会话的下一个seq值 +func (r *MessageRepository) GetNextSeq(conversationID string) (int64, error) { + var conv model.Conversation + err := r.db.Select("last_seq").First(&conv, conversationID).Error + if err != nil { + return 0, err + } + return conv.LastSeq + 1, nil +} + +// CreateMessageWithSeq 创建消息并更新seq(事务操作) +func (r *MessageRepository) CreateMessageWithSeq(msg *model.Message) error { + return r.db.Transaction(func(tx *gorm.DB) error { + // 获取当前seq并+1 + var conv model.Conversation + if err := tx.Select("last_seq").First(&conv, msg.ConversationID).Error; err != nil { + return err + } + + msg.Seq = conv.LastSeq + 1 + + // 创建消息 + if err := tx.Create(msg).Error; err != nil { + return err + } + + // 更新会话的last_seq + if err := tx.Model(&model.Conversation{}). + Where("id = ?", msg.ConversationID). + Updates(map[string]interface{}{ + "last_seq": msg.Seq, + "last_msg_time": gorm.Expr("CURRENT_TIMESTAMP"), + }).Error; err != nil { + return err + } + + // 新消息到达后,自动恢复被“仅自己删除”的会话 + if err := tx.Model(&model.ConversationParticipant{}). + Where("conversation_id = ?", msg.ConversationID). + Update("hidden_at", nil).Error; err != nil { + return err + } + + return nil + }) +} + +// GetAllUnreadCount 获取用户所有会话的未读消息总数 +// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 +func (r *MessageRepository) GetAllUnreadCount(userID string) (int64, error) { + var totalUnread int64 + err := r.db.Table("conversation_participants AS cp"). + Joins("LEFT JOIN messages AS m ON m.conversation_id = cp.conversation_id AND m.sender_id <> ? AND m.seq > cp.last_read_seq AND m.deleted_at IS NULL", userID). + Where("cp.user_id = ?", userID). + Select("COALESCE(COUNT(m.id), 0)"). + Scan(&totalUnread).Error + return totalUnread, err +} + +// GetMessageByID 根据ID获取消息 +func (r *MessageRepository) GetMessageByID(messageID string) (*model.Message, error) { + var message model.Message + err := r.db.First(&message, "id = ?", messageID).Error + if err != nil { + return nil, err + } + return &message, nil +} + +// CountMessagesBySenderInConversation 统计会话中某用户已发送消息数 +func (r *MessageRepository) CountMessagesBySenderInConversation(conversationID, senderID string) (int64, error) { + var count int64 + err := r.db.Model(&model.Message{}). + Where("conversation_id = ? AND sender_id = ?", conversationID, senderID). + Count(&count).Error + return count, err +} + +// UpdateMessageStatus 更新消息状态 +func (r *MessageRepository) UpdateMessageStatus(messageID int64, status model.MessageStatus) error { + return r.db.Model(&model.Message{}). + Where("id = ?", messageID). + Update("status", status).Error +} + +// GetOrCreateSystemParticipant 获取或创建用户在系统会话中的参与者记录 +// 系统会话是虚拟会话,但需要参与者记录来跟踪已读状态 +func (r *MessageRepository) GetOrCreateSystemParticipant(userID string) (*model.ConversationParticipant, error) { + var participant model.ConversationParticipant + err := r.db.Where("conversation_id = ? AND user_id = ?", + model.SystemConversationID, userID).First(&participant).Error + + if err == nil { + return &participant, nil + } + + if err != gorm.ErrRecordNotFound { + return nil, err + } + + // 自动创建参与者记录 + participant = model.ConversationParticipant{ + ConversationID: model.SystemConversationID, + UserID: userID, + LastReadSeq: 0, + } + + if err := r.db.Create(&participant).Error; err != nil { + return nil, err + } + + return &participant, nil +} + +// GetSystemMessagesUnreadCount 获取系统消息未读数 +func (r *MessageRepository) GetSystemMessagesUnreadCount(userID string) (int64, error) { + // 获取或创建参与者记录 + participant, err := r.GetOrCreateSystemParticipant(userID) + if err != nil { + return 0, err + } + + // 计算未读数:查询 seq > last_read_seq 的消息 + var count int64 + err = r.db.Model(&model.Message{}). + Where("conversation_id = ? AND seq > ?", + model.SystemConversationID, participant.LastReadSeq). + Count(&count).Error + + return count, err +} + +// MarkAllSystemMessagesAsRead 标记所有系统消息已读 +func (r *MessageRepository) MarkAllSystemMessagesAsRead(userID string) error { + // 获取系统会话的最新 seq + var maxSeq int64 + err := r.db.Model(&model.Message{}). + Where("conversation_id = ?", model.SystemConversationID). + Select("COALESCE(MAX(seq), 0)"). + Scan(&maxSeq).Error + if err != nil { + return err + } + + // 使用跨数据库 upsert 方式更新或创建参与者记录 + return r.db.Clauses(clause.OnConflict{ + Columns: []clause.Column{ + {Name: "conversation_id"}, + {Name: "user_id"}, + }, + DoUpdates: clause.Assignments(map[string]interface{}{ + "last_read_seq": maxSeq, + "updated_at": gorm.Expr("CURRENT_TIMESTAMP"), + }), + }).Create(&model.ConversationParticipant{ + ConversationID: model.SystemConversationID, + UserID: userID, + LastReadSeq: maxSeq, + }).Error +} + +// GetConversationByGroupID 通过群组ID获取会话 +func (r *MessageRepository) GetConversationByGroupID(groupID string) (*model.Conversation, error) { + var conv model.Conversation + err := r.db.Where("group_id = ?", groupID).First(&conv).Error + if err != nil { + return nil, err + } + return &conv, nil +} + +// RemoveParticipant 移除会话参与者 +// 当用户退出群聊时,需要同时移除其在对应会话中的参与者记录 +func (r *MessageRepository) RemoveParticipant(conversationID string, userID string) error { + return r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID). + Delete(&model.ConversationParticipant{}).Error +} + +// AddParticipant 添加会话参与者 +// 当用户加入群聊时,需要同时将其添加到对应会话的参与者记录 +func (r *MessageRepository) AddParticipant(conversationID string, userID string) error { + // 先检查是否已经是参与者 + var count int64 + err := r.db.Model(&model.ConversationParticipant{}). + Where("conversation_id = ? AND user_id = ?", conversationID, userID). + Count(&count).Error + if err != nil { + return err + } + + // 如果已经是参与者,直接返回 + if count > 0 { + return nil + } + + // 添加参与者 + participant := model.ConversationParticipant{ + ConversationID: conversationID, + UserID: userID, + LastReadSeq: 0, + } + return r.db.Create(&participant).Error +} + +// DeleteConversationByGroupID 删除群组对应的会话及其参与者 +// 当解散群组时调用 +func (r *MessageRepository) DeleteConversationByGroupID(groupID string) error { + // 获取群组对应的会话 + conv, err := r.GetConversationByGroupID(groupID) + if err != nil { + // 如果会话不存在,直接返回 + if err == gorm.ErrRecordNotFound { + return nil + } + return err + } + + return r.db.Transaction(func(tx *gorm.DB) error { + // 删除会话参与者 + if err := tx.Where("conversation_id = ?", conv.ID).Delete(&model.ConversationParticipant{}).Error; err != nil { + return err + } + // 删除会话中的消息 + if err := tx.Where("conversation_id = ?", conv.ID).Delete(&model.Message{}).Error; err != nil { + return err + } + // 删除会话 + if err := tx.Delete(&model.Conversation{}, "id = ?", conv.ID).Error; err != nil { + return err + } + return nil + }) +} + +// HideConversationForUser 仅对当前用户隐藏会话(私聊删除) +func (r *MessageRepository) HideConversationForUser(conversationID, userID string) error { + now := time.Now() + return r.db.Model(&model.ConversationParticipant{}). + Where("conversation_id = ? AND user_id = ?", conversationID, userID). + Update("hidden_at", &now).Error +} diff --git a/internal/repository/notification_repo.go b/internal/repository/notification_repo.go new file mode 100644 index 0000000..83e3f28 --- /dev/null +++ b/internal/repository/notification_repo.go @@ -0,0 +1,78 @@ +package repository + +import ( + "carrot_bbs/internal/model" + + "gorm.io/gorm" +) + +// NotificationRepository 通知仓储 +type NotificationRepository struct { + db *gorm.DB +} + +// NewNotificationRepository 创建通知仓储 +func NewNotificationRepository(db *gorm.DB) *NotificationRepository { + return &NotificationRepository{db: db} +} + +// Create 创建通知 +func (r *NotificationRepository) Create(notification *model.Notification) error { + return r.db.Create(notification).Error +} + +// GetByID 根据ID获取通知 +func (r *NotificationRepository) GetByID(id string) (*model.Notification, error) { + var notification model.Notification + err := r.db.First(¬ification, "id = ?", id).Error + if err != nil { + return nil, err + } + return ¬ification, nil +} + +// GetByUserID 获取用户通知 +func (r *NotificationRepository) GetByUserID(userID string, page, pageSize int, unreadOnly bool) ([]*model.Notification, int64, error) { + var notifications []*model.Notification + var total int64 + + query := r.db.Model(&model.Notification{}).Where("user_id = ?", userID) + + if unreadOnly { + query = query.Where("is_read = ?", false) + } + + query.Count(&total) + + offset := (page - 1) * pageSize + err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(¬ifications).Error + + return notifications, total, err +} + +// MarkAsRead 标记为已读 +func (r *NotificationRepository) MarkAsRead(id string) error { + return r.db.Model(&model.Notification{}).Where("id = ?", id).Update("is_read", true).Error +} + +// MarkAllAsRead 标记所有为已读 +func (r *NotificationRepository) MarkAllAsRead(userID string) error { + return r.db.Model(&model.Notification{}).Where("user_id = ?", userID).Update("is_read", true).Error +} + +// Delete 删除通知 +func (r *NotificationRepository) Delete(id string) error { + return r.db.Delete(&model.Notification{}, "id = ?", id).Error +} + +// GetUnreadCount 获取未读数量 +func (r *NotificationRepository) GetUnreadCount(userID string) (int64, error) { + var count int64 + err := r.db.Model(&model.Notification{}).Where("user_id = ? AND is_read = ?", userID, false).Count(&count).Error + return count, err +} + +// DeleteAllByUserID 删除用户所有通知 +func (r *NotificationRepository) DeleteAllByUserID(userID string) error { + return r.db.Where("user_id = ?", userID).Delete(&model.Notification{}).Error +} diff --git a/internal/repository/post_repo.go b/internal/repository/post_repo.go new file mode 100644 index 0000000..56a8574 --- /dev/null +++ b/internal/repository/post_repo.go @@ -0,0 +1,360 @@ +package repository + +import ( + "carrot_bbs/internal/model" + + "gorm.io/gorm" +) + +// PostRepository 帖子仓储 +type PostRepository struct { + db *gorm.DB +} + +// NewPostRepository 创建帖子仓储 +func NewPostRepository(db *gorm.DB) *PostRepository { + return &PostRepository{db: db} +} + +// Create 创建帖子 +func (r *PostRepository) Create(post *model.Post, images []string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + // 创建帖子 + if err := tx.Create(post).Error; err != nil { + return err + } + + // 创建图片记录 + for i, url := range images { + image := &model.PostImage{ + PostID: post.ID, + URL: url, + SortOrder: i, + } + if err := tx.Create(image).Error; err != nil { + return err + } + } + + return nil + }) +} + +// GetByID 根据ID获取帖子 +func (r *PostRepository) GetByID(id string) (*model.Post, error) { + var post model.Post + err := r.db.Preload("User").Preload("Images").First(&post, "id = ?", id).Error + if err != nil { + return nil, err + } + return &post, nil +} + +// Update 更新帖子 +func (r *PostRepository) Update(post *model.Post) error { + return r.db.Save(post).Error +} + +// UpdateModerationStatus 更新帖子审核状态 +func (r *PostRepository) UpdateModerationStatus(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error { + updates := map[string]interface{}{ + "status": status, + "reviewed_at": gorm.Expr("CURRENT_TIMESTAMP"), + "reviewed_by": reviewedBy, + "reject_reason": rejectReason, + } + return r.db.Model(&model.Post{}).Where("id = ?", postID).Updates(updates).Error +} + +// Delete 删除帖子(软删除,同时清理关联数据) +func (r *PostRepository) Delete(id string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + // 删除帖子图片 + if err := tx.Where("post_id = ?", id).Delete(&model.PostImage{}).Error; err != nil { + return err + } + + // 删除帖子点赞记录 + if err := tx.Where("post_id = ?", id).Delete(&model.PostLike{}).Error; err != nil { + return err + } + + // 删除帖子收藏记录 + if err := tx.Where("post_id = ?", id).Delete(&model.Favorite{}).Error; err != nil { + return err + } + + // 删除评论点赞记录(子查询获取该帖子所有评论ID) + if err := tx.Where("comment_id IN (SELECT id FROM comments WHERE post_id = ?)", id).Delete(&model.CommentLike{}).Error; err != nil { + return err + } + + // 删除帖子评论 + if err := tx.Where("post_id = ?", id).Delete(&model.Comment{}).Error; err != nil { + return err + } + + // 最后删除帖子本身(软删除) + return tx.Delete(&model.Post{}, "id = ?", id).Error + }) +} + +// List 分页获取帖子列表 +func (r *PostRepository) List(page, pageSize int, userID string) ([]*model.Post, int64, error) { + var posts []*model.Post + var total int64 + + query := r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished) + + if userID != "" { + query = query.Where("user_id = ?", userID) + } + + query.Count(&total) + + offset := (page - 1) * pageSize + err := query.Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error + + return posts, total, err +} + +// GetUserPosts 获取用户帖子 +func (r *PostRepository) GetUserPosts(userID string, page, pageSize int) ([]*model.Post, int64, error) { + var posts []*model.Post + var total int64 + + r.db.Model(&model.Post{}).Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Count(&total) + + offset := (page - 1) * pageSize + err := r.db.Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error + + return posts, total, err +} + +// GetFavorites 获取用户收藏 +func (r *PostRepository) GetFavorites(userID string, page, pageSize int) ([]*model.Post, int64, error) { + var posts []*model.Post + var total int64 + + subQuery := r.db.Model(&model.Favorite{}).Where("user_id = ?", userID).Select("post_id") + r.db.Model(&model.Post{}).Where("id IN (?) AND status = ?", subQuery, model.PostStatusPublished).Count(&total) + + offset := (page - 1) * pageSize + err := r.db.Where("id IN (?) AND status = ?", subQuery, model.PostStatusPublished).Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error + + return posts, total, err +} + +// Like 点赞帖子 +func (r *PostRepository) Like(postID, userID string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + // 检查是否已经点赞 + var existing model.PostLike + err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&existing).Error + if err == nil { + // 已经点赞,直接返回 + return nil + } + if err != gorm.ErrRecordNotFound { + return err + } + + // 创建点赞记录 + if err := tx.Create(&model.PostLike{ + PostID: postID, + UserID: userID, + }).Error; err != nil { + return err + } + + // 增加帖子点赞数并同步热度分 + return tx.Model(&model.Post{}).Where("id = ?", postID). + Updates(map[string]interface{}{ + "likes_count": gorm.Expr("likes_count + 1"), + "hot_score": gorm.Expr("(likes_count + 1) * 2 + comments_count * 3 + views_count * 0.1"), + }).Error + }) +} + +// Unlike 取消点赞 +func (r *PostRepository) Unlike(postID, userID string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + result := tx.Where("post_id = ? AND user_id = ?", postID, userID).Delete(&model.PostLike{}) + if result.Error != nil { + return result.Error + } + if result.RowsAffected > 0 { + // 减少帖子点赞数并同步热度分 + return tx.Model(&model.Post{}).Where("id = ?", postID). + Updates(map[string]interface{}{ + "likes_count": gorm.Expr("likes_count - 1"), + "hot_score": gorm.Expr("(likes_count - 1) * 2 + comments_count * 3 + views_count * 0.1"), + }).Error + } + return nil + }) +} + +// IsLiked 检查是否点赞 +func (r *PostRepository) IsLiked(postID, userID string) bool { + var count int64 + r.db.Model(&model.PostLike{}).Where("post_id = ? AND user_id = ?", postID, userID).Count(&count) + return count > 0 +} + +// Favorite 收藏帖子 +func (r *PostRepository) Favorite(postID, userID string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + // 检查是否已经收藏 + var existing model.Favorite + err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&existing).Error + if err == nil { + // 已经收藏,直接返回 + return nil + } + if err != gorm.ErrRecordNotFound { + return err + } + + // 创建收藏记录 + if err := tx.Create(&model.Favorite{ + PostID: postID, + UserID: userID, + }).Error; err != nil { + return err + } + + // 增加帖子收藏数 + return tx.Model(&model.Post{}).Where("id = ?", postID). + UpdateColumn("favorites_count", gorm.Expr("favorites_count + 1")).Error + }) +} + +// Unfavorite 取消收藏 +func (r *PostRepository) Unfavorite(postID, userID string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + result := tx.Where("post_id = ? AND user_id = ?", postID, userID).Delete(&model.Favorite{}) + if result.Error != nil { + return result.Error + } + if result.RowsAffected > 0 { + // 减少帖子收藏数 + return tx.Model(&model.Post{}).Where("id = ?", postID). + UpdateColumn("favorites_count", gorm.Expr("favorites_count - 1")).Error + } + return nil + }) +} + +// IsFavorited 检查是否收藏 +func (r *PostRepository) IsFavorited(postID, userID string) bool { + var count int64 + r.db.Model(&model.Favorite{}).Where("post_id = ? AND user_id = ?", postID, userID).Count(&count) + return count > 0 +} + +// IncrementViews 增加帖子观看量 +func (r *PostRepository) IncrementViews(postID string) error { + return r.db.Model(&model.Post{}).Where("id = ?", postID). + Updates(map[string]interface{}{ + "views_count": gorm.Expr("views_count + 1"), + "hot_score": gorm.Expr("likes_count * 2 + comments_count * 3 + (views_count + 1) * 0.1"), + }).Error +} + +// Search 搜索帖子 +func (r *PostRepository) Search(keyword string, page, pageSize int) ([]*model.Post, int64, error) { + var posts []*model.Post + var total int64 + + query := r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished) + + // 搜索标题和内容 + if keyword != "" { + if r.db.Dialector.Name() == "postgres" { + // PostgreSQL 使用全文检索表达式,为 pg_trgm/GIN 索引升级预留路径 + query = query.Where( + "to_tsvector('simple', COALESCE(title, '') || ' ' || COALESCE(content, '')) @@ plainto_tsquery('simple', ?)", + keyword, + ) + } else { + searchPattern := "%" + keyword + "%" + query = query.Where("title LIKE ? OR content LIKE ?", searchPattern, searchPattern) + } + } + + query.Count(&total) + + offset := (page - 1) * pageSize + err := query.Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error + + return posts, total, err +} + +// GetFollowingPosts 获取关注用户的帖子 +func (r *PostRepository) GetFollowingPosts(userID string, page, pageSize int) ([]*model.Post, int64, error) { + var posts []*model.Post + var total int64 + + // 子查询:获取当前用户关注的所有用户ID + subQuery := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Select("following_id") + + // 统计总数 + r.db.Model(&model.Post{}).Where("user_id IN (?) AND status = ?", subQuery, model.PostStatusPublished).Count(&total) + + offset := (page - 1) * pageSize + err := r.db.Where("user_id IN (?) AND status = ?", subQuery, model.PostStatusPublished). + Preload("User").Preload("Images"). + Offset(offset).Limit(pageSize). + Order("created_at DESC"). + Find(&posts).Error + + return posts, total, err +} + +// GetHotPosts 获取热门帖子(按点赞数和评论数排序) +func (r *PostRepository) GetHotPosts(page, pageSize int) ([]*model.Post, int64, error) { + var posts []*model.Post + var total int64 + + r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished).Count(&total) + + offset := (page - 1) * pageSize + // 热门排序使用预计算热度分,避免每次请求进行表达式排序计算 + err := r.db.Where("status = ?", model.PostStatusPublished).Preload("User").Preload("Images"). + Offset(offset).Limit(pageSize). + Order("hot_score DESC, created_at DESC"). + Find(&posts).Error + + return posts, total, err +} + +// GetByIDs 根据ID列表获取帖子(保持传入顺序) +func (r *PostRepository) GetByIDs(ids []string) ([]*model.Post, error) { + if len(ids) == 0 { + return []*model.Post{}, nil + } + + var posts []*model.Post + err := r.db.Preload("User").Preload("Images"). + Where("id IN ? AND status = ?", ids, model.PostStatusPublished). + Find(&posts).Error + if err != nil { + return nil, err + } + + // 按传入ID顺序排序 + postMap := make(map[string]*model.Post) + for _, post := range posts { + postMap[post.ID] = post + } + + ordered := make([]*model.Post, 0, len(ids)) + for _, id := range ids { + if post, ok := postMap[id]; ok { + ordered = append(ordered, post) + } + } + + return ordered, nil +} diff --git a/internal/repository/push_repo.go b/internal/repository/push_repo.go new file mode 100644 index 0000000..78c29c7 --- /dev/null +++ b/internal/repository/push_repo.go @@ -0,0 +1,172 @@ +package repository + +import ( + "time" + + "carrot_bbs/internal/model" + + "gorm.io/gorm" +) + +// PushRecordRepository 推送记录仓储 +type PushRecordRepository struct { + db *gorm.DB +} + +// NewPushRecordRepository 创建推送记录仓储 +func NewPushRecordRepository(db *gorm.DB) *PushRecordRepository { + return &PushRecordRepository{db: db} +} + +// Create 创建推送记录 +func (r *PushRecordRepository) Create(record *model.PushRecord) error { + return r.db.Create(record).Error +} + +// GetByID 根据ID获取推送记录 +func (r *PushRecordRepository) GetByID(id int64) (*model.PushRecord, error) { + var record model.PushRecord + err := r.db.First(&record, "id = ?", id).Error + if err != nil { + return nil, err + } + return &record, nil +} + +// Update 更新推送记录 +func (r *PushRecordRepository) Update(record *model.PushRecord) error { + return r.db.Save(record).Error +} + +// GetPendingPushes 获取待推送记录 +func (r *PushRecordRepository) GetPendingPushes(limit int) ([]*model.PushRecord, error) { + var records []*model.PushRecord + err := r.db.Where("push_status = ?", model.PushStatusPending). + Where("expired_at IS NULL OR expired_at > ?", time.Now()). + Order("created_at ASC"). + Limit(limit). + Find(&records).Error + return records, err +} + +// GetByUserID 根据用户ID获取推送记录 +// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 +func (r *PushRecordRepository) GetByUserID(userID string, limit, offset int) ([]*model.PushRecord, error) { + var records []*model.PushRecord + err := r.db.Where("user_id = ?", userID). + Order("created_at DESC"). + Offset(offset). + Limit(limit). + Find(&records).Error + return records, err +} + +// GetByMessageID 根据消息ID获取推送记录 +func (r *PushRecordRepository) GetByMessageID(messageID int64) ([]*model.PushRecord, error) { + var records []*model.PushRecord + err := r.db.Where("message_id = ?", messageID). + Order("created_at DESC"). + Find(&records).Error + return records, err +} + +// GetFailedPushesForRetry 获取失败待重试的推送 +func (r *PushRecordRepository) GetFailedPushesForRetry(limit int) ([]*model.PushRecord, error) { + var records []*model.PushRecord + err := r.db.Where("push_status = ?", model.PushStatusFailed). + Where("retry_count < max_retry"). + Where("expired_at IS NULL OR expired_at > ?", time.Now()). + Order("created_at ASC"). + Limit(limit). + Find(&records).Error + return records, err +} + +// BatchCreate 批量创建推送记录 +func (r *PushRecordRepository) BatchCreate(records []*model.PushRecord) error { + if len(records) == 0 { + return nil + } + return r.db.Create(&records).Error +} + +// BatchUpdateStatus 批量更新推送状态 +func (r *PushRecordRepository) BatchUpdateStatus(ids []int64, status model.PushStatus) error { + if len(ids) == 0 { + return nil + } + updates := map[string]interface{}{ + "push_status": status, + } + if status == model.PushStatusPushed { + updates["pushed_at"] = time.Now() + } + return r.db.Model(&model.PushRecord{}). + Where("id IN ?", ids). + Updates(updates).Error +} + +// UpdateStatus 更新单条记录状态 +func (r *PushRecordRepository) UpdateStatus(id int64, status model.PushStatus) error { + updates := map[string]interface{}{ + "push_status": status, + } + if status == model.PushStatusPushed { + updates["pushed_at"] = time.Now() + } + return r.db.Model(&model.PushRecord{}). + Where("id = ?", id). + Updates(updates).Error +} + +// MarkAsFailed 标记为失败 +func (r *PushRecordRepository) MarkAsFailed(id int64, errMsg string) error { + return r.db.Model(&model.PushRecord{}). + Where("id = ?", id). + Updates(map[string]interface{}{ + "push_status": model.PushStatusFailed, + "error_message": errMsg, + "retry_count": gorm.Expr("retry_count + 1"), + }).Error +} + +// MarkAsDelivered 标记为已送达 +func (r *PushRecordRepository) MarkAsDelivered(id int64) error { + return r.db.Model(&model.PushRecord{}). + Where("id = ?", id). + Updates(map[string]interface{}{ + "push_status": model.PushStatusDelivered, + "delivered_at": time.Now(), + }).Error +} + +// DeleteExpiredRecords 删除过期的推送记录(软删除) +func (r *PushRecordRepository) DeleteExpiredRecords() error { + return r.db.Where("expired_at IS NOT NULL AND expired_at < ?", time.Now()). + Delete(&model.PushRecord{}).Error +} + +// GetStatsByUserID 获取用户推送统计 +func (r *PushRecordRepository) GetStatsByUserID(userID int64) (map[model.PushStatus]int64, error) { + type statusCount struct { + Status model.PushStatus + Count int64 + } + var results []statusCount + + err := r.db.Model(&model.PushRecord{}). + Select("push_status as status, count(*) as count"). + Where("user_id = ?", userID). + Group("push_status"). + Scan(&results).Error + + if err != nil { + return nil, err + } + + stats := make(map[model.PushStatus]int64) + for _, r := range results { + stats[r.Status] = r.Count + } + return stats, nil +} diff --git a/internal/repository/sticker_repo.go b/internal/repository/sticker_repo.go new file mode 100644 index 0000000..d5a2e20 --- /dev/null +++ b/internal/repository/sticker_repo.go @@ -0,0 +1,112 @@ +package repository + +import ( + "carrot_bbs/internal/model" + + "gorm.io/gorm" +) + +// StickerRepository 自定义表情仓库接口 +type StickerRepository interface { + // 获取用户的所有表情 + GetByUserID(userID string) ([]model.UserSticker, error) + // 根据ID获取表情 + GetByID(id string) (*model.UserSticker, error) + // 创建表情 + Create(sticker *model.UserSticker) error + // 删除表情 + Delete(id string) error + // 删除用户的所有表情 + DeleteByUserID(userID string) error + // 检查表情是否存在 + Exists(userID string, url string) (bool, error) + // 更新排序 + UpdateSortOrder(id string, sortOrder int) error + // 批量更新排序 + BatchUpdateSortOrder(userID string, orders map[string]int) error + // 获取用户表情数量 + CountByUserID(userID string) (int64, error) +} + +// stickerRepository 自定义表情仓库实现 +type stickerRepository struct { + db *gorm.DB +} + +// NewStickerRepository 创建自定义表情仓库 +func NewStickerRepository(db *gorm.DB) StickerRepository { + return &stickerRepository{db: db} +} + +// GetByUserID 获取用户的所有表情 +func (r *stickerRepository) GetByUserID(userID string) ([]model.UserSticker, error) { + var stickers []model.UserSticker + err := r.db.Where("user_id = ?", userID). + Order("sort_order ASC, created_at DESC"). + Find(&stickers).Error + return stickers, err +} + +// GetByID 根据ID获取表情 +func (r *stickerRepository) GetByID(id string) (*model.UserSticker, error) { + var sticker model.UserSticker + err := r.db.Where("id = ?", id).First(&sticker).Error + if err != nil { + return nil, err + } + return &sticker, nil +} + +// Create 创建表情 +func (r *stickerRepository) Create(sticker *model.UserSticker) error { + return r.db.Create(sticker).Error +} + +// Delete 删除表情 +func (r *stickerRepository) Delete(id string) error { + return r.db.Where("id = ?", id).Delete(&model.UserSticker{}).Error +} + +// DeleteByUserID 删除用户的所有表情 +func (r *stickerRepository) DeleteByUserID(userID string) error { + return r.db.Where("user_id = ?", userID).Delete(&model.UserSticker{}).Error +} + +// Exists 检查表情是否存在 +func (r *stickerRepository) Exists(userID string, url string) (bool, error) { + var count int64 + err := r.db.Model(&model.UserSticker{}). + Where("user_id = ? AND url = ?", userID, url). + Count(&count).Error + return count > 0, err +} + +// UpdateSortOrder 更新排序 +func (r *stickerRepository) UpdateSortOrder(id string, sortOrder int) error { + return r.db.Model(&model.UserSticker{}). + Where("id = ?", id). + Update("sort_order", sortOrder).Error +} + +// BatchUpdateSortOrder 批量更新排序 +func (r *stickerRepository) BatchUpdateSortOrder(userID string, orders map[string]int) error { + return r.db.Transaction(func(tx *gorm.DB) error { + for id, sortOrder := range orders { + if err := tx.Model(&model.UserSticker{}). + Where("id = ? AND user_id = ?", id, userID). + Update("sort_order", sortOrder).Error; err != nil { + return err + } + } + return nil + }) +} + +// CountByUserID 获取用户表情数量 +func (r *stickerRepository) CountByUserID(userID string) (int64, error) { + var count int64 + err := r.db.Model(&model.UserSticker{}). + Where("user_id = ?", userID). + Count(&count).Error + return count, err +} diff --git a/internal/repository/system_notification_repo.go b/internal/repository/system_notification_repo.go new file mode 100644 index 0000000..2506fd4 --- /dev/null +++ b/internal/repository/system_notification_repo.go @@ -0,0 +1,114 @@ +package repository + +import ( + "carrot_bbs/internal/model" + + "gorm.io/gorm" +) + +// SystemNotificationRepository 系统通知仓储 +type SystemNotificationRepository struct { + db *gorm.DB +} + +// NewSystemNotificationRepository 创建系统通知仓储 +func NewSystemNotificationRepository(db *gorm.DB) *SystemNotificationRepository { + return &SystemNotificationRepository{db: db} +} + +// Create 创建系统通知 +func (r *SystemNotificationRepository) Create(notification *model.SystemNotification) error { + return r.db.Create(notification).Error +} + +// GetByID 根据ID获取通知 +func (r *SystemNotificationRepository) GetByID(id int64) (*model.SystemNotification, error) { + var notification model.SystemNotification + err := r.db.First(¬ification, "id = ?", id).Error + if err != nil { + return nil, err + } + return ¬ification, nil +} + +// GetByReceiverID 获取用户的通知列表 +func (r *SystemNotificationRepository) GetByReceiverID(receiverID string, page, pageSize int) ([]*model.SystemNotification, int64, error) { + var notifications []*model.SystemNotification + var total int64 + + query := r.db.Model(&model.SystemNotification{}).Where("receiver_id = ?", receiverID) + query.Count(&total) + + offset := (page - 1) * pageSize + err := query.Offset(offset). + Limit(pageSize). + Order("created_at DESC"). + Find(¬ifications).Error + + return notifications, total, err +} + +// GetUnreadByReceiverID 获取用户的未读通知列表 +func (r *SystemNotificationRepository) GetUnreadByReceiverID(receiverID string, limit int) ([]*model.SystemNotification, error) { + var notifications []*model.SystemNotification + err := r.db.Where("receiver_id = ? AND is_read = ?", receiverID, false). + Order("created_at DESC"). + Limit(limit). + Find(¬ifications).Error + return notifications, err +} + +// GetUnreadCount 获取用户未读通知数量 +func (r *SystemNotificationRepository) GetUnreadCount(receiverID string) (int64, error) { + var count int64 + err := r.db.Model(&model.SystemNotification{}). + Where("receiver_id = ? AND is_read = ?", receiverID, false). + Count(&count).Error + return count, err +} + +// MarkAsRead 标记单条通知为已读 +func (r *SystemNotificationRepository) MarkAsRead(id int64, receiverID string) error { + now := model.SystemNotification{}.UpdatedAt + return r.db.Model(&model.SystemNotification{}). + Where("id = ? AND receiver_id = ?", id, receiverID). + Updates(map[string]interface{}{ + "is_read": true, + "read_at": now, + }).Error +} + +// MarkAllAsRead 标记用户所有通知为已读 +func (r *SystemNotificationRepository) MarkAllAsRead(receiverID string) error { + now := model.SystemNotification{}.UpdatedAt + return r.db.Model(&model.SystemNotification{}). + Where("receiver_id = ? AND is_read = ?", receiverID, false). + Updates(map[string]interface{}{ + "is_read": true, + "read_at": now, + }).Error +} + +// Delete 删除通知(软删除) +func (r *SystemNotificationRepository) Delete(id int64, receiverID string) error { + return r.db.Where("id = ? AND receiver_id = ?", id, receiverID). + Delete(&model.SystemNotification{}).Error +} + +// GetByType 获取用户指定类型的通知 +func (r *SystemNotificationRepository) GetByType(receiverID string, notifyType model.SystemNotificationType, page, pageSize int) ([]*model.SystemNotification, int64, error) { + var notifications []*model.SystemNotification + var total int64 + + query := r.db.Model(&model.SystemNotification{}). + Where("receiver_id = ? AND type = ?", receiverID, notifyType) + query.Count(&total) + + offset := (page - 1) * pageSize + err := query.Offset(offset). + Limit(pageSize). + Order("created_at DESC"). + Find(¬ifications).Error + + return notifications, total, err +} diff --git a/internal/repository/user_repo.go b/internal/repository/user_repo.go new file mode 100644 index 0000000..05cf202 --- /dev/null +++ b/internal/repository/user_repo.go @@ -0,0 +1,404 @@ +package repository + +import ( + "carrot_bbs/internal/model" + "fmt" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +// UserRepository 用户仓储 +type UserRepository struct { + db *gorm.DB +} + +// NewUserRepository 创建用户仓储 +func NewUserRepository(db *gorm.DB) *UserRepository { + return &UserRepository{db: db} +} + +// Create 创建用户 +func (r *UserRepository) Create(user *model.User) error { + return r.db.Create(user).Error +} + +// GetByID 根据ID获取用户 +func (r *UserRepository) GetByID(id string) (*model.User, error) { + var user model.User + err := r.db.First(&user, "id = ?", id).Error + if err != nil { + return nil, err + } + return &user, nil +} + +// GetByUsername 根据用户名获取用户 +func (r *UserRepository) GetByUsername(username string) (*model.User, error) { + var user model.User + err := r.db.First(&user, "username = ?", username).Error + if err != nil { + return nil, err + } + return &user, nil +} + +// GetByEmail 根据邮箱获取用户 +func (r *UserRepository) GetByEmail(email string) (*model.User, error) { + var user model.User + err := r.db.First(&user, "email = ?", email).Error + if err != nil { + return nil, err + } + return &user, nil +} + +// GetByPhone 根据手机号获取用户 +func (r *UserRepository) GetByPhone(phone string) (*model.User, error) { + var user model.User + err := r.db.First(&user, "phone = ?", phone).Error + if err != nil { + return nil, err + } + return &user, nil +} + +// Update 更新用户 +func (r *UserRepository) Update(user *model.User) error { + return r.db.Save(user).Error +} + +// Delete 删除用户 +func (r *UserRepository) Delete(id string) error { + return r.db.Delete(&model.User{}, "id = ?", id).Error +} + +// List 分页获取用户列表 +func (r *UserRepository) List(page, pageSize int) ([]*model.User, int64, error) { + var users []*model.User + var total int64 + + r.db.Model(&model.User{}).Count(&total) + + offset := (page - 1) * pageSize + err := r.db.Order("created_at DESC, id DESC").Offset(offset).Limit(pageSize).Find(&users).Error + + return users, total, err +} + +// GetFollowers 获取用户粉丝 +func (r *UserRepository) GetFollowers(userID string, page, pageSize int) ([]*model.User, int64, error) { + var users []*model.User + var total int64 + + subQuery := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Select("follower_id") + r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total) + + offset := (page - 1) * pageSize + err := r.db.Where("id IN (?)", subQuery). + Order("created_at DESC, id DESC"). + Offset(offset).Limit(pageSize). + Find(&users).Error + + return users, total, err +} + +// GetFollowing 获取用户关注 +func (r *UserRepository) GetFollowing(userID string, page, pageSize int) ([]*model.User, int64, error) { + var users []*model.User + var total int64 + + subQuery := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Select("following_id") + r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total) + + offset := (page - 1) * pageSize + err := r.db.Where("id IN (?)", subQuery). + Order("created_at DESC, id DESC"). + Offset(offset).Limit(pageSize). + Find(&users).Error + + return users, total, err +} + +// CreateFollow 创建关注关系 +func (r *UserRepository) CreateFollow(follow *model.Follow) error { + return r.db.Create(follow).Error +} + +// DeleteFollow 删除关注关系 +func (r *UserRepository) DeleteFollow(followerID, followingID string) error { + return r.db.Where("follower_id = ? AND following_id = ?", followerID, followingID).Delete(&model.Follow{}).Error +} + +// IsFollowing 检查是否关注了某用户 +func (r *UserRepository) IsFollowing(followerID, followingID string) (bool, error) { + var count int64 + err := r.db.Model(&model.Follow{}).Where("follower_id = ? AND following_id = ?", followerID, followingID).Count(&count).Error + if err != nil { + return false, err + } + return count > 0, nil +} + +// IncrementFollowersCount 增加用户粉丝数 +func (r *UserRepository) IncrementFollowersCount(userID string) error { + return r.db.Model(&model.User{}).Where("id = ?", userID). + UpdateColumn("followers_count", gorm.Expr("followers_count + 1")).Error +} + +// DecrementFollowersCount 减少用户粉丝数 +func (r *UserRepository) DecrementFollowersCount(userID string) error { + return r.db.Model(&model.User{}).Where("id = ? AND followers_count > 0", userID). + UpdateColumn("followers_count", gorm.Expr("followers_count - 1")).Error +} + +// IncrementFollowingCount 增加用户关注数 +func (r *UserRepository) IncrementFollowingCount(userID string) error { + return r.db.Model(&model.User{}).Where("id = ?", userID). + UpdateColumn("following_count", gorm.Expr("following_count + 1")).Error +} + +// DecrementFollowingCount 减少用户关注数 +func (r *UserRepository) DecrementFollowingCount(userID string) error { + return r.db.Model(&model.User{}).Where("id = ? AND following_count > 0", userID). + UpdateColumn("following_count", gorm.Expr("following_count - 1")).Error +} + +// RefreshFollowersCount 刷新用户粉丝数(通过实际计数) +func (r *UserRepository) RefreshFollowersCount(userID string) error { + var count int64 + err := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Count(&count).Error + if err != nil { + return err + } + return r.db.Model(&model.User{}).Where("id = ?", userID). + UpdateColumn("followers_count", count).Error +} + +// GetPostsCount 获取用户帖子数(实时计算) +func (r *UserRepository) GetPostsCount(userID string) (int64, error) { + var count int64 + err := r.db.Model(&model.Post{}).Where("user_id = ?", userID).Count(&count).Error + return count, err +} + +// GetPostsCountBatch 批量获取用户帖子数(实时计算) +// 返回 map[userID]postsCount +func (r *UserRepository) GetPostsCountBatch(userIDs []string) (map[string]int64, error) { + result := make(map[string]int64) + if len(userIDs) == 0 { + return result, nil + } + + // 初始化所有用户ID的计数为0 + for _, userID := range userIDs { + result[userID] = 0 + } + + // 使用 GROUP BY 一次性查询所有用户的帖子数 + type CountResult struct { + UserID string + Count int64 + } + var counts []CountResult + err := r.db.Model(&model.Post{}). + Select("user_id, count(*) as count"). + Where("user_id IN ?", userIDs). + Group("user_id"). + Scan(&counts).Error + if err != nil { + return nil, err + } + + // 更新查询结果 + for _, c := range counts { + result[c.UserID] = c.Count + } + + return result, nil +} + +// RefreshFollowingCount 刷新用户关注数(通过实际计数) +func (r *UserRepository) RefreshFollowingCount(userID string) error { + var count int64 + err := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Count(&count).Error + if err != nil { + return err + } + return r.db.Model(&model.User{}).Where("id = ?", userID). + UpdateColumn("following_count", count).Error +} + +// IsBlocked 检查拉黑关系是否存在(blocker -> blocked) +func (r *UserRepository) IsBlocked(blockerID, blockedID string) (bool, error) { + var count int64 + err := r.db.Model(&model.UserBlock{}). + Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID). + Count(&count).Error + if err != nil { + return false, err + } + return count > 0, nil +} + +// IsBlockedEitherDirection 检查是否任一方向存在拉黑 +func (r *UserRepository) IsBlockedEitherDirection(userA, userB string) (bool, error) { + var count int64 + err := r.db.Model(&model.UserBlock{}). + Where("(blocker_id = ? AND blocked_id = ?) OR (blocker_id = ? AND blocked_id = ?)", + userA, userB, userB, userA). + Count(&count).Error + if err != nil { + return false, err + } + return count > 0, nil +} + +// BlockUserAndCleanupRelations 拉黑用户并清理双向关注关系(事务) +func (r *UserRepository) BlockUserAndCleanupRelations(blockerID, blockedID string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + block := &model.UserBlock{ + BlockerID: blockerID, + BlockedID: blockedID, + } + if err := tx.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "blocker_id"}, {Name: "blocked_id"}}, + DoNothing: true, + }).Create(block).Error; err != nil { + return err + } + + if err := tx.Where("follower_id = ? AND following_id = ?", blockerID, blockedID). + Delete(&model.Follow{}).Error; err != nil { + return err + } + if err := tx.Where("follower_id = ? AND following_id = ?", blockedID, blockerID). + Delete(&model.Follow{}).Error; err != nil { + return err + } + + for _, uid := range []string{blockerID, blockedID} { + var followersCount int64 + if err := tx.Model(&model.Follow{}).Where("following_id = ?", uid).Count(&followersCount).Error; err != nil { + return err + } + if err := tx.Model(&model.User{}).Where("id = ?", uid). + UpdateColumn("followers_count", followersCount).Error; err != nil { + return err + } + + var followingCount int64 + if err := tx.Model(&model.Follow{}).Where("follower_id = ?", uid).Count(&followingCount).Error; err != nil { + return err + } + if err := tx.Model(&model.User{}).Where("id = ?", uid). + UpdateColumn("following_count", followingCount).Error; err != nil { + return err + } + } + + return nil + }) +} + +// UnblockUser 取消拉黑 +func (r *UserRepository) UnblockUser(blockerID, blockedID string) error { + return r.db.Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID). + Delete(&model.UserBlock{}).Error +} + +// GetBlockedUsers 获取用户黑名单列表 +func (r *UserRepository) GetBlockedUsers(blockerID string, page, pageSize int) ([]*model.User, int64, error) { + var users []*model.User + var total int64 + + subQuery := r.db.Model(&model.UserBlock{}).Where("blocker_id = ?", blockerID).Select("blocked_id") + r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total) + + offset := (page - 1) * pageSize + err := r.db.Where("id IN (?)", subQuery). + Order("created_at DESC, id DESC"). + Offset(offset). + Limit(pageSize). + Find(&users).Error + + return users, total, err +} + +// Search 搜索用户 +func (r *UserRepository) Search(keyword string, page, pageSize int) ([]*model.User, int64, error) { + var users []*model.User + var total int64 + + query := r.db.Model(&model.User{}) + + // 搜索用户名、昵称、简介 + if keyword != "" { + if r.db.Dialector.Name() == "postgres" { + query = query.Where( + "to_tsvector('simple', COALESCE(username, '') || ' ' || COALESCE(nickname, '') || ' ' || COALESCE(bio, '')) @@ plainto_tsquery('simple', ?)", + keyword, + ) + } else { + searchPattern := "%" + keyword + "%" + query = query.Where("username LIKE ? OR nickname LIKE ? OR bio LIKE ?", searchPattern, searchPattern, searchPattern) + } + } + + query.Count(&total) + + offset := (page - 1) * pageSize + err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&users).Error + + return users, total, err +} + +// GetMutualFollowStatus 批量获取双向关注状态 +// 返回 map[userID][isFollowing, isFollowingMe] +func (r *UserRepository) GetMutualFollowStatus(currentUserID string, targetUserIDs []string) (map[string][2]bool, error) { + result := make(map[string][2]bool) + + if len(targetUserIDs) == 0 { + return result, nil + } + + fmt.Printf("[DEBUG] GetMutualFollowStatus: currentUserID=%s, targetUserIDs=%v\n", currentUserID, targetUserIDs) + + // 初始化所有目标用户为未关注状态 + for _, userID := range targetUserIDs { + result[userID] = [2]bool{false, false} + } + + // 查询当前用户关注了哪些目标用户 (isFollowing) + var followingIDs []string + err := r.db.Model(&model.Follow{}). + Where("follower_id = ? AND following_id IN ?", currentUserID, targetUserIDs). + Pluck("following_id", &followingIDs).Error + if err != nil { + return nil, err + } + fmt.Printf("[DEBUG] GetMutualFollowStatus: currentUser follows these targets: %v\n", followingIDs) + for _, id := range followingIDs { + status := result[id] + status[0] = true + result[id] = status + } + + // 查询哪些目标用户关注了当前用户 (isFollowingMe) + var followerIDs []string + err = r.db.Model(&model.Follow{}). + Where("follower_id IN ? AND following_id = ?", targetUserIDs, currentUserID). + Pluck("follower_id", &followerIDs).Error + if err != nil { + return nil, err + } + fmt.Printf("[DEBUG] GetMutualFollowStatus: these targets follow currentUser: %v\n", followerIDs) + for _, id := range followerIDs { + status := result[id] + status[1] = true + result[id] = status + } + + fmt.Printf("[DEBUG] GetMutualFollowStatus: final result=%v\n", result) + return result, nil +} diff --git a/internal/repository/vote_repo.go b/internal/repository/vote_repo.go new file mode 100644 index 0000000..b566ceb --- /dev/null +++ b/internal/repository/vote_repo.go @@ -0,0 +1,141 @@ +package repository + +import ( + "carrot_bbs/internal/model" + "errors" + + "gorm.io/gorm" +) + +// VoteRepository 投票仓储 +type VoteRepository struct { + db *gorm.DB +} + +// NewVoteRepository 创建投票仓储 +func NewVoteRepository(db *gorm.DB) *VoteRepository { + return &VoteRepository{db: db} +} + +// CreateOptions 批量创建投票选项 +func (r *VoteRepository) CreateOptions(postID string, options []string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + for i, content := range options { + option := &model.VoteOption{ + PostID: postID, + Content: content, + SortOrder: i, + } + if err := tx.Create(option).Error; err != nil { + return err + } + } + return nil + }) +} + +// GetOptionsByPostID 获取帖子的所有投票选项 +func (r *VoteRepository) GetOptionsByPostID(postID string) ([]model.VoteOption, error) { + var options []model.VoteOption + err := r.db.Where("post_id = ?", postID).Order("sort_order ASC").Find(&options).Error + return options, err +} + +// Vote 用户投票 +func (r *VoteRepository) Vote(postID, userID, optionID string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + // 检查用户是否已投票 + var existing model.UserVote + err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&existing).Error + if err == nil { + // 已经投票,返回错误 + return errors.New("user already voted") + } + if err != gorm.ErrRecordNotFound { + return err + } + + // 验证选项是否属于该帖子 + var option model.VoteOption + if err := tx.Where("id = ? AND post_id = ?", optionID, postID).First(&option).Error; err != nil { + if err == gorm.ErrRecordNotFound { + return errors.New("invalid option") + } + return err + } + + // 创建投票记录 + if err := tx.Create(&model.UserVote{ + PostID: postID, + UserID: userID, + OptionID: optionID, + }).Error; err != nil { + return err + } + + // 原子增加选项投票数 + return tx.Model(&model.VoteOption{}).Where("id = ?", optionID). + UpdateColumn("votes_count", gorm.Expr("votes_count + 1")).Error + }) +} + +// Unvote 取消投票 +func (r *VoteRepository) Unvote(postID, userID string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + // 获取用户的投票记录 + var userVote model.UserVote + err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&userVote).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return nil // 没有投票记录,直接返回 + } + return err + } + + // 删除投票记录 + result := tx.Where("post_id = ? AND user_id = ?", postID, userID).Delete(&model.UserVote{}) + if result.Error != nil { + return result.Error + } + + if result.RowsAffected > 0 { + // 原子减少选项投票数 + return tx.Model(&model.VoteOption{}).Where("id = ?", userVote.OptionID). + UpdateColumn("votes_count", gorm.Expr("votes_count - 1")).Error + } + + return nil + }) +} + +// GetUserVote 获取用户在指定帖子的投票 +func (r *VoteRepository) GetUserVote(postID, userID string) (*model.UserVote, error) { + var userVote model.UserVote + err := r.db.Where("post_id = ? AND user_id = ?", postID, userID).First(&userVote).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return nil, nil + } + return nil, err + } + return &userVote, nil +} + +// UpdateOption 更新选项内容 +func (r *VoteRepository) UpdateOption(optionID, content string) error { + return r.db.Model(&model.VoteOption{}).Where("id = ?", optionID). + Update("content", content).Error +} + +// DeleteOptionsByPostID 删除帖子的所有投票选项 +func (r *VoteRepository) DeleteOptionsByPostID(postID string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + // 删除关联的用户投票记录 + if err := tx.Where("post_id = ?", postID).Delete(&model.UserVote{}).Error; err != nil { + return err + } + + // 删除投票选项 + return tx.Where("post_id = ?", postID).Delete(&model.VoteOption{}).Error + }) +} diff --git a/internal/router/router.go b/internal/router/router.go new file mode 100644 index 0000000..d617acb --- /dev/null +++ b/internal/router/router.go @@ -0,0 +1,334 @@ +package router + +import ( + "github.com/gin-gonic/gin" + + "carrot_bbs/internal/handler" + "carrot_bbs/internal/middleware" + "carrot_bbs/internal/service" +) + +// Router 路由配置 +type Router struct { + engine *gin.Engine + userHandler *handler.UserHandler + postHandler *handler.PostHandler + commentHandler *handler.CommentHandler + messageHandler *handler.MessageHandler + notificationHandler *handler.NotificationHandler + uploadHandler *handler.UploadHandler + wsHandler *handler.WebSocketHandler + pushHandler *handler.PushHandler + systemMessageHandler *handler.SystemMessageHandler + groupHandler *handler.GroupHandler + stickerHandler *handler.StickerHandler + gorseHandler *handler.GorseHandler + voteHandler *handler.VoteHandler + jwtService *service.JWTService +} + +// New 创建路由 +func New( + userHandler *handler.UserHandler, + postHandler *handler.PostHandler, + commentHandler *handler.CommentHandler, + messageHandler *handler.MessageHandler, + notificationHandler *handler.NotificationHandler, + uploadHandler *handler.UploadHandler, + jwtService *service.JWTService, + wsHandler *handler.WebSocketHandler, + pushHandler *handler.PushHandler, + systemMessageHandler *handler.SystemMessageHandler, + groupHandler *handler.GroupHandler, + stickerHandler *handler.StickerHandler, + gorseHandler *handler.GorseHandler, + voteHandler *handler.VoteHandler, +) *Router { + // 设置JWT服务 + userHandler.SetJWTService(jwtService) + + r := &Router{ + engine: gin.Default(), + userHandler: userHandler, + postHandler: postHandler, + commentHandler: commentHandler, + messageHandler: messageHandler, + notificationHandler: notificationHandler, + uploadHandler: uploadHandler, + wsHandler: wsHandler, + pushHandler: pushHandler, + systemMessageHandler: systemMessageHandler, + groupHandler: groupHandler, + stickerHandler: stickerHandler, + gorseHandler: gorseHandler, + voteHandler: voteHandler, + jwtService: jwtService, + } + + r.setupRoutes() + return r +} + +// setupRoutes 设置路由 +func (r *Router) setupRoutes() { + // 中间件 + r.engine.Use(middleware.CORS()) + + // 健康检查 + r.engine.GET("/health", func(c *gin.Context) { + c.JSON(200, gin.H{"status": "ok"}) + }) + + // WebSocket 路由 + if r.wsHandler != nil { + r.engine.GET("/ws", r.wsHandler.HandleWebSocket) + } + + // API v1 + v1 := r.engine.Group("/api/v1") + { + // 认证路由(公开) + auth := v1.Group("/auth") + { + auth.POST("/register", r.userHandler.Register) + auth.POST("/register/send-code", r.userHandler.SendRegisterCode) + auth.POST("/login", r.userHandler.Login) + auth.POST("/password/send-code", r.userHandler.SendPasswordResetCode) + auth.POST("/password/reset", r.userHandler.ResetPassword) + auth.POST("/refresh", r.userHandler.RefreshToken) + } + + // 需要认证的路由 + authMiddleware := middleware.Auth(r.jwtService) + + // 用户路由 + users := v1.Group("/users") + { + // 当前用户 + users.GET("/me", authMiddleware, r.userHandler.GetCurrentUser) + users.PUT("/me", authMiddleware, r.userHandler.UpdateUser) + users.POST("/me/email/send-code", authMiddleware, r.userHandler.SendEmailVerifyCode) + users.POST("/me/email/verify", authMiddleware, r.userHandler.VerifyEmail) + users.POST("/me/avatar", authMiddleware, r.uploadHandler.UploadAvatar) + users.POST("/me/cover", authMiddleware, r.uploadHandler.UploadCover) + users.POST("/change-password/send-code", authMiddleware, r.userHandler.SendChangePasswordCode) + users.POST("/change-password", authMiddleware, r.userHandler.ChangePassword) + + // 搜索用户 + users.GET("/search", middleware.OptionalAuth(r.jwtService), r.userHandler.Search) + + // 其他用户 + users.GET("/:id", middleware.OptionalAuth(r.jwtService), r.userHandler.GetUserByID) + users.POST("/:id/follow", authMiddleware, r.userHandler.FollowUser) + users.DELETE("/:id/follow", authMiddleware, r.userHandler.UnfollowUser) + users.POST("/:id/block", authMiddleware, r.userHandler.BlockUser) + users.DELETE("/:id/block", authMiddleware, r.userHandler.UnblockUser) + users.GET("/:id/block-status", authMiddleware, r.userHandler.GetBlockStatus) + users.GET("/blocks", authMiddleware, r.userHandler.GetBlockedUsers) + users.GET("/:id/following", authMiddleware, r.userHandler.GetFollowingList) + users.GET("/:id/followers", authMiddleware, r.userHandler.GetFollowersList) + + // 用户帖子 - 使用 OptionalAuth 获取当前用户点赞/收藏状态 + users.GET("/:id/posts", middleware.OptionalAuth(r.jwtService), r.postHandler.GetUserPosts) + users.GET("/:id/favorites", middleware.OptionalAuth(r.jwtService), r.postHandler.GetFavorites) + } + + // 认证路由(公开) + authPublic := v1.Group("/auth") + { + authPublic.GET("/check-username", r.userHandler.CheckUsername) + } + + // 帖子路由 + posts := v1.Group("/posts") + { + // 使用 OptionalAuth 中间件来获取用户登录状态 + posts.GET("", middleware.OptionalAuth(r.jwtService), r.postHandler.List) + posts.GET("/search", middleware.OptionalAuth(r.jwtService), r.postHandler.Search) + posts.GET("/:id", middleware.OptionalAuth(r.jwtService), r.postHandler.GetByID) + posts.POST("", authMiddleware, r.postHandler.Create) + posts.PUT("/:id", authMiddleware, r.postHandler.Update) + posts.DELETE("/:id", authMiddleware, r.postHandler.Delete) + + // 浏览量记录(可选认证,允许游客浏览) + posts.POST("/:id/view", middleware.OptionalAuth(r.jwtService), r.postHandler.RecordView) + + // 点赞 + posts.POST("/:id/like", authMiddleware, r.postHandler.Like) + posts.DELETE("/:id/like", authMiddleware, r.postHandler.Unlike) + + // 收藏 + posts.POST("/:id/favorite", authMiddleware, r.postHandler.Favorite) + posts.DELETE("/:id/favorite", authMiddleware, r.postHandler.Unfavorite) + + // 投票相关路由 + posts.POST("/vote", authMiddleware, r.voteHandler.CreateVotePost) // 创建投票帖子 + posts.GET("/:id/vote", middleware.OptionalAuth(r.jwtService), r.voteHandler.GetVoteResult) // 获取投票结果 + posts.POST("/:id/vote", authMiddleware, r.voteHandler.Vote) // 投票 + posts.DELETE("/:id/vote", authMiddleware, r.voteHandler.Unvote) // 取消投票 + } + + // 投票选项路由 + voteOptions := v1.Group("/vote-options") + voteOptions.Use(authMiddleware) + { + voteOptions.PUT("/:id", r.voteHandler.UpdateVoteOption) // 更新选项 + } + + // 评论路由 + comments := v1.Group("/comments") + { + comments.GET("/post/:id", middleware.OptionalAuth(r.jwtService), r.commentHandler.GetByPostID) + comments.GET("/:id", middleware.OptionalAuth(r.jwtService), r.commentHandler.GetByID) + comments.POST("", authMiddleware, r.commentHandler.Create) + comments.PUT("/:id", authMiddleware, r.commentHandler.Update) + comments.DELETE("/:id", authMiddleware, r.commentHandler.Delete) + comments.GET("/:id/replies", middleware.OptionalAuth(r.jwtService), r.commentHandler.GetReplies) + comments.GET("/:id/replies/flat", middleware.OptionalAuth(r.jwtService), r.commentHandler.GetRepliesByRootID) // 扁平化分页获取回复 + // 评论点赞 + comments.POST("/:id/like", authMiddleware, r.commentHandler.Like) + comments.DELETE("/:id/like", authMiddleware, r.commentHandler.Unlike) + } + + // 会话路由(新版 RESTful action 风格) + conversations := v1.Group("/conversations") + conversations.Use(authMiddleware) + { + // 获取会话列表 + conversations.GET("/list", r.messageHandler.HandleGetConversationList) + // 创建会话 + conversations.POST("/create", r.messageHandler.HandleCreateConversation) + // 获取会话详情 + conversations.GET("/get", r.messageHandler.HandleGetConversation) + // 获取会话消息 + conversations.GET("/get_messages", r.messageHandler.HandleGetMessages) + // 发送消息 + conversations.POST("/send_message", r.messageHandler.HandleSendMessage) + // 标记已读 + conversations.POST("/mark_read", r.messageHandler.HandleMarkRead) + // 会话置顶 + conversations.POST("/set_pinned", r.messageHandler.HandleSetConversationPinned) + // 获取未读消息总数 + conversations.GET("/unread/count", r.messageHandler.GetUnreadCount) + // 仅自己删除会话 + conversations.DELETE("/:id/self", r.messageHandler.HandleDeleteConversationForSelf) + } + + // 消息操作路由 + messages := v1.Group("/messages") + messages.Use(authMiddleware) + { + // 撤回/删除消息(统一接口) + messages.POST("/delete_msg", r.messageHandler.HandleDeleteMsg) + } + + // 通知路由 + notifications := v1.Group("/notifications") + { + notifications.GET("", authMiddleware, r.notificationHandler.GetNotifications) + notifications.POST("/:id/read", authMiddleware, r.notificationHandler.MarkAsRead) + notifications.POST("/read-all", authMiddleware, r.notificationHandler.MarkAllAsRead) + notifications.GET("/unread-count", authMiddleware, r.notificationHandler.GetUnreadCount) + notifications.DELETE("/:id", authMiddleware, r.notificationHandler.DeleteNotification) + notifications.DELETE("", authMiddleware, r.notificationHandler.ClearAllNotifications) + } + + // 上传路由 + uploads := v1.Group("/uploads") + { + uploads.POST("/images", authMiddleware, r.uploadHandler.UploadImage) + } + + // 推送相关路由 + if r.pushHandler != nil { + pushGroup := v1.Group("/push") + pushGroup.Use(authMiddleware) + { + pushGroup.POST("/devices", r.pushHandler.RegisterDevice) + pushGroup.GET("/devices", r.pushHandler.GetMyDevices) + pushGroup.DELETE("/devices/:device_id", r.pushHandler.UnregisterDevice) + pushGroup.PUT("/devices/:device_id/token", r.pushHandler.UpdateDeviceToken) + pushGroup.GET("/records", r.pushHandler.GetPushRecords) + } + } + + // 系统消息路由 + if r.systemMessageHandler != nil { + msgGroup := v1.Group("/messages") + msgGroup.Use(authMiddleware) + { + msgGroup.GET("/system", r.systemMessageHandler.GetSystemMessages) + msgGroup.GET("/system/unread-count", r.systemMessageHandler.GetUnreadCount) + msgGroup.PUT("/system/:id/read", r.systemMessageHandler.MarkAsRead) + msgGroup.PUT("/system/read-all", r.systemMessageHandler.MarkAllAsRead) + } + } + + // 群组路由(新版 RESTful action 风格) + if r.groupHandler != nil { + groups := v1.Group("/groups") + groups.Use(authMiddleware) + { + // 群组管理 + groups.POST("/create", r.groupHandler.HandleCreateGroup) + groups.GET("/list", r.groupHandler.HandleGetUserGroups) + groups.GET("/get", r.groupHandler.HandleGetGroupInfo) + groups.GET("/get_my_info", r.groupHandler.HandleGetMyMemberInfo) + groups.POST("/dissolve", r.groupHandler.HandleDissolveGroup) + groups.POST("/transfer", r.groupHandler.HandleTransferOwner) + + // 成员管理 + groups.POST("/invite_members", r.groupHandler.HandleInviteMembers) + groups.POST("/join", r.groupHandler.HandleJoinGroup) + groups.POST("/respond_invite", r.groupHandler.HandleRespondInvite) + groups.POST("/set_group_leave", r.groupHandler.HandleSetGroupLeave) + groups.GET("/get_members", r.groupHandler.HandleGetGroupMemberList) + groups.POST("/set_group_kick", r.groupHandler.HandleSetGroupKick) + groups.POST("/set_group_admin", r.groupHandler.HandleSetGroupAdmin) + groups.POST("/set_nickname", r.groupHandler.HandleSetNickname) + groups.POST("/set_group_ban", r.groupHandler.HandleSetGroupBan) + + // 群设置 + groups.POST("/set_group_whole_ban", r.groupHandler.HandleSetGroupWholeBan) + groups.POST("/set_join_type", r.groupHandler.HandleSetJoinType) + groups.POST("/set_group_name", r.groupHandler.HandleSetGroupName) + groups.POST("/set_group_avatar", r.groupHandler.HandleSetGroupAvatar) + + // 群公告 + groups.POST("/create_announcement", r.groupHandler.HandleCreateAnnouncement) + groups.GET("/get_announcements", r.groupHandler.HandleGetAnnouncements) + groups.POST("/delete_announcement", r.groupHandler.HandleDeleteAnnouncement) + + // 加群请求处理(预留) + groups.POST("/set_group_add_request", r.groupHandler.HandleSetGroupAddRequest) + } + } + + // 自定义表情路由 + if r.stickerHandler != nil { + stickers := v1.Group("/stickers") + stickers.Use(authMiddleware) + { + stickers.GET("", r.stickerHandler.GetStickers) + stickers.POST("", r.stickerHandler.AddSticker) + stickers.DELETE("", r.stickerHandler.DeleteSticker) + stickers.POST("/reorder", r.stickerHandler.ReorderStickers) + stickers.GET("/check", r.stickerHandler.CheckStickerExists) + } + } + + // Gorse 管理路由 + if r.gorseHandler != nil { + gorseGroup := v1.Group("/gorse") + { + gorseGroup.GET("/status", r.gorseHandler.GetStatus) + gorseGroup.POST("/import", r.gorseHandler.ImportData) + } + } + } +} + +// Engine 获取引擎 +func (r *Router) Engine() *gin.Engine { + return r.engine +} diff --git a/internal/service/audit_service.go b/internal/service/audit_service.go new file mode 100644 index 0000000..efebbe0 --- /dev/null +++ b/internal/service/audit_service.go @@ -0,0 +1,759 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "log" + "strings" + "sync" + "time" + + "carrot_bbs/internal/model" + + "gorm.io/gorm" +) + +// ==================== 内容审核服务接口和实现 ==================== + +// AuditServiceProvider 内容审核服务提供商接口 +type AuditServiceProvider interface { + // AuditText 审核文本 + AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) + // AuditImage 审核图片 + AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) + // GetName 获取提供商名称 + GetName() string +} + +// AuditResult 审核结果 +type AuditResult struct { + Pass bool `json:"pass"` // 是否通过 + Risk string `json:"risk"` // 风险等级: low, medium, high + Labels []string `json:"labels"` // 标签列表 + Suggest string `json:"suggest"` // 建议: pass, review, block + Detail string `json:"detail"` // 详细说明 + Provider string `json:"provider"` // 服务提供商 +} + +// AuditService 内容审核服务接口 +type AuditService interface { + // AuditText 审核文本 + AuditText(ctx context.Context, text string, auditType string) (*AuditResult, error) + // AuditImage 审核图片 + AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) + // GetAuditResult 获取审核结果 + GetAuditResult(ctx context.Context, auditID string) (*AuditResult, error) + // SetProvider 设置审核服务提供商 + SetProvider(provider AuditServiceProvider) + // GetProvider 获取当前审核服务提供商 + GetProvider() AuditServiceProvider +} + +// auditServiceImpl 内容审核服务实现 +type auditServiceImpl struct { + db *gorm.DB + provider AuditServiceProvider + config *AuditConfig + mu sync.RWMutex +} + +// AuditConfig 内容审核服务配置 +type AuditConfig struct { + Enabled bool `mapstructure:"enabled" yaml:"enabled"` + // 审核服务提供商: local, aliyun, tencent, baidu + Provider string `mapstructure:"provider" yaml:"provider"` + // 阿里云配置 + AliyunAccessKey string `mapstructure:"aliyun_access_key" yaml:"aliyun_access_key"` + AliyunSecretKey string `mapstructure:"aliyun_secret_key" yaml:"aliyun_secret_key"` + AliyunRegion string `mapstructure:"aliyun_region" yaml:"aliyun_region"` + // 腾讯云配置 + TencentSecretID string `mapstructure:"tencent_secret_id" yaml:"tencent_secret_id"` + TencentSecretKey string `mapstructure:"tencent_secret_key" yaml:"tencent_secret_key"` + // 百度云配置 + BaiduAPIKey string `mapstructure:"baidu_api_key" yaml:"baidu_api_key"` + BaiduSecretKey string `mapstructure:"baidu_secret_key" yaml:"baidu_secret_key"` + // 是否自动审核 + AutoAudit bool `mapstructure:"auto_audit" yaml:"auto_audit"` + // 审核超时时间(秒) + Timeout int `mapstructure:"timeout" yaml:"timeout"` +} + +// NewAuditService 创建内容审核服务 +func NewAuditService(db *gorm.DB, config *AuditConfig) AuditService { + s := &auditServiceImpl{ + db: db, + config: config, + } + + // 根据配置初始化提供商 + if config.Enabled { + provider := s.initProvider(config.Provider) + s.provider = provider + } + + return s +} + +// initProvider 根据配置初始化审核服务提供商 +func (s *auditServiceImpl) initProvider(providerType string) AuditServiceProvider { + switch strings.ToLower(providerType) { + case "aliyun": + return NewAliyunAuditProvider(s.config.AliyunAccessKey, s.config.AliyunSecretKey, s.config.AliyunRegion) + case "tencent": + return NewTencentAuditProvider(s.config.TencentSecretID, s.config.TencentSecretKey) + case "baidu": + return NewBaiduAuditProvider(s.config.BaiduAPIKey, s.config.BaiduSecretKey) + case "local": + fallthrough + default: + // 默认使用本地审核服务 + return NewLocalAuditProvider() + } +} + +// AuditText 审核文本 +func (s *auditServiceImpl) AuditText(ctx context.Context, text string, auditType string) (*AuditResult, error) { + if !s.config.Enabled { + // 如果审核服务未启用,直接返回通过 + return &AuditResult{ + Pass: true, + Risk: "low", + Suggest: "pass", + Detail: "Audit service disabled", + }, nil + } + + if text == "" { + return &AuditResult{ + Pass: true, + Risk: "low", + Suggest: "pass", + Detail: "Empty text", + }, nil + } + + var result *AuditResult + var err error + + // 使用提供商审核 + if s.provider != nil { + result, err = s.provider.AuditText(ctx, text, auditType) + } else { + // 如果没有设置提供商,使用本地审核 + localProvider := NewLocalAuditProvider() + result, err = localProvider.AuditText(ctx, text, auditType) + } + + if err != nil { + log.Printf("Audit text error: %v", err) + return &AuditResult{ + Pass: false, + Risk: "high", + Suggest: "review", + Detail: fmt.Sprintf("Audit error: %v", err), + }, err + } + + // 记录审核日志 + go s.saveAuditLog(ctx, "text", "", text, auditType, result) + + return result, nil +} + +// AuditImage 审核图片 +func (s *auditServiceImpl) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) { + if !s.config.Enabled { + return &AuditResult{ + Pass: true, + Risk: "low", + Suggest: "pass", + Detail: "Audit service disabled", + }, nil + } + + if imageURL == "" { + return &AuditResult{ + Pass: true, + Risk: "low", + Suggest: "pass", + Detail: "Empty image URL", + }, nil + } + + var result *AuditResult + var err error + + // 使用提供商审核 + if s.provider != nil { + result, err = s.provider.AuditImage(ctx, imageURL) + } else { + // 如果没有设置提供商,使用本地审核 + localProvider := NewLocalAuditProvider() + result, err = localProvider.AuditImage(ctx, imageURL) + } + + if err != nil { + log.Printf("Audit image error: %v", err) + return &AuditResult{ + Pass: false, + Risk: "high", + Suggest: "review", + Detail: fmt.Sprintf("Audit error: %v", err), + }, err + } + + // 记录审核日志 + go s.saveAuditLog(ctx, "image", "", "", "image", result) + + return result, nil +} + +// GetAuditResult 获取审核结果 +func (s *auditServiceImpl) GetAuditResult(ctx context.Context, auditID string) (*AuditResult, error) { + if s.db == nil || auditID == "" { + return nil, fmt.Errorf("invalid audit ID") + } + + var auditLog model.AuditLog + if err := s.db.Where("id = ?", auditID).First(&auditLog).Error; err != nil { + return nil, err + } + + result := &AuditResult{ + Pass: auditLog.Result == model.AuditResultPass, + Risk: string(auditLog.RiskLevel), + Suggest: auditLog.Suggestion, + Detail: auditLog.Detail, + } + + // 解析标签 + if auditLog.Labels != "" { + json.Unmarshal([]byte(auditLog.Labels), &result.Labels) + } + + return result, nil +} + +// SetProvider 设置审核服务提供商 +func (s *auditServiceImpl) SetProvider(provider AuditServiceProvider) { + s.mu.Lock() + defer s.mu.Unlock() + s.provider = provider +} + +// GetProvider 获取当前审核服务提供商 +func (s *auditServiceImpl) GetProvider() AuditServiceProvider { + s.mu.RLock() + defer s.mu.RUnlock() + return s.provider +} + +// saveAuditLog 保存审核日志 +func (s *auditServiceImpl) saveAuditLog(ctx context.Context, contentType, content, imageURL, auditType string, result *AuditResult) { + if s.db == nil { + return + } + + auditLog := model.AuditLog{ + ContentType: contentType, + Content: content, + ContentURL: imageURL, + AuditType: auditType, + Labels: strings.Join(result.Labels, ","), + Suggestion: result.Suggest, + Detail: result.Detail, + Source: model.AuditSourceAuto, + Status: "completed", + } + + if result.Pass { + auditLog.Result = model.AuditResultPass + } else if result.Suggest == "review" { + auditLog.Result = model.AuditResultReview + } else { + auditLog.Result = model.AuditResultBlock + } + + switch result.Risk { + case "low": + auditLog.RiskLevel = model.AuditRiskLevelLow + case "medium": + auditLog.RiskLevel = model.AuditRiskLevelMedium + case "high": + auditLog.RiskLevel = model.AuditRiskLevelHigh + default: + auditLog.RiskLevel = model.AuditRiskLevelLow + } + + if err := s.db.Create(&auditLog).Error; err != nil { + log.Printf("Failed to save audit log: %v", err) + } +} + +// ==================== 本地审核服务提供商 ==================== + +// localAuditProvider 本地审核服务提供商 +type localAuditProvider struct { + // 可以注入敏感词服务进行本地审核 + sensitiveService SensitiveService +} + +// NewLocalAuditProvider 创建本地审核服务提供商 +func NewLocalAuditProvider() AuditServiceProvider { + return &localAuditProvider{ + sensitiveService: nil, + } +} + +// GetName 获取提供商名称 +func (p *localAuditProvider) GetName() string { + return "local" +} + +// AuditText 审核文本 +func (p *localAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) { + // 本地审核逻辑 + // 1. 敏感词检查 + // 2. 规则匹配 + // 3. 简单的关键词检测 + + result := &AuditResult{ + Pass: true, + Risk: "low", + Suggest: "pass", + Labels: []string{}, + Provider: "local", + } + + // 如果有敏感词服务,使用它进行检测 + if p.sensitiveService != nil { + hasSensitive, words := p.sensitiveService.Check(ctx, text) + if hasSensitive { + result.Pass = false + result.Risk = "high" + result.Suggest = "block" + result.Detail = fmt.Sprintf("包含敏感词: %s", strings.Join(words, ",")) + result.Labels = append(result.Labels, "sensitive") + } + } + + // 简单的关键词检测规则 + // 实际项目中应该从数据库加载 + suspiciousPatterns := []string{ + "诈骗", + "钓鱼", + "木马", + "病毒", + } + + for _, pattern := range suspiciousPatterns { + if strings.Contains(text, pattern) { + result.Pass = false + result.Risk = "high" + result.Suggest = "block" + result.Labels = append(result.Labels, "suspicious") + if result.Detail == "" { + result.Detail = fmt.Sprintf("包含可疑内容: %s", pattern) + } else { + result.Detail += fmt.Sprintf(", %s", pattern) + } + } + } + + return result, nil +} + +// AuditImage 审核图片 +func (p *localAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) { + // 本地图片审核逻辑 + // 1. 图片URL合法性检查 + // 2. 图片格式检查 + // 3. 可以扩展接入本地图片识别服务 + + result := &AuditResult{ + Pass: true, + Risk: "low", + Suggest: "pass", + Labels: []string{}, + Provider: "local", + } + + // 检查URL是否为空 + if imageURL == "" { + result.Detail = "Empty image URL" + return result, nil + } + + // 检查是否为支持的图片URL格式 + validPrefixes := []string{"http://", "https://", "s3://", "oss://", "cos://"} + isValid := false + for _, prefix := range validPrefixes { + if strings.HasPrefix(strings.ToLower(imageURL), prefix) { + isValid = true + break + } + } + + if !isValid { + result.Pass = false + result.Risk = "medium" + result.Suggest = "review" + result.Detail = "Invalid image URL format" + result.Labels = append(result.Labels, "invalid_url") + } + + return result, nil +} + +// SetSensitiveService 设置敏感词服务 +func (p *localAuditProvider) SetSensitiveService(ss SensitiveService) { + p.sensitiveService = ss +} + +// ==================== 阿里云审核服务提供商 ==================== + +// aliyunAuditProvider 阿里云审核服务提供商 +type aliyunAuditProvider struct { + accessKey string + secretKey string + region string +} + +// NewAliyunAuditProvider 创建阿里云审核服务提供商 +func NewAliyunAuditProvider(accessKey, secretKey, region string) AuditServiceProvider { + return &aliyunAuditProvider{ + accessKey: accessKey, + secretKey: secretKey, + region: region, + } +} + +// GetName 获取提供商名称 +func (p *aliyunAuditProvider) GetName() string { + return "aliyun" +} + +// AuditText 审核文本 +func (p *aliyunAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) { + // 阿里云内容安全API调用 + // 实际项目中需要实现阿里云SDK调用 + // 这里预留接口 + + result := &AuditResult{ + Pass: true, + Risk: "low", + Suggest: "pass", + Labels: []string{}, + Provider: "aliyun", + Detail: "Aliyun audit not implemented, using pass", + } + + // TODO: 实现阿里云内容安全API调用 + // 具体参考: https://help.aliyun.com/document_detail/28417.html + + return result, nil +} + +// AuditImage 审核图片 +func (p *aliyunAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) { + result := &AuditResult{ + Pass: true, + Risk: "low", + Suggest: "pass", + Labels: []string{}, + Provider: "aliyun", + Detail: "Aliyun image audit not implemented, using pass", + } + + // TODO: 实现阿里云图片审核API调用 + + return result, nil +} + +// ==================== 腾讯云审核服务提供商 ==================== + +// tencentAuditProvider 腾讯云审核服务提供商 +type tencentAuditProvider struct { + secretID string + secretKey string +} + +// NewTencentAuditProvider 创建腾讯云审核服务提供商 +func NewTencentAuditProvider(secretID, secretKey string) AuditServiceProvider { + return &tencentAuditProvider{ + secretID: secretID, + secretKey: secretKey, + } +} + +// GetName 获取提供商名称 +func (p *tencentAuditProvider) GetName() string { + return "tencent" +} + +// AuditText 审核文本 +func (p *tencentAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) { + result := &AuditResult{ + Pass: true, + Risk: "low", + Suggest: "pass", + Labels: []string{}, + Provider: "tencent", + Detail: "Tencent audit not implemented, using pass", + } + + // TODO: 实现腾讯云内容审核API调用 + // 具体参考: https://cloud.tencent.com/document/product/1124/64508 + + return result, nil +} + +// AuditImage 审核图片 +func (p *tencentAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) { + result := &AuditResult{ + Pass: true, + Risk: "low", + Suggest: "pass", + Labels: []string{}, + Provider: "tencent", + Detail: "Tencent image audit not implemented, using pass", + } + + // TODO: 实现腾讯云图片审核API调用 + + return result, nil +} + +// ==================== 百度云审核服务提供商 ==================== + +// baiduAuditProvider 百度云审核服务提供商 +type baiduAuditProvider struct { + apiKey string + secretKey string +} + +// NewBaiduAuditProvider 创建百度云审核服务提供商 +func NewBaiduAuditProvider(apiKey, secretKey string) AuditServiceProvider { + return &baiduAuditProvider{ + apiKey: apiKey, + secretKey: secretKey, + } +} + +// GetName 获取提供商名称 +func (p *baiduAuditProvider) GetName() string { + return "baidu" +} + +// AuditText 审核文本 +func (p *baiduAuditProvider) AuditText(ctx context.Context, text string, scene string) (*AuditResult, error) { + result := &AuditResult{ + Pass: true, + Risk: "low", + Suggest: "pass", + Labels: []string{}, + Provider: "baidu", + Detail: "Baidu audit not implemented, using pass", + } + + // TODO: 实现百度云内容审核API调用 + // 具体参考: https://cloud.baidu.com/doc/ANTISPAM/s/Jjw0r1iF6 + + return result, nil +} + +// AuditImage 审核图片 +func (p *baiduAuditProvider) AuditImage(ctx context.Context, imageURL string) (*AuditResult, error) { + result := &AuditResult{ + Pass: true, + Risk: "low", + Suggest: "pass", + Labels: []string{}, + Provider: "baidu", + Detail: "Baidu image audit not implemented, using pass", + } + + // TODO: 实现百度云图片审核API调用 + + return result, nil +} + +// ==================== 审核结果回调处理 ==================== + +// AuditCallback 审核回调处理 +type AuditCallback struct { + service AuditService +} + +// NewAuditCallback 创建审核回调处理 +func NewAuditCallback(service AuditService) *AuditCallback { + return &AuditCallback{ + service: service, + } +} + +// HandleTextCallback 处理文本审核回调 +func (c *AuditCallback) HandleTextCallback(ctx context.Context, auditID string, result *AuditResult) error { + if c.service == nil || auditID == "" || result == nil { + return fmt.Errorf("invalid parameters") + } + + log.Printf("Processing text audit callback: auditID=%s, result=%+v", auditID, result) + + // 根据审核结果执行相应操作 + // 例如: 更新帖子状态、发送通知等 + + return nil +} + +// HandleImageCallback 处理图片审核回调 +func (c *AuditCallback) HandleImageCallback(ctx context.Context, auditID string, result *AuditResult) error { + if c.service == nil || auditID == "" || result == nil { + return fmt.Errorf("invalid parameters") + } + + log.Printf("Processing image audit callback: auditID=%s, result=%+v", auditID, result) + + // 根据审核结果执行相应操作 + // 例如: 更新图片状态、删除违规图片等 + + return nil +} + +// ==================== 辅助函数 ==================== + +// IsContentSafe 判断内容是否安全 +func IsContentSafe(result *AuditResult) bool { + if result == nil { + return true + } + return result.Pass && result.Suggest != "block" +} + +// NeedReview 判断内容是否需要人工复审 +func NeedReview(result *AuditResult) bool { + if result == nil { + return false + } + return result.Suggest == "review" +} + +// GetRiskLevel 获取风险等级 +func GetRiskLevel(result *AuditResult) string { + if result == nil { + return "low" + } + return result.Risk +} + +// FormatAuditResult 格式化审核结果为字符串 +func FormatAuditResult(result *AuditResult) string { + if result == nil { + return "{}" + } + data, _ := json.Marshal(result) + return string(data) +} + +// ParseAuditResult 从字符串解析审核结果 +func ParseAuditResult(data string) (*AuditResult, error) { + if data == "" { + return nil, fmt.Errorf("empty data") + } + var result AuditResult + if err := json.Unmarshal([]byte(data), &result); err != nil { + return nil, err + } + return &result, nil +} + +// ==================== 审核日志查询 ==================== + +// GetAuditLogs 获取审核日志列表 +func GetAuditLogs(db *gorm.DB, targetType string, targetID string, result string, page, pageSize int) ([]model.AuditLog, int64, error) { + query := db.Model(&model.AuditLog{}) + + if targetType != "" { + query = query.Where("target_type = ?", targetType) + } + if targetID != "" { + query = query.Where("target_id = ?", targetID) + } + if result != "" { + query = query.Where("result = ?", result) + } + + var total int64 + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + var logs []model.AuditLog + offset := (page - 1) * pageSize + if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&logs).Error; err != nil { + return nil, 0, err + } + + return logs, total, nil +} + +// ==================== 定时任务 ==================== + +// AuditScheduler 审核调度器 +type AuditScheduler struct { + db *gorm.DB + service AuditService + interval time.Duration + stopCh chan bool +} + +// NewAuditScheduler 创建审核调度器 +func NewAuditScheduler(db *gorm.DB, service AuditService, interval time.Duration) *AuditScheduler { + return &AuditScheduler{ + db: db, + service: service, + interval: interval, + stopCh: make(chan bool), + } +} + +// Start 启动调度器 +func (s *AuditScheduler) Start() { + go func() { + ticker := time.NewTicker(s.interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.processPendingAudits() + case <-s.stopCh: + return + } + } + }() +} + +// Stop 停止调度器 +func (s *AuditScheduler) Stop() { + s.stopCh <- true +} + +// processPendingAudits 处理待审核内容 +func (s *AuditScheduler) processPendingAudits() { + // 查询待审核的内容 + // 1. 查询审核状态为 pending 的记录 + // 2. 调用审核服务 + // 3. 更新审核状态 + + // 示例逻辑,实际需要根据业务需求实现 + log.Println("Processing pending audits...") +} + +// CleanupOldLogs 清理旧的审核日志 +func CleanupOldLogs(db *gorm.DB, days int) error { + // 清理指定天数之前的审核日志 + cutoffTime := time.Now().AddDate(0, 0, -days) + return db.Where("created_at < ? AND result = ?", cutoffTime, model.AuditResultPass).Delete(&model.AuditLog{}).Error +} diff --git a/internal/service/chat_service.go b/internal/service/chat_service.go new file mode 100644 index 0000000..17fd563 --- /dev/null +++ b/internal/service/chat_service.go @@ -0,0 +1,622 @@ +package service + +import ( + "context" + "errors" + "fmt" + "log" + "time" + + "carrot_bbs/internal/model" + "carrot_bbs/internal/pkg/websocket" + "carrot_bbs/internal/repository" + + "gorm.io/gorm" +) + +// 撤回消息的时间限制(2分钟) +const RecallMessageTimeout = 2 * time.Minute + +// ChatService 聊天服务接口 +type ChatService interface { + // 会话管理 + GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) + GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) + GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) + DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error + SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error + + // 消息操作 + SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) + GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) + GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) + GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) + + // 已读管理 + MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error + GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) + GetAllUnreadCount(ctx context.Context, userID string) (int64, error) + + // 消息扩展功能 + RecallMessage(ctx context.Context, messageID string, userID string) error + DeleteMessage(ctx context.Context, messageID string, userID string) error + + // WebSocket相关 + SendTyping(ctx context.Context, senderID string, conversationID string) + BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string) + + // 系统消息推送 + IsUserOnline(userID string) bool + PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error + PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error + PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error + + // 仅保存消息到数据库,不发送 WebSocket 推送(供群聊等自行推送的场景使用) + SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) +} + +// chatServiceImpl 聊天服务实现 +type chatServiceImpl struct { + db *gorm.DB + repo *repository.MessageRepository + userRepo *repository.UserRepository + sensitive SensitiveService + wsManager *websocket.WebSocketManager +} + +// NewChatService 创建聊天服务 +func NewChatService( + db *gorm.DB, + repo *repository.MessageRepository, + userRepo *repository.UserRepository, + sensitive SensitiveService, + wsManager *websocket.WebSocketManager, +) ChatService { + return &chatServiceImpl{ + db: db, + repo: repo, + userRepo: userRepo, + sensitive: sensitive, + wsManager: wsManager, + } +} + +// GetOrCreateConversation 获取或创建私聊会话 +func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) { + return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID) +} + +// GetConversationList 获取用户的会话列表 +func (s *chatServiceImpl) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) { + return s.repo.GetConversations(userID, page, pageSize) +} + +// GetConversationByID 获取会话详情 +func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) { + // 验证用户是否是会话参与者 + participant, err := s.repo.GetParticipant(conversationID, userID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.New("conversation not found or no permission") + } + return nil, fmt.Errorf("failed to get participant: %w", err) + } + + // 获取会话信息 + conv, err := s.repo.GetConversation(conversationID) + if err != nil { + return nil, fmt.Errorf("failed to get conversation: %w", err) + } + + // 填充用户的已读位置信息 + _ = participant // 可以用于返回已读位置等信息 + + return conv, nil +} + +// DeleteConversationForSelf 仅自己删除会话 +func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error { + participant, err := s.repo.GetParticipant(conversationID, userID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("conversation not found or no permission") + } + return fmt.Errorf("failed to get participant: %w", err) + } + if participant.ConversationID == "" { + return errors.New("conversation not found or no permission") + } + + if err := s.repo.HideConversationForUser(conversationID, userID); err != nil { + return fmt.Errorf("failed to hide conversation: %w", err) + } + return nil +} + +// SetConversationPinned 设置会话置顶(用户维度) +func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error { + participant, err := s.repo.GetParticipant(conversationID, userID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("conversation not found or no permission") + } + return fmt.Errorf("failed to get participant: %w", err) + } + if participant.ConversationID == "" { + return errors.New("conversation not found or no permission") + } + + if err := s.repo.UpdatePinned(conversationID, userID, isPinned); err != nil { + return fmt.Errorf("failed to update pinned status: %w", err) + } + return nil +} + +// SendMessage 发送消息(使用 segments) +func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) { + // 首先验证会话是否存在 + conv, err := s.repo.GetConversation(conversationID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.New("会话不存在,请重新创建会话") + } + return nil, fmt.Errorf("failed to get conversation: %w", err) + } + + // 拉黑限制:仅拦截“被拉黑方 -> 拉黑人”方向 + if conv.Type == model.ConversationTypePrivate && s.userRepo != nil { + participants, pErr := s.repo.GetConversationParticipants(conversationID) + if pErr != nil { + return nil, fmt.Errorf("failed to get participants: %w", pErr) + } + var sentCount *int64 + for _, p := range participants { + if p.UserID == senderID { + continue + } + blocked, bErr := s.userRepo.IsBlocked(p.UserID, senderID) + if bErr != nil { + return nil, fmt.Errorf("failed to check block status: %w", bErr) + } + if blocked { + return nil, ErrUserBlocked + } + + // 陌生人限制:对方未回关前,只允许发送一条文本消息,且禁止发送图片 + isFollowedBack, fErr := s.userRepo.IsFollowing(p.UserID, senderID) + if fErr != nil { + return nil, fmt.Errorf("failed to check follow status: %w", fErr) + } + if !isFollowedBack { + if containsImageSegment(segments) { + return nil, errors.New("对方未关注你,暂不支持发送图片") + } + if sentCount == nil { + c, cErr := s.repo.CountMessagesBySenderInConversation(conversationID, senderID) + if cErr != nil { + return nil, fmt.Errorf("failed to count sender messages: %w", cErr) + } + sentCount = &c + } + if *sentCount >= 1 { + return nil, errors.New("对方未关注你前,仅允许发送一条消息") + } + } + } + } + + // 验证用户是否是会话参与者 + participant, err := s.repo.GetParticipant(conversationID, senderID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.New("您不是该会话的参与者") + } + return nil, fmt.Errorf("failed to get participant: %w", err) + } + + // 创建消息 + message := &model.Message{ + ConversationID: conversationID, + SenderID: senderID, // 直接使用string类型的UUID + Segments: segments, + ReplyToID: replyToID, + Status: model.MessageStatusNormal, + } + + // 使用事务创建消息并更新seq + if err := s.repo.CreateMessageWithSeq(message); err != nil { + return nil, fmt.Errorf("failed to save message: %w", err) + } + + // 发送消息给接收者 + log.Printf("[DEBUG SendMessage] 私聊消息 segments 类型: %T, 值: %+v", message.Segments, message.Segments) + wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, websocket.ChatMessage{ + ID: message.ID, + ConversationID: message.ConversationID, + SenderID: senderID, + Segments: message.Segments, + Seq: message.Seq, + CreatedAt: message.CreatedAt.UnixMilli(), + }) + + // 获取会话中的其他参与者 + participants, err := s.repo.GetConversationParticipants(conversationID) + if err == nil { + for _, p := range participants { + // 不发给自己 + if p.UserID == senderID { + continue + } + // 如果接收者在线,发送实时消息 + if s.wsManager != nil { + isOnline := s.wsManager.IsUserOnline(p.UserID) + log.Printf("[DEBUG SendMessage] 接收者 UserID=%s, 在线状态=%v", p.UserID, isOnline) + if isOnline { + log.Printf("[DEBUG SendMessage] 发送WebSocket消息给 UserID=%s, 消息类型=%s", p.UserID, wsMsg.Type) + s.wsManager.SendToUser(p.UserID, wsMsg) + } + } + } + } else { + log.Printf("[DEBUG SendMessage] 获取参与者失败: %v", err) + } + + _ = participant // 避免未使用变量警告 + + return message, nil +} + +func containsImageSegment(segments model.MessageSegments) bool { + for _, seg := range segments { + if seg.Type == string(model.ContentTypeImage) || seg.Type == "image" { + return true + } + } + return false +} + +// GetMessages 获取消息历史(分页) +func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) { + // 验证用户是否是会话参与者 + _, err := s.repo.GetParticipant(conversationID, userID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, 0, errors.New("conversation not found or no permission") + } + return nil, 0, fmt.Errorf("failed to get participant: %w", err) + } + + return s.repo.GetMessages(conversationID, page, pageSize) +} + +// GetMessagesAfterSeq 获取指定seq之后的消息(用于增量同步) +func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) { + // 验证用户是否是会话参与者 + _, err := s.repo.GetParticipant(conversationID, userID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.New("conversation not found or no permission") + } + return nil, fmt.Errorf("failed to get participant: %w", err) + } + + if limit <= 0 { + limit = 100 + } + + return s.repo.GetMessagesAfterSeq(conversationID, afterSeq, limit) +} + +// GetMessagesBeforeSeq 获取指定seq之前的历史消息(用于下拉加载更多) +func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) { + // 验证用户是否是会话参与者 + _, err := s.repo.GetParticipant(conversationID, userID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.New("conversation not found or no permission") + } + return nil, fmt.Errorf("failed to get participant: %w", err) + } + + if limit <= 0 { + limit = 20 + } + + return s.repo.GetMessagesBeforeSeq(conversationID, beforeSeq, limit) +} + +// MarkAsRead 标记已读 +func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error { + // 验证用户是否是会话参与者 + _, err := s.repo.GetParticipant(conversationID, userID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("conversation not found or no permission") + } + return fmt.Errorf("failed to get participant: %w", err) + } + + // 更新参与者的已读位置 + err = s.repo.UpdateLastReadSeq(conversationID, userID, seq) + if err != nil { + return fmt.Errorf("failed to update last read seq: %w", err) + } + + // 发送已读回执(作为 meta 事件) + if s.wsManager != nil { + wsMsg := websocket.CreateWSMessage("meta", map[string]interface{}{ + "detail_type": websocket.MetaDetailTypeRead, + "conversation_id": conversationID, + "seq": seq, + "user_id": userID, + }) + + // 获取会话中的所有参与者 + participants, err := s.repo.GetConversationParticipants(conversationID) + if err == nil { + // 推送给会话中的所有参与者(包括自己) + for _, p := range participants { + if s.wsManager.IsUserOnline(p.UserID) { + s.wsManager.SendToUser(p.UserID, wsMsg) + } + } + } + } + + return nil +} + +// GetUnreadCount 获取指定会话的未读消息数 +func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) { + // 验证用户是否是会话参与者 + _, err := s.repo.GetParticipant(conversationID, userID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return 0, errors.New("conversation not found or no permission") + } + return 0, fmt.Errorf("failed to get participant: %w", err) + } + + return s.repo.GetUnreadCount(conversationID, userID) +} + +// GetAllUnreadCount 获取所有会话的未读消息总数 +func (s *chatServiceImpl) GetAllUnreadCount(ctx context.Context, userID string) (int64, error) { + return s.repo.GetAllUnreadCount(userID) +} + +// RecallMessage 撤回消息(2分钟内) +func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, userID string) error { + // 获取消息 + var message model.Message + err := s.db.First(&message, "id = ?", messageID).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("message not found") + } + return fmt.Errorf("failed to get message: %w", err) + } + + // 验证是否是消息发送者 + if message.SenderIDStr() != userID { + return errors.New("can only recall your own messages") + } + + // 验证消息是否已被撤回 + if message.Status == model.MessageStatusRecalled { + return errors.New("message already recalled") + } + + // 验证是否在2分钟内 + if time.Since(message.CreatedAt) > RecallMessageTimeout { + return errors.New("message recall timeout (2 minutes)") + } + + // 更新消息状态为已撤回 + err = s.db.Model(&message).Update("status", model.MessageStatusRecalled).Error + if err != nil { + return fmt.Errorf("failed to recall message: %w", err) + } + + // 发送撤回通知 + if s.wsManager != nil { + wsMsg := websocket.CreateWSMessage(websocket.MessageTypeRecall, map[string]interface{}{ + "messageId": messageID, + "conversationId": message.ConversationID, + "senderId": userID, + }) + + // 通知会话中的所有参与者 + participants, err := s.repo.GetConversationParticipants(message.ConversationID) + if err == nil { + for _, p := range participants { + if s.wsManager.IsUserOnline(p.UserID) { + s.wsManager.SendToUser(p.UserID, wsMsg) + } + } + } + } + + return nil +} + +// DeleteMessage 删除消息(仅对自己可见) +func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, userID string) error { + // 获取消息 + var message model.Message + err := s.db.First(&message, "id = ?", messageID).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("message not found") + } + return fmt.Errorf("failed to get message: %w", err) + } + + // 验证用户是否是会话参与者 + _, err = s.repo.GetParticipant(message.ConversationID, userID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("no permission to delete this message") + } + return fmt.Errorf("failed to get participant: %w", err) + } + + // 对于删除消息,我们使用软删除,但需要确保只对当前用户隐藏 + // 这里简化处理:只有发送者可以删除自己的消息 + if message.SenderIDStr() != userID { + return errors.New("can only delete your own messages") + } + + // 更新消息状态为已删除 + err = s.db.Model(&message).Update("status", model.MessageStatusDeleted).Error + if err != nil { + return fmt.Errorf("failed to delete message: %w", err) + } + + return nil +} + +// SendTyping 发送正在输入状态 +func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conversationID string) { + if s.wsManager == nil { + return + } + + // 验证用户是否是会话参与者 + _, err := s.repo.GetParticipant(conversationID, senderID) + if err != nil { + return + } + + // 获取会话中的其他参与者 + participants, err := s.repo.GetConversationParticipants(conversationID) + if err != nil { + return + } + + for _, p := range participants { + if p.UserID == senderID { + continue + } + // 发送正在输入状态 + wsMsg := websocket.CreateWSMessage(websocket.MessageTypeTyping, map[string]string{ + "conversationId": conversationID, + "senderId": senderID, + }) + + if s.wsManager.IsUserOnline(p.UserID) { + s.wsManager.SendToUser(p.UserID, wsMsg) + } + } +} + +// BroadcastMessage 广播消息给用户 +func (s *chatServiceImpl) BroadcastMessage(ctx context.Context, msg *websocket.WSMessage, targetUser string) { + if s.wsManager != nil { + s.wsManager.SendToUser(targetUser, msg) + } +} + +// IsUserOnline 检查用户是否在线 +func (s *chatServiceImpl) IsUserOnline(userID string) bool { + if s.wsManager == nil { + return false + } + return s.wsManager.IsUserOnline(userID) +} + +// PushSystemMessage 推送系统消息给指定用户 +func (s *chatServiceImpl) PushSystemMessage(userID string, msgType, title, content string, data map[string]interface{}) error { + if s.wsManager == nil { + return errors.New("websocket manager not available") + } + + if !s.wsManager.IsUserOnline(userID) { + return errors.New("user is offline") + } + + sysMsg := &websocket.SystemMessage{ + ID: "", // 由调用方生成 + Type: msgType, + Title: title, + Content: content, + Data: data, + CreatedAt: time.Now().UnixMilli(), + } + + wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg) + s.wsManager.SendToUser(userID, wsMsg) + return nil +} + +// PushNotificationMessage 推送通知消息给指定用户 +func (s *chatServiceImpl) PushNotificationMessage(userID string, notification *websocket.NotificationMessage) error { + if s.wsManager == nil { + return errors.New("websocket manager not available") + } + + if !s.wsManager.IsUserOnline(userID) { + return errors.New("user is offline") + } + + // 确保时间戳已设置 + if notification.CreatedAt == 0 { + notification.CreatedAt = time.Now().UnixMilli() + } + + wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification) + s.wsManager.SendToUser(userID, wsMsg) + return nil +} + +// PushAnnouncementMessage 广播公告消息给所有在线用户 +func (s *chatServiceImpl) PushAnnouncementMessage(announcement *websocket.AnnouncementMessage) error { + if s.wsManager == nil { + return errors.New("websocket manager not available") + } + + // 确保时间戳已设置 + if announcement.CreatedAt == 0 { + announcement.CreatedAt = time.Now().UnixMilli() + } + + wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement) + s.wsManager.Broadcast(wsMsg) + return nil +} + +// SaveMessage 仅保存消息到数据库,不发送 WebSocket 推送 +// 适用于群聊等由调用方自行负责推送的场景 +func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) { + // 验证会话是否存在 + _, err := s.repo.GetConversation(conversationID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.New("会话不存在,请重新创建会话") + } + return nil, fmt.Errorf("failed to get conversation: %w", err) + } + + // 验证用户是否是会话参与者 + _, err = s.repo.GetParticipant(conversationID, senderID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.New("您不是该会话的参与者") + } + return nil, fmt.Errorf("failed to get participant: %w", err) + } + + message := &model.Message{ + ConversationID: conversationID, + SenderID: senderID, + Segments: segments, + ReplyToID: replyToID, + Status: model.MessageStatusNormal, + } + + if err := s.repo.CreateMessageWithSeq(message); err != nil { + return nil, fmt.Errorf("failed to save message: %w", err) + } + + return message, nil +} diff --git a/internal/service/comment_service.go b/internal/service/comment_service.go new file mode 100644 index 0000000..2a9481a --- /dev/null +++ b/internal/service/comment_service.go @@ -0,0 +1,273 @@ +package service + +import ( + "context" + "errors" + "fmt" + "log" + "strings" + + "carrot_bbs/internal/model" + "carrot_bbs/internal/pkg/gorse" + "carrot_bbs/internal/repository" +) + +// CommentService 评论服务 +type CommentService struct { + commentRepo *repository.CommentRepository + postRepo *repository.PostRepository + systemMessageService SystemMessageService + gorseClient gorse.Client + postAIService *PostAIService +} + +// NewCommentService 创建评论服务 +func NewCommentService(commentRepo *repository.CommentRepository, postRepo *repository.PostRepository, systemMessageService SystemMessageService, gorseClient gorse.Client, postAIService *PostAIService) *CommentService { + return &CommentService{ + commentRepo: commentRepo, + postRepo: postRepo, + systemMessageService: systemMessageService, + gorseClient: gorseClient, + postAIService: postAIService, + } +} + +// Create 创建评论 +func (s *CommentService) Create(ctx context.Context, postID, userID, content string, parentID *string, images string, imageURLs []string) (*model.Comment, error) { + if s.postAIService != nil { + // 采用异步审核,前端先立即返回 + } + + // 获取帖子信息用于发送通知 + post, err := s.postRepo.GetByID(postID) + if err != nil { + return nil, err + } + + comment := &model.Comment{ + PostID: postID, + UserID: userID, + Content: content, + ParentID: parentID, + Images: images, + Status: model.CommentStatusPending, + } + + // 如果有父评论,设置根评论ID + var parentUserID string + if parentID != nil { + parent, err := s.commentRepo.GetByID(*parentID) + if err == nil && parent != nil { + if parent.RootID != nil { + comment.RootID = parent.RootID + } else { + comment.RootID = parentID + } + parentUserID = parent.UserID + } + } + + err = s.commentRepo.Create(comment) + if err != nil { + return nil, err + } + + // 重新查询以获取关联的 User + comment, err = s.commentRepo.GetByID(comment.ID) + if err != nil { + return nil, err + } + + go s.reviewCommentAsync(comment.ID, userID, postID, content, imageURLs, parentID, parentUserID, post.UserID) + + return comment, nil +} + +func (s *CommentService) reviewCommentAsync( + commentID, userID, postID, content string, + imageURLs []string, + parentID *string, + parentUserID string, + postOwnerID string, +) { + // 未启用AI时,直接通过审核并发送后续通知 + if s.postAIService == nil || !s.postAIService.IsEnabled() { + if err := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); err != nil { + log.Printf("[WARN] Failed to publish comment without AI moderation: %v", err) + return + } + s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID) + return + } + + err := s.postAIService.ModerateComment(context.Background(), content, imageURLs) + if err != nil { + var rejectedErr *CommentModerationRejectedError + if errors.As(err, &rejectedErr) { + if delErr := s.commentRepo.Delete(commentID); delErr != nil { + log.Printf("[WARN] Failed to delete rejected comment %s: %v", commentID, delErr) + } + s.notifyCommentModerationRejected(userID, rejectedErr.Reason) + return + } + + // 审核服务异常时降级放行,避免评论长期pending + if updateErr := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); updateErr != nil { + log.Printf("[WARN] Failed to publish comment %s after moderation error: %v", commentID, updateErr) + return + } + log.Printf("[WARN] Comment moderation failed, fallback publish comment=%s err=%v", commentID, err) + s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID) + return + } + + if updateErr := s.commentRepo.UpdateModerationStatus(commentID, model.CommentStatusPublished); updateErr != nil { + log.Printf("[WARN] Failed to publish comment %s: %v", commentID, updateErr) + return + } + s.afterCommentPublished(userID, postID, commentID, parentID, parentUserID, postOwnerID) +} + +func (s *CommentService) afterCommentPublished(userID, postID, commentID string, parentID *string, parentUserID, postOwnerID string) { + // 发送系统消息通知 + if s.systemMessageService != nil { + go func() { + if parentID != nil && parentUserID != "" { + // 回复评论,通知被回复的人 + if parentUserID != userID { + notifyErr := s.systemMessageService.SendReplyNotification(context.Background(), parentUserID, userID, postID, *parentID, commentID) + if notifyErr != nil { + fmt.Printf("[DEBUG] Error sending reply notification: %v\n", notifyErr) + } + } + } else { + // 评论帖子,通知帖子作者 + if postOwnerID != userID { + notifyErr := s.systemMessageService.SendCommentNotification(context.Background(), postOwnerID, userID, postID, commentID) + if notifyErr != nil { + fmt.Printf("[DEBUG] Error sending comment notification: %v\n", notifyErr) + } + } + } + }() + } + + // 推送评论行为到Gorse(异步) + go func() { + if s.gorseClient.IsEnabled() { + if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeComment, userID, postID); err != nil { + log.Printf("[WARN] Failed to insert comment feedback to Gorse: %v", err) + } + } + }() +} + +func (s *CommentService) notifyCommentModerationRejected(userID, reason string) { + if s.systemMessageService == nil || strings.TrimSpace(userID) == "" { + return + } + + content := "您发布的评论未通过AI审核,请修改后重试。" + if strings.TrimSpace(reason) != "" { + content = fmt.Sprintf("您发布的评论未通过AI审核,原因:%s。请修改后重试。", reason) + } + + go func() { + if err := s.systemMessageService.SendSystemAnnouncement( + context.Background(), + []string{userID}, + "评论审核未通过", + content, + ); err != nil { + log.Printf("[WARN] Failed to send comment moderation reject notification: %v", err) + } + }() +} + +// GetByID 根据ID获取评论 +func (s *CommentService) GetByID(ctx context.Context, id string) (*model.Comment, error) { + return s.commentRepo.GetByID(id) +} + +// GetByPostID 获取帖子评论 +func (s *CommentService) GetByPostID(ctx context.Context, postID string, page, pageSize int) ([]*model.Comment, int64, error) { + // 使用带回复的查询,默认加载前3条回复 + return s.commentRepo.GetByPostIDWithReplies(postID, page, pageSize, 3) +} + +// GetRepliesByRootID 根据根评论ID分页获取回复 +func (s *CommentService) GetRepliesByRootID(ctx context.Context, rootID string, page, pageSize int) ([]*model.Comment, int64, error) { + return s.commentRepo.GetRepliesByRootID(rootID, page, pageSize) +} + +// GetReplies 获取回复 +func (s *CommentService) GetReplies(ctx context.Context, parentID string) ([]*model.Comment, error) { + return s.commentRepo.GetReplies(parentID) +} + +// Update 更新评论 +func (s *CommentService) Update(ctx context.Context, comment *model.Comment) error { + return s.commentRepo.Update(comment) +} + +// Delete 删除评论 +func (s *CommentService) Delete(ctx context.Context, id string) error { + return s.commentRepo.Delete(id) +} + +// Like 点赞评论 +func (s *CommentService) Like(ctx context.Context, commentID, userID string) error { + // 获取评论信息用于发送通知 + comment, err := s.commentRepo.GetByID(commentID) + if err != nil { + return err + } + + err = s.commentRepo.Like(commentID, userID) + if err != nil { + return err + } + + // 发送评论/回复点赞通知(只有不是给自己点赞时才发送) + if s.systemMessageService != nil && comment.UserID != userID { + go func() { + var notifyErr error + if comment.ParentID != nil { + notifyErr = s.systemMessageService.SendLikeReplyNotification( + context.Background(), + comment.UserID, + userID, + comment.PostID, + commentID, + comment.Content, + ) + } else { + notifyErr = s.systemMessageService.SendLikeCommentNotification( + context.Background(), + comment.UserID, + userID, + comment.PostID, + commentID, + comment.Content, + ) + } + if notifyErr != nil { + fmt.Printf("[DEBUG] Error sending like notification: %v\n", notifyErr) + } else { + fmt.Printf("[DEBUG] Like notification sent successfully\n") + } + }() + } + + return nil +} + +// Unlike 取消点赞评论 +func (s *CommentService) Unlike(ctx context.Context, commentID, userID string) error { + return s.commentRepo.Unlike(commentID, userID) +} + +// IsLiked 检查是否已点赞 +func (s *CommentService) IsLiked(ctx context.Context, commentID, userID string) bool { + return s.commentRepo.IsLiked(commentID, userID) +} diff --git a/internal/service/email_code_service.go b/internal/service/email_code_service.go new file mode 100644 index 0000000..87d0a01 --- /dev/null +++ b/internal/service/email_code_service.go @@ -0,0 +1,234 @@ +package service + +import ( + "context" + "crypto/rand" + "encoding/json" + "fmt" + "math/big" + "strings" + "time" + + "carrot_bbs/internal/cache" + "carrot_bbs/internal/pkg/utils" +) + +const ( + verifyCodeTTL = 10 * time.Minute + verifyCodeRateLimitTTL = 60 * time.Second +) + +const ( + CodePurposeRegister = "register" + CodePurposePasswordReset = "password_reset" + CodePurposeEmailVerify = "email_verify" + CodePurposeChangePassword = "change_password" +) + +type verificationCodePayload struct { + Code string `json:"code"` + Purpose string `json:"purpose"` + Email string `json:"email"` + ExpiresAt int64 `json:"expires_at"` +} + +type EmailCodeService interface { + SendCode(ctx context.Context, purpose, email string) error + VerifyCode(purpose, email, code string) error +} + +type emailCodeServiceImpl struct { + emailService EmailService + cache cache.Cache +} + +func NewEmailCodeService(emailService EmailService, cacheBackend cache.Cache) EmailCodeService { + if cacheBackend == nil { + cacheBackend = cache.GetCache() + } + return &emailCodeServiceImpl{ + emailService: emailService, + cache: cacheBackend, + } +} + +func verificationCodeCacheKey(purpose, email string) string { + return fmt.Sprintf("auth:verify_code:%s:%s", purpose, strings.ToLower(strings.TrimSpace(email))) +} + +func verificationCodeRateLimitKey(purpose, email string) string { + return fmt.Sprintf("auth:verify_code_rate_limit:%s:%s", purpose, strings.ToLower(strings.TrimSpace(email))) +} + +func generateNumericCode(length int) (string, error) { + if length <= 0 { + return "", fmt.Errorf("invalid code length") + } + max := big.NewInt(10) + result := make([]byte, length) + for i := 0; i < length; i++ { + n, err := rand.Int(rand.Reader, max) + if err != nil { + return "", err + } + result[i] = byte('0' + n.Int64()) + } + return string(result), nil +} + +func (s *emailCodeServiceImpl) SendCode(ctx context.Context, purpose, email string) error { + if strings.TrimSpace(email) == "" || !utils.ValidateEmail(email) { + return ErrInvalidEmail + } + if s.emailService == nil || !s.emailService.IsEnabled() { + return ErrEmailServiceUnavailable + } + if s.cache == nil { + return ErrVerificationCodeUnavailable + } + + rateLimitKey := verificationCodeRateLimitKey(purpose, email) + if s.cache.Exists(rateLimitKey) { + return ErrVerificationCodeTooFrequent + } + + code, err := generateNumericCode(6) + if err != nil { + return fmt.Errorf("generate verification code failed: %w", err) + } + payload := verificationCodePayload{ + Code: code, + Purpose: purpose, + Email: strings.ToLower(strings.TrimSpace(email)), + ExpiresAt: time.Now().Add(verifyCodeTTL).Unix(), + } + cacheKey := verificationCodeCacheKey(purpose, email) + s.cache.Set(cacheKey, payload, verifyCodeTTL) + s.cache.Set(rateLimitKey, "1", verifyCodeRateLimitTTL) + + subject, sceneText := verificationEmailMeta(purpose) + textBody := fmt.Sprintf("【%s】验证码:%s\n有效期:10分钟\n请勿将验证码泄露给他人。", sceneText, code) + htmlBody := buildVerificationEmailHTML(sceneText, code) + if err := s.emailService.Send(ctx, SendEmailRequest{ + To: []string{email}, + Subject: subject, + TextBody: textBody, + HTMLBody: htmlBody, + }); err != nil { + s.cache.Delete(cacheKey) + return fmt.Errorf("send verification email failed: %w", err) + } + + return nil +} + +func (s *emailCodeServiceImpl) VerifyCode(purpose, email, code string) error { + if strings.TrimSpace(email) == "" || strings.TrimSpace(code) == "" { + return ErrVerificationCodeInvalid + } + if s.cache == nil { + return ErrVerificationCodeUnavailable + } + + cacheKey := verificationCodeCacheKey(purpose, email) + raw, ok := s.cache.Get(cacheKey) + if !ok { + return ErrVerificationCodeExpired + } + + var payload verificationCodePayload + switch v := raw.(type) { + case string: + if err := json.Unmarshal([]byte(v), &payload); err != nil { + return ErrVerificationCodeInvalid + } + case []byte: + if err := json.Unmarshal(v, &payload); err != nil { + return ErrVerificationCodeInvalid + } + case verificationCodePayload: + payload = v + default: + data, err := json.Marshal(v) + if err != nil { + return ErrVerificationCodeInvalid + } + if err := json.Unmarshal(data, &payload); err != nil { + return ErrVerificationCodeInvalid + } + } + + if payload.Purpose != purpose || payload.Email != strings.ToLower(strings.TrimSpace(email)) { + return ErrVerificationCodeInvalid + } + if payload.ExpiresAt > 0 && time.Now().Unix() > payload.ExpiresAt { + s.cache.Delete(cacheKey) + return ErrVerificationCodeExpired + } + if payload.Code != strings.TrimSpace(code) { + return ErrVerificationCodeInvalid + } + + s.cache.Delete(cacheKey) + return nil +} + +func verificationEmailMeta(purpose string) (subject string, sceneText string) { + switch purpose { + case CodePurposeRegister: + return "Carrot BBS 注册验证码", "注册账号" + case CodePurposePasswordReset: + return "Carrot BBS 找回密码验证码", "找回密码" + case CodePurposeEmailVerify: + return "Carrot BBS 邮箱验证验证码", "验证邮箱" + case CodePurposeChangePassword: + return "Carrot BBS 修改密码验证码", "修改密码" + default: + return "Carrot BBS 验证码", "身份验证" + } +} + +func buildVerificationEmailHTML(sceneText, code string) string { + return fmt.Sprintf(` + + + + + Carrot BBS 验证码 + + + + + + +
+ + + + + + + + + + +
+
Carrot BBS
+
%s 验证
+
+

你好,

+

你正在进行 %s 操作,请使用下方验证码完成验证:

+
+
验证码(10分钟内有效)
+
%s
+
+

如果不是你本人操作,请忽略此邮件,并及时检查账号安全。

+

请勿向任何人透露验证码,平台不会以任何理由索取验证码。

+
+ 此邮件由系统自动发送,请勿直接回复。
+ © Carrot BBS +
+
+ +`, sceneText, sceneText, code) +} diff --git a/internal/service/email_service.go b/internal/service/email_service.go new file mode 100644 index 0000000..e879731 --- /dev/null +++ b/internal/service/email_service.go @@ -0,0 +1,82 @@ +package service + +import ( + "context" + "fmt" + "strings" + + emailpkg "carrot_bbs/internal/pkg/email" +) + +// SendEmailRequest 发信请求 +type SendEmailRequest struct { + To []string + Cc []string + Bcc []string + ReplyTo []string + Subject string + TextBody string + HTMLBody string + Attachments []string +} + +type EmailService interface { + IsEnabled() bool + Send(ctx context.Context, req SendEmailRequest) error + SendText(ctx context.Context, to []string, subject, body string) error + SendHTML(ctx context.Context, to []string, subject, html string) error +} + +type emailServiceImpl struct { + client emailpkg.Client +} + +func NewEmailService(client emailpkg.Client) EmailService { + return &emailServiceImpl{client: client} +} + +func (s *emailServiceImpl) IsEnabled() bool { + return s.client != nil && s.client.IsEnabled() +} + +func (s *emailServiceImpl) Send(ctx context.Context, req SendEmailRequest) error { + if s.client == nil { + return fmt.Errorf("email client is nil") + } + if !s.client.IsEnabled() { + return fmt.Errorf("email service is disabled") + } + if len(req.To) == 0 { + return fmt.Errorf("email recipient is empty") + } + if strings.TrimSpace(req.Subject) == "" { + return fmt.Errorf("email subject is empty") + } + + return s.client.Send(ctx, emailpkg.Message{ + To: req.To, + Cc: req.Cc, + Bcc: req.Bcc, + ReplyTo: req.ReplyTo, + Subject: req.Subject, + TextBody: req.TextBody, + HTMLBody: req.HTMLBody, + Attachments: req.Attachments, + }) +} + +func (s *emailServiceImpl) SendText(ctx context.Context, to []string, subject, body string) error { + return s.Send(ctx, SendEmailRequest{ + To: to, + Subject: subject, + TextBody: body, + }) +} + +func (s *emailServiceImpl) SendHTML(ctx context.Context, to []string, subject, html string) error { + return s.Send(ctx, SendEmailRequest{ + To: to, + Subject: subject, + HTMLBody: html, + }) +} diff --git a/internal/service/group_service.go b/internal/service/group_service.go new file mode 100644 index 0000000..226eae8 --- /dev/null +++ b/internal/service/group_service.go @@ -0,0 +1,1491 @@ +package service + +import ( + "errors" + "fmt" + "log" + "strconv" + "time" + + "carrot_bbs/internal/cache" + "carrot_bbs/internal/model" + "carrot_bbs/internal/pkg/utils" + "carrot_bbs/internal/pkg/websocket" + "carrot_bbs/internal/repository" + + "gorm.io/gorm" +) + +// 缓存TTL常量 +const ( + GroupMembersTTL = 120 * time.Second // 群组成员缓存120秒 + GroupMembersNullTTL = 5 * time.Second + GroupCacheJitter = 0.1 +) + +// 群组服务错误定义 +var ( + ErrGroupNotFound = errors.New("群组不存在") + ErrNotGroupMember = errors.New("不是群成员") + ErrNotGroupAdmin = errors.New("不是群管理员") + ErrNotGroupOwner = errors.New("不是群主") + ErrGroupFull = errors.New("群已满") + ErrAlreadyMember = errors.New("已经是群成员") + ErrCannotRemoveOwner = errors.New("不能移除群主") + ErrCannotMuteOwner = errors.New("不能禁言群主") + ErrMuted = errors.New("你已被禁言") + ErrMuteAllEnabled = errors.New("全员禁言中") + ErrCannotJoin = errors.New("该群不允许加入") + ErrJoinRequestPending = errors.New("加群申请已提交") + ErrGroupRequestNotFound = errors.New("加群请求不存在") + ErrGroupRequestHandled = errors.New("该加群请求已处理") + ErrNotRequestTarget = errors.New("不是邀请目标用户") + ErrNoEligibleInvitee = errors.New("没有可邀请的用户") + ErrNotMutualFollow = errors.New("仅支持邀请互相关注用户") +) + +// GroupService 群组服务接口 +type GroupService interface { + // 群组管理 + CreateGroup(ownerID string, name string, description string, memberIDs []string) (*model.Group, error) + GetGroupByID(id string) (*model.Group, error) + UpdateGroup(userID string, groupID string, updates map[string]interface{}) error + DissolveGroup(userID string, groupID string) error + TransferOwner(userID string, groupID string, newOwnerID string) error + GetUserGroups(userID string, page, pageSize int) ([]model.Group, int64, error) + GetMemberCount(groupID string) (int, error) + + // 成员管理 + InviteMembers(userID string, groupID string, memberIDs []string) error + JoinGroup(userID string, groupID string) error + RespondInvite(userID string, flag string, approve bool, reason string) error + SetGroupAddRequest(userID string, flag string, approve bool, reason string) error + LeaveGroup(userID string, groupID string) error + RemoveMember(userID string, groupID string, targetUserID string) error + GetMembers(groupID string, page, pageSize int) ([]model.GroupMember, int64, error) + SetMemberRole(userID string, groupID string, targetUserID string, role string) error + SetMemberNickname(userID string, groupID string, nickname string) error + MuteMember(userID string, groupID string, targetUserID string, muted bool) error + + // 群设置 + SetMuteAll(userID string, groupID string, muteAll bool) error + SetJoinType(userID string, groupID string, joinType int) error + + // 群公告 + CreateAnnouncement(userID string, groupID string, content string) (*model.GroupAnnouncement, error) + GetAnnouncements(groupID string, page, pageSize int) ([]model.GroupAnnouncement, int64, error) + DeleteAnnouncement(userID string, announcementID string) error + + // 权限检查 + CanSendGroupMessage(userID string, groupID string) error + IsGroupAdmin(userID string, groupID string) bool + IsGroupOwner(userID string, groupID string) bool + + // 获取成员信息 + GetMember(groupID string, userID string) (*model.GroupMember, error) +} + +// GroupMembersResult 群组成员缓存结果 +type GroupMembersResult struct { + Members []model.GroupMember + Total int64 +} + +// groupService 群组服务实现 +type groupService struct { + db *gorm.DB + groupRepo repository.GroupRepository + userRepo *repository.UserRepository + messageRepo *repository.MessageRepository + requestRepo repository.GroupJoinRequestRepository + notifyRepo *repository.SystemNotificationRepository + wsManager *websocket.WebSocketManager + cache cache.Cache +} + +// NewGroupService 创建群组服务 +func NewGroupService(db *gorm.DB, groupRepo repository.GroupRepository, userRepo *repository.UserRepository, messageRepo *repository.MessageRepository, wsManager *websocket.WebSocketManager) GroupService { + return &groupService{ + db: db, + groupRepo: groupRepo, + userRepo: userRepo, + messageRepo: messageRepo, + requestRepo: repository.NewGroupJoinRequestRepository(db), + notifyRepo: repository.NewSystemNotificationRepository(db), + wsManager: wsManager, + cache: cache.GetCache(), + } +} + +// ==================== 群组管理 ==================== + +// CreateGroup 创建群组 +func (s *groupService) CreateGroup(ownerID string, name string, description string, memberIDs []string) (*model.Group, error) { + // 创建群组(ID会在BeforeCreate中自动生成) + group := &model.Group{ + Name: name, + Description: description, + OwnerID: ownerID, + MemberCount: 1, // 群主 + MaxMembers: 500, + JoinType: model.JoinTypeAnyone, + MuteAll: false, + } + + // 保存群组 + if err := s.groupRepo.Create(group); err != nil { + return nil, err + } + + // 添加群主为成员 + ownerMember := &model.GroupMember{ + GroupID: group.ID, + UserID: ownerID, + Role: model.GroupRoleOwner, + JoinTime: time.Now(), + } + if err := s.groupRepo.AddMember(ownerMember); err != nil { + // 回滚:删除群组 + _ = s.groupRepo.Delete(group.ID) + return nil, err + } + + // 邀请初始成员 + if len(memberIDs) > 0 { + for _, memberID := range memberIDs { + if memberID == ownerID { + continue // 跳过群主 + } + member := &model.GroupMember{ + GroupID: group.ID, + UserID: memberID, + Role: model.GroupRoleMember, + JoinTime: time.Now(), + } + if err := s.groupRepo.AddMember(member); err != nil { + // 单个成员添加失败不回滚整个操作 + continue + } + } + } + + // 创建群组会话(Conversation) + conversationID, err := utils.GetSnowflake().GenerateID() + if err != nil { + conversationID = int64(time.Now().UnixNano()) + } + conversation := &model.Conversation{ + ID: strconv.FormatInt(conversationID, 10), + Type: model.ConversationTypeGroup, + GroupID: &group.ID, + } + + // 在事务中创建会话和参与者 + err = s.db.Transaction(func(tx *gorm.DB) error { + // 创建会话 + if err := tx.Create(conversation).Error; err != nil { + return err + } + + // 添加群主为会话参与者 + ownerParticipant := model.ConversationParticipant{ + ConversationID: conversation.ID, + UserID: ownerID, + LastReadSeq: 0, + } + if err := tx.Create(&ownerParticipant).Error; err != nil { + return err + } + + // 添加被邀请的成员为会话参与者 + for _, memberID := range memberIDs { + if memberID == ownerID { + continue + } + participant := model.ConversationParticipant{ + ConversationID: conversation.ID, + UserID: memberID, + LastReadSeq: 0, + } + if err := tx.Create(&participant).Error; err != nil { + // 单个参与者添加失败继续处理其他成员 + continue + } + } + + return nil + }) + + if err != nil { + // 记录错误但不影响群组创建成功 + } + + return group, nil +} + +// GetGroupByID 根据ID获取群组 +func (s *groupService) GetGroupByID(id string) (*model.Group, error) { + group, err := s.groupRepo.GetByID(id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGroupNotFound + } + return nil, err + } + return group, nil +} + +// GetMemberCount 实时获取群成员数量 +func (s *groupService) GetMemberCount(groupID string) (int, error) { + count, err := s.groupRepo.GetMemberCount(groupID) + if err != nil { + return 0, err + } + return int(count), nil +} + +// UpdateGroup 更新群组信息 +func (s *groupService) UpdateGroup(userID string, groupID string, updates map[string]interface{}) error { + // 检查群组是否存在 + group, err := s.groupRepo.GetByID(groupID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGroupNotFound + } + return err + } + + // 检查权限:只有群主和管理员可以更新群信息 + if !s.IsGroupAdmin(userID, groupID) { + return ErrNotGroupAdmin + } + + // 不允许直接修改的字段 + delete(updates, "id") + delete(updates, "owner_id") + delete(updates, "member_count") + delete(updates, "created_at") + + // 应用更新 + if name, ok := updates["name"].(string); ok { + group.Name = name + } + if description, ok := updates["description"].(string); ok { + group.Description = description + } + if avatar, ok := updates["avatar"].(string); ok { + group.Avatar = avatar + } + + return s.groupRepo.Update(group) +} + +// DissolveGroup 解散群组 +func (s *groupService) DissolveGroup(userID string, groupID string) error { + // 检查群组是否存在 + group, err := s.groupRepo.GetByID(groupID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGroupNotFound + } + return err + } + + // 检查权限:只有群主可以解散群 + if group.OwnerID != userID { + return ErrNotGroupOwner + } + + // 先删除群组对应的会话(包括参与者、消息) + if s.messageRepo != nil { + if err := s.messageRepo.DeleteConversationByGroupID(groupID); err != nil { + log.Printf("[DissolveGroup] 删除会话失败: groupID=%s, err=%v", groupID, err) + // 继续删除群组,不因为会话删除失败而中断 + } + } + + return s.groupRepo.Delete(groupID) +} + +// TransferOwner 转让群主 +func (s *groupService) TransferOwner(userID string, groupID string, newOwnerID string) error { + // 检查群组是否存在 + group, err := s.groupRepo.GetByID(groupID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGroupNotFound + } + return err + } + + // 检查权限:只有群主可以转让群主 + if group.OwnerID != userID { + return ErrNotGroupOwner + } + + // 检查新群主是否是群成员 + isMember, err := s.groupRepo.IsMember(groupID, newOwnerID) + if err != nil { + return err + } + if !isMember { + return ErrNotGroupMember + } + + // 在事务中更新 + return s.db.Transaction(func(tx *gorm.DB) error { + // 更新群组的群主 + group.OwnerID = newOwnerID + if err := tx.Save(group).Error; err != nil { + return err + } + + // 更新原群主为管理员 + if err := tx.Model(&model.GroupMember{}). + Where("group_id = ? AND user_id = ?", groupID, userID). + Update("role", model.GroupRoleAdmin).Error; err != nil { + return err + } + + // 更新新群主为群主 + if err := tx.Model(&model.GroupMember{}). + Where("group_id = ? AND user_id = ?", groupID, newOwnerID). + Update("role", model.GroupRoleOwner).Error; err != nil { + return err + } + + return nil + }) +} + +// GetUserGroups 获取用户加入的群组列表 +func (s *groupService) GetUserGroups(userID string, page, pageSize int) ([]model.Group, int64, error) { + return s.groupRepo.GetUserGroups(userID, page, pageSize) +} + +// ==================== 成员管理 ==================== + +func (s *groupService) newRequestFlag() string { + id, err := utils.GetSnowflake().GenerateID() + if err != nil { + return strconv.FormatInt(time.Now().UnixNano(), 10) + } + return strconv.FormatInt(id, 10) +} + +func (s *groupService) createSystemNotification(receiverID string, notifyType model.SystemNotificationType, content string, extra *model.SystemNotificationExtra) { + if s.notifyRepo == nil { + return + } + notification := &model.SystemNotification{ + ReceiverID: receiverID, + Type: notifyType, + Content: content, + ExtraData: extra, + } + if err := s.notifyRepo.Create(notification); err != nil { + log.Printf("[groupService] create system notification failed: receiverID=%s type=%s err=%v", receiverID, notifyType, err) + } +} + +func (s *groupService) broadcastMemberJoinNotice(groupID string, targetUserID string, operatorID string) { + if groupID == "" || targetUserID == "" { + return + } + + targetUserName := "用户" + if targetUser, err := s.userRepo.GetByID(targetUserID); err == nil && targetUser != nil && targetUser.Nickname != "" { + targetUserName = targetUser.Nickname + } + noticeContent := "\"" + targetUserName + "\" 加入了群聊" + + var savedMessage *model.Message + if s.messageRepo != nil { + conv, err := s.messageRepo.GetConversationByGroupID(groupID) + if err == nil && conv != nil { + msg := &model.Message{ + ConversationID: conv.ID, + SenderID: model.SystemSenderIDStr, + Segments: model.MessageSegments{ + {Type: "text", Data: map[string]interface{}{"text": noticeContent}}, + }, + Status: model.MessageStatusNormal, + Category: model.CategoryNotification, + } + if err := s.messageRepo.CreateMessageWithSeq(msg); err != nil { + log.Printf("[broadcastMemberJoinNotice] 保存入群提示消息失败: groupID=%s, userID=%s, err=%v", groupID, targetUserID, err) + } else { + savedMessage = msg + } + } else { + log.Printf("[broadcastMemberJoinNotice] 获取群组会话失败: groupID=%s, err=%v", groupID, err) + } + } + + if s.wsManager == nil { + return + } + + noticeMsg := websocket.GroupNoticeMessage{ + NoticeType: "member_join", + GroupID: groupID, + Data: websocket.GroupNoticeData{ + UserID: targetUserID, + Username: targetUserName, + OperatorID: operatorID, + }, + Timestamp: time.Now().UnixMilli(), + } + if savedMessage != nil { + noticeMsg.MessageID = savedMessage.ID + noticeMsg.Seq = savedMessage.Seq + } + + wsMsg := websocket.CreateWSMessage(websocket.MessageTypeGroupNotice, noticeMsg) + members, _, err := s.groupRepo.GetMembers(groupID, 1, 1000) + if err != nil { + log.Printf("[broadcastMemberJoinNotice] 获取群成员失败: groupID=%s, err=%v", groupID, err) + return + } + for _, m := range members { + if s.wsManager.IsUserOnline(m.UserID) { + s.wsManager.SendToUser(m.UserID, wsMsg) + } + } +} + +func (s *groupService) addMemberToGroupAndConversation(group *model.Group, userID string, operatorID string) error { + if group == nil { + return ErrGroupNotFound + } + isMember, err := s.groupRepo.IsMember(group.ID, userID) + if err != nil { + return err + } + if isMember { + return nil + } + memberCount, err := s.groupRepo.GetMemberCount(group.ID) + if err != nil { + return err + } + if int(memberCount) >= group.MaxMembers { + return ErrGroupFull + } + + member := &model.GroupMember{ + GroupID: group.ID, + UserID: userID, + Role: model.GroupRoleMember, + JoinTime: time.Now(), + } + if err := s.groupRepo.AddMember(member); err != nil { + return err + } + if s.messageRepo != nil { + conv, err := s.messageRepo.GetConversationByGroupID(group.ID) + if err == nil && conv != nil { + if err := s.messageRepo.AddParticipant(conv.ID, userID); err != nil { + log.Printf("[addMemberToGroupAndConversation] 添加会话参与者失败: groupID=%s, userID=%s, err=%v", group.ID, userID, err) + } + } + } + cache.InvalidateGroupMembers(s.cache, group.ID) + s.broadcastMemberJoinNotice(group.ID, userID, operatorID) + return nil +} + +func (s *groupService) collectGroupReviewerIDs(groupID string, ownerID string, skipUserID string) []string { + reviewerSet := make(map[string]struct{}) + if ownerID != "" && ownerID != skipUserID { + reviewerSet[ownerID] = struct{}{} + } + + var reviewers []model.GroupMember + if err := s.db.Where("group_id = ? AND role IN ?", groupID, []string{model.GroupRoleOwner, model.GroupRoleAdmin}).Find(&reviewers).Error; err == nil { + for _, reviewer := range reviewers { + if reviewer.UserID == "" || reviewer.UserID == skipUserID { + continue + } + reviewerSet[reviewer.UserID] = struct{}{} + } + } + + result := make([]string, 0, len(reviewerSet)) + for id := range reviewerSet { + result = append(result, id) + } + return result +} + +func (s *groupService) notifyJoinApplyReviewers( + group *model.Group, + req *model.GroupJoinRequest, + applicantName string, + applicantAvatar string, + targetUserID string, + targetUserName string, + targetUserAvatar string, +) { + if group == nil || req == nil { + return + } + reviewerIDs := s.collectGroupReviewerIDs(group.ID, group.OwnerID, req.InitiatorID) + for _, reviewerID := range reviewerIDs { + s.createSystemNotification(reviewerID, model.SysNotifyGroupJoinApply, applicantName+" 申请加入群聊 "+group.Name, &model.SystemNotificationExtra{ + ActorIDStr: req.InitiatorID, + ActorName: applicantName, + AvatarURL: applicantAvatar, + TargetID: req.Flag, + TargetTitle: group.Name, + TargetType: string(model.GroupJoinRequestTypeJoinApply), + GroupID: group.ID, + GroupName: group.Name, + GroupAvatar: group.Avatar, + GroupDescription: group.Description, + Flag: req.Flag, + RequestType: string(req.RequestType), + RequestStatus: string(req.Status), + TargetUserID: targetUserID, + TargetUserName: targetUserName, + TargetUserAvatar: targetUserAvatar, + }) + } +} + +func (s *groupService) sendGroupInviteToTarget(group *model.Group, req *model.GroupJoinRequest, inviterName, inviterAvatar string) { + if group == nil || req == nil { + return + } + s.createSystemNotification(req.TargetUserID, model.SysNotifyGroupInvite, inviterName+" 邀请你加入群聊 "+group.Name, &model.SystemNotificationExtra{ + ActorIDStr: req.InitiatorID, + ActorName: inviterName, + AvatarURL: inviterAvatar, + TargetID: req.Flag, + TargetTitle: group.Name, + TargetType: string(model.GroupJoinRequestTypeInvite), + GroupID: group.ID, + GroupName: group.Name, + GroupAvatar: group.Avatar, + GroupDescription: group.Description, + Flag: req.Flag, + RequestType: string(req.RequestType), + RequestStatus: string(req.Status), + TargetUserID: req.TargetUserID, + TargetUserName: "", + TargetUserAvatar: "", + }) +} + +// InviteMembers 邀请成员 +func (s *groupService) InviteMembers(userID string, groupID string, memberIDs []string) error { + group, err := s.groupRepo.GetByID(groupID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGroupNotFound + } + return err + } + + isInviterMember, err := s.groupRepo.IsMember(groupID, userID) + if err != nil { + return err + } + if !isInviterMember { + return ErrNotGroupMember + } + isInviterAdmin := s.IsGroupAdmin(userID, groupID) + + inviter, _ := s.userRepo.GetByID(userID) + inviterName := "群成员" + inviterAvatar := "" + if inviter != nil && inviter.Nickname != "" { + inviterName = inviter.Nickname + } + if inviter != nil { + inviterAvatar = inviter.Avatar + } + + createdCount := 0 + for _, memberID := range memberIDs { + if memberID == "" || memberID == userID { + continue + } + isMember, err := s.groupRepo.IsMember(groupID, memberID) + if err != nil { + continue + } + if isMember { + continue + } + + isFollowing, err := s.userRepo.IsFollowing(userID, memberID) + if err != nil || !isFollowing { + continue + } + isFollowedBack, err := s.userRepo.IsFollowing(memberID, userID) + if err != nil || !isFollowedBack { + continue + } + + if _, err := s.requestRepo.GetPendingByGroupAndTarget(groupID, memberID, model.GroupJoinRequestTypeInvite); err == nil { + continue + } else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + continue + } + expireAt := time.Now().Add(72 * time.Hour) + req := &model.GroupJoinRequest{ + Flag: s.newRequestFlag(), + GroupID: groupID, + InitiatorID: userID, + TargetUserID: memberID, + RequestType: model.GroupJoinRequestTypeInvite, + Status: model.GroupJoinRequestStatusPending, + ExpireAt: &expireAt, + } + if err := s.requestRepo.Create(req); err != nil { + continue + } + + if isInviterAdmin { + // 群主/管理员邀请:直接发送邀请卡片,等待被邀请人确认 + s.sendGroupInviteToTarget(group, req, inviterName, inviterAvatar) + createdCount++ + continue + } + + inviteeName := "用户" + inviteeAvatar := "" + if invitee, e := s.userRepo.GetByID(memberID); e == nil && invitee != nil { + if invitee.Nickname != "" { + inviteeName = invitee.Nickname + } + inviteeAvatar = invitee.Avatar + } + reviewerIDs := s.collectGroupReviewerIDs(group.ID, group.OwnerID, userID) + for _, reviewerID := range reviewerIDs { + s.createSystemNotification(reviewerID, model.SysNotifyGroupJoinApply, inviterName+" 邀请 "+inviteeName+" 加入群聊 "+group.Name+",请审批", &model.SystemNotificationExtra{ + ActorIDStr: userID, + ActorName: inviterName, + AvatarURL: inviterAvatar, + TargetID: req.Flag, + TargetTitle: group.Name, + TargetType: string(model.GroupJoinRequestTypeInvite), + GroupID: group.ID, + GroupName: group.Name, + GroupAvatar: group.Avatar, + GroupDescription: group.Description, + Flag: req.Flag, + RequestType: string(req.RequestType), + RequestStatus: string(req.Status), + TargetUserID: memberID, + TargetUserName: inviteeName, + TargetUserAvatar: inviteeAvatar, + }) + } + createdCount++ + } + + if createdCount == 0 { + return ErrNoEligibleInvitee + } + return nil +} + +// JoinGroup 加入群组 +func (s *groupService) JoinGroup(userID string, groupID string) error { + group, err := s.groupRepo.GetByID(groupID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGroupNotFound + } + return err + } + + // 检查是否已经是群成员 + isMember, err := s.groupRepo.IsMember(groupID, userID) + if err != nil { + return err + } + if isMember { + return ErrAlreadyMember + } + + if group.JoinType == model.JoinTypeForbidden { + return ErrCannotJoin + } + + if group.JoinType == model.JoinTypeApproval { + if pendingReq, err := s.requestRepo.GetPendingByGroupAndTarget(groupID, userID, model.GroupJoinRequestTypeJoinApply); err == nil { + applicant, _ := s.userRepo.GetByID(userID) + applicantName := "用户" + applicantAvatar := "" + if applicant != nil && applicant.Nickname != "" { + applicantName = applicant.Nickname + } + if applicant != nil { + applicantAvatar = applicant.Avatar + } + // 已有待审批单时补发一次提醒,避免管理端漏看 + s.notifyJoinApplyReviewers( + group, + pendingReq, + applicantName, + applicantAvatar, + userID, + applicantName, + applicantAvatar, + ) + return ErrJoinRequestPending + } else if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + applicant, _ := s.userRepo.GetByID(userID) + applicantName := "用户" + applicantAvatar := "" + if applicant != nil && applicant.Nickname != "" { + applicantName = applicant.Nickname + } + if applicant != nil { + applicantAvatar = applicant.Avatar + } + expireAt := time.Now().Add(72 * time.Hour) + req := &model.GroupJoinRequest{ + Flag: s.newRequestFlag(), + GroupID: groupID, + InitiatorID: userID, + TargetUserID: userID, + RequestType: model.GroupJoinRequestTypeJoinApply, + Status: model.GroupJoinRequestStatusPending, + ExpireAt: &expireAt, + } + if err := s.requestRepo.Create(req); err != nil { + return err + } + + s.notifyJoinApplyReviewers( + group, + req, + applicantName, + applicantAvatar, + userID, + applicantName, + applicantAvatar, + ) + return ErrJoinRequestPending + } + + return s.addMemberToGroupAndConversation(group, userID, userID) +} + +func (s *groupService) RespondInvite(userID string, flag string, approve bool, reason string) error { + req, err := s.requestRepo.GetByFlag(flag) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGroupRequestNotFound + } + return err + } + if req.RequestType != model.GroupJoinRequestTypeInvite { + return ErrGroupRequestNotFound + } + if req.Status != model.GroupJoinRequestStatusPending { + return ErrGroupRequestHandled + } + if req.TargetUserID != userID { + return ErrNotRequestTarget + } + + group, err := s.groupRepo.GetByID(req.GroupID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGroupNotFound + } + return err + } + + now := time.Now() + req.ReviewerID = userID + req.ReviewedAt = &now + req.Reason = reason + + if approve { + if err := s.addMemberToGroupAndConversation(group, userID, userID); err != nil { + return err + } + req.Status = model.GroupJoinRequestStatusAccepted + s.createSystemNotification(req.InitiatorID, model.SysNotifySystem, "你发出的群邀请已被接受", &model.SystemNotificationExtra{ + GroupID: group.ID, + GroupName: group.Name, + Flag: req.Flag, + RequestType: string(req.RequestType), + RequestStatus: string(model.GroupJoinRequestStatusAccepted), + }) + } else { + req.Status = model.GroupJoinRequestStatusRejected + s.createSystemNotification(req.InitiatorID, model.SysNotifySystem, "你发出的群邀请已被拒绝", &model.SystemNotificationExtra{ + GroupID: group.ID, + GroupName: group.Name, + Flag: req.Flag, + RequestType: string(req.RequestType), + RequestStatus: string(model.GroupJoinRequestStatusRejected), + Reason: reason, + }) + } + + return s.requestRepo.Update(req) +} + +func (s *groupService) SetGroupAddRequest(userID string, flag string, approve bool, reason string) error { + req, err := s.requestRepo.GetByFlag(flag) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGroupRequestNotFound + } + return err + } + if req.RequestType != model.GroupJoinRequestTypeJoinApply && req.RequestType != model.GroupJoinRequestTypeInvite { + return ErrGroupRequestNotFound + } + if req.Status != model.GroupJoinRequestStatusPending { + return ErrGroupRequestHandled + } + if req.RequestType == model.GroupJoinRequestTypeInvite && req.ReviewerID != "" { + // invite 类型中 reviewerID 非空表示已完成管理员审批,等待被邀请人确认 + return ErrGroupRequestHandled + } + if !s.IsGroupAdmin(userID, req.GroupID) { + return ErrNotGroupAdmin + } + + group, err := s.groupRepo.GetByID(req.GroupID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGroupNotFound + } + return err + } + + now := time.Now() + req.ReviewerID = userID + req.ReviewedAt = &now + req.Reason = reason + targetUserName := "用户" + if u, e := s.userRepo.GetByID(req.TargetUserID); e == nil && u != nil && u.Nickname != "" { + targetUserName = u.Nickname + } + reviewerName := "管理员" + reviewerAvatar := "" + if u, e := s.userRepo.GetByID(userID); e == nil && u != nil { + if u.Nickname != "" { + reviewerName = u.Nickname + } + reviewerAvatar = u.Avatar + } + + if approve { + if req.RequestType == model.GroupJoinRequestTypeInvite { + // 管理员审批通过后,不直接拉人;发送邀请卡片等待对方确认 + inviterName := "群成员" + inviterAvatar := "" + if u, e := s.userRepo.GetByID(req.InitiatorID); e == nil && u != nil { + if u.Nickname != "" { + inviterName = u.Nickname + } + inviterAvatar = u.Avatar + } + s.sendGroupInviteToTarget(group, req, inviterName, inviterAvatar) + s.createSystemNotification(req.InitiatorID, model.SysNotifySystem, "你邀请 "+targetUserName+" 入群已通过审批,等待对方确认", &model.SystemNotificationExtra{ + GroupID: group.ID, + GroupName: group.Name, + GroupAvatar: group.Avatar, + Flag: req.Flag, + RequestType: string(req.RequestType), + RequestStatus: string(model.GroupJoinRequestStatusPending), + }) + } else { + if err := s.addMemberToGroupAndConversation(group, req.TargetUserID, userID); err != nil { + return err + } + req.Status = model.GroupJoinRequestStatusAccepted + s.createSystemNotification(req.TargetUserID, model.SysNotifyGroupJoinApproved, "你申请加入群聊 "+group.Name+" 已通过", &model.SystemNotificationExtra{ + GroupID: group.ID, + GroupName: group.Name, + GroupAvatar: group.Avatar, + Flag: req.Flag, + RequestType: string(req.RequestType), + RequestStatus: string(model.GroupJoinRequestStatusAccepted), + }) + } + // 同步通知其他可审批人:该请求已被处理 + reviewerIDs := s.collectGroupReviewerIDs(group.ID, group.OwnerID, "") + for _, reviewerID := range reviewerIDs { + if reviewerID == userID { + continue + } + s.createSystemNotification(reviewerID, model.SysNotifyGroupJoinApply, reviewerName+" 已同意该入群请求", &model.SystemNotificationExtra{ + ActorIDStr: userID, + ActorName: reviewerName, + AvatarURL: reviewerAvatar, + GroupID: group.ID, + GroupName: group.Name, + GroupAvatar: group.Avatar, + GroupDescription: group.Description, + Flag: req.Flag, + RequestType: string(req.RequestType), + RequestStatus: string(model.GroupJoinRequestStatusAccepted), + TargetUserID: req.TargetUserID, + TargetUserName: targetUserName, + }) + } + } else { + req.Status = model.GroupJoinRequestStatusRejected + if req.RequestType == model.GroupJoinRequestTypeInvite { + // 成员邀请被管理员拒绝,仅通知邀请人 + s.createSystemNotification(req.InitiatorID, model.SysNotifySystem, "你邀请 "+targetUserName+" 入群未通过审批", &model.SystemNotificationExtra{ + GroupID: group.ID, + GroupName: group.Name, + GroupAvatar: group.Avatar, + Flag: req.Flag, + RequestType: string(req.RequestType), + RequestStatus: string(model.GroupJoinRequestStatusRejected), + Reason: reason, + }) + } else { + s.createSystemNotification(req.TargetUserID, model.SysNotifyGroupJoinRejected, "你申请加入群聊 "+group.Name+" 被拒绝", &model.SystemNotificationExtra{ + GroupID: group.ID, + GroupName: group.Name, + GroupAvatar: group.Avatar, + Flag: req.Flag, + RequestType: string(req.RequestType), + RequestStatus: string(model.GroupJoinRequestStatusRejected), + Reason: reason, + }) + } + // 同步通知其他可审批人:该请求已被处理 + reviewerIDs := s.collectGroupReviewerIDs(group.ID, group.OwnerID, "") + for _, reviewerID := range reviewerIDs { + if reviewerID == userID { + continue + } + s.createSystemNotification(reviewerID, model.SysNotifyGroupJoinApply, reviewerName+" 已拒绝该入群请求", &model.SystemNotificationExtra{ + ActorIDStr: userID, + ActorName: reviewerName, + AvatarURL: reviewerAvatar, + GroupID: group.ID, + GroupName: group.Name, + GroupAvatar: group.Avatar, + GroupDescription: group.Description, + Flag: req.Flag, + RequestType: string(req.RequestType), + RequestStatus: string(model.GroupJoinRequestStatusRejected), + Reason: reason, + TargetUserID: req.TargetUserID, + TargetUserName: targetUserName, + }) + } + } + return s.requestRepo.Update(req) +} + +// LeaveGroup 退出群组 +func (s *groupService) LeaveGroup(userID string, groupID string) error { + // 检查群组是否存在 + group, err := s.groupRepo.GetByID(groupID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGroupNotFound + } + return err + } + + // 群主不能退出群 + if group.OwnerID == userID { + return ErrCannotRemoveOwner + } + + // 移除群成员 + if err := s.groupRepo.RemoveMember(groupID, userID); err != nil { + return err + } + + // 移除会话参与者记录 + // 根据群组ID查找对应的会话 + conv, err := s.messageRepo.GetConversationByGroupID(groupID) + if err != nil { + // 如果找不到会话,记录日志但不阻塞退出群流程 + fmt.Printf("[WARN] LeaveGroup: conversation not found for group %s, error: %v\n", groupID, err) + } else { + // 移除该用户在会话中的参与者记录 + if err := s.messageRepo.RemoveParticipant(conv.ID, userID); err != nil { + // 如果移除参与者失败,记录日志但不阻塞退出群流程 + fmt.Printf("[WARN] LeaveGroup: failed to remove participant %s from conversation %s, error: %v\n", userID, conv.ID, err) + } + } + + // 失效群组成员缓存 + cache.InvalidateGroupMembers(s.cache, groupID) + + return nil +} + +// RemoveMember 移除成员 +func (s *groupService) RemoveMember(userID string, groupID string, targetUserID string) error { + // 检查群组是否存在 + group, err := s.groupRepo.GetByID(groupID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGroupNotFound + } + return err + } + + // 不能移除群主 + if targetUserID == group.OwnerID { + return ErrCannotRemoveOwner + } + + // 检查权限 + targetRole, err := s.groupRepo.GetMemberRole(groupID, targetUserID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotGroupMember + } + return err + } + + // 群主可以移除任何人 + // 管理员只能移除普通成员 + if group.OwnerID != userID { + if targetRole == model.GroupRoleOwner { + return ErrNotGroupOwner + } + if targetRole == model.GroupRoleAdmin && !s.IsGroupOwner(userID, groupID) { + return ErrNotGroupAdmin + } + } + + // 移除群成员 + if err := s.groupRepo.RemoveMember(groupID, targetUserID); err != nil { + return err + } + + // 同时移除会话参与者 + if s.messageRepo != nil { + conv, err := s.messageRepo.GetConversationByGroupID(groupID) + if err == nil && conv != nil { + if err := s.messageRepo.RemoveParticipant(conv.ID, targetUserID); err != nil { + log.Printf("[RemoveMember] 移除会话参与者失败: groupID=%s, userID=%s, err=%v", groupID, targetUserID, err) + } + } + } + + // 失效群组成员缓存 + cache.InvalidateGroupMembers(s.cache, groupID) + + return nil +} + +// GetMembers 获取群成员列表(带缓存) +func (s *groupService) GetMembers(groupID string, page, pageSize int) ([]model.GroupMember, int64, error) { + cacheSettings := cache.GetSettings() + groupMembersTTL := cacheSettings.GroupMembersTTL + if groupMembersTTL <= 0 { + groupMembersTTL = GroupMembersTTL + } + nullTTL := cacheSettings.NullTTL + if nullTTL <= 0 { + nullTTL = GroupMembersNullTTL + } + jitter := cacheSettings.JitterRatio + if jitter <= 0 { + jitter = GroupCacheJitter + } + + // 生成缓存键 + cacheKey := cache.GroupMembersKey(groupID, page, pageSize) + result, err := cache.GetOrLoadTyped[*GroupMembersResult]( + s.cache, + cacheKey, + groupMembersTTL, + jitter, + nullTTL, + func() (*GroupMembersResult, error) { + members, total, err := s.groupRepo.GetMembers(groupID, page, pageSize) + if err != nil { + return nil, err + } + return &GroupMembersResult{ + Members: members, + Total: total, + }, nil + }, + ) + if err != nil { + return nil, 0, err + } + if result == nil { + return []model.GroupMember{}, 0, nil + } + return result.Members, result.Total, nil +} + +// SetMemberRole 设置成员角色 +func (s *groupService) SetMemberRole(userID string, groupID string, targetUserID string, role string) error { + // 检查群组是否存在 + group, err := s.groupRepo.GetByID(groupID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGroupNotFound + } + return err + } + + // 检查权限:只有群主可以设置角色 + if group.OwnerID != userID { + return ErrNotGroupOwner + } + + // 不能修改群主的角色 + if targetUserID == group.OwnerID { + return ErrCannotRemoveOwner + } + + // 验证角色值 + if role != model.GroupRoleAdmin && role != model.GroupRoleMember { + return errors.New("无效的角色") + } + + err = s.groupRepo.SetMemberRole(groupID, targetUserID, role) + if err != nil { + return err + } + + // 失效群组成员缓存 + cache.InvalidateGroupMembers(s.cache, groupID) + + return nil +} + +// SetMemberNickname 设置群内昵称 +func (s *groupService) SetMemberNickname(userID string, groupID string, nickname string) error { + // 获取成员信息 + member, err := s.groupRepo.GetMember(groupID, userID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotGroupMember + } + return err + } + + member.Nickname = nickname + err = s.groupRepo.UpdateMember(member) + if err != nil { + return err + } + + // 失效群组成员缓存 + cache.InvalidateGroupMembers(s.cache, groupID) + + return nil +} + +// MuteMember 禁言成员 +func (s *groupService) MuteMember(userID string, groupID string, targetUserID string, muted bool) error { + // 检查群组是否存在 + group, err := s.groupRepo.GetByID(groupID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGroupNotFound + } + return err + } + + // 不能禁言群主 + if targetUserID == group.OwnerID { + return ErrCannotMuteOwner + } + + // 检查权限:群主可以禁言任何人,管理员只能禁言普通成员 + if group.OwnerID != userID { + targetRole, err := s.groupRepo.GetMemberRole(groupID, targetUserID) + if err != nil { + return err + } + if targetRole == model.GroupRoleAdmin { + return ErrNotGroupAdmin + } + } + + member, err := s.groupRepo.GetMember(groupID, targetUserID) + if err != nil { + log.Printf("[MuteMember] 获取成员失败: %v", err) + return err + } + + log.Printf("[MuteMember] 禁言前状态: member.Muted=%v, 即将设置 muted=%v", member.Muted, muted) + member.Muted = muted + if err := s.groupRepo.UpdateMember(member); err != nil { + log.Printf("[MuteMember] 更新成员失败: %v", err) + return err + } + log.Printf("[MuteMember] 禁言状态已更新到数据库") + + // 验证更新结果 + updatedMember, _ := s.groupRepo.GetMember(groupID, targetUserID) + if updatedMember != nil { + log.Printf("[MuteMember] 验证: member.Muted=%v", updatedMember.Muted) + } + + // 获取被禁言用户的显示名称 + targetUser, _ := s.userRepo.GetByID(targetUserID) + targetUserName := "用户" + if targetUser != nil { + targetUserName = targetUser.Nickname + } + + // 构建通知内容 + noticeType := "muted" + noticeContent := "\"" + targetUserName + "\" 已被管理员禁言" + if !muted { + noticeType = "unmuted" + noticeContent = "\"" + targetUserName + "\" 已被管理员解除禁言" + } + + // 保存禁言/解禁消息到数据库 + var savedMessage *model.Message + if s.messageRepo != nil { + // 获取群组会话 + conv, err := s.messageRepo.GetConversationByGroupID(groupID) + if err == nil && conv != nil { + // 创建系统消息 + msg := &model.Message{ + ConversationID: conv.ID, + SenderID: model.SystemSenderIDStr, + Segments: model.MessageSegments{ + {Type: "text", Data: map[string]interface{}{"text": noticeContent}}, + }, + Status: model.MessageStatusNormal, + Category: model.CategoryNotification, + } + + // 保存消息并获取 seq + if err := s.messageRepo.CreateMessageWithSeq(msg); err != nil { + log.Printf("[MuteMember] 保存禁言消息失败: %v", err) + } else { + savedMessage = msg + log.Printf("[MuteMember] 禁言消息已保存, ID=%s, Seq=%d", msg.ID, msg.Seq) + } + } else { + log.Printf("[MuteMember] 获取群组会话失败: %v", err) + } + } + + // 发送WebSocket通知给群成员 + if s.wsManager != nil { + log.Printf("[MuteMember] 准备发送禁言通知: groupID=%s, targetUserID=%s, noticeType=%s, operatorID=%s", groupID, targetUserID, noticeType, userID) + + // 构建通知消息,包含保存的消息信息 + noticeMsg := websocket.GroupNoticeMessage{ + NoticeType: noticeType, + GroupID: groupID, + Data: websocket.GroupNoticeData{ + UserID: targetUserID, + OperatorID: userID, + }, + Timestamp: time.Now().UnixMilli(), + } + + // 如果消息已保存,添加消息ID和seq + if savedMessage != nil { + noticeMsg.MessageID = savedMessage.ID + noticeMsg.Seq = savedMessage.Seq + } + + wsMsg := websocket.CreateWSMessage(websocket.MessageTypeGroupNotice, noticeMsg) + log.Printf("[MuteMember] 创建的WebSocket消息: Type=%s, Data=%+v", wsMsg.Type, wsMsg.Data) + + // 获取所有群成员并发送通知 + members, _, err := s.groupRepo.GetMembers(groupID, 1, 1000) + if err == nil { + log.Printf("[MuteMember] 获取到群成员数量: %d", len(members)) + for _, m := range members { + isOnline := s.wsManager.IsUserOnline(m.UserID) + log.Printf("[MuteMember] 成员 %s 在线状态: %v", m.UserID, isOnline) + if isOnline { + s.wsManager.SendToUser(m.UserID, wsMsg) + log.Printf("[MuteMember] 已发送通知给成员: %s", m.UserID) + } + } + } else { + log.Printf("[MuteMember] 获取群成员失败: %v", err) + } + } + + // 失效群组成员缓存 + cache.InvalidateGroupMembers(s.cache, groupID) + + return nil +} + +// ==================== 群设置 ==================== + +// SetMuteAll 全员禁言 +func (s *groupService) SetMuteAll(userID string, groupID string, muteAll bool) error { + // 检查群组是否存在 + group, err := s.groupRepo.GetByID(groupID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGroupNotFound + } + return err + } + + // 检查权限:只有群主可以设置全员禁言 + if group.OwnerID != userID { + return ErrNotGroupOwner + } + + group.MuteAll = muteAll + return s.groupRepo.Update(group) +} + +// SetJoinType 设置加群方式 +func (s *groupService) SetJoinType(userID string, groupID string, joinType int) error { + // 检查群组是否存在 + group, err := s.groupRepo.GetByID(groupID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrGroupNotFound + } + return err + } + + // 检查权限:只有群主可以设置加群方式 + if group.OwnerID != userID { + return ErrNotGroupOwner + } + + // 验证加群方式 + if joinType < 0 || joinType > 2 { + return errors.New("无效的加群方式") + } + + group.JoinType = model.JoinType(joinType) + return s.groupRepo.Update(group) +} + +// ==================== 群公告 ==================== + +// CreateAnnouncement 创建群公告 +func (s *groupService) CreateAnnouncement(userID string, groupID string, content string) (*model.GroupAnnouncement, error) { + // 检查群组是否存在 + _, err := s.groupRepo.GetByID(groupID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrGroupNotFound + } + return nil, err + } + + // 检查权限:只有群主和管理员可以发布公告 + if !s.IsGroupAdmin(userID, groupID) { + return nil, ErrNotGroupAdmin + } + + announcement := &model.GroupAnnouncement{ + GroupID: groupID, + Content: content, + AuthorID: userID, + } + + if err := s.groupRepo.CreateAnnouncement(announcement); err != nil { + return nil, err + } + + return announcement, nil +} + +// GetAnnouncements 获取群公告列表 +func (s *groupService) GetAnnouncements(groupID string, page, pageSize int) ([]model.GroupAnnouncement, int64, error) { + return s.groupRepo.GetAnnouncements(groupID, page, pageSize) +} + +// DeleteAnnouncement 删除群公告 +func (s *groupService) DeleteAnnouncement(userID string, announcementID string) error { + // 获取公告 + announcement, err := s.groupRepo.GetAnnouncementByID(announcementID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("公告不存在") + } + return err + } + + // 检查群组是否存在 + group, err := s.groupRepo.GetByID(announcement.GroupID) + if err != nil { + return err + } + + // 检查权限:只有群主和管理员可以删除公告 + if !s.IsGroupAdmin(userID, group.ID) { + return ErrNotGroupAdmin + } + + return s.groupRepo.DeleteAnnouncement(announcementID) +} + +// ==================== 权限检查 ==================== + +// CanSendGroupMessage 检查是否可以发送群消息 +func (s *groupService) CanSendGroupMessage(userID string, groupID string) error { + // 检查是否是群成员 + member, err := s.groupRepo.GetMember(groupID, userID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotGroupMember + } + return err + } + + // 检查是否被禁言 + if member.Muted { + return ErrMuted + } + + // 检查是否全员禁言 + group, err := s.groupRepo.GetByID(groupID) + if err != nil { + return err + } + if group.MuteAll { + return ErrMuteAllEnabled + } + + return nil +} + +// IsGroupAdmin 检查是否是群管理员 +func (s *groupService) IsGroupAdmin(userID string, groupID string) bool { + role, err := s.groupRepo.GetMemberRole(groupID, userID) + if err != nil { + return false + } + return role == model.GroupRoleOwner || role == model.GroupRoleAdmin +} + +// IsGroupOwner 检查是否是群主 +func (s *groupService) IsGroupOwner(userID string, groupID string) bool { + group, err := s.groupRepo.GetByID(groupID) + if err != nil { + return false + } + return group.OwnerID == userID +} + +// GetMember 获取指定用户在群组中的成员信息 +func (s *groupService) GetMember(groupID string, userID string) (*model.GroupMember, error) { + return s.groupRepo.GetMember(groupID, userID) +} diff --git a/internal/service/jwt_service.go b/internal/service/jwt_service.go new file mode 100644 index 0000000..3344737 --- /dev/null +++ b/internal/service/jwt_service.go @@ -0,0 +1,38 @@ +package service + +import ( + "carrot_bbs/internal/pkg/jwt" + "time" +) + +// JWTService JWT服务 +type JWTService struct { + jwt *jwt.JWT +} + +// NewJWTService 创建JWT服务 +func NewJWTService(secret string, accessExpire, refreshExpire int64) *JWTService { + return &JWTService{ + jwt: jwt.New(secret, time.Duration(accessExpire)*time.Second, time.Duration(refreshExpire)*time.Second), + } +} + +// GenerateAccessToken 生成访问令牌 +func (s *JWTService) GenerateAccessToken(userID, username string) (string, error) { + return s.jwt.GenerateAccessToken(userID, username) +} + +// GenerateRefreshToken 生成刷新令牌 +func (s *JWTService) GenerateRefreshToken(userID, username string) (string, error) { + return s.jwt.GenerateRefreshToken(userID, username) +} + +// ParseToken 解析令牌 +func (s *JWTService) ParseToken(tokenString string) (*jwt.Claims, error) { + return s.jwt.ParseToken(tokenString) +} + +// ValidateToken 验证令牌 +func (s *JWTService) ValidateToken(tokenString string) error { + return s.jwt.ValidateToken(tokenString) +} diff --git a/internal/service/message_service.go b/internal/service/message_service.go new file mode 100644 index 0000000..f8e7482 --- /dev/null +++ b/internal/service/message_service.go @@ -0,0 +1,215 @@ +package service + +import ( + "context" + "time" + + "carrot_bbs/internal/cache" + "carrot_bbs/internal/model" + "carrot_bbs/internal/repository" +) + +// 缓存TTL常量 +const ( + ConversationListTTL = 60 * time.Second // 会话列表缓存60秒 + ConversationDetailTTL = 60 * time.Second // 会话详情缓存60秒 + UnreadCountTTL = 30 * time.Second // 未读数缓存30秒 + ConversationNullTTL = 5 * time.Second + UnreadNullTTL = 5 * time.Second + CacheJitterRatio = 0.1 +) + +// MessageService 消息服务 +type MessageService struct { + messageRepo *repository.MessageRepository + cache cache.Cache +} + +// NewMessageService 创建消息服务 +func NewMessageService(messageRepo *repository.MessageRepository) *MessageService { + return &MessageService{ + messageRepo: messageRepo, + cache: cache.GetCache(), + } +} + +// ConversationListResult 会话列表缓存结果 +type ConversationListResult struct { + Conversations []*model.Conversation + Total int64 +} + +// SendMessage 发送消息(使用 segments) +// senderID 和 receiverID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 +func (s *MessageService) SendMessage(ctx context.Context, senderID, receiverID string, segments model.MessageSegments) (*model.Message, error) { + // 获取或创建会话 + conv, err := s.messageRepo.GetOrCreatePrivateConversation(senderID, receiverID) + if err != nil { + return nil, err + } + + msg := &model.Message{ + ConversationID: conv.ID, + SenderID: senderID, + Segments: segments, + Status: model.MessageStatusNormal, + } + + // 使用事务创建消息并更新seq + err = s.messageRepo.CreateMessageWithSeq(msg) + if err != nil { + return nil, err + } + + // 失效会话列表缓存(发送者和接收者) + cache.InvalidateConversationList(s.cache, senderID) + cache.InvalidateConversationList(s.cache, receiverID) + + // 失效未读数缓存 + cache.InvalidateUnreadConversation(s.cache, receiverID) + cache.InvalidateUnreadDetail(s.cache, receiverID, conv.ID) + + return msg, nil +} + +// GetConversations 获取会话列表(带缓存) +// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 +func (s *MessageService) GetConversations(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) { + cacheSettings := cache.GetSettings() + conversationTTL := cacheSettings.ConversationTTL + if conversationTTL <= 0 { + conversationTTL = ConversationListTTL + } + nullTTL := cacheSettings.NullTTL + if nullTTL <= 0 { + nullTTL = ConversationNullTTL + } + jitter := cacheSettings.JitterRatio + if jitter <= 0 { + jitter = CacheJitterRatio + } + + // 生成缓存键 + cacheKey := cache.ConversationListKey(userID, page, pageSize) + result, err := cache.GetOrLoadTyped[*ConversationListResult]( + s.cache, + cacheKey, + conversationTTL, + jitter, + nullTTL, + func() (*ConversationListResult, error) { + conversations, total, err := s.messageRepo.GetConversations(userID, page, pageSize) + if err != nil { + return nil, err + } + return &ConversationListResult{ + Conversations: conversations, + Total: total, + }, nil + }, + ) + if err != nil { + return nil, 0, err + } + if result == nil { + return []*model.Conversation{}, 0, nil + } + return result.Conversations, result.Total, nil +} + +// GetMessages 获取消息列表 +func (s *MessageService) GetMessages(ctx context.Context, conversationID string, page, pageSize int) ([]*model.Message, int64, error) { + return s.messageRepo.GetMessages(conversationID, page, pageSize) +} + +// GetMessagesAfterSeq 获取指定seq之后的消息(增量同步) +func (s *MessageService) GetMessagesAfterSeq(ctx context.Context, conversationID string, afterSeq int64, limit int) ([]*model.Message, error) { + return s.messageRepo.GetMessagesAfterSeq(conversationID, afterSeq, limit) +} + +// MarkAsRead 标记为已读 +// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 +func (s *MessageService) MarkAsRead(ctx context.Context, conversationID string, userID string, lastReadSeq int64) error { + err := s.messageRepo.UpdateLastReadSeq(conversationID, userID, lastReadSeq) + if err != nil { + return err + } + + // 失效未读数缓存 + cache.InvalidateUnreadConversation(s.cache, userID) + cache.InvalidateUnreadDetail(s.cache, userID, conversationID) + + // 失效会话列表缓存 + cache.InvalidateConversationList(s.cache, userID) + + return nil +} + +// GetUnreadCount 获取未读消息数(带缓存) +// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 +func (s *MessageService) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) { + cacheSettings := cache.GetSettings() + unreadTTL := cacheSettings.UnreadCountTTL + if unreadTTL <= 0 { + unreadTTL = UnreadCountTTL + } + nullTTL := cacheSettings.NullTTL + if nullTTL <= 0 { + nullTTL = UnreadNullTTL + } + jitter := cacheSettings.JitterRatio + if jitter <= 0 { + jitter = CacheJitterRatio + } + + // 生成缓存键 + cacheKey := cache.UnreadDetailKey(userID, conversationID) + + return cache.GetOrLoadTyped[int64]( + s.cache, + cacheKey, + unreadTTL, + jitter, + nullTTL, + func() (int64, error) { + return s.messageRepo.GetUnreadCount(conversationID, userID) + }, + ) +} + +// GetOrCreateConversation 获取或创建私聊会话 +// user1ID 和 user2ID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 +func (s *MessageService) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) { + conv, err := s.messageRepo.GetOrCreatePrivateConversation(user1ID, user2ID) + if err != nil { + return nil, err + } + + // 失效会话列表缓存 + cache.InvalidateConversationList(s.cache, user1ID) + cache.InvalidateConversationList(s.cache, user2ID) + + return conv, nil +} + +// GetConversationParticipants 获取会话参与者列表 +func (s *MessageService) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) { + return s.messageRepo.GetConversationParticipants(conversationID) +} + +// ParseConversationID 辅助函数:直接返回字符串ID(已经是string类型) +func ParseConversationID(idStr string) (string, error) { + return idStr, nil +} + +// InvalidateUserConversationCache 失效用户会话相关缓存(供外部调用) +func (s *MessageService) InvalidateUserConversationCache(userID string) { + cache.InvalidateConversationList(s.cache, userID) + cache.InvalidateUnreadConversation(s.cache, userID) +} + +// InvalidateUserUnreadCache 失效用户未读数缓存(供外部调用) +func (s *MessageService) InvalidateUserUnreadCache(userID, conversationID string) { + cache.InvalidateUnreadConversation(s.cache, userID) + cache.InvalidateUnreadDetail(s.cache, userID, conversationID) +} diff --git a/internal/service/notification_service.go b/internal/service/notification_service.go new file mode 100644 index 0000000..e44ba1a --- /dev/null +++ b/internal/service/notification_service.go @@ -0,0 +1,169 @@ +package service + +import ( + "context" + "time" + + "carrot_bbs/internal/cache" + "carrot_bbs/internal/model" + "carrot_bbs/internal/repository" +) + +// 缓存TTL常量 +const ( + NotificationUnreadCountTTL = 30 * time.Second // 通知未读数缓存30秒 + NotificationNullTTL = 5 * time.Second + NotificationCacheJitter = 0.1 +) + +// NotificationService 通知服务 +type NotificationService struct { + notificationRepo *repository.NotificationRepository + cache cache.Cache +} + +// NewNotificationService 创建通知服务 +func NewNotificationService(notificationRepo *repository.NotificationRepository) *NotificationService { + return &NotificationService{ + notificationRepo: notificationRepo, + cache: cache.GetCache(), + } +} + +// Create 创建通知 +func (s *NotificationService) Create(ctx context.Context, userID string, notificationType model.NotificationType, title, content string) (*model.Notification, error) { + notification := &model.Notification{ + UserID: userID, + Type: notificationType, + Title: title, + Content: content, + IsRead: false, + } + + err := s.notificationRepo.Create(notification) + if err != nil { + return nil, err + } + + // 失效未读数缓存 + cache.InvalidateUnreadSystem(s.cache, userID) + + return notification, nil +} + +// GetByUserID 获取用户通知 +func (s *NotificationService) GetByUserID(ctx context.Context, userID string, page, pageSize int, unreadOnly bool) ([]*model.Notification, int64, error) { + return s.notificationRepo.GetByUserID(userID, page, pageSize, unreadOnly) +} + +// MarkAsRead 标记为已读 +func (s *NotificationService) MarkAsRead(ctx context.Context, id string) error { + err := s.notificationRepo.MarkAsRead(id) + if err != nil { + return err + } + + // 注意:这里无法获取userID,所以不在缓存中失效 + // 调用方应该使用MarkAsReadWithUserID方法 + + return nil +} + +// MarkAsReadWithUserID 标记为已读(带用户ID,用于缓存失效) +func (s *NotificationService) MarkAsReadWithUserID(ctx context.Context, id, userID string) error { + err := s.notificationRepo.MarkAsRead(id) + if err != nil { + return err + } + + // 失效未读数缓存 + cache.InvalidateUnreadSystem(s.cache, userID) + + return nil +} + +// MarkAllAsRead 标记所有为已读 +func (s *NotificationService) MarkAllAsRead(ctx context.Context, userID string) error { + err := s.notificationRepo.MarkAllAsRead(userID) + if err != nil { + return err + } + + // 失效未读数缓存 + cache.InvalidateUnreadSystem(s.cache, userID) + + return nil +} + +// Delete 删除通知 +func (s *NotificationService) Delete(ctx context.Context, id string) error { + return s.notificationRepo.Delete(id) +} + +// GetUnreadCount 获取未读数量(带缓存) +func (s *NotificationService) GetUnreadCount(ctx context.Context, userID string) (int64, error) { + cacheSettings := cache.GetSettings() + unreadTTL := cacheSettings.UnreadCountTTL + if unreadTTL <= 0 { + unreadTTL = NotificationUnreadCountTTL + } + nullTTL := cacheSettings.NullTTL + if nullTTL <= 0 { + nullTTL = NotificationNullTTL + } + jitter := cacheSettings.JitterRatio + if jitter <= 0 { + jitter = NotificationCacheJitter + } + + // 生成缓存键 + cacheKey := cache.UnreadSystemKey(userID) + return cache.GetOrLoadTyped[int64]( + s.cache, + cacheKey, + unreadTTL, + jitter, + nullTTL, + func() (int64, error) { + return s.notificationRepo.GetUnreadCount(userID) + }, + ) +} + +// DeleteNotification 删除通知(带用户验证) +func (s *NotificationService) DeleteNotification(ctx context.Context, id, userID string) error { + // 先检查通知是否属于该用户 + notification, err := s.notificationRepo.GetByID(id) + if err != nil { + return err + } + if notification.UserID != userID { + return ErrUnauthorizedNotification + } + + err = s.notificationRepo.Delete(id) + if err != nil { + return err + } + + // 失效未读数缓存 + cache.InvalidateUnreadSystem(s.cache, userID) + + return nil +} + +// ClearAllNotifications 清空所有通知 +func (s *NotificationService) ClearAllNotifications(ctx context.Context, userID string) error { + err := s.notificationRepo.DeleteAllByUserID(userID) + if err != nil { + return err + } + + // 失效未读数缓存 + cache.InvalidateUnreadSystem(s.cache, userID) + + return nil +} + +// 错误定义 +var ErrUnauthorizedNotification = &ServiceError{Code: 403, Message: "unauthorized to delete this notification"} diff --git a/internal/service/post_ai_service.go b/internal/service/post_ai_service.go new file mode 100644 index 0000000..23c947d --- /dev/null +++ b/internal/service/post_ai_service.go @@ -0,0 +1,103 @@ +package service + +import ( + "context" + "log" + "strings" + + "carrot_bbs/internal/pkg/openai" +) + +// PostModerationRejectedError 帖子审核拒绝错误 +type PostModerationRejectedError struct { + Reason string +} + +func (e *PostModerationRejectedError) Error() string { + if strings.TrimSpace(e.Reason) == "" { + return "post rejected by moderation" + } + return "post rejected by moderation: " + e.Reason +} + +// UserMessage 返回给前端的用户可读文案 +func (e *PostModerationRejectedError) UserMessage() string { + if strings.TrimSpace(e.Reason) == "" { + return "内容未通过审核,请修改后重试" + } + return strings.TrimSpace(e.Reason) +} + +// CommentModerationRejectedError 评论审核拒绝错误 +type CommentModerationRejectedError struct { + Reason string +} + +func (e *CommentModerationRejectedError) Error() string { + if strings.TrimSpace(e.Reason) == "" { + return "comment rejected by moderation" + } + return "comment rejected by moderation: " + e.Reason +} + +// UserMessage 返回给前端的用户可读文案 +func (e *CommentModerationRejectedError) UserMessage() string { + if strings.TrimSpace(e.Reason) == "" { + return "评论未通过审核,请修改后重试" + } + return strings.TrimSpace(e.Reason) +} + +type PostAIService struct { + openAIClient openai.Client +} + +func NewPostAIService(openAIClient openai.Client) *PostAIService { + return &PostAIService{ + openAIClient: openAIClient, + } +} + +func (s *PostAIService) IsEnabled() bool { + return s != nil && s.openAIClient != nil && s.openAIClient.IsEnabled() +} + +// ModeratePost 审核帖子内容,返回 nil 表示通过 +func (s *PostAIService) ModeratePost(ctx context.Context, title, content string, images []string) error { + if !s.IsEnabled() { + return nil + } + + approved, reason, err := s.openAIClient.ModeratePost(ctx, title, content, images) + if err != nil { + if s.openAIClient.Config().StrictModeration { + return err + } + log.Printf("[WARN] AI moderation failed, fallback allow: %v", err) + return nil + } + if !approved { + return &PostModerationRejectedError{Reason: reason} + } + return nil +} + +// ModerateComment 审核评论内容,返回 nil 表示通过 +func (s *PostAIService) ModerateComment(ctx context.Context, content string, images []string) error { + if !s.IsEnabled() { + return nil + } + + approved, reason, err := s.openAIClient.ModerateComment(ctx, content, images) + if err != nil { + if s.openAIClient.Config().StrictModeration { + return err + } + log.Printf("[WARN] AI comment moderation failed, fallback allow: %v", err) + return nil + } + if !approved { + return &CommentModerationRejectedError{Reason: reason} + } + return nil +} diff --git a/internal/service/post_service.go b/internal/service/post_service.go new file mode 100644 index 0000000..46c5b9b --- /dev/null +++ b/internal/service/post_service.go @@ -0,0 +1,593 @@ +package service + +import ( + "context" + "errors" + "fmt" + "log" + "strings" + "time" + + "carrot_bbs/internal/cache" + "carrot_bbs/internal/model" + "carrot_bbs/internal/pkg/gorse" + "carrot_bbs/internal/repository" +) + +// 缓存TTL常量 +const ( + PostListTTL = 30 * time.Second // 帖子列表缓存30秒 + PostListNullTTL = 5 * time.Second + PostListJitterRatio = 0.15 + anonymousViewUserID = "_anon_view" +) + +// PostService 帖子服务 +type PostService struct { + postRepo *repository.PostRepository + systemMessageService SystemMessageService + cache cache.Cache + gorseClient gorse.Client + postAIService *PostAIService +} + +// NewPostService 创建帖子服务 +func NewPostService(postRepo *repository.PostRepository, systemMessageService SystemMessageService, gorseClient gorse.Client, postAIService *PostAIService) *PostService { + return &PostService{ + postRepo: postRepo, + systemMessageService: systemMessageService, + cache: cache.GetCache(), + gorseClient: gorseClient, + postAIService: postAIService, + } +} + +// PostListResult 帖子列表缓存结果 +type PostListResult struct { + Posts []*model.Post + Total int64 +} + +// Create 创建帖子 +func (s *PostService) Create(ctx context.Context, userID, title, content string, images []string) (*model.Post, error) { + post := &model.Post{ + UserID: userID, + Title: title, + Content: content, + Status: model.PostStatusPending, + } + + err := s.postRepo.Create(post, images) + if err != nil { + return nil, err + } + + // 失效帖子列表缓存 + cache.InvalidatePostList(s.cache) + + // 同步到Gorse推荐系统(异步) + go s.reviewPostAsync(post.ID, userID, title, content, images) + + // 重新查询以获取关联的 User 和 Images + return s.postRepo.GetByID(post.ID) +} + +func (s *PostService) reviewPostAsync(postID, userID, title, content string, images []string) { + // 未启用AI时,直接发布 + if s.postAIService == nil || !s.postAIService.IsEnabled() { + if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil { + log.Printf("[WARN] Failed to publish post without AI moderation: %v", err) + } + return + } + + err := s.postAIService.ModeratePost(context.Background(), title, content, images) + if err != nil { + var rejectedErr *PostModerationRejectedError + if errors.As(err, &rejectedErr) { + if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil { + log.Printf("[WARN] Failed to reject post %s: %v", postID, updateErr) + } + s.notifyModerationRejected(userID, rejectedErr.Reason) + return + } + + // 规则审核不可用时,降级为发布,避免长时间pending + if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil { + log.Printf("[WARN] Failed to publish post %s after moderation error: %v", postID, updateErr) + } + log.Printf("[WARN] Post moderation failed, fallback publish post=%s err=%v", postID, err) + return + } + + if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "ai"); err != nil { + log.Printf("[WARN] Failed to publish post %s: %v", postID, err) + return + } + + if s.gorseClient.IsEnabled() { + post, getErr := s.postRepo.GetByID(postID) + if getErr != nil { + log.Printf("[WARN] Failed to load published post for gorse sync: %v", getErr) + return + } + categories := s.buildPostCategories(post) + comment := post.Title + textToEmbed := post.Title + " " + post.Content + if upsertErr := s.gorseClient.UpsertItemWithEmbedding(context.Background(), post.ID, categories, comment, textToEmbed); upsertErr != nil { + log.Printf("[WARN] Failed to upsert item to Gorse: %v", upsertErr) + } + } +} + +func (s *PostService) notifyModerationRejected(userID, reason string) { + if s.systemMessageService == nil || strings.TrimSpace(userID) == "" { + return + } + + content := "您发布的帖子未通过AI审核,请修改后重试。" + if strings.TrimSpace(reason) != "" { + content = fmt.Sprintf("您发布的帖子未通过AI审核,原因:%s。请修改后重试。", reason) + } + + go func() { + if err := s.systemMessageService.SendSystemAnnouncement( + context.Background(), + []string{userID}, + "帖子审核未通过", + content, + ); err != nil { + log.Printf("[WARN] Failed to send moderation reject notification: %v", err) + } + }() +} + +// GetByID 根据ID获取帖子 +func (s *PostService) GetByID(ctx context.Context, id string) (*model.Post, error) { + return s.postRepo.GetByID(id) +} + +// Update 更新帖子 +func (s *PostService) Update(ctx context.Context, post *model.Post) error { + err := s.postRepo.Update(post) + if err != nil { + return err + } + + // 失效帖子详情缓存和列表缓存 + cache.InvalidatePostDetail(s.cache, post.ID) + cache.InvalidatePostList(s.cache) + + return nil +} + +// Delete 删除帖子 +func (s *PostService) Delete(ctx context.Context, id string) error { + err := s.postRepo.Delete(id) + if err != nil { + return err + } + + // 失效帖子详情缓存和列表缓存 + cache.InvalidatePostDetail(s.cache, id) + cache.InvalidatePostList(s.cache) + + // 从Gorse中删除帖子(异步) + go func() { + if s.gorseClient.IsEnabled() { + if err := s.gorseClient.DeleteItem(context.Background(), id); err != nil { + log.Printf("[WARN] Failed to delete item from Gorse: %v", err) + } + } + }() + + return nil +} + +// List 获取帖子列表(带缓存) +func (s *PostService) List(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) { + cacheSettings := cache.GetSettings() + postListTTL := cacheSettings.PostListTTL + if postListTTL <= 0 { + postListTTL = PostListTTL + } + nullTTL := cacheSettings.NullTTL + if nullTTL <= 0 { + nullTTL = PostListNullTTL + } + jitter := cacheSettings.JitterRatio + if jitter <= 0 { + jitter = PostListJitterRatio + } + + // 生成缓存键(包含 userID 维度,避免过滤查询与全量查询互相污染) + cacheKey := cache.PostListKey("latest", userID, page, pageSize) + + result, err := cache.GetOrLoadTyped[*PostListResult]( + s.cache, + cacheKey, + postListTTL, + jitter, + nullTTL, + func() (*PostListResult, error) { + posts, total, err := s.postRepo.List(page, pageSize, userID) + if err != nil { + return nil, err + } + return &PostListResult{Posts: posts, Total: total}, nil + }, + ) + if err != nil { + return nil, 0, err + } + if result == nil { + return []*model.Post{}, 0, nil + } + + // 兼容历史脏缓存:旧缓存序列化会丢失 Post.User,导致前端显示“匿名用户” + // 这里检测并回源重建一次缓存,避免在 TTL 内持续返回缺失作者的数据。 + missingAuthor := false + for _, post := range result.Posts { + if post != nil && post.UserID != "" && post.User == nil { + missingAuthor = true + break + } + } + if missingAuthor { + posts, total, loadErr := s.postRepo.List(page, pageSize, userID) + if loadErr != nil { + return nil, 0, loadErr + } + result = &PostListResult{Posts: posts, Total: total} + cache.SetWithJitter(s.cache, cacheKey, result, postListTTL, jitter) + } + + return result.Posts, result.Total, nil +} + +// GetLatestPosts 获取最新帖子(语义化别名) +func (s *PostService) GetLatestPosts(ctx context.Context, page, pageSize int, userID string) ([]*model.Post, int64, error) { + return s.List(ctx, page, pageSize, userID) +} + +// GetUserPosts 获取用户帖子 +func (s *PostService) GetUserPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) { + return s.postRepo.GetUserPosts(userID, page, pageSize) +} + +// Like 点赞 +func (s *PostService) Like(ctx context.Context, postID, userID string) error { + // 获取帖子信息用于发送通知 + post, err := s.postRepo.GetByID(postID) + if err != nil { + return err + } + + err = s.postRepo.Like(postID, userID) + if err != nil { + return err + } + + // 失效帖子详情缓存 + cache.InvalidatePostDetail(s.cache, postID) + + // 发送点赞通知(不给自己发通知) + if s.systemMessageService != nil && post.UserID != userID { + go func() { + notifyErr := s.systemMessageService.SendLikeNotification(context.Background(), post.UserID, userID, postID) + if notifyErr != nil { + fmt.Printf("[DEBUG] Error sending like notification: %v\n", notifyErr) + } else { + fmt.Printf("[DEBUG] Like notification sent successfully\n") + } + }() + } + + // 推送点赞行为到Gorse(异步) + go func() { + if s.gorseClient.IsEnabled() { + if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeLike, userID, postID); err != nil { + log.Printf("[WARN] Failed to insert like feedback to Gorse: %v", err) + } + } + }() + + return nil +} + +// Unlike 取消点赞 +func (s *PostService) Unlike(ctx context.Context, postID, userID string) error { + err := s.postRepo.Unlike(postID, userID) + if err != nil { + return err + } + + // 失效帖子详情缓存 + cache.InvalidatePostDetail(s.cache, postID) + + // 删除Gorse中的点赞反馈(异步) + go func() { + if s.gorseClient.IsEnabled() { + if err := s.gorseClient.DeleteFeedback(context.Background(), gorse.FeedbackTypeLike, userID, postID); err != nil { + log.Printf("[WARN] Failed to delete like feedback from Gorse: %v", err) + } + } + }() + + return nil +} + +// IsLiked 检查是否点赞 +func (s *PostService) IsLiked(ctx context.Context, postID, userID string) bool { + return s.postRepo.IsLiked(postID, userID) +} + +// Favorite 收藏 +func (s *PostService) Favorite(ctx context.Context, postID, userID string) error { + // 获取帖子信息用于发送通知 + post, err := s.postRepo.GetByID(postID) + if err != nil { + return err + } + + err = s.postRepo.Favorite(postID, userID) + if err != nil { + return err + } + + // 失效帖子详情缓存 + cache.InvalidatePostDetail(s.cache, postID) + + // 发送收藏通知(不给自己发通知) + if s.systemMessageService != nil && post.UserID != userID { + go func() { + notifyErr := s.systemMessageService.SendFavoriteNotification(context.Background(), post.UserID, userID, postID) + if notifyErr != nil { + fmt.Printf("[DEBUG] Error sending favorite notification: %v\n", notifyErr) + } else { + fmt.Printf("[DEBUG] Favorite notification sent successfully\n") + } + }() + } + + // 推送收藏行为到Gorse(异步) + go func() { + if s.gorseClient.IsEnabled() { + if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeStar, userID, postID); err != nil { + log.Printf("[WARN] Failed to insert favorite feedback to Gorse: %v", err) + } + } + }() + + return nil +} + +// Unfavorite 取消收藏 +func (s *PostService) Unfavorite(ctx context.Context, postID, userID string) error { + err := s.postRepo.Unfavorite(postID, userID) + if err != nil { + return err + } + + // 失效帖子详情缓存 + cache.InvalidatePostDetail(s.cache, postID) + + // 删除Gorse中的收藏反馈(异步) + go func() { + if s.gorseClient.IsEnabled() { + if err := s.gorseClient.DeleteFeedback(context.Background(), gorse.FeedbackTypeStar, userID, postID); err != nil { + log.Printf("[WARN] Failed to delete favorite feedback from Gorse: %v", err) + } + } + }() + + return nil +} + +// IsFavorited 检查是否收藏 +func (s *PostService) IsFavorited(ctx context.Context, postID, userID string) bool { + return s.postRepo.IsFavorited(postID, userID) +} + +// IncrementViews 增加帖子观看量并同步到Gorse +func (s *PostService) IncrementViews(ctx context.Context, postID, userID string) error { + if err := s.postRepo.IncrementViews(postID); err != nil { + return err + } + + // 同步浏览行为到Gorse(异步) + go func() { + if !s.gorseClient.IsEnabled() { + return + } + + feedbackUserID := userID + if feedbackUserID == "" { + feedbackUserID = anonymousViewUserID + } + + if err := s.gorseClient.InsertFeedback(context.Background(), gorse.FeedbackTypeRead, feedbackUserID, postID); err != nil { + log.Printf("[WARN] Failed to insert read feedback to Gorse: %v", err) + } + }() + + return nil +} + +// GetFavorites 获取收藏列表 +func (s *PostService) GetFavorites(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) { + return s.postRepo.GetFavorites(userID, page, pageSize) +} + +// Search 搜索帖子 +func (s *PostService) Search(ctx context.Context, keyword string, page, pageSize int) ([]*model.Post, int64, error) { + return s.postRepo.Search(keyword, page, pageSize) +} + +// GetFollowingPosts 获取关注用户的帖子(带缓存) +func (s *PostService) GetFollowingPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) { + cacheSettings := cache.GetSettings() + postListTTL := cacheSettings.PostListTTL + if postListTTL <= 0 { + postListTTL = PostListTTL + } + nullTTL := cacheSettings.NullTTL + if nullTTL <= 0 { + nullTTL = PostListNullTTL + } + jitter := cacheSettings.JitterRatio + if jitter <= 0 { + jitter = PostListJitterRatio + } + + // 生成缓存键 + cacheKey := cache.PostListKey("follow", userID, page, pageSize) + + result, err := cache.GetOrLoadTyped[*PostListResult]( + s.cache, + cacheKey, + postListTTL, + jitter, + nullTTL, + func() (*PostListResult, error) { + posts, total, err := s.postRepo.GetFollowingPosts(userID, page, pageSize) + if err != nil { + return nil, err + } + return &PostListResult{Posts: posts, Total: total}, nil + }, + ) + if err != nil { + return nil, 0, err + } + if result == nil { + return []*model.Post{}, 0, nil + } + return result.Posts, result.Total, nil +} + +// GetHotPosts 获取热门帖子(使用Gorse非个性化推荐) +func (s *PostService) GetHotPosts(ctx context.Context, page, pageSize int) ([]*model.Post, int64, error) { + // 如果Gorse启用,使用自定义的非个性化推荐器 + if s.gorseClient.IsEnabled() { + offset := (page - 1) * pageSize + // 使用 most_liked_weekly 推荐器获取周热门 + // 多取1条用于判断是否还有下一页 + itemIDs, err := s.gorseClient.GetNonPersonalized(ctx, "most_liked_weekly", pageSize+1, offset, "") + if err != nil { + log.Printf("[WARN] Gorse GetNonPersonalized failed: %v, fallback to database", err) + return s.getHotPostsFromDB(ctx, page, pageSize) + } + if len(itemIDs) > 0 { + hasNext := len(itemIDs) > pageSize + if hasNext { + itemIDs = itemIDs[:pageSize] + } + posts, err := s.postRepo.GetByIDs(itemIDs) + if err != nil { + return nil, 0, err + } + // 近似 total:当 hasNext 为 true 时,按分页窗口估算,避免因脏数据/缺失数据导致总页数被低估 + estimatedTotal := int64(offset + len(posts)) + if hasNext { + estimatedTotal = int64(offset + pageSize + 1) + } + return posts, estimatedTotal, nil + } + } + + // 降级:从数据库获取 + return s.getHotPostsFromDB(ctx, page, pageSize) +} + +// getHotPostsFromDB 从数据库获取热门帖子(降级路径) +func (s *PostService) getHotPostsFromDB(ctx context.Context, page, pageSize int) ([]*model.Post, int64, error) { + // 直接查询数据库,不再使用本地缓存(Gorse失败降级时使用) + posts, total, err := s.postRepo.GetHotPosts(page, pageSize) + if err != nil { + return nil, 0, err + } + return posts, total, nil +} + +// GetRecommendedPosts 获取推荐帖子 +func (s *PostService) GetRecommendedPosts(ctx context.Context, userID string, page, pageSize int) ([]*model.Post, int64, error) { + // 如果Gorse未启用或用户未登录,降级为热门帖子 + if !s.gorseClient.IsEnabled() || userID == "" { + return s.GetHotPosts(ctx, page, pageSize) + } + + // 计算偏移量 + offset := (page - 1) * pageSize + + // 从Gorse获取推荐列表 + // 多取1条用于判断是否还有下一页 + itemIDs, err := s.gorseClient.GetRecommend(ctx, userID, pageSize+1, offset) + if err != nil { + log.Printf("[WARN] Gorse recommendation failed: %v, fallback to hot posts", err) + return s.GetHotPosts(ctx, page, pageSize) + } + + // 如果没有推荐结果,降级为热门帖子 + if len(itemIDs) == 0 { + return s.GetHotPosts(ctx, page, pageSize) + } + + hasNext := len(itemIDs) > pageSize + if hasNext { + itemIDs = itemIDs[:pageSize] + } + + // 根据ID列表查询帖子详情 + posts, err := s.postRepo.GetByIDs(itemIDs) + if err != nil { + return nil, 0, err + } + + // 近似 total:当 hasNext 为 true 时,按分页窗口估算,避免因脏数据/缺失数据导致总页数被低估 + estimatedTotal := int64(offset + len(posts)) + if hasNext { + estimatedTotal = int64(offset + pageSize + 1) + } + return posts, estimatedTotal, nil +} + +// buildPostCategories 构建帖子的类别标签 +func (s *PostService) buildPostCategories(post *model.Post) []string { + var categories []string + + // 热度标签 + if post.ViewsCount > 1000 { + categories = append(categories, "hot_high") + } else if post.ViewsCount > 100 { + categories = append(categories, "hot_medium") + } + + // 点赞标签 + if post.LikesCount > 100 { + categories = append(categories, "likes_100+") + } else if post.LikesCount > 50 { + categories = append(categories, "likes_50+") + } else if post.LikesCount > 10 { + categories = append(categories, "likes_10+") + } + + // 评论标签 + if post.CommentsCount > 50 { + categories = append(categories, "comments_50+") + } else if post.CommentsCount > 10 { + categories = append(categories, "comments_10+") + } + + // 时间标签 + age := time.Since(post.CreatedAt) + if age < 24*time.Hour { + categories = append(categories, "today") + } else if age < 7*24*time.Hour { + categories = append(categories, "this_week") + } else if age < 30*24*time.Hour { + categories = append(categories, "this_month") + } + + return categories +} diff --git a/internal/service/push_service.go b/internal/service/push_service.go new file mode 100644 index 0000000..80fc1c1 --- /dev/null +++ b/internal/service/push_service.go @@ -0,0 +1,575 @@ +package service + +import ( + "context" + "errors" + "fmt" + "time" + + "carrot_bbs/internal/dto" + "carrot_bbs/internal/model" + "carrot_bbs/internal/pkg/websocket" + "carrot_bbs/internal/repository" +) + +// 推送相关常量 +const ( + // DefaultPushTimeout 默认推送超时时间 + DefaultPushTimeout = 30 * time.Second + // MaxRetryCount 最大重试次数 + MaxRetryCount = 3 + // DefaultExpiredTime 默认消息过期时间(24小时) + DefaultExpiredTime = 24 * time.Hour + // PushQueueSize 推送队列大小 + PushQueueSize = 1000 +) + +// PushPriority 推送优先级 +type PushPriority int + +const ( + PriorityLow PushPriority = 1 // 低优先级(营销消息等) + PriorityNormal PushPriority = 5 // 普通优先级(系统通知) + PriorityHigh PushPriority = 8 // 高优先级(聊天消息) + PriorityCritical PushPriority = 10 // 最高优先级(重要系统通知) +) + +// PushService 推送服务接口 +type PushService interface { + // 推送核心方法 + PushMessage(ctx context.Context, userID string, message *model.Message) error + PushToUser(ctx context.Context, userID string, message *model.Message, priority int) error + + // 系统消息推送 + PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error + PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error + PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error + + // 系统通知推送(新接口,使用独立的 SystemNotification 模型) + PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error + + // 设备管理 + RegisterDevice(ctx context.Context, userID string, deviceID string, deviceType model.DeviceType, pushToken string) error + UnregisterDevice(ctx context.Context, deviceID string) error + UpdateDeviceToken(ctx context.Context, deviceID string, newPushToken string) error + + // 推送记录管理 + CreatePushRecord(ctx context.Context, userID string, messageID string, channel model.PushChannel) (*model.PushRecord, error) + GetPendingPushes(ctx context.Context, userID string) ([]*model.PushRecord, error) + + // 后台任务 + StartPushWorker(ctx context.Context) + StopPushWorker() +} + +// pushServiceImpl 推送服务实现 +type pushServiceImpl struct { + pushRepo *repository.PushRecordRepository + deviceRepo *repository.DeviceTokenRepository + messageRepo *repository.MessageRepository + wsManager *websocket.WebSocketManager + + // 推送队列 + pushQueue chan *pushTask + stopChan chan struct{} +} + +// pushTask 推送任务 +type pushTask struct { + userID string + message *model.Message + priority int +} + +// NewPushService 创建推送服务 +func NewPushService( + pushRepo *repository.PushRecordRepository, + deviceRepo *repository.DeviceTokenRepository, + messageRepo *repository.MessageRepository, + wsManager *websocket.WebSocketManager, +) PushService { + return &pushServiceImpl{ + pushRepo: pushRepo, + deviceRepo: deviceRepo, + messageRepo: messageRepo, + wsManager: wsManager, + pushQueue: make(chan *pushTask, PushQueueSize), + stopChan: make(chan struct{}), + } +} + +// PushMessage 推送消息给用户 +func (s *pushServiceImpl) PushMessage(ctx context.Context, userID string, message *model.Message) error { + return s.PushToUser(ctx, userID, message, int(PriorityNormal)) +} + +// PushToUser 带优先级的推送 +func (s *pushServiceImpl) PushToUser(ctx context.Context, userID string, message *model.Message, priority int) error { + // 首先尝试WebSocket推送(实时推送) + if s.pushViaWebSocket(ctx, userID, message) { + // WebSocket推送成功,记录推送状态 + record, err := s.CreatePushRecord(ctx, userID, message.ID, model.PushChannelWebSocket) + if err != nil { + return fmt.Errorf("failed to create push record: %w", err) + } + record.MarkPushed() + if err := s.pushRepo.Update(record); err != nil { + return fmt.Errorf("failed to update push record: %w", err) + } + return nil + } + + // WebSocket推送失败,加入推送队列等待移动端推送 + select { + case s.pushQueue <- &pushTask{ + userID: userID, + message: message, + priority: priority, + }: + return nil + default: + // 队列已满,直接创建待推送记录 + _, err := s.CreatePushRecord(ctx, userID, message.ID, model.PushChannelFCM) + if err != nil { + return fmt.Errorf("failed to create pending push record: %w", err) + } + return errors.New("push queue is full, message queued for later delivery") + } +} + +// pushViaWebSocket 通过WebSocket推送消息 +// 返回true表示推送成功,false表示用户不在线 +func (s *pushServiceImpl) pushViaWebSocket(ctx context.Context, userID string, message *model.Message) bool { + if s.wsManager == nil { + return false + } + + if !s.wsManager.IsUserOnline(userID) { + return false + } + + // 判断是否为系统消息/通知消息 + if message.IsSystemMessage() || message.Category == model.CategoryNotification { + // 使用 NotificationMessage 格式推送系统通知 + // 从 segments 中提取文本内容 + content := dto.ExtractTextContentFromModel(message.Segments) + + notification := &websocket.NotificationMessage{ + ID: fmt.Sprintf("%s", message.ID), + Type: string(message.SystemType), + Content: content, + Extra: make(map[string]interface{}), + CreatedAt: message.CreatedAt.UnixMilli(), + } + + // 填充额外数据 + if message.ExtraData != nil { + notification.Extra["actor_id"] = message.ExtraData.ActorID + notification.Extra["actor_name"] = message.ExtraData.ActorName + notification.Extra["avatar_url"] = message.ExtraData.AvatarURL + notification.Extra["target_id"] = message.ExtraData.TargetID + notification.Extra["target_type"] = message.ExtraData.TargetType + notification.Extra["action_url"] = message.ExtraData.ActionURL + notification.Extra["action_time"] = message.ExtraData.ActionTime + + // 设置触发用户信息 + if message.ExtraData.ActorID > 0 { + notification.TriggerUser = &websocket.NotificationUser{ + ID: fmt.Sprintf("%d", message.ExtraData.ActorID), + Username: message.ExtraData.ActorName, + Avatar: message.ExtraData.AvatarURL, + } + } + } + + wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification) + s.wsManager.SendToUser(userID, wsMsg) + return true + } + + // 构建普通聊天消息的 WebSocket 消息 - 使用新的 WSEventResponse 格式 + // 获取会话类型 (private/group) + detailType := "private" + if message.ConversationID != "" { + // 从会话中获取类型,需要查询数据库或从消息中判断 + // 这里暂时默认为 private,group 类型需要额外逻辑 + } + + // 直接使用 message.Segments + segments := message.Segments + + event := &dto.WSEventResponse{ + ID: fmt.Sprintf("%s", message.ID), + Time: message.CreatedAt.UnixMilli(), + Type: "message", + DetailType: detailType, + Seq: fmt.Sprintf("%d", message.Seq), + Segments: segments, + SenderID: message.SenderID, + } + + wsMsg := websocket.CreateWSMessage(websocket.MessageTypeMessage, event) + s.wsManager.SendToUser(userID, wsMsg) + return true +} + +// pushViaFCM 通过FCM推送(预留接口) +func (s *pushServiceImpl) pushViaFCM(ctx context.Context, deviceToken *model.DeviceToken, message *model.Message) error { + // TODO: 实现FCM推送 + // 1. 构建FCM消息 + // 2. 调用Firebase Admin SDK发送消息 + // 3. 处理发送结果 + return errors.New("FCM push not implemented") +} + +// pushViaAPNs 通过APNs推送(预留接口) +func (s *pushServiceImpl) pushViaAPNs(ctx context.Context, deviceToken *model.DeviceToken, message *model.Message) error { + // TODO: 实现APNs推送 + // 1. 构建APNs消息 + // 2. 调用APNs SDK发送消息 + // 3. 处理发送结果 + return errors.New("APNs push not implemented") +} + +// RegisterDevice 注册设备 +func (s *pushServiceImpl) RegisterDevice(ctx context.Context, userID string, deviceID string, deviceType model.DeviceType, pushToken string) error { + deviceToken := &model.DeviceToken{ + UserID: userID, + DeviceID: deviceID, + DeviceType: deviceType, + PushToken: pushToken, + IsActive: true, + } + deviceToken.UpdateLastUsed() + + return s.deviceRepo.Upsert(deviceToken) +} + +// UnregisterDevice 注销设备 +func (s *pushServiceImpl) UnregisterDevice(ctx context.Context, deviceID string) error { + return s.deviceRepo.Deactivate(deviceID) +} + +// UpdateDeviceToken 更新设备Token +func (s *pushServiceImpl) UpdateDeviceToken(ctx context.Context, deviceID string, newPushToken string) error { + deviceToken, err := s.deviceRepo.GetByDeviceID(deviceID) + if err != nil { + return fmt.Errorf("device not found: %w", err) + } + + deviceToken.PushToken = newPushToken + deviceToken.Activate() + + return s.deviceRepo.Update(deviceToken) +} + +// CreatePushRecord 创建推送记录 +func (s *pushServiceImpl) CreatePushRecord(ctx context.Context, userID string, messageID string, channel model.PushChannel) (*model.PushRecord, error) { + expiredAt := time.Now().Add(DefaultExpiredTime) + record := &model.PushRecord{ + UserID: userID, + MessageID: messageID, + PushChannel: channel, + PushStatus: model.PushStatusPending, + MaxRetry: MaxRetryCount, + ExpiredAt: &expiredAt, + } + + if err := s.pushRepo.Create(record); err != nil { + return nil, fmt.Errorf("failed to create push record: %w", err) + } + + return record, nil +} + +// GetPendingPushes 获取待推送记录 +func (s *pushServiceImpl) GetPendingPushes(ctx context.Context, userID string) ([]*model.PushRecord, error) { + return s.pushRepo.GetByUserID(userID, 100, 0) +} + +// StartPushWorker 启动推送工作协程 +func (s *pushServiceImpl) StartPushWorker(ctx context.Context) { + go s.processPushQueue() + go s.retryFailedPushes() +} + +// StopPushWorker 停止推送工作协程 +func (s *pushServiceImpl) StopPushWorker() { + close(s.stopChan) +} + +// processPushQueue 处理推送队列 +func (s *pushServiceImpl) processPushQueue() { + for { + select { + case <-s.stopChan: + return + case task := <-s.pushQueue: + s.processPushTask(task) + } + } +} + +// processPushTask 处理单个推送任务 +func (s *pushServiceImpl) processPushTask(task *pushTask) { + ctx, cancel := context.WithTimeout(context.Background(), DefaultPushTimeout) + defer cancel() + + // 获取用户活跃设备 + devices, err := s.deviceRepo.GetActiveByUserID(task.userID) + if err != nil || len(devices) == 0 { + // 没有可用设备,创建待推送记录 + s.CreatePushRecord(ctx, task.userID, task.message.ID, model.PushChannelFCM) + return + } + + // 对每个设备创建推送记录并尝试推送 + for _, device := range devices { + record, err := s.CreatePushRecord(ctx, task.userID, task.message.ID, s.getChannelForDevice(device)) + if err != nil { + continue + } + + var pushErr error + switch { + case device.IsIOS(): + pushErr = s.pushViaAPNs(ctx, device, task.message) + case device.IsAndroid(): + pushErr = s.pushViaFCM(ctx, device, task.message) + default: + // Web设备只支持WebSocket + continue + } + + if pushErr != nil { + record.MarkFailed(pushErr.Error()) + } else { + record.MarkPushed() + } + s.pushRepo.Update(record) + } +} + +// getChannelForDevice 根据设备类型获取推送通道 +func (s *pushServiceImpl) getChannelForDevice(device *model.DeviceToken) model.PushChannel { + switch device.DeviceType { + case model.DeviceTypeIOS: + return model.PushChannelAPNs + case model.DeviceTypeAndroid: + return model.PushChannelFCM + default: + return model.PushChannelWebSocket + } +} + +// retryFailedPushes 重试失败的推送 +func (s *pushServiceImpl) retryFailedPushes() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-s.stopChan: + return + case <-ticker.C: + s.doRetry() + } + } +} + +// doRetry 执行重试 +func (s *pushServiceImpl) doRetry() { + ctx := context.Background() + + // 获取失败待重试的推送 + records, err := s.pushRepo.GetFailedPushesForRetry(100) + if err != nil { + return + } + + for _, record := range records { + // 检查是否过期 + if record.IsExpired() { + record.MarkExpired() + s.pushRepo.Update(record) + continue + } + + // 获取消息 + message, err := s.messageRepo.GetMessageByID(record.MessageID) + if err != nil { + record.MarkFailed("message not found") + s.pushRepo.Update(record) + continue + } + + // 尝试WebSocket推送 + if s.pushViaWebSocket(ctx, record.UserID, message) { + record.MarkDelivered() + s.pushRepo.Update(record) + continue + } + + // 获取设备并尝试移动端推送 + if record.DeviceToken != "" { + device, err := s.deviceRepo.GetByPushToken(record.DeviceToken) + if err != nil { + record.MarkFailed("device not found") + s.pushRepo.Update(record) + continue + } + + var pushErr error + switch { + case device.IsIOS(): + pushErr = s.pushViaAPNs(ctx, device, message) + case device.IsAndroid(): + pushErr = s.pushViaFCM(ctx, device, message) + } + + if pushErr != nil { + record.MarkFailed(pushErr.Error()) + } else { + record.MarkPushed() + } + s.pushRepo.Update(record) + } + } +} + +// PushSystemMessage 推送系统消息 +func (s *pushServiceImpl) PushSystemMessage(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) error { + // 首先尝试WebSocket推送 + if s.pushSystemViaWebSocket(ctx, userID, msgType, title, content, data) { + return nil + } + + // 用户不在线,创建待推送记录(移动端上线后可通过其他方式获取) + // 系统消息通常不需要离线推送,客户端上线后会主动拉取 + return errors.New("user is offline, system message will be available on next sync") +} + +// pushSystemViaWebSocket 通过WebSocket推送系统消息 +func (s *pushServiceImpl) pushSystemViaWebSocket(ctx context.Context, userID string, msgType, title, content string, data map[string]interface{}) bool { + if s.wsManager == nil { + return false + } + + if !s.wsManager.IsUserOnline(userID) { + return false + } + + sysMsg := &websocket.SystemMessage{ + Type: msgType, + Title: title, + Content: content, + Data: data, + CreatedAt: time.Now().UnixMilli(), + } + + wsMsg := websocket.CreateWSMessage(websocket.MessageTypeSystem, sysMsg) + s.wsManager.SendToUser(userID, wsMsg) + return true +} + +// PushNotification 推送通知消息 +func (s *pushServiceImpl) PushNotification(ctx context.Context, userID string, notification *websocket.NotificationMessage) error { + // 首先尝试WebSocket推送 + if s.pushNotificationViaWebSocket(ctx, userID, notification) { + return nil + } + + // 用户不在线,创建待推送记录 + // 通知消息可以等用户上线后拉取 + return errors.New("user is offline, notification will be available on next sync") +} + +// pushNotificationViaWebSocket 通过WebSocket推送通知消息 +func (s *pushServiceImpl) pushNotificationViaWebSocket(ctx context.Context, userID string, notification *websocket.NotificationMessage) bool { + if s.wsManager == nil { + return false + } + + if !s.wsManager.IsUserOnline(userID) { + return false + } + + if notification.CreatedAt == 0 { + notification.CreatedAt = time.Now().UnixMilli() + } + + wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, notification) + s.wsManager.SendToUser(userID, wsMsg) + return true +} + +// PushAnnouncement 广播公告消息 +func (s *pushServiceImpl) PushAnnouncement(ctx context.Context, announcement *websocket.AnnouncementMessage) error { + if s.wsManager == nil { + return errors.New("websocket manager not available") + } + + if announcement.CreatedAt == 0 { + announcement.CreatedAt = time.Now().UnixMilli() + } + + wsMsg := websocket.CreateWSMessage(websocket.MessageTypeAnnouncement, announcement) + s.wsManager.Broadcast(wsMsg) + return nil +} + +// PushSystemNotification 推送系统通知(使用独立的 SystemNotification 模型) +func (s *pushServiceImpl) PushSystemNotification(ctx context.Context, userID string, notification *model.SystemNotification) error { + // 首先尝试WebSocket推送 + if s.pushSystemNotificationViaWebSocket(ctx, userID, notification) { + return nil + } + + // 用户不在线,系统通知已存储在数据库中,用户上线后会主动拉取 + return nil +} + +// pushSystemNotificationViaWebSocket 通过WebSocket推送系统通知 +func (s *pushServiceImpl) pushSystemNotificationViaWebSocket(ctx context.Context, userID string, notification *model.SystemNotification) bool { + if s.wsManager == nil { + return false + } + + if !s.wsManager.IsUserOnline(userID) { + return false + } + + // 构建 WebSocket 通知消息 + wsNotification := &websocket.NotificationMessage{ + ID: fmt.Sprintf("%d", notification.ID), + Type: string(notification.Type), + Title: notification.Title, + Content: notification.Content, + Extra: make(map[string]interface{}), + CreatedAt: notification.CreatedAt.UnixMilli(), + } + + // 填充额外数据 + if notification.ExtraData != nil { + wsNotification.Extra["actor_id_str"] = notification.ExtraData.ActorIDStr + wsNotification.Extra["actor_name"] = notification.ExtraData.ActorName + wsNotification.Extra["avatar_url"] = notification.ExtraData.AvatarURL + wsNotification.Extra["target_id"] = notification.ExtraData.TargetID + wsNotification.Extra["target_type"] = notification.ExtraData.TargetType + wsNotification.Extra["action_url"] = notification.ExtraData.ActionURL + wsNotification.Extra["action_time"] = notification.ExtraData.ActionTime + + // 设置触发用户信息 + if notification.ExtraData.ActorIDStr != "" { + wsNotification.TriggerUser = &websocket.NotificationUser{ + ID: notification.ExtraData.ActorIDStr, + Username: notification.ExtraData.ActorName, + Avatar: notification.ExtraData.AvatarURL, + } + } + } + + wsMsg := websocket.CreateWSMessage(websocket.MessageTypeNotification, wsNotification) + s.wsManager.SendToUser(userID, wsMsg) + return true +} diff --git a/internal/service/sensitive_service.go b/internal/service/sensitive_service.go new file mode 100644 index 0000000..8ef69fa --- /dev/null +++ b/internal/service/sensitive_service.go @@ -0,0 +1,559 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "log" + "regexp" + "strings" + "sync" + "time" + "unicode/utf8" + + "carrot_bbs/internal/model" + redisclient "carrot_bbs/internal/pkg/redis" + + "gorm.io/gorm" +) + +// ==================== DFA 敏感词过滤实现 ==================== + +// SensitiveNode 敏感词树节点 +type SensitiveNode struct { + // 子节点映射 + Children map[rune]*SensitiveNode + // 是否为敏感词结尾 + IsEnd bool + // 敏感词信息(仅在 IsEnd 为 true 时有效) + Word string + Level model.SensitiveWordLevel + Category model.SensitiveWordCategory +} + +// NewSensitiveNode 创建新的敏感词节点 +func NewSensitiveNode() *SensitiveNode { + return &SensitiveNode{ + Children: make(map[rune]*SensitiveNode), + IsEnd: false, + } +} + +// SensitiveWordTree 敏感词树 +type SensitiveWordTree struct { + root *SensitiveNode + wordCount int + mu sync.RWMutex + lastReload time.Time +} + +// NewSensitiveWordTree 创建新的敏感词树 +func NewSensitiveWordTree() *SensitiveWordTree { + return &SensitiveWordTree{ + root: NewSensitiveNode(), + wordCount: 0, + lastReload: time.Now(), + } +} + +// AddWord 添加敏感词到树中 +func (t *SensitiveWordTree) AddWord(word string, level model.SensitiveWordLevel, category model.SensitiveWordCategory) { + if word == "" { + return + } + + t.mu.Lock() + defer t.mu.Unlock() + + node := t.root + // 转换为小写进行匹配(不区分大小写) + lowerWord := strings.ToLower(word) + runes := []rune(lowerWord) + + for _, r := range runes { + child, exists := node.Children[r] + if !exists { + child = NewSensitiveNode() + node.Children[r] = child + } + node = child + } + + // 如果不是已存在的敏感词,则计数+1 + if !node.IsEnd { + t.wordCount++ + } + + node.IsEnd = true + node.Word = word + node.Level = level + node.Category = category +} + +// RemoveWord 从树中移除敏感词 +func (t *SensitiveWordTree) RemoveWord(word string) { + if word == "" { + return + } + + t.mu.Lock() + defer t.mu.Unlock() + + lowerWord := strings.ToLower(word) + runes := []rune(lowerWord) + + // 查找节点 + node := t.root + for _, r := range runes { + child, exists := node.Children[r] + if !exists { + return // 敏感词不存在 + } + node = child + } + + if node.IsEnd { + node.IsEnd = false + node.Word = "" + t.wordCount-- + } +} + +// Check 检查文本是否包含敏感词,返回是否包含及敏感词列表 +func (t *SensitiveWordTree) Check(text string) (bool, []string) { + if text == "" { + return false, nil + } + + t.mu.RLock() + defer t.mu.RUnlock() + + var foundWords []string + runes := []rune(strings.ToLower(text)) + length := len(runes) + + // 用于标记已找到的敏感词位置,避免重复计算 + marked := make([]bool, length) + + for i := 0; i < length; i++ { + // 从当前位置开始搜索 + node := t.root + matchEnd := -1 + matchWord := "" + + for j := i; j < length; j++ { + child, exists := node.Children[runes[j]] + if !exists { + break + } + node = child + + if node.IsEnd { + matchEnd = j + matchWord = node.Word + } + } + + // 标记找到的敏感词位置 + if matchEnd >= 0 && !marked[i] { + for k := i; k <= matchEnd; k++ { + marked[k] = true + } + foundWords = append(foundWords, matchWord) + } + } + + return len(foundWords) > 0, foundWords +} + +// Replace 替换文本中的敏感词 +func (t *SensitiveWordTree) Replace(text string, repl string) string { + if text == "" { + return text + } + + t.mu.RLock() + defer t.mu.RUnlock() + + runes := []rune(text) + length := len(runes) + result := make([]rune, 0, length) + + // 用于标记已替换的位置 + marked := make([]bool, length) + + for i := 0; i < length; i++ { + if marked[i] { + continue + } + + // 从当前位置开始搜索 + node := t.root + matchEnd := -1 + + for j := i; j < length; j++ { + child, exists := node.Children[runes[j]] + if !exists { + break + } + node = child + + if node.IsEnd { + matchEnd = j + } + } + + if matchEnd >= 0 { + // 标记已替换的位置 + for k := i; k <= matchEnd; k++ { + marked[k] = true + } + // 追加替换符 + replRunes := []rune(repl) + result = append(result, replRunes...) + // 跳过已匹配的字符 + i = matchEnd + } else { + // 追加原字符 + result = append(result, runes[i]) + } + } + + return string(result) +} + +// WordCount 获取敏感词数量 +func (t *SensitiveWordTree) WordCount() int { + t.mu.RLock() + defer t.mu.RUnlock() + return t.wordCount +} + +// ==================== 敏感词服务实现 ==================== + +// SensitiveService 敏感词服务接口 +type SensitiveService interface { + // Check 检查文本是否包含敏感词 + Check(ctx context.Context, text string) (bool, []string) + // Replace 替换敏感词 + Replace(ctx context.Context, text string, repl string) string + // AddWord 添加敏感词 + AddWord(ctx context.Context, word string, category string, level int) error + // RemoveWord 移除敏感词 + RemoveWord(ctx context.Context, word string) error + // Reload 重新加载敏感词库 + Reload(ctx context.Context) error + // GetWordCount 获取敏感词数量 + GetWordCount(ctx context.Context) int +} + +// sensitiveServiceImpl 敏感词服务实现 +type sensitiveServiceImpl struct { + tree *SensitiveWordTree + db *gorm.DB + redis *redisclient.Client + config *SensitiveConfig + mu sync.RWMutex + replaceStr string +} + +// SensitiveConfig 敏感词服务配置 +type SensitiveConfig struct { + Enabled bool `mapstructure:"enabled" yaml:"enabled"` + ReplaceStr string `mapstructure:"replace_str" yaml:"replace_str"` + // 最小匹配长度 + MinMatchLen int `mapstructure:"min_match_len" yaml:"min_match_len"` + // 是否从数据库加载 + LoadFromDB bool `mapstructure:"load_from_db" yaml:"load_from_db"` + // 是否从Redis加载 + LoadFromRedis bool `mapstructure:"load_from_redis" yaml:"load_from_redis"` + // Redis键前缀 + RedisKeyPrefix string `mapstructure:"redis_key_prefix" yaml:"redis_key_prefix"` +} + +// NewSensitiveService 创建敏感词服务 +func NewSensitiveService(db *gorm.DB, redisClient *redisclient.Client, config *SensitiveConfig) SensitiveService { + s := &sensitiveServiceImpl{ + tree: NewSensitiveWordTree(), + db: db, + redis: redisClient, + config: config, + replaceStr: config.ReplaceStr, + } + + // 如果未设置替换符,默认使用 *** + if s.replaceStr == "" { + s.replaceStr = "***" + } + + // 初始化加载敏感词 + if config.LoadFromDB { + if err := s.loadFromDB(context.Background()); err != nil { + log.Printf("Failed to load sensitive words from database: %v", err) + } + } + + if config.LoadFromRedis && redisClient != nil { + if err := s.loadFromRedis(context.Background()); err != nil { + log.Printf("Failed to load sensitive words from redis: %v", err) + } + } + + return s +} + +// Check 检查文本是否包含敏感词 +func (s *sensitiveServiceImpl) Check(ctx context.Context, text string) (bool, []string) { + if !s.config.Enabled { + return false, nil + } + if text == "" { + return false, nil + } + return s.tree.Check(text) +} + +// Replace 替换敏感词 +func (s *sensitiveServiceImpl) Replace(ctx context.Context, text string, repl string) string { + if !s.config.Enabled { + return text + } + if text == "" { + return text + } + + // 如果未指定替换符,使用默认替换符 + if repl == "" { + repl = s.replaceStr + } + + return s.tree.Replace(text, repl) +} + +// AddWord 添加敏感词 +func (s *sensitiveServiceImpl) AddWord(ctx context.Context, word string, category string, level int) error { + if word == "" { + return fmt.Errorf("word cannot be empty") + } + + // 转换为敏感词级别 + wordLevel := model.SensitiveWordLevel(level) + if wordLevel < 1 || wordLevel > 3 { + wordLevel = model.SensitiveWordLevelLow + } + + // 转换为敏感词分类 + wordCategory := model.SensitiveWordCategory(category) + if wordCategory == "" { + wordCategory = model.SensitiveWordCategoryOther + } + + // 添加到树 + s.tree.AddWord(word, wordLevel, wordCategory) + + // 持久化到数据库 + if s.db != nil { + sensitiveWord := model.SensitiveWord{ + Word: word, + Category: wordCategory, + Level: wordLevel, + IsActive: true, + } + + // 使用 upsert 逻辑 + var existing model.SensitiveWord + result := s.db.Where("word = ?", word).First(&existing) + if result.Error == gorm.ErrRecordNotFound { + if err := s.db.Create(&sensitiveWord).Error; err != nil { + log.Printf("Failed to save sensitive word to database: %v", err) + } + } else if result.Error == nil { + // 更新已存在的记录 + existing.Category = wordCategory + existing.Level = wordLevel + existing.IsActive = true + if err := s.db.Save(&existing).Error; err != nil { + log.Printf("Failed to update sensitive word in database: %v", err) + } + } + } + + // 同步到 Redis + if s.redis != nil && s.config.RedisKeyPrefix != "" { + key := fmt.Sprintf("%s:%s", s.config.RedisKeyPrefix, word) + data := map[string]interface{}{ + "word": word, + "category": category, + "level": level, + } + jsonData, _ := json.Marshal(data) + s.redis.Set(ctx, key, jsonData, 0) + } + + return nil +} + +// RemoveWord 移除敏感词 +func (s *sensitiveServiceImpl) RemoveWord(ctx context.Context, word string) error { + if word == "" { + return fmt.Errorf("word cannot be empty") + } + + // 从树中移除 + s.tree.RemoveWord(word) + + // 从数据库中标记为不活跃 + if s.db != nil { + result := s.db.Model(&model.SensitiveWord{}).Where("word = ?", word).Update("is_active", false) + if result.Error != nil { + log.Printf("Failed to deactivate sensitive word in database: %v", result.Error) + } + } + + // 从 Redis 中删除 + if s.redis != nil && s.config.RedisKeyPrefix != "" { + key := fmt.Sprintf("%s:%s", s.config.RedisKeyPrefix, word) + s.redis.Del(ctx, key) + } + + return nil +} + +// Reload 重新加载敏感词库 +func (s *sensitiveServiceImpl) Reload(ctx context.Context) error { + // 清空现有树 + s.tree = NewSensitiveWordTree() + + // 从数据库加载 + if s.config.LoadFromDB { + if err := s.loadFromDB(ctx); err != nil { + return fmt.Errorf("failed to load from database: %w", err) + } + } + + // 从 Redis 加载 + if s.config.LoadFromRedis && s.redis != nil { + if err := s.loadFromRedis(ctx); err != nil { + return fmt.Errorf("failed to load from redis: %w", err) + } + } + + return nil +} + +// GetWordCount 获取敏感词数量 +func (s *sensitiveServiceImpl) GetWordCount(ctx context.Context) int { + return s.tree.WordCount() +} + +// loadFromDB 从数据库加载敏感词 +func (s *sensitiveServiceImpl) loadFromDB(ctx context.Context) error { + if s.db == nil { + return nil + } + + var words []model.SensitiveWord + if err := s.db.Where("is_active = ?", true).Find(&words).Error; err != nil { + return err + } + + for _, word := range words { + s.tree.AddWord(word.Word, word.Level, word.Category) + } + + log.Printf("Loaded %d sensitive words from database", len(words)) + return nil +} + +// loadFromRedis 从 Redis 加载敏感词 +func (s *sensitiveServiceImpl) loadFromRedis(ctx context.Context) error { + if s.redis == nil || s.config.RedisKeyPrefix == "" { + return nil + } + + // 使用 SCAN 命令代替 KEYS,避免阻塞 + pattern := fmt.Sprintf("%s:*", s.config.RedisKeyPrefix) + var cursor uint64 + for { + keys, nextCursor, err := s.redis.GetClient().Scan(ctx, cursor, pattern, 100).Result() + if err != nil { + return err + } + + for _, key := range keys { + data, err := s.redis.Get(ctx, key) + if err != nil { + continue + } + + var wordData map[string]interface{} + if err := json.Unmarshal([]byte(data), &wordData); err != nil { + continue + } + + word, _ := wordData["word"].(string) + category, _ := wordData["category"].(string) + level, _ := wordData["level"].(float64) + + if word != "" { + s.tree.AddWord(word, model.SensitiveWordLevel(int(level)), model.SensitiveWordCategory(category)) + } + } + + cursor = nextCursor + if cursor == 0 { + break + } + } + + return nil +} + +// ==================== 辅助函数 ==================== + +// ContainsSensitiveWord 快速检查文本是否包含敏感词 +func ContainsSensitiveWord(text string, tree *SensitiveWordTree) bool { + if tree == nil || text == "" { + return false + } + hasSensitive, _ := tree.Check(text) + return hasSensitive +} + +// FilterSensitiveWords 过滤敏感词并返回替换后的文本 +func FilterSensitiveWords(text string, tree *SensitiveWordTree, repl string) string { + if tree == nil || text == "" { + return text + } + if repl == "" { + repl = "***" + } + return tree.Replace(text, repl) +} + +// ValidateTextLength 验证文本长度是否合法 +func ValidateTextLength(text string, minLen, maxLen int) bool { + length := utf8.RuneCountInString(text) + return length >= minLen && length <= maxLen +} + +// SanitizeText 清理文本,移除多余空白字符 +func SanitizeText(text string) string { + // 替换多个连续空白字符为单个空格 + spaceReg := regexp.MustCompile(`\s+`) + text = spaceReg.ReplaceAllString(text, " ") + // 去除首尾空白 + return strings.TrimSpace(text) +} + +// ==================== 默认敏感词列表 ==================== + +// DefaultSensitiveWords 返回默认敏感词列表(示例) +func DefaultSensitiveWords() map[string]struct{} { + return map[string]struct{}{ + // 示例敏感词,实际需要从数据库或配置加载 + "测试敏感词1": {}, + "测试敏感词2": {}, + "测试敏感词3": {}, + } +} diff --git a/internal/service/sticker_service.go b/internal/service/sticker_service.go new file mode 100644 index 0000000..feb6aee --- /dev/null +++ b/internal/service/sticker_service.go @@ -0,0 +1,139 @@ +package service + +import ( + "carrot_bbs/internal/model" + "carrot_bbs/internal/repository" + "errors" + "net/url" + "strings" +) + +var ( + ErrStickerAlreadyExists = errors.New("sticker already exists") + ErrInvalidStickerURL = errors.New("invalid sticker url") +) + +// StickerService 自定义表情服务接口 +type StickerService interface { + // 获取用户的所有表情 + GetUserStickers(userID string) ([]model.UserSticker, error) + // 添加表情 + AddSticker(userID string, url string, width, height int) (*model.UserSticker, error) + // 删除表情 + DeleteSticker(userID string, stickerID string) error + // 检查表情是否已存在 + CheckExists(userID string, url string) (bool, error) + // 重新排序 + ReorderStickers(userID string, orders map[string]int) error + // 获取用户表情数量 + GetStickerCount(userID string) (int64, error) +} + +// stickerService 自定义表情服务实现 +type stickerService struct { + stickerRepo repository.StickerRepository +} + +// NewStickerService 创建自定义表情服务 +func NewStickerService(stickerRepo repository.StickerRepository) StickerService { + return &stickerService{ + stickerRepo: stickerRepo, + } +} + +// GetUserStickers 获取用户的所有表情 +func (s *stickerService) GetUserStickers(userID string) ([]model.UserSticker, error) { + stickers, err := s.stickerRepo.GetByUserID(userID) + if err != nil { + return nil, err + } + + // 兼容历史脏数据:过滤本地文件 URI,避免客户端加载 file:// 报错 + filtered := make([]model.UserSticker, 0, len(stickers)) + for _, sticker := range stickers { + if isValidStickerURL(sticker.URL) { + filtered = append(filtered, sticker) + } + } + return filtered, nil +} + +// AddSticker 添加表情 +func (s *stickerService) AddSticker(userID string, url string, width, height int) (*model.UserSticker, error) { + if !isValidStickerURL(url) { + return nil, ErrInvalidStickerURL + } + + // 检查是否已存在 + exists, err := s.stickerRepo.Exists(userID, url) + if err != nil { + return nil, err + } + if exists { + return nil, ErrStickerAlreadyExists + } + + // 获取当前数量用于设置排序 + count, err := s.stickerRepo.CountByUserID(userID) + if err != nil { + return nil, err + } + + sticker := &model.UserSticker{ + UserID: userID, + URL: url, + Width: width, + Height: height, + SortOrder: int(count), // 新表情添加到末尾 + } + + if err := s.stickerRepo.Create(sticker); err != nil { + return nil, err + } + + return sticker, nil +} + +func isValidStickerURL(raw string) bool { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return false + } + + parsed, err := url.Parse(trimmed) + if err != nil { + return false + } + + scheme := strings.ToLower(parsed.Scheme) + return scheme == "http" || scheme == "https" +} + +// DeleteSticker 删除表情 +func (s *stickerService) DeleteSticker(userID string, stickerID string) error { + // 先检查表情是否属于该用户 + sticker, err := s.stickerRepo.GetByID(stickerID) + if err != nil { + return err + } + if sticker.UserID != userID { + return errors.New("sticker not found") + } + + return s.stickerRepo.Delete(stickerID) +} + +// CheckExists 检查表情是否已存在 +func (s *stickerService) CheckExists(userID string, url string) (bool, error) { + return s.stickerRepo.Exists(userID, url) +} + +// ReorderStickers 重新排序 +func (s *stickerService) ReorderStickers(userID string, orders map[string]int) error { + return s.stickerRepo.BatchUpdateSortOrder(userID, orders) +} + +// GetStickerCount 获取用户表情数量 +func (s *stickerService) GetStickerCount(userID string) (int64, error) { + return s.stickerRepo.CountByUserID(userID) +} diff --git a/internal/service/system_message_service.go b/internal/service/system_message_service.go new file mode 100644 index 0000000..231421c --- /dev/null +++ b/internal/service/system_message_service.go @@ -0,0 +1,462 @@ +package service + +import ( + "context" + "fmt" + "time" + + "carrot_bbs/internal/cache" + "carrot_bbs/internal/model" + "carrot_bbs/internal/pkg/utils" + "carrot_bbs/internal/repository" +) + +// SystemMessageService 系统消息服务接口 +type SystemMessageService interface { + // 发送互动通知 + SendLikeNotification(ctx context.Context, userID string, operatorID string, postID string) error + SendCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string) error + SendReplyNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, replyID string) error + SendFollowNotification(ctx context.Context, userID string, operatorID string) error + SendMentionNotification(ctx context.Context, userID string, operatorID string, postID string) error + SendFavoriteNotification(ctx context.Context, userID string, operatorID string, postID string) error + SendLikeCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, commentContent string) error + SendLikeReplyNotification(ctx context.Context, userID string, operatorID string, postID string, replyID string, replyContent string) error + + // 发送系统公告 + SendSystemAnnouncement(ctx context.Context, userIDs []string, title string, content string) error + SendBroadcastAnnouncement(ctx context.Context, title string, content string) error +} + +type systemMessageServiceImpl struct { + notifyRepo *repository.SystemNotificationRepository + pushService PushService + userRepo *repository.UserRepository + postRepo *repository.PostRepository + cache cache.Cache +} + +// NewSystemMessageService 创建系统消息服务 +func NewSystemMessageService( + notifyRepo *repository.SystemNotificationRepository, + pushService PushService, + userRepo *repository.UserRepository, + postRepo *repository.PostRepository, +) SystemMessageService { + return &systemMessageServiceImpl{ + notifyRepo: notifyRepo, + pushService: pushService, + userRepo: userRepo, + postRepo: postRepo, + cache: cache.GetCache(), + } +} + +// SendLikeNotification 发送点赞通知 +func (s *systemMessageServiceImpl) SendLikeNotification(ctx context.Context, userID string, operatorID string, postID string) error { + // 获取操作者信息 + actorName, avatarURL, err := s.getActorInfo(ctx, operatorID) + if err != nil { + return err + } + + // 获取帖子标题 + postTitle, err := s.getPostTitle(postID) + if err != nil { + postTitle = "您的帖子" + } + + extraData := &model.SystemNotificationExtra{ + ActorIDStr: operatorID, + ActorName: actorName, + AvatarURL: avatarURL, + TargetID: postID, + TargetTitle: postTitle, + TargetType: "post", + ActionURL: fmt.Sprintf("/posts/%s", postID), + ActionTime: time.Now().Format(time.RFC3339), + } + + content := fmt.Sprintf("%s 赞了「%s」", actorName, postTitle) + + // 创建通知 + notification, err := s.createNotification(ctx, userID, model.SysNotifyLikePost, content, extraData) + if err != nil { + return fmt.Errorf("failed to create like notification: %w", err) + } + + // 推送通知 + return s.pushService.PushSystemNotification(ctx, userID, notification) +} + +// SendCommentNotification 发送评论通知 +func (s *systemMessageServiceImpl) SendCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string) error { + // 获取操作者信息 + actorName, avatarURL, err := s.getActorInfo(ctx, operatorID) + if err != nil { + return err + } + + // 获取帖子标题 + postTitle, err := s.getPostTitle(postID) + if err != nil { + postTitle = "您的帖子" + } + + extraData := &model.SystemNotificationExtra{ + ActorIDStr: operatorID, + ActorName: actorName, + AvatarURL: avatarURL, + TargetID: postID, + TargetTitle: postTitle, + TargetType: "comment", + ActionURL: fmt.Sprintf("/posts/%s?comment=%s", postID, commentID), + ActionTime: time.Now().Format(time.RFC3339), + } + + content := fmt.Sprintf("%s 评论了「%s」", actorName, postTitle) + + // 创建通知 + notification, err := s.createNotification(ctx, userID, model.SysNotifyComment, content, extraData) + if err != nil { + return fmt.Errorf("failed to create comment notification: %w", err) + } + + // 推送通知 + return s.pushService.PushSystemNotification(ctx, userID, notification) +} + +// SendReplyNotification 发送回复通知 +func (s *systemMessageServiceImpl) SendReplyNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, replyID string) error { + // 获取操作者信息 + actorName, avatarURL, err := s.getActorInfo(ctx, operatorID) + if err != nil { + return err + } + + // 获取帖子标题 + postTitle, err := s.getPostTitle(postID) + if err != nil { + postTitle = "您的帖子" + } + + extraData := &model.SystemNotificationExtra{ + ActorIDStr: operatorID, + ActorName: actorName, + AvatarURL: avatarURL, + TargetID: replyID, + TargetTitle: postTitle, + TargetType: "reply", + ActionURL: fmt.Sprintf("/posts/%s?comment=%s&reply=%s", postID, commentID, replyID), + ActionTime: time.Now().Format(time.RFC3339), + } + + content := fmt.Sprintf("%s 回复了您在「%s」的评论", actorName, postTitle) + + // 创建通知 + notification, err := s.createNotification(ctx, userID, model.SysNotifyReply, content, extraData) + if err != nil { + return fmt.Errorf("failed to create reply notification: %w", err) + } + + // 推送通知 + return s.pushService.PushSystemNotification(ctx, userID, notification) +} + +// SendFollowNotification 发送关注通知 +func (s *systemMessageServiceImpl) SendFollowNotification(ctx context.Context, userID string, operatorID string) error { + fmt.Printf("[DEBUG] SendFollowNotification: userID=%s, operatorID=%s\n", userID, operatorID) + + // 获取操作者信息 + actorName, avatarURL, err := s.getActorInfo(ctx, operatorID) + if err != nil { + fmt.Printf("[DEBUG] SendFollowNotification: getActorInfo error: %v\n", err) + return err + } + fmt.Printf("[DEBUG] SendFollowNotification: actorName=%s, avatarURL=%s\n", actorName, avatarURL) + + extraData := &model.SystemNotificationExtra{ + ActorIDStr: operatorID, + ActorName: actorName, + AvatarURL: avatarURL, + TargetID: "", + TargetType: "user", + ActionURL: fmt.Sprintf("/users/%s", operatorID), + ActionTime: time.Now().Format(time.RFC3339), + } + + content := fmt.Sprintf("%s 关注了你", actorName) + + // 创建通知 + notification, err := s.createNotification(ctx, userID, model.SysNotifyFollow, content, extraData) + if err != nil { + return fmt.Errorf("failed to create follow notification: %w", err) + } + + fmt.Printf("[DEBUG] SendFollowNotification: notification created, ID=%d, Content=%s\n", notification.ID, notification.Content) + + // 推送通知 + pushErr := s.pushService.PushSystemNotification(ctx, userID, notification) + if pushErr != nil { + fmt.Printf("[DEBUG] SendFollowNotification: PushSystemNotification error: %v\n", pushErr) + } else { + fmt.Printf("[DEBUG] SendFollowNotification: PushSystemNotification success\n") + } + return pushErr +} + +// SendFavoriteNotification 发送收藏通知 +func (s *systemMessageServiceImpl) SendFavoriteNotification(ctx context.Context, userID string, operatorID string, postID string) error { + // 获取操作者信息 + actorName, avatarURL, err := s.getActorInfo(ctx, operatorID) + if err != nil { + return err + } + + // 获取帖子标题 + postTitle, err := s.getPostTitle(postID) + if err != nil { + postTitle = "您的帖子" + } + + extraData := &model.SystemNotificationExtra{ + ActorIDStr: operatorID, + ActorName: actorName, + AvatarURL: avatarURL, + TargetID: postID, + TargetTitle: postTitle, + TargetType: "post", + ActionURL: fmt.Sprintf("/posts/%s", postID), + ActionTime: time.Now().Format(time.RFC3339), + } + + content := fmt.Sprintf("%s 收藏了「%s」", actorName, postTitle) + + // 创建通知 + notification, err := s.createNotification(ctx, userID, model.SysNotifyFavoritePost, content, extraData) + if err != nil { + return fmt.Errorf("failed to create favorite notification: %w", err) + } + + // 推送通知 + return s.pushService.PushSystemNotification(ctx, userID, notification) +} + +// SendLikeCommentNotification 发送评论点赞通知 +func (s *systemMessageServiceImpl) SendLikeCommentNotification(ctx context.Context, userID string, operatorID string, postID string, commentID string, commentContent string) error { + // 获取操作者信息 + actorName, avatarURL, err := s.getActorInfo(ctx, operatorID) + if err != nil { + return err + } + + // 截取评论内容预览(最多50字) + preview := commentContent + runes := []rune(preview) + if len(runes) > 50 { + preview = string(runes[:50]) + "..." + } + + extraData := &model.SystemNotificationExtra{ + ActorIDStr: operatorID, + ActorName: actorName, + AvatarURL: avatarURL, + TargetID: postID, + TargetTitle: preview, + TargetType: "comment", + ActionURL: fmt.Sprintf("/posts/%s?comment=%s", postID, commentID), + ActionTime: time.Now().Format(time.RFC3339), + } + + content := fmt.Sprintf("%s 赞了您的评论", actorName) + + // 创建通知 + notification, err := s.createNotification(ctx, userID, model.SysNotifyLikeComment, content, extraData) + if err != nil { + return fmt.Errorf("failed to create like comment notification: %w", err) + } + + // 推送通知 + return s.pushService.PushSystemNotification(ctx, userID, notification) +} + +// SendLikeReplyNotification 发送回复点赞通知 +func (s *systemMessageServiceImpl) SendLikeReplyNotification(ctx context.Context, userID string, operatorID string, postID string, replyID string, replyContent string) error { + // 获取操作者信息 + actorName, avatarURL, err := s.getActorInfo(ctx, operatorID) + if err != nil { + return err + } + + // 截取回复内容预览(最多50字) + preview := replyContent + runes := []rune(preview) + if len(runes) > 50 { + preview = string(runes[:50]) + "..." + } + + extraData := &model.SystemNotificationExtra{ + ActorIDStr: operatorID, + ActorName: actorName, + AvatarURL: avatarURL, + TargetID: postID, + TargetTitle: preview, + TargetType: "reply", + ActionURL: fmt.Sprintf("/posts/%s?reply=%s", postID, replyID), + ActionTime: time.Now().Format(time.RFC3339), + } + + content := fmt.Sprintf("%s 赞了您的回复", actorName) + + // 创建通知 + notification, err := s.createNotification(ctx, userID, model.SysNotifyLikeReply, content, extraData) + if err != nil { + return fmt.Errorf("failed to create like reply notification: %w", err) + } + + // 推送通知 + return s.pushService.PushSystemNotification(ctx, userID, notification) +} + +// SendMentionNotification 发送@提及通知 +func (s *systemMessageServiceImpl) SendMentionNotification(ctx context.Context, userID string, operatorID string, postID string) error { + // 获取操作者信息 + actorName, avatarURL, err := s.getActorInfo(ctx, operatorID) + if err != nil { + return err + } + + // 获取帖子标题 + postTitle, err := s.getPostTitle(postID) + if err != nil { + postTitle = "您的帖子" + } + + extraData := &model.SystemNotificationExtra{ + ActorIDStr: operatorID, + ActorName: actorName, + AvatarURL: avatarURL, + TargetID: postID, + TargetTitle: postTitle, + TargetType: "post", + ActionURL: fmt.Sprintf("/posts/%s", postID), + ActionTime: time.Now().Format(time.RFC3339), + } + + content := fmt.Sprintf("%s 在「%s」中提到了你", actorName, postTitle) + + // 创建通知 + notification, err := s.createNotification(ctx, userID, model.SysNotifyMention, content, extraData) + if err != nil { + return fmt.Errorf("failed to create mention notification: %w", err) + } + + // 推送通知 + return s.pushService.PushSystemNotification(ctx, userID, notification) +} + +// SendSystemAnnouncement 发送系统公告给指定用户 +func (s *systemMessageServiceImpl) SendSystemAnnouncement(ctx context.Context, userIDs []string, title string, content string) error { + for _, userID := range userIDs { + extraData := &model.SystemNotificationExtra{ + TargetType: "announcement", + ActionTime: time.Now().Format(time.RFC3339), + } + + notification, err := s.createNotification(ctx, userID, model.SysNotifyAnnounce, fmt.Sprintf("【%s】%s", title, content), extraData) + if err != nil { + continue // 单个失败不影响其他用户 + } + + // 推送通知(使用高优先级) + if err := s.pushService.PushSystemNotification(ctx, userID, notification); err != nil { + continue + } + } + return nil +} + +// SendBroadcastAnnouncement 发送广播公告给所有在线用户 +func (s *systemMessageServiceImpl) SendBroadcastAnnouncement(ctx context.Context, title string, content string) error { + // TODO: 实现广播公告 + // 1. 获取所有在线用户 + // 2. 批量发送公告 + // 3. 对于离线用户,存储为待推送记录 + return fmt.Errorf("broadcast announcement not implemented") +} + +// createNotification 创建系统通知(存储到独立表) +func (s *systemMessageServiceImpl) createNotification(ctx context.Context, userID string, notifyType model.SystemNotificationType, content string, extraData *model.SystemNotificationExtra) (*model.SystemNotification, error) { + fmt.Printf("[DEBUG] createNotification: userID=%s, notifyType=%s\n", userID, notifyType) + + // 生成雪花算法ID + id, err := utils.GetSnowflake().GenerateID() + if err != nil { + fmt.Printf("[DEBUG] createNotification: failed to generate ID: %v\n", err) + return nil, fmt.Errorf("failed to generate notification ID: %w", err) + } + + notification := &model.SystemNotification{ + ID: id, + ReceiverID: userID, + Type: notifyType, + Content: content, + ExtraData: extraData, + IsRead: false, + } + + fmt.Printf("[DEBUG] createNotification: notification created with ID=%d, ReceiverID=%s\n", id, userID) + + // 保存通知到数据库 + if err := s.notifyRepo.Create(notification); err != nil { + fmt.Printf("[DEBUG] createNotification: failed to save notification: %v\n", err) + return nil, fmt.Errorf("failed to save notification: %w", err) + } + + // 失效系统消息未读数缓存 + cache.InvalidateUnreadSystem(s.cache, userID) + + fmt.Printf("[DEBUG] createNotification: notification saved successfully, ID=%d\n", notification.ID) + return notification, nil +} + +// getActorInfo 获取操作者信息 +func (s *systemMessageServiceImpl) getActorInfo(ctx context.Context, operatorID string) (string, string, error) { + // 从用户仓储获取用户信息 + if s.userRepo != nil { + user, err := s.userRepo.GetByID(operatorID) + if err != nil { + fmt.Printf("[DEBUG] getActorInfo: failed to get user %s: %v\n", operatorID, err) + return "用户", utils.GenerateDefaultAvatarURL("用户"), nil // 返回默认值,不阻断流程 + } + avatar := utils.GetAvatarOrDefault(user.Username, user.Nickname, user.Avatar) + return user.Nickname, avatar, nil + } + // 如果没有用户仓储,返回默认值 + return "用户", utils.GenerateDefaultAvatarURL("用户"), nil +} + +// getPostTitle 获取帖子标题 +func (s *systemMessageServiceImpl) getPostTitle(postID string) (string, error) { + if s.postRepo == nil { + if len(postID) >= 8 { + return fmt.Sprintf("帖子#%s", postID[:8]), nil + } + return fmt.Sprintf("帖子#%s", postID), nil + } + post, err := s.postRepo.GetByID(postID) + if err != nil { + if len(postID) >= 8 { + return fmt.Sprintf("帖子#%s", postID[:8]), nil + } + return fmt.Sprintf("帖子#%s", postID), nil + } + if post.Title != "" { + return post.Title, nil + } + // 如果没有标题,返回内容前20个字符 + if len(post.Content) > 20 { + return post.Content[:20] + "...", nil + } + return post.Content, nil +} diff --git a/internal/service/upload_service.go b/internal/service/upload_service.go new file mode 100644 index 0000000..01b0598 --- /dev/null +++ b/internal/service/upload_service.go @@ -0,0 +1,273 @@ +package service + +import ( + "bytes" + "context" + "crypto/sha256" + "fmt" + "image" + "image/jpeg" + "image/png" + "io" + "mime" + "mime/multipart" + "net/http" + "path/filepath" + "strings" + + "carrot_bbs/internal/pkg/s3" + _ "golang.org/x/image/bmp" + _ "golang.org/x/image/tiff" +) + +// UploadService 上传服务 +type UploadService struct { + s3Client *s3.Client + userService *UserService +} + +// NewUploadService 创建上传服务 +func NewUploadService(s3Client *s3.Client, userService *UserService) *UploadService { + return &UploadService{ + s3Client: s3Client, + userService: userService, + } +} + +// UploadImage 上传图片 +func (s *UploadService) UploadImage(ctx context.Context, file *multipart.FileHeader) (string, error) { + processedData, contentType, ext, err := prepareImageForUpload(file) + if err != nil { + return "", err + } + + // 压缩后再计算哈希,确保同一压缩结果映射同一对象名 + hash := sha256.Sum256(processedData) + hashStr := fmt.Sprintf("%x", hash) + + objectName := fmt.Sprintf("images/%s%s", hashStr, ext) + + url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType) + if err != nil { + return "", fmt.Errorf("failed to upload to S3: %w", err) + } + + return url, nil +} + +// getExtFromContentType 根据Content-Type获取文件扩展名 +func getExtFromContentType(contentType string) string { + baseType, _, err := mime.ParseMediaType(contentType) + if err == nil && baseType != "" { + contentType = baseType + } + + switch contentType { + case "image/jpg", "image/jpeg": + return ".jpg" + case "image/png": + return ".png" + case "image/gif": + return ".gif" + case "image/webp": + return ".webp" + case "image/bmp", "image/x-ms-bmp": + return ".bmp" + case "image/tiff": + return ".tiff" + default: + return "" + } +} + +// UploadAvatar 上传头像 +func (s *UploadService) UploadAvatar(ctx context.Context, userID string, file *multipart.FileHeader) (string, error) { + processedData, contentType, ext, err := prepareImageForUpload(file) + if err != nil { + return "", err + } + + // 压缩后再计算哈希 + hash := sha256.Sum256(processedData) + hashStr := fmt.Sprintf("%x", hash) + + objectName := fmt.Sprintf("avatars/%s%s", hashStr, ext) + + url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType) + if err != nil { + return "", fmt.Errorf("failed to upload to S3: %w", err) + } + + // 更新用户头像 + if s.userService != nil { + user, err := s.userService.GetUserByID(ctx, userID) + if err == nil && user != nil { + user.Avatar = url + err = s.userService.UpdateUser(ctx, user) + if err != nil { + // 更新失败不影响上传结果,只记录日志 + fmt.Printf("[UploadAvatar] failed to update user avatar: %v\n", err) + } + } + } + + return url, nil +} + +// UploadCover 上传头图(个人主页封面) +func (s *UploadService) UploadCover(ctx context.Context, userID string, file *multipart.FileHeader) (string, error) { + processedData, contentType, ext, err := prepareImageForUpload(file) + if err != nil { + return "", err + } + + // 压缩后再计算哈希 + hash := sha256.Sum256(processedData) + hashStr := fmt.Sprintf("%x", hash) + + objectName := fmt.Sprintf("covers/%s%s", hashStr, ext) + + url, err := s.s3Client.UploadData(ctx, objectName, processedData, contentType) + if err != nil { + return "", fmt.Errorf("failed to upload to S3: %w", err) + } + + // 更新用户头图 + if s.userService != nil { + user, err := s.userService.GetUserByID(ctx, userID) + if err == nil && user != nil { + user.CoverURL = url + err = s.userService.UpdateUser(ctx, user) + if err != nil { + // 更新失败不影响上传结果,只记录日志 + fmt.Printf("[UploadCover] failed to update user cover: %v\n", err) + } + } + } + + return url, nil +} + +// GetURL 获取文件URL +func (s *UploadService) GetURL(ctx context.Context, objectName string) (string, error) { + return s.s3Client.GetURL(ctx, objectName) +} + +// Delete 删除文件 +func (s *UploadService) Delete(ctx context.Context, objectName string) error { + return s.s3Client.Delete(ctx, objectName) +} + +func prepareImageForUpload(file *multipart.FileHeader) ([]byte, string, string, error) { + f, err := file.Open() + if err != nil { + return nil, "", "", fmt.Errorf("failed to open file: %w", err) + } + defer f.Close() + + originalData, err := io.ReadAll(f) + if err != nil { + return nil, "", "", fmt.Errorf("failed to read file: %w", err) + } + + // 优先从文件字节探测真实类型,避免前端压缩/转码后 header 与实际格式不一致 + detectedType := normalizeImageContentType(http.DetectContentType(originalData)) + headerType := normalizeImageContentType(file.Header.Get("Content-Type")) + contentType := detectedType + if contentType == "" || contentType == "application/octet-stream" { + contentType = headerType + } + + compressedData, compressedType, err := compressImageData(originalData, contentType) + if err != nil { + // 压缩失败时回退到原图,保证上传可用性 + compressedData = originalData + compressedType = contentType + } + + if compressedType == "" { + compressedType = contentType + } + if compressedType == "" { + compressedType = http.DetectContentType(compressedData) + } + + ext := getExtFromContentType(compressedType) + if ext == "" { + ext = strings.ToLower(filepath.Ext(file.Filename)) + } + if ext == "" { + // 最终兜底,避免对象名无扩展名导致 URL 语义不明确 + ext = ".jpg" + } + + return compressedData, compressedType, ext, nil +} + +func compressImageData(data []byte, contentType string) ([]byte, string, error) { + contentType = normalizeImageContentType(contentType) + + // GIF/WebP 等格式先保留原图,避免动画和透明通道丢失 + if contentType == "image/gif" || contentType == "image/webp" { + return data, contentType, nil + } + + if contentType != "image/jpeg" && + contentType != "image/png" && + contentType != "image/bmp" && + contentType != "image/x-ms-bmp" && + contentType != "image/tiff" { + return data, contentType, nil + } + + img, _, err := image.Decode(bytes.NewReader(data)) + if err != nil { + return nil, "", fmt.Errorf("failed to decode image: %w", err) + } + + var buf bytes.Buffer + switch contentType { + case "image/png": + encoder := png.Encoder{CompressionLevel: png.BestCompression} + if err := encoder.Encode(&buf, img); err != nil { + return nil, "", fmt.Errorf("failed to encode png: %w", err) + } + return buf.Bytes(), "image/png", nil + default: + // BMP/TIFF 等无损大图统一压缩为 JPEG,控制体积 + if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 82}); err != nil { + return nil, "", fmt.Errorf("failed to encode jpeg: %w", err) + } + return buf.Bytes(), "image/jpeg", nil + } +} + +func normalizeImageContentType(contentType string) string { + if contentType == "" { + return "" + } + + baseType, _, err := mime.ParseMediaType(contentType) + if err == nil && baseType != "" { + contentType = baseType + } + + switch strings.ToLower(contentType) { + case "image/jpg": + return "image/jpeg" + case "image/jpeg": + return "image/jpeg" + case "image/png": + return "image/png" + case "image/gif": + return "image/gif" + case "image/webp": + return "image/webp" + case "image/bmp", "image/x-ms-bmp": + return "image/bmp" + case "image/tiff": + return "image/tiff" + default: + return contentType + } +} diff --git a/internal/service/user_service.go b/internal/service/user_service.go new file mode 100644 index 0000000..6d956d0 --- /dev/null +++ b/internal/service/user_service.go @@ -0,0 +1,592 @@ +package service + +import ( + "context" + "fmt" + "strings" + + "carrot_bbs/internal/cache" + "carrot_bbs/internal/model" + "carrot_bbs/internal/pkg/utils" + "carrot_bbs/internal/repository" +) + +// UserService 用户服务 +type UserService struct { + userRepo *repository.UserRepository + systemMessageService SystemMessageService + emailCodeService EmailCodeService +} + +// NewUserService 创建用户服务 +func NewUserService( + userRepo *repository.UserRepository, + systemMessageService SystemMessageService, + emailService EmailService, + cacheBackend cache.Cache, +) *UserService { + return &UserService{ + userRepo: userRepo, + systemMessageService: systemMessageService, + emailCodeService: NewEmailCodeService(emailService, cacheBackend), + } +} + +// SendRegisterCode 发送注册验证码 +func (s *UserService) SendRegisterCode(ctx context.Context, email string) error { + user, err := s.userRepo.GetByEmail(email) + if err == nil && user != nil { + return ErrEmailExists + } + return s.emailCodeService.SendCode(ctx, CodePurposeRegister, email) +} + +// SendPasswordResetCode 发送找回密码验证码 +func (s *UserService) SendPasswordResetCode(ctx context.Context, email string) error { + user, err := s.userRepo.GetByEmail(email) + if err != nil || user == nil { + return ErrUserNotFound + } + return s.emailCodeService.SendCode(ctx, CodePurposePasswordReset, email) +} + +// SendCurrentUserEmailVerifyCode 发送当前用户邮箱验证验证码 +func (s *UserService) SendCurrentUserEmailVerifyCode(ctx context.Context, userID, email string) error { + user, err := s.userRepo.GetByID(userID) + if err != nil || user == nil { + return ErrUserNotFound + } + + targetEmail := strings.TrimSpace(email) + if targetEmail == "" && user.Email != nil { + targetEmail = strings.TrimSpace(*user.Email) + } + if targetEmail == "" || !utils.ValidateEmail(targetEmail) { + return ErrInvalidEmail + } + + if user.EmailVerified && user.Email != nil && strings.EqualFold(strings.TrimSpace(*user.Email), targetEmail) { + return ErrEmailAlreadyVerified + } + + existingUser, queryErr := s.userRepo.GetByEmail(targetEmail) + if queryErr == nil && existingUser != nil && existingUser.ID != userID { + return ErrEmailExists + } + + return s.emailCodeService.SendCode(ctx, CodePurposeEmailVerify, targetEmail) +} + +// VerifyCurrentUserEmail 验证当前用户邮箱 +func (s *UserService) VerifyCurrentUserEmail(ctx context.Context, userID, email, verificationCode string) error { + user, err := s.userRepo.GetByID(userID) + if err != nil || user == nil { + return ErrUserNotFound + } + + targetEmail := strings.TrimSpace(email) + if targetEmail == "" && user.Email != nil { + targetEmail = strings.TrimSpace(*user.Email) + } + if targetEmail == "" || !utils.ValidateEmail(targetEmail) { + return ErrInvalidEmail + } + + if err := s.emailCodeService.VerifyCode(CodePurposeEmailVerify, targetEmail, verificationCode); err != nil { + return err + } + + existingUser, queryErr := s.userRepo.GetByEmail(targetEmail) + if queryErr == nil && existingUser != nil && existingUser.ID != userID { + return ErrEmailExists + } + + user.Email = &targetEmail + user.EmailVerified = true + return s.userRepo.Update(user) +} + +// SendChangePasswordCode 发送修改密码验证码 +func (s *UserService) SendChangePasswordCode(ctx context.Context, userID string) error { + user, err := s.userRepo.GetByID(userID) + if err != nil || user == nil { + return ErrUserNotFound + } + if user.Email == nil || strings.TrimSpace(*user.Email) == "" { + return ErrEmailNotBound + } + return s.emailCodeService.SendCode(ctx, CodePurposeChangePassword, *user.Email) +} + +// Register 用户注册 +func (s *UserService) Register(ctx context.Context, username, email, password, nickname, phone, verificationCode string) (*model.User, error) { + // 验证用户名 + if !utils.ValidateUsername(username) { + return nil, ErrInvalidUsername + } + + // 注册必须提供邮箱并完成验证码校验 + if email == "" || !utils.ValidateEmail(email) { + return nil, ErrInvalidEmail + } + if err := s.emailCodeService.VerifyCode(CodePurposeRegister, email, verificationCode); err != nil { + return nil, err + } + + // 验证密码 + if !utils.ValidatePassword(password) { + return nil, ErrWeakPassword + } + + // 验证手机号(如果提供) + if phone != "" && !utils.ValidatePhone(phone) { + return nil, ErrInvalidPhone + } + + // 检查用户名是否已存在 + existingUser, err := s.userRepo.GetByUsername(username) + if err == nil && existingUser != nil { + return nil, ErrUsernameExists + } + + // 检查邮箱是否已存在(如果提供) + if email != "" { + existingUser, err = s.userRepo.GetByEmail(email) + if err == nil && existingUser != nil { + return nil, ErrEmailExists + } + } + + // 检查手机号是否已存在(如果提供) + if phone != "" { + existingUser, err = s.userRepo.GetByPhone(phone) + if err == nil && existingUser != nil { + return nil, ErrPhoneExists + } + } + + // 密码哈希 + hashedPassword, err := utils.HashPassword(password) + if err != nil { + return nil, err + } + + // 创建用户 + user := &model.User{ + Username: username, + Nickname: nickname, + EmailVerified: true, + PasswordHash: hashedPassword, + Status: model.UserStatusActive, + } + + // 如果提供了邮箱,设置指针值 + if email != "" { + user.Email = &email + } + + // 如果提供了手机号,设置指针值 + if phone != "" { + user.Phone = &phone + } + + err = s.userRepo.Create(user) + if err != nil { + return nil, err + } + + return user, nil +} + +// Login 用户登录 +func (s *UserService) Login(ctx context.Context, account, password string) (*model.User, error) { + account = strings.TrimSpace(account) + var ( + user *model.User + err error + ) + if utils.ValidateEmail(account) { + user, err = s.userRepo.GetByEmail(account) + } else if utils.ValidatePhone(account) { + user, err = s.userRepo.GetByPhone(account) + } else { + user, err = s.userRepo.GetByUsername(account) + } + if err != nil || user == nil { + return nil, ErrInvalidCredentials + } + + if !utils.CheckPasswordHash(password, user.PasswordHash) { + return nil, ErrInvalidCredentials + } + + if user.Status != model.UserStatusActive { + return nil, ErrUserBanned + } + + return user, nil +} + +// GetUserByID 根据ID获取用户 +func (s *UserService) GetUserByID(ctx context.Context, id string) (*model.User, error) { + return s.userRepo.GetByID(id) +} + +// GetUserPostCount 获取用户帖子数(实时计算) +func (s *UserService) GetUserPostCount(ctx context.Context, userID string) (int64, error) { + return s.userRepo.GetPostsCount(userID) +} + +// GetUserPostCountBatch 批量获取用户帖子数(实时计算) +func (s *UserService) GetUserPostCountBatch(ctx context.Context, userIDs []string) (map[string]int64, error) { + return s.userRepo.GetPostsCountBatch(userIDs) +} + +// GetUserByIDWithFollowingStatus 根据ID获取用户(包含当前用户是否关注的状态) +func (s *UserService) GetUserByIDWithFollowingStatus(ctx context.Context, userID, currentUserID string) (*model.User, bool, error) { + user, err := s.userRepo.GetByID(userID) + if err != nil { + return nil, false, err + } + + // 如果查询的是当前用户自己,不需要检查关注状态 + if userID == currentUserID { + return user, false, nil + } + + isFollowing, err := s.userRepo.IsFollowing(currentUserID, userID) + if err != nil { + return user, false, err + } + + return user, isFollowing, nil +} + +// GetUserByIDWithMutualFollowStatus 根据ID获取用户(包含双向关注状态) +func (s *UserService) GetUserByIDWithMutualFollowStatus(ctx context.Context, userID, currentUserID string) (*model.User, bool, bool, error) { + user, err := s.userRepo.GetByID(userID) + if err != nil { + return nil, false, false, err + } + + // 如果查询的是当前用户自己,不需要检查关注状态 + if userID == currentUserID { + return user, false, false, nil + } + + // 当前用户是否关注了该用户 + isFollowing, err := s.userRepo.IsFollowing(currentUserID, userID) + if err != nil { + return user, false, false, err + } + + // 该用户是否关注了当前用户 + isFollowingMe, err := s.userRepo.IsFollowing(userID, currentUserID) + if err != nil { + return user, isFollowing, false, err + } + + return user, isFollowing, isFollowingMe, nil +} + +// UpdateUser 更新用户 +func (s *UserService) UpdateUser(ctx context.Context, user *model.User) error { + return s.userRepo.Update(user) +} + +// GetFollowers 获取粉丝 +func (s *UserService) GetFollowers(ctx context.Context, userID string, page, pageSize int) ([]*model.User, int64, error) { + return s.userRepo.GetFollowers(userID, page, pageSize) +} + +// GetFollowing 获取关注 +func (s *UserService) GetFollowing(ctx context.Context, userID string, page, pageSize int) ([]*model.User, int64, error) { + return s.userRepo.GetFollowing(userID, page, pageSize) +} + +// FollowUser 关注用户 +func (s *UserService) FollowUser(ctx context.Context, followerID, followeeID string) error { + fmt.Printf("[DEBUG] FollowUser called: followerID=%s, followeeID=%s\n", followerID, followeeID) + + blocked, err := s.userRepo.IsBlockedEitherDirection(followerID, followeeID) + if err != nil { + return err + } + if blocked { + return ErrUserBlocked + } + + // 检查是否已经关注 + isFollowing, err := s.userRepo.IsFollowing(followerID, followeeID) + if err != nil { + fmt.Printf("[DEBUG] Error checking existing follow: %v\n", err) + return err + } + if isFollowing { + fmt.Printf("[DEBUG] Already following, skip creation\n") + return nil // 已经关注,直接返回成功 + } + + // 创建关注关系 + follow := &model.Follow{ + FollowerID: followerID, + FollowingID: followeeID, + } + + err = s.userRepo.CreateFollow(follow) + if err != nil { + fmt.Printf("[DEBUG] CreateFollow error: %v\n", err) + return err + } + + fmt.Printf("[DEBUG] Follow record created successfully\n") + + // 刷新关注者的关注数(通过实际计数,更可靠) + err = s.userRepo.RefreshFollowingCount(followerID) + if err != nil { + fmt.Printf("[DEBUG] Error refreshing following count: %v\n", err) + // 不回滚,计数可以通过其他方式修复 + } + + // 刷新被关注者的粉丝数(通过实际计数,更可靠) + err = s.userRepo.RefreshFollowersCount(followeeID) + if err != nil { + fmt.Printf("[DEBUG] Error refreshing followers count: %v\n", err) + // 不回滚,计数可以通过其他方式修复 + } + + // 发送关注通知给被关注者 + if s.systemMessageService != nil { + // 异步发送通知,不阻塞主流程 + go func() { + notifyErr := s.systemMessageService.SendFollowNotification(context.Background(), followeeID, followerID) + if notifyErr != nil { + fmt.Printf("[DEBUG] Error sending follow notification: %v\n", notifyErr) + } else { + fmt.Printf("[DEBUG] Follow notification sent successfully to %s\n", followeeID) + } + }() + } + + fmt.Printf("[DEBUG] FollowUser completed: followerID=%s, followeeID=%s\n", followerID, followeeID) + return nil +} + +// UnfollowUser 取消关注用户 +func (s *UserService) UnfollowUser(ctx context.Context, followerID, followeeID string) error { + fmt.Printf("[DEBUG] UnfollowUser called: followerID=%s, followeeID=%s\n", followerID, followeeID) + + // 检查是否已经关注 + isFollowing, err := s.userRepo.IsFollowing(followerID, followeeID) + if err != nil { + fmt.Printf("[DEBUG] Error checking existing follow: %v\n", err) + return err + } + if !isFollowing { + fmt.Printf("[DEBUG] Not following, skip deletion\n") + return nil // 没有关注,直接返回成功 + } + + // 删除关注关系 + err = s.userRepo.DeleteFollow(followerID, followeeID) + if err != nil { + fmt.Printf("[DEBUG] DeleteFollow error: %v\n", err) + return err + } + + fmt.Printf("[DEBUG] Follow record deleted successfully\n") + + // 刷新关注者的关注数(通过实际计数,更可靠) + err = s.userRepo.RefreshFollowingCount(followerID) + if err != nil { + fmt.Printf("[DEBUG] Error refreshing following count: %v\n", err) + } + + // 刷新被关注者的粉丝数(通过实际计数,更可靠) + err = s.userRepo.RefreshFollowersCount(followeeID) + if err != nil { + fmt.Printf("[DEBUG] Error refreshing followers count: %v\n", err) + } + + fmt.Printf("[DEBUG] UnfollowUser completed: followerID=%s, followeeID=%s\n", followerID, followeeID) + return nil +} + +// BlockUser 拉黑用户,并自动清理双向关注/粉丝关系 +func (s *UserService) BlockUser(ctx context.Context, blockerID, blockedID string) error { + if blockerID == blockedID { + return ErrInvalidOperation + } + return s.userRepo.BlockUserAndCleanupRelations(blockerID, blockedID) +} + +// UnblockUser 取消拉黑 +func (s *UserService) UnblockUser(ctx context.Context, blockerID, blockedID string) error { + if blockerID == blockedID { + return ErrInvalidOperation + } + return s.userRepo.UnblockUser(blockerID, blockedID) +} + +// GetBlockedUsers 获取黑名单列表 +func (s *UserService) GetBlockedUsers(ctx context.Context, blockerID string, page, pageSize int) ([]*model.User, int64, error) { + return s.userRepo.GetBlockedUsers(blockerID, page, pageSize) +} + +// IsBlocked 检查当前用户是否已拉黑目标用户 +func (s *UserService) IsBlocked(ctx context.Context, blockerID, blockedID string) (bool, error) { + return s.userRepo.IsBlocked(blockerID, blockedID) +} + +// GetFollowingList 获取关注列表(字符串参数版本) +func (s *UserService) GetFollowingList(ctx context.Context, userID, page, pageSize string) ([]*model.User, error) { + // 转换字符串参数为整数 + pageInt := 1 + pageSizeInt := 20 + if page != "" { + _, err := fmt.Sscanf(page, "%d", &pageInt) + if err != nil { + pageInt = 1 + } + } + if pageSize != "" { + _, err := fmt.Sscanf(pageSize, "%d", &pageSizeInt) + if err != nil { + pageSizeInt = 20 + } + } + + users, _, err := s.userRepo.GetFollowing(userID, pageInt, pageSizeInt) + return users, err +} + +// GetFollowersList 获取粉丝列表(字符串参数版本) +func (s *UserService) GetFollowersList(ctx context.Context, userID, page, pageSize string) ([]*model.User, error) { + // 转换字符串参数为整数 + pageInt := 1 + pageSizeInt := 20 + if page != "" { + _, err := fmt.Sscanf(page, "%d", &pageInt) + if err != nil { + pageInt = 1 + } + } + if pageSize != "" { + _, err := fmt.Sscanf(pageSize, "%d", &pageSizeInt) + if err != nil { + pageSizeInt = 20 + } + } + + users, _, err := s.userRepo.GetFollowers(userID, pageInt, pageSizeInt) + return users, err +} + +// GetMutualFollowStatus 批量获取双向关注状态 +func (s *UserService) GetMutualFollowStatus(ctx context.Context, currentUserID string, targetUserIDs []string) (map[string][2]bool, error) { + return s.userRepo.GetMutualFollowStatus(currentUserID, targetUserIDs) +} + +// CheckUsernameAvailable 检查用户名是否可用 +func (s *UserService) CheckUsernameAvailable(ctx context.Context, username string) (bool, error) { + user, err := s.userRepo.GetByUsername(username) + if err != nil { + return true, nil // 用户不存在,可用 + } + return user == nil, nil +} + +// ChangePassword 修改密码 +func (s *UserService) ChangePassword(ctx context.Context, userID, oldPassword, newPassword, verificationCode string) error { + // 获取用户 + user, err := s.userRepo.GetByID(userID) + if err != nil { + return ErrUserNotFound + } + if user.Email == nil || strings.TrimSpace(*user.Email) == "" { + return ErrEmailNotBound + } + if err := s.emailCodeService.VerifyCode(CodePurposeChangePassword, *user.Email, verificationCode); err != nil { + return err + } + + // 验证旧密码 + if !utils.CheckPasswordHash(oldPassword, user.PasswordHash) { + return ErrInvalidCredentials + } + + // 哈希新密码 + hashedPassword, err := utils.HashPassword(newPassword) + if err != nil { + return err + } + + // 更新密码 + user.PasswordHash = hashedPassword + return s.userRepo.Update(user) +} + +// ResetPasswordByEmail 通过邮箱重置密码 +func (s *UserService) ResetPasswordByEmail(ctx context.Context, email, verificationCode, newPassword string) error { + email = strings.TrimSpace(email) + if !utils.ValidateEmail(email) { + return ErrInvalidEmail + } + if !utils.ValidatePassword(newPassword) { + return ErrWeakPassword + } + if err := s.emailCodeService.VerifyCode(CodePurposePasswordReset, email, verificationCode); err != nil { + return err + } + + user, err := s.userRepo.GetByEmail(email) + if err != nil || user == nil { + return ErrUserNotFound + } + + hashedPassword, err := utils.HashPassword(newPassword) + if err != nil { + return err + } + user.PasswordHash = hashedPassword + return s.userRepo.Update(user) +} + +// Search 搜索用户 +func (s *UserService) Search(ctx context.Context, keyword string, page, pageSize int) ([]*model.User, int64, error) { + return s.userRepo.Search(keyword, page, pageSize) +} + +// 错误定义 +var ( + ErrInvalidUsername = &ServiceError{Code: 400, Message: "invalid username"} + ErrInvalidEmail = &ServiceError{Code: 400, Message: "invalid email"} + ErrInvalidPhone = &ServiceError{Code: 400, Message: "invalid phone number"} + ErrWeakPassword = &ServiceError{Code: 400, Message: "password too weak"} + ErrUsernameExists = &ServiceError{Code: 400, Message: "username already exists"} + ErrEmailExists = &ServiceError{Code: 400, Message: "email already exists"} + ErrPhoneExists = &ServiceError{Code: 400, Message: "phone number already exists"} + ErrUserNotFound = &ServiceError{Code: 404, Message: "user not found"} + ErrUserBanned = &ServiceError{Code: 403, Message: "user is banned"} + ErrUserBlocked = &ServiceError{Code: 403, Message: "blocked relationship exists"} + ErrInvalidOperation = &ServiceError{Code: 400, Message: "invalid operation"} + ErrEmailServiceUnavailable = &ServiceError{Code: 503, Message: "email service unavailable"} + ErrVerificationCodeTooFrequent = &ServiceError{Code: 429, Message: "verification code sent too frequently"} + ErrVerificationCodeInvalid = &ServiceError{Code: 400, Message: "invalid verification code"} + ErrVerificationCodeExpired = &ServiceError{Code: 400, Message: "verification code expired"} + ErrVerificationCodeUnavailable = &ServiceError{Code: 500, Message: "verification code storage unavailable"} + ErrEmailAlreadyVerified = &ServiceError{Code: 400, Message: "email already verified"} + ErrEmailNotBound = &ServiceError{Code: 400, Message: "email not bound"} +) + +// ServiceError 服务错误 +type ServiceError struct { + Code int + Message string +} + +func (e *ServiceError) Error() string { + return e.Message +} + +var ErrInvalidCredentials = &ServiceError{Code: 401, Message: "invalid username or password"} diff --git a/internal/service/vote_service.go b/internal/service/vote_service.go new file mode 100644 index 0000000..46482ff --- /dev/null +++ b/internal/service/vote_service.go @@ -0,0 +1,282 @@ +package service + +import ( + "context" + "errors" + "fmt" + "log" + "strings" + + "carrot_bbs/internal/cache" + "carrot_bbs/internal/dto" + "carrot_bbs/internal/model" + "carrot_bbs/internal/repository" +) + +// VoteService 投票服务 +type VoteService struct { + voteRepo *repository.VoteRepository + postRepo *repository.PostRepository + cache cache.Cache + postAIService *PostAIService + systemMessageService SystemMessageService +} + +// NewVoteService 创建投票服务 +func NewVoteService( + voteRepo *repository.VoteRepository, + postRepo *repository.PostRepository, + cache cache.Cache, + postAIService *PostAIService, + systemMessageService SystemMessageService, +) *VoteService { + return &VoteService{ + voteRepo: voteRepo, + postRepo: postRepo, + cache: cache, + postAIService: postAIService, + systemMessageService: systemMessageService, + } +} + +// CreateVotePost 创建投票帖子 +func (s *VoteService) CreateVotePost(ctx context.Context, userID string, req *dto.CreateVotePostRequest) (*dto.PostResponse, error) { + // 验证投票选项数量 + if len(req.VoteOptions) < 2 { + return nil, errors.New("投票选项至少需要2个") + } + if len(req.VoteOptions) > 10 { + return nil, errors.New("投票选项最多10个") + } + + // 创建普通帖子(设置IsVote=true) + post := &model.Post{ + UserID: userID, + CommunityID: req.CommunityID, + Title: req.Title, + Content: req.Content, + Status: model.PostStatusPending, + IsVote: true, + } + + err := s.postRepo.Create(post, req.Images) + if err != nil { + return nil, err + } + + // 创建投票选项 + err = s.voteRepo.CreateOptions(post.ID, req.VoteOptions) + if err != nil { + return nil, err + } + + // 异步审核 + go s.reviewVotePostAsync(post.ID, userID, req.Title, req.Content, req.Images) + + // 重新查询以获取关联的User和Images + createdPost, err := s.postRepo.GetByID(post.ID) + if err != nil { + return nil, err + } + + // 转换为响应DTO + return s.convertToPostResponse(createdPost, userID), nil +} + +func (s *VoteService) reviewVotePostAsync(postID, userID, title, content string, images []string) { + if s.postAIService == nil || !s.postAIService.IsEnabled() { + if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil { + log.Printf("[WARN] Failed to publish vote post without AI moderation: %v", err) + } + return + } + + err := s.postAIService.ModeratePost(context.Background(), title, content, images) + if err != nil { + var rejectedErr *PostModerationRejectedError + if errors.As(err, &rejectedErr) { + if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil { + log.Printf("[WARN] Failed to reject vote post %s: %v", postID, updateErr) + } + s.notifyModerationRejected(userID, rejectedErr.Reason) + return + } + + if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil { + log.Printf("[WARN] Failed to publish vote post %s after moderation error: %v", postID, updateErr) + } + return + } + + if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "ai"); err != nil { + log.Printf("[WARN] Failed to publish vote post %s: %v", postID, err) + } +} + +func (s *VoteService) notifyModerationRejected(userID, reason string) { + if s.systemMessageService == nil || strings.TrimSpace(userID) == "" { + return + } + + content := "您发布的投票帖未通过AI审核,请修改后重试。" + if strings.TrimSpace(reason) != "" { + content = fmt.Sprintf("您发布的投票帖未通过AI审核,原因:%s。请修改后重试。", reason) + } + + go func() { + _ = s.systemMessageService.SendSystemAnnouncement( + context.Background(), + []string{userID}, + "投票帖审核未通过", + content, + ) + }() +} + +// GetVoteOptions 获取投票选项 +func (s *VoteService) GetVoteOptions(postID string) ([]dto.VoteOptionDTO, error) { + options, err := s.voteRepo.GetOptionsByPostID(postID) + if err != nil { + return nil, err + } + + result := make([]dto.VoteOptionDTO, 0, len(options)) + for _, option := range options { + result = append(result, dto.VoteOptionDTO{ + ID: option.ID, + Content: option.Content, + VotesCount: option.VotesCount, + }) + } + + return result, nil +} + +// GetVoteResult 获取投票结果(包含用户投票状态) +func (s *VoteService) GetVoteResult(postID, userID string) (*dto.VoteResultDTO, error) { + // 获取所有投票选项 + options, err := s.voteRepo.GetOptionsByPostID(postID) + if err != nil { + return nil, err + } + + // 获取用户的投票记录 + userVote, err := s.voteRepo.GetUserVote(postID, userID) + if err != nil { + return nil, err + } + + // 构建结果 + result := &dto.VoteResultDTO{ + Options: make([]dto.VoteOptionDTO, 0, len(options)), + TotalVotes: 0, + HasVoted: userVote != nil, + } + + if userVote != nil { + result.VotedOptionID = userVote.OptionID + } + + for _, option := range options { + result.Options = append(result.Options, dto.VoteOptionDTO{ + ID: option.ID, + Content: option.Content, + VotesCount: option.VotesCount, + }) + result.TotalVotes += option.VotesCount + } + + return result, nil +} + +// Vote 投票 +func (s *VoteService) Vote(ctx context.Context, postID, userID, optionID string) error { + // 调用voteRepo.Vote + err := s.voteRepo.Vote(postID, userID, optionID) + if err != nil { + return err + } + + // 失效帖子详情缓存 + cache.InvalidatePostDetail(s.cache, postID) + + return nil +} + +// Unvote 取消投票 +func (s *VoteService) Unvote(ctx context.Context, postID, userID string) error { + // 调用voteRepo.Unvote + err := s.voteRepo.Unvote(postID, userID) + if err != nil { + return err + } + + // 失效帖子详情缓存 + cache.InvalidatePostDetail(s.cache, postID) + + return nil +} + +// UpdateVoteOption 更新投票选项(作者权限) +func (s *VoteService) UpdateVoteOption(ctx context.Context, postID, optionID, userID, content string) error { + // 获取帖子信息 + post, err := s.postRepo.GetByID(postID) + if err != nil { + return err + } + + // 验证用户是否为帖子作者 + if post.UserID != userID { + return errors.New("只有帖子作者可以更新投票选项") + } + + // 调用voteRepo.UpdateOption + return s.voteRepo.UpdateOption(optionID, content) +} + +// convertToPostResponse 将Post模型转换为PostResponse DTO +func (s *VoteService) convertToPostResponse(post *model.Post, currentUserID string) *dto.PostResponse { + if post == nil { + return nil + } + + response := &dto.PostResponse{ + ID: post.ID, + UserID: post.UserID, + Title: post.Title, + Content: post.Content, + LikesCount: post.LikesCount, + CommentsCount: post.CommentsCount, + FavoritesCount: post.FavoritesCount, + SharesCount: post.SharesCount, + ViewsCount: post.ViewsCount, + IsPinned: post.IsPinned, + IsLocked: post.IsLocked, + IsVote: post.IsVote, + CreatedAt: dto.FormatTime(post.CreatedAt), + Images: make([]dto.PostImageResponse, 0, len(post.Images)), + } + + // 转换图片 + for _, img := range post.Images { + response.Images = append(response.Images, dto.PostImageResponse{ + ID: img.ID, + URL: img.URL, + ThumbnailURL: img.ThumbnailURL, + Width: img.Width, + Height: img.Height, + }) + } + + // 转换作者信息 + if post.User != nil { + response.Author = &dto.UserResponse{ + ID: post.User.ID, + Username: post.User.Username, + Nickname: post.User.Nickname, + Avatar: post.User.Avatar, + } + } + + return response +} diff --git a/start-docker.sh b/start-docker.sh new file mode 100755 index 0000000..11fcd7c --- /dev/null +++ b/start-docker.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +# carrot_bbs Docker 启动脚本 +# 使用前请确保已经构建好镜像: docker build -t carrot_bbs:latest . + +docker run -d \ + --name carrot_bbs \ + --network 1panel-network \ + -p 8080:8080 \ + -e APP_DATABASE_TYPE=postgres \ + -e APP_DATABASE_POSTGRES_HOST=1Panel-postgresql-t0g7 \ + -e APP_DATABASE_POSTGRES_USER=carrot_bbs \ + -e APP_DATABASE_POSTGRES_PASSWORD=We5Zyb6WzCa36tCT \ + -e APP_DATABASE_POSTGRES_DBNAME=carrot_bbs \ + -e APP_REDIS_TYPE=redis \ + -e APP_REDIS_REDIS_HOST=1Panel-redis-dfmM \ + -e APP_REDIS_REDIS_PASSWORD=redis_j8CMza \ + -e APP_S3_ENDPOINT=files.littlelan.cn \ + -e APP_S3_ACCESS_KEY=E6bMcYkQzCldRTrtmhvi \ + -e APP_S3_SECRET_KEY=4R9yjmwKNoHphiBkv05Oa8WGEIFbnlZeTLXfSgx3 \ + -e APP_S3_BUCKET=test \ + -e APP_S3_DOMAIN=files.littlelan.cn \ + -e APP_GORSE_ENABLED=true \ + -e APP_GORSE_ADDRESS=http://111.170.19.33:8088 \ + -e APP_GORSE_IMPORT_PASSWORD=lanyimin123 \ + -e APP_GORSE_EMBEDDING_API_KEY=sk-ZPN5NMPSqEaOGCPfD2LqndZ5Wwmw3DC4CQgzgKhM35fI3RpD \ + -e APP_OPENAI_ENABLED=true \ + -e APP_OPENAI_BASE_URL=https://api.littlelan.cn/ \ + -e APP_OPENAI_API_KEY=sk-y7LOeKsNfzbZWTRSFsTs79jd8WYlezbIVgdVPgMvG4Xz2AlV \ + -e APP_OPENAI_MODERATION_MODEL=qwen3.5-122b \ + -e APP_OPENAI_MODERATION_MAX_IMAGES_PER_REQUEST=1 \ + -e APP_OPENAI_REQUEST_TIMEOUT=30 \ + -e APP_OPENAI_STRICT_MODERATION=false \ + -e APP_EMAIL_ENABLED=true \ + -e APP_EMAIL_HOST=smtp.exmail.qq.com \ + -e APP_EMAIL_PORT=465 \ + -e APP_EMAIL_USERNAME=no-reply@qczlit.cn \ + -e APP_EMAIL_PASSWORD=HbvwwVjRyiWg9gsK \ + -e APP_EMAIL_FROM_ADDRESS=no-reply@qczlit.cn \ + -e APP_EMAIL_FROM_NAME="Carrot BBS" \ + -e APP_EMAIL_USE_TLS=true \ + -e APP_EMAIL_INSECURE_SKIP_VERIFY=false \ + -e APP_EMAIL_TIMEOUT=15 \ + carrot-bbs-backend:20260309-181055