package net import ( "bufio" "context" "fmt" "net" "net/http" "strings" "time" "go.uber.org/zap" ) // SSEClient Server-Sent Events 客户端 type SSEClient struct { url string accessToken string eventChan chan []byte logger *zap.Logger reconnectDelay time.Duration maxReconnect int ctx context.Context cancel context.CancelFunc eventFilter string } // SSEClientConfig SSE 客户端配置 type SSEClientConfig struct { URL string AccessToken string ReconnectDelay time.Duration MaxReconnect int EventFilter string BufferSize int } // NewSSEClient 创建 SSE 客户端 func NewSSEClient(config SSEClientConfig, logger *zap.Logger) *SSEClient { ctx, cancel := context.WithCancel(context.Background()) if config.ReconnectDelay == 0 { config.ReconnectDelay = 5 * time.Second } if config.MaxReconnect == 0 { config.MaxReconnect = -1 } if config.BufferSize == 0 { config.BufferSize = 100 } return &SSEClient{ url: config.URL, accessToken: config.AccessToken, eventChan: make(chan []byte, config.BufferSize), logger: logger.Named("sse-client"), reconnectDelay: config.ReconnectDelay, maxReconnect: config.MaxReconnect, eventFilter: config.EventFilter, ctx: ctx, cancel: cancel, } } // Connect 连接到 SSE 服务器 func (c *SSEClient) Connect(ctx context.Context) error { c.logger.Info("Starting SSE client", zap.String("url", c.url)) go c.connectLoop(ctx) return nil } // connectLoop 连接循环 func (c *SSEClient) connectLoop(ctx context.Context) { reconnectCount := 0 for { select { case <-ctx.Done(): c.logger.Info("SSE client stopped") return case <-c.ctx.Done(): c.logger.Info("SSE client stopped") return default: } c.logger.Info("Connecting to SSE server", zap.String("url", c.url), zap.Int("reconnect_count", reconnectCount)) err := c.connect(ctx) if err != nil { c.logger.Error("SSE connection failed", zap.Error(err)) } if c.maxReconnect >= 0 && reconnectCount >= c.maxReconnect { c.logger.Error("Max reconnect attempts reached", zap.Int("count", reconnectCount)) return } reconnectCount++ c.logger.Info("Reconnecting after delay", zap.Duration("delay", c.reconnectDelay), zap.Int("attempt", reconnectCount)) select { case <-time.After(c.reconnectDelay): case <-ctx.Done(): return case <-c.ctx.Done(): return } } } // connect 建立单次连接 func (c *SSEClient) connect(ctx context.Context) error { req, err := http.NewRequestWithContext(ctx, "GET", c.url, nil) if err != nil { return fmt.Errorf("failed to create request: %w", err) } if c.accessToken != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.accessToken)) } req.Header.Set("Accept", "text/event-stream") req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Connection", "keep-alive") client := &http.Client{ Timeout: 0, Transport: &http.Transport{ DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, }).DialContext, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, }, } resp, err := client.Do(req) if err != nil { return fmt.Errorf("failed to connect: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("unexpected status code: %d", resp.StatusCode) } contentType := resp.Header.Get("Content-Type") if !strings.HasPrefix(contentType, "text/event-stream") { return fmt.Errorf("unexpected content type: %s", contentType) } c.logger.Info("SSE connection established") return c.readEventStream(ctx, resp) } // readEventStream 读取事件流 func (c *SSEClient) readEventStream(ctx context.Context, resp *http.Response) error { scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) var eventType string var dataLines []string for scanner.Scan() { select { case <-ctx.Done(): return ctx.Err() case <-c.ctx.Done(): return c.ctx.Err() default: } line := scanner.Text() if line == "" { if len(dataLines) > 0 { c.processEvent(eventType, dataLines) eventType = "" dataLines = nil } continue } if strings.HasPrefix(line, ":") { continue } if strings.HasPrefix(line, "event:") { eventType = strings.TrimSpace(strings.TrimPrefix(line, "event:")) } else if strings.HasPrefix(line, "data:") { data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) dataLines = append(dataLines, data) } } if err := scanner.Err(); err != nil { return fmt.Errorf("scanner error: %w", err) } return fmt.Errorf("connection closed") } // processEvent 处理事件 func (c *SSEClient) processEvent(eventType string, dataLines []string) { if c.eventFilter != "" && eventType != c.eventFilter && eventType != "" { c.logger.Debug("Ignoring filtered event", zap.String("event_type", eventType)) return } data := strings.Join(dataLines, "\n") c.logger.Debug("Received SSE event", zap.String("event_type", eventType), zap.Int("data_length", len(data))) select { case c.eventChan <- []byte(data): default: c.logger.Warn("Event channel full, dropping event") } } // Events 获取事件通道 func (c *SSEClient) Events() <-chan []byte { return c.eventChan } // Close 关闭客户端 func (c *SSEClient) Close() error { c.cancel() close(c.eventChan) c.logger.Info("SSE client closed") return nil }