Files

106 lines
2.4 KiB
Go
Raw Permalink Normal View History

package gorse
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
)
// EmbeddingConfig embedding服务配置
type EmbeddingConfig struct {
APIKey string
URL string
Model string
}
var defaultEmbeddingConfig = EmbeddingConfig{
APIKey: "sk-ZPN5NMPSqEaOGCPfD2LqndZ5Wwmw3DC4CQgzgKhM35fI3RpD",
URL: "https://api.littlelan.cn/v1/embeddings",
Model: "BAAI/bge-m3",
}
// SetEmbeddingConfig 设置embedding配置
func SetEmbeddingConfig(apiKey, url, model string) {
if apiKey != "" {
defaultEmbeddingConfig.APIKey = apiKey
}
if url != "" {
defaultEmbeddingConfig.URL = url
}
if model != "" {
defaultEmbeddingConfig.Model = model
}
}
// GetEmbedding 获取文本的embedding
func GetEmbedding(text string) ([]float64, error) {
type embeddingRequest struct {
Input string `json:"input"`
Model string `json:"model"`
}
type embeddingResponse struct {
Data []struct {
Embedding []float64 `json:"embedding"`
} `json:"data"`
}
reqBody := embeddingRequest{
Input: text,
Model: defaultEmbeddingConfig.Model,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequest("POST", defaultEmbeddingConfig.URL, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+defaultEmbeddingConfig.APIKey)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("embedding API error: status=%d, body=%s", resp.StatusCode, string(body))
}
var result embeddingResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
if len(result.Data) == 0 {
return nil, fmt.Errorf("no embedding returned")
}
return result.Data[0].Embedding, nil
}
// InitEmbeddingWithConfig 从应用配置初始化embedding
func InitEmbeddingWithConfig(apiKey, url, model string) {
if apiKey == "" {
log.Println("[WARN] Gorse embedding API key not set, using default")
}
defaultEmbeddingConfig.APIKey = apiKey
if url != "" {
defaultEmbeddingConfig.URL = url
}
if model != "" {
defaultEmbeddingConfig.Model = model
}
}