feat: enhance event handling and add scheduling capabilities

- Introduced a new scheduler to manage timed tasks within the event dispatcher.
- Updated the dispatcher to support the new scheduler, allowing for improved event processing.
- Enhanced action serialization in the OneBot11 adapter to convert message chains to the appropriate format.
- Added new dependencies for cron scheduling and other indirect packages in go.mod and go.sum.
- Improved logging for event publishing and handler matching, providing better insights during execution.
- Refactored plugin loading to include scheduled job management.
This commit is contained in:
lafay
2026-01-05 04:33:30 +08:00
parent d16261e6bd
commit 64cd81b7f1
14 changed files with 2130 additions and 27 deletions

14
go.mod
View File

@@ -15,6 +15,18 @@ require (
golang.org/x/time v0.14.0 golang.org/x/time v0.14.0
) )
require github.com/robfig/cron/v3 v3.0.1
require (
github.com/chromedp/cdproto v0.0.0-20250724212937-08a3db8b4327 // indirect
github.com/chromedp/chromedp v0.14.2 // indirect
github.com/chromedp/sysutil v1.1.0 // indirect
github.com/go-json-experiment/json v0.0.0-20250725192818-e39067aee2d2 // indirect
github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect
github.com/gobwas/ws v1.4.0 // indirect
)
require ( require (
github.com/andybalholm/brotli v1.1.1 // indirect github.com/andybalholm/brotli v1.1.1 // indirect
github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/gopkg v0.1.3 // indirect
@@ -30,5 +42,5 @@ require (
go.uber.org/multierr v1.10.0 // indirect go.uber.org/multierr v1.10.0 // indirect
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect
golang.org/x/net v0.33.0 // indirect golang.org/x/net v0.33.0 // indirect
golang.org/x/sys v0.28.0 // indirect golang.org/x/sys v0.34.0 // indirect
) )

19
go.sum
View File

@@ -10,6 +10,12 @@ github.com/bytedance/sonic v1.14.2 h1:k1twIoe97C1DtYUo+fZQy865IuHia4PR5RPiuGPPII
github.com/bytedance/sonic v1.14.2/go.mod h1:T80iDELeHiHKSc0C9tubFygiuXoGzrkjKzX2quAx980= github.com/bytedance/sonic v1.14.2/go.mod h1:T80iDELeHiHKSc0C9tubFygiuXoGzrkjKzX2quAx980=
github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2NYzevs+o= github.com/bytedance/sonic/loader v0.4.0 h1:olZ7lEqcxtZygCK9EKYKADnpQoYkRQxaeY2NYzevs+o=
github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/bytedance/sonic/loader v0.4.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
github.com/chromedp/cdproto v0.0.0-20250724212937-08a3db8b4327 h1:UQ4AU+BGti3Sy/aLU8KVseYKNALcX9UXY6DfpwQ6J8E=
github.com/chromedp/cdproto v0.0.0-20250724212937-08a3db8b4327/go.mod h1:NItd7aLkcfOA/dcMXvl8p1u+lQqioRMq/SqDp71Pb/k=
github.com/chromedp/chromedp v0.14.2 h1:r3b/WtwM50RsBZHMUm9fsNhhzRStTHrKdr2zmwbZSzM=
github.com/chromedp/chromedp v0.14.2/go.mod h1:rHzAv60xDE7VNy/MYtTUrYreSc0ujt2O1/C3bzctYBo=
github.com/chromedp/sysutil v1.1.0 h1:PUFNv5EcprjqXZD9nJb9b/c9ibAbxiYo4exNWZyipwM=
github.com/chromedp/sysutil v1.1.0/go.mod h1:WiThHUdltqCNKGc4gaU50XgYjwjYIhKWoHGPTUfWTJ8=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -19,6 +25,14 @@ github.com/fasthttp/websocket v1.5.12 h1:e4RGPpWW2HTbL3zV0Y/t7g0ub294LkiuXXUuTOU
github.com/fasthttp/websocket v1.5.12/go.mod h1:I+liyL7/4moHojiOgUOIKEWm9EIxHqxZChS+aMFltyg= github.com/fasthttp/websocket v1.5.12/go.mod h1:I+liyL7/4moHojiOgUOIKEWm9EIxHqxZChS+aMFltyg=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-json-experiment/json v0.0.0-20250725192818-e39067aee2d2 h1:iizUGZ9pEquQS5jTGkh4AqeeHCMbfbjeb0zMt0aEFzs=
github.com/go-json-experiment/json v0.0.0-20250725192818-e39067aee2d2/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M=
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs=
github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
@@ -27,6 +41,8 @@ github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kK
github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55+q0lFuGse6U8lxlg7fYTctlT5Gc= github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55+q0lFuGse6U8lxlg7fYTctlT5Gc=
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg= github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@@ -62,8 +78,11 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VA
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -303,4 +303,4 @@ func BuildSetGroupCard(groupID, userID int64, card string) map[string]interface{
"user_id": userID, "user_id": userID,
"card": card, "card": card,
} }
} }

View File

@@ -227,9 +227,23 @@ func (a *Adapter) SerializeAction(action protocol.Action) ([]byte, error) {
zap.String("hint", "This action type may not be supported by OneBot11")) zap.String("hint", "This action type may not be supported by OneBot11"))
} }
// 复制参数并转换消息链
params := make(map[string]interface{})
for k, v := range action.GetParams() {
// 检查是否是消息链
if k == "message" {
if chain, ok := v.(protocol.MessageChain); ok {
// 转换为 OneBot11 格式
params[k] = ConvertMessageChainToOB11(chain)
continue
}
}
params[k] = v
}
ob11Action := &OB11Action{ ob11Action := &OB11Action{
Action: ob11ActionName, Action: ob11ActionName,
Params: action.GetParams(), Params: params,
} }
return sonic.Marshal(ob11Action) return sonic.Marshal(ob11Action)

View File

@@ -0,0 +1,238 @@
package onebot11
import (
"fmt"
"cellbot/internal/protocol"
)
// ConvertMessageChainToOB11 将通用消息链转换为 OneBot11 格式
func ConvertMessageChainToOB11(chain protocol.MessageChain) interface{} {
if len(chain) == 0 {
return ""
}
// 转换为 OneBot11 消息段格式
segments := make([]MessageSegment, 0, len(chain))
for _, seg := range chain {
ob11Seg := convertSegmentToOB11(seg)
if ob11Seg != nil {
segments = append(segments, *ob11Seg)
}
}
// 如果只有一个文本消息段,直接返回文本字符串
if len(segments) == 1 && segments[0].Type == SegmentTypeText {
if text, ok := segments[0].Data["text"].(string); ok {
return text
}
}
return segments
}
// convertSegmentToOB11 将通用消息段转换为 OneBot11 消息段
func convertSegmentToOB11(seg protocol.MessageSegment) *MessageSegment {
switch seg.Type {
case protocol.SegmentTypeText:
// 文本消息段
return &MessageSegment{
Type: SegmentTypeText,
Data: map[string]interface{}{
"text": seg.Data["text"],
},
}
case protocol.SegmentTypeMention:
// @提及OneBot12 -> OneBot11
userID := seg.Data["user_id"]
return &MessageSegment{
Type: SegmentTypeAt,
Data: map[string]interface{}{
"qq": userID,
},
}
case protocol.SegmentTypeAt:
// OneBot11 兼容格式
userID, ok := seg.Data["user_id"]
if !ok {
userID = seg.Data["qq"]
}
return &MessageSegment{
Type: SegmentTypeAt,
Data: map[string]interface{}{
"qq": userID,
},
}
case protocol.SegmentTypeImage:
// 图片消息段
fileID, ok := seg.Data["file_id"].(string)
if !ok {
// 兼容 file 字段
fileID, _ = seg.Data["file"].(string)
}
return &MessageSegment{
Type: SegmentTypeImage,
Data: map[string]interface{}{
"file": fileID,
},
}
case protocol.SegmentTypeVoice:
// 语音消息段OneBot12 -> OneBot11
fileID, ok := seg.Data["file_id"].(string)
if !ok {
fileID, _ = seg.Data["file"].(string)
}
return &MessageSegment{
Type: SegmentTypeRecord,
Data: map[string]interface{}{
"file": fileID,
},
}
case protocol.SegmentTypeRecord:
// OneBot11 兼容格式
fileID, ok := seg.Data["file"].(string)
if !ok {
fileID, _ = seg.Data["file_id"].(string)
}
return &MessageSegment{
Type: SegmentTypeRecord,
Data: map[string]interface{}{
"file": fileID,
},
}
case protocol.SegmentTypeVideo:
// 视频消息段
fileID, ok := seg.Data["file_id"].(string)
if !ok {
fileID, _ = seg.Data["file"].(string)
}
return &MessageSegment{
Type: SegmentTypeVideo,
Data: map[string]interface{}{
"file": fileID,
},
}
case protocol.SegmentTypeReply:
// 回复消息段
messageID := seg.Data["message_id"]
return &MessageSegment{
Type: SegmentTypeReply,
Data: map[string]interface{}{
"id": messageID,
},
}
case protocol.SegmentTypeFace:
// 表情消息段
faceID := seg.Data["id"]
return &MessageSegment{
Type: SegmentTypeFace,
Data: map[string]interface{}{
"id": faceID,
},
}
default:
// 其他类型,尝试直接转换
return &MessageSegment{
Type: seg.Type,
Data: seg.Data,
}
}
}
// ConvertOB11ToMessageChain 将 OneBot11 消息段转换为通用消息链
func ConvertOB11ToMessageChain(ob11Message interface{}) (protocol.MessageChain, error) {
chain := protocol.MessageChain{}
// 如果是字符串,转换为文本消息段
if str, ok := ob11Message.(string); ok {
return protocol.NewMessageChain(protocol.NewTextSegment(str)), nil
}
// 如果是数组,解析为消息段数组
if segments, ok := ob11Message.([]MessageSegment); ok {
for _, seg := range segments {
genericSeg := convertOB11SegmentToGeneric(seg)
chain = append(chain, genericSeg)
}
return chain, nil
}
// 如果是接口数组,尝试转换
if segments, ok := ob11Message.([]interface{}); ok {
for _, seg := range segments {
if segMap, ok := seg.(map[string]interface{}); ok {
segType, _ := segMap["type"].(string)
segData, _ := segMap["data"].(map[string]interface{})
genericSeg := convertOB11SegmentToGeneric(MessageSegment{
Type: segType,
Data: segData,
})
chain = append(chain, genericSeg)
}
}
return chain, nil
}
return nil, fmt.Errorf("unsupported message format: %T", ob11Message)
}
// convertOB11SegmentToGeneric 将 OneBot11 消息段转换为通用消息段
func convertOB11SegmentToGeneric(seg MessageSegment) protocol.MessageSegment {
switch seg.Type {
case SegmentTypeText:
return protocol.NewTextSegment(seg.Data["text"].(string))
case SegmentTypeAt:
// OneBot11 @ 转换为 OneBot12 mention
userID := seg.Data["qq"]
return protocol.NewMentionSegment(userID)
case SegmentTypeImage:
fileID, ok := seg.Data["file"].(string)
if !ok {
fileID, _ = seg.Data["file_id"].(string)
}
return protocol.NewImageSegment(fileID)
case SegmentTypeRecord:
// OneBot11 record 转换为 OneBot12 voice
fileID, ok := seg.Data["file"].(string)
if !ok {
fileID, _ = seg.Data["file_id"].(string)
}
return protocol.MessageSegment{
Type: protocol.SegmentTypeVoice,
Data: map[string]interface{}{
"file_id": fileID,
},
}
case SegmentTypeReply:
messageID := seg.Data["id"]
return protocol.NewReplySegment(messageID)
case SegmentTypeFace:
return protocol.MessageSegment{
Type: protocol.SegmentTypeFace,
Data: map[string]interface{}{
"id": seg.Data["id"],
},
}
default:
// 其他类型,直接转换
return protocol.MessageSegment{
Type: seg.Type,
Data: seg.Data,
}
}
}

