351 lines
7.6 KiB
Go
351 lines
7.6 KiB
Go
package service
|
||
|
||
import (
|
||
"net"
|
||
"strings"
|
||
"testing"
|
||
"time"
|
||
)
|
||
|
||
// TestYggdrasilService_Constants 测试Yggdrasil服务常量
|
||
func TestYggdrasilService_Constants(t *testing.T) {
|
||
if SessionKeyPrefix != "Join_" {
|
||
t.Errorf("SessionKeyPrefix = %s, want 'Join_'", SessionKeyPrefix)
|
||
}
|
||
|
||
if SessionTTL != 15*time.Minute {
|
||
t.Errorf("SessionTTL = %v, want 15 minutes", SessionTTL)
|
||
}
|
||
}
|
||
|
||
// TestSessionData_Structure 测试SessionData结构
|
||
func TestSessionData_Structure(t *testing.T) {
|
||
data := SessionData{
|
||
AccessToken: "test-token",
|
||
UserName: "TestUser",
|
||
SelectedProfile: "test-profile-uuid",
|
||
IP: "127.0.0.1",
|
||
}
|
||
|
||
if data.AccessToken == "" {
|
||
t.Error("AccessToken should not be empty")
|
||
}
|
||
|
||
if data.UserName == "" {
|
||
t.Error("UserName should not be empty")
|
||
}
|
||
|
||
if data.SelectedProfile == "" {
|
||
t.Error("SelectedProfile should not be empty")
|
||
}
|
||
}
|
||
|
||
// TestJoinServer_InputValidation 测试JoinServer输入验证逻辑
|
||
func TestJoinServer_InputValidation(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
serverId string
|
||
accessToken string
|
||
selectedProfile string
|
||
wantErr bool
|
||
errContains string
|
||
}{
|
||
{
|
||
name: "所有参数有效",
|
||
serverId: "test-server-123",
|
||
accessToken: "test-token",
|
||
selectedProfile: "test-profile",
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "serverId为空",
|
||
serverId: "",
|
||
accessToken: "test-token",
|
||
selectedProfile: "test-profile",
|
||
wantErr: true,
|
||
errContains: "参数不能为空",
|
||
},
|
||
{
|
||
name: "accessToken为空",
|
||
serverId: "test-server",
|
||
accessToken: "",
|
||
selectedProfile: "test-profile",
|
||
wantErr: true,
|
||
errContains: "参数不能为空",
|
||
},
|
||
{
|
||
name: "selectedProfile为空",
|
||
serverId: "test-server",
|
||
accessToken: "test-token",
|
||
selectedProfile: "",
|
||
wantErr: true,
|
||
errContains: "参数不能为空",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
hasError := tt.serverId == "" || tt.accessToken == "" || tt.selectedProfile == ""
|
||
if hasError != tt.wantErr {
|
||
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestJoinServer_ServerIDValidation 测试服务器ID格式验证
|
||
func TestJoinServer_ServerIDValidation(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
serverId string
|
||
wantValid bool
|
||
}{
|
||
{
|
||
name: "有效的serverId",
|
||
serverId: "test-server-123",
|
||
wantValid: true,
|
||
},
|
||
{
|
||
name: "serverId过长",
|
||
serverId: strings.Repeat("a", 101),
|
||
wantValid: false,
|
||
},
|
||
{
|
||
name: "serverId包含危险字符<",
|
||
serverId: "test<server",
|
||
wantValid: false,
|
||
},
|
||
{
|
||
name: "serverId包含危险字符>",
|
||
serverId: "test>server",
|
||
wantValid: false,
|
||
},
|
||
{
|
||
name: "serverId包含危险字符\"",
|
||
serverId: "test\"server",
|
||
wantValid: false,
|
||
},
|
||
{
|
||
name: "serverId包含危险字符'",
|
||
serverId: "test'server",
|
||
wantValid: false,
|
||
},
|
||
{
|
||
name: "serverId包含危险字符&",
|
||
serverId: "test&server",
|
||
wantValid: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
isValid := len(tt.serverId) <= 100 && !strings.ContainsAny(tt.serverId, "<>\"'&")
|
||
if isValid != tt.wantValid {
|
||
t.Errorf("ServerID validation failed: got %v, want %v", isValid, tt.wantValid)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestJoinServer_IPValidation 测试IP地址验证逻辑
|
||
func TestJoinServer_IPValidation(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
ip string
|
||
wantValid bool
|
||
}{
|
||
{
|
||
name: "有效的IPv4地址",
|
||
ip: "127.0.0.1",
|
||
wantValid: true,
|
||
},
|
||
{
|
||
name: "有效的IPv6地址",
|
||
ip: "::1",
|
||
wantValid: true,
|
||
},
|
||
{
|
||
name: "无效的IP地址",
|
||
ip: "invalid-ip",
|
||
wantValid: false,
|
||
},
|
||
{
|
||
name: "空IP地址(可选)",
|
||
ip: "",
|
||
wantValid: true, // 空IP是允许的
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
var isValid bool
|
||
if tt.ip == "" {
|
||
isValid = true // 空IP是允许的
|
||
} else {
|
||
isValid = net.ParseIP(tt.ip) != nil
|
||
}
|
||
if isValid != tt.wantValid {
|
||
t.Errorf("IP validation failed: got %v, want %v (ip=%s)", isValid, tt.wantValid, tt.ip)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHasJoinedServer_InputValidation 测试HasJoinedServer输入验证
|
||
func TestHasJoinedServer_InputValidation(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
serverId string
|
||
username string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "所有参数有效",
|
||
serverId: "test-server",
|
||
username: "TestUser",
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "serverId为空",
|
||
serverId: "",
|
||
username: "TestUser",
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "username为空",
|
||
serverId: "test-server",
|
||
username: "",
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "两者都为空",
|
||
serverId: "",
|
||
username: "",
|
||
wantErr: true,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
hasError := tt.serverId == "" || tt.username == ""
|
||
if hasError != tt.wantErr {
|
||
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHasJoinedServer_UsernameMatching 测试用户名匹配逻辑
|
||
func TestHasJoinedServer_UsernameMatching(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
sessionUser string
|
||
requestUser string
|
||
wantMatch bool
|
||
}{
|
||
{
|
||
name: "用户名匹配",
|
||
sessionUser: "TestUser",
|
||
requestUser: "TestUser",
|
||
wantMatch: true,
|
||
},
|
||
{
|
||
name: "用户名不匹配",
|
||
sessionUser: "TestUser",
|
||
requestUser: "OtherUser",
|
||
wantMatch: false,
|
||
},
|
||
{
|
||
name: "大小写敏感",
|
||
sessionUser: "TestUser",
|
||
requestUser: "testuser",
|
||
wantMatch: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
matches := tt.sessionUser == tt.requestUser
|
||
if matches != tt.wantMatch {
|
||
t.Errorf("Username matching failed: got %v, want %v", matches, tt.wantMatch)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestHasJoinedServer_IPMatching 测试IP地址匹配逻辑
|
||
func TestHasJoinedServer_IPMatching(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
sessionIP string
|
||
requestIP string
|
||
wantMatch bool
|
||
shouldCheck bool
|
||
}{
|
||
{
|
||
name: "IP匹配",
|
||
sessionIP: "127.0.0.1",
|
||
requestIP: "127.0.0.1",
|
||
wantMatch: true,
|
||
shouldCheck: true,
|
||
},
|
||
{
|
||
name: "IP不匹配",
|
||
sessionIP: "127.0.0.1",
|
||
requestIP: "192.168.1.1",
|
||
wantMatch: false,
|
||
shouldCheck: true,
|
||
},
|
||
{
|
||
name: "请求IP为空时不检查",
|
||
sessionIP: "127.0.0.1",
|
||
requestIP: "",
|
||
wantMatch: true,
|
||
shouldCheck: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
var matches bool
|
||
if tt.requestIP == "" {
|
||
matches = true // 空IP不检查
|
||
} else {
|
||
matches = tt.sessionIP == tt.requestIP
|
||
}
|
||
if matches != tt.wantMatch {
|
||
t.Errorf("IP matching failed: got %v, want %v", matches, tt.wantMatch)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestJoinServer_SessionKey 测试会话键生成
|
||
func TestJoinServer_SessionKey(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
serverId string
|
||
expected string
|
||
}{
|
||
{
|
||
name: "生成正确的会话键",
|
||
serverId: "test-server-123",
|
||
expected: "Join_test-server-123",
|
||
},
|
||
{
|
||
name: "空serverId",
|
||
serverId: "",
|
||
expected: "Join_",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
sessionKey := SessionKeyPrefix + tt.serverId
|
||
if sessionKey != tt.expected {
|
||
t.Errorf("Session key = %s, want %s", sessionKey, tt.expected)
|
||
}
|
||
})
|
||
}
|
||
}
|