package milky import ( "bufio" "context" "fmt" "net" "net/http" "strings" "time" "go.uber.org/zap" ) // SSEClient Server-Sent Events 客户端 // 用于接收协议端推送的事件 (GET /event) type SSEClient struct { url string accessToken string eventChan chan []byte logger *zap.Logger reconnectDelay time.Duration maxReconnect int ctx context.Context cancel context.CancelFunc } // NewSSEClient 创建 SSE 客户端 func NewSSEClient(url, accessToken string, logger *zap.Logger) *SSEClient { ctx, cancel := context.WithCancel(context.Background()) return &SSEClient{ url: url, accessToken: accessToken, eventChan: make(chan []byte, 100), logger: logger.Named("sse-client"), reconnectDelay: 5 * time.Second, maxReconnect: -1, // 无限重连 ctx: ctx, cancel: cancel, } } // Connect 连接到 SSE 服务器 func (c *SSEClient) Connect(ctx context.Context) error { c.logger.Info("Starting SSE client", zap.String("url", c.url)) go c.connectLoop(ctx) return nil } // connectLoop 连接循环(支持自动重连) func (c *SSEClient) connectLoop(ctx context.Context) { reconnectCount := 0 for { select { case <-ctx.Done(): c.logger.Info("SSE client stopped") return case <-c.ctx.Done(): c.logger.Info("SSE client stopped") return default: } c.logger.Info("Connecting to SSE server", zap.String("url", c.url), zap.Int("reconnect_count", reconnectCount)) err := c.connect(ctx) if err != nil { c.logger.Error("SSE connection failed", zap.Error(err)) } // 检查是否需要重连 if c.maxReconnect >= 0 && reconnectCount >= c.maxReconnect { c.logger.Error("Max reconnect attempts reached", zap.Int("count", reconnectCount)) return } reconnectCount++ // 等待后重连 c.logger.Info("Reconnecting after delay", zap.Duration("delay", c.reconnectDelay), zap.Int("attempt", reconnectCount)) select { case <-time.After(c.reconnectDelay): case <-ctx.Done(): return case <-c.ctx.Done(): return } } } // connect 建立单次连接 func (c *SSEClient) connect(ctx context.Context) error { // 创建 HTTP 请求 req, err := http.NewRequestWithContext(ctx, "GET", c.url, nil) if err != nil { return fmt.Errorf("failed to create request: %w", err) } // 设置 Authorization 头 if c.accessToken != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.accessToken)) } // 设置 Accept 头 req.Header.Set("Accept", "text/event-stream") req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Connection", "keep-alive") // 发送请求 client := &http.Client{ Timeout: 0, // 无超时,保持长连接 Transport: &http.Transport{ DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, }).DialContext, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, }, } resp, err := client.Do(req) if err != nil { return fmt.Errorf("failed to connect: %w", err) } defer resp.Body.Close() // 检查状态码 if resp.StatusCode != http.StatusOK { return fmt.Errorf("unexpected status code: %d", resp.StatusCode) } // 检查 Content-Type contentType := resp.Header.Get("Content-Type") if !strings.HasPrefix(contentType, "text/event-stream") { return fmt.Errorf("unexpected content type: %s", contentType) } c.logger.Info("SSE connection established") // 读取事件流 return c.readEventStream(ctx, resp) } // readEventStream 读取事件流 func (c *SSEClient) readEventStream(ctx context.Context, resp *http.Response) error { scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) var eventType string var dataLines []string for scanner.Scan() { select { case <-ctx.Done(): return ctx.Err() case <-c.ctx.Done(): return c.ctx.Err() default: } line := scanner.Text() // 空行表示事件结束 if line == "" { if eventType != "" && len(dataLines) > 0 { c.processEvent(eventType, dataLines) eventType = "" dataLines = nil } continue } // 注释行(以 : 开头) if strings.HasPrefix(line, ":") { continue } // 解析字段 if strings.HasPrefix(line, "event:") { eventType = strings.TrimSpace(strings.TrimPrefix(line, "event:")) } else if strings.HasPrefix(line, "data:") { data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) dataLines = append(dataLines, data) } // 忽略其他字段(id, retry 等) } if err := scanner.Err(); err != nil { return fmt.Errorf("scanner error: %w", err) } return fmt.Errorf("connection closed") } // processEvent 处理事件 func (c *SSEClient) processEvent(eventType string, dataLines []string) { // 只处理 milky_event 类型 if eventType != "milky_event" && eventType != "" { c.logger.Debug("Ignoring non-milky event", zap.String("event_type", eventType)) return } // 合并多行 data data := strings.Join(dataLines, "\n") c.logger.Debug("Received SSE event", zap.String("event_type", eventType), zap.Int("data_length", len(data))) // 发送到事件通道 select { case c.eventChan <- []byte(data): default: c.logger.Warn("Event channel full, dropping event") } } // Events 获取事件通道 func (c *SSEClient) Events() <-chan []byte { return c.eventChan } // Close 关闭客户端 func (c *SSEClient) Close() error { c.cancel() close(c.eventChan) c.logger.Info("SSE client closed") return nil }