/** * @author chengliang * @date 2026/1/27 11:18 * @brief * **/ package whiteip import ( "fmt" //"github.com/dlclark/regexp2" // 支持更复杂的正则表达式 "github.com/gin-gonic/gin" "net" "regexp" "strings" ) // 支持正则的IP白名单中间件 type TWhiteIpMgr struct { // 预编译的正则表达式列表 ipPatterns []*regexp.Regexp // 是否允许内网IP(默认允许) allowInternal bool // 是否启用IP检查(默认启用) enabled bool } // 创建白名单中间件 func NewIPWhiteListMiddleware(patterns []string, allowInternal bool) *TWhiteIpMgr { m := &TWhiteIpMgr{ allowInternal: allowInternal, enabled: true, } // 编译正则表达式 for _, pattern := range patterns { // 支持简写语法 compiledPattern := normalizePattern(pattern) re, err := regexp.Compile(compiledPattern) if err != nil { // 如果编译失败,记录错误但仍继续 fmt.Printf("warning:IP white list compile failed: %s, error: %v\n", pattern, err) continue } m.ipPatterns = append(m.ipPatterns, re) } return m } func (this *TWhiteIpMgr) UpdatePattern(patterns []string, allowInternal bool, enabled bool) { this.enabled = enabled this.allowInternal = allowInternal this.ipPatterns = make([]*regexp.Regexp, len(patterns)) for _, pattern := range patterns { compiledPattern := normalizePattern(pattern) re, err := regexp.Compile(compiledPattern) if err != nil { // 如果编译失败,记录错误但仍继续 fmt.Printf("warning:IP white list compile failed: %s, error: %v\n", pattern, err) continue } this.ipPatterns = append(this.ipPatterns, re) } } // 标准化IP模式 func normalizePattern(pattern string) string { // 支持CIDR表示法转换成正则表达式 if strings.Contains(pattern, "/") { return cidrToRegex(pattern) } // 支持IP段简写,如: 192.168.1.* 或 192.168.1.1-100 pattern = strings.ReplaceAll(pattern, "*", `\d+`) // 支持IP范围,如: 192.168.1.1-100 if strings.Contains(pattern, "-") { parts := strings.Split(pattern, ".") for i, part := range parts { if strings.Contains(part, "-") { rangeParts := strings.Split(part, "-") if len(rangeParts) == 2 { parts[i] = fmt.Sprintf(`(%s)`, buildRangeRegex(rangeParts[0], rangeParts[1])) } } } return "^" + strings.Join(parts, `\.`) + "$" } // 如果是普通IP,确保完全匹配 if !strings.HasPrefix(pattern, "^") { pattern = "^" + pattern } if !strings.HasSuffix(pattern, "$") { pattern = pattern + "$" } return pattern } // cidrToRegex 将CIDR表示法转换为正则表达式 func cidrToRegex(cidr string) string { ip, ipNet, err := net.ParseCIDR(cidr) if err != nil { return cidr // 如果解析失败,返回原字符串 } // 获取网络掩码 mask := ipNet.Mask ones, bits := mask.Size() if bits != 32 { return cidr // 仅支持IPv4 } // 将IP转换为整数 ipInt := ipToInt(ip.To4()) // 计算网络地址和广播地址 network := ipInt & (^uint32(0) << uint32(bits-ones)) broadcast := network | (^uint32(0) >> uint32(ones)) // 生成正则表达式 return fmt.Sprintf(`^%s$`, ipRangeToRegex(intToIP(network), intToIP(broadcast))) } // ipToInt IP转整数 func ipToInt(ip net.IP) uint32 { if len(ip) == 16 { return uint32(ip[12])<<24 | uint32(ip[13])<<16 | uint32(ip[14])<<8 | uint32(ip[15]) } return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3]) } // intToIP 整数转IP func intToIP(n uint32) net.IP { return net.IPv4(byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) } // ipRangeToRegex IP范围转正则表达式 func ipRangeToRegex(start, end net.IP) string { startParts := strings.Split(start.String(), ".") endParts := strings.Split(end.String(), ".") var regexParts []string for i := 0; i < 4; i++ { startNum := parseInt(startParts[i]) endNum := parseInt(endParts[i]) if startNum == endNum { regexParts = append(regexParts, fmt.Sprintf("%d", startNum)) } else { regexParts = append(regexParts, fmt.Sprintf("(%s)", buildRangeRegex(startParts[i], endParts[i]))) } } return strings.Join(regexParts, `\.`) } // buildRangeRegex 构建数字范围的正则表达式 func buildRangeRegex(start, end string) string { startNum := parseInt(start) endNum := parseInt(end) if startNum == endNum { return start } // 简单处理:如果范围小,直接列举 if endNum-startNum < 10 { var options []string for i := startNum; i <= endNum; i++ { options = append(options, fmt.Sprintf("%d", i)) } return strings.Join(options, "|") } // 复杂范围,使用正则表达式模式 return fmt.Sprintf("%d|%d|[1-9]\\d{0,2}", startNum, endNum) // 简化处理 } // parseInt 字符串转整数 func parseInt(s string) int { var result int fmt.Sscanf(s, "%d", &result) return result } // 检查是否为内网IP func isInternalIP(ipStr string) bool { ip := net.ParseIP(ipStr) if ip == nil { return false } // IPv4 检查 if ip4 := ip.To4(); ip4 != nil { return ip4[0] == 10 || (ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31) || (ip4[0] == 192 && ip4[1] == 168) || ip4[0] == 127 } // IPv6 检查 return ip.IsLoopback() || ip.IsPrivate() } // GetClientIP 获取客户端真实IP func GetClientIP(c *gin.Context) string { // 从代理头获取 if ip := c.GetHeader("X-Forwarded-For"); ip != "" { ips := strings.Split(ip, ",") if len(ips) > 0 { return strings.TrimSpace(ips[0]) } } if ip := c.GetHeader("X-Real-IP"); ip != "" { return ip } // 直接获取 RemoteAddr remoteAddr := c.Request.RemoteAddr if ip, _, err := net.SplitHostPort(remoteAddr); err == nil { return ip } return remoteAddr } // Middleware 返回Gin中间件函数 func (m *TWhiteIpMgr) Middleware() gin.HandlerFunc { return func(c *gin.Context) { if !m.enabled { c.Next() return } clientIP := GetClientIP(c) // 检查内网IP if m.allowInternal && isInternalIP(clientIP) { c.Next() return } // 检查白名单 if m.isIPAllowed(clientIP) { c.Next() return } // 拒绝访问 c.JSON(403, gin.H{ "code": 403, "message": fmt.Sprintf("IP %s 不在白名单中", clientIP), "data": nil, }) c.Abort() } } // isIPAllowed 检查IP是否被允许 func (m *TWhiteIpMgr) isIPAllowed(ip string) bool { // 如果没有设置任何模式,拒绝所有(除非是内网) if len(m.ipPatterns) == 0 { return false } for _, pattern := range m.ipPatterns { if pattern.MatchString(ip) { return true } } return false } // Enable 启用中间件 func (m *TWhiteIpMgr) Enable() { m.enabled = true } // Disable 禁用中间件 func (m *TWhiteIpMgr) Disable() { m.enabled = false } // AddPattern 动态添加IP模式 func (m *TWhiteIpMgr) AddPattern(pattern string) error { compiledPattern := normalizePattern(pattern) re, err := regexp.Compile(compiledPattern) if err != nil { return err } m.ipPatterns = append(m.ipPatterns, re) return nil }