351 lines
7.6 KiB
Go
351 lines
7.6 KiB
Go
|
|
package service
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"net"
|
|||
|
|
"strings"
|
|||
|
|
"testing"
|
|||
|
|
"time"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// TestYggdrasilService_Constants 测试Yggdrasil服务常量
|
|||
|
|
func TestYggdrasilService_Constants(t *testing.T) {
|
|||
|
|
if SessionKeyPrefix != "Join_" {
|
|||
|
|
t.Errorf("SessionKeyPrefix = %s, want 'Join_'", SessionKeyPrefix)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if SessionTTL != 15*time.Minute {
|
|||
|
|
t.Errorf("SessionTTL = %v, want 15 minutes", SessionTTL)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TestSessionData_Structure 测试SessionData结构
|
|||
|
|
func TestSessionData_Structure(t *testing.T) {
|
|||
|
|
data := SessionData{
|
|||
|
|
AccessToken: "test-token",
|
|||
|
|
UserName: "TestUser",
|
|||
|
|
SelectedProfile: "test-profile-uuid",
|
|||
|
|
IP: "127.0.0.1",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if data.AccessToken == "" {
|
|||
|
|
t.Error("AccessToken should not be empty")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if data.UserName == "" {
|
|||
|
|
t.Error("UserName should not be empty")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if data.SelectedProfile == "" {
|
|||
|
|
t.Error("SelectedProfile should not be empty")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TestJoinServer_InputValidation 测试JoinServer输入验证逻辑
|
|||
|
|
func TestJoinServer_InputValidation(t *testing.T) {
|
|||
|
|
tests := []struct {
|
|||
|
|
name string
|
|||
|
|
serverId string
|
|||
|
|
accessToken string
|
|||
|
|
selectedProfile string
|
|||
|
|
wantErr bool
|
|||
|
|
errContains string
|
|||
|
|
}{
|
|||
|
|
{
|
|||
|
|
name: "所有参数有效",
|
|||
|
|
serverId: "test-server-123",
|
|||
|
|
accessToken: "test-token",
|
|||
|
|
selectedProfile: "test-profile",
|
|||
|
|
wantErr: false,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "serverId为空",
|
|||
|
|
serverId: "",
|
|||
|
|
accessToken: "test-token",
|
|||
|
|
selectedProfile: "test-profile",
|
|||
|
|
wantErr: true,
|
|||
|
|
errContains: "参数不能为空",
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "accessToken为空",
|
|||
|
|
serverId: "test-server",
|
|||
|
|
accessToken: "",
|
|||
|
|
selectedProfile: "test-profile",
|
|||
|
|
wantErr: true,
|
|||
|
|
errContains: "参数不能为空",
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "selectedProfile为空",
|
|||
|
|
serverId: "test-server",
|
|||
|
|
accessToken: "test-token",
|
|||
|
|
selectedProfile: "",
|
|||
|
|
wantErr: true,
|
|||
|
|
errContains: "参数不能为空",
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, tt := range tests {
|
|||
|
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
|
hasError := tt.serverId == "" || tt.accessToken == "" || tt.selectedProfile == ""
|
|||
|
|
if hasError != tt.wantErr {
|
|||
|
|
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TestJoinServer_ServerIDValidation 测试服务器ID格式验证
|
|||
|
|
func TestJoinServer_ServerIDValidation(t *testing.T) {
|
|||
|
|
tests := []struct {
|
|||
|
|
name string
|
|||
|
|
serverId string
|
|||
|
|
wantValid bool
|
|||
|
|
}{
|
|||
|
|
{
|
|||
|
|
name: "有效的serverId",
|
|||
|
|
serverId: "test-server-123",
|
|||
|
|
wantValid: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "serverId过长",
|
|||
|
|
serverId: strings.Repeat("a", 101),
|
|||
|
|
wantValid: false,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "serverId包含危险字符<",
|
|||
|
|
serverId: "test<server",
|
|||
|
|
wantValid: false,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "serverId包含危险字符>",
|
|||
|
|
serverId: "test>server",
|
|||
|
|
wantValid: false,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "serverId包含危险字符\"",
|
|||
|
|
serverId: "test\"server",
|
|||
|
|
wantValid: false,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "serverId包含危险字符'",
|
|||
|
|
serverId: "test'server",
|
|||
|
|
wantValid: false,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "serverId包含危险字符&",
|
|||
|
|
serverId: "test&server",
|
|||
|
|
wantValid: false,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, tt := range tests {
|
|||
|
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
|
isValid := len(tt.serverId) <= 100 && !strings.ContainsAny(tt.serverId, "<>\"'&")
|
|||
|
|
if isValid != tt.wantValid {
|
|||
|
|
t.Errorf("ServerID validation failed: got %v, want %v", isValid, tt.wantValid)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TestJoinServer_IPValidation 测试IP地址验证逻辑
|
|||
|
|
func TestJoinServer_IPValidation(t *testing.T) {
|
|||
|
|
tests := []struct {
|
|||
|
|
name string
|
|||
|
|
ip string
|
|||
|
|
wantValid bool
|
|||
|
|
}{
|
|||
|
|
{
|
|||
|
|
name: "有效的IPv4地址",
|
|||
|
|
ip: "127.0.0.1",
|
|||
|
|
wantValid: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "有效的IPv6地址",
|
|||
|
|
ip: "::1",
|
|||
|
|
wantValid: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "无效的IP地址",
|
|||
|
|
ip: "invalid-ip",
|
|||
|
|
wantValid: false,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "空IP地址(可选)",
|
|||
|
|
ip: "",
|
|||
|
|
wantValid: true, // 空IP是允许的
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, tt := range tests {
|
|||
|
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
|
var isValid bool
|
|||
|
|
if tt.ip == "" {
|
|||
|
|
isValid = true // 空IP是允许的
|
|||
|
|
} else {
|
|||
|
|
isValid = net.ParseIP(tt.ip) != nil
|
|||
|
|
}
|
|||
|
|
if isValid != tt.wantValid {
|
|||
|
|
t.Errorf("IP validation failed: got %v, want %v (ip=%s)", isValid, tt.wantValid, tt.ip)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TestHasJoinedServer_InputValidation 测试HasJoinedServer输入验证
|
|||
|
|
func TestHasJoinedServer_InputValidation(t *testing.T) {
|
|||
|
|
tests := []struct {
|
|||
|
|
name string
|
|||
|
|
serverId string
|
|||
|
|
username string
|
|||
|
|
wantErr bool
|
|||
|
|
}{
|
|||
|
|
{
|
|||
|
|
name: "所有参数有效",
|
|||
|
|
serverId: "test-server",
|
|||
|
|
username: "TestUser",
|
|||
|
|
wantErr: false,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "serverId为空",
|
|||
|
|
serverId: "",
|
|||
|
|
username: "TestUser",
|
|||
|
|
wantErr: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "username为空",
|
|||
|
|
serverId: "test-server",
|
|||
|
|
username: "",
|
|||
|
|
wantErr: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "两者都为空",
|
|||
|
|
serverId: "",
|
|||
|
|
username: "",
|
|||
|
|
wantErr: true,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, tt := range tests {
|
|||
|
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
|
hasError := tt.serverId == "" || tt.username == ""
|
|||
|
|
if hasError != tt.wantErr {
|
|||
|
|
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TestHasJoinedServer_UsernameMatching 测试用户名匹配逻辑
|
|||
|
|
func TestHasJoinedServer_UsernameMatching(t *testing.T) {
|
|||
|
|
tests := []struct {
|
|||
|
|
name string
|
|||
|
|
sessionUser string
|
|||
|
|
requestUser string
|
|||
|
|
wantMatch bool
|
|||
|
|
}{
|
|||
|
|
{
|
|||
|
|
name: "用户名匹配",
|
|||
|
|
sessionUser: "TestUser",
|
|||
|
|
requestUser: "TestUser",
|
|||
|
|
wantMatch: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "用户名不匹配",
|
|||
|
|
sessionUser: "TestUser",
|
|||
|
|
requestUser: "OtherUser",
|
|||
|
|
wantMatch: false,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "大小写敏感",
|
|||
|
|
sessionUser: "TestUser",
|
|||
|
|
requestUser: "testuser",
|
|||
|
|
wantMatch: false,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, tt := range tests {
|
|||
|
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
|
matches := tt.sessionUser == tt.requestUser
|
|||
|
|
if matches != tt.wantMatch {
|
|||
|
|
t.Errorf("Username matching failed: got %v, want %v", matches, tt.wantMatch)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TestHasJoinedServer_IPMatching 测试IP地址匹配逻辑
|
|||
|
|
func TestHasJoinedServer_IPMatching(t *testing.T) {
|
|||
|
|
tests := []struct {
|
|||
|
|
name string
|
|||
|
|
sessionIP string
|
|||
|
|
requestIP string
|
|||
|
|
wantMatch bool
|
|||
|
|
shouldCheck bool
|
|||
|
|
}{
|
|||
|
|
{
|
|||
|
|
name: "IP匹配",
|
|||
|
|
sessionIP: "127.0.0.1",
|
|||
|
|
requestIP: "127.0.0.1",
|
|||
|
|
wantMatch: true,
|
|||
|
|
shouldCheck: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "IP不匹配",
|
|||
|
|
sessionIP: "127.0.0.1",
|
|||
|
|
requestIP: "192.168.1.1",
|
|||
|
|
wantMatch: false,
|
|||
|
|
shouldCheck: true,
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "请求IP为空时不检查",
|
|||
|
|
sessionIP: "127.0.0.1",
|
|||
|
|
requestIP: "",
|
|||
|
|
wantMatch: true,
|
|||
|
|
shouldCheck: false,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, tt := range tests {
|
|||
|
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
|
var matches bool
|
|||
|
|
if tt.requestIP == "" {
|
|||
|
|
matches = true // 空IP不检查
|
|||
|
|
} else {
|
|||
|
|
matches = tt.sessionIP == tt.requestIP
|
|||
|
|
}
|
|||
|
|
if matches != tt.wantMatch {
|
|||
|
|
t.Errorf("IP matching failed: got %v, want %v", matches, tt.wantMatch)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TestJoinServer_SessionKey 测试会话键生成
|
|||
|
|
func TestJoinServer_SessionKey(t *testing.T) {
|
|||
|
|
tests := []struct {
|
|||
|
|
name string
|
|||
|
|
serverId string
|
|||
|
|
expected string
|
|||
|
|
}{
|
|||
|
|
{
|
|||
|
|
name: "生成正确的会话键",
|
|||
|
|
serverId: "test-server-123",
|
|||
|
|
expected: "Join_test-server-123",
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
name: "空serverId",
|
|||
|
|
serverId: "",
|
|||
|
|
expected: "Join_",
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, tt := range tests {
|
|||
|
|
t.Run(tt.name, func(t *testing.T) {
|
|||
|
|
sessionKey := SessionKeyPrefix + tt.serverId
|
|||
|
|
if sessionKey != tt.expected {
|
|||
|
|
t.Errorf("Session key = %s, want %s", sessionKey, tt.expected)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
}
|