Files
backend/internal/service/yggdrasil_service_test.go

351 lines
7.6 KiB
Go
Raw Permalink Normal View History

package service
import (
"net"
"strings"
"testing"
"time"
)
// TestYggdrasilService_Constants 测试Yggdrasil服务常量
func TestYggdrasilService_Constants(t *testing.T) {
if SessionKeyPrefix != "Join_" {
t.Errorf("SessionKeyPrefix = %s, want 'Join_'", SessionKeyPrefix)
}
if SessionTTL != 15*time.Minute {
t.Errorf("SessionTTL = %v, want 15 minutes", SessionTTL)
}
}
// TestSessionData_Structure 测试SessionData结构
func TestSessionData_Structure(t *testing.T) {
data := SessionData{
AccessToken: "test-token",
UserName: "TestUser",
SelectedProfile: "test-profile-uuid",
IP: "127.0.0.1",
}
if data.AccessToken == "" {
t.Error("AccessToken should not be empty")
}
if data.UserName == "" {
t.Error("UserName should not be empty")
}
if data.SelectedProfile == "" {
t.Error("SelectedProfile should not be empty")
}
}
// TestJoinServer_InputValidation 测试JoinServer输入验证逻辑
func TestJoinServer_InputValidation(t *testing.T) {
tests := []struct {
name string
serverId string
accessToken string
selectedProfile string
wantErr bool
errContains string
}{
{
name: "所有参数有效",
serverId: "test-server-123",
accessToken: "test-token",
selectedProfile: "test-profile",
wantErr: false,
},
{
name: "serverId为空",
serverId: "",
accessToken: "test-token",
selectedProfile: "test-profile",
wantErr: true,
errContains: "参数不能为空",
},
{
name: "accessToken为空",
serverId: "test-server",
accessToken: "",
selectedProfile: "test-profile",
wantErr: true,
errContains: "参数不能为空",
},
{
name: "selectedProfile为空",
serverId: "test-server",
accessToken: "test-token",
selectedProfile: "",
wantErr: true,
errContains: "参数不能为空",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hasError := tt.serverId == "" || tt.accessToken == "" || tt.selectedProfile == ""
if hasError != tt.wantErr {
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
}
})
}
}
// TestJoinServer_ServerIDValidation 测试服务器ID格式验证
func TestJoinServer_ServerIDValidation(t *testing.T) {
tests := []struct {
name string
serverId string
wantValid bool
}{
{
name: "有效的serverId",
serverId: "test-server-123",
wantValid: true,
},
{
name: "serverId过长",
serverId: strings.Repeat("a", 101),
wantValid: false,
},
{
name: "serverId包含危险字符<",
serverId: "test<server",
wantValid: false,
},
{
name: "serverId包含危险字符>",
serverId: "test>server",
wantValid: false,
},
{
name: "serverId包含危险字符\"",
serverId: "test\"server",
wantValid: false,
},
{
name: "serverId包含危险字符'",
serverId: "test'server",
wantValid: false,
},
{
name: "serverId包含危险字符&",
serverId: "test&server",
wantValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := len(tt.serverId) <= 100 && !strings.ContainsAny(tt.serverId, "<>\"'&")
if isValid != tt.wantValid {
t.Errorf("ServerID validation failed: got %v, want %v", isValid, tt.wantValid)
}
})
}
}
// TestJoinServer_IPValidation 测试IP地址验证逻辑
func TestJoinServer_IPValidation(t *testing.T) {
tests := []struct {
name string
ip string
wantValid bool
}{
{
name: "有效的IPv4地址",
ip: "127.0.0.1",
wantValid: true,
},
{
name: "有效的IPv6地址",
ip: "::1",
wantValid: true,
},
{
name: "无效的IP地址",
ip: "invalid-ip",
wantValid: false,
},
{
name: "空IP地址可选",
ip: "",
wantValid: true, // 空IP是允许的
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var isValid bool
if tt.ip == "" {
isValid = true // 空IP是允许的
} else {
isValid = net.ParseIP(tt.ip) != nil
}
if isValid != tt.wantValid {
t.Errorf("IP validation failed: got %v, want %v (ip=%s)", isValid, tt.wantValid, tt.ip)
}
})
}
}
// TestHasJoinedServer_InputValidation 测试HasJoinedServer输入验证
func TestHasJoinedServer_InputValidation(t *testing.T) {
tests := []struct {
name string
serverId string
username string
wantErr bool
}{
{
name: "所有参数有效",
serverId: "test-server",
username: "TestUser",
wantErr: false,
},
{
name: "serverId为空",
serverId: "",
username: "TestUser",
wantErr: true,
},
{
name: "username为空",
serverId: "test-server",
username: "",
wantErr: true,
},
{
name: "两者都为空",
serverId: "",
username: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hasError := tt.serverId == "" || tt.username == ""
if hasError != tt.wantErr {
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
}
})
}
}
// TestHasJoinedServer_UsernameMatching 测试用户名匹配逻辑
func TestHasJoinedServer_UsernameMatching(t *testing.T) {
tests := []struct {
name string
sessionUser string
requestUser string
wantMatch bool
}{
{
name: "用户名匹配",
sessionUser: "TestUser",
requestUser: "TestUser",
wantMatch: true,
},
{
name: "用户名不匹配",
sessionUser: "TestUser",
requestUser: "OtherUser",
wantMatch: false,
},
{
name: "大小写敏感",
sessionUser: "TestUser",
requestUser: "testuser",
wantMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matches := tt.sessionUser == tt.requestUser
if matches != tt.wantMatch {
t.Errorf("Username matching failed: got %v, want %v", matches, tt.wantMatch)
}
})
}
}
// TestHasJoinedServer_IPMatching 测试IP地址匹配逻辑
func TestHasJoinedServer_IPMatching(t *testing.T) {
tests := []struct {
name string
sessionIP string
requestIP string
wantMatch bool
shouldCheck bool
}{
{
name: "IP匹配",
sessionIP: "127.0.0.1",
requestIP: "127.0.0.1",
wantMatch: true,
shouldCheck: true,
},
{
name: "IP不匹配",
sessionIP: "127.0.0.1",
requestIP: "192.168.1.1",
wantMatch: false,
shouldCheck: true,
},
{
name: "请求IP为空时不检查",
sessionIP: "127.0.0.1",
requestIP: "",
wantMatch: true,
shouldCheck: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var matches bool
if tt.requestIP == "" {
matches = true // 空IP不检查
} else {
matches = tt.sessionIP == tt.requestIP
}
if matches != tt.wantMatch {
t.Errorf("IP matching failed: got %v, want %v", matches, tt.wantMatch)
}
})
}
}
// TestJoinServer_SessionKey 测试会话键生成
func TestJoinServer_SessionKey(t *testing.T) {
tests := []struct {
name string
serverId string
expected string
}{
{
name: "生成正确的会话键",
serverId: "test-server-123",
expected: "Join_test-server-123",
},
{
name: "空serverId",
serverId: "",
expected: "Join_",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sessionKey := SessionKeyPrefix + tt.serverId
if sessionKey != tt.expected {
t.Errorf("Session key = %s, want %s", sessionKey, tt.expected)
}
})
}
}