|
@@ -0,0 +1,306 @@
|
|
|
|
|
+/**
|
|
|
|
|
+ * @author chengliang
|
|
|
|
|
+ * @date 2026/1/27 11:18
|
|
|
|
|
+ * @brief
|
|
|
|
|
+ *
|
|
|
|
|
+ **/
|
|
|
|
|
+
|
|
|
|
|
+package whiteip
|
|
|
|
|
+
|
|
|
|
|
+import (
|
|
|
|
|
+ "fmt"
|
|
|
|
|
+ //"github.com/dlclark/regexp2" // 支持更复杂的正则表达式
|
|
|
|
|
+ "github.com/gin-gonic/gin"
|
|
|
|
|
+ "net"
|
|
|
|
|
+ "regexp"
|
|
|
|
|
+ "strings"
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+// 支持正则的IP白名单中间件
|
|
|
|
|
+type TWhiteIpMgr struct {
|
|
|
|
|
+ // 预编译的正则表达式列表
|
|
|
|
|
+ ipPatterns []*regexp.Regexp
|
|
|
|
|
+ // 是否允许内网IP(默认允许)
|
|
|
|
|
+ allowInternal bool
|
|
|
|
|
+ // 是否启用IP检查(默认启用)
|
|
|
|
|
+ enabled bool
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// 创建白名单中间件
|
|
|
|
|
+func NewIPWhiteListMiddleware(patterns []string, allowInternal bool) *TWhiteIpMgr {
|
|
|
|
|
+ m := &TWhiteIpMgr{
|
|
|
|
|
+ allowInternal: allowInternal,
|
|
|
|
|
+ enabled: true,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 编译正则表达式
|
|
|
|
|
+ for _, pattern := range patterns {
|
|
|
|
|
+ // 支持简写语法
|
|
|
|
|
+ compiledPattern := normalizePattern(pattern)
|
|
|
|
|
+
|
|
|
|
|
+ re, err := regexp.Compile(compiledPattern)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ // 如果编译失败,记录错误但仍继续
|
|
|
|
|
+ fmt.Printf("warning:IP white list compile failed: %s, error: %v\n", pattern, err)
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ m.ipPatterns = append(m.ipPatterns, re)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return m
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (this *TWhiteIpMgr) UpdatePattern(patterns []string, allowInternal bool, enabled bool) {
|
|
|
|
|
+ this.enabled = enabled
|
|
|
|
|
+ this.allowInternal = allowInternal
|
|
|
|
|
+ this.ipPatterns = make([]*regexp.Regexp, len(patterns))
|
|
|
|
|
+ for _, pattern := range patterns {
|
|
|
|
|
+ compiledPattern := normalizePattern(pattern)
|
|
|
|
|
+
|
|
|
|
|
+ re, err := regexp.Compile(compiledPattern)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ // 如果编译失败,记录错误但仍继续
|
|
|
|
|
+ fmt.Printf("warning:IP white list compile failed: %s, error: %v\n", pattern, err)
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ this.ipPatterns = append(this.ipPatterns, re)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// 标准化IP模式
|
|
|
|
|
+func normalizePattern(pattern string) string {
|
|
|
|
|
+ // 支持CIDR表示法转换成正则表达式
|
|
|
|
|
+ if strings.Contains(pattern, "/") {
|
|
|
|
|
+ return cidrToRegex(pattern)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 支持IP段简写,如: 192.168.1.* 或 192.168.1.1-100
|
|
|
|
|
+ pattern = strings.ReplaceAll(pattern, "*", `\d+`)
|
|
|
|
|
+
|
|
|
|
|
+ // 支持IP范围,如: 192.168.1.1-100
|
|
|
|
|
+ if strings.Contains(pattern, "-") {
|
|
|
|
|
+ parts := strings.Split(pattern, ".")
|
|
|
|
|
+ for i, part := range parts {
|
|
|
|
|
+ if strings.Contains(part, "-") {
|
|
|
|
|
+ rangeParts := strings.Split(part, "-")
|
|
|
|
|
+ if len(rangeParts) == 2 {
|
|
|
|
|
+ parts[i] = fmt.Sprintf(`(%s)`, buildRangeRegex(rangeParts[0], rangeParts[1]))
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ return "^" + strings.Join(parts, `\.`) + "$"
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 如果是普通IP,确保完全匹配
|
|
|
|
|
+ if !strings.HasPrefix(pattern, "^") {
|
|
|
|
|
+ pattern = "^" + pattern
|
|
|
|
|
+ }
|
|
|
|
|
+ if !strings.HasSuffix(pattern, "$") {
|
|
|
|
|
+ pattern = pattern + "$"
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return pattern
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// cidrToRegex 将CIDR表示法转换为正则表达式
|
|
|
|
|
+func cidrToRegex(cidr string) string {
|
|
|
|
|
+ ip, ipNet, err := net.ParseCIDR(cidr)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return cidr // 如果解析失败,返回原字符串
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 获取网络掩码
|
|
|
|
|
+ mask := ipNet.Mask
|
|
|
|
|
+ ones, bits := mask.Size()
|
|
|
|
|
+
|
|
|
|
|
+ if bits != 32 {
|
|
|
|
|
+ return cidr // 仅支持IPv4
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 将IP转换为整数
|
|
|
|
|
+ ipInt := ipToInt(ip.To4())
|
|
|
|
|
+
|
|
|
|
|
+ // 计算网络地址和广播地址
|
|
|
|
|
+ network := ipInt & (^uint32(0) << uint32(bits-ones))
|
|
|
|
|
+ broadcast := network | (^uint32(0) >> uint32(ones))
|
|
|
|
|
+
|
|
|
|
|
+ // 生成正则表达式
|
|
|
|
|
+ return fmt.Sprintf(`^%s$`, ipRangeToRegex(intToIP(network), intToIP(broadcast)))
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// ipToInt IP转整数
|
|
|
|
|
+func ipToInt(ip net.IP) uint32 {
|
|
|
|
|
+ if len(ip) == 16 {
|
|
|
|
|
+ return uint32(ip[12])<<24 | uint32(ip[13])<<16 | uint32(ip[14])<<8 | uint32(ip[15])
|
|
|
|
|
+ }
|
|
|
|
|
+ return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// intToIP 整数转IP
|
|
|
|
|
+func intToIP(n uint32) net.IP {
|
|
|
|
|
+ return net.IPv4(byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// ipRangeToRegex IP范围转正则表达式
|
|
|
|
|
+func ipRangeToRegex(start, end net.IP) string {
|
|
|
|
|
+ startParts := strings.Split(start.String(), ".")
|
|
|
|
|
+ endParts := strings.Split(end.String(), ".")
|
|
|
|
|
+
|
|
|
|
|
+ var regexParts []string
|
|
|
|
|
+ for i := 0; i < 4; i++ {
|
|
|
|
|
+ startNum := parseInt(startParts[i])
|
|
|
|
|
+ endNum := parseInt(endParts[i])
|
|
|
|
|
+
|
|
|
|
|
+ if startNum == endNum {
|
|
|
|
|
+ regexParts = append(regexParts, fmt.Sprintf("%d", startNum))
|
|
|
|
|
+ } else {
|
|
|
|
|
+ regexParts = append(regexParts, fmt.Sprintf("(%s)", buildRangeRegex(startParts[i], endParts[i])))
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return strings.Join(regexParts, `\.`)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// buildRangeRegex 构建数字范围的正则表达式
|
|
|
|
|
+func buildRangeRegex(start, end string) string {
|
|
|
|
|
+ startNum := parseInt(start)
|
|
|
|
|
+ endNum := parseInt(end)
|
|
|
|
|
+
|
|
|
|
|
+ if startNum == endNum {
|
|
|
|
|
+ return start
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 简单处理:如果范围小,直接列举
|
|
|
|
|
+ if endNum-startNum < 10 {
|
|
|
|
|
+ var options []string
|
|
|
|
|
+ for i := startNum; i <= endNum; i++ {
|
|
|
|
|
+ options = append(options, fmt.Sprintf("%d", i))
|
|
|
|
|
+ }
|
|
|
|
|
+ return strings.Join(options, "|")
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 复杂范围,使用正则表达式模式
|
|
|
|
|
+ return fmt.Sprintf("%d|%d|[1-9]\\d{0,2}", startNum, endNum) // 简化处理
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// parseInt 字符串转整数
|
|
|
|
|
+func parseInt(s string) int {
|
|
|
|
|
+ var result int
|
|
|
|
|
+ fmt.Sscanf(s, "%d", &result)
|
|
|
|
|
+ return result
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// 检查是否为内网IP
|
|
|
|
|
+func isInternalIP(ipStr string) bool {
|
|
|
|
|
+ ip := net.ParseIP(ipStr)
|
|
|
|
|
+ if ip == nil {
|
|
|
|
|
+ return false
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // IPv4 检查
|
|
|
|
|
+ if ip4 := ip.To4(); ip4 != nil {
|
|
|
|
|
+ return ip4[0] == 10 ||
|
|
|
|
|
+ (ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31) ||
|
|
|
|
|
+ (ip4[0] == 192 && ip4[1] == 168) ||
|
|
|
|
|
+ ip4[0] == 127
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // IPv6 检查
|
|
|
|
|
+ return ip.IsLoopback() || ip.IsPrivate()
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// GetClientIP 获取客户端真实IP
|
|
|
|
|
+func GetClientIP(c *gin.Context) string {
|
|
|
|
|
+ // 从代理头获取
|
|
|
|
|
+ if ip := c.GetHeader("X-Forwarded-For"); ip != "" {
|
|
|
|
|
+ ips := strings.Split(ip, ",")
|
|
|
|
|
+ if len(ips) > 0 {
|
|
|
|
|
+ return strings.TrimSpace(ips[0])
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if ip := c.GetHeader("X-Real-IP"); ip != "" {
|
|
|
|
|
+ return ip
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 直接获取 RemoteAddr
|
|
|
|
|
+ remoteAddr := c.Request.RemoteAddr
|
|
|
|
|
+ if ip, _, err := net.SplitHostPort(remoteAddr); err == nil {
|
|
|
|
|
+ return ip
|
|
|
|
|
+ }
|
|
|
|
|
+ return remoteAddr
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Middleware 返回Gin中间件函数
|
|
|
|
|
+func (m *TWhiteIpMgr) Middleware() gin.HandlerFunc {
|
|
|
|
|
+ return func(c *gin.Context) {
|
|
|
|
|
+ if !m.enabled {
|
|
|
|
|
+ c.Next()
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ clientIP := GetClientIP(c)
|
|
|
|
|
+
|
|
|
|
|
+ // 检查内网IP
|
|
|
|
|
+ if m.allowInternal && isInternalIP(clientIP) {
|
|
|
|
|
+ c.Next()
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 检查白名单
|
|
|
|
|
+ if m.isIPAllowed(clientIP) {
|
|
|
|
|
+ c.Next()
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 拒绝访问
|
|
|
|
|
+ c.JSON(403, gin.H{
|
|
|
|
|
+ "code": 403,
|
|
|
|
|
+ "message": fmt.Sprintf("IP %s 不在白名单中", clientIP),
|
|
|
|
|
+ "data": nil,
|
|
|
|
|
+ })
|
|
|
|
|
+ c.Abort()
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// isIPAllowed 检查IP是否被允许
|
|
|
|
|
+func (m *TWhiteIpMgr) isIPAllowed(ip string) bool {
|
|
|
|
|
+ // 如果没有设置任何模式,拒绝所有(除非是内网)
|
|
|
|
|
+ if len(m.ipPatterns) == 0 {
|
|
|
|
|
+ return false
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for _, pattern := range m.ipPatterns {
|
|
|
|
|
+ if pattern.MatchString(ip) {
|
|
|
|
|
+ return true
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return false
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Enable 启用中间件
|
|
|
|
|
+func (m *TWhiteIpMgr) Enable() {
|
|
|
|
|
+ m.enabled = true
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Disable 禁用中间件
|
|
|
|
|
+func (m *TWhiteIpMgr) Disable() {
|
|
|
|
|
+ m.enabled = false
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// AddPattern 动态添加IP模式
|
|
|
|
|
+func (m *TWhiteIpMgr) AddPattern(pattern string) error {
|
|
|
|
|
+ compiledPattern := normalizePattern(pattern)
|
|
|
|
|
+
|
|
|
|
|
+ re, err := regexp.Compile(compiledPattern)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ m.ipPatterns = append(m.ipPatterns, re)
|
|
|
|
|
+ return nil
|
|
|
|
|
+}
|