package repository import ( "carrot_bbs/internal/model" "gorm.io/gorm" ) // GroupRepository 群组仓库接口 type GroupRepository interface { // 群组操作 Create(group *model.Group) error GetByID(id string) (*model.Group, error) Update(group *model.Group) error Delete(id string) error GetByOwnerID(ownerID string, page, pageSize int) ([]model.Group, int64, error) // 群成员操作 AddMember(member *model.GroupMember) error GetMember(groupID string, userID string) (*model.GroupMember, error) GetMembers(groupID string, page, pageSize int) ([]model.GroupMember, int64, error) UpdateMember(member *model.GroupMember) error RemoveMember(groupID string, userID string) error GetMemberCount(groupID string) (int64, error) IsMember(groupID string, userID string) (bool, error) GetUserGroups(userID string, page, pageSize int) ([]model.Group, int64, error) // 角色相关 GetMemberRole(groupID string, userID string) (string, error) SetMemberRole(groupID string, userID string, role string) error GetAdmins(groupID string) ([]model.GroupMember, error) // 群公告操作 CreateAnnouncement(announcement *model.GroupAnnouncement) error GetAnnouncements(groupID string, page, pageSize int) ([]model.GroupAnnouncement, int64, error) GetAnnouncementByID(id string) (*model.GroupAnnouncement, error) DeleteAnnouncement(id string) error } // groupRepository 群组仓库实现 type groupRepository struct { db *gorm.DB } // NewGroupRepository 创建群组仓库 func NewGroupRepository(db *gorm.DB) GroupRepository { return &groupRepository{db: db} } // Create 创建群组 func (r *groupRepository) Create(group *model.Group) error { return r.db.Create(group).Error } // GetByID 根据ID获取群组 func (r *groupRepository) GetByID(id string) (*model.Group, error) { var group model.Group err := r.db.First(&group, "id = ?", id).Error if err != nil { return nil, err } return &group, nil } // Update 更新群组 func (r *groupRepository) Update(group *model.Group) error { return r.db.Save(group).Error } // Delete 删除群组 func (r *groupRepository) Delete(id string) error { return r.db.Transaction(func(tx *gorm.DB) error { // 删除群成员 if err := tx.Where("group_id = ?", id).Delete(&model.GroupMember{}).Error; err != nil { return err } // 删除群公告 if err := tx.Where("group_id = ?", id).Delete(&model.GroupAnnouncement{}).Error; err != nil { return err } // 删除群组 if err := tx.Delete(&model.Group{}, "id = ?", id).Error; err != nil { return err } return nil }) } // GetByOwnerID 根据群主ID获取群组列表 func (r *groupRepository) GetByOwnerID(ownerID string, page, pageSize int) ([]model.Group, int64, error) { var groups []model.Group var total int64 query := r.db.Model(&model.Group{}).Where("owner_id = ?", ownerID) query.Count(&total) offset := (page - 1) * pageSize err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&groups).Error return groups, total, err } // AddMember 添加群成员 func (r *groupRepository) AddMember(member *model.GroupMember) error { return r.db.Transaction(func(tx *gorm.DB) error { if err := tx.Create(member).Error; err != nil { return err } // 更新群组成员数量 return tx.Model(&model.Group{}).Where("id = ?", member.GroupID). Update("member_count", gorm.Expr("member_count + ?", 1)).Error }) } // GetMember 获取群成员 func (r *groupRepository) GetMember(groupID string, userID string) (*model.GroupMember, error) { var member model.GroupMember err := r.db.First(&member, "group_id = ? AND user_id = ?", groupID, userID).Error if err != nil { return nil, err } return &member, nil } // GetMembers 获取群成员列表 func (r *groupRepository) GetMembers(groupID string, page, pageSize int) ([]model.GroupMember, int64, error) { var members []model.GroupMember var total int64 query := r.db.Model(&model.GroupMember{}).Where("group_id = ?", groupID) query.Count(&total) offset := (page - 1) * pageSize err := query.Offset(offset).Limit(pageSize).Order("created_at ASC").Find(&members).Error return members, total, err } // UpdateMember 更新群成员 func (r *groupRepository) UpdateMember(member *model.GroupMember) error { return r.db.Save(member).Error } // RemoveMember 移除群成员 func (r *groupRepository) RemoveMember(groupID string, userID string) error { return r.db.Transaction(func(tx *gorm.DB) error { // 删除成员 if err := tx.Where("group_id = ? AND user_id = ?", groupID, userID).Delete(&model.GroupMember{}).Error; err != nil { return err } // 更新群组成员数量 return tx.Model(&model.Group{}).Where("id = ?", groupID). Update("member_count", gorm.Expr("member_count - ?", 1)).Error }) } // GetMemberCount 获取群成员数量 func (r *groupRepository) GetMemberCount(groupID string) (int64, error) { var count int64 err := r.db.Model(&model.GroupMember{}).Where("group_id = ?", groupID).Count(&count).Error return count, err } // IsMember 检查是否是群成员 func (r *groupRepository) IsMember(groupID string, userID string) (bool, error) { var count int64 err := r.db.Model(&model.GroupMember{}).Where("group_id = ? AND user_id = ?", groupID, userID).Count(&count).Error return count > 0, err } // GetUserGroups 获取用户加入的群组列表 func (r *groupRepository) GetUserGroups(userID string, page, pageSize int) ([]model.Group, int64, error) { var groups []model.Group var total int64 // 通过群成员表查询用户加入的群组 subQuery := r.db.Model(&model.GroupMember{}). Select("group_id"). Where("user_id = ?", userID) query := r.db.Model(&model.Group{}).Where("id IN (?)", subQuery) query.Count(&total) offset := (page - 1) * pageSize err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&groups).Error return groups, total, err } // GetMemberRole 获取成员角色 func (r *groupRepository) GetMemberRole(groupID string, userID string) (string, error) { member, err := r.GetMember(groupID, userID) if err != nil { return "", err } return member.Role, nil } // SetMemberRole 设置成员角色 func (r *groupRepository) SetMemberRole(groupID string, userID string, role string) error { return r.db.Model(&model.GroupMember{}). Where("group_id = ? AND user_id = ?", groupID, userID). Update("role", role).Error } // GetAdmins 获取群管理员列表 func (r *groupRepository) GetAdmins(groupID string) ([]model.GroupMember, error) { var admins []model.GroupMember err := r.db.Where("group_id = ? AND role = ?", groupID, model.GroupRoleAdmin).Find(&admins).Error return admins, err } // CreateAnnouncement 创建群公告 func (r *groupRepository) CreateAnnouncement(announcement *model.GroupAnnouncement) error { return r.db.Create(announcement).Error } // GetAnnouncements 获取群公告列表 func (r *groupRepository) GetAnnouncements(groupID string, page, pageSize int) ([]model.GroupAnnouncement, int64, error) { var announcements []model.GroupAnnouncement var total int64 query := r.db.Model(&model.GroupAnnouncement{}).Where("group_id = ?", groupID) query.Count(&total) offset := (page - 1) * pageSize // 置顶的排在前面,然后按时间倒序 err := query.Offset(offset).Limit(pageSize).Order("is_pinned DESC, created_at DESC").Find(&announcements).Error return announcements, total, err } // GetAnnouncementByID 根据ID获取群公告 func (r *groupRepository) GetAnnouncementByID(id string) (*model.GroupAnnouncement, error) { var announcement model.GroupAnnouncement err := r.db.First(&announcement, "id = ?", id).Error if err != nil { return nil, err } return &announcement, nil } // DeleteAnnouncement 删除群公告 func (r *groupRepository) DeleteAnnouncement(id string) error { return r.db.Delete(&model.GroupAnnouncement{}, "id = ?", id).Error }