View File

@@ -7,7 +7,8 @@ import (
"cellbot/internal/adapter/onebot11" "cellbot/internal/adapter/onebot11"
"cellbot/internal/config" "cellbot/internal/config"
"cellbot/internal/engine" "cellbot/internal/engine"
_ "cellbot/internal/plugins/echo" // 导入插件以触发 init 函数 _ "cellbot/internal/plugins/echo" // 导入插件以触发 init 函数
_ "cellbot/internal/plugins/welcome" // 导入插件以触发 init 函数
"cellbot/internal/protocol" "cellbot/internal/protocol"
"cellbot/pkg/net" "cellbot/pkg/net"
@@ -42,8 +43,12 @@ func ProvideEventBus(logger *zap.Logger) *engine.EventBus {
return engine.NewEventBus(logger, 10000) return engine.NewEventBus(logger, 10000)
} }
func ProvideDispatcher(eventBus *engine.EventBus, logger *zap.Logger, cfg *config.Config) *engine.Dispatcher { func ProvideScheduler(logger *zap.Logger) *engine.Scheduler {
dispatcher := engine.NewDispatcher(eventBus, logger) return engine.NewScheduler(logger)
}
func ProvideDispatcher(eventBus *engine.EventBus, logger *zap.Logger, cfg *config.Config, scheduler *engine.Scheduler) *engine.Dispatcher {
dispatcher := engine.NewDispatcherWithScheduler(eventBus, logger, scheduler)
// 注册限流中间件 // 注册限流中间件
if cfg.Engine.RateLimit.Enabled { if cfg.Engine.RateLimit.Enabled {
@@ -97,7 +102,18 @@ func ProvideMilkyBots(cfg *config.Config, logger *zap.Logger, eventBus *engine.E
lc.Append(fx.Hook{ lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error { OnStart: func(ctx context.Context) error {
logger.Info("Starting Milky bot", zap.String("bot_id", botCfg.ID)) logger.Info("Starting Milky bot", zap.String("bot_id", botCfg.ID))
return bot.Connect(ctx) // 在后台启动连接,失败时只记录错误,不终止应用
go func() {
if err := bot.Connect(context.Background()); err != nil {
logger.Error("Failed to connect Milky bot, will retry in background",
zap.String("bot_id", botCfg.ID),
zap.Error(err))
// 可以在这里实现重试逻辑
} else {
logger.Info("Milky bot connected successfully", zap.String("bot_id", botCfg.ID))
}
}()
return nil
}, },
OnStop: func(ctx context.Context) error { OnStop: func(ctx context.Context) error {
logger.Info("Stopping Milky bot", zap.String("bot_id", botCfg.ID)) logger.Info("Stopping Milky bot", zap.String("bot_id", botCfg.ID))
@@ -137,7 +153,18 @@ func ProvideOneBot11Bots(cfg *config.Config, logger *zap.Logger, wsManager *net.
lc.Append(fx.Hook{ lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error { OnStart: func(ctx context.Context) error {
logger.Info("Starting OneBot11 bot", zap.String("bot_id", botCfg.ID)) logger.Info("Starting OneBot11 bot", zap.String("bot_id", botCfg.ID))
return bot.Connect(ctx) // 在后台启动连接,失败时只记录错误,不终止应用
go func() {
if err := bot.Connect(context.Background()); err != nil {
logger.Error("Failed to connect OneBot11 bot, will retry in background",
zap.String("bot_id", botCfg.ID),
zap.Error(err))
// 可以在这里实现重试逻辑
} else {
logger.Info("OneBot11 bot connected successfully", zap.String("bot_id", botCfg.ID))
}
}()
return nil
}, },
OnStop: func(ctx context.Context) error { OnStop: func(ctx context.Context) error {
logger.Info("Stopping OneBot11 bot", zap.String("bot_id", botCfg.ID)) logger.Info("Stopping OneBot11 bot", zap.String("bot_id", botCfg.ID))
@@ -171,12 +198,18 @@ func LoadPlugins(logger *zap.Logger, botManager *protocol.BotManager, registry *
zap.Strings("plugins", engine.GetRegisteredPlugins())) zap.Strings("plugins", engine.GetRegisteredPlugins()))
} }
// LoadScheduledJobs 加载所有定时任务(由依赖注入系统调用)
func LoadScheduledJobs(scheduler *engine.Scheduler, logger *zap.Logger) error {
return engine.LoadAllJobs(scheduler, logger)
}
var Providers = fx.Options( var Providers = fx.Options(
fx.Provide( fx.Provide(
ProvideConfig, ProvideConfig,
ProvideConfigManager, ProvideConfigManager,
ProvideLogger, ProvideLogger,
ProvideEventBus, ProvideEventBus,
ProvideScheduler,
ProvideDispatcher, ProvideDispatcher,
ProvidePluginRegistry, ProvidePluginRegistry,
ProvideBotManager, ProvideBotManager,
@@ -186,4 +219,5 @@ var Providers = fx.Options(
fx.Invoke(ProvideMilkyBots), fx.Invoke(ProvideMilkyBots),
fx.Invoke(ProvideOneBot11Bots), fx.Invoke(ProvideOneBot11Bots),
fx.Invoke(LoadPlugins), fx.Invoke(LoadPlugins),
fx.Invoke(LoadScheduledJobs),
) )

View File

@@ -30,6 +30,7 @@ type Dispatcher struct {
middlewares []protocol.Middleware middlewares []protocol.Middleware
logger *zap.Logger logger *zap.Logger
eventBus *EventBus eventBus *EventBus
scheduler *Scheduler
metrics DispatcherMetrics metrics DispatcherMetrics
mu sync.RWMutex mu sync.RWMutex
workerPool chan struct{} // 工作池,限制并发数 workerPool chan struct{} // 工作池,限制并发数
@@ -43,6 +44,13 @@ func NewDispatcher(eventBus *EventBus, logger *zap.Logger) *Dispatcher {
return NewDispatcherWithConfig(eventBus, logger, 100, true) return NewDispatcherWithConfig(eventBus, logger, 100, true)
} }
// NewDispatcherWithScheduler 创建带调度器的事件分发器
func NewDispatcherWithScheduler(eventBus *EventBus, logger *zap.Logger, scheduler *Scheduler) *Dispatcher {
dispatcher := NewDispatcherWithConfig(eventBus, logger, 100, true)
dispatcher.scheduler = scheduler
return dispatcher
}
// NewDispatcherWithConfig 使用配置创建事件分发器 // NewDispatcherWithConfig 使用配置创建事件分发器
func NewDispatcherWithConfig(eventBus *EventBus, logger *zap.Logger, maxWorkers int, async bool) *Dispatcher { func NewDispatcherWithConfig(eventBus *EventBus, logger *zap.Logger, maxWorkers int, async bool) *Dispatcher {
if maxWorkers <= 0 { if maxWorkers <= 0 {
@@ -114,14 +122,37 @@ func (d *Dispatcher) Start(ctx context.Context) {
go d.eventLoop(ctx, eventChan) go d.eventLoop(ctx, eventChan)
} }
// 启动调度器
if d.scheduler != nil {
if err := d.scheduler.Start(); err != nil {
d.logger.Error("Failed to start scheduler", zap.Error(err))
} else {
d.logger.Info("Scheduler started")
}
}
d.logger.Info("Dispatcher started") d.logger.Info("Dispatcher started")
} }
// Stop 停止分发器 // Stop 停止分发器
func (d *Dispatcher) Stop() { func (d *Dispatcher) Stop() {
// 停止调度器
if d.scheduler != nil {
if err := d.scheduler.Stop(); err != nil {
d.logger.Error("Failed to stop scheduler", zap.Error(err))
} else {
d.logger.Info("Scheduler stopped")
}
}
d.logger.Info("Dispatcher stopped") d.logger.Info("Dispatcher stopped")
} }
// GetScheduler 获取调度器
func (d *Dispatcher) GetScheduler() *Scheduler {
return d.scheduler
}
// eventLoop 事件循环 // eventLoop 事件循环
func (d *Dispatcher) eventLoop(ctx context.Context, eventChan chan protocol.Event) { func (d *Dispatcher) eventLoop(ctx context.Context, eventChan chan protocol.Event) {
for { for {
@@ -215,13 +246,16 @@ func (d *Dispatcher) createHandlerChain(ctx context.Context, event protocol.Even
for i, handler := range handlers { for i, handler := range handlers {
matched := handler.Match(event) matched := handler.Match(event)
d.logger.Info("Checking handler", d.logger.Debug("Checking handler",
zap.Int("handler_index", i), zap.Int("handler_index", i),
zap.String("handler_name", handler.Name()),
zap.Int("priority", handler.Priority()), zap.Int("priority", handler.Priority()),
zap.Bool("matched", matched)) zap.Bool("matched", matched))
if matched { if matched {
d.logger.Info("Handler matched, calling Handle", d.logger.Info("Handler matched, calling Handle",
zap.Int("handler_index", i)) zap.Int("handler_index", i),
zap.String("handler_name", handler.Name()),
zap.String("handler_description", handler.Description()))
// 使用defer捕获单个handler的panic // 使用defer捕获单个handler的panic
func() { func() {
defer func() { defer func() {

View File

@@ -78,9 +78,10 @@ func (eb *EventBus) Stop() {
// Publish 发布事件 // Publish 发布事件
func (eb *EventBus) Publish(event protocol.Event) { func (eb *EventBus) Publish(event protocol.Event) {
eb.logger.Info("Publishing event to channel", eb.logger.Debug("Publishing event to channel",
zap.String("event_type", string(event.GetType())), zap.String("event_type", string(event.GetType())),
zap.String("detail_type", event.GetDetailType()), zap.String("detail_type", event.GetDetailType()),
zap.String("self_id", event.GetSelfID()),
zap.Int("channel_len", len(eb.eventChan)), zap.Int("channel_len", len(eb.eventChan)),
zap.Int("channel_cap", cap(eb.eventChan))) zap.Int("channel_cap", cap(eb.eventChan)))
@@ -88,8 +89,10 @@ func (eb *EventBus) Publish(event protocol.Event) {
case eb.eventChan <- event: case eb.eventChan <- event:
atomic.AddInt64(&eb.metrics.PublishedTotal, 1) atomic.AddInt64(&eb.metrics.PublishedTotal, 1)
atomic.StoreInt64(&eb.metrics.LastEventTime, time.Now().Unix()) atomic.StoreInt64(&eb.metrics.LastEventTime, time.Now().Unix())
eb.logger.Info("Event successfully queued", eb.logger.Info("Event published successfully",
zap.String("event_type", string(event.GetType()))) zap.String("event_type", string(event.GetType())),
zap.String("detail_type", event.GetDetailType()),
zap.String("self_id", event.GetSelfID()))
case <-eb.ctx.Done(): case <-eb.ctx.Done():
atomic.AddInt64(&eb.metrics.DroppedTotal, 1) atomic.AddInt64(&eb.metrics.DroppedTotal, 1)
eb.logger.Warn("Event bus is shutting down, event dropped", eb.logger.Warn("Event bus is shutting down, event dropped",

View File

@@ -3,6 +3,7 @@ package engine
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -208,11 +209,16 @@ var (
// HandlerFunc 处理函数类型(支持依赖注入) // HandlerFunc 处理函数类型(支持依赖注入)
type HandlerFunc func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error type HandlerFunc func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error
// HandlerMiddleware 处理器中间件函数类型
// 返回 true 表示通过中间件检查false 表示不通过
type HandlerMiddleware func(event protocol.Event) bool
// HandlerBuilder 处理器构建器(类似 ZeroBot 的 API // HandlerBuilder 处理器构建器(类似 ZeroBot 的 API
type HandlerBuilder struct { type HandlerBuilder struct {
matchFunc func(protocol.Event) bool matchFunc func(protocol.Event) bool
priority int priority int
handleFunc HandlerFunc handleFunc HandlerFunc
middlewares []HandlerMiddleware
} }
// OnPrivateMessage 匹配私聊消息 // OnPrivateMessage 匹配私聊消息
@@ -246,20 +252,82 @@ func OnMessage() *HandlerBuilder {
} }
// OnNotice 匹配通知事件 // OnNotice 匹配通知事件
func OnNotice() *HandlerBuilder { // 用法:
//
// OnNotice() - 匹配所有通知事件
// OnNotice("group_increase") - 匹配群成员增加事件
// OnNotice("group_increase", "group_decrease") - 匹配群成员增加或减少事件
func OnNotice(detailTypes ...string) *HandlerBuilder {
return &HandlerBuilder{ return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool { matchFunc: func(event protocol.Event) bool {
return event.GetType() == protocol.EventTypeNotice if event.GetType() != protocol.EventTypeNotice {
return false
}
// 如果没有指定类型,匹配所有通知事件
if len(detailTypes) == 0 {
return true
}
// 检查 detail_type 是否在指定列表中
eventDetailType := event.GetDetailType()
for _, dt := range detailTypes {
if dt == eventDetailType {
return true
}
}
return false
}, },
priority: 100, priority: 100,
} }
} }
// OnRequest 匹配请求事件 // OnRequest 匹配请求事件
func OnRequest() *HandlerBuilder { // 用法:
//
// OnRequest() - 匹配所有请求事件
// OnRequest("friend") - 匹配好友请求事件
// OnRequest("friend", "group") - 匹配好友或群请求事件
func OnRequest(detailTypes ...string) *HandlerBuilder {
return &HandlerBuilder{ return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool { matchFunc: func(event protocol.Event) bool {
return event.GetType() == protocol.EventTypeRequest if event.GetType() != protocol.EventTypeRequest {
return false
}
// 如果没有指定类型,匹配所有请求事件
if len(detailTypes) == 0 {
return true
}
// 检查 detail_type 是否在指定列表中
eventDetailType := event.GetDetailType()
for _, dt := range detailTypes {
if dt == eventDetailType {
return true
}
}
return false
},
priority: 100,
}
}
// OnEvent 匹配指定类型的事件(可传一个或多个 EventType
// 用法:
//
// OnEvent() - 匹配所有事件
// OnEvent(protocol.EventTypeMessage) - 匹配消息事件
// OnEvent(protocol.EventTypeMessage, protocol.EventTypeNotice) - 匹配消息和通知事件
func OnEvent(eventTypes ...protocol.EventType) *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
if len(eventTypes) == 0 {
return true // 不传参数时匹配所有事件
}
// 检查事件类型是否在指定列表中
for _, et := range eventTypes {
if event.GetType() == et {
return true
}
}
return false
}, },
priority: 100, priority: 100,
} }
@@ -273,8 +341,13 @@ func On(matchFunc func(protocol.Event) bool) *HandlerBuilder {
} }
} }
// OnCommand 匹配命令(以指定前缀开头的消息) // OnCommand 匹配命令
func OnCommand(prefix string) *HandlerBuilder { // 用法:
//
// OnCommand("/help") - 匹配 /help 命令(前缀为 /,命令为 help
// OnCommand("/", "help") - 匹配以 / 开头且命令为 help 的消息
// OnCommand("/", "help", "h") - 匹配以 / 开头且命令为 help 或 h 的消息
func OnCommand(prefix string, commands ...string) *HandlerBuilder {
return &HandlerBuilder{ return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool { matchFunc: func(event protocol.Event) bool {
if event.GetType() != protocol.EventTypeMessage { if event.GetType() != protocol.EventTypeMessage {
@@ -285,9 +358,35 @@ func OnCommand(prefix string) *HandlerBuilder {
if !ok { if !ok {
return false return false
} }
// 检查是否以命令前缀开头
if len(rawMessage) > 0 && len(prefix) > 0 { // 检查是否以前缀开头
return len(rawMessage) >= len(prefix) && rawMessage[:len(prefix)] == prefix if len(rawMessage) < len(prefix) || rawMessage[:len(prefix)] != prefix {
return false
}
// 如果没有指定具体命令,匹配所有以该前缀开头的消息
if len(commands) == 0 {
return true
}
// 提取命令部分(去除前缀和空格)
cmdText := strings.TrimSpace(rawMessage[len(prefix):])
if cmdText == "" {
return false
}
// 获取第一个单词作为命令
parts := strings.Fields(cmdText)
if len(parts) == 0 {
return false
}
cmd := parts[0]
// 检查是否匹配指定的命令
for _, c := range commands {
if cmd == c {
return true
}
} }
return false return false
}, },
@@ -364,12 +463,56 @@ func contains(s, substr string) bool {
return false return false
} }
// OnFullMatch 完全匹配文本
func OnFullMatch(text string) *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
if event.GetType() != protocol.EventTypeMessage {
return false
}
data := event.GetData()
rawMessage, ok := data["raw_message"].(string)
if !ok {
return false
}
return rawMessage == text
},
priority: 100,
}
}
// OnDetailType 匹配指定 detail_type 的事件
func OnDetailType(detailType string) *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
return event.GetDetailType() == detailType
},
priority: 100,
}
}
// OnSubType 匹配指定 sub_type 的事件
func OnSubType(subType string) *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
return event.GetSubType() == subType
},
priority: 100,
}
}
// Priority 设置优先级 // Priority 设置优先级
func (b *HandlerBuilder) Priority(priority int) *HandlerBuilder { func (b *HandlerBuilder) Priority(priority int) *HandlerBuilder {
b.priority = priority b.priority = priority
return b return b
} }
// Use 添加中间件(链式调用)
func (b *HandlerBuilder) Use(middleware HandlerMiddleware) *HandlerBuilder {
b.middlewares = append(b.middlewares, middleware)
return b
}
// Handle 注册处理函数(在 init 中调用) // Handle 注册处理函数(在 init 中调用)
func (b *HandlerBuilder) Handle(handleFunc HandlerFunc) { func (b *HandlerBuilder) Handle(handleFunc HandlerFunc) {
globalHandlerMu.Lock() globalHandlerMu.Lock()
@@ -379,6 +522,16 @@ func (b *HandlerBuilder) Handle(handleFunc HandlerFunc) {
globalHandlerRegistry = append(globalHandlerRegistry, b) globalHandlerRegistry = append(globalHandlerRegistry, b)
} }
// applyMiddlewares 应用所有中间件
func (b *HandlerBuilder) applyMiddlewares(event protocol.Event) bool {
for _, middleware := range b.middlewares {
if !middleware(event) {
return false
}
}
return true
}
// generateHandlerName 生成处理器名称 // generateHandlerName 生成处理器名称
var handlerCounter int64 var handlerCounter int64
@@ -406,11 +559,21 @@ func LoadAllHandlers(botManager *protocol.BotManager, logger *zap.Logger) []prot
return builder.handleFunc(ctx, event, botManager, logger) return builder.handleFunc(ctx, event, botManager, logger)
} }
// 创建包装的匹配函数,应用中间件
matchFunc := func(event protocol.Event) bool {
// 先检查基础匹配
if builder.matchFunc != nil && !builder.matchFunc(event) {
return false
}
// 再应用中间件
return builder.applyMiddlewares(event)
}
handler := &simplePlugin{ handler := &simplePlugin{
name: pluginName, name: pluginName,
description: "Handler registered via OnXXX().Handle()", description: "Handler registered via OnXXX().Handle()",
priority: builder.priority, priority: builder.priority,
matchFunc: builder.matchFunc, matchFunc: matchFunc,
handleFunc: handleFunc, handleFunc: handleFunc,
} }
@@ -419,3 +582,129 @@ func LoadAllHandlers(botManager *protocol.BotManager, logger *zap.Logger) []prot
return handlers return handlers
} }
// ============================================================================
// 常用中间件(类似 NoneBot 风格)
// ============================================================================
// OnlyToMe 只响应@机器人的消息(群聊中)
func OnlyToMe() HandlerMiddleware {
return func(event protocol.Event) bool {
// 只对群消息生效
if event.GetType() != protocol.EventTypeMessage || event.GetDetailType() != "group" {
return true // 非群消息不检查,让其他中间件处理
}
data := event.GetData()
selfID := event.GetSelfID()
// 检查消息段中是否包含@机器人的消息
if segments, ok := data["message_segments"].([]interface{}); ok {
for _, seg := range segments {
if segMap, ok := seg.(map[string]interface{}); ok {
segType, _ := segMap["type"].(string)
if segType == "at" || segType == "mention" {
segData, _ := segMap["data"].(map[string]interface{})
// 检查是否@了机器人
if userID, ok := segData["user_id"]; ok {
if userIDStr := fmt.Sprintf("%v", userID); userIDStr == selfID {
return true
}
}
if qq, ok := segData["qq"]; ok {
if qqStr := fmt.Sprintf("%v", qq); qqStr == selfID {
return true
}
}
}
}
}
}
// 检查 raw_message 中是否包含@机器人的信息(兼容性检查)
if rawMessage, ok := data["raw_message"].(string); ok {
// 简单的检查:消息是否以 @机器人 开头
// 注意:这里需要根据实际协议调整
if strings.Contains(rawMessage, fmt.Sprintf("[CQ:at,qq=%s]", selfID)) {
return true
}
}
return false
}
}
// OnlyPrivate 只在私聊中响应
func OnlyPrivate() HandlerMiddleware {
return func(event protocol.Event) bool {
return event.GetType() == protocol.EventTypeMessage && event.GetDetailType() == "private"
}
}
// OnlyGroup 只在群聊中响应(消息事件)或群相关事件(通知/请求事件)
func OnlyGroup() HandlerMiddleware {
return func(event protocol.Event) bool {
// 消息事件:检查 detail_type
if event.GetType() == protocol.EventTypeMessage {
return event.GetDetailType() == "group"
}
// 通知/请求事件:检查是否有 group_id
data := event.GetData()
_, hasGroupID := data["group_id"]
return hasGroupID
}
}
// OnlySuperuser 只允许超级用户(需要从配置或数据中获取)
func OnlySuperuser(superusers []string) HandlerMiddleware {
return func(event protocol.Event) bool {
data := event.GetData()
userID, ok := data["user_id"]
if !ok {
return false
}
userIDStr := fmt.Sprintf("%v", userID)
for _, su := range superusers {
if su == userIDStr {
return true
}
}
return false
}
}
// BlockPrivate 阻止私聊消息
func BlockPrivate() HandlerMiddleware {
return func(event protocol.Event) bool {
return !(event.GetType() == protocol.EventTypeMessage && event.GetDetailType() == "private")
}
}
// BlockGroup 阻止群聊消息
func BlockGroup() HandlerMiddleware {
return func(event protocol.Event) bool {
// 消息事件:检查 detail_type
if event.GetType() == protocol.EventTypeMessage {
return event.GetDetailType() != "group"
}
// 通知/请求事件:检查是否有 group_id
data := event.GetData()
_, hasGroupID := data["group_id"]
return !hasGroupID
}
}
// OnlyDetailType 只匹配指定的 detail_type
func OnlyDetailType(detailType string) HandlerMiddleware {
return func(event protocol.Event) bool {
return event.GetDetailType() == detailType
}
}
// OnlySubType 只匹配指定的 sub_type
func OnlySubType(subType string) HandlerMiddleware {
return func(event protocol.Event) bool {
return event.GetSubType() == subType
}
}

View File

@@ -0,0 +1,631 @@
package engine
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/robfig/cron/v3"
"go.uber.org/zap"
)
// Job 定时任务接口
type Job interface {
// ID 返回任务唯一标识
ID() string
// Start 启动任务
Start(ctx context.Context) error
// Stop 停止任务
Stop() error
// IsRunning 检查任务是否正在运行
IsRunning() bool
// NextRun 返回下次执行时间
NextRun() time.Time
}
// JobFunc 任务执行函数类型
type JobFunc func(ctx context.Context) error
// Scheduler 定时任务调度器
type Scheduler struct {
jobs map[string]Job
mu sync.RWMutex
logger *zap.Logger
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
running int32
}
// NewScheduler 创建新的调度器
func NewScheduler(logger *zap.Logger) *Scheduler {
ctx, cancel := context.WithCancel(context.Background())
return &Scheduler{
jobs: make(map[string]Job),
logger: logger.Named("scheduler"),
ctx: ctx,
cancel: cancel,
}
}
// Start 启动调度器
func (s *Scheduler) Start() error {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return fmt.Errorf("scheduler is already running")
}
s.mu.RLock()
defer s.mu.RUnlock()
// 启动所有任务
for id, job := range s.jobs {
if err := job.Start(s.ctx); err != nil {
s.logger.Error("Failed to start job",
zap.String("job_id", id),
zap.Error(err))
continue
}
s.logger.Info("Job started", zap.String("job_id", id))
}
s.logger.Info("Scheduler started", zap.Int("job_count", len(s.jobs)))
return nil
}
// Stop 停止调度器
func (s *Scheduler) Stop() error {
if !atomic.CompareAndSwapInt32(&s.running, 1, 0) {
return fmt.Errorf("scheduler is not running")
}
s.cancel()
s.mu.RLock()
defer s.mu.RUnlock()
// 停止所有任务
for id, job := range s.jobs {
if err := job.Stop(); err != nil {
s.logger.Error("Failed to stop job",
zap.String("job_id", id),
zap.Error(err))
continue
}
s.logger.Info("Job stopped", zap.String("job_id", id))
}
s.wg.Wait()
s.logger.Info("Scheduler stopped")
return nil
}
// AddJob 添加任务
func (s *Scheduler) AddJob(job Job) error {
s.mu.Lock()
defer s.mu.Unlock()
id := job.ID()
if _, exists := s.jobs[id]; exists {
return fmt.Errorf("job with id %s already exists", id)
}
s.jobs[id] = job
// 如果调度器正在运行,立即启动任务
if atomic.LoadInt32(&s.running) == 1 {
if err := job.Start(s.ctx); err != nil {
delete(s.jobs, id)
return fmt.Errorf("failed to start job: %w", err)
}
s.logger.Info("Job added and started", zap.String("job_id", id))
} else {
s.logger.Info("Job added", zap.String("job_id", id))
}
return nil
}
// RemoveJob 移除任务
func (s *Scheduler) RemoveJob(id string) error {
s.mu.Lock()
defer s.mu.Unlock()
job, exists := s.jobs[id]
if !exists {
return fmt.Errorf("job with id %s not found", id)
}
if err := job.Stop(); err != nil {
s.logger.Error("Failed to stop job during removal",
zap.String("job_id", id),
zap.Error(err))
}
delete(s.jobs, id)
s.logger.Info("Job removed", zap.String("job_id", id))
return nil
}
// GetJob 获取任务
func (s *Scheduler) GetJob(id string) (Job, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
job, exists := s.jobs[id]
return job, exists
}
// GetAllJobs 获取所有任务
func (s *Scheduler) GetAllJobs() map[string]Job {
s.mu.RLock()
defer s.mu.RUnlock()
result := make(map[string]Job, len(s.jobs))
for id, job := range s.jobs {
result[id] = job
}
return result
}
// IsRunning 检查调度器是否正在运行
func (s *Scheduler) IsRunning() bool {
return atomic.LoadInt32(&s.running) == 1
}
// ============================================================================
// Job 实现
// ============================================================================
// CronJob 基于 Cron 表达式的任务
type CronJob struct {
id string
spec string
handler JobFunc
cron *cron.Cron
logger *zap.Logger
running int32
nextRun time.Time
mu sync.RWMutex
}
// NewCronJob 创建 Cron 任务
func NewCronJob(id, spec string, handler JobFunc, logger *zap.Logger) (*CronJob, error) {
parser := cron.NewParser(cron.Second | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)
c := cron.New(cron.WithParser(parser), cron.WithChain(cron.Recover(cron.DefaultLogger)))
job := &CronJob{
id: id,
spec: spec,
handler: handler,
cron: c,
logger: logger.Named("cron-job").With(zap.String("job_id", id)),
}
// 添加任务到 cron
_, err := c.AddFunc(spec, func() {
ctx := context.Background()
if err := handler(ctx); err != nil {
job.logger.Error("Cron job execution failed", zap.Error(err))
}
// 更新下次执行时间
entries := c.Entries()
job.mu.Lock()
if len(entries) > 0 {
// 找到最近的执行时间
next := entries[0].Next
for _, entry := range entries {
if entry.Next.Before(next) {
next = entry.Next
}
}
job.nextRun = next
}
job.mu.Unlock()
})
if err != nil {
return nil, fmt.Errorf("invalid cron spec: %w", err)
}
// 计算初始下次执行时间(需要先启动 cron 才能计算)
// 这里先设置为零值,在 Start 时再计算
job.mu.Lock()
job.nextRun = time.Time{}
job.mu.Unlock()
return job, nil
}
func (j *CronJob) ID() string {
return j.id
}
func (j *CronJob) Start(ctx context.Context) error {
if !atomic.CompareAndSwapInt32(&j.running, 0, 1) {
return fmt.Errorf("job is already running")
}
j.cron.Start()
j.logger.Info("Cron job started", zap.String("spec", j.spec))
// 更新下次执行时间
entries := j.cron.Entries()
if len(entries) > 0 {
j.mu.Lock()
// 找到最近的执行时间
next := entries[0].Next
for _, entry := range entries {
if entry.Next.Before(next) {
next = entry.Next
}
}
j.nextRun = next
j.mu.Unlock()
}
return nil
}
func (j *CronJob) Stop() error {
if !atomic.CompareAndSwapInt32(&j.running, 1, 0) {
return fmt.Errorf("job is not running")
}
ctx := j.cron.Stop()
<-ctx.Done()
j.logger.Info("Cron job stopped")
return nil
}
func (j *CronJob) IsRunning() bool {
return atomic.LoadInt32(&j.running) == 1
}
func (j *CronJob) NextRun() time.Time {
j.mu.RLock()
defer j.mu.RUnlock()
return j.nextRun
}
// IntervalJob 固定间隔的任务
type IntervalJob struct {
id string
interval time.Duration
handler JobFunc
logger *zap.Logger
running int32
nextRun time.Time
mu sync.RWMutex
ticker *time.Ticker
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// NewIntervalJob 创建固定间隔任务
func NewIntervalJob(id string, interval time.Duration, handler JobFunc, logger *zap.Logger) *IntervalJob {
return &IntervalJob{
id: id,
interval: interval,
handler: handler,
logger: logger.Named("interval-job").With(zap.String("job_id", id)),
}
}
func (j *IntervalJob) ID() string {
return j.id
}
func (j *IntervalJob) Start(ctx context.Context) error {
if !atomic.CompareAndSwapInt32(&j.running, 0, 1) {
return fmt.Errorf("job is already running")
}
j.ctx, j.cancel = context.WithCancel(ctx)
j.ticker = time.NewTicker(j.interval)
j.mu.Lock()
j.nextRun = time.Now().Add(j.interval)
j.mu.Unlock()
j.wg.Add(1)
go j.run()
j.logger.Info("Interval job started", zap.Duration("interval", j.interval))
return nil
}
func (j *IntervalJob) run() {
defer j.wg.Done()
// 立即执行一次(可选,根据需求调整)
// if err := j.handler(j.ctx); err != nil {
// j.logger.Error("Interval job execution failed", zap.Error(err))
// }
for {
select {
case <-j.ticker.C:
j.mu.Lock()
j.nextRun = time.Now().Add(j.interval)
j.mu.Unlock()
if err := j.handler(j.ctx); err != nil {
j.logger.Error("Interval job execution failed", zap.Error(err))
}
case <-j.ctx.Done():
return
}
}
}
func (j *IntervalJob) Stop() error {
if !atomic.CompareAndSwapInt32(&j.running, 1, 0) {
return fmt.Errorf("job is not running")
}
if j.cancel != nil {
j.cancel()
}
if j.ticker != nil {
j.ticker.Stop()
}
j.wg.Wait()
j.logger.Info("Interval job stopped")
return nil
}
func (j *IntervalJob) IsRunning() bool {
return atomic.LoadInt32(&j.running) == 1
}
func (j *IntervalJob) NextRun() time.Time {
j.mu.RLock()
defer j.mu.RUnlock()
return j.nextRun
}
// OnceJob 单次延迟执行的任务
type OnceJob struct {
id string
delay time.Duration
handler JobFunc
logger *zap.Logger
running int32
nextRun time.Time
mu sync.RWMutex
timer *time.Timer
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// NewOnceJob 创建单次延迟执行任务
func NewOnceJob(id string, delay time.Duration, handler JobFunc, logger *zap.Logger) *OnceJob {
return &OnceJob{
id: id,
delay: delay,
handler: handler,
logger: logger.Named("once-job").With(zap.String("job_id", id)),
}
}
func (j *OnceJob) ID() string {
return j.id
}
func (j *OnceJob) Start(ctx context.Context) error {
if !atomic.CompareAndSwapInt32(&j.running, 0, 1) {
return fmt.Errorf("job is already running")
}
j.ctx, j.cancel = context.WithCancel(ctx)
j.timer = time.NewTimer(j.delay)
j.mu.Lock()
j.nextRun = time.Now().Add(j.delay)
j.mu.Unlock()
j.wg.Add(1)
go j.run()
j.logger.Info("Once job started", zap.Duration("delay", j.delay))
return nil
}
func (j *OnceJob) run() {
defer j.wg.Done()
select {
case <-j.timer.C:
if err := j.handler(j.ctx); err != nil {
j.logger.Error("Once job execution failed", zap.Error(err))
}
atomic.StoreInt32(&j.running, 0)
case <-j.ctx.Done():
if !j.timer.Stop() {
<-j.timer.C
}
return
}
}
func (j *OnceJob) Stop() error {
if !atomic.CompareAndSwapInt32(&j.running, 1, 0) {
return fmt.Errorf("job is not running")
}
if j.cancel != nil {
j.cancel()
}
if j.timer != nil {
if !j.timer.Stop() {
<-j.timer.C
}
}
j.wg.Wait()
j.logger.Info("Once job stopped")
return nil
}
func (j *OnceJob) IsRunning() bool {
return atomic.LoadInt32(&j.running) == 1
}
func (j *OnceJob) NextRun() time.Time {
j.mu.RLock()
defer j.mu.RUnlock()
return j.nextRun
}
// ============================================================================
// 全局调度器 API链式风格延迟注册
// ============================================================================
var (
globalJobRegistry = make([]JobBuilder, 0)
globalJobMu sync.RWMutex
jobCounter int64
)
// JobBuilder 任务构建器接口(延迟注册)
type JobBuilder interface {
// Build 构建任务实例(由依赖注入系统调用)
Build(logger *zap.Logger) (Job, error)
}
// generateJobID 生成任务 ID
func generateJobID(prefix string) string {
counter := atomic.AddInt64(&jobCounter, 1)
return fmt.Sprintf("%s_%d", prefix, counter)
}
// CronJobBuilder Cron 任务构建器
type CronJobBuilder struct {
id string
spec string
handler JobFunc
}
// Cron 创建 Cron 任务构建器(在 init 函数中调用)
func Cron(spec string) *CronJobBuilder {
return &CronJobBuilder{
id: generateJobID("cron"),
spec: spec,
}
}
// Handle 设置处理函数并注册到全局注册表(延迟注册)
func (b *CronJobBuilder) Handle(handler JobFunc) {
b.handler = handler
if b.handler == nil {
panic("scheduler: handler cannot be nil")
}
globalJobMu.Lock()
defer globalJobMu.Unlock()
globalJobRegistry = append(globalJobRegistry, b)
}
// Build 构建 Cron 任务
func (b *CronJobBuilder) Build(logger *zap.Logger) (Job, error) {
return NewCronJob(b.id, b.spec, b.handler, logger)
}
// IntervalJobBuilder 固定间隔任务构建器
type IntervalJobBuilder struct {
id string
interval time.Duration
handler JobFunc
}
// Interval 创建固定间隔任务构建器(在 init 函数中调用)
func Interval(interval time.Duration) *IntervalJobBuilder {
return &IntervalJobBuilder{
id: generateJobID("interval"),
interval: interval,
}
}
// Handle 设置处理函数并注册到全局注册表(延迟注册)
func (b *IntervalJobBuilder) Handle(handler JobFunc) {
b.handler = handler
if b.handler == nil {
panic("scheduler: handler cannot be nil")
}
globalJobMu.Lock()
defer globalJobMu.Unlock()
globalJobRegistry = append(globalJobRegistry, b)
}
// Build 构建固定间隔任务
func (b *IntervalJobBuilder) Build(logger *zap.Logger) (Job, error) {
return NewIntervalJob(b.id, b.interval, b.handler, logger), nil
}
// OnceJobBuilder 单次延迟任务构建器
type OnceJobBuilder struct {
id string
delay time.Duration
handler JobFunc
}
// Once 创建单次延迟任务构建器(在 init 函数中调用)
func Once(delay time.Duration) *OnceJobBuilder {
return &OnceJobBuilder{
id: generateJobID("once"),
delay: delay,
}
}
// Handle 设置处理函数并注册到全局注册表(延迟注册)
func (b *OnceJobBuilder) Handle(handler JobFunc) {
b.handler = handler
if b.handler == nil {
panic("scheduler: handler cannot be nil")
}
globalJobMu.Lock()
defer globalJobMu.Unlock()
globalJobRegistry = append(globalJobRegistry, b)
}
// Build 构建单次延迟任务
func (b *OnceJobBuilder) Build(logger *zap.Logger) (Job, error) {
return NewOnceJob(b.id, b.delay, b.handler, logger), nil
}
// LoadAllJobs 加载所有已注册的任务(由依赖注入系统调用)
func LoadAllJobs(scheduler *Scheduler, logger *zap.Logger) error {
globalJobMu.RLock()
defer globalJobMu.RUnlock()
for i, builder := range globalJobRegistry {
job, err := builder.Build(logger)
if err != nil {
logger.Error("Failed to build job",
zap.Int("index", i),
zap.Error(err))
continue
}
if err := scheduler.AddJob(job); err != nil {
logger.Error("Failed to add job to scheduler",
zap.String("job_id", job.ID()),
zap.Error(err))
continue
}
logger.Debug("Job loaded",
zap.String("job_id", job.ID()))
}
logger.Info("All scheduled jobs loaded",
zap.Int("job_count", len(globalJobRegistry)))
return nil
}

View File

@@ -12,8 +12,9 @@ import (
func init() { func init() {
// 在 init 函数中注册多个处理函数(类似 ZeroBot 风格) // 在 init 函数中注册多个处理函数(类似 ZeroBot 风格)
// 处理私聊消息 // 处理私聊消息(使用 OnlyPrivate 中间件,虽然 OnPrivateMessage 已经匹配私聊,这里作为示例)
engine.OnPrivateMessage(). engine.OnPrivateMessage().
Use(engine.OnlyPrivate()). // 只在私聊中响应
Handle(func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error { Handle(func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error {
// 获取消息内容 // 获取消息内容
data := event.GetData() data := event.GetData()

View File

@@ -0,0 +1,347 @@
package welcome
import (
"context"
"fmt"
"time"
"cellbot/internal/engine"
"cellbot/internal/protocol"
"cellbot/pkg/utils"
"go.uber.org/zap"
)
func init() {
// 监听群成员增加通知事件
engine.OnNotice("group_increase").
Priority(50). // 设置较高优先级,确保及时响应
Handle(handleWelcomeEvent)
}
// handleWelcomeEvent 处理群成员加入欢迎事件
func handleWelcomeEvent(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error {
logger.Info("Welcome event received",
zap.String("event_type", string(event.GetType())),
zap.String("detail_type", event.GetDetailType()),
zap.String("sub_type", event.GetSubType()),
zap.String("self_id", event.GetSelfID()))
// 注意:中间件已经过滤了 detail_type这里可以简化检查
data := event.GetData()
logger.Debug("Event data", zap.Any("data", data))
// 获取群ID和用户ID
groupID, ok := data["group_id"]
if !ok {
logger.Warn("Missing group_id in event data")
return nil
}
userID, ok := data["user_id"]
if !ok {
logger.Warn("Missing user_id in event data")
return nil
}
logger.Info("Processing welcome event",
zap.Any("group_id", groupID),
zap.Any("user_id", userID))
// 获取操作者ID邀请者或审批者
var operatorID interface{}
if opID, exists := data["operator_id"]; exists {
operatorID = opID
}
// 获取子类型approve: 管理员同意, invite: 邀请)
subType := event.GetSubType()
// 获取 bot 实例
selfID := event.GetSelfID()
bot, ok := botManager.Get(selfID)
if !ok {
bots := botManager.GetAll()
if len(bots) == 0 {
logger.Error("No bot instance available")
return nil
}
bot = bots[0]
}
// 构建欢迎消息链使用HTML模板渲染图片
logger.Info("Building welcome message",
zap.Any("user_id", userID),
zap.Any("operator_id", operatorID),
zap.String("sub_type", subType))
welcomeChain, err := buildWelcomeMessage(ctx, userID, operatorID, subType, logger)
if err != nil {
logger.Error("Failed to build welcome message",
zap.Any("user_id", userID),
zap.Any("operator_id", operatorID),
zap.String("sub_type", subType),
zap.Error(err))
return err
}
logger.Info("Welcome message built successfully",
zap.Int("chain_length", len(welcomeChain)))
// 发送群消息
logger.Info("Sending welcome message",
zap.Any("group_id", groupID),
zap.Any("user_id", userID))
action := &protocol.BaseAction{
Type: protocol.ActionTypeSendGroupMessage,
Params: map[string]interface{}{
"group_id": groupID,
"message": welcomeChain,
},
}
result, err := bot.SendAction(ctx, action)
if err != nil {
logger.Error("Failed to send welcome message",
zap.Any("group_id", groupID),
zap.Any("user_id", userID),
zap.Any("action_type", action.Type),
zap.Error(err))
return err
}
logger.Info("Welcome message sent successfully",
zap.Any("group_id", groupID),
zap.Any("user_id", userID),
zap.Any("result", result))
logger.Info("Welcome message sent",
zap.Any("group_id", groupID),
zap.Any("user_id", userID),
zap.String("sub_type", subType))
return nil
}
// welcomeTemplate HTML欢迎消息模板
const welcomeTemplate = `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 40px 20px;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
}
.container {
background: white;
border-radius: 20px;
padding: 40px;
max-width: 600px;
width: 100%;
box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
}
.header {
text-align: center;
margin-bottom: 30px;
}
.welcome-icon {
font-size: 64px;
margin-bottom: 10px;
}
.title {
font-size: 32px;
font-weight: bold;
color: #333;
margin-bottom: 10px;
}
.subtitle {
font-size: 18px;
color: #666;
}
.content {
margin: 30px 0;
}
.user-info {
background: #f8f9fa;
border-radius: 10px;
padding: 20px;
margin: 20px 0;
}
.info-item {
display: flex;
justify-content: space-between;
padding: 10px 0;
border-bottom: 1px solid #e9ecef;
}
.info-item:last-child {
border-bottom: none;
}
.info-label {
font-weight: 600;
color: #495057;
}
.info-value {
color: #212529;
}
.join-type {
text-align: center;
padding: 15px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border-radius: 10px;
margin: 20px 0;
font-weight: 600;
}
.tips {
background: #e7f3ff;
border-left: 4px solid #2196F3;
padding: 15px;
border-radius: 5px;
margin: 20px 0;
}
.tips-title {
font-weight: 600;
color: #1976D2;
margin-bottom: 10px;
}
.tips-list {
list-style: none;
padding-left: 0;
}
.tips-list li {
padding: 5px 0;
color: #424242;
}
.tips-list li:before {
content: "• ";
color: #2196F3;
font-weight: bold;
}
.footer {
text-align: center;
margin-top: 30px;
color: #999;
font-size: 14px;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<div class="welcome-icon">🎉</div>
<div class="title">欢迎加入!</div>
<div class="subtitle">Welcome to the Group</div>
</div>
<div class="content">
<div class="user-info">
<div class="info-item">
<span class="info-label">用户ID</span>
<span class="info-value">{{.UserID}}</span>
</div>
{{if .OperatorID}}
<div class="info-item">
<span class="info-label">{{.OperatorLabel}}</span>
<span class="info-value">{{.OperatorID}}</span>
</div>
{{end}}
</div>
{{if .JoinType}}
<div class="join-type">
{{.JoinType}}
</div>
{{end}}
<div class="tips">
<div class="tips-title">💡 温馨提示</div>
<ul class="tips-list">
<li>请遵守群规,文明交流</li>
<li>如有问题可以@管理员</li>
<li>发送 /help 查看帮助</li>
</ul>
</div>
</div>
<div class="footer">
希望你在群里玩得开心!
</div>
</div>
</body>
</html>
`
// buildWelcomeMessage 构建欢迎消息链使用HTML模板渲染图片
func buildWelcomeMessage(ctx context.Context, userID, operatorID interface{}, subType string, logger *zap.Logger) (protocol.MessageChain, error) {
logger.Debug("Starting to build welcome message",
zap.Any("user_id", userID),
zap.Any("operator_id", operatorID),
zap.String("sub_type", subType))
// 准备模板数据
data := map[string]interface{}{
"UserID": fmt.Sprintf("%v", userID),
}
logger.Debug("Template data prepared", zap.Any("data", data))
// 根据加入方式设置不同的信息
switch subType {
case "approve":
data["JoinType"] = "✅ 管理员审批通过"
if operatorID != nil {
data["OperatorID"] = fmt.Sprintf("%v", operatorID)
data["OperatorLabel"] = "审批管理员"
}
case "invite":
data["JoinType"] = "👥 被邀请加入"
if operatorID != nil {
data["OperatorID"] = fmt.Sprintf("%v", operatorID)
data["OperatorLabel"] = "邀请人"
}
default:
data["JoinType"] = "🎊 加入群聊"
}
// 配置截图选项
opts := &utils.ScreenshotOptions{
Width: 800,
Height: 600,
Timeout: 60 * time.Second, // 增加超时时间到60秒
WaitTime: 3 * time.Second, // 增加等待时间,确保页面完全加载
FullPage: false,
Quality: 90,
Logger: logger,
}
// 渲染模板并截图
logger.Info("Rendering template and taking screenshot",
zap.Int("width", opts.Width),
zap.Int("height", opts.Height),
zap.Duration("timeout", opts.Timeout))
chain, err := utils.ScreenshotTemplateToMessageChain(ctx, welcomeTemplate, data, opts)
if err != nil {
logger.Error("Failed to render welcome template",
zap.Any("data", data),
zap.Error(err))
return nil, fmt.Errorf("failed to render welcome template: %w", err)
}
logger.Info("Template rendered and screenshot taken successfully",
zap.Int("chain_length", len(chain)))
return chain, nil
}

View File

@@ -0,0 +1,133 @@
package protocol
import "fmt"
// MessageSegment 消息段(基于 OneBot12 设计)
// 通用消息段结构,适配器负责转换为具体协议格式
type MessageSegment struct {
Type string `json:"type"` // 消息段类型
Data map[string]interface{} `json:"data"` // 消息段数据
}
// MessageChain 消息链(基于 OneBot12 设计)
// 消息由多个消息段组成
type MessageChain []MessageSegment
// 消息段类型常量(基于 OneBot12
const (
SegmentTypeText = "text" // 文本
SegmentTypeMention = "mention" // @提及OneBot12
SegmentTypeImage = "image" // 图片
SegmentTypeVoice = "voice" // 语音
SegmentTypeVideo = "video" // 视频
SegmentTypeFile = "file" // 文件
SegmentTypeLocation = "location" // 位置
SegmentTypeReply = "reply" // 回复
SegmentTypeForward = "forward" // 转发
SegmentTypeFace = "face" // 表情QQ
SegmentTypeAt = "at" // @提及OneBot11 兼容)
SegmentTypeRecord = "record" // 语音OneBot11 兼容)
)
// NewTextSegment 创建文本消息段
func NewTextSegment(text string) MessageSegment {
return MessageSegment{
Type: SegmentTypeText,
Data: map[string]interface{}{
"text": text,
},
}
}
// NewMentionSegment 创建@提及消息段OneBot12 标准)
func NewMentionSegment(userID interface{}) MessageSegment {
return MessageSegment{
Type: SegmentTypeMention,
Data: map[string]interface{}{
"user_id": userID,
},
}
}
// NewImageSegment 创建图片消息段
func NewImageSegment(fileID string) MessageSegment {
return MessageSegment{
Type: SegmentTypeImage,
Data: map[string]interface{}{
"file_id": fileID,
},
}
}
// NewImageSegmentFromBase64 从base64字符串创建图片消息段
func NewImageSegmentFromBase64(base64Data string) MessageSegment {
return MessageSegment{
Type: SegmentTypeImage,
Data: map[string]interface{}{
"file": fmt.Sprintf("base64://%s", base64Data),
},
}
}
// NewReplySegment 创建回复消息段
func NewReplySegment(messageID interface{}) MessageSegment {
return MessageSegment{
Type: SegmentTypeReply,
Data: map[string]interface{}{
"message_id": messageID,
},
}
}
// NewMessageChain 创建消息链
func NewMessageChain(segments ...MessageSegment) MessageChain {
return MessageChain(segments)
}
// Append 追加消息段到消息链
func (mc MessageChain) Append(segments ...MessageSegment) MessageChain {
return append(mc, segments...)
}
// AppendText 追加文本到消息链
func (mc MessageChain) AppendText(text string) MessageChain {
return mc.Append(NewTextSegment(text))
}
// AppendMention 追加@提及到消息链
func (mc MessageChain) AppendMention(userID interface{}) MessageChain {
return mc.Append(NewMentionSegment(userID))
}
// AppendImage 追加图片到消息链
func (mc MessageChain) AppendImage(fileID string) MessageChain {
return mc.Append(NewImageSegment(fileID))
}
// AppendImageFromBase64 从base64追加图片到消息链
func (mc MessageChain) AppendImageFromBase64(base64Data string) MessageChain {
return mc.Append(NewImageSegmentFromBase64(base64Data))
}
// ToString 将消息链转换为字符串(用于调试)
func (mc MessageChain) ToString() string {
result := ""
for _, seg := range mc {
switch seg.Type {
case SegmentTypeText:
if text, ok := seg.Data["text"].(string); ok {
result += text
}
case SegmentTypeMention, SegmentTypeAt:
if userID, ok := seg.Data["user_id"]; ok {
result += fmt.Sprintf("@%v", userID)
} else if qq, ok := seg.Data["qq"]; ok {
// OneBot11 兼容
result += fmt.Sprintf("@%v", qq)
}
default:
result += fmt.Sprintf("[%s]", seg.Type)
}
}
return result
}

348
pkg/utils/screenshot.go Normal file
View File

@@ -0,0 +1,348 @@
package utils
import (
"bytes"
"cellbot/internal/protocol"
"context"
"encoding/base64"
"fmt"
"html/template"
"os"
"path/filepath"
"time"
"github.com/chromedp/chromedp"
"go.uber.org/zap"
)
// ScreenshotOptions 截图选项
type ScreenshotOptions struct {
Width int // 视口宽度(像素)
Height int // 视口高度(像素)
Timeout time.Duration // 超时时间
WaitTime time.Duration // 等待时间(页面加载后等待)
FullPage bool // 是否截取整个页面
Quality int // 图片质量0-100仅PNG格式
Format string // 图片格式png, jpeg
Logger *zap.Logger // 日志记录器
}
// DefaultScreenshotOptions 默认截图选项
func DefaultScreenshotOptions() *ScreenshotOptions {
return &ScreenshotOptions{
Width: 1920,
Height: 1080,
Timeout: 30 * time.Second,
WaitTime: 1 * time.Second,
FullPage: false,
Quality: 90,
Format: "png",
Logger: zap.NewNop(),
}
}
// ScreenshotURL 对指定URL进行截图并返回base64编码
func ScreenshotURL(ctx context.Context, url string, opts *ScreenshotOptions) (string, error) {
if opts == nil {
opts = DefaultScreenshotOptions()
}
// 创建上下文,添加优化选项
allocCtx, cancel := chromedp.NewExecAllocator(ctx,
chromedp.NoSandbox,
chromedp.NoFirstRun,
chromedp.NoDefaultBrowserCheck,
chromedp.Headless,
chromedp.DisableGPU,
)
defer cancel()
ctx, cancel = chromedp.NewContext(allocCtx, chromedp.WithLogf(func(format string, v ...interface{}) {
if opts.Logger != nil {
opts.Logger.Debug(fmt.Sprintf(format, v...))
}
}))
defer cancel()
// 设置超时
ctx, cancel = context.WithTimeout(ctx, opts.Timeout)
defer cancel()
var buf []byte
// 执行截图任务
var err error
if opts.FullPage {
err = chromedp.Run(ctx,
chromedp.Navigate(url),
chromedp.WaitReady("body", chromedp.ByQuery), // 使用 WaitReady 等待页面完全加载
chromedp.Sleep(opts.WaitTime),
chromedp.FullScreenshot(&buf, opts.Quality),
)
} else {
err = chromedp.Run(ctx,
chromedp.Navigate(url),
chromedp.WaitReady("body", chromedp.ByQuery), // 使用 WaitReady 等待页面完全加载
chromedp.Sleep(opts.WaitTime),
chromedp.CaptureScreenshot(&buf),
)
}
if err != nil {
return "", fmt.Errorf("failed to capture screenshot: %w", err)
}
// 转换为base64
base64Str := base64.StdEncoding.EncodeToString(buf)
return base64Str, nil
}
// ScreenshotHTML 对HTML内容进行截图并返回base64编码
func ScreenshotHTML(ctx context.Context, htmlContent string, opts *ScreenshotOptions) (string, error) {
if opts == nil {
opts = DefaultScreenshotOptions()
}
if opts.Logger != nil {
opts.Logger.Info("Starting HTML screenshot",
zap.Int("html_length", len(htmlContent)),
zap.Int("width", opts.Width),
zap.Int("height", opts.Height),
zap.Duration("timeout", opts.Timeout))
}
// 创建上下文
allocCtx, cancel := chromedp.NewExecAllocator(ctx,
chromedp.NoSandbox,
chromedp.NoFirstRun,
chromedp.NoDefaultBrowserCheck,
chromedp.Headless,
chromedp.DisableGPU,
)
defer cancel()
if opts.Logger != nil {
opts.Logger.Debug("Chrome allocator created")
}
ctx, cancel = chromedp.NewContext(allocCtx, chromedp.WithLogf(func(format string, v ...interface{}) {
if opts.Logger != nil {
opts.Logger.Debug(fmt.Sprintf(format, v...))
}
}))
defer cancel()
// 设置超时
ctx, cancel = context.WithTimeout(ctx, opts.Timeout)
defer cancel()
var buf []byte
// 使用 base64 编码的 data URL避免 URL 编码导致的 + 号问题
htmlBytes := []byte(htmlContent)
htmlBase64 := base64.StdEncoding.EncodeToString(htmlBytes)
dataURL := fmt.Sprintf("data:text/html;charset=utf-8;base64,%s", htmlBase64)
if opts.Logger != nil {
opts.Logger.Debug("Navigating to base64 data URL",
zap.Int("html_length", len(htmlContent)),
zap.Int("base64_length", len(htmlBase64)))
}
// 执行截图任务
var err error
if opts.FullPage {
if opts.Logger != nil {
opts.Logger.Debug("Taking full page screenshot")
}
err = chromedp.Run(ctx,
chromedp.Navigate(dataURL),
chromedp.WaitReady("body", chromedp.ByQuery),
chromedp.Sleep(opts.WaitTime),
chromedp.FullScreenshot(&buf, opts.Quality),
)
} else {
if opts.Logger != nil {
opts.Logger.Debug("Taking viewport screenshot")
}
err = chromedp.Run(ctx,
chromedp.Navigate(dataURL),
chromedp.WaitReady("body", chromedp.ByQuery),
chromedp.Sleep(opts.WaitTime),
chromedp.CaptureScreenshot(&buf),
)
}
if err != nil {
if opts.Logger != nil {
opts.Logger.Error("Failed to capture screenshot", zap.Error(err))
}
return "", fmt.Errorf("failed to capture screenshot: %w", err)
}
if opts.Logger != nil {
opts.Logger.Info("Screenshot captured successfully",
zap.Int("image_size", len(buf)))
}
// 转换为base64
base64Str := base64.StdEncoding.EncodeToString(buf)
if opts.Logger != nil {
opts.Logger.Info("Screenshot converted to base64",
zap.Int("base64_length", len(base64Str)))
}
return base64Str, nil
}
// ScreenshotElement 对页面中的特定元素进行截图
func ScreenshotElement(ctx context.Context, url string, selector string, opts *ScreenshotOptions) (string, error) {
if opts == nil {
opts = DefaultScreenshotOptions()
}
// 创建上下文
allocCtx, cancel := chromedp.NewExecAllocator(ctx,
chromedp.NoSandbox,
chromedp.NoFirstRun,
chromedp.NoDefaultBrowserCheck,
chromedp.Headless,
chromedp.DisableGPU,
)
defer cancel()
ctx, cancel = chromedp.NewContext(allocCtx, chromedp.WithLogf(func(format string, v ...interface{}) {
if opts.Logger != nil {
opts.Logger.Debug(fmt.Sprintf(format, v...))
}
}))
defer cancel()
// 设置超时
ctx, cancel = context.WithTimeout(ctx, opts.Timeout)
defer cancel()
var buf []byte
// 执行截图任务
err := chromedp.Run(ctx,
chromedp.Navigate(url),
chromedp.WaitVisible(selector, chromedp.ByQuery),
chromedp.Sleep(opts.WaitTime),
chromedp.Screenshot(selector, &buf, chromedp.NodeVisible),
)
if err != nil {
return "", fmt.Errorf("failed to capture element screenshot: %w", err)
}
// 转换为base64
base64Str := base64.StdEncoding.EncodeToString(buf)
return base64Str, nil
}
// ScreenshotHTMLToMessageChain 对HTML内容进行截图并返回包含图片的消息链
func ScreenshotHTMLToMessageChain(ctx context.Context, htmlContent string, opts *ScreenshotOptions) (protocol.MessageChain, error) {
base64Data, err := ScreenshotHTML(ctx, htmlContent, opts)
if err != nil {
return nil, err
}
chain := protocol.NewMessageChain()
chain = chain.AppendImageFromBase64(base64Data)
return chain, nil
}
// ScreenshotURLToMessageChain 对URL进行截图并返回包含图片的消息链
func ScreenshotURLToMessageChain(ctx context.Context, url string, opts *ScreenshotOptions) (protocol.MessageChain, error) {
base64Data, err := ScreenshotURL(ctx, url, opts)
if err != nil {
return nil, err
}
chain := protocol.NewMessageChain()
chain = chain.AppendImageFromBase64(base64Data)
return chain, nil
}
// ============================================================================
// HTML 模板功能
// ============================================================================
// RenderHTMLTemplate 渲染HTML模板并返回HTML字符串
func RenderHTMLTemplate(tmplContent string, data interface{}) (string, error) {
tmpl, err := template.New("html").Parse(tmplContent)
if err != nil {
return "", fmt.Errorf("failed to parse template: %w", err)
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
return "", fmt.Errorf("failed to execute template: %w", err)
}
return buf.String(), nil
}
// RenderHTMLTemplateFromFile 从文件加载并渲染HTML模板
func RenderHTMLTemplateFromFile(tmplPath string, data interface{}) (string, error) {
content, err := os.ReadFile(tmplPath)
if err != nil {
return "", fmt.Errorf("failed to read template file: %w", err)
}
tmpl, err := template.New(filepath.Base(tmplPath)).Parse(string(content))
if err != nil {
return "", fmt.Errorf("failed to parse template: %w", err)
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
return "", fmt.Errorf("failed to execute template: %w", err)
}
return buf.String(), nil
}
// ScreenshotTemplate 渲染HTML模板并截图返回base64编码
func ScreenshotTemplate(ctx context.Context, tmplContent string, data interface{}, opts *ScreenshotOptions) (string, error) {
html, err := RenderHTMLTemplate(tmplContent, data)
if err != nil {
return "", err
}
return ScreenshotHTML(ctx, html, opts)
}
// ScreenshotTemplateFromFile 从文件加载模板渲染并截图返回base64编码
func ScreenshotTemplateFromFile(ctx context.Context, tmplPath string, data interface{}, opts *ScreenshotOptions) (string, error) {
html, err := RenderHTMLTemplateFromFile(tmplPath, data)
if err != nil {
return "", err
}
return ScreenshotHTML(ctx, html, opts)
}
// ScreenshotTemplateToMessageChain 渲染HTML模板并截图返回包含图片的消息链
func ScreenshotTemplateToMessageChain(ctx context.Context, tmplContent string, data interface{}, opts *ScreenshotOptions) (protocol.MessageChain, error) {
base64Data, err := ScreenshotTemplate(ctx, tmplContent, data, opts)
if err != nil {
return nil, err
}
chain := protocol.NewMessageChain()
chain = chain.AppendImageFromBase64(base64Data)
return chain, nil
}
// ScreenshotTemplateFromFileToMessageChain 从文件加载模板,渲染并截图,返回包含图片的消息链
func ScreenshotTemplateFromFileToMessageChain(ctx context.Context, tmplPath string, data interface{}, opts *ScreenshotOptions) (protocol.MessageChain, error) {
base64Data, err := ScreenshotTemplateFromFile(ctx, tmplPath, data, opts)
if err != nil {
return nil, err
}
chain := protocol.NewMessageChain()
chain = chain.AppendImageFromBase64(base64Data)
return chain, nil
}