Initial backend repository commit.
Set up project files and add .gitignore to exclude local build/runtime artifacts. Made-with: Cursor
This commit is contained in:
115
internal/pkg/avatar/avatar.go
Normal file
115
internal/pkg/avatar/avatar.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package avatar
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// 预定义一组好看的颜色
|
||||
var colors = []string{
|
||||
"#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4",
|
||||
"#FFEAA7", "#DDA0DD", "#98D8C8", "#F7DC6F",
|
||||
"#BB8FCE", "#85C1E9", "#F8B500", "#00CED1",
|
||||
"#E74C3C", "#3498DB", "#2ECC71", "#9B59B6",
|
||||
"#1ABC9C", "#F39C12", "#E67E22", "#16A085",
|
||||
}
|
||||
|
||||
// SVG模板
|
||||
const svgTemplate = `<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 100" width="%d" height="%d">
|
||||
<rect width="100" height="100" fill="%s"/>
|
||||
<text x="50" y="50" font-family="Arial, sans-serif" font-size="40" font-weight="bold" fill="#ffffff" text-anchor="middle" dominant-baseline="central">%s</text>
|
||||
</svg>`
|
||||
|
||||
// GenerateSVGAvatar 根据用户名生成SVG头像
|
||||
// username: 用户名
|
||||
// size: 头像尺寸(像素)
|
||||
func GenerateSVGAvatar(username string, size int) string {
|
||||
initials := getInitials(username)
|
||||
color := stringToColor(username)
|
||||
return fmt.Sprintf(svgTemplate, size, size, color, initials)
|
||||
}
|
||||
|
||||
// GenerateAvatarDataURI 生成Data URI格式的头像
|
||||
// 可以直接在HTML img标签或CSS background-image中使用
|
||||
func GenerateAvatarDataURI(username string, size int) string {
|
||||
svg := GenerateSVGAvatar(username, size)
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(svg))
|
||||
return fmt.Sprintf("data:image/svg+xml;base64,%s", encoded)
|
||||
}
|
||||
|
||||
// getInitials 获取用户名首字母
|
||||
// 中文取第一个字,英文取首字母(最多2个)
|
||||
func getInitials(username string) string {
|
||||
if username == "" {
|
||||
return "?"
|
||||
}
|
||||
|
||||
// 检查是否是中文字符
|
||||
firstRune, _ := utf8.DecodeRuneInString(username)
|
||||
if isChinese(firstRune) {
|
||||
// 中文直接返回第一个字符
|
||||
return string(firstRune)
|
||||
}
|
||||
|
||||
// 英文处理:取前两个单词的首字母
|
||||
// 例如: "John Doe" -> "JD", "john" -> "J"
|
||||
result := []rune{}
|
||||
for i, r := range username {
|
||||
if i == 0 {
|
||||
result = append(result, toUpper(r))
|
||||
} else if r == ' ' || r == '_' || r == '-' {
|
||||
// 找到下一个字符作为第二个首字母
|
||||
nextIdx := i + 1
|
||||
if nextIdx < len(username) {
|
||||
nextRune, _ := utf8.DecodeRuneInString(username[nextIdx:])
|
||||
if nextRune != utf8.RuneError && nextRune != ' ' {
|
||||
result = append(result, toUpper(nextRune))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return "?"
|
||||
}
|
||||
|
||||
// 最多返回2个字符
|
||||
if len(result) > 2 {
|
||||
result = result[:2]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// isChinese 判断是否是中文字符
|
||||
func isChinese(r rune) bool {
|
||||
return r >= 0x4E00 && r <= 0x9FFF
|
||||
}
|
||||
|
||||
// toUpper 将字母转换为大写
|
||||
func toUpper(r rune) rune {
|
||||
if r >= 'a' && r <= 'z' {
|
||||
return r - 32
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// stringToColor 根据字符串生成颜色
|
||||
// 使用简单的哈希算法确保同一用户名每次生成的颜色一致
|
||||
func stringToColor(s string) string {
|
||||
if s == "" {
|
||||
return colors[0]
|
||||
}
|
||||
|
||||
hash := 0
|
||||
for _, r := range s {
|
||||
hash = (hash*31 + int(r)) % len(colors)
|
||||
}
|
||||
if hash < 0 {
|
||||
hash = -hash
|
||||
}
|
||||
|
||||
return colors[hash%len(colors)]
|
||||
}
|
||||
118
internal/pkg/avatar/avatar_test.go
Normal file
118
internal/pkg/avatar/avatar_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package avatar
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetInitials(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
want string
|
||||
}{
|
||||
{"中文用户名", "张三", "张"},
|
||||
{"英文用户名", "John", "J"},
|
||||
{"英文全名", "John Doe", "JD"},
|
||||
{"带下划线", "john_doe", "JD"},
|
||||
{"带连字符", "john-doe", "JD"},
|
||||
{"空字符串", "", "?"},
|
||||
{"小写英文", "alice", "A"},
|
||||
{"中文复合", "李小龙", "李"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := getInitials(tt.username)
|
||||
if got != tt.want {
|
||||
t.Errorf("getInitials(%q) = %q, want %q", tt.username, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringToColor(t *testing.T) {
|
||||
// 测试同一用户名生成的颜色一致
|
||||
color1 := stringToColor("张三")
|
||||
color2 := stringToColor("张三")
|
||||
if color1 != color2 {
|
||||
t.Errorf("stringToColor should return consistent colors for the same input")
|
||||
}
|
||||
|
||||
// 测试不同用户名生成不同颜色(大概率)
|
||||
color3 := stringToColor("李四")
|
||||
if color1 == color3 {
|
||||
t.Logf("Warning: different usernames generated the same color (possible but unlikely)")
|
||||
}
|
||||
|
||||
// 测试空字符串
|
||||
color4 := stringToColor("")
|
||||
if color4 == "" {
|
||||
t.Errorf("stringToColor should return a color for empty string")
|
||||
}
|
||||
|
||||
// 验证颜色格式
|
||||
if !strings.HasPrefix(color4, "#") {
|
||||
t.Errorf("stringToColor should return hex color format starting with #")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSVGAvatar(t *testing.T) {
|
||||
svg := GenerateSVGAvatar("张三", 100)
|
||||
|
||||
// 验证SVG结构
|
||||
if !strings.Contains(svg, "<svg") {
|
||||
t.Errorf("SVG should contain <svg tag")
|
||||
}
|
||||
if !strings.Contains(svg, "</svg>") {
|
||||
t.Errorf("SVG should contain </svg> tag")
|
||||
}
|
||||
if !strings.Contains(svg, "width=\"100\"") {
|
||||
t.Errorf("SVG should have width=100")
|
||||
}
|
||||
if !strings.Contains(svg, "height=\"100\"") {
|
||||
t.Errorf("SVG should have height=100")
|
||||
}
|
||||
if !strings.Contains(svg, "张") {
|
||||
t.Errorf("SVG should contain the initial character")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAvatarDataURI(t *testing.T) {
|
||||
dataURI := GenerateAvatarDataURI("张三", 100)
|
||||
|
||||
// 验证Data URI格式
|
||||
if !strings.HasPrefix(dataURI, "data:image/svg+xml;base64,") {
|
||||
t.Errorf("Data URI should start with data:image/svg+xml;base64,")
|
||||
}
|
||||
|
||||
// 验证base64部分不为空
|
||||
parts := strings.Split(dataURI, ",")
|
||||
if len(parts) != 2 {
|
||||
t.Errorf("Data URI should have two parts separated by comma")
|
||||
}
|
||||
if parts[1] == "" {
|
||||
t.Errorf("Base64 part should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsChinese(t *testing.T) {
|
||||
tests := []struct {
|
||||
r rune
|
||||
want bool
|
||||
}{
|
||||
{'中', true},
|
||||
{'文', true},
|
||||
{'a', false},
|
||||
{'Z', false},
|
||||
{'0', false},
|
||||
{'_', false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := isChinese(tt.r)
|
||||
if got != tt.want {
|
||||
t.Errorf("isChinese(%q) = %v, want %v", tt.r, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
131
internal/pkg/email/client.go
Normal file
131
internal/pkg/email/client.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gomail "gopkg.in/gomail.v2"
|
||||
)
|
||||
|
||||
// Message 发信参数
|
||||
type Message struct {
|
||||
To []string
|
||||
Cc []string
|
||||
Bcc []string
|
||||
ReplyTo []string
|
||||
Subject string
|
||||
TextBody string
|
||||
HTMLBody string
|
||||
Attachments []string
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
IsEnabled() bool
|
||||
Config() Config
|
||||
Send(ctx context.Context, msg Message) error
|
||||
}
|
||||
|
||||
type clientImpl struct {
|
||||
cfg Config
|
||||
}
|
||||
|
||||
func NewClient(cfg Config) Client {
|
||||
return &clientImpl{cfg: cfg}
|
||||
}
|
||||
|
||||
func (c *clientImpl) IsEnabled() bool {
|
||||
return c.cfg.Enabled &&
|
||||
strings.TrimSpace(c.cfg.Host) != "" &&
|
||||
c.cfg.Port > 0 &&
|
||||
strings.TrimSpace(c.cfg.FromAddress) != ""
|
||||
}
|
||||
|
||||
func (c *clientImpl) Config() Config {
|
||||
return c.cfg
|
||||
}
|
||||
|
||||
func (c *clientImpl) Send(ctx context.Context, msg Message) error {
|
||||
if !c.IsEnabled() {
|
||||
return fmt.Errorf("email client is disabled or misconfigured")
|
||||
}
|
||||
if len(msg.To) == 0 {
|
||||
return fmt.Errorf("email recipient is empty")
|
||||
}
|
||||
if strings.TrimSpace(msg.Subject) == "" {
|
||||
return fmt.Errorf("email subject is empty")
|
||||
}
|
||||
if strings.TrimSpace(msg.TextBody) == "" && strings.TrimSpace(msg.HTMLBody) == "" {
|
||||
return fmt.Errorf("email body is empty")
|
||||
}
|
||||
|
||||
m := gomail.NewMessage()
|
||||
m.SetAddressHeader("From", c.cfg.FromAddress, c.cfg.FromName)
|
||||
m.SetHeader("To", msg.To...)
|
||||
if len(msg.Cc) > 0 {
|
||||
m.SetHeader("Cc", msg.Cc...)
|
||||
}
|
||||
if len(msg.Bcc) > 0 {
|
||||
m.SetHeader("Bcc", msg.Bcc...)
|
||||
}
|
||||
if len(msg.ReplyTo) > 0 {
|
||||
m.SetHeader("Reply-To", msg.ReplyTo...)
|
||||
}
|
||||
m.SetHeader("Subject", msg.Subject)
|
||||
|
||||
if strings.TrimSpace(msg.TextBody) != "" && strings.TrimSpace(msg.HTMLBody) != "" {
|
||||
m.SetBody("text/plain", msg.TextBody)
|
||||
m.AddAlternative("text/html", msg.HTMLBody)
|
||||
} else if strings.TrimSpace(msg.HTMLBody) != "" {
|
||||
m.SetBody("text/html", msg.HTMLBody)
|
||||
} else {
|
||||
m.SetBody("text/plain", msg.TextBody)
|
||||
}
|
||||
|
||||
for _, attachment := range msg.Attachments {
|
||||
if strings.TrimSpace(attachment) == "" {
|
||||
continue
|
||||
}
|
||||
m.Attach(attachment)
|
||||
}
|
||||
|
||||
timeout := c.cfg.TimeoutSeconds
|
||||
if timeout <= 0 {
|
||||
timeout = 15
|
||||
}
|
||||
dialer := gomail.NewDialer(c.cfg.Host, c.cfg.Port, c.cfg.Username, c.cfg.Password)
|
||||
if c.cfg.UseTLS {
|
||||
dialer.TLSConfig = &tls.Config{
|
||||
ServerName: c.cfg.Host,
|
||||
InsecureSkipVerify: c.cfg.InsecureSkipVerify,
|
||||
}
|
||||
// 465 端口通常要求直接 TLS(Implicit TLS)。
|
||||
if c.cfg.Port == 465 {
|
||||
dialer.SSL = true
|
||||
}
|
||||
}
|
||||
|
||||
sendCtx := ctx
|
||||
cancel := func() {}
|
||||
if timeout > 0 {
|
||||
sendCtx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- dialer.DialAndSend(m)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sendCtx.Done():
|
||||
return fmt.Errorf("send email canceled: %w", sendCtx.Err())
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return fmt.Errorf("send email failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
33
internal/pkg/email/config.go
Normal file
33
internal/pkg/email/config.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package email
|
||||
|
||||
import "carrot_bbs/internal/config"
|
||||
|
||||
// Config SMTP 邮件配置(由应用配置转换)
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
FromAddress string
|
||||
FromName string
|
||||
UseTLS bool
|
||||
InsecureSkipVerify bool
|
||||
TimeoutSeconds int
|
||||
}
|
||||
|
||||
// ConfigFromAppConfig 从应用配置转换
|
||||
func ConfigFromAppConfig(cfg *config.EmailConfig) Config {
|
||||
return Config{
|
||||
Enabled: cfg.Enabled,
|
||||
Host: cfg.Host,
|
||||
Port: cfg.Port,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
FromAddress: cfg.FromAddress,
|
||||
FromName: cfg.FromName,
|
||||
UseTLS: cfg.UseTLS,
|
||||
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||
TimeoutSeconds: cfg.Timeout,
|
||||
}
|
||||
}
|
||||
286
internal/pkg/gorse/client.go
Normal file
286
internal/pkg/gorse/client.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package gorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
gorseio "github.com/gorse-io/gorse-go"
|
||||
)
|
||||
|
||||
// FeedbackType 反馈类型
|
||||
type FeedbackType string
|
||||
|
||||
const (
|
||||
FeedbackTypeLike FeedbackType = "like" // 点赞
|
||||
FeedbackTypeStar FeedbackType = "star" // 收藏
|
||||
FeedbackTypeComment FeedbackType = "comment" // 评论
|
||||
FeedbackTypeRead FeedbackType = "read" // 浏览
|
||||
)
|
||||
|
||||
// Score 非个性化推荐返回的评分项
|
||||
type Score struct {
|
||||
Id string `json:"Id"`
|
||||
Score float64 `json:"Score"`
|
||||
}
|
||||
|
||||
// Client Gorse客户端接口
|
||||
type Client interface {
|
||||
// InsertFeedback 插入用户反馈
|
||||
InsertFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error
|
||||
// DeleteFeedback 删除用户反馈
|
||||
DeleteFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error
|
||||
// GetRecommend 获取个性化推荐列表
|
||||
GetRecommend(ctx context.Context, userID string, n int, offset int) ([]string, error)
|
||||
// GetNonPersonalized 获取非个性化推荐(通过名称)
|
||||
GetNonPersonalized(ctx context.Context, name string, n int, offset int, userID string) ([]string, error)
|
||||
// UpsertItem 插入或更新物品(无embedding)
|
||||
UpsertItem(ctx context.Context, itemID string, categories []string, comment string) error
|
||||
// UpsertItemWithEmbedding 插入或更新物品(带embedding)
|
||||
UpsertItemWithEmbedding(ctx context.Context, itemID string, categories []string, comment string, textToEmbed string) error
|
||||
// DeleteItem 删除物品
|
||||
DeleteItem(ctx context.Context, itemID string) error
|
||||
// UpsertUser 插入或更新用户
|
||||
UpsertUser(ctx context.Context, userID string, labels map[string]any) error
|
||||
// IsEnabled 检查是否启用
|
||||
IsEnabled() bool
|
||||
}
|
||||
|
||||
// client Gorse客户端实现
|
||||
type client struct {
|
||||
config Config
|
||||
gorse *gorseio.GorseClient
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewClient 创建新的Gorse客户端
|
||||
func NewClient(cfg Config) Client {
|
||||
if !cfg.Enabled {
|
||||
return &noopClient{}
|
||||
}
|
||||
|
||||
gorse := gorseio.NewGorseClient(cfg.Address, cfg.APIKey)
|
||||
return &client{
|
||||
config: cfg,
|
||||
gorse: gorse,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// IsEnabled 检查是否启用
|
||||
func (c *client) IsEnabled() bool {
|
||||
return c.config.Enabled
|
||||
}
|
||||
|
||||
// InsertFeedback 插入用户反馈
|
||||
func (c *client) InsertFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error {
|
||||
if !c.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := c.gorse.InsertFeedback(ctx, []gorseio.Feedback{
|
||||
{
|
||||
FeedbackType: string(feedbackType),
|
||||
UserId: userID,
|
||||
ItemId: itemID,
|
||||
Timestamp: time.Now().UTC().Truncate(time.Second),
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteFeedback 删除用户反馈
|
||||
func (c *client) DeleteFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error {
|
||||
if !c.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := c.gorse.DeleteFeedback(ctx, string(feedbackType), userID, itemID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetRecommend 获取个性化推荐列表
|
||||
func (c *client) GetRecommend(ctx context.Context, userID string, n int, offset int) ([]string, error) {
|
||||
if !c.config.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result, err := c.gorse.GetRecommend(ctx, userID, "", n, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetNonPersonalized 获取非个性化推荐
|
||||
// name: 推荐器名称,如 "most_liked_weekly"
|
||||
// n: 返回数量
|
||||
// offset: 偏移量
|
||||
// userID: 可选,用于排除用户已读物品
|
||||
func (c *client) GetNonPersonalized(ctx context.Context, name string, n int, offset int, userID string) ([]string, error) {
|
||||
if !c.config.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 构建URL
|
||||
url := fmt.Sprintf("%s/api/non-personalized/%s?n=%d&offset=%d", c.config.Address, name, n, offset)
|
||||
if userID != "" {
|
||||
url += fmt.Sprintf("&user-id=%s", userID)
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// 设置API Key
|
||||
if c.config.APIKey != "" {
|
||||
req.Header.Set("X-API-Key", c.config.APIKey)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("gorse api error: status=%d, body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var scores []Score
|
||||
if err := json.Unmarshal(body, &scores); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
// 提取ID
|
||||
ids := make([]string, len(scores))
|
||||
for i, score := range scores {
|
||||
ids[i] = score.Id
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// UpsertItem 插入或更新物品
|
||||
func (c *client) UpsertItem(ctx context.Context, itemID string, categories []string, comment string) error {
|
||||
if !c.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := c.gorse.InsertItem(ctx, gorseio.Item{
|
||||
ItemId: itemID,
|
||||
IsHidden: false,
|
||||
Categories: categories,
|
||||
Comment: comment,
|
||||
Timestamp: time.Now().UTC().Truncate(time.Second),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// UpsertItemWithEmbedding 插入或更新物品(带embedding)
|
||||
func (c *client) UpsertItemWithEmbedding(ctx context.Context, itemID string, categories []string, comment string, textToEmbed string) error {
|
||||
if !c.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 生成embedding
|
||||
var embedding []float64
|
||||
if textToEmbed != "" {
|
||||
var err error
|
||||
embedding, err = GetEmbedding(textToEmbed)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Failed to get embedding for item %s: %v, using zero vector", itemID, err)
|
||||
embedding = make([]float64, 1024)
|
||||
}
|
||||
} else {
|
||||
embedding = make([]float64, 1024)
|
||||
}
|
||||
|
||||
_, err := c.gorse.InsertItem(ctx, gorseio.Item{
|
||||
ItemId: itemID,
|
||||
IsHidden: false,
|
||||
Categories: categories,
|
||||
Comment: comment,
|
||||
Timestamp: time.Now().UTC().Truncate(time.Second),
|
||||
Labels: map[string]any{
|
||||
"embedding": embedding,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteItem 删除物品
|
||||
func (c *client) DeleteItem(ctx context.Context, itemID string) error {
|
||||
if !c.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := c.gorse.DeleteItem(ctx, itemID)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpsertUser 插入或更新用户
|
||||
func (c *client) UpsertUser(ctx context.Context, userID string, labels map[string]any) error {
|
||||
if !c.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := c.gorse.InsertUser(ctx, gorseio.User{
|
||||
UserId: userID,
|
||||
Labels: labels,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// noopClient 空操作客户端(用于未启用推荐功能时)
|
||||
type noopClient struct{}
|
||||
|
||||
func (c *noopClient) IsEnabled() bool { return false }
|
||||
func (c *noopClient) InsertFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error {
|
||||
return nil
|
||||
}
|
||||
func (c *noopClient) DeleteFeedback(ctx context.Context, feedbackType FeedbackType, userID, itemID string) error {
|
||||
return nil
|
||||
}
|
||||
func (c *noopClient) GetRecommend(ctx context.Context, userID string, n int, offset int) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *noopClient) GetNonPersonalized(ctx context.Context, name string, n int, offset int, userID string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *noopClient) UpsertItem(ctx context.Context, itemID string, categories []string, comment string) error {
|
||||
return nil
|
||||
}
|
||||
func (c *noopClient) UpsertItemWithEmbedding(ctx context.Context, itemID string, categories []string, comment string, textToEmbed string) error {
|
||||
return nil
|
||||
}
|
||||
func (c *noopClient) DeleteItem(ctx context.Context, itemID string) error { return nil }
|
||||
func (c *noopClient) UpsertUser(ctx context.Context, userID string, labels map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 确保实现了接口
|
||||
var _ Client = (*client)(nil)
|
||||
var _ Client = (*noopClient)(nil)
|
||||
|
||||
// log 用于内部日志
|
||||
func init() {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
}
|
||||
23
internal/pkg/gorse/config.go
Normal file
23
internal/pkg/gorse/config.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package gorse
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/config"
|
||||
)
|
||||
|
||||
// Config Gorse客户端配置(从config.GorseConfig转换)
|
||||
type Config struct {
|
||||
Address string
|
||||
APIKey string
|
||||
Enabled bool
|
||||
Dashboard string
|
||||
}
|
||||
|
||||
// ConfigFromAppConfig 从应用配置创建Gorse配置
|
||||
func ConfigFromAppConfig(cfg *config.GorseConfig) Config {
|
||||
return Config{
|
||||
Address: cfg.Address,
|
||||
APIKey: cfg.APIKey,
|
||||
Enabled: cfg.Enabled,
|
||||
Dashboard: cfg.Dashboard,
|
||||
}
|
||||
}
|
||||
106
internal/pkg/gorse/embedding.go
Normal file
106
internal/pkg/gorse/embedding.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package gorse
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EmbeddingConfig embedding服务配置
|
||||
type EmbeddingConfig struct {
|
||||
APIKey string
|
||||
URL string
|
||||
Model string
|
||||
}
|
||||
|
||||
var defaultEmbeddingConfig = EmbeddingConfig{
|
||||
APIKey: "sk-ZPN5NMPSqEaOGCPfD2LqndZ5Wwmw3DC4CQgzgKhM35fI3RpD",
|
||||
URL: "https://api.littlelan.cn/v1/embeddings",
|
||||
Model: "BAAI/bge-m3",
|
||||
}
|
||||
|
||||
// SetEmbeddingConfig 设置embedding配置
|
||||
func SetEmbeddingConfig(apiKey, url, model string) {
|
||||
if apiKey != "" {
|
||||
defaultEmbeddingConfig.APIKey = apiKey
|
||||
}
|
||||
if url != "" {
|
||||
defaultEmbeddingConfig.URL = url
|
||||
}
|
||||
if model != "" {
|
||||
defaultEmbeddingConfig.Model = model
|
||||
}
|
||||
}
|
||||
|
||||
// GetEmbedding 获取文本的embedding
|
||||
func GetEmbedding(text string) ([]float64, error) {
|
||||
type embeddingRequest struct {
|
||||
Input string `json:"input"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type embeddingResponse struct {
|
||||
Data []struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
reqBody := embeddingRequest{
|
||||
Input: text,
|
||||
Model: defaultEmbeddingConfig.Model,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", defaultEmbeddingConfig.URL, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+defaultEmbeddingConfig.APIKey)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("embedding API error: status=%d, body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result embeddingResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Data) == 0 {
|
||||
return nil, fmt.Errorf("no embedding returned")
|
||||
}
|
||||
|
||||
return result.Data[0].Embedding, nil
|
||||
}
|
||||
|
||||
// InitEmbeddingWithConfig 从应用配置初始化embedding
|
||||
func InitEmbeddingWithConfig(apiKey, url, model string) {
|
||||
if apiKey == "" {
|
||||
log.Println("[WARN] Gorse embedding API key not set, using default")
|
||||
}
|
||||
defaultEmbeddingConfig.APIKey = apiKey
|
||||
if url != "" {
|
||||
defaultEmbeddingConfig.URL = url
|
||||
}
|
||||
if model != "" {
|
||||
defaultEmbeddingConfig.Model = model
|
||||
}
|
||||
}
|
||||
105
internal/pkg/jwt/jwt.go
Normal file
105
internal/pkg/jwt/jwt.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
ErrExpiredToken = errors.New("token has expired")
|
||||
)
|
||||
|
||||
// Claims JWT 声明
|
||||
type Claims struct {
|
||||
UserID string `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// JWT JWT工具
|
||||
type JWT struct {
|
||||
secretKey string
|
||||
accessTokenExpire time.Duration
|
||||
refreshTokenExpire time.Duration
|
||||
}
|
||||
|
||||
// New 创建JWT实例
|
||||
func New(secret string, accessExpire, refreshExpire time.Duration) *JWT {
|
||||
return &JWT{
|
||||
secretKey: secret,
|
||||
accessTokenExpire: accessExpire,
|
||||
refreshTokenExpire: refreshExpire,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateAccessToken 生成访问令牌
|
||||
func (j *JWT) GenerateAccessToken(userID, username string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := Claims{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(j.accessTokenExpire)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Issuer: "carrot_bbs",
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(j.secretKey))
|
||||
}
|
||||
|
||||
// GenerateRefreshToken 生成刷新令牌
|
||||
func (j *JWT) GenerateRefreshToken(userID, username string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := Claims{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(j.refreshTokenExpire)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
Issuer: "carrot_bbs",
|
||||
ID: "refresh",
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(j.secretKey))
|
||||
}
|
||||
|
||||
// ParseToken 解析令牌
|
||||
func (j *JWT) ParseToken(tokenString string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(j.secretKey), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
// ValidateToken 验证令牌
|
||||
func (j *JWT) ValidateToken(tokenString string) error {
|
||||
claims, err := j.ParseToken(tokenString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查是否是刷新令牌
|
||||
if claims.ID == "refresh" {
|
||||
return errors.New("cannot use refresh token as access token")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
438
internal/pkg/openai/client.go
Normal file
438
internal/pkg/openai/client.go
Normal file
@@ -0,0 +1,438 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
"image/jpeg"
|
||||
_ "image/png"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
xdraw "golang.org/x/image/draw"
|
||||
)
|
||||
|
||||
const moderationSystemPrompt = "你是中文社区的内容审核助手,负责对“帖子标题、正文、配图”做联合审核。目标是平衡社区安全与正常交流:必须拦截高风险违规内容,但不要误伤正常玩梗、二创、吐槽和轻度调侃。请只输出指定JSON。\n\n审核流程:\n1) 先判断是否命中硬性违规;\n2) 再判断语境(玩笑/自嘲/朋友间互动/作品讨论);\n3) 做文图交叉判断(文本+图片合并理解);\n4) 给出 approved 与简短 reason。\n\n硬性违规(命中任一项必须 approved=false):\nA. 宣传对立与煽动撕裂:\n- 明确煽动群体对立、地域对立、性别对立、民族宗教对立,鼓动仇恨、排斥、报复。\nB. 严重人身攻击与网暴引导:\n- 持续性侮辱贬损、羞辱人格、号召围攻/骚扰/挂人/线下冲突。\nC. 开盒/人肉/隐私暴露:\n- 故意公开、拼接、索取他人可识别隐私信息(姓名+联系方式、身份证号、住址、学校单位、车牌、定位轨迹等);\n- 图片/截图中出现可识别隐私信息并伴随曝光意图,也按违规处理。\nD. 其他高危违规:\n- 违法犯罪、暴力威胁、极端仇恨、色情低俗、诈骗引流、恶意广告等。\n\n放行规则(以下通常 approved=true):\n- 正常玩梗、表情包、谐音梗、二次创作、无恶意的吐槽;\n- 非定向、轻度口语化吐槽(无明确攻击对象、无网暴号召、无隐私暴露);\n- 对社会事件/作品的理性讨论、观点争论(即使语气尖锐,但未煽动对立或人身攻击)。\n\n边界判定:\n- 若只是“梗文化表达”且不指向现实伤害,优先通过;\n- 若存在明确伤害意图(煽动、围攻、曝光隐私),必须拒绝;\n- 对模糊内容不因个别粗口直接拒绝,需结合对象、意图、号召性和可执行性综合判断。\n\nreason 要求:\n- approved=false 时:中文10-30字,说明核心违规点;\n- approved=true 时:reason 为空字符串。\n\n输出格式(严格):\n仅输出一行JSON对象,不要Markdown,不要额外解释:\n{\"approved\": true/false, \"reason\": \"...\"}"
|
||||
|
||||
const (
|
||||
defaultMaxImagesPerModerationRequest = 1
|
||||
maxModerationResultRetries = 3
|
||||
maxChatCompletionRetries = 3
|
||||
initialRetryBackoff = 500 * time.Millisecond
|
||||
maxDownloadImageBytes = 10 * 1024 * 1024
|
||||
maxModerationImageSide = 1280
|
||||
compressedJPEGQuality = 72
|
||||
maxCompressedPayloadBytes = 1536 * 1024
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
IsEnabled() bool
|
||||
Config() Config
|
||||
ModeratePost(ctx context.Context, title, content string, images []string) (bool, string, error)
|
||||
ModerateComment(ctx context.Context, content string, images []string) (bool, string, error)
|
||||
}
|
||||
|
||||
type clientImpl struct {
|
||||
cfg Config
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewClient(cfg Config) Client {
|
||||
timeout := cfg.RequestTimeoutSeconds
|
||||
if timeout <= 0 {
|
||||
timeout = 30
|
||||
}
|
||||
return &clientImpl{
|
||||
cfg: cfg,
|
||||
httpClient: &http.Client{
|
||||
Timeout: time.Duration(timeout) * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientImpl) IsEnabled() bool {
|
||||
return c.cfg.Enabled && c.cfg.APIKey != "" && c.cfg.BaseURL != ""
|
||||
}
|
||||
|
||||
func (c *clientImpl) Config() Config {
|
||||
return c.cfg
|
||||
}
|
||||
|
||||
func (c *clientImpl) ModeratePost(ctx context.Context, title, content string, images []string) (bool, string, error) {
|
||||
if !c.IsEnabled() {
|
||||
return true, "", nil
|
||||
}
|
||||
return c.moderateContentInBatches(ctx, fmt.Sprintf("帖子标题:%s\n帖子内容:%s", title, content), images)
|
||||
}
|
||||
|
||||
func (c *clientImpl) ModerateComment(ctx context.Context, content string, images []string) (bool, string, error) {
|
||||
if !c.IsEnabled() {
|
||||
return true, "", nil
|
||||
}
|
||||
return c.moderateContentInBatches(ctx, fmt.Sprintf("评论内容:%s", content), images)
|
||||
}
|
||||
|
||||
func (c *clientImpl) moderateContentInBatches(ctx context.Context, contentPrompt string, images []string) (bool, string, error) {
|
||||
cleanImages := normalizeImageURLs(images)
|
||||
optimizedImages := c.optimizeImagesForModeration(ctx, cleanImages)
|
||||
maxImagesPerRequest := c.maxImagesPerModerationRequest()
|
||||
totalBatches := 1
|
||||
if len(optimizedImages) > 0 {
|
||||
totalBatches = (len(optimizedImages) + maxImagesPerRequest - 1) / maxImagesPerRequest
|
||||
}
|
||||
|
||||
// 图片超过单批上限时分批审核,任一批次拒绝即整体拒绝
|
||||
for i := 0; i < totalBatches; i++ {
|
||||
start := i * maxImagesPerRequest
|
||||
end := start + maxImagesPerRequest
|
||||
if end > len(optimizedImages) {
|
||||
end = len(optimizedImages)
|
||||
}
|
||||
|
||||
batchImages := []string{}
|
||||
if len(optimizedImages) > 0 {
|
||||
batchImages = optimizedImages[start:end]
|
||||
}
|
||||
|
||||
approved, reason, err := c.moderateSingleBatch(ctx, contentPrompt, batchImages, i+1, totalBatches)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
if !approved {
|
||||
if strings.TrimSpace(reason) != "" && totalBatches > 1 {
|
||||
reason = fmt.Sprintf("第%d/%d批图片未通过:%s", i+1, totalBatches, reason)
|
||||
}
|
||||
return false, reason, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
func (c *clientImpl) moderateSingleBatch(
|
||||
ctx context.Context,
|
||||
contentPrompt string,
|
||||
images []string,
|
||||
batchNo, totalBatches int,
|
||||
) (bool, string, error) {
|
||||
userPrompt := fmt.Sprintf(
|
||||
"%s\n图片批次:%d/%d(本次仅提供当前批次图片)",
|
||||
contentPrompt,
|
||||
batchNo,
|
||||
totalBatches,
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxModerationResultRetries; attempt++ {
|
||||
replyText, err := c.chatCompletion(ctx, c.cfg.ModerationModel, moderationSystemPrompt, userPrompt, images, 0.1, 220)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
} else {
|
||||
parsed := struct {
|
||||
Approved bool `json:"approved"`
|
||||
Reason string `json:"reason"`
|
||||
}{}
|
||||
if err := json.Unmarshal([]byte(extractJSONObject(replyText)), &parsed); err != nil {
|
||||
lastErr = fmt.Errorf("failed to parse moderation result: %w", err)
|
||||
} else {
|
||||
return parsed.Approved, parsed.Reason, nil
|
||||
}
|
||||
}
|
||||
|
||||
if attempt == maxModerationResultRetries-1 {
|
||||
break
|
||||
}
|
||||
if sleepErr := sleepWithBackoff(ctx, attempt); sleepErr != nil {
|
||||
return false, "", sleepErr
|
||||
}
|
||||
}
|
||||
|
||||
return false, "", fmt.Errorf(
|
||||
"moderation batch %d/%d failed after %d attempts: %w",
|
||||
batchNo,
|
||||
totalBatches,
|
||||
maxModerationResultRetries,
|
||||
lastErr,
|
||||
)
|
||||
}
|
||||
|
||||
type chatCompletionsRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type chatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content interface{} `json:"content"`
|
||||
}
|
||||
|
||||
type contentPart struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL *imageURLPart `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type imageURLPart struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type chatCompletionsResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
func (c *clientImpl) chatCompletion(
|
||||
ctx context.Context,
|
||||
model string,
|
||||
systemPrompt string,
|
||||
userPrompt string,
|
||||
images []string,
|
||||
temperature float64,
|
||||
maxTokens int,
|
||||
) (string, error) {
|
||||
if model == "" {
|
||||
return "", fmt.Errorf("model is empty")
|
||||
}
|
||||
|
||||
cleanImages := normalizeImageURLs(images)
|
||||
|
||||
userParts := []contentPart{
|
||||
{Type: "text", Text: userPrompt},
|
||||
}
|
||||
for _, image := range cleanImages {
|
||||
userParts = append(userParts, contentPart{
|
||||
Type: "image_url",
|
||||
ImageURL: &imageURLPart{URL: image},
|
||||
})
|
||||
}
|
||||
|
||||
reqBody := chatCompletionsRequest{
|
||||
Model: model,
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: userParts},
|
||||
},
|
||||
Temperature: temperature,
|
||||
MaxTokens: maxTokens,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(c.cfg.BaseURL, "/")
|
||||
endpoint := baseURL + "/v1/chat/completions"
|
||||
if strings.HasSuffix(baseURL, "/v1") {
|
||||
endpoint = baseURL + "/chat/completions"
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxChatCompletionRetries; attempt++ {
|
||||
body, statusCode, err := c.doChatCompletionRequest(ctx, endpoint, data)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
} else if statusCode >= 400 {
|
||||
lastErr = fmt.Errorf("openai error status=%d body=%s", statusCode, string(body))
|
||||
if !isRetryableStatusCode(statusCode) {
|
||||
return "", lastErr
|
||||
}
|
||||
} else {
|
||||
var parsed chatCompletionsResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return "", fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
if len(parsed.Choices) == 0 {
|
||||
return "", fmt.Errorf("empty response choices")
|
||||
}
|
||||
return parsed.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
if attempt == maxChatCompletionRetries-1 {
|
||||
break
|
||||
}
|
||||
if sleepErr := sleepWithBackoff(ctx, attempt); sleepErr != nil {
|
||||
return "", sleepErr
|
||||
}
|
||||
}
|
||||
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
func (c *clientImpl) doChatCompletionRequest(ctx context.Context, endpoint string, data []byte) ([]byte, int, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.cfg.APIKey)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("request openai: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
return body, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func isRetryableStatusCode(statusCode int) bool {
|
||||
if statusCode == http.StatusTooManyRequests {
|
||||
return true
|
||||
}
|
||||
return statusCode >= 500 && statusCode <= 599
|
||||
}
|
||||
|
||||
func sleepWithBackoff(ctx context.Context, attempt int) error {
|
||||
delay := initialRetryBackoff * time.Duration(1<<attempt)
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("request cancelled: %w", ctx.Err())
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeImageURLs(images []string) []string {
|
||||
clean := make([]string, 0, len(images))
|
||||
for _, image := range images {
|
||||
trimmed := strings.TrimSpace(image)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
clean = append(clean, trimmed)
|
||||
}
|
||||
return clean
|
||||
}
|
||||
|
||||
func extractJSONObject(raw string) string {
|
||||
text := strings.TrimSpace(raw)
|
||||
start := strings.Index(text, "{")
|
||||
end := strings.LastIndex(text, "}")
|
||||
if start >= 0 && end > start {
|
||||
return text[start : end+1]
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func (c *clientImpl) maxImagesPerModerationRequest() int {
|
||||
// 审核固定单图请求,降低单次payload体积,减少超时风险。
|
||||
if c.cfg.ModerationMaxImagesPerRequest <= 0 {
|
||||
return defaultMaxImagesPerModerationRequest
|
||||
}
|
||||
if c.cfg.ModerationMaxImagesPerRequest > 1 {
|
||||
return 1
|
||||
}
|
||||
return c.cfg.ModerationMaxImagesPerRequest
|
||||
}
|
||||
|
||||
func (c *clientImpl) optimizeImagesForModeration(ctx context.Context, images []string) []string {
|
||||
if len(images) == 0 {
|
||||
return images
|
||||
}
|
||||
|
||||
optimized := make([]string, 0, len(images))
|
||||
for _, imageURL := range images {
|
||||
optimized = append(optimized, c.tryCompressImageForModeration(ctx, imageURL))
|
||||
}
|
||||
return optimized
|
||||
}
|
||||
|
||||
func (c *clientImpl) tryCompressImageForModeration(ctx context.Context, imageURL string) string {
|
||||
parsed, err := url.Parse(imageURL)
|
||||
if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") {
|
||||
return imageURL
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, imageURL, nil)
|
||||
if err != nil {
|
||||
return imageURL
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return imageURL
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return imageURL
|
||||
}
|
||||
if !strings.HasPrefix(strings.ToLower(resp.Header.Get("Content-Type")), "image/") {
|
||||
return imageURL
|
||||
}
|
||||
|
||||
originBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxDownloadImageBytes))
|
||||
if err != nil || len(originBytes) == 0 {
|
||||
return imageURL
|
||||
}
|
||||
|
||||
srcImg, _, err := image.Decode(bytes.NewReader(originBytes))
|
||||
if err != nil {
|
||||
return imageURL
|
||||
}
|
||||
|
||||
dstImg := resizeIfNeeded(srcImg, maxModerationImageSide)
|
||||
var buf bytes.Buffer
|
||||
if err := jpeg.Encode(&buf, dstImg, &jpeg.Options{Quality: compressedJPEGQuality}); err != nil {
|
||||
return imageURL
|
||||
}
|
||||
|
||||
compressed := buf.Bytes()
|
||||
if len(compressed) == 0 || len(compressed) > maxCompressedPayloadBytes {
|
||||
return imageURL
|
||||
}
|
||||
// 压缩效果不明显时直接使用原图URL,避免增大请求体。
|
||||
if len(compressed) >= int(float64(len(originBytes))*0.95) {
|
||||
return imageURL
|
||||
}
|
||||
|
||||
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(compressed)
|
||||
}
|
||||
|
||||
func resizeIfNeeded(src image.Image, maxSide int) image.Image {
|
||||
bounds := src.Bounds()
|
||||
w := bounds.Dx()
|
||||
h := bounds.Dy()
|
||||
if w <= 0 || h <= 0 || max(w, h) <= maxSide {
|
||||
return src
|
||||
}
|
||||
|
||||
newW, newH := w, h
|
||||
if w >= h {
|
||||
newW = maxSide
|
||||
newH = int(float64(h) * (float64(maxSide) / float64(w)))
|
||||
} else {
|
||||
newH = maxSide
|
||||
newW = int(float64(w) * (float64(maxSide) / float64(h)))
|
||||
}
|
||||
if newW < 1 {
|
||||
newW = 1
|
||||
}
|
||||
if newH < 1 {
|
||||
newH = 1
|
||||
}
|
||||
|
||||
dst := image.NewRGBA(image.Rect(0, 0, newW, newH))
|
||||
xdraw.CatmullRom.Scale(dst, dst.Bounds(), src, bounds, xdraw.Over, nil)
|
||||
return dst
|
||||
}
|
||||
27
internal/pkg/openai/config.go
Normal file
27
internal/pkg/openai/config.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package openai
|
||||
|
||||
import "carrot_bbs/internal/config"
|
||||
|
||||
// Config OpenAI 兼容接口配置
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
BaseURL string
|
||||
APIKey string
|
||||
ModerationModel string
|
||||
ModerationMaxImagesPerRequest int
|
||||
RequestTimeoutSeconds int
|
||||
StrictModeration bool
|
||||
}
|
||||
|
||||
// ConfigFromAppConfig 从应用配置转换
|
||||
func ConfigFromAppConfig(cfg *config.OpenAIConfig) Config {
|
||||
return Config{
|
||||
Enabled: cfg.Enabled,
|
||||
BaseURL: cfg.BaseURL,
|
||||
APIKey: cfg.APIKey,
|
||||
ModerationModel: cfg.ModerationModel,
|
||||
ModerationMaxImagesPerRequest: cfg.ModerationMaxImagesPerRequest,
|
||||
RequestTimeoutSeconds: cfg.RequestTimeout,
|
||||
StrictModeration: cfg.StrictModeration,
|
||||
}
|
||||
}
|
||||
119
internal/pkg/redis/redis.go
Normal file
119
internal/pkg/redis/redis.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"carrot_bbs/internal/config"
|
||||
)
|
||||
|
||||
// Client Redis客户端
|
||||
type Client struct {
|
||||
rdb *redis.Client
|
||||
isMiniRedis bool
|
||||
mr *miniredis.Miniredis
|
||||
}
|
||||
|
||||
// New 创建Redis客户端
|
||||
func New(cfg *config.RedisConfig) (*Client, error) {
|
||||
switch cfg.Type {
|
||||
case "miniredis":
|
||||
// 启动内嵌Redis模拟
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start miniredis: %w", err)
|
||||
}
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: mr.Addr(),
|
||||
Password: "",
|
||||
DB: 0,
|
||||
})
|
||||
return &Client{
|
||||
rdb: rdb,
|
||||
isMiniRedis: true,
|
||||
mr: mr,
|
||||
}, nil
|
||||
case "redis":
|
||||
// 使用真实Redis
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Redis.Addr(),
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
PoolSize: cfg.PoolSize,
|
||||
})
|
||||
ctx := context.Background()
|
||||
if err := rdb.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to redis: %w", err)
|
||||
}
|
||||
return &Client{rdb: rdb, isMiniRedis: false}, nil
|
||||
default:
|
||||
// 默认使用miniredis
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start miniredis: %w", err)
|
||||
}
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: mr.Addr(),
|
||||
})
|
||||
return &Client{
|
||||
rdb: rdb,
|
||||
isMiniRedis: true,
|
||||
mr: mr,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Get 获取值
|
||||
func (c *Client) Get(ctx context.Context, key string) (string, error) {
|
||||
return c.rdb.Get(ctx, key).Result()
|
||||
}
|
||||
|
||||
// Set 设置值
|
||||
func (c *Client) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
|
||||
return c.rdb.Set(ctx, key, value, expiration).Err()
|
||||
}
|
||||
|
||||
// Del 删除键
|
||||
func (c *Client) Del(ctx context.Context, keys ...string) error {
|
||||
return c.rdb.Del(ctx, keys...).Err()
|
||||
}
|
||||
|
||||
// Exists 检查键是否存在
|
||||
func (c *Client) Exists(ctx context.Context, keys ...string) (int64, error) {
|
||||
return c.rdb.Exists(ctx, keys...).Result()
|
||||
}
|
||||
|
||||
// Incr 递增
|
||||
func (c *Client) Incr(ctx context.Context, key string) (int64, error) {
|
||||
return c.rdb.Incr(ctx, key).Result()
|
||||
}
|
||||
|
||||
// Expire 设置过期时间
|
||||
func (c *Client) Expire(ctx context.Context, key string, expiration time.Duration) (bool, error) {
|
||||
return c.rdb.Expire(ctx, key, expiration).Result()
|
||||
}
|
||||
|
||||
// GetClient 获取原生客户端
|
||||
func (c *Client) GetClient() *redis.Client {
|
||||
return c.rdb
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (c *Client) Close() error {
|
||||
if err := c.rdb.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if c.mr != nil {
|
||||
c.mr.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsMiniRedis 返回是否是miniredis
|
||||
func (c *Client) IsMiniRedis() bool {
|
||||
return c.isMiniRedis
|
||||
}
|
||||
117
internal/pkg/response/response.go
Normal file
117
internal/pkg/response/response.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Response 统一响应结构
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// ResponseSnakeCase 统一响应结构(snake_case)
|
||||
type ResponseSnakeCase struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Success 成功响应
|
||||
func Success(c *gin.Context, data interface{}) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// SuccessWithMessage 成功响应带消息
|
||||
func SuccessWithMessage(c *gin.Context, message string, data interface{}) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Code: 0,
|
||||
Message: message,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Error 错误响应
|
||||
func Error(c *gin.Context, code int, message string) {
|
||||
c.JSON(http.StatusBadRequest, Response{
|
||||
Code: code,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
|
||||
// ErrorWithStatusCode 带状态码的错误响应
|
||||
func ErrorWithStatusCode(c *gin.Context, statusCode int, code int, message string) {
|
||||
c.JSON(statusCode, Response{
|
||||
Code: code,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
|
||||
// BadRequest 参数错误
|
||||
func BadRequest(c *gin.Context, message string) {
|
||||
ErrorWithStatusCode(c, http.StatusBadRequest, 400, message)
|
||||
}
|
||||
|
||||
// Unauthorized 未授权
|
||||
func Unauthorized(c *gin.Context, message string) {
|
||||
if message == "" {
|
||||
message = "unauthorized"
|
||||
}
|
||||
ErrorWithStatusCode(c, http.StatusUnauthorized, 401, message)
|
||||
}
|
||||
|
||||
// Forbidden 禁止访问
|
||||
func Forbidden(c *gin.Context, message string) {
|
||||
if message == "" {
|
||||
message = "forbidden"
|
||||
}
|
||||
ErrorWithStatusCode(c, http.StatusForbidden, 403, message)
|
||||
}
|
||||
|
||||
// NotFound 资源不存在
|
||||
func NotFound(c *gin.Context, message string) {
|
||||
if message == "" {
|
||||
message = "resource not found"
|
||||
}
|
||||
ErrorWithStatusCode(c, http.StatusNotFound, 404, message)
|
||||
}
|
||||
|
||||
// InternalServerError 服务器内部错误
|
||||
func InternalServerError(c *gin.Context, message string) {
|
||||
if message == "" {
|
||||
message = "internal server error"
|
||||
}
|
||||
ErrorWithStatusCode(c, http.StatusInternalServerError, 500, message)
|
||||
}
|
||||
|
||||
// PaginatedResponse 分页响应
|
||||
type PaginatedResponse struct {
|
||||
List interface{} `json:"list"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
// Paginated 分页成功响应
|
||||
func Paginated(c *gin.Context, list interface{}, total int64, page, pageSize int) {
|
||||
totalPages := int(total) / pageSize
|
||||
if int(total)%pageSize > 0 {
|
||||
totalPages++
|
||||
}
|
||||
|
||||
Success(c, PaginatedResponse{
|
||||
List: list,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
})
|
||||
}
|
||||
119
internal/pkg/s3/s3.go
Normal file
119
internal/pkg/s3/s3.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
|
||||
"carrot_bbs/internal/config"
|
||||
)
|
||||
|
||||
// Client S3客户端
|
||||
type Client struct {
|
||||
client *minio.Client
|
||||
bucket string
|
||||
domain string
|
||||
}
|
||||
|
||||
// New 创建S3客户端
|
||||
func New(cfg *config.S3Config) (*Client, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
client, err := minio.New(cfg.Endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(cfg.AccessKey, cfg.SecretKey, ""),
|
||||
Secure: cfg.UseSSL,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create S3 client: %w", err)
|
||||
}
|
||||
|
||||
// 检查bucket是否存在
|
||||
exists, err := client.BucketExists(ctx, cfg.Bucket)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check bucket: %w", err)
|
||||
}
|
||||
|
||||
if !exists {
|
||||
if err := client.MakeBucket(ctx, cfg.Bucket, minio.MakeBucketOptions{
|
||||
Region: cfg.Region,
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("failed to create bucket: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有配置domain,则使用默认的endpoint
|
||||
domain := cfg.Domain
|
||||
if domain == "" {
|
||||
domain = cfg.Endpoint
|
||||
}
|
||||
|
||||
return &Client{
|
||||
client: client,
|
||||
bucket: cfg.Bucket,
|
||||
domain: domain,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Upload 上传文件
|
||||
func (c *Client) Upload(ctx context.Context, objectName string, filePath string, contentType string) (string, error) {
|
||||
_, err := c.client.FPutObject(ctx, c.bucket, objectName, filePath, minio.PutObjectOptions{
|
||||
ContentType: contentType,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to upload file: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/%s", c.bucket, objectName), nil
|
||||
}
|
||||
|
||||
// UploadData 上传数据
|
||||
func (c *Client) UploadData(ctx context.Context, objectName string, data []byte, contentType string) (string, error) {
|
||||
_, err := c.client.PutObject(ctx, c.bucket, objectName, bytes.NewReader(data), int64(len(data)), minio.PutObjectOptions{
|
||||
ContentType: contentType,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to upload data: %w", err)
|
||||
}
|
||||
|
||||
// 返回完整URL,包含bucket名称
|
||||
scheme := "https"
|
||||
if c.domain == c.bucket || c.domain == "" {
|
||||
scheme = "http"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s/%s/%s", scheme, c.domain, c.bucket, objectName), nil
|
||||
}
|
||||
|
||||
// GetURL 获取文件URL - 使用自定义域名
|
||||
func (c *Client) GetURL(ctx context.Context, objectName string) (string, error) {
|
||||
// 使用自定义域名构建URL,包含bucket名称
|
||||
scheme := "https"
|
||||
if c.domain == c.bucket || c.domain == "" {
|
||||
scheme = "http"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s/%s/%s", scheme, c.domain, c.bucket, objectName), nil
|
||||
}
|
||||
|
||||
// GetPresignedURL 获取预签名URL(用于私有桶)
|
||||
func (c *Client) GetPresignedURL(ctx context.Context, objectName string) (string, error) {
|
||||
url, err := c.client.PresignedGetObject(ctx, c.bucket, objectName, time.Hour*24, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get presigned URL: %w", err)
|
||||
}
|
||||
|
||||
return url.String(), nil
|
||||
}
|
||||
|
||||
// Delete 删除文件
|
||||
func (c *Client) Delete(ctx context.Context, objectName string) error {
|
||||
return c.client.RemoveObject(ctx, c.bucket, objectName, minio.RemoveObjectOptions{})
|
||||
}
|
||||
|
||||
// GetClient 获取原生客户端
|
||||
func (c *Client) GetClient() *minio.Client {
|
||||
return c.client
|
||||
}
|
||||
52
internal/pkg/utils/avatar.go
Normal file
52
internal/pkg/utils/avatar.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// AvatarServiceBaseURL 默认头像服务基础URL (使用 UI Avatars API)
|
||||
const AvatarServiceBaseURL = "https://ui-avatars.com/api"
|
||||
|
||||
// DefaultAvatarSize 默认头像尺寸
|
||||
const DefaultAvatarSize = 100
|
||||
|
||||
// AvatarInfo 头像信息
|
||||
type AvatarInfo struct {
|
||||
Username string
|
||||
Nickname string
|
||||
Avatar string
|
||||
}
|
||||
|
||||
// GetAvatarOrDefault 获取头像URL,如果为空则返回在线头像生成服务的URL
|
||||
// 优先使用已有的头像,否则使用昵称或用户名生成默认头像
|
||||
func GetAvatarOrDefault(username, nickname, avatar string) string {
|
||||
if avatar != "" {
|
||||
return avatar
|
||||
}
|
||||
// 使用用户名生成默认头像URL(优先使用昵称)
|
||||
displayName := nickname
|
||||
if displayName == "" {
|
||||
displayName = username
|
||||
}
|
||||
return GenerateDefaultAvatarURL(displayName)
|
||||
}
|
||||
|
||||
// GetAvatarOrDefaultFromInfo 从 AvatarInfo 获取头像URL
|
||||
func GetAvatarOrDefaultFromInfo(info AvatarInfo) string {
|
||||
return GetAvatarOrDefault(info.Username, info.Nickname, info.Avatar)
|
||||
}
|
||||
|
||||
// GenerateDefaultAvatarURL 生成默认头像URL
|
||||
// 使用 UI Avatars API 生成基于用户名首字母的头像
|
||||
func GenerateDefaultAvatarURL(name string) string {
|
||||
if name == "" {
|
||||
name = "?"
|
||||
}
|
||||
// 使用 UI Avatars API 生成头像
|
||||
params := url.Values{}
|
||||
params.Set("name", url.QueryEscape(name))
|
||||
params.Set("size", "100")
|
||||
params.Set("background", "0D8ABC") // 默认蓝色背景
|
||||
params.Set("color", "ffffff") // 白色文字
|
||||
return AvatarServiceBaseURL + "?" + params.Encode()
|
||||
}
|
||||
17
internal/pkg/utils/hash.go
Normal file
17
internal/pkg/utils/hash.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// HashPassword 密码哈希
|
||||
func HashPassword(password string) (string, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
// CheckPasswordHash 验证密码
|
||||
func CheckPasswordHash(password, hash string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
261
internal/pkg/utils/snowflake.go
Normal file
261
internal/pkg/utils/snowflake.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 雪花算法常量定义
|
||||
const (
|
||||
// 64位ID结构:1位符号位 + 41位时间戳 + 10位机器ID + 12位序列号
|
||||
|
||||
// 机器ID占用的位数
|
||||
nodeIDBits uint64 = 10
|
||||
// 序列号占用的位数
|
||||
sequenceBits uint64 = 12
|
||||
|
||||
// 机器ID的最大值 (0-1023)
|
||||
maxNodeID int64 = -1 ^ (-1 << nodeIDBits)
|
||||
// 序列号的最大值 (0-4095)
|
||||
maxSequence int64 = -1 ^ (-1 << sequenceBits)
|
||||
|
||||
// 机器ID左移位数
|
||||
nodeIDShift uint64 = sequenceBits
|
||||
// 时间戳左移位数
|
||||
timestampShift uint64 = sequenceBits + nodeIDBits
|
||||
|
||||
// 自定义纪元时间:2024-01-01 00:00:00 UTC
|
||||
// 使用自定义纪元可以延长ID有效期约24年(从2024年开始)
|
||||
customEpoch int64 = 1704067200000 // 2024-01-01 00:00:00 UTC 的毫秒时间戳
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
var (
|
||||
// ErrInvalidNodeID 机器ID无效
|
||||
ErrInvalidNodeID = errors.New("node ID must be between 0 and 1023")
|
||||
// ErrClockBackwards 时钟回拨
|
||||
ErrClockBackwards = errors.New("clock moved backwards, refusing to generate ID")
|
||||
)
|
||||
|
||||
// IDInfo 解析后的ID信息
|
||||
type IDInfo struct {
|
||||
Timestamp int64 // 生成ID时的时间戳(毫秒)
|
||||
NodeID int64 // 机器ID
|
||||
Sequence int64 // 序列号
|
||||
}
|
||||
|
||||
// Snowflake 雪花算法ID生成器
|
||||
type Snowflake struct {
|
||||
mu sync.Mutex // 互斥锁,保证线程安全
|
||||
nodeID int64 // 机器ID (0-1023)
|
||||
sequence int64 // 当前序列号 (0-4095)
|
||||
lastTimestamp int64 // 上次生成ID的时间戳
|
||||
}
|
||||
|
||||
// 全局雪花算法实例
|
||||
var (
|
||||
globalSnowflake *Snowflake
|
||||
globalSnowflakeOnce sync.Once
|
||||
globalSnowflakeErr error
|
||||
)
|
||||
|
||||
// InitSnowflake 初始化全局雪花算法实例
|
||||
// nodeID: 机器ID,范围0-1023,可以通过环境变量 NODE_ID 配置
|
||||
func InitSnowflake(nodeID int64) error {
|
||||
globalSnowflake, globalSnowflakeErr = NewSnowflake(nodeID)
|
||||
return globalSnowflakeErr
|
||||
}
|
||||
|
||||
// GetSnowflake 获取全局雪花算法实例
|
||||
// 如果未初始化,会自动使用默认配置初始化
|
||||
func GetSnowflake() *Snowflake {
|
||||
globalSnowflakeOnce.Do(func() {
|
||||
if globalSnowflake == nil {
|
||||
globalSnowflake, globalSnowflakeErr = NewSnowflake(-1)
|
||||
}
|
||||
})
|
||||
return globalSnowflake
|
||||
}
|
||||
|
||||
// NewSnowflake 创建雪花算法ID生成器实例
|
||||
// nodeID: 机器ID,范围0-1023,可以通过环境变量 NODE_ID 配置
|
||||
// 如果nodeID为-1,则尝试从环境变量 NODE_ID 读取
|
||||
func NewSnowflake(nodeID int64) (*Snowflake, error) {
|
||||
// 如果传入-1,尝试从环境变量读取
|
||||
if nodeID == -1 {
|
||||
nodeIDStr := os.Getenv("NODE_ID")
|
||||
if nodeIDStr != "" {
|
||||
// 解析环境变量
|
||||
parsedID, err := parseInt(nodeIDStr)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidNodeID
|
||||
}
|
||||
nodeID = parsedID
|
||||
} else {
|
||||
// 默认使用0
|
||||
nodeID = 0
|
||||
}
|
||||
}
|
||||
|
||||
// 验证机器ID范围
|
||||
if nodeID < 0 || nodeID > maxNodeID {
|
||||
return nil, ErrInvalidNodeID
|
||||
}
|
||||
|
||||
return &Snowflake{
|
||||
nodeID: nodeID,
|
||||
sequence: 0,
|
||||
lastTimestamp: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseInt 辅助函数:解析整数
|
||||
func parseInt(s string) (int64, error) {
|
||||
var result int64
|
||||
var negative bool
|
||||
|
||||
if len(s) == 0 {
|
||||
return 0, errors.New("empty string")
|
||||
}
|
||||
|
||||
i := 0
|
||||
if s[0] == '-' {
|
||||
negative = true
|
||||
i = 1
|
||||
}
|
||||
|
||||
for ; i < len(s); i++ {
|
||||
if s[i] < '0' || s[i] > '9' {
|
||||
return 0, errors.New("invalid character")
|
||||
}
|
||||
result = result*10 + int64(s[i]-'0')
|
||||
}
|
||||
|
||||
if negative {
|
||||
result = -result
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateID 生成唯一的雪花算法ID
|
||||
// 返回值:生成的ID,以及可能的错误(如时钟回拨)
|
||||
// 线程安全:使用互斥锁保证并发安全
|
||||
func (s *Snowflake) GenerateID() (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// 获取当前时间戳(毫秒)
|
||||
now := currentTimestamp()
|
||||
|
||||
// 处理时钟回拨
|
||||
if now < s.lastTimestamp {
|
||||
return 0, ErrClockBackwards
|
||||
}
|
||||
|
||||
// 同一毫秒内
|
||||
if now == s.lastTimestamp {
|
||||
// 序列号递增
|
||||
s.sequence = (s.sequence + 1) & maxSequence
|
||||
// 序列号溢出,等待下一毫秒
|
||||
if s.sequence == 0 {
|
||||
now = s.waitNextMillis(now)
|
||||
}
|
||||
} else {
|
||||
// 不同毫秒,序列号重置为0
|
||||
s.sequence = 0
|
||||
}
|
||||
|
||||
// 更新上次生成时间
|
||||
s.lastTimestamp = now
|
||||
|
||||
// 组装ID
|
||||
// ID结构:时间戳部分 | 机器ID部分 | 序列号部分
|
||||
id := ((now - customEpoch) << timestampShift) |
|
||||
(s.nodeID << nodeIDShift) |
|
||||
s.sequence
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// waitNextMillis 等待到下一毫秒
|
||||
// 参数:当前时间戳
|
||||
// 返回值:下一毫秒的时间戳
|
||||
func (s *Snowflake) waitNextMillis(timestamp int64) int64 {
|
||||
now := currentTimestamp()
|
||||
for now <= timestamp {
|
||||
now = currentTimestamp()
|
||||
}
|
||||
return now
|
||||
}
|
||||
|
||||
// ParseID 解析雪花算法ID,提取其中的信息
|
||||
// id: 要解析的雪花算法ID
|
||||
// 返回值:包含时间戳、机器ID、序列号的结构体
|
||||
func ParseID(id int64) *IDInfo {
|
||||
// 提取序列号(低12位)
|
||||
sequence := id & maxSequence
|
||||
|
||||
// 提取机器ID(中间10位)
|
||||
nodeID := (id >> nodeIDShift) & maxNodeID
|
||||
|
||||
// 提取时间戳(高41位)
|
||||
timestamp := (id >> timestampShift) + customEpoch
|
||||
|
||||
return &IDInfo{
|
||||
Timestamp: timestamp,
|
||||
NodeID: nodeID,
|
||||
Sequence: sequence,
|
||||
}
|
||||
}
|
||||
|
||||
// currentTimestamp 获取当前时间戳(毫秒)
|
||||
func currentTimestamp() int64 {
|
||||
return time.Now().UnixNano() / 1e6
|
||||
}
|
||||
|
||||
// GetNodeID 获取当前机器ID
|
||||
func (s *Snowflake) GetNodeID() int64 {
|
||||
return s.nodeID
|
||||
}
|
||||
|
||||
// GetCustomEpoch 获取自定义纪元时间
|
||||
func GetCustomEpoch() int64 {
|
||||
return customEpoch
|
||||
}
|
||||
|
||||
// IDToTime 将雪花算法ID转换为生成时间
|
||||
// 这是一个便捷方法,等价于 ParseID(id).Timestamp
|
||||
func IDToTime(id int64) time.Time {
|
||||
info := ParseID(id)
|
||||
return time.Unix(0, info.Timestamp*1e6) // 毫秒转纳秒
|
||||
}
|
||||
|
||||
// ValidateID 验证ID是否为有效的雪花算法ID
|
||||
// 检查时间戳是否在合理范围内
|
||||
func ValidateID(id int64) bool {
|
||||
if id <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
info := ParseID(id)
|
||||
|
||||
// 检查时间戳是否在合理范围内
|
||||
// 不能早于纪元时间,不能晚于当前时间太多(允许1分钟的时钟偏差)
|
||||
now := currentTimestamp()
|
||||
if info.Timestamp < customEpoch || info.Timestamp > now+60000 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查机器ID和序列号是否在有效范围内
|
||||
if info.NodeID < 0 || info.NodeID > maxNodeID {
|
||||
return false
|
||||
}
|
||||
if info.Sequence < 0 || info.Sequence > maxSequence {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
46
internal/pkg/utils/validator.go
Normal file
46
internal/pkg/utils/validator.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ValidateEmail 验证邮箱
|
||||
func ValidateEmail(email string) bool {
|
||||
pattern := `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`
|
||||
matched, _ := regexp.MatchString(pattern, email)
|
||||
return matched
|
||||
}
|
||||
|
||||
// ValidateUsername 验证用户名
|
||||
func ValidateUsername(username string) bool {
|
||||
if len(username) < 3 || len(username) > 30 {
|
||||
return false
|
||||
}
|
||||
pattern := `^[a-zA-Z0-9_]+$`
|
||||
matched, _ := regexp.MatchString(pattern, username)
|
||||
return matched
|
||||
}
|
||||
|
||||
// ValidatePassword 验证密码强度
|
||||
func ValidatePassword(password string) bool {
|
||||
if len(password) < 6 || len(password) > 50 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ValidatePhone 验证手机号
|
||||
func ValidatePhone(phone string) bool {
|
||||
pattern := `^1[3-9]\d{9}$`
|
||||
matched, _ := regexp.MatchString(pattern, phone)
|
||||
return matched
|
||||
}
|
||||
|
||||
// SanitizeHTML 清理HTML
|
||||
func SanitizeHTML(input string) string {
|
||||
// 简单实现,实际使用建议用专门库
|
||||
input = strings.ReplaceAll(input, "<", "<")
|
||||
input = strings.ReplaceAll(input, ">", ">")
|
||||
return input
|
||||
}
|
||||
440
internal/pkg/websocket/websocket.go
Normal file
440
internal/pkg/websocket/websocket.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// WebSocket消息类型常量
|
||||
const (
|
||||
MessageTypePing = "ping"
|
||||
MessageTypePong = "pong"
|
||||
MessageTypeMessage = "message"
|
||||
MessageTypeTyping = "typing"
|
||||
MessageTypeRead = "read"
|
||||
MessageTypeAck = "ack"
|
||||
MessageTypeError = "error"
|
||||
MessageTypeRecall = "recall" // 撤回消息
|
||||
MessageTypeSystem = "system" // 系统消息
|
||||
MessageTypeNotification = "notification" // 通知消息
|
||||
MessageTypeAnnouncement = "announcement" // 公告消息
|
||||
|
||||
// 群组相关消息类型
|
||||
MessageTypeGroupMessage = "group_message" // 群消息
|
||||
MessageTypeGroupTyping = "group_typing" // 群输入状态
|
||||
MessageTypeGroupNotice = "group_notice" // 群组通知(成员变动等)
|
||||
MessageTypeGroupMention = "group_mention" // @提及通知
|
||||
MessageTypeGroupRead = "group_read" // 群消息已读
|
||||
MessageTypeGroupRecall = "group_recall" // 群消息撤回
|
||||
|
||||
// Meta事件详细类型
|
||||
MetaDetailTypeHeartbeat = "heartbeat"
|
||||
MetaDetailTypeTyping = "typing"
|
||||
MetaDetailTypeAck = "ack" // 消息发送确认
|
||||
MetaDetailTypeRead = "read" // 已读回执
|
||||
)
|
||||
|
||||
// WSMessage WebSocket消息结构
|
||||
type WSMessage struct {
|
||||
Type string `json:"type"`
|
||||
Data interface{} `json:"data"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
// ChatMessage 聊天消息结构
|
||||
type ChatMessage struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
SenderID string `json:"sender_id"`
|
||||
Seq int64 `json:"seq"`
|
||||
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
|
||||
ReplyToID *string `json:"reply_to_id,omitempty"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
|
||||
// SystemMessage 系统消息结构
|
||||
type SystemMessage struct {
|
||||
ID string `json:"id"` // 消息ID
|
||||
Type string `json:"type"` // 消息子类型(如:account_banned, post_approved等)
|
||||
Title string `json:"title"` // 消息标题
|
||||
Content string `json:"content"` // 消息内容
|
||||
Data map[string]interface{} `json:"data"` // 额外数据
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// NotificationMessage 通知消息结构
|
||||
type NotificationMessage struct {
|
||||
ID string `json:"id"` // 通知ID
|
||||
Type string `json:"type"` // 通知类型(like, comment, follow, mention等)
|
||||
Title string `json:"title"` // 通知标题
|
||||
Content string `json:"content"` // 通知内容
|
||||
TriggerUser *NotificationUser `json:"trigger_user"` // 触发用户
|
||||
ResourceType string `json:"resource_type"` // 资源类型(post, comment等)
|
||||
ResourceID string `json:"resource_id"` // 资源ID
|
||||
Extra map[string]interface{} `json:"extra"` // 额外数据
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// NotificationUser 通知中的用户信息
|
||||
type NotificationUser struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Avatar string `json:"avatar"`
|
||||
}
|
||||
|
||||
// AnnouncementMessage 公告消息结构
|
||||
type AnnouncementMessage struct {
|
||||
ID string `json:"id"` // 公告ID
|
||||
Title string `json:"title"` // 公告标题
|
||||
Content string `json:"content"` // 公告内容
|
||||
Priority int `json:"priority"` // 优先级(1-10)
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// GroupMessage 群消息结构
|
||||
type GroupMessage struct {
|
||||
ID string `json:"id"` // 消息ID
|
||||
ConversationID string `json:"conversation_id"` // 会话ID(群聊会话)
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
SenderID string `json:"sender_id"` // 发送者ID
|
||||
Seq int64 `json:"seq"` // 消息序号
|
||||
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
|
||||
ReplyToID *string `json:"reply_to_id,omitempty"` // 回复的消息ID
|
||||
MentionUsers []uint64 `json:"mention_users,omitempty"` // @的用户ID列表
|
||||
MentionAll bool `json:"mention_all"` // 是否@所有人
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// GroupTypingMessage 群输入状态消息
|
||||
type GroupTypingMessage struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
UserID string `json:"user_id"` // 用户ID
|
||||
Username string `json:"username"` // 用户名
|
||||
IsTyping bool `json:"is_typing"` // 是否正在输入
|
||||
}
|
||||
|
||||
// GroupNoticeMessage 群组通知消息
|
||||
type GroupNoticeMessage struct {
|
||||
NoticeType string `json:"notice_type"` // 通知类型:member_join, member_leave, member_removed, role_changed, muted, unmuted, group_dissolved
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
Data interface{} `json:"data"` // 通知数据
|
||||
Timestamp int64 `json:"timestamp"` // 时间戳
|
||||
MessageID string `json:"message_id,omitempty"` // 消息ID(如果通知保存为消息)
|
||||
Seq int64 `json:"seq,omitempty"` // 消息序号(如果通知保存为消息)
|
||||
}
|
||||
|
||||
// GroupNoticeData 通知数据结构
|
||||
type GroupNoticeData struct {
|
||||
// 成员变动
|
||||
UserID string `json:"user_id,omitempty"` // 变动的用户ID
|
||||
Username string `json:"username,omitempty"` // 用户名
|
||||
OperatorID string `json:"operator_id,omitempty"` // 操作者ID
|
||||
OpName string `json:"op_name,omitempty"` // 操作者名称
|
||||
NewRole string `json:"new_role,omitempty"` // 新角色
|
||||
OldRole string `json:"old_role,omitempty"` // 旧角色
|
||||
MemberCount int `json:"member_count,omitempty"` // 当前成员数
|
||||
|
||||
// 群设置变更
|
||||
MuteAll bool `json:"mute_all,omitempty"` // 全员禁言状态
|
||||
}
|
||||
|
||||
// GroupMentionMessage @提及通知消息
|
||||
type GroupMentionMessage struct {
|
||||
GroupID string `json:"group_id"` // 群组ID
|
||||
MessageID string `json:"message_id"` // 消息ID
|
||||
FromUserID string `json:"from_user_id"` // 发送者ID
|
||||
FromName string `json:"from_name"` // 发送者名称
|
||||
Content string `json:"content"` // 消息内容预览
|
||||
MentionAll bool `json:"mention_all"` // 是否@所有人
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// AckMessage 消息发送确认结构
|
||||
type AckMessage struct {
|
||||
ConversationID string `json:"conversation_id"` // 会话ID
|
||||
GroupID string `json:"group_id,omitempty"` // 群组ID(群聊时)
|
||||
ID string `json:"id"` // 消息ID
|
||||
SenderID string `json:"sender_id"` // 发送者ID
|
||||
Seq int64 `json:"seq"` // 消息序号
|
||||
Segments model.MessageSegments `json:"segments"` // 消息链(结构体数组)
|
||||
CreatedAt int64 `json:"created_at"` // 创建时间戳
|
||||
}
|
||||
|
||||
// Client WebSocket客户端
|
||||
type Client struct {
|
||||
ID string
|
||||
UserID string
|
||||
Conn *websocket.Conn
|
||||
Send chan []byte
|
||||
Manager *WebSocketManager
|
||||
IsClosed bool
|
||||
Mu sync.Mutex
|
||||
closeOnce sync.Once // 确保 Send channel 只关闭一次
|
||||
}
|
||||
|
||||
// WebSocketManager WebSocket连接管理器
|
||||
type WebSocketManager struct {
|
||||
clients map[string]*Client // userID -> Client
|
||||
register chan *Client
|
||||
unregister chan *Client
|
||||
broadcast chan *BroadcastMessage
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// BroadcastMessage 广播消息
|
||||
type BroadcastMessage struct {
|
||||
Message *WSMessage
|
||||
ExcludeUser string // 排除的用户ID,为空表示不排除
|
||||
TargetUser string // 目标用户ID,为空表示广播给所有用户
|
||||
}
|
||||
|
||||
// NewWebSocketManager 创建WebSocket管理器
|
||||
func NewWebSocketManager() *WebSocketManager {
|
||||
return &WebSocketManager{
|
||||
clients: make(map[string]*Client),
|
||||
register: make(chan *Client, 100),
|
||||
unregister: make(chan *Client, 100),
|
||||
broadcast: make(chan *BroadcastMessage, 100),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动管理器
|
||||
func (m *WebSocketManager) Start() {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case client := <-m.register:
|
||||
m.mutex.Lock()
|
||||
m.clients[client.UserID] = client
|
||||
m.mutex.Unlock()
|
||||
log.Printf("WebSocket client connected: userID=%s, 当前在线用户数=%d", client.UserID, len(m.clients))
|
||||
|
||||
case client := <-m.unregister:
|
||||
m.mutex.Lock()
|
||||
if _, ok := m.clients[client.UserID]; ok {
|
||||
delete(m.clients, client.UserID)
|
||||
// 使用 closeOnce 确保 channel 只关闭一次,避免 panic
|
||||
client.closeOnce.Do(func() {
|
||||
close(client.Send)
|
||||
})
|
||||
log.Printf("WebSocket client disconnected: userID=%s", client.UserID)
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
case broadcast := <-m.broadcast:
|
||||
m.sendMessage(broadcast)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Register 注册客户端
|
||||
func (m *WebSocketManager) Register(client *Client) {
|
||||
m.register <- client
|
||||
}
|
||||
|
||||
// Unregister 注销客户端
|
||||
func (m *WebSocketManager) Unregister(client *Client) {
|
||||
m.unregister <- client
|
||||
}
|
||||
|
||||
// Broadcast 广播消息给所有用户
|
||||
func (m *WebSocketManager) Broadcast(msg *WSMessage) {
|
||||
m.broadcast <- &BroadcastMessage{
|
||||
Message: msg,
|
||||
TargetUser: "",
|
||||
}
|
||||
}
|
||||
|
||||
// SendToUser 发送消息给指定用户
|
||||
func (m *WebSocketManager) SendToUser(userID string, msg *WSMessage) {
|
||||
m.broadcast <- &BroadcastMessage{
|
||||
Message: msg,
|
||||
TargetUser: userID,
|
||||
}
|
||||
}
|
||||
|
||||
// SendToUsers 发送消息给指定用户列表
|
||||
func (m *WebSocketManager) SendToUsers(userIDs []string, msg *WSMessage) {
|
||||
for _, userID := range userIDs {
|
||||
m.SendToUser(userID, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// GetClient 获取客户端
|
||||
func (m *WebSocketManager) GetClient(userID string) (*Client, bool) {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
client, ok := m.clients[userID]
|
||||
return client, ok
|
||||
}
|
||||
|
||||
// GetAllClients 获取所有客户端
|
||||
func (m *WebSocketManager) GetAllClients() map[string]*Client {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return m.clients
|
||||
}
|
||||
|
||||
// GetClientCount 获取在线用户数量
|
||||
func (m *WebSocketManager) GetClientCount() int {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return len(m.clients)
|
||||
}
|
||||
|
||||
// IsUserOnline 检查用户是否在线
|
||||
func (m *WebSocketManager) IsUserOnline(userID string) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
_, ok := m.clients[userID]
|
||||
log.Printf("[DEBUG IsUserOnline] 检查用户 %s, 结果=%v, 当前在线用户=%v", userID, ok, m.clients)
|
||||
return ok
|
||||
}
|
||||
|
||||
// sendMessage 发送消息
|
||||
func (m *WebSocketManager) sendMessage(broadcast *BroadcastMessage) {
|
||||
msgBytes, err := json.Marshal(broadcast.Message)
|
||||
if err != nil {
|
||||
log.Printf("Failed to marshal message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG WebSocket] sendMessage: 目标用户=%s, 当前在线用户数=%d, 消息类型=%s",
|
||||
broadcast.TargetUser, len(m.clients), broadcast.Message.Type)
|
||||
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
for userID, client := range m.clients {
|
||||
// 如果指定了目标用户,只发送给目标用户
|
||||
if broadcast.TargetUser != "" && userID != broadcast.TargetUser {
|
||||
continue
|
||||
}
|
||||
|
||||
// 如果指定了排除用户,跳过
|
||||
if broadcast.ExcludeUser != "" && userID == broadcast.ExcludeUser {
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case client.Send <- msgBytes:
|
||||
log.Printf("[DEBUG WebSocket] 成功发送消息到用户 %s, 消息类型=%s", userID, broadcast.Message.Type)
|
||||
default:
|
||||
log.Printf("Failed to send message to user %s: channel full", userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SendPing 发送心跳
|
||||
func (c *Client) SendPing() error {
|
||||
c.Mu.Lock()
|
||||
defer c.Mu.Unlock()
|
||||
if c.IsClosed {
|
||||
return nil
|
||||
}
|
||||
msg := WSMessage{
|
||||
Type: MessageTypePing,
|
||||
Data: nil,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
msgBytes, _ := json.Marshal(msg)
|
||||
return c.Conn.WriteMessage(websocket.TextMessage, msgBytes)
|
||||
}
|
||||
|
||||
// SendPong 发送Pong响应
|
||||
func (c *Client) SendPong() error {
|
||||
c.Mu.Lock()
|
||||
defer c.Mu.Unlock()
|
||||
if c.IsClosed {
|
||||
return nil
|
||||
}
|
||||
msg := WSMessage{
|
||||
Type: MessageTypePong,
|
||||
Data: nil,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
msgBytes, _ := json.Marshal(msg)
|
||||
return c.Conn.WriteMessage(websocket.TextMessage, msgBytes)
|
||||
}
|
||||
|
||||
// WritePump 写入泵,将消息从Manager发送到客户端
|
||||
func (c *Client) WritePump() {
|
||||
defer func() {
|
||||
c.Conn.Close()
|
||||
c.Mu.Lock()
|
||||
c.IsClosed = true
|
||||
c.Mu.Unlock()
|
||||
}()
|
||||
|
||||
for {
|
||||
message, ok := <-c.Send
|
||||
if !ok {
|
||||
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
c.Mu.Lock()
|
||||
if c.IsClosed {
|
||||
c.Mu.Unlock()
|
||||
return
|
||||
}
|
||||
err := c.Conn.WriteMessage(websocket.TextMessage, message)
|
||||
c.Mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Write error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReadPump 读取泵,从客户端读取消息
|
||||
func (c *Client) ReadPump(handler func(msg *WSMessage)) {
|
||||
defer func() {
|
||||
c.Manager.Unregister(c)
|
||||
c.Conn.Close()
|
||||
c.Mu.Lock()
|
||||
c.IsClosed = true
|
||||
c.Mu.Unlock()
|
||||
}()
|
||||
|
||||
c.Conn.SetReadLimit(512 * 1024) // 512KB
|
||||
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
c.Conn.SetPongHandler(func(string) error {
|
||||
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
_, message, err := c.Conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Printf("WebSocket error: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
var wsMsg WSMessage
|
||||
if err := json.Unmarshal(message, &wsMsg); err != nil {
|
||||
log.Printf("Failed to unmarshal message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
handler(&wsMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateWSMessage 创建WebSocket消息
|
||||
func CreateWSMessage(msgType string, data interface{}) *WSMessage {
|
||||
return &WSMessage{
|
||||
Type: msgType,
|
||||
Data: data,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user