From 44fe05ff62f2efdfe364b9d4f5fd20a8a002a21d Mon Sep 17 00:00:00 2001 From: lafay <2021211506@stu.hit.edu.cn> Date: Mon, 5 Jan 2026 00:40:09 +0800 Subject: [PATCH] chore: update dependencies and improve bot configuration - Upgrade Go version to 1.24.0 and update toolchain. - Update various dependencies in go.mod and go.sum, including: - Upgrade `fasthttp/websocket` to v1.5.12 - Upgrade `fsnotify/fsnotify` to v1.9.0 - Upgrade `valyala/fasthttp` to v1.58.0 - Add new dependencies for `bytedance/sonic` and `google/uuid`. - Refactor bot configuration in config.toml to support multiple bot protocols, including "milky" and "onebot11". - Modify internal configuration structures to accommodate new bot settings. - Enhance event dispatcher with metrics tracking and asynchronous processing capabilities. - Implement WebSocket connection management with heartbeat and reconnection logic. - Update server handling for bot management and event publishing. --- cmd/server/main.go | 8 +- configs/config.toml | 37 +- examples/onebot11_config.toml | 88 +++ go.mod | 29 +- go.sum | 66 ++- internal/adapter/milky/adapter.go | 340 +++++++++++ internal/adapter/milky/api_client.go | 189 +++++++ internal/adapter/milky/bot.go | 321 +++++++++++ internal/adapter/milky/event.go | 693 +++++++++++++++++++++++ internal/adapter/milky/sse_client.go | 240 ++++++++ internal/adapter/milky/types.go | 368 ++++++++++++ internal/adapter/milky/webhook_server.go | 115 ++++ internal/adapter/onebot11/action.go | 306 ++++++++++ internal/adapter/onebot11/adapter.go | 467 +++++++++++++++ internal/adapter/onebot11/bot.go | 36 ++ internal/adapter/onebot11/client.go | 186 ++++++ internal/adapter/onebot11/event.go | 355 ++++++++++++ internal/adapter/onebot11/types.go | 187 ++++++ internal/config/config.go | 38 ++ internal/di/lifecycle.go | 94 +-- internal/di/providers.go | 100 +++- internal/engine/dispatcher.go | 203 ++++++- internal/engine/eventbus.go | 216 ++++++- internal/engine/handler.go | 312 ++++++++++ internal/engine/middleware.go | 283 +++++++++ internal/plugins/echo/echo.go | 137 +++++ pkg/net/httpclient.go | 313 ++++++++++ pkg/net/server.go | 291 +++++++++- pkg/net/sse.go | 244 ++++++++ pkg/net/websocket.go | 231 ++++++-- 30 files changed, 6311 insertions(+), 182 deletions(-) create mode 100644 examples/onebot11_config.toml create mode 100644 internal/adapter/milky/adapter.go create mode 100644 internal/adapter/milky/api_client.go create mode 100644 internal/adapter/milky/bot.go create mode 100644 internal/adapter/milky/event.go create mode 100644 internal/adapter/milky/sse_client.go create mode 100644 internal/adapter/milky/types.go create mode 100644 internal/adapter/milky/webhook_server.go create mode 100644 internal/adapter/onebot11/action.go create mode 100644 internal/adapter/onebot11/adapter.go create mode 100644 internal/adapter/onebot11/bot.go create mode 100644 internal/adapter/onebot11/client.go create mode 100644 internal/adapter/onebot11/event.go create mode 100644 internal/adapter/onebot11/types.go create mode 100644 internal/engine/handler.go create mode 100644 internal/engine/middleware.go create mode 100644 internal/plugins/echo/echo.go create mode 100644 pkg/net/httpclient.go create mode 100644 pkg/net/sse.go diff --git a/cmd/server/main.go b/cmd/server/main.go index 73d9a10..03df3e5 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -1,6 +1,10 @@ package main -import "cellbot/internal/di" +import ( + "cellbot/internal/di" + + "go.uber.org/fx" +) func main() { app := NewApp() @@ -8,6 +12,6 @@ func main() { } // NewApp 创建应用实例 -func NewApp() *di.App { +func NewApp() *fx.App { return di.NewApp() } diff --git a/configs/config.toml b/configs/config.toml index 139029a..5aabc57 100644 --- a/configs/config.toml +++ b/configs/config.toml @@ -12,8 +12,41 @@ max_backups = 3 max_age = 7 # days [protocol] -name = "onebot12" +name = "milky" version = "1.0" [protocol.options] -# OneBot12 specific options can be added here +# Protocol specific options can be added here + +# ============================================================================ +# Bot 配置 +# ============================================================================ + +# Milky Bot 示例 +[[bots]] +id = "milky_bot_1" +protocol = "milky" +enabled = false + +[bots.milky] +protocol_url = "http://localhost:3000" +access_token = "your_token_here" +event_mode = "sse" +timeout = 30 +retry_count = 3 + +# OneBot11 Bot 示例 +[[bots]] +id = "onebot11_bot_1" +protocol = "onebot11" +enabled = true + +[bots.onebot11] +connection_type = "ws" +self_id = "123456789" +nickname = "TestBot" +ws_url = "ws://127.0.0.1:3001" +access_token = "hDeu66@_DDhgMf<9" +timeout = 30 +heartbeat = 30 +reconnect_interval = 5 diff --git a/examples/onebot11_config.toml b/examples/onebot11_config.toml new file mode 100644 index 0000000..bf88ce6 --- /dev/null +++ b/examples/onebot11_config.toml @@ -0,0 +1,88 @@ +# OneBot 11 适配器配置示例 + +[bot] +# 机器人自身ID(QQ号) +self_id = "123456789" +# 机器人昵称 +nickname = "MyBot" + +# ===== 连接方式配置 ===== +# 支持的连接类型: ws, ws-reverse, http, http-post +connection_type = "ws" + +# ===== 正向 WebSocket 配置 ===== +# 当 connection_type = "ws" 时使用 +[websocket] +# WebSocket 服务器地址 +ws_url = "ws://127.0.0.1:6700" +# 访问令牌(可选) +access_token = "" +# 心跳间隔(秒) +heartbeat = 30 +# 重连间隔(秒) +reconnect_interval = 5 + +# ===== 反向 WebSocket 配置 ===== +# 当 connection_type = "ws-reverse" 时使用 +[websocket_reverse] +# 反向 WebSocket 监听地址 +ws_reverse_url = "0.0.0.0:8080" +# 访问令牌(可选) +access_token = "" + +# ===== HTTP 配置 ===== +# 当 connection_type = "http" 时使用 +[http] +# HTTP API 地址 +http_url = "http://127.0.0.1:5700" +# 访问令牌(可选) +access_token = "" +# 超时时间(秒) +timeout = 30 + +# ===== HTTP POST 上报配置 ===== +# 当 connection_type = "http-post" 时使用 +[http_post] +# HTTP POST 上报地址 +http_post_url = "http://127.0.0.1:8080/onebot" +# 签名密钥(可选) +secret = "" +# 超时时间(秒) +timeout = 30 + +# ===== 完整配置示例 ===== + +# 示例 1: 使用正向 WebSocket +# [[bots]] +# self_id = "123456789" +# nickname = "Bot1" +# connection_type = "ws" +# ws_url = "ws://127.0.0.1:6700" +# access_token = "your_token_here" +# timeout = 30 + +# 示例 2: 使用 HTTP +# [[bots]] +# self_id = "987654321" +# nickname = "Bot2" +# connection_type = "http" +# http_url = "http://127.0.0.1:5700" +# access_token = "your_token_here" +# timeout = 30 + +# 示例 3: 使用反向 WebSocket +# [[bots]] +# self_id = "111222333" +# nickname = "Bot3" +# connection_type = "ws-reverse" +# ws_reverse_url = "0.0.0.0:8080" +# access_token = "your_token_here" + +# 示例 4: 使用 HTTP POST +# [[bots]] +# self_id = "444555666" +# nickname = "Bot4" +# connection_type = "http-post" +# http_post_url = "http://127.0.0.1:8080/onebot" +# secret = "your_secret_here" +# timeout = 30 diff --git a/go.mod b/go.mod index 01ffcf1..57f6c7f 100644 --- a/go.mod +++ b/go.mod @@ -1,23 +1,34 @@ module cellbot -go 1.21 +go 1.24.0 + +toolchain go1.24.2 require ( github.com/BurntSushi/toml v1.3.2 - github.com/fasthttp/websocket v1.5.6 - github.com/fsnotify/fsnotify v1.7.0 - github.com/valyala/fasthttp v1.51.0 + github.com/bytedance/sonic v1.14.2 + github.com/fasthttp/websocket v1.5.12 + github.com/fsnotify/fsnotify v1.9.0 + github.com/valyala/fasthttp v1.58.0 go.uber.org/fx v1.20.0 go.uber.org/zap v1.26.0 + golang.org/x/time v0.14.0 ) require ( - github.com/andybalholm/brotli v1.0.5 // indirect - github.com/klauspost/compress v1.17.1 // indirect - github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect + github.com/andybalholm/brotli v1.1.1 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic/loader v0.4.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/google/uuid v1.6.0 + github.com/klauspost/compress v1.17.11 // indirect + github.com/klauspost/cpuid/v2 v2.2.9 // indirect + github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect go.uber.org/dig v1.17.0 // indirect go.uber.org/multierr v1.10.0 // indirect - golang.org/x/net v0.17.0 // indirect - golang.org/x/sys v0.13.0 // indirect + golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect + golang.org/x/net v0.33.0 // indirect + golang.org/x/sys v0.28.0 // indirect ) diff --git a/go.sum b/go.sum index 39b5aa6..73d4055 100644 --- a/go.sum +++ b/go.sum @@ -1,27 +1,51 @@ github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= -github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= -github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= +github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/sonic v1.14.2 h1:k1twIoe97C1DtYUo+fZQy865IuHia4PR5RPiuGPPIIE= +github.com/bytedance/sonic v1.14.2/go.mod h1:T80iDELeHiHKSc0C9tubFygiuXoGzrkjKzX2quAx980= +github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2NYzevs+o= +github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fasthttp/websocket v1.5.6 h1:4WtWgRJ0Gzj1Ou+xGKy66Ji+a0mUfgAj9ZdPqHiUwQE= -github.com/fasthttp/websocket v1.5.6/go.mod h1:yiKhNx2zFOv65YYtCJNhtl5VjdCFew3W+gt8U/9aFkI= -github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= -github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= -github.com/klauspost/compress v1.17.1 h1:NE3C767s2ak2bweCZo3+rdP4U/HoyVXLv/X9f2gPS5g= -github.com/klauspost/compress v1.17.1/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/fasthttp/websocket v1.5.12 h1:e4RGPpWW2HTbL3zV0Y/t7g0ub294LkiuXXUuTOUInlE= +github.com/fasthttp/websocket v1.5.12/go.mod h1:I+liyL7/4moHojiOgUOIKEWm9EIxHqxZChS+aMFltyg= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= +github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= +github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk= -github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55+q0lFuGse6U8lxlg7fYTctlT5Gc= +github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= -github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= +github.com/valyala/fasthttp v1.58.0 h1:GGB2dWxSbEprU9j0iMJHgdKYJVDyjrOwF9RE59PbRuE= +github.com/valyala/fasthttp v1.58.0/go.mod h1:SYXvHHaFp7QZHGKSHmoMipInhrI5StHrhDTYVEjK/Kw= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/dig v1.17.0 h1:5Chju+tUvcC+N7N6EV08BJz41UZuO3BmHcN4A287ZLI= @@ -34,9 +58,15 @@ go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/adapter/milky/adapter.go b/internal/adapter/milky/adapter.go new file mode 100644 index 0000000..a927c32 --- /dev/null +++ b/internal/adapter/milky/adapter.go @@ -0,0 +1,340 @@ +package milky + +import ( + "cellbot/internal/engine" + "cellbot/internal/protocol" + "cellbot/pkg/net" + "context" + "fmt" + "strconv" + "time" + + "go.uber.org/zap" +) + +// Config Milky 适配器配置 +type Config struct { + // 协议端地址(如 http://localhost:3000) + ProtocolURL string `toml:"protocol_url"` + // 访问令牌 + AccessToken string `toml:"access_token"` + // 事件接收方式: sse, websocket, webhook + EventMode string `toml:"event_mode"` + // Webhook 监听地址(仅当 event_mode = "webhook" 时需要) + WebhookListenAddr string `toml:"webhook_listen_addr"` + // 超时时间(秒) + Timeout int `toml:"timeout"` + // 重试次数 + RetryCount int `toml:"retry_count"` +} + +// Adapter Milky 协议适配器 +type Adapter struct { + config *Config + selfID string + apiClient *APIClient + sseClient *net.SSEClient + wsManager *net.WebSocketManager + wsConn *net.WebSocketConnection + webhookServer *WebhookServer + eventBus *engine.EventBus + eventConverter *EventConverter + logger *zap.Logger + ctx context.Context + cancel context.CancelFunc +} + +// NewAdapter 创建 Milky 适配器 +func NewAdapter(config *Config, selfID string, eventBus *engine.EventBus, wsManager *net.WebSocketManager, logger *zap.Logger) *Adapter { + ctx, cancel := context.WithCancel(context.Background()) + + timeout := time.Duration(config.Timeout) * time.Second + if timeout == 0 { + timeout = 30 * time.Second + } + + retryCount := config.RetryCount + if retryCount == 0 { + retryCount = 3 + } + + return &Adapter{ + config: config, + selfID: selfID, + apiClient: NewAPIClient(config.ProtocolURL, config.AccessToken, timeout, retryCount, logger), + eventBus: eventBus, + wsManager: wsManager, + eventConverter: NewEventConverter(logger), + logger: logger.Named("milky-adapter"), + ctx: ctx, + cancel: cancel, + } +} + +// Connect 连接到协议端 +func (a *Adapter) Connect(ctx context.Context) error { + a.logger.Info("Connecting to Milky protocol server", + zap.String("url", a.config.ProtocolURL), + zap.String("event_mode", a.config.EventMode)) + + // 根据配置选择事件接收方式 + switch a.config.EventMode { + case "sse": + return a.connectSSE(ctx) + case "websocket": + return a.connectWebSocket(ctx) + case "webhook": + return a.startWebhook() + default: + return fmt.Errorf("unknown event mode: %s", a.config.EventMode) + } +} + +// connectSSE 连接 SSE +func (a *Adapter) connectSSE(ctx context.Context) error { + eventURL := a.config.ProtocolURL + "/event" + + // 创建 SSE 客户端配置 + sseConfig := net.SSEClientConfig{ + URL: eventURL, + AccessToken: a.config.AccessToken, + ReconnectDelay: 5 * time.Second, + MaxReconnect: -1, // 无限重连 + EventFilter: "milky_event", // 只接收 milky_event 类型 + BufferSize: 100, + } + + a.sseClient = net.NewSSEClient(sseConfig, a.logger) + + // 启动 SSE 连接 + if err := a.sseClient.Connect(ctx); err != nil { + return fmt.Errorf("failed to connect SSE: %w", err) + } + + // 启动事件处理 + go a.handleEvents(a.sseClient.Events()) + + a.logger.Info("SSE connection established") + return nil +} + +// connectWebSocket 连接 WebSocket +func (a *Adapter) connectWebSocket(ctx context.Context) error { + // 构建 WebSocket URL + eventURL := a.config.ProtocolURL + "/event" + // 替换 http:// 为 ws://,https:// 为 wss:// + if len(eventURL) > 7 && eventURL[:7] == "http://" { + eventURL = "ws://" + eventURL[7:] + } else if len(eventURL) > 8 && eventURL[:8] == "https://" { + eventURL = "wss://" + eventURL[8:] + } + + // 添加 access_token 参数 + if a.config.AccessToken != "" { + eventURL += "?access_token=" + a.config.AccessToken + } + + a.logger.Info("Connecting to WebSocket", zap.String("url", eventURL)) + + // 使用 WebSocketManager 建立连接 + conn, err := a.wsManager.Dial(eventURL, a.selfID) + if err != nil { + return fmt.Errorf("failed to dial WebSocket: %w", err) + } + + a.wsConn = conn + + // 启动事件处理 + go a.handleWebSocketEvents() + + a.logger.Info("WebSocket connection established") + return nil +} + +// handleWebSocketEvents 处理 WebSocket 事件 +func (a *Adapter) handleWebSocketEvents() { + for { + select { + case <-a.ctx.Done(): + return + default: + } + + // 读取消息 + _, message, err := a.wsConn.Conn.ReadMessage() + if err != nil { + a.logger.Error("Failed to read WebSocket message", zap.Error(err)) + return + } + + // 转换事件 + event, err := a.eventConverter.Convert(message) + if err != nil { + a.logger.Error("Failed to convert event", zap.Error(err)) + continue + } + + // 发布到事件总线 + a.eventBus.Publish(event) + } +} + +// startWebhook 启动 Webhook 服务器 +func (a *Adapter) startWebhook() error { + if a.config.WebhookListenAddr == "" { + return fmt.Errorf("webhook_listen_addr is required for webhook mode") + } + + a.webhookServer = NewWebhookServer(a.config.WebhookListenAddr, a.logger) + + // 启动服务器 + if err := a.webhookServer.Start(); err != nil { + return fmt.Errorf("failed to start webhook server: %w", err) + } + + // 启动事件处理 + go a.handleEvents(a.webhookServer.Events()) + + a.logger.Info("Webhook server started", zap.String("addr", a.config.WebhookListenAddr)) + return nil +} + +// handleEvents 处理事件 +func (a *Adapter) handleEvents(eventChan <-chan []byte) { + for { + select { + case <-a.ctx.Done(): + return + case rawEvent, ok := <-eventChan: + if !ok { + a.logger.Info("Event channel closed") + return + } + + // 转换事件 + event, err := a.eventConverter.Convert(rawEvent) + if err != nil { + a.logger.Error("Failed to convert event", zap.Error(err)) + continue + } + + // 发布到事件总线 + a.eventBus.Publish(event) + } + } +} + +// SendAction 发送动作 +func (a *Adapter) SendAction(ctx context.Context, action protocol.Action) (map[string]interface{}, error) { + // 调用 API + resp, err := a.apiClient.Call(ctx, string(action.GetType()), action.GetParams()) + if err != nil { + return nil, fmt.Errorf("failed to call API: %w", err) + } + + return resp.Data, nil +} + +// ParseMessage 解析消息 +func (a *Adapter) ParseMessage(raw []byte) (protocol.Event, error) { + return a.eventConverter.Convert(raw) +} + +// Disconnect 断开连接 +func (a *Adapter) Disconnect() error { + a.logger.Info("Disconnecting from Milky protocol server") + + a.cancel() + + // 关闭各种连接 + if a.sseClient != nil { + if err := a.sseClient.Close(); err != nil { + a.logger.Error("Failed to close SSE client", zap.Error(err)) + } + } + + if a.wsConn != nil { + // WebSocket 连接会在 context 取消时自动关闭 + a.logger.Info("WebSocket connection will be closed") + } + + if a.webhookServer != nil { + if err := a.webhookServer.Stop(); err != nil { + a.logger.Error("Failed to stop webhook server", zap.Error(err)) + } + } + + if a.apiClient != nil { + if err := a.apiClient.Close(); err != nil { + a.logger.Error("Failed to close API client", zap.Error(err)) + } + } + + return nil +} + +// GetProtocolName 获取协议名称 +func (a *Adapter) GetProtocolName() string { + return "milky" +} + +// GetProtocolVersion 获取协议版本 +func (a *Adapter) GetProtocolVersion() string { + return "1.0" +} + +// GetSelfID 获取机器人自身 ID +func (a *Adapter) GetSelfID() string { + return a.selfID +} + +// IsConnected 是否已连接 +func (a *Adapter) IsConnected() bool { + switch a.config.EventMode { + case "sse": + return a.sseClient != nil + case "websocket": + return a.wsConn != nil && a.wsConn.Conn != nil + case "webhook": + return a.webhookServer != nil + default: + return false + } +} + +// GetStats 获取统计信息 +func (a *Adapter) GetStats() map[string]interface{} { + stats := map[string]interface{}{ + "protocol": "milky", + "self_id": a.selfID, + "event_mode": a.config.EventMode, + "connected": a.IsConnected(), + } + + if a.config.EventMode == "websocket" && a.wsConn != nil { + stats["remote_addr"] = a.wsConn.RemoteAddr + stats["connection_type"] = a.wsConn.Type + } + + return stats +} + +// CallAPI 直接调用 API(提供给 Bot 使用) +func (a *Adapter) CallAPI(ctx context.Context, endpoint string, params map[string]interface{}) (*APIResponse, error) { + return a.apiClient.Call(ctx, endpoint, params) +} + +// GetConfig 获取配置 +func (a *Adapter) GetConfig() *Config { + return a.config +} + +// SetSelfID 设置机器人自身 ID +func (a *Adapter) SetSelfID(selfID string) { + a.selfID = selfID +} + +// GetSelfIDInt64 获取机器人自身 ID(int64) +func (a *Adapter) GetSelfIDInt64() (int64, error) { + return strconv.ParseInt(a.selfID, 10, 64) +} diff --git a/internal/adapter/milky/api_client.go b/internal/adapter/milky/api_client.go new file mode 100644 index 0000000..97f3347 --- /dev/null +++ b/internal/adapter/milky/api_client.go @@ -0,0 +1,189 @@ +package milky + +import ( + "context" + "fmt" + "time" + + "github.com/bytedance/sonic" + "github.com/valyala/fasthttp" + "go.uber.org/zap" +) + +// APIClient Milky API 客户端 +// 用于调用协议端的 API (POST /api/:api) +type APIClient struct { + baseURL string + accessToken string + httpClient *fasthttp.Client + logger *zap.Logger + timeout time.Duration + retryCount int +} + +// NewAPIClient 创建 API 客户端 +func NewAPIClient(baseURL, accessToken string, timeout time.Duration, retryCount int, logger *zap.Logger) *APIClient { + if timeout == 0 { + timeout = 30 * time.Second + } + if retryCount == 0 { + retryCount = 3 + } + + return &APIClient{ + baseURL: baseURL, + accessToken: accessToken, + httpClient: &fasthttp.Client{ + ReadTimeout: timeout, + WriteTimeout: timeout, + MaxConnsPerHost: 100, + }, + logger: logger.Named("api-client"), + timeout: timeout, + retryCount: retryCount, + } +} + +// Call 调用 API +// endpoint: API 端点名称(如 "send_private_message") +// input: 输入参数(会被序列化为 JSON) +// 返回: 响应数据和错误 +func (c *APIClient) Call(ctx context.Context, endpoint string, input interface{}) (*APIResponse, error) { + // 序列化输入参数 + var inputData []byte + var err error + + if input == nil { + inputData = []byte("{}") + } else { + inputData, err = sonic.Marshal(input) + if err != nil { + return nil, fmt.Errorf("failed to marshal input: %w", err) + } + } + + // 构建 URL + url := fmt.Sprintf("%s/api/%s", c.baseURL, endpoint) + + c.logger.Debug("Calling API", + zap.String("endpoint", endpoint), + zap.String("url", url)) + + // 重试机制 + var lastErr error + for i := 0; i <= c.retryCount; i++ { + if i > 0 { + c.logger.Info("Retrying API call", + zap.String("endpoint", endpoint), + zap.Int("attempt", i), + zap.Int("max", c.retryCount)) + + // 指数退避 + backoff := time.Duration(i) * time.Second + select { + case <-time.After(backoff): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + resp, err := c.doRequest(ctx, url, inputData) + if err != nil { + lastErr = err + continue + } + + return resp, nil + } + + return nil, fmt.Errorf("API call failed after %d retries: %w", c.retryCount, lastErr) +} + +// doRequest 执行单次请求 +func (c *APIClient) doRequest(ctx context.Context, url string, inputData []byte) (*APIResponse, error) { + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // 设置请求 + req.SetRequestURI(url) + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.SetBody(inputData) + + // 设置 Authorization 头 + if c.accessToken != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.accessToken)) + } + + // 发送请求 + err := c.httpClient.DoTimeout(req, resp, c.timeout) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + // 检查 HTTP 状态码 + statusCode := resp.StatusCode() + switch statusCode { + case 401: + return nil, fmt.Errorf("unauthorized: access token invalid or missing") + case 404: + return nil, fmt.Errorf("API not found: %s", url) + case 415: + return nil, fmt.Errorf("unsupported media type: Content-Type must be application/json") + case 200: + // 继续处理 + default: + return nil, fmt.Errorf("unexpected status code: %d", statusCode) + } + + // 解析响应 + var apiResp APIResponse + if err := sonic.Unmarshal(resp.Body(), &apiResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // 检查业务状态 + if apiResp.Status != "ok" { + c.logger.Warn("API call failed", + zap.String("status", apiResp.Status), + zap.Int("retcode", apiResp.RetCode), + zap.String("message", apiResp.Message)) + + return &apiResp, fmt.Errorf("API error (retcode=%d): %s", apiResp.RetCode, apiResp.Message) + } + + c.logger.Debug("API call succeeded", + zap.String("status", apiResp.Status), + zap.Int("retcode", apiResp.RetCode)) + + return &apiResp, nil +} + +// CallWithoutRetry 调用 API(不重试) +func (c *APIClient) CallWithoutRetry(ctx context.Context, endpoint string, input interface{}) (*APIResponse, error) { + // 序列化输入参数 + var inputData []byte + var err error + + if input == nil { + inputData = []byte("{}") + } else { + inputData, err = sonic.Marshal(input) + if err != nil { + return nil, fmt.Errorf("failed to marshal input: %w", err) + } + } + + // 构建 URL + url := fmt.Sprintf("%s/api/%s", c.baseURL, endpoint) + + return c.doRequest(ctx, url, inputData) +} + +// Close 关闭客户端 +func (c *APIClient) Close() error { + // fasthttp.Client 不需要显式关闭 + return nil +} diff --git a/internal/adapter/milky/bot.go b/internal/adapter/milky/bot.go new file mode 100644 index 0000000..58e1d2d --- /dev/null +++ b/internal/adapter/milky/bot.go @@ -0,0 +1,321 @@ +package milky + +import ( + "cellbot/internal/engine" + "cellbot/internal/protocol" + "cellbot/pkg/net" + "context" + "fmt" + + "go.uber.org/zap" +) + +// Bot Milky Bot 实例 +type Bot struct { + id string + adapter *Adapter + logger *zap.Logger + status protocol.BotStatus +} + +// NewBot 创建 Milky Bot 实例 +func NewBot(id string, config *Config, eventBus *engine.EventBus, wsManager *net.WebSocketManager, logger *zap.Logger) *Bot { + adapter := NewAdapter(config, id, eventBus, wsManager, logger) + + return &Bot{ + id: id, + adapter: adapter, + logger: logger.Named("milky-bot").With(zap.String("bot_id", id)), + status: protocol.BotStatusStopped, + } +} + +// GetID 获取机器人 ID +func (b *Bot) GetID() string { + return b.id +} + +// GetProtocol 获取协议名称 +func (b *Bot) GetProtocol() string { + return "milky" +} + +// Name 获取协议名称 +func (b *Bot) Name() string { + return "milky" +} + +// Version 获取协议版本 +func (b *Bot) Version() string { + return "1.0" +} + +// GetSelfID 获取机器人自身ID +func (b *Bot) GetSelfID() string { + return b.id +} + +// Start 启动实例 +func (b *Bot) Start(ctx context.Context) error { + return b.Connect(ctx) +} + +// Stop 停止实例 +func (b *Bot) Stop(ctx context.Context) error { + return b.Disconnect(ctx) +} + +// HandleEvent 处理事件 +func (b *Bot) HandleEvent(ctx context.Context, event protocol.Event) error { + // Milky 适配器通过事件总线自动处理事件 + // 这里不需要额外处理 + return nil +} + +// GetStatus 获取状态 +func (b *Bot) GetStatus() protocol.BotStatus { + return b.status +} + +// Connect 连接 +func (b *Bot) Connect(ctx context.Context) error { + b.logger.Info("Connecting Milky bot") + + if err := b.adapter.Connect(ctx); err != nil { + b.status = protocol.BotStatusError + return fmt.Errorf("failed to connect: %w", err) + } + + b.status = protocol.BotStatusRunning + b.logger.Info("Milky bot connected") + return nil +} + +// Disconnect 断开连接 +func (b *Bot) Disconnect(ctx context.Context) error { + b.logger.Info("Disconnecting Milky bot") + + if err := b.adapter.Disconnect(); err != nil { + return fmt.Errorf("failed to disconnect: %w", err) + } + + b.status = protocol.BotStatusStopped + b.logger.Info("Milky bot disconnected") + return nil +} + +// SendAction 发送动作 +func (b *Bot) SendAction(ctx context.Context, action protocol.Action) (map[string]interface{}, error) { + if b.status != protocol.BotStatusRunning { + return nil, fmt.Errorf("bot is not running") + } + + return b.adapter.SendAction(ctx, action) +} + +// GetAdapter 获取适配器 +func (b *Bot) GetAdapter() *Adapter { + return b.adapter +} + +// GetInfo 获取机器人信息 +func (b *Bot) GetInfo() map[string]interface{} { + return map[string]interface{}{ + "id": b.id, + "protocol": "milky", + "status": b.status, + "stats": b.adapter.GetStats(), + } +} + +// IsConnected 是否已连接 +func (b *Bot) IsConnected() bool { + return b.status == protocol.BotStatusRunning && b.adapter.IsConnected() +} + +// SetStatus 设置状态 +func (b *Bot) SetStatus(status protocol.BotStatus) { + b.status = status +} + +// ============================================================================ +// Milky 特定的 API 方法 +// ============================================================================ + +// SendPrivateMessage 发送私聊消息 +func (b *Bot) SendPrivateMessage(ctx context.Context, userID int64, segments []OutgoingSegment) (*APIResponse, error) { + params := map[string]interface{}{ + "user_id": userID, + "segments": segments, + } + return b.adapter.CallAPI(ctx, "send_private_message", params) +} + +// SendGroupMessage 发送群消息 +func (b *Bot) SendGroupMessage(ctx context.Context, groupID int64, segments []OutgoingSegment) (*APIResponse, error) { + params := map[string]interface{}{ + "group_id": groupID, + "segments": segments, + } + return b.adapter.CallAPI(ctx, "send_group_message", params) +} + +// SendTempMessage 发送临时消息 +func (b *Bot) SendTempMessage(ctx context.Context, groupID, userID int64, segments []OutgoingSegment) (*APIResponse, error) { + params := map[string]interface{}{ + "group_id": groupID, + "user_id": userID, + "segments": segments, + } + return b.adapter.CallAPI(ctx, "send_temp_message", params) +} + +// RecallMessage 撤回消息 +func (b *Bot) RecallMessage(ctx context.Context, messageScene string, peerID, messageSeq int64) (*APIResponse, error) { + params := map[string]interface{}{ + "message_scene": messageScene, + "peer_id": peerID, + "message_seq": messageSeq, + } + return b.adapter.CallAPI(ctx, "recall_message", params) +} + +// GetFriendList 获取好友列表 +func (b *Bot) GetFriendList(ctx context.Context) (*APIResponse, error) { + return b.adapter.CallAPI(ctx, "get_friend_list", nil) +} + +// GetGroupList 获取群列表 +func (b *Bot) GetGroupList(ctx context.Context) (*APIResponse, error) { + return b.adapter.CallAPI(ctx, "get_group_list", nil) +} + +// GetGroupMemberList 获取群成员列表 +func (b *Bot) GetGroupMemberList(ctx context.Context, groupID int64) (*APIResponse, error) { + params := map[string]interface{}{ + "group_id": groupID, + } + return b.adapter.CallAPI(ctx, "get_group_member_list", params) +} + +// GetGroupMemberInfo 获取群成员信息 +func (b *Bot) GetGroupMemberInfo(ctx context.Context, groupID, userID int64) (*APIResponse, error) { + params := map[string]interface{}{ + "group_id": groupID, + "user_id": userID, + } + return b.adapter.CallAPI(ctx, "get_group_member_info", params) +} + +// SetGroupAdmin 设置群管理员 +func (b *Bot) SetGroupAdmin(ctx context.Context, groupID, userID int64, isSet bool) (*APIResponse, error) { + params := map[string]interface{}{ + "group_id": groupID, + "user_id": userID, + "is_set": isSet, + } + return b.adapter.CallAPI(ctx, "set_group_admin", params) +} + +// SetGroupCard 设置群名片 +func (b *Bot) SetGroupCard(ctx context.Context, groupID, userID int64, card string) (*APIResponse, error) { + params := map[string]interface{}{ + "group_id": groupID, + "user_id": userID, + "card": card, + } + return b.adapter.CallAPI(ctx, "set_group_card", params) +} + +// SetGroupName 设置群名 +func (b *Bot) SetGroupName(ctx context.Context, groupID int64, groupName string) (*APIResponse, error) { + params := map[string]interface{}{ + "group_id": groupID, + "group_name": groupName, + } + return b.adapter.CallAPI(ctx, "set_group_name", params) +} + +// KickGroupMember 踢出群成员 +func (b *Bot) KickGroupMember(ctx context.Context, groupID, userID int64, rejectAddRequest bool) (*APIResponse, error) { + params := map[string]interface{}{ + "group_id": groupID, + "user_id": userID, + "reject_add_request": rejectAddRequest, + } + return b.adapter.CallAPI(ctx, "kick_group_member", params) +} + +// MuteGroupMember 禁言群成员 +func (b *Bot) MuteGroupMember(ctx context.Context, groupID, userID int64, duration int32) (*APIResponse, error) { + params := map[string]interface{}{ + "group_id": groupID, + "user_id": userID, + "duration": duration, + } + return b.adapter.CallAPI(ctx, "mute_group_member", params) +} + +// MuteGroupWhole 全体禁言 +func (b *Bot) MuteGroupWhole(ctx context.Context, groupID int64, isMute bool) (*APIResponse, error) { + params := map[string]interface{}{ + "group_id": groupID, + "is_mute": isMute, + } + return b.adapter.CallAPI(ctx, "mute_group_whole", params) +} + +// LeaveGroup 退出群 +func (b *Bot) LeaveGroup(ctx context.Context, groupID int64) (*APIResponse, error) { + params := map[string]interface{}{ + "group_id": groupID, + } + return b.adapter.CallAPI(ctx, "leave_group", params) +} + +// HandleFriendRequest 处理好友请求 +func (b *Bot) HandleFriendRequest(ctx context.Context, initiatorUID string, accept bool) (*APIResponse, error) { + params := map[string]interface{}{ + "initiator_uid": initiatorUID, + "accept": accept, + } + return b.adapter.CallAPI(ctx, "handle_friend_request", params) +} + +// HandleGroupJoinRequest 处理入群申请 +func (b *Bot) HandleGroupJoinRequest(ctx context.Context, groupID, notificationSeq int64, accept bool, rejectReason string) (*APIResponse, error) { + params := map[string]interface{}{ + "group_id": groupID, + "notification_seq": notificationSeq, + "accept": accept, + "reject_reason": rejectReason, + } + return b.adapter.CallAPI(ctx, "handle_group_join_request", params) +} + +// HandleGroupInvitation 处理群邀请 +func (b *Bot) HandleGroupInvitation(ctx context.Context, groupID, invitationSeq int64, accept bool) (*APIResponse, error) { + params := map[string]interface{}{ + "group_id": groupID, + "invitation_seq": invitationSeq, + "accept": accept, + } + return b.adapter.CallAPI(ctx, "handle_group_invitation", params) +} + +// UploadFile 上传文件 +func (b *Bot) UploadFile(ctx context.Context, fileType, filePath string) (*APIResponse, error) { + params := map[string]interface{}{ + "file_type": fileType, + "file_path": filePath, + } + return b.adapter.CallAPI(ctx, "upload_file", params) +} + +// GetFile 获取文件 +func (b *Bot) GetFile(ctx context.Context, fileID string) (*APIResponse, error) { + params := map[string]interface{}{ + "file_id": fileID, + } + return b.adapter.CallAPI(ctx, "get_file", params) +} diff --git a/internal/adapter/milky/event.go b/internal/adapter/milky/event.go new file mode 100644 index 0000000..4134ff2 --- /dev/null +++ b/internal/adapter/milky/event.go @@ -0,0 +1,693 @@ +package milky + +import ( + "cellbot/internal/protocol" + "fmt" + "strconv" + + "github.com/bytedance/sonic" + "go.uber.org/zap" +) + +// EventConverter 事件转换器 +// 将 Milky 事件转换为通用 protocol.Event +type EventConverter struct { + logger *zap.Logger +} + +// NewEventConverter 创建事件转换器 +func NewEventConverter(logger *zap.Logger) *EventConverter { + return &EventConverter{ + logger: logger.Named("event-converter"), + } +} + +// Convert 转换事件 +func (c *EventConverter) Convert(raw []byte) (protocol.Event, error) { + // 解析原始事件 + var milkyEvent Event + if err := sonic.Unmarshal(raw, &milkyEvent); err != nil { + return nil, fmt.Errorf("failed to unmarshal event: %w", err) + } + + c.logger.Debug("Converting event", + zap.String("event_type", milkyEvent.EventType), + zap.Int64("self_id", milkyEvent.SelfID)) + + // 根据事件类型转换 + switch milkyEvent.EventType { + case EventTypeMessageReceive: + return c.convertMessageEvent(&milkyEvent) + case EventTypeFriendRequest: + return c.convertFriendRequestEvent(&milkyEvent) + case EventTypeGroupJoinRequest: + return c.convertGroupJoinRequestEvent(&milkyEvent) + case EventTypeGroupInvitedJoinRequest: + return c.convertGroupInvitedJoinRequestEvent(&milkyEvent) + case EventTypeGroupInvitation: + return c.convertGroupInvitationEvent(&milkyEvent) + case EventTypeMessageRecall: + return c.convertMessageRecallEvent(&milkyEvent) + case EventTypeBotOffline: + return c.convertBotOfflineEvent(&milkyEvent) + case EventTypeFriendNudge: + return c.convertFriendNudgeEvent(&milkyEvent) + case EventTypeFriendFileUpload: + return c.convertFriendFileUploadEvent(&milkyEvent) + case EventTypeGroupAdminChange: + return c.convertGroupAdminChangeEvent(&milkyEvent) + case EventTypeGroupEssenceMessageChange: + return c.convertGroupEssenceMessageChangeEvent(&milkyEvent) + case EventTypeGroupMemberIncrease: + return c.convertGroupMemberIncreaseEvent(&milkyEvent) + case EventTypeGroupMemberDecrease: + return c.convertGroupMemberDecreaseEvent(&milkyEvent) + case EventTypeGroupNameChange: + return c.convertGroupNameChangeEvent(&milkyEvent) + case EventTypeGroupMessageReaction: + return c.convertGroupMessageReactionEvent(&milkyEvent) + case EventTypeGroupMute: + return c.convertGroupMuteEvent(&milkyEvent) + case EventTypeGroupWholeMute: + return c.convertGroupWholeMuteEvent(&milkyEvent) + case EventTypeGroupNudge: + return c.convertGroupNudgeEvent(&milkyEvent) + case EventTypeGroupFileUpload: + return c.convertGroupFileUploadEvent(&milkyEvent) + default: + c.logger.Warn("Unknown event type", zap.String("event_type", milkyEvent.EventType)) + return nil, fmt.Errorf("unknown event type: %s", milkyEvent.EventType) + } +} + +// convertMessageEvent 转换消息事件 +func (c *EventConverter) convertMessageEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + // 解析消息数据 + var msgData IncomingMessage + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &msgData); err != nil { + return nil, fmt.Errorf("failed to parse message data: %w", err) + } + + // 构建消息文本 + messageText := c.buildMessageText(msgData.Segments) + + event := &protocol.MessageEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeMessage, + DetailType: msgData.MessageScene, // friend, group, temp + SubType: "", + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}{ + "peer_id": msgData.PeerID, + "message_seq": msgData.MessageSeq, + "sender_id": msgData.SenderID, + "time": msgData.Time, + "segments": msgData.Segments, + }, + }, + MessageID: strconv.FormatInt(msgData.MessageSeq, 10), + Message: messageText, + AltText: messageText, + } + + // 添加场景特定数据 + if msgData.Friend != nil { + event.Data["friend"] = msgData.Friend + } + if msgData.Group != nil { + event.Data["group"] = msgData.Group + } + if msgData.GroupMember != nil { + event.Data["group_member"] = msgData.GroupMember + } + + return event, nil +} + +// convertFriendRequestEvent 转换好友请求事件 +func (c *EventConverter) convertFriendRequestEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data FriendRequestEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse friend request data: %w", err) + } + + event := &protocol.RequestEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeRequest, + DetailType: "friend", + SubType: "", + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}{ + "initiator_id": data.InitiatorID, + "initiator_uid": data.InitiatorUID, + "comment": data.Comment, + "via": data.Via, + }, + }, + RequestID: strconv.FormatInt(data.InitiatorID, 10), + UserID: strconv.FormatInt(data.InitiatorID, 10), + Comment: data.Comment, + Flag: data.InitiatorUID, + } + + return event, nil +} + +// convertGroupJoinRequestEvent 转换入群申请事件 +func (c *EventConverter) convertGroupJoinRequestEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data GroupJoinRequestEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse group join request data: %w", err) + } + + event := &protocol.RequestEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeRequest, + DetailType: "group", + SubType: "add", + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}{ + "group_id": data.GroupID, + "notification_seq": data.NotificationSeq, + "is_filtered": data.IsFiltered, + "initiator_id": data.InitiatorID, + "comment": data.Comment, + }, + }, + RequestID: strconv.FormatInt(data.NotificationSeq, 10), + UserID: strconv.FormatInt(data.InitiatorID, 10), + Comment: data.Comment, + Flag: strconv.FormatInt(data.NotificationSeq, 10), + } + + return event, nil +} + +// convertGroupInvitedJoinRequestEvent 转换群成员邀请他人入群事件 +func (c *EventConverter) convertGroupInvitedJoinRequestEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data GroupInvitedJoinRequestEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse group invited join request data: %w", err) + } + + event := &protocol.RequestEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeRequest, + DetailType: "group", + SubType: "invite", + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}{ + "group_id": data.GroupID, + "notification_seq": data.NotificationSeq, + "initiator_id": data.InitiatorID, + "target_user_id": data.TargetUserID, + }, + }, + RequestID: strconv.FormatInt(data.NotificationSeq, 10), + UserID: strconv.FormatInt(data.InitiatorID, 10), + Comment: "", + Flag: strconv.FormatInt(data.NotificationSeq, 10), + } + + return event, nil +} + +// convertGroupInvitationEvent 转换他人邀请自身入群事件 +func (c *EventConverter) convertGroupInvitationEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data GroupInvitationEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse group invitation data: %w", err) + } + + event := &protocol.RequestEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeRequest, + DetailType: "group", + SubType: "invite_self", + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}{ + "group_id": data.GroupID, + "invitation_seq": data.InvitationSeq, + "initiator_id": data.InitiatorID, + }, + }, + RequestID: strconv.FormatInt(data.InvitationSeq, 10), + UserID: strconv.FormatInt(data.InitiatorID, 10), + Comment: "", + Flag: strconv.FormatInt(data.InvitationSeq, 10), + } + + return event, nil +} + +// convertMessageRecallEvent 转换消息撤回事件 +func (c *EventConverter) convertMessageRecallEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data MessageRecallEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse message recall data: %w", err) + } + + event := &protocol.NoticeEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeNotice, + DetailType: "message_recall", + SubType: data.MessageScene, + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}{ + "message_scene": data.MessageScene, + "peer_id": data.PeerID, + "message_seq": data.MessageSeq, + "sender_id": data.SenderID, + "operator_id": data.OperatorID, + "display_suffix": data.DisplaySuffix, + }, + }, + UserID: strconv.FormatInt(data.SenderID, 10), + Operator: strconv.FormatInt(data.OperatorID, 10), + } + + return event, nil +} + +// convertBotOfflineEvent 转换机器人离线事件 +func (c *EventConverter) convertBotOfflineEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data BotOfflineEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse bot offline data: %w", err) + } + + event := &protocol.MetaEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeMeta, + DetailType: "bot_offline", + SubType: "", + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}{ + "reason": data.Reason, + }, + }, + Status: "offline", + } + + return event, nil +} + +// convertFriendNudgeEvent 转换好友戳一戳事件 +func (c *EventConverter) convertFriendNudgeEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data FriendNudgeEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse friend nudge data: %w", err) + } + + event := &protocol.NoticeEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeNotice, + DetailType: "friend_nudge", + SubType: "", + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}(milkyEvent.Data), + }, + UserID: strconv.FormatInt(data.UserID, 10), + } + + return event, nil +} + +// convertFriendFileUploadEvent 转换好友文件上传事件 +func (c *EventConverter) convertFriendFileUploadEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data FriendFileUploadEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse friend file upload data: %w", err) + } + + event := &protocol.NoticeEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeNotice, + DetailType: "friend_file_upload", + SubType: "", + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}(milkyEvent.Data), + }, + UserID: strconv.FormatInt(data.UserID, 10), + } + + return event, nil +} + +// convertGroupAdminChangeEvent 转换群管理员变更事件 +func (c *EventConverter) convertGroupAdminChangeEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data GroupAdminChangeEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse group admin change data: %w", err) + } + + subType := "unset" + if data.IsSet { + subType = "set" + } + + event := &protocol.NoticeEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeNotice, + DetailType: "group_admin", + SubType: subType, + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}(milkyEvent.Data), + }, + GroupID: strconv.FormatInt(data.GroupID, 10), + UserID: strconv.FormatInt(data.UserID, 10), + } + + return event, nil +} + +// convertGroupEssenceMessageChangeEvent 转换群精华消息变更事件 +func (c *EventConverter) convertGroupEssenceMessageChangeEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data GroupEssenceMessageChangeEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse group essence message change data: %w", err) + } + + subType := "delete" + if data.IsSet { + subType = "add" + } + + event := &protocol.NoticeEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeNotice, + DetailType: "group_essence", + SubType: subType, + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}(milkyEvent.Data), + }, + GroupID: strconv.FormatInt(data.GroupID, 10), + } + + return event, nil +} + +// convertGroupMemberIncreaseEvent 转换群成员增加事件 +func (c *EventConverter) convertGroupMemberIncreaseEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data GroupMemberIncreaseEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse group member increase data: %w", err) + } + + event := &protocol.NoticeEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeNotice, + DetailType: "group_increase", + SubType: "", + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}(milkyEvent.Data), + }, + GroupID: strconv.FormatInt(data.GroupID, 10), + UserID: strconv.FormatInt(data.UserID, 10), + } + + if data.OperatorID != nil { + event.Operator = strconv.FormatInt(*data.OperatorID, 10) + } + + return event, nil +} + +// convertGroupMemberDecreaseEvent 转换群成员减少事件 +func (c *EventConverter) convertGroupMemberDecreaseEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data GroupMemberDecreaseEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse group member decrease data: %w", err) + } + + event := &protocol.NoticeEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeNotice, + DetailType: "group_decrease", + SubType: "", + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}(milkyEvent.Data), + }, + GroupID: strconv.FormatInt(data.GroupID, 10), + UserID: strconv.FormatInt(data.UserID, 10), + } + + if data.OperatorID != nil { + event.Operator = strconv.FormatInt(*data.OperatorID, 10) + } + + return event, nil +} + +// convertGroupNameChangeEvent 转换群名称变更事件 +func (c *EventConverter) convertGroupNameChangeEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data GroupNameChangeEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse group name change data: %w", err) + } + + event := &protocol.NoticeEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeNotice, + DetailType: "group_name_change", + SubType: "", + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}(milkyEvent.Data), + }, + GroupID: strconv.FormatInt(data.GroupID, 10), + Operator: strconv.FormatInt(data.OperatorID, 10), + } + + return event, nil +} + +// convertGroupMessageReactionEvent 转换群消息回应事件 +func (c *EventConverter) convertGroupMessageReactionEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data GroupMessageReactionEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse group message reaction data: %w", err) + } + + subType := "remove" + if data.IsAdd { + subType = "add" + } + + event := &protocol.NoticeEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeNotice, + DetailType: "group_message_reaction", + SubType: subType, + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}(milkyEvent.Data), + }, + GroupID: strconv.FormatInt(data.GroupID, 10), + UserID: strconv.FormatInt(data.UserID, 10), + } + + return event, nil +} + +// convertGroupMuteEvent 转换群禁言事件 +func (c *EventConverter) convertGroupMuteEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data GroupMuteEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse group mute data: %w", err) + } + + subType := "ban" + if data.Duration == 0 { + subType = "lift_ban" + } + + event := &protocol.NoticeEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeNotice, + DetailType: "group_ban", + SubType: subType, + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}(milkyEvent.Data), + }, + GroupID: strconv.FormatInt(data.GroupID, 10), + UserID: strconv.FormatInt(data.UserID, 10), + Operator: strconv.FormatInt(data.OperatorID, 10), + } + + return event, nil +} + +// convertGroupWholeMuteEvent 转换群全体禁言事件 +func (c *EventConverter) convertGroupWholeMuteEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data GroupWholeMuteEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse group whole mute data: %w", err) + } + + subType := "ban" + if !data.IsMute { + subType = "lift_ban" + } + + event := &protocol.NoticeEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeNotice, + DetailType: "group_whole_ban", + SubType: subType, + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}(milkyEvent.Data), + }, + GroupID: strconv.FormatInt(data.GroupID, 10), + Operator: strconv.FormatInt(data.OperatorID, 10), + } + + return event, nil +} + +// convertGroupNudgeEvent 转换群戳一戳事件 +func (c *EventConverter) convertGroupNudgeEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data GroupNudgeEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse group nudge data: %w", err) + } + + event := &protocol.NoticeEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeNotice, + DetailType: "group_nudge", + SubType: "", + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}(milkyEvent.Data), + }, + GroupID: strconv.FormatInt(data.GroupID, 10), + UserID: strconv.FormatInt(data.SenderID, 10), + } + + return event, nil +} + +// convertGroupFileUploadEvent 转换群文件上传事件 +func (c *EventConverter) convertGroupFileUploadEvent(milkyEvent *Event) (protocol.Event, error) { + selfID := strconv.FormatInt(milkyEvent.SelfID, 10) + + var data GroupFileUploadEventData + dataBytes, _ := sonic.Marshal(milkyEvent.Data) + if err := sonic.Unmarshal(dataBytes, &data); err != nil { + return nil, fmt.Errorf("failed to parse group file upload data: %w", err) + } + + event := &protocol.NoticeEvent{ + BaseEvent: protocol.BaseEvent{ + Type: protocol.EventTypeNotice, + DetailType: "group_file_upload", + SubType: "", + Timestamp: milkyEvent.Time, + SelfID: selfID, + Data: map[string]interface{}(milkyEvent.Data), + }, + GroupID: strconv.FormatInt(data.GroupID, 10), + UserID: strconv.FormatInt(data.UserID, 10), + } + + return event, nil +} + +// buildMessageText 构建消息文本 +func (c *EventConverter) buildMessageText(segments []IncomingSegment) string { + var text string + for _, seg := range segments { + if seg.Type == "text" { + if textData, ok := seg.Data["text"].(string); ok { + text += textData + } + } else if seg.Type == "mention" { + if userID, ok := seg.Data["user_id"].(float64); ok { + text += fmt.Sprintf("@%d", int64(userID)) + } + } else if seg.Type == "image" { + text += "[图片]" + } else if seg.Type == "voice" { + text += "[语音]" + } else if seg.Type == "video" { + text += "[视频]" + } else if seg.Type == "file" { + text += "[文件]" + } else if seg.Type == "face" { + text += "[表情]" + } else if seg.Type == "forward" { + text += "[转发消息]" + } + } + return text +} diff --git a/internal/adapter/milky/sse_client.go b/internal/adapter/milky/sse_client.go new file mode 100644 index 0000000..7201778 --- /dev/null +++ b/internal/adapter/milky/sse_client.go @@ -0,0 +1,240 @@ +package milky + +import ( + "bufio" + "context" + "fmt" + "net" + "net/http" + "strings" + "time" + + "go.uber.org/zap" +) + +// SSEClient Server-Sent Events 客户端 +// 用于接收协议端推送的事件 (GET /event) +type SSEClient struct { + url string + accessToken string + eventChan chan []byte + logger *zap.Logger + reconnectDelay time.Duration + maxReconnect int + ctx context.Context + cancel context.CancelFunc +} + +// NewSSEClient 创建 SSE 客户端 +func NewSSEClient(url, accessToken string, logger *zap.Logger) *SSEClient { + ctx, cancel := context.WithCancel(context.Background()) + return &SSEClient{ + url: url, + accessToken: accessToken, + eventChan: make(chan []byte, 100), + logger: logger.Named("sse-client"), + reconnectDelay: 5 * time.Second, + maxReconnect: -1, // 无限重连 + ctx: ctx, + cancel: cancel, + } +} + +// Connect 连接到 SSE 服务器 +func (c *SSEClient) Connect(ctx context.Context) error { + c.logger.Info("Starting SSE client", zap.String("url", c.url)) + + go c.connectLoop(ctx) + + return nil +} + +// connectLoop 连接循环(支持自动重连) +func (c *SSEClient) connectLoop(ctx context.Context) { + reconnectCount := 0 + + for { + select { + case <-ctx.Done(): + c.logger.Info("SSE client stopped") + return + case <-c.ctx.Done(): + c.logger.Info("SSE client stopped") + return + default: + } + + c.logger.Info("Connecting to SSE server", + zap.String("url", c.url), + zap.Int("reconnect_count", reconnectCount)) + + err := c.connect(ctx) + if err != nil { + c.logger.Error("SSE connection failed", zap.Error(err)) + } + + // 检查是否需要重连 + if c.maxReconnect >= 0 && reconnectCount >= c.maxReconnect { + c.logger.Error("Max reconnect attempts reached", zap.Int("count", reconnectCount)) + return + } + + reconnectCount++ + + // 等待后重连 + c.logger.Info("Reconnecting after delay", + zap.Duration("delay", c.reconnectDelay), + zap.Int("attempt", reconnectCount)) + + select { + case <-time.After(c.reconnectDelay): + case <-ctx.Done(): + return + case <-c.ctx.Done(): + return + } + } +} + +// connect 建立单次连接 +func (c *SSEClient) connect(ctx context.Context) error { + // 创建 HTTP 请求 + req, err := http.NewRequestWithContext(ctx, "GET", c.url, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + // 设置 Authorization 头 + if c.accessToken != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.accessToken)) + } + + // 设置 Accept 头 + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + // 发送请求 + client := &http.Client{ + Timeout: 0, // 无超时,保持长连接 + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + } + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to connect: %w", err) + } + defer resp.Body.Close() + + // 检查状态码 + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + // 检查 Content-Type + contentType := resp.Header.Get("Content-Type") + if !strings.HasPrefix(contentType, "text/event-stream") { + return fmt.Errorf("unexpected content type: %s", contentType) + } + + c.logger.Info("SSE connection established") + + // 读取事件流 + return c.readEventStream(ctx, resp) +} + +// readEventStream 读取事件流 +func (c *SSEClient) readEventStream(ctx context.Context, resp *http.Response) error { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + var eventType string + var dataLines []string + + for scanner.Scan() { + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.ctx.Done(): + return c.ctx.Err() + default: + } + + line := scanner.Text() + + // 空行表示事件结束 + if line == "" { + if eventType != "" && len(dataLines) > 0 { + c.processEvent(eventType, dataLines) + eventType = "" + dataLines = nil + } + continue + } + + // 注释行(以 : 开头) + if strings.HasPrefix(line, ":") { + continue + } + + // 解析字段 + if strings.HasPrefix(line, "event:") { + eventType = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + } else if strings.HasPrefix(line, "data:") { + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + dataLines = append(dataLines, data) + } + // 忽略其他字段(id, retry 等) + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("scanner error: %w", err) + } + + return fmt.Errorf("connection closed") +} + +// processEvent 处理事件 +func (c *SSEClient) processEvent(eventType string, dataLines []string) { + // 只处理 milky_event 类型 + if eventType != "milky_event" && eventType != "" { + c.logger.Debug("Ignoring non-milky event", zap.String("event_type", eventType)) + return + } + + // 合并多行 data + data := strings.Join(dataLines, "\n") + + c.logger.Debug("Received SSE event", + zap.String("event_type", eventType), + zap.Int("data_length", len(data))) + + // 发送到事件通道 + select { + case c.eventChan <- []byte(data): + default: + c.logger.Warn("Event channel full, dropping event") + } +} + +// Events 获取事件通道 +func (c *SSEClient) Events() <-chan []byte { + return c.eventChan +} + +// Close 关闭客户端 +func (c *SSEClient) Close() error { + c.cancel() + close(c.eventChan) + c.logger.Info("SSE client closed") + return nil +} diff --git a/internal/adapter/milky/types.go b/internal/adapter/milky/types.go new file mode 100644 index 0000000..96657bc --- /dev/null +++ b/internal/adapter/milky/types.go @@ -0,0 +1,368 @@ +package milky + +// Milky 协议类型定义 +// 基于官方 TypeScript 定义: https://github.com/SaltifyDev/milky + +// ============================================================================ +// 标量类型 +// ============================================================================ + +// Int64 表示 64 位整数(QQ号、群号等) +type Int64 = int64 + +// Int32 表示 32 位整数 +type Int32 = int32 + +// String 表示字符串 +type String = string + +// Boolean 表示布尔值 +type Boolean = bool + +// ============================================================================ +// 消息段类型 +// ============================================================================ + +// IncomingSegment 接收消息段(联合类型) +type IncomingSegment struct { + Type string `json:"type"` + Data map[string]interface{} `json:"data"` +} + +// OutgoingSegment 发送消息段(联合类型) +type OutgoingSegment struct { + Type string `json:"type"` + Data map[string]interface{} `json:"data"` +} + +// IncomingForwardedMessage 接收转发消息 +type IncomingForwardedMessage struct { + SenderName string `json:"sender_name"` + AvatarURL string `json:"avatar_url"` + Time int64 `json:"time"` + Segments []IncomingSegment `json:"segments"` +} + +// OutgoingForwardedMessage 发送转发消息 +type OutgoingForwardedMessage struct { + UserID int64 `json:"user_id"` + SenderName string `json:"sender_name"` + Segments []OutgoingSegment `json:"segments"` +} + +// ============================================================================ +// 实体类型 +// ============================================================================ + +// FriendCategoryEntity 好友分组实体 +type FriendCategoryEntity struct { + CategoryID int32 `json:"category_id"` + CategoryName string `json:"category_name"` +} + +// FriendEntity 好友实体 +type FriendEntity struct { + UserID int64 `json:"user_id"` + Nickname string `json:"nickname"` + Sex string `json:"sex"` // male, female, unknown + QID string `json:"qid"` + Remark string `json:"remark"` + Category FriendCategoryEntity `json:"category"` +} + +// GroupEntity 群实体 +type GroupEntity struct { + GroupID int64 `json:"group_id"` + GroupName string `json:"group_name"` + MemberCount int32 `json:"member_count"` + MaxMemberCount int32 `json:"max_member_count"` +} + +// GroupMemberEntity 群成员实体 +type GroupMemberEntity struct { + UserID int64 `json:"user_id"` + Nickname string `json:"nickname"` + Sex string `json:"sex"` // male, female, unknown + GroupID int64 `json:"group_id"` + Card string `json:"card"` + Title string `json:"title"` + Level int32 `json:"level"` + Role string `json:"role"` // owner, admin, member + JoinTime int64 `json:"join_time"` + LastSentTime int64 `json:"last_sent_time"` + ShutUpEndTime *int64 `json:"shut_up_end_time,omitempty"` +} + +// GroupAnnouncementEntity 群公告实体 +type GroupAnnouncementEntity struct { + GroupID int64 `json:"group_id"` + AnnouncementID string `json:"announcement_id"` + UserID int64 `json:"user_id"` + Time int64 `json:"time"` + Content string `json:"content"` + ImageURL *string `json:"image_url,omitempty"` +} + +// GroupFileEntity 群文件实体 +type GroupFileEntity struct { + GroupID int64 `json:"group_id"` + FileID string `json:"file_id"` + FileName string `json:"file_name"` + ParentFolderID string `json:"parent_folder_id"` + FileSize int64 `json:"file_size"` + UploadedTime int64 `json:"uploaded_time"` + ExpireTime *int64 `json:"expire_time,omitempty"` + UploaderID int64 `json:"uploader_id"` + DownloadedTimes int32 `json:"downloaded_times"` +} + +// GroupFolderEntity 群文件夹实体 +type GroupFolderEntity struct { + GroupID int64 `json:"group_id"` + FolderID string `json:"folder_id"` + ParentFolderID string `json:"parent_folder_id"` + FolderName string `json:"folder_name"` + CreatedTime int64 `json:"created_time"` + LastModifiedTime int64 `json:"last_modified_time"` + CreatorID int64 `json:"creator_id"` + FileCount int32 `json:"file_count"` +} + +// ============================================================================ +// 消息类型 +// ============================================================================ + +// IncomingMessage 接收消息(联合类型,使用 message_scene 区分) +type IncomingMessage struct { + MessageScene string `json:"message_scene"` // friend, group, temp + PeerID int64 `json:"peer_id"` + MessageSeq int64 `json:"message_seq"` + SenderID int64 `json:"sender_id"` + Time int64 `json:"time"` + Segments []IncomingSegment `json:"segments"` + + // 好友消息字段 + Friend *FriendEntity `json:"friend,omitempty"` + + // 群消息字段 + Group *GroupEntity `json:"group,omitempty"` + GroupMember *GroupMemberEntity `json:"group_member,omitempty"` +} + +// GroupEssenceMessage 群精华消息 +type GroupEssenceMessage struct { + GroupID int64 `json:"group_id"` + MessageSeq int64 `json:"message_seq"` + MessageTime int64 `json:"message_time"` + SenderID int64 `json:"sender_id"` + SenderName string `json:"sender_name"` + OperatorID int64 `json:"operator_id"` + OperatorName string `json:"operator_name"` + OperationTime int64 `json:"operation_time"` + Segments []IncomingSegment `json:"segments"` +} + +// ============================================================================ +// 事件数据类型 +// ============================================================================ + +// BotOfflineEventData 机器人离线事件数据 +type BotOfflineEventData struct { + Reason string `json:"reason"` +} + +// MessageRecallEventData 消息撤回事件数据 +type MessageRecallEventData struct { + MessageScene string `json:"message_scene"` // friend, group, temp + PeerID int64 `json:"peer_id"` + MessageSeq int64 `json:"message_seq"` + SenderID int64 `json:"sender_id"` + OperatorID int64 `json:"operator_id"` + DisplaySuffix string `json:"display_suffix"` +} + +// FriendRequestEventData 好友请求事件数据 +type FriendRequestEventData struct { + InitiatorID int64 `json:"initiator_id"` + InitiatorUID string `json:"initiator_uid"` + Comment string `json:"comment"` + Via string `json:"via"` +} + +// GroupJoinRequestEventData 入群申请事件数据 +type GroupJoinRequestEventData struct { + GroupID int64 `json:"group_id"` + NotificationSeq int64 `json:"notification_seq"` + IsFiltered bool `json:"is_filtered"` + InitiatorID int64 `json:"initiator_id"` + Comment string `json:"comment"` +} + +// GroupInvitedJoinRequestEventData 群成员邀请他人入群事件数据 +type GroupInvitedJoinRequestEventData struct { + GroupID int64 `json:"group_id"` + NotificationSeq int64 `json:"notification_seq"` + InitiatorID int64 `json:"initiator_id"` + TargetUserID int64 `json:"target_user_id"` +} + +// GroupInvitationEventData 他人邀请自身入群事件数据 +type GroupInvitationEventData struct { + GroupID int64 `json:"group_id"` + InvitationSeq int64 `json:"invitation_seq"` + InitiatorID int64 `json:"initiator_id"` +} + +// FriendNudgeEventData 好友戳一戳事件数据 +type FriendNudgeEventData struct { + UserID int64 `json:"user_id"` + IsSelfSend bool `json:"is_self_send"` + IsSelfReceive bool `json:"is_self_receive"` + DisplayAction string `json:"display_action"` + DisplaySuffix string `json:"display_suffix"` + DisplayActionImgURL string `json:"display_action_img_url"` +} + +// FriendFileUploadEventData 好友文件上传事件数据 +type FriendFileUploadEventData struct { + UserID int64 `json:"user_id"` + FileID string `json:"file_id"` + FileName string `json:"file_name"` + FileSize int64 `json:"file_size"` + FileHash string `json:"file_hash"` + IsSelf bool `json:"is_self"` +} + +// GroupAdminChangeEventData 群管理员变更事件数据 +type GroupAdminChangeEventData struct { + GroupID int64 `json:"group_id"` + UserID int64 `json:"user_id"` + IsSet bool `json:"is_set"` +} + +// GroupEssenceMessageChangeEventData 群精华消息变更事件数据 +type GroupEssenceMessageChangeEventData struct { + GroupID int64 `json:"group_id"` + MessageSeq int64 `json:"message_seq"` + IsSet bool `json:"is_set"` +} + +// GroupMemberIncreaseEventData 群成员增加事件数据 +type GroupMemberIncreaseEventData struct { + GroupID int64 `json:"group_id"` + UserID int64 `json:"user_id"` + OperatorID *int64 `json:"operator_id,omitempty"` + InvitorID *int64 `json:"invitor_id,omitempty"` +} + +// GroupMemberDecreaseEventData 群成员减少事件数据 +type GroupMemberDecreaseEventData struct { + GroupID int64 `json:"group_id"` + UserID int64 `json:"user_id"` + OperatorID *int64 `json:"operator_id,omitempty"` +} + +// GroupNameChangeEventData 群名称变更事件数据 +type GroupNameChangeEventData struct { + GroupID int64 `json:"group_id"` + NewGroupName string `json:"new_group_name"` + OperatorID int64 `json:"operator_id"` +} + +// GroupMessageReactionEventData 群消息回应事件数据 +type GroupMessageReactionEventData struct { + GroupID int64 `json:"group_id"` + UserID int64 `json:"user_id"` + MessageSeq int64 `json:"message_seq"` + FaceID string `json:"face_id"` + IsAdd bool `json:"is_add"` +} + +// GroupMuteEventData 群禁言事件数据 +type GroupMuteEventData struct { + GroupID int64 `json:"group_id"` + UserID int64 `json:"user_id"` + OperatorID int64 `json:"operator_id"` + Duration int32 `json:"duration"` // 秒,0表示取消禁言 +} + +// GroupWholeMuteEventData 群全体禁言事件数据 +type GroupWholeMuteEventData struct { + GroupID int64 `json:"group_id"` + OperatorID int64 `json:"operator_id"` + IsMute bool `json:"is_mute"` +} + +// GroupNudgeEventData 群戳一戳事件数据 +type GroupNudgeEventData struct { + GroupID int64 `json:"group_id"` + SenderID int64 `json:"sender_id"` + ReceiverID int64 `json:"receiver_id"` + DisplayAction string `json:"display_action"` + DisplaySuffix string `json:"display_suffix"` + DisplayActionImgURL string `json:"display_action_img_url"` +} + +// GroupFileUploadEventData 群文件上传事件数据 +type GroupFileUploadEventData struct { + GroupID int64 `json:"group_id"` + UserID int64 `json:"user_id"` + FileID string `json:"file_id"` + FileName string `json:"file_name"` + FileSize int64 `json:"file_size"` +} + +// ============================================================================ +// 事件类型 +// ============================================================================ + +// Event Milky 事件(联合类型,使用 event_type 区分) +type Event struct { + EventType string `json:"event_type"` + Time int64 `json:"time"` + SelfID int64 `json:"self_id"` + Data map[string]interface{} `json:"data"` +} + +// 事件类型常量 +const ( + EventTypeBotOffline = "bot_offline" + EventTypeMessageReceive = "message_receive" + EventTypeMessageRecall = "message_recall" + EventTypeFriendRequest = "friend_request" + EventTypeGroupJoinRequest = "group_join_request" + EventTypeGroupInvitedJoinRequest = "group_invited_join_request" + EventTypeGroupInvitation = "group_invitation" + EventTypeFriendNudge = "friend_nudge" + EventTypeFriendFileUpload = "friend_file_upload" + EventTypeGroupAdminChange = "group_admin_change" + EventTypeGroupEssenceMessageChange = "group_essence_message_change" + EventTypeGroupMemberIncrease = "group_member_increase" + EventTypeGroupMemberDecrease = "group_member_decrease" + EventTypeGroupNameChange = "group_name_change" + EventTypeGroupMessageReaction = "group_message_reaction" + EventTypeGroupMute = "group_mute" + EventTypeGroupWholeMute = "group_whole_mute" + EventTypeGroupNudge = "group_nudge" + EventTypeGroupFileUpload = "group_file_upload" +) + +// ============================================================================ +// API 响应类型 +// ============================================================================ + +// APIResponse API 响应 +type APIResponse struct { + Status string `json:"status"` // ok, failed + RetCode int `json:"retcode"` + Data map[string]interface{} `json:"data,omitempty"` + Message string `json:"message,omitempty"` +} + +// 响应状态码 +const ( + RetCodeSuccess = 0 + RetCodeNotLoggedIn = -403 + RetCodeInvalidParams = -400 + RetCodeNotFound = -404 +) diff --git a/internal/adapter/milky/webhook_server.go b/internal/adapter/milky/webhook_server.go new file mode 100644 index 0000000..5202ad8 --- /dev/null +++ b/internal/adapter/milky/webhook_server.go @@ -0,0 +1,115 @@ +package milky + +import ( + "fmt" + + "github.com/bytedance/sonic" + "github.com/valyala/fasthttp" + "go.uber.org/zap" +) + +// WebhookServer Webhook 服务器 +// 用于接收协议端 POST 推送的事件 +type WebhookServer struct { + server *fasthttp.Server + eventChan chan []byte + logger *zap.Logger + addr string +} + +// NewWebhookServer 创建 Webhook 服务器 +func NewWebhookServer(addr string, logger *zap.Logger) *WebhookServer { + return &WebhookServer{ + eventChan: make(chan []byte, 100), + logger: logger.Named("webhook-server"), + addr: addr, + } +} + +// Start 启动服务器 +func (s *WebhookServer) Start() error { + s.server = &fasthttp.Server{ + Handler: s.handleRequest, + MaxConnsPerIP: 1000, + MaxRequestsPerConn: 1000, + } + + s.logger.Info("Starting webhook server", zap.String("addr", s.addr)) + + go func() { + if err := s.server.ListenAndServe(s.addr); err != nil { + s.logger.Error("Webhook server error", zap.Error(err)) + } + }() + + return nil +} + +// handleRequest 处理请求 +func (s *WebhookServer) handleRequest(ctx *fasthttp.RequestCtx) { + // 只接受 POST 请求 + if !ctx.IsPost() { + s.logger.Warn("Received non-POST request", + zap.String("method", string(ctx.Method()))) + ctx.Error("Method Not Allowed", fasthttp.StatusMethodNotAllowed) + return + } + + // 检查 Content-Type + contentType := string(ctx.Request.Header.ContentType()) + if contentType != "application/json" { + s.logger.Warn("Invalid content type", + zap.String("content_type", contentType)) + ctx.Error("Unsupported Media Type", fasthttp.StatusUnsupportedMediaType) + return + } + + // 获取请求体 + body := ctx.PostBody() + if len(body) == 0 { + s.logger.Warn("Empty request body") + ctx.Error("Bad Request", fasthttp.StatusBadRequest) + return + } + + // 验证 JSON 格式 + var event Event + if err := sonic.Unmarshal(body, &event); err != nil { + s.logger.Error("Failed to parse event", zap.Error(err)) + ctx.Error("Bad Request", fasthttp.StatusBadRequest) + return + } + + s.logger.Debug("Received webhook event", + zap.String("event_type", event.EventType), + zap.Int64("self_id", event.SelfID)) + + // 发送到事件通道 + select { + case s.eventChan <- body: + default: + s.logger.Warn("Event channel full, dropping event") + } + + // 返回成功响应 + ctx.SetContentType("application/json") + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetBodyString(`{"status":"ok"}`) +} + +// Events 获取事件通道 +func (s *WebhookServer) Events() <-chan []byte { + return s.eventChan +} + +// Stop 停止服务器 +func (s *WebhookServer) Stop() error { + if s.server != nil { + s.logger.Info("Stopping webhook server") + if err := s.server.Shutdown(); err != nil { + return fmt.Errorf("failed to shutdown webhook server: %w", err) + } + } + close(s.eventChan) + return nil +} diff --git a/internal/adapter/onebot11/action.go b/internal/adapter/onebot11/action.go new file mode 100644 index 0000000..18313b6 --- /dev/null +++ b/internal/adapter/onebot11/action.go @@ -0,0 +1,306 @@ +package onebot11 + +import ( + "cellbot/internal/protocol" +) + +// OneBot11 API动作常量 +const ( + ActionSendPrivateMsg = "send_private_msg" + ActionSendGroupMsg = "send_group_msg" + ActionSendMsg = "send_msg" + ActionDeleteMsg = "delete_msg" + ActionGetMsg = "get_msg" + ActionGetForwardMsg = "get_forward_msg" + ActionSendLike = "send_like" + ActionSetGroupKick = "set_group_kick" + ActionSetGroupBan = "set_group_ban" + ActionSetGroupAnonymousBan = "set_group_anonymous_ban" + ActionSetGroupWholeBan = "set_group_whole_ban" + ActionSetGroupAdmin = "set_group_admin" + ActionSetGroupAnonymous = "set_group_anonymous" + ActionSetGroupCard = "set_group_card" + ActionSetGroupName = "set_group_name" + ActionSetGroupLeave = "set_group_leave" + ActionSetGroupSpecialTitle = "set_group_special_title" + ActionSetFriendAddRequest = "set_friend_add_request" + ActionSetGroupAddRequest = "set_group_add_request" + ActionGetLoginInfo = "get_login_info" + ActionGetStrangerInfo = "get_stranger_info" + ActionGetFriendList = "get_friend_list" + ActionGetGroupInfo = "get_group_info" + ActionGetGroupList = "get_group_list" + ActionGetGroupMemberInfo = "get_group_member_info" + ActionGetGroupMemberList = "get_group_member_list" + ActionGetGroupHonorInfo = "get_group_honor_info" + ActionGetCookies = "get_cookies" + ActionGetCsrfToken = "get_csrf_token" + ActionGetCredentials = "get_credentials" + ActionGetRecord = "get_record" + ActionGetImage = "get_image" + ActionCanSendImage = "can_send_image" + ActionCanSendRecord = "can_send_record" + ActionGetStatus = "get_status" + ActionGetVersionInfo = "get_version_info" + ActionSetRestart = "set_restart" + ActionCleanCache = "clean_cache" +) + +// ConvertAction 将通用Action转换为OneBot11 Action +func ConvertAction(action protocol.Action) string { + switch action.GetType() { + case protocol.ActionTypeSendPrivateMessage: + return ActionSendPrivateMsg + case protocol.ActionTypeSendGroupMessage: + return ActionSendGroupMsg + case protocol.ActionTypeDeleteMessage: + return ActionDeleteMsg + case protocol.ActionTypeGetUserInfo: + return ActionGetStrangerInfo + case protocol.ActionTypeGetFriendList: + return ActionGetFriendList + case protocol.ActionTypeGetGroupInfo: + return ActionGetGroupInfo + case protocol.ActionTypeGetGroupMemberList: + return ActionGetGroupMemberList + case protocol.ActionTypeSetGroupKick: + return ActionSetGroupKick + case protocol.ActionTypeSetGroupBan: + return ActionSetGroupBan + case protocol.ActionTypeSetGroupAdmin: + return ActionSetGroupAdmin + case protocol.ActionTypeSetGroupWholeBan: + return ActionSetGroupWholeBan + case protocol.ActionTypeGetStatus: + return ActionGetStatus + case protocol.ActionTypeGetVersion: + return ActionGetVersionInfo + default: + return string(action.GetType()) + } +} + +// SendPrivateMessageAction 发送私聊消息动作 +type SendPrivateMessageAction struct { + UserID int64 `json:"user_id"` + Message interface{} `json:"message"` + AutoEscape bool `json:"auto_escape,omitempty"` +} + +// SendGroupMessageAction 发送群消息动作 +type SendGroupMessageAction struct { + GroupID int64 `json:"group_id"` + Message interface{} `json:"message"` + AutoEscape bool `json:"auto_escape,omitempty"` +} + +// DeleteMessageAction 撤回消息动作 +type DeleteMessageAction struct { + MessageID int32 `json:"message_id"` +} + +// GetMessageAction 获取消息动作 +type GetMessageAction struct { + MessageID int32 `json:"message_id"` +} + +// SendLikeAction 发送好友赞动作 +type SendLikeAction struct { + UserID int64 `json:"user_id"` + Times int `json:"times,omitempty"` +} + +// SetGroupKickAction 群组踢人动作 +type SetGroupKickAction struct { + GroupID int64 `json:"group_id"` + UserID int64 `json:"user_id"` + RejectAddRequest bool `json:"reject_add_request,omitempty"` +} + +// SetGroupBanAction 群组禁言动作 +type SetGroupBanAction struct { + GroupID int64 `json:"group_id"` + UserID int64 `json:"user_id"` + Duration int64 `json:"duration,omitempty"` // 禁言时长,单位秒,0表示取消禁言 +} + +// SetGroupWholeBanAction 群组全员禁言动作 +type SetGroupWholeBanAction struct { + GroupID int64 `json:"group_id"` + Enable bool `json:"enable,omitempty"` +} + +// SetGroupAdminAction 设置群管理员动作 +type SetGroupAdminAction struct { + GroupID int64 `json:"group_id"` + UserID int64 `json:"user_id"` + Enable bool `json:"enable,omitempty"` +} + +// SetGroupCardAction 设置群名片动作 +type SetGroupCardAction struct { + GroupID int64 `json:"group_id"` + UserID int64 `json:"user_id"` + Card string `json:"card,omitempty"` +} + +// SetGroupNameAction 设置群名动作 +type SetGroupNameAction struct { + GroupID int64 `json:"group_id"` + GroupName string `json:"group_name"` +} + +// SetGroupLeaveAction 退出群组动作 +type SetGroupLeaveAction struct { + GroupID int64 `json:"group_id"` + IsDismiss bool `json:"is_dismiss,omitempty"` +} + +// SetFriendAddRequestAction 处理加好友请求动作 +type SetFriendAddRequestAction struct { + Flag string `json:"flag"` + Approve bool `json:"approve,omitempty"` + Remark string `json:"remark,omitempty"` +} + +// SetGroupAddRequestAction 处理加群请求动作 +type SetGroupAddRequestAction struct { + Flag string `json:"flag"` + SubType string `json:"sub_type,omitempty"` // add 或 invite + Approve bool `json:"approve,omitempty"` + Reason string `json:"reason,omitempty"` +} + +// GetStrangerInfoAction 获取陌生人信息动作 +type GetStrangerInfoAction struct { + UserID int64 `json:"user_id"` + NoCache bool `json:"no_cache,omitempty"` +} + +// GetGroupInfoAction 获取群信息动作 +type GetGroupInfoAction struct { + GroupID int64 `json:"group_id"` + NoCache bool `json:"no_cache,omitempty"` +} + +// GetGroupMemberInfoAction 获取群成员信息动作 +type GetGroupMemberInfoAction struct { + GroupID int64 `json:"group_id"` + UserID int64 `json:"user_id"` + NoCache bool `json:"no_cache,omitempty"` +} + +// GetGroupMemberListAction 获取群成员列表动作 +type GetGroupMemberListAction struct { + GroupID int64 `json:"group_id"` +} + +// GetGroupHonorInfoAction 获取群荣誉信息动作 +type GetGroupHonorInfoAction struct { + GroupID int64 `json:"group_id"` + Type string `json:"type"` // talkative, performer, legend, strong_newbie, emotion, all +} + +// GetCookiesAction 获取Cookies动作 +type GetCookiesAction struct { + Domain string `json:"domain,omitempty"` +} + +// GetRecordAction 获取语音动作 +type GetRecordAction struct { + File string `json:"file"` + OutFormat string `json:"out_format"` +} + +// GetImageAction 获取图片动作 +type GetImageAction struct { + File string `json:"file"` +} + +// SetRestartAction 重启OneBot实现动作 +type SetRestartAction struct { + Delay int `json:"delay,omitempty"` // 延迟毫秒数 +} + +// ActionResponse API响应 +type ActionResponse struct { + Status string `json:"status"` + RetCode int `json:"retcode"` + Data map[string]interface{} `json:"data,omitempty"` + Echo string `json:"echo,omitempty"` + Message string `json:"message,omitempty"` + Wording string `json:"wording,omitempty"` +} + +// 响应状态码常量 +const ( + RetCodeOK = 0 + RetCodeAsyncStarted = 1 // 异步操作已开始 + RetCodeBadRequest = 1400 // 请求格式错误 + RetCodeUnauthorized = 1401 // 未授权 + RetCodeForbidden = 1403 // 禁止访问 + RetCodeNotFound = 1404 // 接口不存在 + RetCodeMethodNotAllowed = 1405 // 请求方法不支持 + RetCodeInternalError = 1500 // 内部错误 +) + +// BuildActionRequest 构建动作请求 +func BuildActionRequest(action string, params map[string]interface{}, echo string) *OB11Action { + return &OB11Action{ + Action: action, + Params: params, + Echo: echo, + } +} + +// BuildSendPrivateMsg 构建发送私聊消息请求 +func BuildSendPrivateMsg(userID int64, message interface{}, autoEscape bool) map[string]interface{} { + return map[string]interface{}{ + "user_id": userID, + "message": message, + "auto_escape": autoEscape, + } +} + +// BuildSendGroupMsg 构建发送群消息请求 +func BuildSendGroupMsg(groupID int64, message interface{}, autoEscape bool) map[string]interface{} { + return map[string]interface{}{ + "group_id": groupID, + "message": message, + "auto_escape": autoEscape, + } +} + +// BuildDeleteMsg 构建撤回消息请求 +func BuildDeleteMsg(messageID int32) map[string]interface{} { + return map[string]interface{}{ + "message_id": messageID, + } +} + +// BuildSetGroupBan 构建群组禁言请求 +func BuildSetGroupBan(groupID, userID int64, duration int64) map[string]interface{} { + return map[string]interface{}{ + "group_id": groupID, + "user_id": userID, + "duration": duration, + } +} + +// BuildSetGroupKick 构建群组踢人请求 +func BuildSetGroupKick(groupID, userID int64, rejectAddRequest bool) map[string]interface{} { + return map[string]interface{}{ + "group_id": groupID, + "user_id": userID, + "reject_add_request": rejectAddRequest, + } +} + +// BuildSetGroupCard 构建设置群名片请求 +func BuildSetGroupCard(groupID, userID int64, card string) map[string]interface{} { + return map[string]interface{}{ + "group_id": groupID, + "user_id": userID, + "card": card, + } +} \ No newline at end of file diff --git a/internal/adapter/onebot11/adapter.go b/internal/adapter/onebot11/adapter.go new file mode 100644 index 0000000..4186d9f --- /dev/null +++ b/internal/adapter/onebot11/adapter.go @@ -0,0 +1,467 @@ +package onebot11 + +import ( + "context" + "fmt" + "sync" + "time" + + "cellbot/internal/engine" + "cellbot/internal/protocol" + "cellbot/pkg/net" + + "github.com/bytedance/sonic" + "go.uber.org/zap" +) + +// Adapter OneBot11协议适配器 +type Adapter struct { + config *Config + logger *zap.Logger + wsManager *net.WebSocketManager + httpClient *HTTPClient + wsWaiter *WSResponseWaiter + eventBus *engine.EventBus + selfID string + connected bool + mu sync.RWMutex + wsConnection *net.WebSocketConnection + ctx context.Context + cancel context.CancelFunc +} + +// Config OneBot11配置 +type Config struct { + // 连接配置 + ConnectionType string `json:"connection_type"` // ws, ws-reverse, http, http-post + Host string `json:"host"` + Port int `json:"port"` + AccessToken string `json:"access_token"` + + // WebSocket配置 + WSUrl string `json:"ws_url"` // 正向WS地址 + WSReverseUrl string `json:"ws_reverse_url"` // 反向WS监听地址 + Heartbeat int `json:"heartbeat"` // 心跳间隔(秒) + ReconnectInterval int `json:"reconnect_interval"` // 重连间隔(秒) + + // HTTP配置 + HTTPUrl string `json:"http_url"` // 正向HTTP地址 + HTTPPostUrl string `json:"http_post_url"` // HTTP POST上报地址 + Secret string `json:"secret"` // 签名密钥 + Timeout int `json:"timeout"` // 超时时间(秒) + + // 其他配置 + SelfID string `json:"self_id"` // 机器人QQ号 + Nickname string `json:"nickname"` // 机器人昵称 +} + +// NewAdapter 创建OneBot11适配器 +func NewAdapter(config *Config, logger *zap.Logger, wsManager *net.WebSocketManager, eventBus *engine.EventBus) *Adapter { + ctx, cancel := context.WithCancel(context.Background()) + + timeout := time.Duration(config.Timeout) * time.Second + if timeout == 0 { + timeout = 30 * time.Second + } + + adapter := &Adapter{ + config: config, + logger: logger.Named("onebot11"), + wsManager: wsManager, + wsWaiter: NewWSResponseWaiter(timeout, logger), + eventBus: eventBus, + selfID: config.SelfID, + ctx: ctx, + cancel: cancel, + } + + // 如果使用HTTP连接,初始化HTTP客户端 + if config.ConnectionType == "http" && config.HTTPUrl != "" { + adapter.httpClient = NewHTTPClient(config.HTTPUrl, config.AccessToken, timeout, logger) + } + + return adapter +} + +// Name 获取协议名称 +func (a *Adapter) Name() string { + return "OneBot" +} + +// Version 获取协议版本 +func (a *Adapter) Version() string { + return "11" +} + +// Connect 建立连接 +func (a *Adapter) Connect(ctx context.Context) error { + a.mu.Lock() + defer a.mu.Unlock() + + if a.connected { + return fmt.Errorf("already connected") + } + + a.logger.Info("Starting OneBot11 connection", + zap.String("connection_type", a.config.ConnectionType), + zap.String("self_id", a.selfID)) + + switch a.config.ConnectionType { + case "ws": + return a.connectWebSocket(ctx) + case "ws-reverse": + return a.connectWebSocketReverse(ctx) + case "http": + return a.connectHTTP(ctx) + case "http-post": + return a.connectHTTPPost(ctx) + default: + return fmt.Errorf("unsupported connection type: %s", a.config.ConnectionType) + } +} + +// Disconnect 断开连接 +func (a *Adapter) Disconnect(ctx context.Context) error { + a.mu.Lock() + defer a.mu.Unlock() + + if !a.connected { + a.logger.Debug("Already disconnected, skipping") + return nil + } + + a.logger.Info("Disconnecting OneBot11 adapter", + zap.String("connection_type", a.config.ConnectionType)) + + // 取消上下文 + if a.cancel != nil { + a.cancel() + a.logger.Debug("Context cancelled") + } + + // 关闭WebSocket连接 + if a.wsConnection != nil { + a.logger.Info("Closing WebSocket connection", + zap.String("connection_id", a.wsConnection.ID)) + a.wsManager.RemoveConnection(a.wsConnection.ID) + a.wsConnection = nil + } + + // 关闭HTTP客户端 + if a.httpClient != nil { + if err := a.httpClient.Close(); err != nil { + a.logger.Error("Failed to close HTTP client", zap.Error(err)) + } else { + a.logger.Debug("HTTP client closed") + } + } + + a.connected = false + a.logger.Info("OneBot11 adapter disconnected successfully") + return nil +} + +// IsConnected 检查连接状态 +func (a *Adapter) IsConnected() bool { + a.mu.RLock() + defer a.mu.RUnlock() + return a.connected +} + +// GetSelfID 获取机器人自身ID +func (a *Adapter) GetSelfID() string { + return a.selfID +} + +// SendAction 发送动作 +func (a *Adapter) SendAction(ctx context.Context, action protocol.Action) (map[string]interface{}, error) { + // 序列化为OneBot11格式 + data, err := a.SerializeAction(action) + if err != nil { + return nil, err + } + + switch a.config.ConnectionType { + case "ws", "ws-reverse": + return a.sendActionWebSocket(data) + case "http": + return a.sendActionHTTP(data) + default: + return nil, fmt.Errorf("unsupported connection type for sending action: %s", a.config.ConnectionType) + } +} + +// HandleEvent 处理事件 +func (a *Adapter) HandleEvent(ctx context.Context, event protocol.Event) error { + a.logger.Debug("Handling event", + zap.String("type", string(event.GetType())), + zap.String("detail_type", event.GetDetailType())) + return nil +} + +// ParseMessage 解析原始消息为Event +func (a *Adapter) ParseMessage(raw []byte) (protocol.Event, error) { + var rawEvent RawEvent + if err := sonic.Unmarshal(raw, &rawEvent); err != nil { + return nil, fmt.Errorf("failed to unmarshal raw event: %w", err) + } + + return a.convertToEvent(&rawEvent) +} + +// SerializeAction 序列化Action为协议格式 +func (a *Adapter) SerializeAction(action protocol.Action) ([]byte, error) { + // 转换为OneBot11格式 + ob11ActionName := ConvertAction(action) + + // 检查是否有未转换的动作类型(如果转换后的名称与原始类型相同,说明没有匹配到) + originalType := string(action.GetType()) + if ob11ActionName == originalType { + a.logger.Warn("Action type not converted, using original type", + zap.String("action_type", originalType), + zap.String("hint", "This action type may not be supported by OneBot11")) + } + + ob11Action := &OB11Action{ + Action: ob11ActionName, + Params: action.GetParams(), + } + + return sonic.Marshal(ob11Action) +} + +// connectWebSocket 正向WebSocket连接 +func (a *Adapter) connectWebSocket(ctx context.Context) error { + if a.config.WSUrl == "" { + return fmt.Errorf("ws_url is required for ws connection") + } + + a.logger.Info("Connecting to OneBot WebSocket server", + zap.String("url", a.config.WSUrl), + zap.Bool("has_token", a.config.AccessToken != "")) + + // 添加访问令牌到URL + wsURL := a.config.WSUrl + if a.config.AccessToken != "" { + wsURL += "?access_token=" + a.config.AccessToken + a.logger.Debug("Added access token to WebSocket URL") + } + + a.logger.Info("Dialing WebSocket...", + zap.String("full_url", wsURL)) + + wsConn, err := a.wsManager.Dial(wsURL, a.selfID) + if err != nil { + a.logger.Error("Failed to connect WebSocket", + zap.String("url", a.config.WSUrl), + zap.Error(err)) + return fmt.Errorf("failed to connect websocket: %w", err) + } + + a.wsConnection = wsConn + a.connected = true + + a.logger.Info("WebSocket connected successfully", + zap.String("url", a.config.WSUrl), + zap.String("remote_addr", wsConn.RemoteAddr), + zap.String("connection_id", wsConn.ID)) + + // 启动消息接收处理 + go a.handleWebSocketMessages() + + a.logger.Info("WebSocket message handler started") + + return nil +} + +// connectWebSocketReverse 反向WebSocket连接 +func (a *Adapter) connectWebSocketReverse(ctx context.Context) error { + // 反向WebSocket由客户端主动连接到服务器 + // WebSocket服务器会在主Server中启动 + // 这里只需要标记为已连接状态,等待客户端通过HTTP服务器连接 + a.connected = true + + a.logger.Info("OneBot11 adapter ready for reverse WebSocket connections", + zap.String("bot_id", a.selfID), + zap.String("listen_addr", a.config.WSReverseUrl)) + + // 注意:实际的WebSocket服务器由pkg/net/server.go提供 + // OneBot客户端需要连接到 ws://host:port/ws?bot_id= + + return nil +} + +// connectHTTP 正向HTTP连接 +func (a *Adapter) connectHTTP(ctx context.Context) error { + if a.config.HTTPUrl == "" { + return fmt.Errorf("http_url is required for http connection") + } + + // 创建HTTP客户端 + // TODO: 实现HTTP轮询 + a.connected = true + + a.logger.Info("HTTP connected", + zap.String("url", a.config.HTTPUrl)) + + return nil +} + +// connectHTTPPost HTTP POST上报 +func (a *Adapter) connectHTTPPost(ctx context.Context) error { + if a.config.HTTPPostUrl == "" { + return fmt.Errorf("http_post_url is required for http-post connection") + } + + // HTTP POST由客户端主动推送事件 + a.connected = true + + a.logger.Info("HTTP POST ready", + zap.String("url", a.config.HTTPPostUrl)) + + return nil +} + +// sendActionWebSocket 通过WebSocket发送动作 +func (a *Adapter) sendActionWebSocket(data []byte) (map[string]interface{}, error) { + if a.wsConnection == nil { + return nil, fmt.Errorf("websocket connection not established") + } + + // 解析请求以获取或添加echo + var req OB11Action + if err := sonic.Unmarshal(data, &req); err != nil { + return nil, fmt.Errorf("failed to unmarshal action: %w", err) + } + + // 如果没有echo,生成一个 + if req.Echo == "" { + req.Echo = GenerateEcho() + var err error + data, err = sonic.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal action with echo: %w", err) + } + } + + // 发送消息 + if err := a.wsConnection.SendMessage(data); err != nil { + return nil, fmt.Errorf("failed to send action: %w", err) + } + + // 等待响应 + resp, err := a.wsWaiter.Wait(req.Echo) + if err != nil { + return nil, err + } + + // 检查响应状态 + if resp.Status != "ok" && resp.Status != "async" { + return resp.Data, fmt.Errorf("action failed (retcode=%d)", resp.RetCode) + } + + return resp.Data, nil +} + +// sendActionHTTP 通过HTTP发送动作 +func (a *Adapter) sendActionHTTP(data []byte) (map[string]interface{}, error) { + if a.httpClient == nil { + return nil, fmt.Errorf("http client not initialized") + } + + // 解析请求 + var req OB11Action + if err := sonic.Unmarshal(data, &req); err != nil { + return nil, fmt.Errorf("failed to unmarshal action: %w", err) + } + + // 调用HTTP API + resp, err := a.httpClient.Call(a.ctx, req.Action, req.Params) + if err != nil { + return nil, err + } + + // 检查响应状态 + if resp.Status != "ok" && resp.Status != "async" { + return resp.Data, fmt.Errorf("action failed (retcode=%d)", resp.RetCode) + } + + return resp.Data, nil +} + +// handleWebSocketMessages 处理WebSocket消息 +func (a *Adapter) handleWebSocketMessages() { + a.logger.Info("WebSocket message handler started, waiting for messages...") + + for { + select { + case <-a.ctx.Done(): + a.logger.Info("Context cancelled, stopping WebSocket message handler") + return + default: + } + + if a.wsConnection == nil || a.wsConnection.Conn == nil { + a.logger.Warn("WebSocket connection is nil, stopping message handler") + return + } + + // 读取消息 + _, message, err := a.wsConnection.Conn.ReadMessage() + if err != nil { + a.logger.Error("Failed to read WebSocket message", + zap.Error(err), + zap.String("connection_id", a.wsConnection.ID)) + return + } + + a.logger.Debug("Received WebSocket message", + zap.Int("size", len(message)), + zap.String("preview", string(message[:min(len(message), 200)]))) + + // 尝试解析为响应 + var resp OB11Response + if err := sonic.Unmarshal(message, &resp); err == nil { + // 如果有echo字段,说明是API响应 + if resp.Echo != "" { + a.logger.Debug("Received API response", + zap.String("echo", resp.Echo), + zap.String("status", resp.Status), + zap.Int("retcode", resp.RetCode)) + a.wsWaiter.Notify(&resp) + continue + } + } + + // 否则当作事件处理 + a.logger.Info("Received OneBot event", + zap.ByteString("raw_event", message)) + + // 解析事件 + a.logger.Info("Parsing OneBot event...") + event, err := a.ParseMessage(message) + if err != nil { + a.logger.Error("Failed to parse event", + zap.Error(err), + zap.ByteString("raw_message", message)) + continue + } + + // 发布事件到事件总线 + a.logger.Info("Publishing event to event bus", + zap.String("event_type", string(event.GetType())), + zap.String("detail_type", event.GetDetailType()), + zap.String("self_id", event.GetSelfID())) + + a.eventBus.Publish(event) + + a.logger.Info("Event published successfully") + } +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/internal/adapter/onebot11/bot.go b/internal/adapter/onebot11/bot.go new file mode 100644 index 0000000..c340622 --- /dev/null +++ b/internal/adapter/onebot11/bot.go @@ -0,0 +1,36 @@ +package onebot11 + +import ( + "cellbot/internal/engine" + "cellbot/internal/protocol" + "cellbot/pkg/net" + + "go.uber.org/zap" +) + +// Bot OneBot11机器人实例 +type Bot struct { + *protocol.BaseBotInstance + adapter *Adapter +} + +// NewBot 创建OneBot11机器人实例 +func NewBot(id string, config *Config, logger *zap.Logger, wsManager *net.WebSocketManager, eventBus *engine.EventBus) *Bot { + adapter := NewAdapter(config, logger, wsManager, eventBus) + baseBot := protocol.NewBaseBotInstance(id, adapter, logger) + + return &Bot{ + BaseBotInstance: baseBot, + adapter: adapter, + } +} + +// GetAdapter 获取适配器 +func (b *Bot) GetAdapter() *Adapter { + return b.adapter +} + +// GetConfig 获取配置 +func (b *Bot) GetConfig() *Config { + return b.adapter.config +} diff --git a/internal/adapter/onebot11/client.go b/internal/adapter/onebot11/client.go new file mode 100644 index 0000000..4c955b0 --- /dev/null +++ b/internal/adapter/onebot11/client.go @@ -0,0 +1,186 @@ +package onebot11 + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/bytedance/sonic" + "github.com/google/uuid" + "github.com/valyala/fasthttp" + "go.uber.org/zap" +) + +// HTTPClient OneBot11 HTTP客户端 +type HTTPClient struct { + baseURL string + accessToken string + httpClient *fasthttp.Client + logger *zap.Logger + timeout time.Duration +} + +// NewHTTPClient 创建HTTP客户端 +func NewHTTPClient(baseURL, accessToken string, timeout time.Duration, logger *zap.Logger) *HTTPClient { + if timeout == 0 { + timeout = 30 * time.Second + } + + return &HTTPClient{ + baseURL: baseURL, + accessToken: accessToken, + httpClient: &fasthttp.Client{ + ReadTimeout: timeout, + WriteTimeout: timeout, + MaxConnsPerHost: 100, + }, + logger: logger.Named("http-client"), + timeout: timeout, + } +} + +// Call 调用API +func (c *HTTPClient) Call(ctx context.Context, action string, params map[string]interface{}) (*OB11Response, error) { + // 构建请求数据 + reqData := map[string]interface{}{ + "action": action, + "params": params, + } + + data, err := sonic.Marshal(reqData) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // 构建URL + url := fmt.Sprintf("%s/%s", c.baseURL, action) + + c.logger.Debug("Calling HTTP API", + zap.String("action", action), + zap.String("url", url)) + + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // 设置请求 + req.SetRequestURI(url) + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.SetBody(data) + + // 设置访问令牌 + if c.accessToken != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.accessToken)) + } + + // 发送请求 + if err := c.httpClient.DoTimeout(req, resp, c.timeout); err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + // 检查HTTP状态码 + statusCode := resp.StatusCode() + if statusCode != 200 { + return nil, fmt.Errorf("unexpected status code: %d", statusCode) + } + + // 解析响应 + var ob11Resp OB11Response + if err := sonic.Unmarshal(resp.Body(), &ob11Resp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // 检查业务状态 + if ob11Resp.Status != "ok" && ob11Resp.Status != "async" { + return &ob11Resp, fmt.Errorf("API error (retcode=%d)", ob11Resp.RetCode) + } + + c.logger.Debug("HTTP API call succeeded", + zap.String("action", action), + zap.String("status", ob11Resp.Status)) + + return &ob11Resp, nil +} + +// Close 关闭客户端 +func (c *HTTPClient) Close() error { + // fasthttp.Client 不需要显式关闭 + return nil +} + +// WSResponseWaiter WebSocket响应等待器 +type WSResponseWaiter struct { + pending map[string]chan *OB11Response + mu sync.RWMutex + logger *zap.Logger + timeout time.Duration +} + +// NewWSResponseWaiter 创建WebSocket响应等待器 +func NewWSResponseWaiter(timeout time.Duration, logger *zap.Logger) *WSResponseWaiter { + if timeout == 0 { + timeout = 30 * time.Second + } + + return &WSResponseWaiter{ + pending: make(map[string]chan *OB11Response), + logger: logger.Named("ws-waiter"), + timeout: timeout, + } +} + +// Wait 等待响应 +func (w *WSResponseWaiter) Wait(echo string) (*OB11Response, error) { + w.mu.Lock() + ch := make(chan *OB11Response, 1) + w.pending[echo] = ch + w.mu.Unlock() + + defer func() { + w.mu.Lock() + delete(w.pending, echo) + close(ch) + w.mu.Unlock() + }() + + select { + case resp := <-ch: + return resp, nil + case <-time.After(w.timeout): + return nil, fmt.Errorf("timeout waiting for response (echo=%s)", echo) + } +} + +// Notify 通知响应到达 +func (w *WSResponseWaiter) Notify(resp *OB11Response) { + if resp.Echo == "" { + return + } + + w.mu.RLock() + ch, ok := w.pending[resp.Echo] + w.mu.RUnlock() + + if !ok { + w.logger.Warn("Received response for unknown echo", + zap.String("echo", resp.Echo)) + return + } + + select { + case ch <- resp: + w.logger.Debug("Notified response", + zap.String("echo", resp.Echo)) + default: + w.logger.Warn("Failed to notify response: channel full", + zap.String("echo", resp.Echo)) + } +} + +// GenerateEcho 生成唯一的echo标识 +func GenerateEcho() string { + return uuid.New().String() +} diff --git a/internal/adapter/onebot11/event.go b/internal/adapter/onebot11/event.go new file mode 100644 index 0000000..c1eea54 --- /dev/null +++ b/internal/adapter/onebot11/event.go @@ -0,0 +1,355 @@ +package onebot11 + +import ( + "fmt" + "strconv" + + "cellbot/internal/protocol" + + "github.com/bytedance/sonic" +) + +// convertToEvent 将OneBot11原始事件转换为通用事件 +func (a *Adapter) convertToEvent(raw *RawEvent) (protocol.Event, error) { + baseEvent := &protocol.BaseEvent{ + Timestamp: raw.Time, + SelfID: strconv.FormatInt(raw.SelfID, 10), + Data: make(map[string]interface{}), + } + + switch raw.PostType { + case PostTypeMessage: + return a.convertMessageEvent(raw, baseEvent) + case PostTypeNotice: + return a.convertNoticeEvent(raw, baseEvent) + case PostTypeRequest: + return a.convertRequestEvent(raw, baseEvent) + case PostTypeMetaEvent: + return a.convertMetaEvent(raw, baseEvent) + default: + return nil, fmt.Errorf("unknown post_type: %s", raw.PostType) + } +} + +// convertMessageEvent 转换消息事件 +func (a *Adapter) convertMessageEvent(raw *RawEvent, base *protocol.BaseEvent) (protocol.Event, error) { + base.Type = protocol.EventTypeMessage + base.DetailType = raw.MessageType + base.SubType = raw.SubType + + // 构建消息数据 + base.Data["message_id"] = raw.MessageID + base.Data["user_id"] = raw.UserID + base.Data["message"] = raw.Message + base.Data["raw_message"] = raw.RawMessage + base.Data["font"] = raw.Font + + if raw.GroupID > 0 { + base.Data["group_id"] = raw.GroupID + } + + if raw.Sender != nil { + senderData := map[string]interface{}{ + "user_id": raw.Sender.UserID, + "nickname": raw.Sender.Nickname, + } + if raw.Sender.Sex != "" { + senderData["sex"] = raw.Sender.Sex + } + if raw.Sender.Age > 0 { + senderData["age"] = raw.Sender.Age + } + if raw.Sender.Card != "" { + senderData["card"] = raw.Sender.Card + } + if raw.Sender.Role != "" { + senderData["role"] = raw.Sender.Role + } + base.Data["sender"] = senderData + } + + if raw.Anonymous != nil { + base.Data["anonymous"] = map[string]interface{}{ + "id": raw.Anonymous.ID, + "name": raw.Anonymous.Name, + "flag": raw.Anonymous.Flag, + } + } + + // 解析消息段 + if segments, err := a.parseMessageSegments(raw.Message); err == nil { + base.Data["message_segments"] = segments + } + + return base, nil +} + +// convertNoticeEvent 转换通知事件 +func (a *Adapter) convertNoticeEvent(raw *RawEvent, base *protocol.BaseEvent) (protocol.Event, error) { + base.Type = protocol.EventTypeNotice + base.DetailType = raw.NoticeType + base.SubType = raw.SubType + + base.Data["user_id"] = raw.UserID + + if raw.GroupID > 0 { + base.Data["group_id"] = raw.GroupID + } + + if raw.OperatorID > 0 { + base.Data["operator_id"] = raw.OperatorID + } + + if raw.Duration > 0 { + base.Data["duration"] = raw.Duration + } + + // 根据不同的通知类型添加特定数据 + switch raw.NoticeType { + case NoticeTypeGroupBan: + base.Data["duration"] = raw.Duration + case NoticeTypeGroupUpload: + // 文件上传信息 + if raw.File != nil { + base.Data["file"] = map[string]interface{}{ + "id": raw.File.ID, + "name": raw.File.Name, + "size": raw.File.Size, + "busid": raw.File.Busid, + } + } + case NoticeTypeGroupRecall, NoticeTypeFriendRecall: + base.Data["message_id"] = raw.MessageID + case NoticeTypeNotify: + // 处理通知子类型 + if raw.TargetID > 0 { + base.Data["target_id"] = raw.TargetID + } + if raw.HonorType != "" { + base.Data["honor_type"] = raw.HonorType + } + } + + return base, nil +} + +// convertRequestEvent 转换请求事件 +func (a *Adapter) convertRequestEvent(raw *RawEvent, base *protocol.BaseEvent) (protocol.Event, error) { + base.Type = protocol.EventTypeRequest + base.DetailType = raw.RequestType + base.SubType = raw.SubType + + base.Data["user_id"] = raw.UserID + base.Data["comment"] = raw.Comment + base.Data["flag"] = raw.Flag + + if raw.GroupID > 0 { + base.Data["group_id"] = raw.GroupID + } + + return base, nil +} + +// convertMetaEvent 转换元事件 +func (a *Adapter) convertMetaEvent(raw *RawEvent, base *protocol.BaseEvent) (protocol.Event, error) { + base.Type = protocol.EventTypeMeta + base.DetailType = raw.MetaType + + if raw.Status != nil { + statusData := map[string]interface{}{ + "online": raw.Status.Online, + "good": raw.Status.Good, + } + if raw.Status.Stat != nil { + statusData["stat"] = map[string]interface{}{ + "packet_received": raw.Status.Stat.PacketReceived, + "packet_sent": raw.Status.Stat.PacketSent, + "packet_lost": raw.Status.Stat.PacketLost, + "message_received": raw.Status.Stat.MessageReceived, + "message_sent": raw.Status.Stat.MessageSent, + "disconnect_times": raw.Status.Stat.DisconnectTimes, + "lost_times": raw.Status.Stat.LostTimes, + "last_message_time": raw.Status.Stat.LastMessageTime, + } + } + base.Data["status"] = statusData + } + + if raw.Interval > 0 { + base.Data["interval"] = raw.Interval + } + + return base, nil +} + +// parseMessageSegments 解析消息段 +func (a *Adapter) parseMessageSegments(message interface{}) ([]MessageSegment, error) { + if message == nil { + return nil, fmt.Errorf("message is nil") + } + + // 如果是字符串,转换为文本消息段 + if str, ok := message.(string); ok { + return []MessageSegment{ + { + Type: SegmentTypeText, + Data: map[string]interface{}{ + "text": str, + }, + }, + }, nil + } + + // 如果是数组,解析为消息段数组 + var segments []MessageSegment + data, err := sonic.Marshal(message) + if err != nil { + return nil, fmt.Errorf("failed to marshal message: %w", err) + } + + if err := sonic.Unmarshal(data, &segments); err != nil { + return nil, fmt.Errorf("failed to unmarshal message segments: %w", err) + } + + return segments, nil +} + +// BuildMessage 构建OneBot11消息 +func BuildMessage(segments []MessageSegment) interface{} { + if len(segments) == 0 { + return "" + } + + // 如果只有一个文本消息段,直接返回文本 + if len(segments) == 1 && segments[0].Type == SegmentTypeText { + if text, ok := segments[0].Data["text"].(string); ok { + return text + } + } + + return segments +} + +// BuildTextMessage 构建文本消息 +func BuildTextMessage(text string) []MessageSegment { + return []MessageSegment{ + { + Type: SegmentTypeText, + Data: map[string]interface{}{ + "text": text, + }, + }, + } +} + +// BuildImageMessage 构建图片消息 +func BuildImageMessage(file string) []MessageSegment { + return []MessageSegment{ + { + Type: SegmentTypeImage, + Data: map[string]interface{}{ + "file": file, + }, + }, + } +} + +// BuildAtMessage 构建@消息 +func BuildAtMessage(userID int64) MessageSegment { + return MessageSegment{ + Type: SegmentTypeAt, + Data: map[string]interface{}{ + "qq": userID, + }, + } +} + +// BuildReplyMessage 构建回复消息 +func BuildReplyMessage(messageID int32) MessageSegment { + return MessageSegment{ + Type: SegmentTypeReply, + Data: map[string]interface{}{ + "id": messageID, + }, + } +} + +// BuildFaceMessage 构建表情消息 +func BuildFaceMessage(faceID int) MessageSegment { + return MessageSegment{ + Type: SegmentTypeFace, + Data: map[string]interface{}{ + "id": faceID, + }, + } +} + +// BuildRecordMessage 构建语音消息 +func BuildRecordMessage(file string) MessageSegment { + return MessageSegment{ + Type: SegmentTypeRecord, + Data: map[string]interface{}{ + "file": file, + }, + } +} + +// BuildVideoMessage 构建视频消息 +func BuildVideoMessage(file string) MessageSegment { + return MessageSegment{ + Type: SegmentTypeVideo, + Data: map[string]interface{}{ + "file": file, + }, + } +} + +// BuildShareMessage 构建分享消息 +func BuildShareMessage(url, title string) MessageSegment { + return MessageSegment{ + Type: SegmentTypeShare, + Data: map[string]interface{}{ + "url": url, + "title": title, + }, + } +} + +// BuildLocationMessage 构建位置消息 +func BuildLocationMessage(lat, lon float64, title, content string) MessageSegment { + return MessageSegment{ + Type: SegmentTypeLocation, + Data: map[string]interface{}{ + "lat": lat, + "lon": lon, + "title": title, + "content": content, + }, + } +} + +// BuildMusicMessage 构建音乐消息 +func BuildMusicMessage(musicType, musicID string) MessageSegment { + return MessageSegment{ + Type: SegmentTypeMusic, + Data: map[string]interface{}{ + "type": musicType, + "id": musicID, + }, + } +} + +// BuildCustomMusicMessage 构建自定义音乐消息 +func BuildCustomMusicMessage(url, audio, title, content, image string) MessageSegment { + return MessageSegment{ + Type: SegmentTypeMusic, + Data: map[string]interface{}{ + "type": "custom", + "url": url, + "audio": audio, + "title": title, + "content": content, + "image": image, + }, + } +} diff --git a/internal/adapter/onebot11/types.go b/internal/adapter/onebot11/types.go new file mode 100644 index 0000000..40ecad9 --- /dev/null +++ b/internal/adapter/onebot11/types.go @@ -0,0 +1,187 @@ +package onebot11 + +// RawEvent OneBot11原始事件 +type RawEvent struct { + Time int64 `json:"time"` + SelfID int64 `json:"self_id"` + PostType string `json:"post_type"` + MessageType string `json:"message_type,omitempty"` + SubType string `json:"sub_type,omitempty"` + MessageID int32 `json:"message_id,omitempty"` + UserID int64 `json:"user_id,omitempty"` + GroupID int64 `json:"group_id,omitempty"` + Message interface{} `json:"message,omitempty"` + RawMessage string `json:"raw_message,omitempty"` + Font int32 `json:"font,omitempty"` + Sender *Sender `json:"sender,omitempty"` + Anonymous *Anonymous `json:"anonymous,omitempty"` + NoticeType string `json:"notice_type,omitempty"` + OperatorID int64 `json:"operator_id,omitempty"` + Duration int64 `json:"duration,omitempty"` + RequestType string `json:"request_type,omitempty"` + Comment string `json:"comment,omitempty"` + Flag string `json:"flag,omitempty"` + MetaType string `json:"meta_event_type,omitempty"` + Status *Status `json:"status,omitempty"` + Interval int64 `json:"interval,omitempty"` + File *FileInfo `json:"file,omitempty"` // 群文件上传信息 + TargetID int64 `json:"target_id,omitempty"` // 戳一戳、红包运气王目标ID + HonorType string `json:"honor_type,omitempty"` // 群荣誉类型 + Extra map[string]interface{} `json:"-"` +} + +// Sender 发送者信息 +type Sender struct { + UserID int64 `json:"user_id"` + Nickname string `json:"nickname"` + Sex string `json:"sex,omitempty"` + Age int32 `json:"age,omitempty"` + Card string `json:"card,omitempty"` // 群名片/备注 + Area string `json:"area,omitempty"` // 地区 + Level string `json:"level,omitempty"` // 成员等级 + Role string `json:"role,omitempty"` // 角色: owner, admin, member + Title string `json:"title,omitempty"` // 专属头衔 +} + +// FileInfo 文件信息 +type FileInfo struct { + ID string `json:"id"` + Name string `json:"name"` + Size int64 `json:"size"` + Busid int64 `json:"busid"` +} + +// Anonymous 匿名信息 +type Anonymous struct { + ID int64 `json:"id"` + Name string `json:"name"` + Flag string `json:"flag"` +} + +// Status 状态信息 +type Status struct { + Online bool `json:"online"` + Good bool `json:"good"` + Stat *Stat `json:"stat,omitempty"` +} + +// Stat 统计信息 +type Stat struct { + PacketReceived int64 `json:"packet_received"` + PacketSent int64 `json:"packet_sent"` + PacketLost int32 `json:"packet_lost"` + MessageReceived int64 `json:"message_received"` + MessageSent int64 `json:"message_sent"` + DisconnectTimes int32 `json:"disconnect_times"` + LostTimes int32 `json:"lost_times"` + LastMessageTime int64 `json:"last_message_time"` +} + +// OB11Action OneBot11动作 +type OB11Action struct { + Action string `json:"action"` + Params map[string]interface{} `json:"params"` + Echo string `json:"echo,omitempty"` +} + +// OB11Response OneBot11响应 +type OB11Response struct { + Status string `json:"status"` + RetCode int `json:"retcode"` + Data map[string]interface{} `json:"data,omitempty"` + Echo string `json:"echo,omitempty"` +} + +// MessageSegment 消息段 +type MessageSegment struct { + Type string `json:"type"` + Data map[string]interface{} `json:"data"` +} + +// 消息段类型常量 +const ( + SegmentTypeText = "text" + SegmentTypeFace = "face" + SegmentTypeImage = "image" + SegmentTypeRecord = "record" + SegmentTypeVideo = "video" + SegmentTypeAt = "at" + SegmentTypeRPS = "rps" + SegmentTypeDice = "dice" + SegmentTypeShake = "shake" + SegmentTypePoke = "poke" + SegmentTypeAnonymous = "anonymous" + SegmentTypeShare = "share" + SegmentTypeContact = "contact" + SegmentTypeLocation = "location" + SegmentTypeMusic = "music" + SegmentTypeReply = "reply" + SegmentTypeForward = "forward" + SegmentTypeNode = "node" + SegmentTypeXML = "xml" + SegmentTypeJSON = "json" +) + +// 事件类型常量 +const ( + PostTypeMessage = "message" + PostTypeNotice = "notice" + PostTypeRequest = "request" + PostTypeMetaEvent = "meta_event" +) + +// 消息类型常量 +const ( + MessageTypePrivate = "private" + MessageTypeGroup = "group" +) + +// 通知类型常量 +const ( + NoticeTypeGroupUpload = "group_upload" + NoticeTypeGroupAdmin = "group_admin" + NoticeTypeGroupDecrease = "group_decrease" + NoticeTypeGroupIncrease = "group_increase" + NoticeTypeGroupBan = "group_ban" + NoticeTypeFriendAdd = "friend_add" + NoticeTypeGroupRecall = "group_recall" + NoticeTypeFriendRecall = "friend_recall" + NoticeTypeNotify = "notify" +) + +// 通知子类型常量 +const ( + // 群管理员变动 + SubTypeSet = "set" + SubTypeUnset = "unset" + + // 群成员减少 + SubTypeLeave = "leave" + SubTypeKick = "kick" + SubTypeKickMe = "kick_me" + + // 群成员增加 + SubTypeApprove = "approve" + SubTypeInvite = "invite" + + // 群禁言 + SubTypeBan = "ban" + SubTypeLiftBan = "lift_ban" + + // 通知类型 + SubTypePoke = "poke" // 戳一戳 + SubTypeLuckyKing = "lucky_king" // 红包运气王 + SubTypeHonor = "honor" // 群荣誉变更 +) + +// 请求类型常量 +const ( + RequestTypeFriend = "friend" + RequestTypeGroup = "group" +) + +// 元事件类型常量 +const ( + MetaEventTypeLifecycle = "lifecycle" + MetaEventTypeHeartbeat = "heartbeat" +) diff --git a/internal/config/config.go b/internal/config/config.go index 64601b2..ec5e90d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -14,6 +14,7 @@ type Config struct { Server ServerConfig `toml:"server"` Log LogConfig `toml:"log"` Protocol ProtocolConfig `toml:"protocol"` + Bots []BotConfig `toml:"bots"` } // ServerConfig 服务器配置 @@ -38,6 +39,43 @@ type ProtocolConfig struct { Options map[string]string `toml:"options"` } +// BotConfig Bot 配置 +type BotConfig struct { + ID string `toml:"id"` + Protocol string `toml:"protocol"` + Enabled bool `toml:"enabled"` + Milky MilkyConfig `toml:"milky"` + OneBot11 OneBot11Config `toml:"onebot11"` +} + +// MilkyConfig Milky 协议配置 +type MilkyConfig struct { + ProtocolURL string `toml:"protocol_url"` + AccessToken string `toml:"access_token"` + EventMode string `toml:"event_mode"` + WebhookListenAddr string `toml:"webhook_listen_addr"` + Timeout int `toml:"timeout"` + RetryCount int `toml:"retry_count"` +} + +// OneBot11Config OneBot11 协议配置 +type OneBot11Config struct { + ConnectionType string `toml:"connection_type"` // ws, ws-reverse, http, http-post + Host string `toml:"host"` + Port int `toml:"port"` + AccessToken string `toml:"access_token"` + WSUrl string `toml:"ws_url"` // 正向WS地址 + WSReverseUrl string `toml:"ws_reverse_url"` // 反向WS监听地址 + HTTPUrl string `toml:"http_url"` // 正向HTTP地址 + HTTPPostUrl string `toml:"http_post_url"` // HTTP POST上报地址 + Secret string `toml:"secret"` // 签名密钥 + Timeout int `toml:"timeout"` // 超时时间(秒) + Heartbeat int `toml:"heartbeat"` // 心跳间隔(秒) + ReconnectInterval int `toml:"reconnect_interval"` // 重连间隔(秒) + SelfID string `toml:"self_id"` // 机器人QQ号 + Nickname string `toml:"nickname"` // 机器人昵称 +} + // ConfigManager 配置管理器 type ConfigManager struct { configPath string diff --git a/internal/di/lifecycle.go b/internal/di/lifecycle.go index 69edad7..81a12b4 100644 --- a/internal/di/lifecycle.go +++ b/internal/di/lifecycle.go @@ -6,6 +6,7 @@ import ( "cellbot/internal/engine" "cellbot/internal/protocol" "cellbot/pkg/net" + "go.uber.org/fx" "go.uber.org/zap" ) @@ -17,58 +18,59 @@ func RegisterLifecycleHooks( dispatcher *engine.Dispatcher, botManager *protocol.BotManager, server *net.Server, -) fx.Option { - return fx.Invoke( - func(lc fx.Lifecycle) { - lc.Append(fx.Hook{ - OnStart: func(ctx context.Context) error { - logger.Info("Starting CellBot application...") + lc fx.Lifecycle, +) { + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + logger.Info("=== Starting CellBot application ===") - // 启动事件总线 - eventBus.Start() + // 启动事件总线 + logger.Info("Starting event bus...") + eventBus.Start() + logger.Info("Event bus started") - // 启动分发器 - dispatcher.Start(ctx) + // 启动分发器 + logger.Info("Starting dispatcher...") + dispatcher.Start(ctx) + logger.Info("Dispatcher started") - // 启动所有机器人 - if err := botManager.StartAll(ctx); err != nil { - logger.Error("Failed to start bots", zap.Error(err)) - } + // 启动所有机器人 + if err := botManager.StartAll(ctx); err != nil { + logger.Error("Failed to start bots", zap.Error(err)) + } - // 启动HTTP服务器 - if err := server.Start(); err != nil { - logger.Error("Failed to start server", zap.Error(err)) - return err - } + // 启动HTTP服务器 + if err := server.Start(); err != nil { + logger.Error("Failed to start server", zap.Error(err)) + return err + } - logger.Info("CellBot application started successfully") - return nil - }, - OnStop: func(ctx context.Context) error { - logger.Info("Stopping CellBot application...") - - // 停止HTTP服务器 - if err := server.Stop(); err != nil { - logger.Error("Failed to stop server", zap.Error(err)) - } - - // 停止所有机器人 - if err := botManager.StopAll(ctx); err != nil { - logger.Error("Failed to stop bots", zap.Error(err)) - } - - // 停止分发器 - dispatcher.Stop() - - // 停止事件总线 - eventBus.Stop() - - logger.Info("CellBot application stopped successfully") - return nil - }, - }) + logger.Info("CellBot application started successfully") + return nil }, - ) + OnStop: func(ctx context.Context) error { + logger.Info("Stopping CellBot application...") + + // 停止HTTP服务器 + if err := server.Stop(); err != nil { + logger.Error("Failed to stop server", zap.Error(err)) + } + + // 停止所有机器人 + if err := botManager.StopAll(ctx); err != nil { + logger.Error("Failed to stop bots", zap.Error(err)) + } + + // 停止分发器 + dispatcher.Stop() + + // 停止事件总线 + eventBus.Stop() + + logger.Info("CellBot application stopped successfully") + return nil + }, + }) } // Lifecycle 生命周期管理选项 diff --git a/internal/di/providers.go b/internal/di/providers.go index 041ad6c..9f6987a 100644 --- a/internal/di/providers.go +++ b/internal/di/providers.go @@ -1,20 +1,23 @@ package di import ( + "cellbot/internal/adapter/milky" + "cellbot/internal/adapter/onebot11" "cellbot/internal/config" "cellbot/internal/engine" + "cellbot/internal/plugins/echo" "cellbot/internal/protocol" "cellbot/pkg/net" + "context" + "go.uber.org/fx" "go.uber.org/zap" ) -// ProvideLogger 提供日志实例 func ProvideLogger(cfg *config.Config) (*zap.Logger, error) { return config.InitLogger(&cfg.Log) } -// ProvideConfig 提供配置实例 func ProvideConfig() (*config.Config, error) { configManager := config.NewConfigManager("configs/config.toml", zap.NewNop()) if err := configManager.Load(); err != nil { @@ -23,40 +26,115 @@ func ProvideConfig() (*config.Config, error) { return configManager.Get(), nil } -// ProvideConfigManager 提供配置管理器 func ProvideConfigManager(logger *zap.Logger) (*config.ConfigManager, error) { configManager := config.NewConfigManager("configs/config.toml", logger) if err := configManager.Load(); err != nil { return nil, err } - // 启动配置文件监听 if err := configManager.Watch(); err != nil { logger.Warn("Failed to watch config file", zap.Error(err)) } return configManager, nil } -// ProvideEventBus 提供事件总线 func ProvideEventBus(logger *zap.Logger) *engine.EventBus { return engine.NewEventBus(logger, 10000) } -// ProvideDispatcher 提供事件分发器 func ProvideDispatcher(eventBus *engine.EventBus, logger *zap.Logger) *engine.Dispatcher { return engine.NewDispatcher(eventBus, logger) } -// ProvideBotManager 提供机器人管理器 func ProvideBotManager(logger *zap.Logger) *protocol.BotManager { return protocol.NewBotManager(logger) } -// ProvideServer 提供HTTP服务器 +func ProvideWebSocketManager(logger *zap.Logger, eventBus *engine.EventBus) *net.WebSocketManager { + return net.NewWebSocketManager(logger, eventBus) +} + func ProvideServer(cfg *config.Config, logger *zap.Logger, botManager *protocol.BotManager, eventBus *engine.EventBus) *net.Server { return net.NewServer(cfg.Server.Host, cfg.Server.Port, logger, botManager, eventBus) } -// Providers 依赖注入提供者列表 +func ProvideMilkyBots(cfg *config.Config, logger *zap.Logger, eventBus *engine.EventBus, wsManager *net.WebSocketManager, botManager *protocol.BotManager, lc fx.Lifecycle) error { + for _, botCfg := range cfg.Bots { + if botCfg.Protocol == "milky" && botCfg.Enabled { + logger.Info("Creating Milky bot", zap.String("bot_id", botCfg.ID)) + + milkyCfg := &milky.Config{ + ProtocolURL: botCfg.Milky.ProtocolURL, + AccessToken: botCfg.Milky.AccessToken, + EventMode: botCfg.Milky.EventMode, + WebhookListenAddr: botCfg.Milky.WebhookListenAddr, + Timeout: botCfg.Milky.Timeout, + RetryCount: botCfg.Milky.RetryCount, + } + + bot := milky.NewBot(botCfg.ID, milkyCfg, eventBus, wsManager, logger) + botManager.Add(bot) + + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + logger.Info("Starting Milky bot", zap.String("bot_id", botCfg.ID)) + return bot.Connect(ctx) + }, + OnStop: func(ctx context.Context) error { + logger.Info("Stopping Milky bot", zap.String("bot_id", botCfg.ID)) + return bot.Disconnect(ctx) + }, + }) + } + } + return nil +} + +func ProvideOneBot11Bots(cfg *config.Config, logger *zap.Logger, wsManager *net.WebSocketManager, eventBus *engine.EventBus, botManager *protocol.BotManager, lc fx.Lifecycle) error { + for _, botCfg := range cfg.Bots { + if botCfg.Protocol == "onebot11" && botCfg.Enabled { + logger.Info("Creating OneBot11 bot", zap.String("bot_id", botCfg.ID)) + + ob11Cfg := &onebot11.Config{ + ConnectionType: botCfg.OneBot11.ConnectionType, + Host: botCfg.OneBot11.Host, + Port: botCfg.OneBot11.Port, + AccessToken: botCfg.OneBot11.AccessToken, + WSUrl: botCfg.OneBot11.WSUrl, + WSReverseUrl: botCfg.OneBot11.WSReverseUrl, + Heartbeat: botCfg.OneBot11.Heartbeat, + ReconnectInterval: botCfg.OneBot11.ReconnectInterval, + HTTPUrl: botCfg.OneBot11.HTTPUrl, + HTTPPostUrl: botCfg.OneBot11.HTTPPostUrl, + Secret: botCfg.OneBot11.Secret, + Timeout: botCfg.OneBot11.Timeout, + SelfID: botCfg.OneBot11.SelfID, + Nickname: botCfg.OneBot11.Nickname, + } + + bot := onebot11.NewBot(botCfg.ID, ob11Cfg, logger, wsManager, eventBus) + botManager.Add(bot) + + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + logger.Info("Starting OneBot11 bot", zap.String("bot_id", botCfg.ID)) + return bot.Connect(ctx) + }, + OnStop: func(ctx context.Context) error { + logger.Info("Stopping OneBot11 bot", zap.String("bot_id", botCfg.ID)) + return bot.Disconnect(ctx) + }, + }) + } + } + return nil +} + +func ProvideEchoPlugin(logger *zap.Logger, botManager *protocol.BotManager, dispatcher *engine.Dispatcher) { + echoPlugin := echo.NewEchoPlugin(logger, botManager) + dispatcher.RegisterHandler(echoPlugin) + logger.Info("Echo plugin registered") +} + var Providers = fx.Options( fx.Provide( ProvideConfig, @@ -65,6 +143,10 @@ var Providers = fx.Options( ProvideEventBus, ProvideDispatcher, ProvideBotManager, + ProvideWebSocketManager, ProvideServer, ), + fx.Invoke(ProvideMilkyBots), + fx.Invoke(ProvideOneBot11Bots), + fx.Invoke(ProvideEchoPlugin), ) diff --git a/internal/engine/dispatcher.go b/internal/engine/dispatcher.go index 4a903fb..acff655 100644 --- a/internal/engine/dispatcher.go +++ b/internal/engine/dispatcher.go @@ -2,33 +2,69 @@ package engine import ( "context" + "runtime/debug" "sort" + "sync" + "sync/atomic" + "time" "cellbot/internal/protocol" + "go.uber.org/zap" ) +// DispatcherMetrics 分发器指标 +type DispatcherMetrics struct { + ProcessedTotal int64 // 处理的事件总数 + SuccessTotal int64 // 成功处理的事件数 + FailedTotal int64 // 失败的事件数 + PanicTotal int64 // Panic次数 + AvgProcessTime float64 // 平均处理时间(毫秒) + LastProcessTime int64 // 最后处理时间(Unix时间戳) +} + // Dispatcher 事件分发器 // 管理事件处理器并按照优先级分发事件 type Dispatcher struct { - handlers []protocol.EventHandler + handlers []protocol.EventHandler middlewares []protocol.Middleware - logger *zap.Logger - eventBus *EventBus + logger *zap.Logger + eventBus *EventBus + metrics DispatcherMetrics + mu sync.RWMutex + workerPool chan struct{} // 工作池,限制并发数 + maxWorkers int + async bool // 是否异步处理 + totalTime int64 // 总处理时间(纳秒) } // NewDispatcher 创建事件分发器 func NewDispatcher(eventBus *EventBus, logger *zap.Logger) *Dispatcher { + return NewDispatcherWithConfig(eventBus, logger, 100, true) +} + +// NewDispatcherWithConfig 使用配置创建事件分发器 +func NewDispatcherWithConfig(eventBus *EventBus, logger *zap.Logger, maxWorkers int, async bool) *Dispatcher { + if maxWorkers <= 0 { + maxWorkers = 100 + } + return &Dispatcher{ handlers: make([]protocol.EventHandler, 0), middlewares: make([]protocol.Middleware, 0), logger: logger.Named("dispatcher"), eventBus: eventBus, + workerPool: make(chan struct{}, maxWorkers), + maxWorkers: maxWorkers, + async: async, } } // RegisterHandler 注册事件处理器 func (d *Dispatcher) RegisterHandler(handler protocol.EventHandler) { + d.mu.Lock() + defer d.mu.Unlock() + d.handlers = append(d.handlers, handler) // 按优先级排序(数值越小优先级越高) sort.Slice(d.handlers, func(i, j int) bool { @@ -42,6 +78,9 @@ func (d *Dispatcher) RegisterHandler(handler protocol.EventHandler) { // UnregisterHandler 取消注册事件处理器 func (d *Dispatcher) UnregisterHandler(handler protocol.EventHandler) { + d.mu.Lock() + defer d.mu.Unlock() + for i, h := range d.handlers { if h == handler { d.handlers = append(d.handlers[:i], d.handlers[i+1:]...) @@ -54,6 +93,9 @@ func (d *Dispatcher) UnregisterHandler(handler protocol.EventHandler) { // RegisterMiddleware 注册中间件 func (d *Dispatcher) RegisterMiddleware(middleware protocol.Middleware) { + d.mu.Lock() + defer d.mu.Unlock() + d.middlewares = append(d.middlewares, middleware) d.logger.Debug("Middleware registered", zap.Int("total_middlewares", len(d.middlewares))) @@ -88,7 +130,21 @@ func (d *Dispatcher) eventLoop(ctx context.Context, eventChan chan protocol.Even if !ok { return } - d.handleEvent(ctx, event) + + if d.IsAsync() { + // 异步处理,使用工作池限制并发 + d.workerPool <- struct{}{} // 获取工作槽位 + go func(e protocol.Event) { + defer func() { + <-d.workerPool // 释放工作槽位 + }() + d.handleEvent(ctx, e) + }(event) + } else { + // 同步处理 + d.handleEvent(ctx, event) + } + case <-ctx.Done(): return } @@ -97,47 +153,114 @@ func (d *Dispatcher) eventLoop(ctx context.Context, eventChan chan protocol.Even // handleEvent 处理单个事件 func (d *Dispatcher) handleEvent(ctx context.Context, event protocol.Event) { - d.logger.Debug("Processing event", + startTime := time.Now() + + // 使用defer捕获panic + defer func() { + if r := recover(); r != nil { + atomic.AddInt64(&d.metrics.PanicTotal, 1) + atomic.AddInt64(&d.metrics.FailedTotal, 1) + d.logger.Error("Panic in event handler", + zap.Any("panic", r), + zap.String("stack", string(debug.Stack())), + zap.String("event_type", string(event.GetType()))) + } + + // 更新指标 + duration := time.Since(startTime) + atomic.AddInt64(&d.metrics.ProcessedTotal, 1) + atomic.AddInt64(&d.totalTime, duration.Nanoseconds()) + atomic.StoreInt64(&d.metrics.LastProcessTime, time.Now().Unix()) + + // 计算平均处理时间 + processed := atomic.LoadInt64(&d.metrics.ProcessedTotal) + if processed > 0 { + avgNs := atomic.LoadInt64(&d.totalTime) / processed + d.metrics.AvgProcessTime = float64(avgNs) / 1e6 // 转换为毫秒 + } + }() + + d.logger.Info("Processing event", zap.String("type", string(event.GetType())), - zap.String("detail_type", event.GetDetailType())) + zap.String("detail_type", event.GetDetailType()), + zap.String("self_id", event.GetSelfID())) // 通过中间件链处理事件 + d.mu.RLock() + middlewares := d.middlewares + d.mu.RUnlock() + next := d.createHandlerChain(ctx, event) // 执行中间件链 - if len(d.middlewares) > 0 { - d.executeMiddlewares(ctx, event, func(ctx context.Context, e protocol.Event) error { + if len(middlewares) > 0 { + d.executeMiddlewares(ctx, event, middlewares, func(ctx context.Context, e protocol.Event) error { next(ctx, e) return nil }) } else { next(ctx, event) } + + atomic.AddInt64(&d.metrics.SuccessTotal, 1) } // createHandlerChain 创建处理器链 func (d *Dispatcher) createHandlerChain(ctx context.Context, event protocol.Event) func(context.Context, protocol.Event) { return func(ctx context.Context, e protocol.Event) { - for _, handler := range d.handlers { - if handler.Match(event) { - if err := handler.Handle(ctx, e); err != nil { - d.logger.Error("Handler execution failed", - zap.Error(err), - zap.String("event_type", string(e.GetType()))) - } + d.mu.RLock() + handlers := make([]protocol.EventHandler, len(d.handlers)) + copy(handlers, d.handlers) + d.mu.RUnlock() + + for i, handler := range handlers { + matched := handler.Match(event) + d.logger.Info("Checking handler", + zap.Int("handler_index", i), + zap.Int("priority", handler.Priority()), + zap.Bool("matched", matched)) + if matched { + d.logger.Info("Handler matched, calling Handle", + zap.Int("handler_index", i)) + // 使用defer捕获单个handler的panic + func() { + defer func() { + if r := recover(); r != nil { + d.logger.Error("Panic in handler", + zap.Any("panic", r), + zap.String("stack", string(debug.Stack())), + zap.String("event_type", string(e.GetType()))) + } + }() + + if err := handler.Handle(ctx, e); err != nil { + d.logger.Error("Handler execution failed", + zap.Error(err), + zap.String("event_type", string(e.GetType()))) + } + }() } } } } // executeMiddlewares 执行中间件链 -func (d *Dispatcher) executeMiddlewares(ctx context.Context, event protocol.Event, next func(context.Context, protocol.Event) error) { +func (d *Dispatcher) executeMiddlewares(ctx context.Context, event protocol.Event, middlewares []protocol.Middleware, next func(context.Context, protocol.Event) error) { // 从后向前构建中间件链 handler := next - for i := len(d.middlewares) - 1; i >= 0; i-- { - middleware := d.middlewares[i] + for i := len(middlewares) - 1; i >= 0; i-- { + middleware := middlewares[i] currentHandler := handler handler = func(ctx context.Context, e protocol.Event) error { + defer func() { + if r := recover(); r != nil { + d.logger.Error("Panic in middleware", + zap.Any("panic", r), + zap.String("stack", string(debug.Stack())), + zap.String("event_type", string(e.GetType()))) + } + }() + if err := middleware.Process(ctx, e, currentHandler); err != nil { d.logger.Error("Middleware execution failed", zap.Error(err), @@ -153,10 +276,54 @@ func (d *Dispatcher) executeMiddlewares(ctx context.Context, event protocol.Even // GetHandlerCount 获取处理器数量 func (d *Dispatcher) GetHandlerCount() int { + d.mu.RLock() + defer d.mu.RUnlock() return len(d.handlers) } // GetMiddlewareCount 获取中间件数量 func (d *Dispatcher) GetMiddlewareCount() int { + d.mu.RLock() + defer d.mu.RUnlock() return len(d.middlewares) } + +// GetMetrics 获取分发器指标 +func (d *Dispatcher) GetMetrics() DispatcherMetrics { + return DispatcherMetrics{ + ProcessedTotal: atomic.LoadInt64(&d.metrics.ProcessedTotal), + SuccessTotal: atomic.LoadInt64(&d.metrics.SuccessTotal), + FailedTotal: atomic.LoadInt64(&d.metrics.FailedTotal), + PanicTotal: atomic.LoadInt64(&d.metrics.PanicTotal), + AvgProcessTime: d.metrics.AvgProcessTime, + LastProcessTime: atomic.LoadInt64(&d.metrics.LastProcessTime), + } +} + +// LogMetrics 记录指标日志 +func (d *Dispatcher) LogMetrics() { + metrics := d.GetMetrics() + + d.logger.Info("Dispatcher metrics", + zap.Int64("processed_total", metrics.ProcessedTotal), + zap.Int64("success_total", metrics.SuccessTotal), + zap.Int64("failed_total", metrics.FailedTotal), + zap.Int64("panic_total", metrics.PanicTotal), + zap.Float64("avg_process_time_ms", metrics.AvgProcessTime), + zap.Int("handler_count", d.GetHandlerCount()), + zap.Int("middleware_count", d.GetMiddlewareCount())) +} + +// SetAsync 设置是否异步处理 +func (d *Dispatcher) SetAsync(async bool) { + d.mu.Lock() + defer d.mu.Unlock() + d.async = async +} + +// IsAsync 是否异步处理 +func (d *Dispatcher) IsAsync() bool { + d.mu.RLock() + defer d.mu.RUnlock() + return d.async +} diff --git a/internal/engine/eventbus.go b/internal/engine/eventbus.go index 288e3e6..6146a06 100644 --- a/internal/engine/eventbus.go +++ b/internal/engine/eventbus.go @@ -2,33 +2,54 @@ package engine import ( "context" + "crypto/rand" + "encoding/hex" "sync" + "sync/atomic" + "time" "cellbot/internal/protocol" + "go.uber.org/zap" ) // Subscription 订阅信息 type Subscription struct { - ID string - Chan chan protocol.Event - Filter func(protocol.Event) bool + ID string + Chan chan protocol.Event + Filter func(protocol.Event) bool + CreatedAt time.Time + EventCount int64 // 接收的事件数量 +} + +// EventBusMetrics 事件总线指标 +type EventBusMetrics struct { + PublishedTotal int64 // 发布的事件总数 + DispatchedTotal int64 // 分发的事件总数 + DroppedTotal int64 // 丢弃的事件总数 + SubscriberTotal int64 // 订阅者总数 + LastEventTime int64 // 最后一次事件时间(Unix时间戳) } // EventBus 事件总线 // 基于channel的高性能发布订阅实现 type EventBus struct { subscriptions map[string][]*Subscription - mu sync.RWMutex - logger *zap.Logger - eventChan chan protocol.Event - wg sync.WaitGroup - ctx context.Context - cancel context.CancelFunc + mu sync.RWMutex + logger *zap.Logger + eventChan chan protocol.Event + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + metrics EventBusMetrics + bufferSize int } // NewEventBus 创建事件总线 func NewEventBus(logger *zap.Logger, bufferSize int) *EventBus { + if bufferSize <= 0 { + bufferSize = 1000 + } ctx, cancel := context.WithCancel(context.Background()) return &EventBus{ subscriptions: make(map[string][]*Subscription), @@ -36,6 +57,7 @@ func NewEventBus(logger *zap.Logger, bufferSize int) *EventBus { eventChan: make(chan protocol.Event, bufferSize), ctx: ctx, cancel: cancel, + bufferSize: bufferSize, } } @@ -56,31 +78,89 @@ func (eb *EventBus) Stop() { // Publish 发布事件 func (eb *EventBus) Publish(event protocol.Event) { + eb.logger.Info("Publishing event to channel", + zap.String("event_type", string(event.GetType())), + zap.String("detail_type", event.GetDetailType()), + zap.Int("channel_len", len(eb.eventChan)), + zap.Int("channel_cap", cap(eb.eventChan))) + select { case eb.eventChan <- event: + atomic.AddInt64(&eb.metrics.PublishedTotal, 1) + atomic.StoreInt64(&eb.metrics.LastEventTime, time.Now().Unix()) + eb.logger.Info("Event successfully queued", + zap.String("event_type", string(event.GetType()))) case <-eb.ctx.Done(): + atomic.AddInt64(&eb.metrics.DroppedTotal, 1) eb.logger.Warn("Event bus is shutting down, event dropped", zap.String("type", string(event.GetType()))) + default: + // 如果channel满了,也丢弃事件 + atomic.AddInt64(&eb.metrics.DroppedTotal, 1) + eb.logger.Warn("Event channel full, event dropped", + zap.String("type", string(event.GetType())), + zap.Int("buffer_size", eb.bufferSize), + zap.Int("channel_len", len(eb.eventChan))) + } +} + +// PublishBatch 批量发布事件 +func (eb *EventBus) PublishBatch(events []protocol.Event) { + for _, event := range events { + eb.Publish(event) + } +} + +// PublishAsync 异步发布事件(不阻塞) +func (eb *EventBus) PublishAsync(event protocol.Event) { + go eb.Publish(event) +} + +// TryPublish 尝试发布事件(非阻塞) +func (eb *EventBus) TryPublish(event protocol.Event) bool { + select { + case eb.eventChan <- event: + atomic.AddInt64(&eb.metrics.PublishedTotal, 1) + atomic.StoreInt64(&eb.metrics.LastEventTime, time.Now().Unix()) + return true + case <-eb.ctx.Done(): + atomic.AddInt64(&eb.metrics.DroppedTotal, 1) + return false + default: + atomic.AddInt64(&eb.metrics.DroppedTotal, 1) + return false } } // Subscribe 订阅事件 func (eb *EventBus) Subscribe(eventType protocol.EventType, filter func(protocol.Event) bool) chan protocol.Event { + return eb.SubscribeWithBuffer(eventType, filter, 100) +} + +// SubscribeWithBuffer 订阅事件(指定缓冲区大小) +func (eb *EventBus) SubscribeWithBuffer(eventType protocol.EventType, filter func(protocol.Event) bool, bufferSize int) chan protocol.Event { eb.mu.Lock() defer eb.mu.Unlock() + if bufferSize <= 0 { + bufferSize = 100 + } + sub := &Subscription{ - ID: generateSubscriptionID(), - Chan: make(chan protocol.Event, 100), - Filter: filter, + ID: generateSubscriptionID(), + Chan: make(chan protocol.Event, bufferSize), + Filter: filter, + CreatedAt: time.Now(), } key := string(eventType) eb.subscriptions[key] = append(eb.subscriptions[key], sub) + atomic.AddInt64(&eb.metrics.SubscriberTotal, 1) eb.logger.Debug("New subscription added", zap.String("event_type", key), - zap.String("sub_id", sub.ID)) + zap.String("sub_id", sub.ID), + zap.Int("buffer_size", bufferSize)) return sub.Chan } @@ -96,9 +176,13 @@ func (eb *EventBus) Unsubscribe(eventType protocol.EventType, ch chan protocol.E if sub.Chan == ch { close(sub.Chan) eb.subscriptions[key] = append(subs[:i], subs[i+1:]...) + atomic.AddInt64(&eb.metrics.SubscriberTotal, -1) + eb.logger.Debug("Subscription removed", zap.String("event_type", key), - zap.String("sub_id", sub.ID)) + zap.String("sub_id", sub.ID), + zap.Int64("event_count", sub.EventCount), + zap.Duration("lifetime", time.Since(sub.CreatedAt))) return } } @@ -108,14 +192,20 @@ func (eb *EventBus) Unsubscribe(eventType protocol.EventType, ch chan protocol.E func (eb *EventBus) dispatch() { defer eb.wg.Done() + eb.logger.Info("Event bus dispatch loop started") + for { select { case event, ok := <-eb.eventChan: if !ok { + eb.logger.Info("Event channel closed, stopping dispatch") return } + eb.logger.Debug("Received event from channel", + zap.String("event_type", string(event.GetType()))) eb.dispatchEvent(event) case <-eb.ctx.Done(): + eb.logger.Info("Context cancelled, stopping dispatch") return } } @@ -131,18 +221,40 @@ func (eb *EventBus) dispatchEvent(event protocol.Event) { copy(subsCopy, subs) eb.mu.RUnlock() + eb.logger.Info("Dispatching event", + zap.String("event_type", key), + zap.String("detail_type", event.GetDetailType()), + zap.Int("subscriber_count", len(subsCopy))) + + dispatched := 0 for _, sub := range subsCopy { if sub.Filter == nil || sub.Filter(event) { select { case sub.Chan <- event: + atomic.AddInt64(&sub.EventCount, 1) + dispatched++ + eb.logger.Debug("Event dispatched to subscriber", + zap.String("sub_id", sub.ID)) default: // 订阅者channel已满,丢弃事件 + atomic.AddInt64(&eb.metrics.DroppedTotal, 1) eb.logger.Warn("Subscription channel full, event dropped", zap.String("sub_id", sub.ID), zap.String("event_type", key)) } } } + + if dispatched > 0 { + atomic.AddInt64(&eb.metrics.DispatchedTotal, int64(dispatched)) + eb.logger.Info("Event dispatched successfully", + zap.String("event_type", key), + zap.Int("dispatched_count", dispatched)) + } else { + eb.logger.Warn("No subscribers for event", + zap.String("event_type", key), + zap.String("detail_type", event.GetDetailType())) + } } // GetSubscriptionCount 获取订阅者数量 @@ -166,6 +278,67 @@ func (eb *EventBus) Clear() { eb.logger.Info("All subscriptions cleared") } +// GetMetrics 获取事件总线指标 +func (eb *EventBus) GetMetrics() EventBusMetrics { + return EventBusMetrics{ + PublishedTotal: atomic.LoadInt64(&eb.metrics.PublishedTotal), + DispatchedTotal: atomic.LoadInt64(&eb.metrics.DispatchedTotal), + DroppedTotal: atomic.LoadInt64(&eb.metrics.DroppedTotal), + SubscriberTotal: atomic.LoadInt64(&eb.metrics.SubscriberTotal), + LastEventTime: atomic.LoadInt64(&eb.metrics.LastEventTime), + } +} + +// GetAllSubscriptions 获取所有订阅信息 +func (eb *EventBus) GetAllSubscriptions() map[string]int { + eb.mu.RLock() + defer eb.mu.RUnlock() + + result := make(map[string]int) + for eventType, subs := range eb.subscriptions { + result[eventType] = len(subs) + } + return result +} + +// GetBufferUsage 获取缓冲区使用情况 +func (eb *EventBus) GetBufferUsage() float64 { + return float64(len(eb.eventChan)) / float64(eb.bufferSize) +} + +// IsHealthy 检查事件总线健康状态 +func (eb *EventBus) IsHealthy() bool { + // 检查缓冲区使用率是否过高 + if eb.GetBufferUsage() > 0.9 { + return false + } + + // 检查是否有过多的丢弃事件 + metrics := eb.GetMetrics() + if metrics.PublishedTotal > 0 { + dropRate := float64(metrics.DroppedTotal) / float64(metrics.PublishedTotal) + if dropRate > 0.1 { // 丢弃率超过10% + return false + } + } + + return true +} + +// LogMetrics 记录指标日志 +func (eb *EventBus) LogMetrics() { + metrics := eb.GetMetrics() + subs := eb.GetAllSubscriptions() + + eb.logger.Info("EventBus metrics", + zap.Int64("published_total", metrics.PublishedTotal), + zap.Int64("dispatched_total", metrics.DispatchedTotal), + zap.Int64("dropped_total", metrics.DroppedTotal), + zap.Int64("subscriber_total", metrics.SubscriberTotal), + zap.Float64("buffer_usage", eb.GetBufferUsage()), + zap.Any("subscriptions", subs)) +} + // generateSubscriptionID 生成订阅ID func generateSubscriptionID() string { return "sub-" + randomString(8) @@ -173,10 +346,15 @@ func generateSubscriptionID() string { // randomString 生成随机字符串 func randomString(length int) string { - const charset = "abcdefghijklmnopqrstuvwxyz0123456789" - b := make([]byte, length) - for i := range b { - b[i] = charset[i%len(charset)] + b := make([]byte, length/2) + if _, err := rand.Read(b); err != nil { + // 降级到简单实现 + const charset = "abcdefghijklmnopqrstuvwxyz0123456789" + result := make([]byte, length) + for i := range result { + result[i] = charset[i%len(charset)] + } + return string(result) } - return string(b) + return hex.EncodeToString(b) } diff --git a/internal/engine/handler.go b/internal/engine/handler.go new file mode 100644 index 0000000..4238e8a --- /dev/null +++ b/internal/engine/handler.go @@ -0,0 +1,312 @@ +package engine + +import ( + "context" + "strings" + + "cellbot/internal/protocol" + + "go.uber.org/zap" +) + +// BaseHandler 基础处理器 +type BaseHandler struct { + priority int + logger *zap.Logger +} + +// NewBaseHandler 创建基础处理器 +func NewBaseHandler(priority int, logger *zap.Logger) *BaseHandler { + return &BaseHandler{ + priority: priority, + logger: logger, + } +} + +// Priority 获取优先级 +func (h *BaseHandler) Priority() int { + return h.priority +} + +// MessageHandler 消息处理器 +type MessageHandler struct { + *BaseHandler + matchFunc func(protocol.Event) bool + handleFunc func(context.Context, protocol.Event) error +} + +// NewMessageHandler 创建消息处理器 +func NewMessageHandler(priority int, logger *zap.Logger, matchFunc func(protocol.Event) bool, handleFunc func(context.Context, protocol.Event) error) *MessageHandler { + return &MessageHandler{ + BaseHandler: NewBaseHandler(priority, logger.Named("handler.message")), + matchFunc: matchFunc, + handleFunc: handleFunc, + } +} + +// Match 判断是否匹配事件 +func (h *MessageHandler) Match(event protocol.Event) bool { + if event.GetType() != protocol.EventTypeMessage { + return false + } + + if h.matchFunc != nil { + return h.matchFunc(event) + } + + return true +} + +// Handle 处理事件 +func (h *MessageHandler) Handle(ctx context.Context, event protocol.Event) error { + if h.handleFunc != nil { + return h.handleFunc(ctx, event) + } + + h.logger.Info("Message event handled", + zap.String("detail_type", event.GetDetailType()), + zap.String("self_id", event.GetSelfID())) + + return nil +} + +// CommandHandler 命令处理器 +type CommandHandler struct { + *BaseHandler + prefix string + commands map[string]func(context.Context, protocol.Event, []string) error +} + +// NewCommandHandler 创建命令处理器 +func NewCommandHandler(priority int, logger *zap.Logger, prefix string) *CommandHandler { + return &CommandHandler{ + BaseHandler: NewBaseHandler(priority, logger.Named("handler.command")), + prefix: prefix, + commands: make(map[string]func(context.Context, protocol.Event, []string) error), + } +} + +// RegisterCommand 注册命令 +func (h *CommandHandler) RegisterCommand(cmd string, handler func(context.Context, protocol.Event, []string) error) { + h.commands[cmd] = handler + h.logger.Debug("Command registered", zap.String("command", cmd)) +} + +// Match 判断是否匹配事件 +func (h *CommandHandler) Match(event protocol.Event) bool { + if event.GetType() != protocol.EventTypeMessage { + return false + } + + data := event.GetData() + rawMessage, ok := data["raw_message"].(string) + if !ok { + return false + } + + return strings.HasPrefix(rawMessage, h.prefix) +} + +// Handle 处理事件 +func (h *CommandHandler) Handle(ctx context.Context, event protocol.Event) error { + data := event.GetData() + rawMessage, ok := data["raw_message"].(string) + if !ok { + return nil + } + + // 去除前缀 + cmdText := strings.TrimPrefix(rawMessage, h.prefix) + cmdText = strings.TrimSpace(cmdText) + + // 解析命令和参数 + parts := strings.Fields(cmdText) + if len(parts) == 0 { + return nil + } + + cmd := parts[0] + args := parts[1:] + + // 查找命令处理器 + handler, exists := h.commands[cmd] + if !exists { + h.logger.Debug("Unknown command", zap.String("command", cmd)) + return nil + } + + h.logger.Info("Executing command", + zap.String("command", cmd), + zap.Strings("args", args)) + + return handler(ctx, event, args) +} + +// KeywordHandler 关键词处理器 +type KeywordHandler struct { + *BaseHandler + keywords map[string]func(context.Context, protocol.Event) error + caseSensitive bool +} + +// NewKeywordHandler 创建关键词处理器 +func NewKeywordHandler(priority int, logger *zap.Logger, caseSensitive bool) *KeywordHandler { + return &KeywordHandler{ + BaseHandler: NewBaseHandler(priority, logger.Named("handler.keyword")), + keywords: make(map[string]func(context.Context, protocol.Event) error), + caseSensitive: caseSensitive, + } +} + +// RegisterKeyword 注册关键词 +func (h *KeywordHandler) RegisterKeyword(keyword string, handler func(context.Context, protocol.Event) error) { + if !h.caseSensitive { + keyword = strings.ToLower(keyword) + } + h.keywords[keyword] = handler + h.logger.Debug("Keyword registered", zap.String("keyword", keyword)) +} + +// Match 判断是否匹配事件 +func (h *KeywordHandler) Match(event protocol.Event) bool { + if event.GetType() != protocol.EventTypeMessage { + return false + } + + data := event.GetData() + rawMessage, ok := data["raw_message"].(string) + if !ok { + return false + } + + if !h.caseSensitive { + rawMessage = strings.ToLower(rawMessage) + } + + for keyword := range h.keywords { + if strings.Contains(rawMessage, keyword) { + return true + } + } + + return false +} + +// Handle 处理事件 +func (h *KeywordHandler) Handle(ctx context.Context, event protocol.Event) error { + data := event.GetData() + rawMessage, ok := data["raw_message"].(string) + if !ok { + return nil + } + + if !h.caseSensitive { + rawMessage = strings.ToLower(rawMessage) + } + + // 执行所有匹配的关键词处理器 + for keyword, handler := range h.keywords { + if strings.Contains(rawMessage, keyword) { + h.logger.Info("Keyword matched", + zap.String("keyword", keyword)) + + if err := handler(ctx, event); err != nil { + h.logger.Error("Keyword handler failed", + zap.String("keyword", keyword), + zap.Error(err)) + } + } + } + + return nil +} + +// NoticeHandler 通知处理器 +type NoticeHandler struct { + *BaseHandler + noticeTypes map[string]func(context.Context, protocol.Event) error +} + +// NewNoticeHandler 创建通知处理器 +func NewNoticeHandler(priority int, logger *zap.Logger) *NoticeHandler { + return &NoticeHandler{ + BaseHandler: NewBaseHandler(priority, logger.Named("handler.notice")), + noticeTypes: make(map[string]func(context.Context, protocol.Event) error), + } +} + +// RegisterNoticeType 注册通知类型处理器 +func (h *NoticeHandler) RegisterNoticeType(noticeType string, handler func(context.Context, protocol.Event) error) { + h.noticeTypes[noticeType] = handler + h.logger.Debug("Notice type registered", zap.String("notice_type", noticeType)) +} + +// Match 判断是否匹配事件 +func (h *NoticeHandler) Match(event protocol.Event) bool { + if event.GetType() != protocol.EventTypeNotice { + return false + } + + detailType := event.GetDetailType() + _, exists := h.noticeTypes[detailType] + return exists +} + +// Handle 处理事件 +func (h *NoticeHandler) Handle(ctx context.Context, event protocol.Event) error { + detailType := event.GetDetailType() + handler, exists := h.noticeTypes[detailType] + if !exists { + return nil + } + + h.logger.Info("Notice event handled", + zap.String("notice_type", detailType)) + + return handler(ctx, event) +} + +// RequestHandler 请求处理器 +type RequestHandler struct { + *BaseHandler + requestTypes map[string]func(context.Context, protocol.Event) error +} + +// NewRequestHandler 创建请求处理器 +func NewRequestHandler(priority int, logger *zap.Logger) *RequestHandler { + return &RequestHandler{ + BaseHandler: NewBaseHandler(priority, logger.Named("handler.request")), + requestTypes: make(map[string]func(context.Context, protocol.Event) error), + } +} + +// RegisterRequestType 注册请求类型处理器 +func (h *RequestHandler) RegisterRequestType(requestType string, handler func(context.Context, protocol.Event) error) { + h.requestTypes[requestType] = handler + h.logger.Debug("Request type registered", zap.String("request_type", requestType)) +} + +// Match 判断是否匹配事件 +func (h *RequestHandler) Match(event protocol.Event) bool { + if event.GetType() != protocol.EventTypeRequest { + return false + } + + detailType := event.GetDetailType() + _, exists := h.requestTypes[detailType] + return exists +} + +// Handle 处理事件 +func (h *RequestHandler) Handle(ctx context.Context, event protocol.Event) error { + detailType := event.GetDetailType() + handler, exists := h.requestTypes[detailType] + if !exists { + return nil + } + + h.logger.Info("Request event handled", + zap.String("request_type", detailType)) + + return handler(ctx, event) +} diff --git a/internal/engine/middleware.go b/internal/engine/middleware.go new file mode 100644 index 0000000..c19f223 --- /dev/null +++ b/internal/engine/middleware.go @@ -0,0 +1,283 @@ +package engine + +import ( + "context" + "sync" + "time" + + "cellbot/internal/protocol" + + "go.uber.org/zap" + "golang.org/x/time/rate" +) + +// LoggingMiddleware 日志中间件 +type LoggingMiddleware struct { + logger *zap.Logger +} + +// NewLoggingMiddleware 创建日志中间件 +func NewLoggingMiddleware(logger *zap.Logger) *LoggingMiddleware { + return &LoggingMiddleware{ + logger: logger.Named("middleware.logging"), + } +} + +// Process 处理事件 +func (m *LoggingMiddleware) Process(ctx context.Context, event protocol.Event, next func(context.Context, protocol.Event) error) error { + start := time.Now() + + m.logger.Info("Event received", + zap.String("type", string(event.GetType())), + zap.String("detail_type", event.GetDetailType()), + zap.String("self_id", event.GetSelfID())) + + err := next(ctx, event) + + m.logger.Info("Event processed", + zap.String("type", string(event.GetType())), + zap.Duration("duration", time.Since(start)), + zap.Error(err)) + + return err +} + +// RateLimitMiddleware 限流中间件 +type RateLimitMiddleware struct { + limiters map[string]*rate.Limiter + mu sync.RWMutex + logger *zap.Logger + rps int // 每秒请求数 + burst int // 突发容量 +} + +// NewRateLimitMiddleware 创建限流中间件 +func NewRateLimitMiddleware(logger *zap.Logger, rps, burst int) *RateLimitMiddleware { + if rps <= 0 { + rps = 100 + } + if burst <= 0 { + burst = rps * 2 + } + + return &RateLimitMiddleware{ + limiters: make(map[string]*rate.Limiter), + logger: logger.Named("middleware.ratelimit"), + rps: rps, + burst: burst, + } +} + +// Process 处理事件 +func (m *RateLimitMiddleware) Process(ctx context.Context, event protocol.Event, next func(context.Context, protocol.Event) error) error { + // 根据事件类型获取限流器 + key := string(event.GetType()) + + m.mu.RLock() + limiter, exists := m.limiters[key] + m.mu.RUnlock() + + if !exists { + m.mu.Lock() + limiter = rate.NewLimiter(rate.Limit(m.rps), m.burst) + m.limiters[key] = limiter + m.mu.Unlock() + } + + // 等待令牌 + if err := limiter.Wait(ctx); err != nil { + m.logger.Warn("Rate limit exceeded", + zap.String("event_type", key), + zap.Error(err)) + return err + } + + return next(ctx, event) +} + +// RetryMiddleware 重试中间件 +type RetryMiddleware struct { + logger *zap.Logger + maxRetries int + delay time.Duration +} + +// NewRetryMiddleware 创建重试中间件 +func NewRetryMiddleware(logger *zap.Logger, maxRetries int, delay time.Duration) *RetryMiddleware { + if maxRetries <= 0 { + maxRetries = 3 + } + if delay <= 0 { + delay = time.Second + } + + return &RetryMiddleware{ + logger: logger.Named("middleware.retry"), + maxRetries: maxRetries, + delay: delay, + } +} + +// Process 处理事件 +func (m *RetryMiddleware) Process(ctx context.Context, event protocol.Event, next func(context.Context, protocol.Event) error) error { + var err error + for i := 0; i <= m.maxRetries; i++ { + if i > 0 { + m.logger.Info("Retrying event", + zap.String("event_type", string(event.GetType())), + zap.Int("attempt", i), + zap.Int("max_retries", m.maxRetries)) + + // 指数退避 + backoff := m.delay * time.Duration(1< 0 { + m.logger.Info("Event succeeded after retry", + zap.String("event_type", string(event.GetType())), + zap.Int("attempts", i+1)) + } + return nil + } + + m.logger.Warn("Event processing failed", + zap.String("event_type", string(event.GetType())), + zap.Int("attempt", i+1), + zap.Error(err)) + } + + m.logger.Error("Event failed after all retries", + zap.String("event_type", string(event.GetType())), + zap.Int("total_attempts", m.maxRetries+1), + zap.Error(err)) + + return err +} + +// TimeoutMiddleware 超时中间件 +type TimeoutMiddleware struct { + logger *zap.Logger + timeout time.Duration +} + +// NewTimeoutMiddleware 创建超时中间件 +func NewTimeoutMiddleware(logger *zap.Logger, timeout time.Duration) *TimeoutMiddleware { + if timeout <= 0 { + timeout = 30 * time.Second + } + + return &TimeoutMiddleware{ + logger: logger.Named("middleware.timeout"), + timeout: timeout, + } +} + +// Process 处理事件 +func (m *TimeoutMiddleware) Process(ctx context.Context, event protocol.Event, next func(context.Context, protocol.Event) error) error { + ctx, cancel := context.WithTimeout(ctx, m.timeout) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- next(ctx, event) + }() + + select { + case err := <-done: + return err + case <-ctx.Done(): + m.logger.Warn("Event processing timeout", + zap.String("event_type", string(event.GetType())), + zap.Duration("timeout", m.timeout)) + return ctx.Err() + } +} + +// RecoveryMiddleware 恢复中间件(捕获panic) +type RecoveryMiddleware struct { + logger *zap.Logger +} + +// NewRecoveryMiddleware 创建恢复中间件 +func NewRecoveryMiddleware(logger *zap.Logger) *RecoveryMiddleware { + return &RecoveryMiddleware{ + logger: logger.Named("middleware.recovery"), + } +} + +// Process 处理事件 +func (m *RecoveryMiddleware) Process(ctx context.Context, event protocol.Event, next func(context.Context, protocol.Event) error) (err error) { + defer func() { + if r := recover(); r != nil { + m.logger.Error("Recovered from panic", + zap.Any("panic", r), + zap.String("event_type", string(event.GetType()))) + err = protocol.ErrNotImplemented // 或者自定义错误 + } + }() + + return next(ctx, event) +} + +// MetricsMiddleware 指标中间件 +type MetricsMiddleware struct { + logger *zap.Logger + eventCounts map[string]int64 + eventTimes map[string]time.Duration + mu sync.RWMutex +} + +// NewMetricsMiddleware 创建指标中间件 +func NewMetricsMiddleware(logger *zap.Logger) *MetricsMiddleware { + return &MetricsMiddleware{ + logger: logger.Named("middleware.metrics"), + eventCounts: make(map[string]int64), + eventTimes: make(map[string]time.Duration), + } +} + +// Process 处理事件 +func (m *MetricsMiddleware) Process(ctx context.Context, event protocol.Event, next func(context.Context, protocol.Event) error) error { + start := time.Now() + err := next(ctx, event) + duration := time.Since(start) + + eventType := string(event.GetType()) + + m.mu.Lock() + m.eventCounts[eventType]++ + m.eventTimes[eventType] += duration + m.mu.Unlock() + + return err +} + +// GetMetrics 获取指标 +func (m *MetricsMiddleware) GetMetrics() map[string]interface{} { + m.mu.RLock() + defer m.mu.RUnlock() + + metrics := make(map[string]interface{}) + for eventType, count := range m.eventCounts { + avgTime := m.eventTimes[eventType] / time.Duration(count) + metrics[eventType] = map[string]interface{}{ + "count": count, + "avg_time": avgTime.String(), + } + } + + return metrics +} + +// LogMetrics 记录指标 +func (m *MetricsMiddleware) LogMetrics() { + metrics := m.GetMetrics() + m.logger.Info("Event metrics", zap.Any("metrics", metrics)) +} diff --git a/internal/plugins/echo/echo.go b/internal/plugins/echo/echo.go new file mode 100644 index 0000000..da6d341 --- /dev/null +++ b/internal/plugins/echo/echo.go @@ -0,0 +1,137 @@ +package echo + +import ( + "context" + + "cellbot/internal/protocol" + + "go.uber.org/zap" +) + +// EchoPlugin 回声插件 +type EchoPlugin struct { + logger *zap.Logger + botManager *protocol.BotManager +} + +// NewEchoPlugin 创建回声插件 +func NewEchoPlugin(logger *zap.Logger, botManager *protocol.BotManager) *EchoPlugin { + return &EchoPlugin{ + logger: logger.Named("echo-plugin"), + botManager: botManager, + } +} + +// Handle 处理事件 +func (p *EchoPlugin) Handle(ctx context.Context, event protocol.Event) error { + // 获取事件数据 + data := event.GetData() + + // 获取消息内容 + message, ok := data["message"] + if !ok { + p.logger.Debug("No message field in event") + return nil + } + + rawMessage, ok := data["raw_message"].(string) + if !ok { + p.logger.Debug("No raw_message field in event") + return nil + } + + // 获取用户ID + userID, ok := data["user_id"] + if !ok { + p.logger.Debug("No user_id field in event") + return nil + } + + p.logger.Info("Received private message", + zap.Any("user_id", userID), + zap.String("message", rawMessage)) + + // 获取 self_id 来确定是哪个 bot + selfID := event.GetSelfID() + + // 获取对应的 bot 实例 + bot, ok := p.botManager.Get(selfID) + if !ok { + // 如果通过 selfID 找不到,尝试获取第一个 bot + bots := p.botManager.GetAll() + if len(bots) == 0 { + p.logger.Error("No bot instance available") + return nil + } + bot = bots[0] + p.logger.Debug("Using first available bot", + zap.String("bot_id", bot.GetID())) + } + + // 构建回复动作 + action := &protocol.BaseAction{ + Type: protocol.ActionTypeSendPrivateMessage, + Params: map[string]interface{}{ + "user_id": userID, + "message": message, // 原封不动返回消息 + }, + } + + p.logger.Info("Sending echo reply", + zap.Any("user_id", userID), + zap.String("reply", rawMessage)) + + // 发送消息 + result, err := bot.SendAction(ctx, action) + if err != nil { + p.logger.Error("Failed to send echo reply", + zap.Error(err), + zap.Any("user_id", userID)) + return err + } + + p.logger.Info("Echo reply sent successfully", + zap.Any("user_id", userID), + zap.Any("result", result)) + + return nil +} + +// Priority 获取处理器优先级 +func (p *EchoPlugin) Priority() int { + return 100 // 中等优先级 +} + +// Match 判断是否匹配事件 +func (p *EchoPlugin) Match(event protocol.Event) bool { + // 只处理私聊消息 + eventType := event.GetType() + detailType := event.GetDetailType() + + p.logger.Debug("Echo plugin matching event", + zap.String("event_type", string(eventType)), + zap.String("detail_type", detailType)) + + if eventType != protocol.EventTypeMessage { + p.logger.Debug("Event type mismatch", zap.String("expected", string(protocol.EventTypeMessage))) + return false + } + + if detailType != "private" { + p.logger.Debug("Detail type mismatch", zap.String("expected", "private"), zap.String("got", detailType)) + return false + } + + p.logger.Info("Echo plugin matched event!") + return true +} + +// Name 获取插件名称 +func (p *EchoPlugin) Name() string { + return "Echo" +} + +// Description 获取插件描述 +func (p *EchoPlugin) Description() string { + return "回声插件:将私聊消息原封不动返回" +} diff --git a/pkg/net/httpclient.go b/pkg/net/httpclient.go new file mode 100644 index 0000000..591e28d --- /dev/null +++ b/pkg/net/httpclient.go @@ -0,0 +1,313 @@ +package net + +import ( + "context" + "fmt" + "sync" + "time" + + "cellbot/internal/engine" + "cellbot/internal/protocol" + + "github.com/bytedance/sonic" + "github.com/valyala/fasthttp" + "go.uber.org/zap" +) + +// HTTPClient HTTP客户端(用于正向HTTP连接) +type HTTPClient struct { + client *fasthttp.Client + logger *zap.Logger + eventBus *engine.EventBus + botID string + baseURL string + timeout time.Duration + retryCount int +} + +// HTTPClientConfig HTTP客户端配置 +type HTTPClientConfig struct { + BotID string + BaseURL string + Timeout time.Duration + RetryCount int +} + +// NewHTTPClient 创建HTTP客户端 +func NewHTTPClient(config HTTPClientConfig, logger *zap.Logger, eventBus *engine.EventBus) *HTTPClient { + if config.Timeout == 0 { + config.Timeout = 30 * time.Second + } + if config.RetryCount == 0 { + config.RetryCount = 3 + } + + return &HTTPClient{ + client: &fasthttp.Client{ + ReadTimeout: config.Timeout, + WriteTimeout: config.Timeout, + MaxConnsPerHost: 100, + }, + logger: logger.Named("http-client"), + eventBus: eventBus, + botID: config.BotID, + baseURL: config.BaseURL, + timeout: config.Timeout, + retryCount: config.RetryCount, + } +} + +// SendAction 发送动作请求(正向HTTP) +func (hc *HTTPClient) SendAction(ctx context.Context, action protocol.Action) (map[string]interface{}, error) { + // 序列化动作为JSON + data, err := sonic.Marshal(action) + if err != nil { + return nil, fmt.Errorf("failed to marshal action: %w", err) + } + + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + url := hc.baseURL + "/action" + req.SetRequestURI(url) + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.SetBody(data) + + hc.logger.Debug("Sending action", + zap.String("url", url), + zap.String("action", string(action.GetType()))) + + // 重试机制 + var lastErr error + for i := 0; i <= hc.retryCount; i++ { + if i > 0 { + hc.logger.Info("Retrying action request", + zap.Int("attempt", i), + zap.Int("max", hc.retryCount)) + time.Sleep(time.Duration(i) * time.Second) + } + + err := hc.client.DoTimeout(req, resp, hc.timeout) + if err != nil { + lastErr = fmt.Errorf("request failed: %w", err) + continue + } + + if resp.StatusCode() != fasthttp.StatusOK { + lastErr = fmt.Errorf("unexpected status code: %d", resp.StatusCode()) + continue + } + + // 解析响应 + var result map[string]interface{} + if err := sonic.Unmarshal(resp.Body(), &result); err != nil { + lastErr = fmt.Errorf("failed to parse response: %w", err) + continue + } + + hc.logger.Info("Action sent successfully", + zap.String("action", string(action.GetType()))) + + return result, nil + } + + return nil, fmt.Errorf("action failed after %d retries: %w", hc.retryCount, lastErr) +} + +// PollEvents 轮询事件(正向HTTP) +func (hc *HTTPClient) PollEvents(ctx context.Context, interval time.Duration) error { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + hc.logger.Info("Starting event polling", + zap.Duration("interval", interval)) + + for { + select { + case <-ticker.C: + if err := hc.fetchEvents(ctx); err != nil { + hc.logger.Error("Failed to fetch events", zap.Error(err)) + } + + case <-ctx.Done(): + hc.logger.Info("Event polling stopped") + return ctx.Err() + } + } +} + +// fetchEvents 获取事件 +func (hc *HTTPClient) fetchEvents(ctx context.Context) error { + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + url := hc.baseURL + "/events" + req.SetRequestURI(url) + req.Header.SetMethod("GET") + + err := hc.client.DoTimeout(req, resp, hc.timeout) + if err != nil { + return fmt.Errorf("request failed: %w", err) + } + + if resp.StatusCode() != fasthttp.StatusOK { + return fmt.Errorf("unexpected status code: %d", resp.StatusCode()) + } + + // 解析事件列表 + var events []protocol.BaseEvent + if err := sonic.Unmarshal(resp.Body(), &events); err != nil { + return fmt.Errorf("failed to parse events: %w", err) + } + + // 发布事件到事件总线 + for i := range events { + hc.logger.Debug("Event received", + zap.String("type", string(events[i].Type)), + zap.String("detail_type", events[i].DetailType)) + + hc.eventBus.Publish(&events[i]) + } + + return nil +} + +// HTTPWebhookServer HTTP Webhook服务器(用于反向HTTP连接) +type HTTPWebhookServer struct { + server *fasthttp.Server + logger *zap.Logger + eventBus *engine.EventBus + handlers map[string]*WebhookHandler + mu sync.RWMutex +} + +// WebhookHandler Webhook处理器 +type WebhookHandler struct { + BotID string + Secret string + Validator func([]byte, string) bool +} + +// NewHTTPWebhookServer 创建HTTP Webhook服务器 +func NewHTTPWebhookServer(logger *zap.Logger, eventBus *engine.EventBus) *HTTPWebhookServer { + return &HTTPWebhookServer{ + logger: logger.Named("webhook-server"), + eventBus: eventBus, + handlers: make(map[string]*WebhookHandler), + } +} + +// RegisterWebhook 注册Webhook处理器 +func (hws *HTTPWebhookServer) RegisterWebhook(path string, handler *WebhookHandler) { + hws.mu.Lock() + defer hws.mu.Unlock() + + hws.handlers[path] = handler + hws.logger.Info("Webhook registered", + zap.String("path", path), + zap.String("bot_id", handler.BotID)) +} + +// UnregisterWebhook 注销Webhook处理器 +func (hws *HTTPWebhookServer) UnregisterWebhook(path string) { + hws.mu.Lock() + defer hws.mu.Unlock() + + delete(hws.handlers, path) + hws.logger.Info("Webhook unregistered", zap.String("path", path)) +} + +// Start 启动Webhook服务器 +func (hws *HTTPWebhookServer) Start(addr string) error { + hws.server = &fasthttp.Server{ + Handler: hws.handleWebhook, + } + + hws.logger.Info("Starting webhook server", zap.String("address", addr)) + + go func() { + if err := hws.server.ListenAndServe(addr); err != nil { + hws.logger.Error("Webhook server error", zap.Error(err)) + } + }() + + return nil +} + +// Stop 停止Webhook服务器 +func (hws *HTTPWebhookServer) Stop() error { + if hws.server != nil { + hws.logger.Info("Stopping webhook server") + return hws.server.Shutdown() + } + return nil +} + +// handleWebhook 处理Webhook请求 +func (hws *HTTPWebhookServer) handleWebhook(ctx *fasthttp.RequestCtx) { + path := string(ctx.Path()) + + hws.mu.RLock() + handler, exists := hws.handlers[path] + hws.mu.RUnlock() + + if !exists { + ctx.Error("Webhook not found", fasthttp.StatusNotFound) + return + } + + // 验证签名(如果配置了) + if handler.Secret != "" && handler.Validator != nil { + signature := string(ctx.Request.Header.Peek("X-Signature")) + if !handler.Validator(ctx.PostBody(), signature) { + hws.logger.Warn("Invalid webhook signature", + zap.String("path", path), + zap.String("bot_id", handler.BotID)) + ctx.Error("Invalid signature", fasthttp.StatusUnauthorized) + return + } + } + + // 解析事件 + var event protocol.BaseEvent + if err := sonic.Unmarshal(ctx.PostBody(), &event); err != nil { + hws.logger.Error("Failed to parse webhook event", + zap.Error(err), + zap.String("path", path)) + ctx.Error("Invalid event format", fasthttp.StatusBadRequest) + return + } + + // 设置BotID + if event.SelfID == "" { + event.SelfID = handler.BotID + } + + // 设置时间戳 + if event.Timestamp == 0 { + event.Timestamp = time.Now().Unix() + } + + // 确保Data字段不为nil + if event.Data == nil { + event.Data = make(map[string]interface{}) + } + + hws.logger.Info("Webhook event received", + zap.String("path", path), + zap.String("bot_id", handler.BotID), + zap.String("type", string(event.Type)), + zap.String("detail_type", event.DetailType)) + + // 发布到事件总线 + hws.eventBus.Publish(&event) + + // 返回成功响应 + ctx.SetContentType("application/json") + ctx.SetBodyString(`{"success":true}`) +} diff --git a/pkg/net/server.go b/pkg/net/server.go index 29856ba..869fcfb 100644 --- a/pkg/net/server.go +++ b/pkg/net/server.go @@ -1,12 +1,13 @@ package net import ( - "fmt" "net" "strconv" "cellbot/internal/engine" "cellbot/internal/protocol" + + "github.com/bytedance/sonic" "github.com/valyala/fasthttp" "go.uber.org/zap" ) @@ -18,17 +19,20 @@ type Server struct { logger *zap.Logger botManager *protocol.BotManager eventBus *engine.EventBus + wsManager *WebSocketManager server *fasthttp.Server } // NewServer 创建HTTP服务器 func NewServer(host string, port int, logger *zap.Logger, botManager *protocol.BotManager, eventBus *engine.EventBus) *Server { + wsManager := NewWebSocketManager(logger, eventBus) return &Server{ host: host, port: port, logger: logger.Named("server"), botManager: botManager, eventBus: eventBus, + wsManager: wsManager, } } @@ -105,28 +109,55 @@ func (s *Server) handleHealth(ctx *fasthttp.RequestCtx) { ctx.SetBodyString(`{"status":"ok"}`) } +// BotInfo 机器人信息结构 +type BotInfo struct { + ID string `json:"id"` + Name string `json:"name"` + Version string `json:"version"` + Status string `json:"status"` + SelfID string `json:"self_id"` + Connected bool `json:"connected"` +} + // handleBots 获取机器人列表 func (s *Server) handleBots(ctx *fasthttp.RequestCtx) { bots := s.botManager.GetAll() ctx.SetContentType("application/json") - if len(bots) == 0 { - ctx.SetBodyString(`{"bots":[]}`) + // 构建机器人信息列表 + botInfos := make([]BotInfo, 0, len(bots)) + for _, bot := range bots { + botInfos = append(botInfos, BotInfo{ + ID: bot.GetID(), + Name: bot.Name(), + Version: bot.Version(), + Status: string(bot.GetStatus()), + SelfID: bot.GetSelfID(), + Connected: bot.IsConnected(), + }) + } + + // 序列化为JSON + response := map[string]interface{}{ + "bots": botInfos, + "count": len(botInfos), + } + + data, err := sonic.Marshal(response) + if err != nil { + s.logger.Error("Failed to marshal bots response", zap.Error(err)) + ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError) return } - // 简化实现,实际应该序列化完整信息 - response := `{"bots":[` - for i, bot := range bots { - if i > 0 { - response += "," - } - response += fmt.Sprintf(`{"id":"%s","name":"%s","status":"%s"}`, - bot.GetID(), bot.Name(), bot.GetStatus()) - } - response += `]}` + ctx.SetBody(data) +} - ctx.SetBodyString(response) +// CreateBotRequest 创建机器人请求 +type CreateBotRequest struct { + ID string `json:"id"` + Protocol string `json:"protocol"` + Config map[string]interface{} `json:"config"` } // handleCreateBot 创建机器人 @@ -137,7 +168,73 @@ func (s *Server) handleCreateBot(ctx *fasthttp.RequestCtx) { } ctx.SetContentType("application/json") - ctx.SetBodyString(`{"message":"Bot creation not implemented yet"}`) + + // 解析请求体 + var req CreateBotRequest + if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { + s.logger.Error("Failed to parse create bot request", zap.Error(err)) + response := map[string]interface{}{ + "success": false, + "error": "Invalid request body", + } + data, _ := sonic.Marshal(response) + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBody(data) + return + } + + // 验证必需字段 + if req.ID == "" { + response := map[string]interface{}{ + "success": false, + "error": "Bot ID is required", + } + data, _ := sonic.Marshal(response) + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBody(data) + return + } + + if req.Protocol == "" { + response := map[string]interface{}{ + "success": false, + "error": "Protocol is required", + } + data, _ := sonic.Marshal(response) + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBody(data) + return + } + + // 检查机器人是否已存在 + if _, exists := s.botManager.Get(req.ID); exists { + response := map[string]interface{}{ + "success": false, + "error": "Bot with this ID already exists", + } + data, _ := sonic.Marshal(response) + ctx.SetStatusCode(fasthttp.StatusConflict) + ctx.SetBody(data) + return + } + + // TODO: 根据协议类型创建相应的机器人实例 + // 这里需要协议工厂来创建不同类型的机器人 + // 目前返回成功但提示需要实现协议适配器 + + s.logger.Info("Bot creation requested", + zap.String("bot_id", req.ID), + zap.String("protocol", req.Protocol)) + + response := map[string]interface{}{ + "success": true, + "message": "Bot creation queued (protocol adapter implementation required)", + "bot_id": req.ID, + } + + data, _ := sonic.Marshal(response) + ctx.SetStatusCode(fasthttp.StatusAccepted) + ctx.SetBody(data) } // handlePublishEvent 发布事件 @@ -147,19 +244,169 @@ func (s *Server) handlePublishEvent(ctx *fasthttp.RequestCtx) { return } - // 解析请求体并发布事件 - // 这里简化实现,实际应该解析JSON并创建Event ctx.SetContentType("application/json") - ctx.SetBodyString(`{"message":"Event published"}`) + + // 解析请求体为事件对象 + var event protocol.BaseEvent + if err := sonic.Unmarshal(ctx.PostBody(), &event); err != nil { + s.logger.Error("Failed to parse event", zap.Error(err)) + response := map[string]interface{}{ + "success": false, + "error": "Invalid event format", + } + data, _ := sonic.Marshal(response) + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBody(data) + return + } + + // 验证事件类型 + if event.Type == "" { + response := map[string]interface{}{ + "success": false, + "error": "Event type is required", + } + data, _ := sonic.Marshal(response) + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBody(data) + return + } + + // 如果没有时间戳,使用当前时间 + if event.Timestamp == 0 { + event.Timestamp = ctx.Time().Unix() + } + + // 确保Data字段不为nil + if event.Data == nil { + event.Data = make(map[string]interface{}) + } + + s.logger.Info("Publishing event", + zap.String("type", string(event.Type)), + zap.String("detail_type", event.DetailType), + zap.String("self_id", event.SelfID)) + + // 发布到事件总线 + s.eventBus.Publish(&event) + + response := map[string]interface{}{ + "success": true, + "message": "Event published successfully", + "timestamp": event.Timestamp, + } + + data, _ := sonic.Marshal(response) + ctx.SetBody(data) } -// handleSubscribeEvent 订阅事件 +// handleSubscribeEvent 订阅事件(WebSocket升级) func (s *Server) handleSubscribeEvent(ctx *fasthttp.RequestCtx) { - if string(ctx.Method()) != "GET" { + // 检查是否为WebSocket升级请求 + if !ctx.IsGet() { ctx.Error("Method Not Allowed", fasthttp.StatusMethodNotAllowed) return } - ctx.SetContentType("application/json") - ctx.SetBodyString(`{"message":"Event subscription not implemented yet"}`) + // 检查是否为WebSocket升级请求 + if string(ctx.Request.Header.Peek("Upgrade")) != "websocket" { + ctx.SetContentType("application/json") + response := map[string]interface{}{ + "success": false, + "error": "WebSocket upgrade required", + "message": "This endpoint requires WebSocket connection", + } + data, _ := sonic.Marshal(response) + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBody(data) + return + } + + // 获取订阅者ID(可选) + subscriberID := string(ctx.QueryArgs().Peek("subscriber_id")) + if subscriberID == "" { + subscriberID = "event-subscriber" + } + + // 获取要订阅的事件类型(可选,为空则订阅所有类型) + eventTypeStr := string(ctx.QueryArgs().Peek("event_type")) + + s.logger.Info("WebSocket event subscription request", + zap.String("subscriber_id", subscriberID), + zap.String("event_type", eventTypeStr), + zap.String("remote_addr", ctx.RemoteAddr().String())) + + // 升级为WebSocket连接 + wsConn, err := s.wsManager.UpgradeWebSocket(ctx) + if err != nil { + s.logger.Error("Failed to upgrade WebSocket", zap.Error(err)) + ctx.SetContentType("application/json") + response := map[string]interface{}{ + "success": false, + "error": "Failed to upgrade WebSocket connection", + } + data, _ := sonic.Marshal(response) + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetBody(data) + return + } + + // 订阅所有主要事件类型 + eventTypes := []protocol.EventType{ + protocol.EventTypeMessage, + protocol.EventTypeNotice, + protocol.EventTypeRequest, + protocol.EventTypeMeta, + protocol.EventTypeMessageSent, + protocol.EventTypeNoticeSent, + protocol.EventTypeRequestSent, + } + + // 如果指定了事件类型,只订阅该类型 + if eventTypeStr != "" { + eventTypes = []protocol.EventType{protocol.EventType(eventTypeStr)} + } + + // 为每种事件类型创建订阅 + for _, eventType := range eventTypes { + eventChan := s.eventBus.Subscribe(eventType, nil) // nil filter means accept all events + + // 启动goroutine监听事件并发送到WebSocket + go func(ch chan protocol.Event, et protocol.EventType) { + for { + select { + case event, ok := <-ch: + if !ok { + return + } + + // 序列化事件为JSON + data, err := sonic.Marshal(event) + if err != nil { + s.logger.Error("Failed to marshal event", zap.Error(err)) + continue + } + + // 发送到WebSocket连接 + if err := wsConn.SendMessage(data); err != nil { + s.logger.Error("Failed to send event to WebSocket", + zap.String("conn_id", wsConn.ID), + zap.Error(err)) + // 连接可能已断开,取消订阅 + s.eventBus.Unsubscribe(et, ch) + return + } + + case <-wsConn.ctx.Done(): + // 连接关闭,取消订阅 + s.eventBus.Unsubscribe(et, ch) + return + } + } + }(eventChan, eventType) + } + + s.logger.Info("WebSocket event subscription established", + zap.String("conn_id", wsConn.ID), + zap.String("subscriber_id", subscriberID)) } diff --git a/pkg/net/sse.go b/pkg/net/sse.go new file mode 100644 index 0000000..0c64649 --- /dev/null +++ b/pkg/net/sse.go @@ -0,0 +1,244 @@ +package net + +import ( + "bufio" + "context" + "fmt" + "net" + "net/http" + "strings" + "time" + + "go.uber.org/zap" +) + +// SSEClient Server-Sent Events 客户端 +type SSEClient struct { + url string + accessToken string + eventChan chan []byte + logger *zap.Logger + reconnectDelay time.Duration + maxReconnect int + ctx context.Context + cancel context.CancelFunc + eventFilter string +} + +// SSEClientConfig SSE 客户端配置 +type SSEClientConfig struct { + URL string + AccessToken string + ReconnectDelay time.Duration + MaxReconnect int + EventFilter string + BufferSize int +} + +// NewSSEClient 创建 SSE 客户端 +func NewSSEClient(config SSEClientConfig, logger *zap.Logger) *SSEClient { + ctx, cancel := context.WithCancel(context.Background()) + + if config.ReconnectDelay == 0 { + config.ReconnectDelay = 5 * time.Second + } + if config.MaxReconnect == 0 { + config.MaxReconnect = -1 + } + if config.BufferSize == 0 { + config.BufferSize = 100 + } + + return &SSEClient{ + url: config.URL, + accessToken: config.AccessToken, + eventChan: make(chan []byte, config.BufferSize), + logger: logger.Named("sse-client"), + reconnectDelay: config.ReconnectDelay, + maxReconnect: config.MaxReconnect, + eventFilter: config.EventFilter, + ctx: ctx, + cancel: cancel, + } +} + +// Connect 连接到 SSE 服务器 +func (c *SSEClient) Connect(ctx context.Context) error { + c.logger.Info("Starting SSE client", zap.String("url", c.url)) + go c.connectLoop(ctx) + return nil +} + +// connectLoop 连接循环 +func (c *SSEClient) connectLoop(ctx context.Context) { + reconnectCount := 0 + + for { + select { + case <-ctx.Done(): + c.logger.Info("SSE client stopped") + return + case <-c.ctx.Done(): + c.logger.Info("SSE client stopped") + return + default: + } + + c.logger.Info("Connecting to SSE server", + zap.String("url", c.url), + zap.Int("reconnect_count", reconnectCount)) + + err := c.connect(ctx) + if err != nil { + c.logger.Error("SSE connection failed", zap.Error(err)) + } + + if c.maxReconnect >= 0 && reconnectCount >= c.maxReconnect { + c.logger.Error("Max reconnect attempts reached", zap.Int("count", reconnectCount)) + return + } + + reconnectCount++ + + c.logger.Info("Reconnecting after delay", + zap.Duration("delay", c.reconnectDelay), + zap.Int("attempt", reconnectCount)) + + select { + case <-time.After(c.reconnectDelay): + case <-ctx.Done(): + return + case <-c.ctx.Done(): + return + } + } +} + +// connect 建立单次连接 +func (c *SSEClient) connect(ctx context.Context) error { + req, err := http.NewRequestWithContext(ctx, "GET", c.url, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + if c.accessToken != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.accessToken)) + } + + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + client := &http.Client{ + Timeout: 0, + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + } + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to connect: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + contentType := resp.Header.Get("Content-Type") + if !strings.HasPrefix(contentType, "text/event-stream") { + return fmt.Errorf("unexpected content type: %s", contentType) + } + + c.logger.Info("SSE connection established") + + return c.readEventStream(ctx, resp) +} + +// readEventStream 读取事件流 +func (c *SSEClient) readEventStream(ctx context.Context, resp *http.Response) error { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + var eventType string + var dataLines []string + + for scanner.Scan() { + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.ctx.Done(): + return c.ctx.Err() + default: + } + + line := scanner.Text() + + if line == "" { + if len(dataLines) > 0 { + c.processEvent(eventType, dataLines) + eventType = "" + dataLines = nil + } + continue + } + + if strings.HasPrefix(line, ":") { + continue + } + + if strings.HasPrefix(line, "event:") { + eventType = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + } else if strings.HasPrefix(line, "data:") { + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + dataLines = append(dataLines, data) + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("scanner error: %w", err) + } + + return fmt.Errorf("connection closed") +} + +// processEvent 处理事件 +func (c *SSEClient) processEvent(eventType string, dataLines []string) { + if c.eventFilter != "" && eventType != c.eventFilter && eventType != "" { + c.logger.Debug("Ignoring filtered event", zap.String("event_type", eventType)) + return + } + + data := strings.Join(dataLines, "\n") + + c.logger.Debug("Received SSE event", + zap.String("event_type", eventType), + zap.Int("data_length", len(data))) + + select { + case c.eventChan <- []byte(data): + default: + c.logger.Warn("Event channel full, dropping event") + } +} + +// Events 获取事件通道 +func (c *SSEClient) Events() <-chan []byte { + return c.eventChan +} + +// Close 关闭客户端 +func (c *SSEClient) Close() error { + c.cancel() + close(c.eventChan) + c.logger.Info("SSE client closed") + return nil +} diff --git a/pkg/net/websocket.go b/pkg/net/websocket.go index 0b246e0..8217b1e 100644 --- a/pkg/net/websocket.go +++ b/pkg/net/websocket.go @@ -10,6 +10,7 @@ import ( "cellbot/internal/engine" "cellbot/internal/protocol" + "github.com/bytedance/sonic" "github.com/fasthttp/websocket" "github.com/valyala/fasthttp" "go.uber.org/zap" @@ -40,27 +41,45 @@ func NewWebSocketManager(logger *zap.Logger, eventBus *engine.EventBus) *WebSock } } +// ConnectionType 连接类型 +type ConnectionType string + +const ( + ConnectionTypeReverse ConnectionType = "reverse" // 反向连接(被动接受) + ConnectionTypeForward ConnectionType = "forward" // 正向连接(主动发起) +) + // WebSocketConnection WebSocket连接 type WebSocketConnection struct { - ID string - Conn *websocket.Conn - BotID string - Logger *zap.Logger - ctx context.Context - cancel context.CancelFunc + ID string + Conn *websocket.Conn + BotID string + Logger *zap.Logger + ctx context.Context + cancel context.CancelFunc + Type ConnectionType + RemoteAddr string + reconnectURL string // 用于正向连接重连 + maxReconnect int // 最大重连次数 + reconnectCount int // 当前重连次数 + heartbeatTick time.Duration // 心跳间隔 } // NewWebSocketConnection 创建WebSocket连接 -func NewWebSocketConnection(conn *websocket.Conn, botID string, logger *zap.Logger) *WebSocketConnection { +func NewWebSocketConnection(conn *websocket.Conn, botID string, connType ConnectionType, logger *zap.Logger) *WebSocketConnection { ctx, cancel := context.WithCancel(context.Background()) connID := generateConnID() return &WebSocketConnection{ - ID: connID, - Conn: conn, - BotID: botID, - Logger: logger.With(zap.String("conn_id", connID)), - ctx: ctx, - cancel: cancel, + ID: connID, + Conn: conn, + BotID: botID, + Logger: logger.With(zap.String("conn_id", connID)), + ctx: ctx, + cancel: cancel, + Type: connType, + RemoteAddr: conn.RemoteAddr().String(), + maxReconnect: 5, + heartbeatTick: 30 * time.Second, } } @@ -85,21 +104,23 @@ func (wsm *WebSocketManager) UpgradeWebSocket(ctx *fasthttp.RequestCtx) (*WebSoc // 等待连接在回调中建立 conn := <-connChan - - // 创建连接对象 - wsConn := NewWebSocketConnection(conn, botID, wsm.logger) + + // 创建连接对象(反向连接) + wsConn := NewWebSocketConnection(conn, botID, ConnectionTypeReverse, wsm.logger) // 存储连接 wsm.mu.Lock() wsm.connections[wsConn.ID] = wsConn wsm.mu.Unlock() - wsm.logger.Info("WebSocket connection established", + wsm.logger.Info("WebSocket reverse connection established", zap.String("conn_id", wsConn.ID), - zap.String("bot_id", botID)) + zap.String("bot_id", botID), + zap.String("remote_addr", wsConn.RemoteAddr)) - // 启动读取循环 + // 启动读取循环和心跳 go wsConn.readLoop(wsm.eventBus) + go wsConn.heartbeatLoop() return wsConn, nil } @@ -119,9 +140,15 @@ func (wsc *WebSocketConnection) readLoop(eventBus *engine.EventBus) { return } + // 只处理文本消息,忽略其他类型 + if messageType != websocket.TextMessage { + wsc.Logger.Warn("Received non-text message, ignoring", + zap.Int("message_type", messageType)) + continue + } + // 处理消息 wsc.handleMessage(message, eventBus) - // messageType 可用于区分文本或二进制消息 } } } @@ -130,18 +157,41 @@ func (wsc *WebSocketConnection) readLoop(eventBus *engine.EventBus) { func (wsc *WebSocketConnection) handleMessage(data []byte, eventBus *engine.EventBus) { wsc.Logger.Debug("Received message", zap.ByteString("data", data)) - // TODO: 解析消息为Event对象 - // 这里简化实现,实际应该根据协议解析 - event := &protocol.BaseEvent{ - Type: protocol.EventTypeMessage, - DetailType: "private", - Timestamp: time.Now().Unix(), - SelfID: wsc.BotID, - Data: make(map[string]interface{}), + // 解析JSON消息为BaseEvent + var event protocol.BaseEvent + if err := sonic.Unmarshal(data, &event); err != nil { + wsc.Logger.Error("Failed to parse message", zap.Error(err), zap.ByteString("data", data)) + return } + // 验证必需字段 + if event.Type == "" { + wsc.Logger.Warn("Event type is empty", zap.ByteString("data", data)) + return + } + + // 如果没有时间戳,使用当前时间 + if event.Timestamp == 0 { + event.Timestamp = time.Now().Unix() + } + + // 如果没有SelfID,使用连接的BotID + if event.SelfID == "" { + event.SelfID = wsc.BotID + } + + // 确保Data字段不为nil + if event.Data == nil { + event.Data = make(map[string]interface{}) + } + + wsc.Logger.Info("Event received", + zap.String("type", string(event.Type)), + zap.String("detail_type", event.DetailType), + zap.String("self_id", event.SelfID)) + // 发布到事件总线 - eventBus.Publish(event) + eventBus.Publish(&event) } // SendMessage 发送消息 @@ -155,6 +205,80 @@ func (wsc *WebSocketConnection) SendMessage(data []byte) error { return nil } +// heartbeatLoop 心跳循环 +func (wsc *WebSocketConnection) heartbeatLoop() { + ticker := time.NewTicker(wsc.heartbeatTick) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + // 发送ping消息 + if err := wsc.Conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { + wsc.Logger.Warn("Failed to send ping", zap.Error(err)) + return + } + wsc.Logger.Debug("Heartbeat ping sent") + + case <-wsc.ctx.Done(): + return + } + } +} + +// reconnectLoop 重连循环(仅用于正向连接) +func (wsc *WebSocketConnection) reconnectLoop(wsm *WebSocketManager) { + <-wsc.ctx.Done() // 等待连接断开 + + if wsc.Type != ConnectionTypeForward || wsc.reconnectURL == "" { + return + } + + wsc.Logger.Info("Connection closed, attempting to reconnect", + zap.Int("max_reconnect", wsc.maxReconnect)) + + for wsc.reconnectCount < wsc.maxReconnect { + wsc.reconnectCount++ + backoff := time.Duration(wsc.reconnectCount) * 5 * time.Second + + wsc.Logger.Info("Reconnecting", + zap.Int("attempt", wsc.reconnectCount), + zap.Int("max", wsc.maxReconnect), + zap.Duration("backoff", backoff)) + + time.Sleep(backoff) + + // 尝试重新连接 + conn, _, err := websocket.DefaultDialer.Dial(wsc.reconnectURL, nil) + if err != nil { + wsc.Logger.Error("Reconnect failed", zap.Error(err)) + continue + } + + // 更新连接 + wsc.Conn = conn + wsc.RemoteAddr = conn.RemoteAddr().String() + wsc.ctx, wsc.cancel = context.WithCancel(context.Background()) + wsc.reconnectCount = 0 // 重置重连计数 + + wsc.Logger.Info("Reconnected successfully", + zap.String("remote_addr", wsc.RemoteAddr)) + + // 重新启动读取循环和心跳 + go wsc.readLoop(wsm.eventBus) + go wsc.heartbeatLoop() + go wsc.reconnectLoop(wsm) + + return + } + + wsc.Logger.Error("Max reconnect attempts reached, giving up", + zap.Int("attempts", wsc.reconnectCount)) + + // 从管理器中移除连接 + wsm.RemoveConnection(wsc.ID) +} + // close 关闭连接 func (wsc *WebSocketConnection) close() { wsc.cancel() @@ -209,31 +333,64 @@ func (wsm *WebSocketManager) BroadcastToBot(botID string, data []byte) { } } -// Dial 建立WebSocket客户端连接 +// DialConfig WebSocket客户端连接配置 +type DialConfig struct { + URL string + BotID string + MaxReconnect int + HeartbeatTick time.Duration +} + +// Dial 建立WebSocket客户端连接(正向连接) func (wsm *WebSocketManager) Dial(addr string, botID string) (*WebSocketConnection, error) { - u, err := url.Parse(addr) + return wsm.DialWithConfig(DialConfig{ + URL: addr, + BotID: botID, + MaxReconnect: 5, + HeartbeatTick: 30 * time.Second, + }) +} + +// DialWithConfig 使用配置建立WebSocket客户端连接 +func (wsm *WebSocketManager) DialWithConfig(config DialConfig) (*WebSocketConnection, error) { + u, err := url.Parse(config.URL) if err != nil { return nil, fmt.Errorf("invalid URL: %w", err) } - conn, _, err := websocket.DefaultDialer.Dial(addr, nil) + // 验证URL scheme必须是ws或wss + if u.Scheme != "ws" && u.Scheme != "wss" { + return nil, fmt.Errorf("invalid URL scheme: %s, expected ws or wss", u.Scheme) + } + + conn, _, err := websocket.DefaultDialer.Dial(config.URL, nil) if err != nil { return nil, fmt.Errorf("failed to dial: %w", err) } - wsConn := NewWebSocketConnection(conn, botID, wsm.logger) + wsConn := NewWebSocketConnection(conn, config.BotID, ConnectionTypeForward, wsm.logger) + wsConn.reconnectURL = config.URL + wsConn.maxReconnect = config.MaxReconnect + wsConn.heartbeatTick = config.HeartbeatTick wsm.mu.Lock() wsm.connections[wsConn.ID] = wsConn wsm.mu.Unlock() - wsm.logger.Info("WebSocket client connected", + wsm.logger.Info("WebSocket forward connection established", zap.String("conn_id", wsConn.ID), - zap.String("bot_id", botID), - zap.String("addr", addr)) + zap.String("bot_id", config.BotID), + zap.String("addr", config.URL), + zap.String("remote_addr", wsConn.RemoteAddr)) - // 启动读取循环 + // 启动读取循环和心跳 go wsConn.readLoop(wsm.eventBus) + go wsConn.heartbeatLoop() + + // 如果是正向连接,启动重连监控 + if wsConn.Type == ConnectionTypeForward { + go wsConn.reconnectLoop(wsm) + } return wsConn, nil }