Files
backend/internal/pkg/openai/client.go
lan 4d8f2ec997 Initial backend repository commit.
Set up project files and add .gitignore to exclude local build/runtime artifacts.

Made-with: Cursor
2026-03-09 21:28:58 +08:00

439 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package openai
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"image"
_ "image/gif"
"image/jpeg"
_ "image/png"
"io"
"net/http"
"net/url"
"strings"
"time"
xdraw "golang.org/x/image/draw"
)
const moderationSystemPrompt = "你是中文社区的内容审核助手负责对“帖子标题、正文、配图”做联合审核。目标是平衡社区安全与正常交流必须拦截高风险违规内容但不要误伤正常玩梗、二创、吐槽和轻度调侃。请只输出指定JSON。\n\n审核流程\n1) 先判断是否命中硬性违规;\n2) 再判断语境(玩笑/自嘲/朋友间互动/作品讨论);\n3) 做文图交叉判断(文本+图片合并理解);\n4) 给出 approved 与简短 reason。\n\n硬性违规命中任一项必须 approved=false\nA. 宣传对立与煽动撕裂:\n- 明确煽动群体对立、地域对立、性别对立、民族宗教对立,鼓动仇恨、排斥、报复。\nB. 严重人身攻击与网暴引导:\n- 持续性侮辱贬损、羞辱人格、号召围攻/骚扰/挂人/线下冲突。\nC. 开盒/人肉/隐私暴露:\n- 故意公开、拼接、索取他人可识别隐私信息(姓名+联系方式、身份证号、住址、学校单位、车牌、定位轨迹等);\n- 图片/截图中出现可识别隐私信息并伴随曝光意图,也按违规处理。\nD. 其他高危违规:\n- 违法犯罪、暴力威胁、极端仇恨、色情低俗、诈骗引流、恶意广告等。\n\n放行规则以下通常 approved=true\n- 正常玩梗、表情包、谐音梗、二次创作、无恶意的吐槽;\n- 非定向、轻度口语化吐槽(无明确攻击对象、无网暴号召、无隐私暴露);\n- 对社会事件/作品的理性讨论、观点争论(即使语气尖锐,但未煽动对立或人身攻击)。\n\n边界判定\n- 若只是“梗文化表达”且不指向现实伤害,优先通过;\n- 若存在明确伤害意图(煽动、围攻、曝光隐私),必须拒绝;\n- 对模糊内容不因个别粗口直接拒绝,需结合对象、意图、号召性和可执行性综合判断。\n\nreason 要求:\n- approved=false 时中文10-30字说明核心违规点\n- approved=true 时reason 为空字符串。\n\n输出格式严格\n仅输出一行JSON对象不要Markdown不要额外解释\n{\"approved\": true/false, \"reason\": \"...\"}"
const (
defaultMaxImagesPerModerationRequest = 1
maxModerationResultRetries = 3
maxChatCompletionRetries = 3
initialRetryBackoff = 500 * time.Millisecond
maxDownloadImageBytes = 10 * 1024 * 1024
maxModerationImageSide = 1280
compressedJPEGQuality = 72
maxCompressedPayloadBytes = 1536 * 1024
)
type Client interface {
IsEnabled() bool
Config() Config
ModeratePost(ctx context.Context, title, content string, images []string) (bool, string, error)
ModerateComment(ctx context.Context, content string, images []string) (bool, string, error)
}
type clientImpl struct {
cfg Config
httpClient *http.Client
}
func NewClient(cfg Config) Client {
timeout := cfg.RequestTimeoutSeconds
if timeout <= 0 {
timeout = 30
}
return &clientImpl{
cfg: cfg,
httpClient: &http.Client{
Timeout: time.Duration(timeout) * time.Second,
},
}
}
func (c *clientImpl) IsEnabled() bool {
return c.cfg.Enabled && c.cfg.APIKey != "" && c.cfg.BaseURL != ""
}
func (c *clientImpl) Config() Config {
return c.cfg
}
func (c *clientImpl) ModeratePost(ctx context.Context, title, content string, images []string) (bool, string, error) {
if !c.IsEnabled() {
return true, "", nil
}
return c.moderateContentInBatches(ctx, fmt.Sprintf("帖子标题:%s\n帖子内容%s", title, content), images)
}
func (c *clientImpl) ModerateComment(ctx context.Context, content string, images []string) (bool, string, error) {
if !c.IsEnabled() {
return true, "", nil
}
return c.moderateContentInBatches(ctx, fmt.Sprintf("评论内容:%s", content), images)
}
func (c *clientImpl) moderateContentInBatches(ctx context.Context, contentPrompt string, images []string) (bool, string, error) {
cleanImages := normalizeImageURLs(images)
optimizedImages := c.optimizeImagesForModeration(ctx, cleanImages)
maxImagesPerRequest := c.maxImagesPerModerationRequest()
totalBatches := 1
if len(optimizedImages) > 0 {
totalBatches = (len(optimizedImages) + maxImagesPerRequest - 1) / maxImagesPerRequest
}
// 图片超过单批上限时分批审核,任一批次拒绝即整体拒绝
for i := 0; i < totalBatches; i++ {
start := i * maxImagesPerRequest
end := start + maxImagesPerRequest
if end > len(optimizedImages) {
end = len(optimizedImages)
}
batchImages := []string{}
if len(optimizedImages) > 0 {
batchImages = optimizedImages[start:end]
}
approved, reason, err := c.moderateSingleBatch(ctx, contentPrompt, batchImages, i+1, totalBatches)
if err != nil {
return false, "", err
}
if !approved {
if strings.TrimSpace(reason) != "" && totalBatches > 1 {
reason = fmt.Sprintf("第%d/%d批图片未通过%s", i+1, totalBatches, reason)
}
return false, reason, nil
}
}
return true, "", nil
}
func (c *clientImpl) moderateSingleBatch(
ctx context.Context,
contentPrompt string,
images []string,
batchNo, totalBatches int,
) (bool, string, error) {
userPrompt := fmt.Sprintf(
"%s\n图片批次%d/%d本次仅提供当前批次图片",
contentPrompt,
batchNo,
totalBatches,
)
var lastErr error
for attempt := 0; attempt < maxModerationResultRetries; attempt++ {
replyText, err := c.chatCompletion(ctx, c.cfg.ModerationModel, moderationSystemPrompt, userPrompt, images, 0.1, 220)
if err != nil {
lastErr = err
} else {
parsed := struct {
Approved bool `json:"approved"`
Reason string `json:"reason"`
}{}
if err := json.Unmarshal([]byte(extractJSONObject(replyText)), &parsed); err != nil {
lastErr = fmt.Errorf("failed to parse moderation result: %w", err)
} else {
return parsed.Approved, parsed.Reason, nil
}
}
if attempt == maxModerationResultRetries-1 {
break
}
if sleepErr := sleepWithBackoff(ctx, attempt); sleepErr != nil {
return false, "", sleepErr
}
}
return false, "", fmt.Errorf(
"moderation batch %d/%d failed after %d attempts: %w",
batchNo,
totalBatches,
maxModerationResultRetries,
lastErr,
)
}
type chatCompletionsRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
}
type chatMessage struct {
Role string `json:"role"`
Content interface{} `json:"content"`
}
type contentPart struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageURL *imageURLPart `json:"image_url,omitempty"`
}
type imageURLPart struct {
URL string `json:"url"`
}
type chatCompletionsResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
func (c *clientImpl) chatCompletion(
ctx context.Context,
model string,
systemPrompt string,
userPrompt string,
images []string,
temperature float64,
maxTokens int,
) (string, error) {
if model == "" {
return "", fmt.Errorf("model is empty")
}
cleanImages := normalizeImageURLs(images)
userParts := []contentPart{
{Type: "text", Text: userPrompt},
}
for _, image := range cleanImages {
userParts = append(userParts, contentPart{
Type: "image_url",
ImageURL: &imageURLPart{URL: image},
})
}
reqBody := chatCompletionsRequest{
Model: model,
Messages: []chatMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userParts},
},
Temperature: temperature,
MaxTokens: maxTokens,
}
data, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshal request: %w", err)
}
baseURL := strings.TrimRight(c.cfg.BaseURL, "/")
endpoint := baseURL + "/v1/chat/completions"
if strings.HasSuffix(baseURL, "/v1") {
endpoint = baseURL + "/chat/completions"
}
var lastErr error
for attempt := 0; attempt < maxChatCompletionRetries; attempt++ {
body, statusCode, err := c.doChatCompletionRequest(ctx, endpoint, data)
if err != nil {
lastErr = err
} else if statusCode >= 400 {
lastErr = fmt.Errorf("openai error status=%d body=%s", statusCode, string(body))
if !isRetryableStatusCode(statusCode) {
return "", lastErr
}
} else {
var parsed chatCompletionsResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return "", fmt.Errorf("decode response: %w", err)
}
if len(parsed.Choices) == 0 {
return "", fmt.Errorf("empty response choices")
}
return parsed.Choices[0].Message.Content, nil
}
if attempt == maxChatCompletionRetries-1 {
break
}
if sleepErr := sleepWithBackoff(ctx, attempt); sleepErr != nil {
return "", sleepErr
}
}
return "", lastErr
}
func (c *clientImpl) doChatCompletionRequest(ctx context.Context, endpoint string, data []byte) ([]byte, int, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data))
if err != nil {
return nil, 0, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.cfg.APIKey)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("request openai: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, 0, fmt.Errorf("read response: %w", err)
}
return body, resp.StatusCode, nil
}
func isRetryableStatusCode(statusCode int) bool {
if statusCode == http.StatusTooManyRequests {
return true
}
return statusCode >= 500 && statusCode <= 599
}
func sleepWithBackoff(ctx context.Context, attempt int) error {
delay := initialRetryBackoff * time.Duration(1<<attempt)
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return fmt.Errorf("request cancelled: %w", ctx.Err())
case <-timer.C:
return nil
}
}
func normalizeImageURLs(images []string) []string {
clean := make([]string, 0, len(images))
for _, image := range images {
trimmed := strings.TrimSpace(image)
if trimmed == "" {
continue
}
clean = append(clean, trimmed)
}
return clean
}
func extractJSONObject(raw string) string {
text := strings.TrimSpace(raw)
start := strings.Index(text, "{")
end := strings.LastIndex(text, "}")
if start >= 0 && end > start {
return text[start : end+1]
}
return text
}
func (c *clientImpl) maxImagesPerModerationRequest() int {
// 审核固定单图请求降低单次payload体积减少超时风险。
if c.cfg.ModerationMaxImagesPerRequest <= 0 {
return defaultMaxImagesPerModerationRequest
}
if c.cfg.ModerationMaxImagesPerRequest > 1 {
return 1
}
return c.cfg.ModerationMaxImagesPerRequest
}
func (c *clientImpl) optimizeImagesForModeration(ctx context.Context, images []string) []string {
if len(images) == 0 {
return images
}
optimized := make([]string, 0, len(images))
for _, imageURL := range images {
optimized = append(optimized, c.tryCompressImageForModeration(ctx, imageURL))
}
return optimized
}
func (c *clientImpl) tryCompressImageForModeration(ctx context.Context, imageURL string) string {
parsed, err := url.Parse(imageURL)
if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") {
return imageURL
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, imageURL, nil)
if err != nil {
return imageURL
}
resp, err := c.httpClient.Do(req)
if err != nil {
return imageURL
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return imageURL
}
if !strings.HasPrefix(strings.ToLower(resp.Header.Get("Content-Type")), "image/") {
return imageURL
}
originBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxDownloadImageBytes))
if err != nil || len(originBytes) == 0 {
return imageURL
}
srcImg, _, err := image.Decode(bytes.NewReader(originBytes))
if err != nil {
return imageURL
}
dstImg := resizeIfNeeded(srcImg, maxModerationImageSide)
var buf bytes.Buffer
if err := jpeg.Encode(&buf, dstImg, &jpeg.Options{Quality: compressedJPEGQuality}); err != nil {
return imageURL
}
compressed := buf.Bytes()
if len(compressed) == 0 || len(compressed) > maxCompressedPayloadBytes {
return imageURL
}
// 压缩效果不明显时直接使用原图URL避免增大请求体。
if len(compressed) >= int(float64(len(originBytes))*0.95) {
return imageURL
}
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(compressed)
}
func resizeIfNeeded(src image.Image, maxSide int) image.Image {
bounds := src.Bounds()
w := bounds.Dx()
h := bounds.Dy()
if w <= 0 || h <= 0 || max(w, h) <= maxSide {
return src
}
newW, newH := w, h
if w >= h {
newW = maxSide
newH = int(float64(h) * (float64(maxSide) / float64(w)))
} else {
newH = maxSide
newW = int(float64(w) * (float64(maxSide) / float64(h)))
}
if newW < 1 {
newW = 1
}
if newH < 1 {
newH = 1
}
dst := image.NewRGBA(image.Rect(0, 0, newW, newH))
xdraw.CatmullRom.Scale(dst, dst.Bounds(), src, bounds, xdraw.Over, nil)
return dst
}