245 lines
5.5 KiB
Go
245 lines
5.5 KiB
Go
|
|
package net
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bufio"
|
||
|
|
"context"
|
||
|
|
"fmt"
|
||
|
|
"net"
|
||
|
|
"net/http"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"go.uber.org/zap"
|
||
|
|
)
|
||
|
|
|
||
|
|
// SSEClient Server-Sent Events 客户端
|
||
|
|
type SSEClient struct {
|
||
|
|
url string
|
||
|
|
accessToken string
|
||
|
|
eventChan chan []byte
|
||
|
|
logger *zap.Logger
|
||
|
|
reconnectDelay time.Duration
|
||
|
|
maxReconnect int
|
||
|
|
ctx context.Context
|
||
|
|
cancel context.CancelFunc
|
||
|
|
eventFilter string
|
||
|
|
}
|
||
|
|
|
||
|
|
// SSEClientConfig SSE 客户端配置
|
||
|
|
type SSEClientConfig struct {
|
||
|
|
URL string
|
||
|
|
AccessToken string
|
||
|
|
ReconnectDelay time.Duration
|
||
|
|
MaxReconnect int
|
||
|
|
EventFilter string
|
||
|
|
BufferSize int
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewSSEClient 创建 SSE 客户端
|
||
|
|
func NewSSEClient(config SSEClientConfig, logger *zap.Logger) *SSEClient {
|
||
|
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
|
|
||
|
|
if config.ReconnectDelay == 0 {
|
||
|
|
config.ReconnectDelay = 5 * time.Second
|
||
|
|
}
|
||
|
|
if config.MaxReconnect == 0 {
|
||
|
|
config.MaxReconnect = -1
|
||
|
|
}
|
||
|
|
if config.BufferSize == 0 {
|
||
|
|
config.BufferSize = 100
|
||
|
|
}
|
||
|
|
|
||
|
|
return &SSEClient{
|
||
|
|
url: config.URL,
|
||
|
|
accessToken: config.AccessToken,
|
||
|
|
eventChan: make(chan []byte, config.BufferSize),
|
||
|
|
logger: logger.Named("sse-client"),
|
||
|
|
reconnectDelay: config.ReconnectDelay,
|
||
|
|
maxReconnect: config.MaxReconnect,
|
||
|
|
eventFilter: config.EventFilter,
|
||
|
|
ctx: ctx,
|
||
|
|
cancel: cancel,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Connect 连接到 SSE 服务器
|
||
|
|
func (c *SSEClient) Connect(ctx context.Context) error {
|
||
|
|
c.logger.Info("Starting SSE client", zap.String("url", c.url))
|
||
|
|
go c.connectLoop(ctx)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// connectLoop 连接循环
|
||
|
|
func (c *SSEClient) connectLoop(ctx context.Context) {
|
||
|
|
reconnectCount := 0
|
||
|
|
|
||
|
|
for {
|
||
|
|
select {
|
||
|
|
case <-ctx.Done():
|
||
|
|
c.logger.Info("SSE client stopped")
|
||
|
|
return
|
||
|
|
case <-c.ctx.Done():
|
||
|
|
c.logger.Info("SSE client stopped")
|
||
|
|
return
|
||
|
|
default:
|
||
|
|
}
|
||
|
|
|
||
|
|
c.logger.Info("Connecting to SSE server",
|
||
|
|
zap.String("url", c.url),
|
||
|
|
zap.Int("reconnect_count", reconnectCount))
|
||
|
|
|
||
|
|
err := c.connect(ctx)
|
||
|
|
if err != nil {
|
||
|
|
c.logger.Error("SSE connection failed", zap.Error(err))
|
||
|
|
}
|
||
|
|
|
||
|
|
if c.maxReconnect >= 0 && reconnectCount >= c.maxReconnect {
|
||
|
|
c.logger.Error("Max reconnect attempts reached", zap.Int("count", reconnectCount))
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
reconnectCount++
|
||
|
|
|
||
|
|
c.logger.Info("Reconnecting after delay",
|
||
|
|
zap.Duration("delay", c.reconnectDelay),
|
||
|
|
zap.Int("attempt", reconnectCount))
|
||
|
|
|
||
|
|
select {
|
||
|
|
case <-time.After(c.reconnectDelay):
|
||
|
|
case <-ctx.Done():
|
||
|
|
return
|
||
|
|
case <-c.ctx.Done():
|
||
|
|
return
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// connect 建立单次连接
|
||
|
|
func (c *SSEClient) connect(ctx context.Context) error {
|
||
|
|
req, err := http.NewRequestWithContext(ctx, "GET", c.url, nil)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("failed to create request: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if c.accessToken != "" {
|
||
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.accessToken))
|
||
|
|
}
|
||
|
|
|
||
|
|
req.Header.Set("Accept", "text/event-stream")
|
||
|
|
req.Header.Set("Cache-Control", "no-cache")
|
||
|
|
req.Header.Set("Connection", "keep-alive")
|
||
|
|
|
||
|
|
client := &http.Client{
|
||
|
|
Timeout: 0,
|
||
|
|
Transport: &http.Transport{
|
||
|
|
DialContext: (&net.Dialer{
|
||
|
|
Timeout: 30 * time.Second,
|
||
|
|
KeepAlive: 30 * time.Second,
|
||
|
|
}).DialContext,
|
||
|
|
MaxIdleConns: 100,
|
||
|
|
IdleConnTimeout: 90 * time.Second,
|
||
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
||
|
|
ExpectContinueTimeout: 1 * time.Second,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
resp, err := client.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("failed to connect: %w", err)
|
||
|
|
}
|
||
|
|
defer resp.Body.Close()
|
||
|
|
|
||
|
|
if resp.StatusCode != http.StatusOK {
|
||
|
|
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||
|
|
}
|
||
|
|
|
||
|
|
contentType := resp.Header.Get("Content-Type")
|
||
|
|
if !strings.HasPrefix(contentType, "text/event-stream") {
|
||
|
|
return fmt.Errorf("unexpected content type: %s", contentType)
|
||
|
|
}
|
||
|
|
|
||
|
|
c.logger.Info("SSE connection established")
|
||
|
|
|
||
|
|
return c.readEventStream(ctx, resp)
|
||
|
|
}
|
||
|
|
|
||
|
|
// readEventStream 读取事件流
|
||
|
|
func (c *SSEClient) readEventStream(ctx context.Context, resp *http.Response) error {
|
||
|
|
scanner := bufio.NewScanner(resp.Body)
|
||
|
|
scanner.Split(bufio.ScanLines)
|
||
|
|
|
||
|
|
var eventType string
|
||
|
|
var dataLines []string
|
||
|
|
|
||
|
|
for scanner.Scan() {
|
||
|
|
select {
|
||
|
|
case <-ctx.Done():
|
||
|
|
return ctx.Err()
|
||
|
|
case <-c.ctx.Done():
|
||
|
|
return c.ctx.Err()
|
||
|
|
default:
|
||
|
|
}
|
||
|
|
|
||
|
|
line := scanner.Text()
|
||
|
|
|
||
|
|
if line == "" {
|
||
|
|
if len(dataLines) > 0 {
|
||
|
|
c.processEvent(eventType, dataLines)
|
||
|
|
eventType = ""
|
||
|
|
dataLines = nil
|
||
|
|
}
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
if strings.HasPrefix(line, ":") {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
if strings.HasPrefix(line, "event:") {
|
||
|
|
eventType = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
||
|
|
} else if strings.HasPrefix(line, "data:") {
|
||
|
|
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||
|
|
dataLines = append(dataLines, data)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := scanner.Err(); err != nil {
|
||
|
|
return fmt.Errorf("scanner error: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return fmt.Errorf("connection closed")
|
||
|
|
}
|
||
|
|
|
||
|
|
// processEvent 处理事件
|
||
|
|
func (c *SSEClient) processEvent(eventType string, dataLines []string) {
|
||
|
|
if c.eventFilter != "" && eventType != c.eventFilter && eventType != "" {
|
||
|
|
c.logger.Debug("Ignoring filtered event", zap.String("event_type", eventType))
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
data := strings.Join(dataLines, "\n")
|
||
|
|
|
||
|
|
c.logger.Debug("Received SSE event",
|
||
|
|
zap.String("event_type", eventType),
|
||
|
|
zap.Int("data_length", len(data)))
|
||
|
|
|
||
|
|
select {
|
||
|
|
case c.eventChan <- []byte(data):
|
||
|
|
default:
|
||
|
|
c.logger.Warn("Event channel full, dropping event")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Events 获取事件通道
|
||
|
|
func (c *SSEClient) Events() <-chan []byte {
|
||
|
|
return c.eventChan
|
||
|
|
}
|
||
|
|
|
||
|
|
// Close 关闭客户端
|
||
|
|
func (c *SSEClient) Close() error {
|
||
|
|
c.cancel()
|
||
|
|
close(c.eventChan)
|
||
|
|
c.logger.Info("SSE client closed")
|
||
|
|
return nil
|
||
|
|
}
|