| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306 |
- /**
- * @author chengliang
- * @date 2026/1/27 11:18
- * @brief
- *
- **/
- package whiteip
- import (
- "fmt"
- //"github.com/dlclark/regexp2" // 支持更复杂的正则表达式
- "github.com/gin-gonic/gin"
- "net"
- "regexp"
- "strings"
- )
- // 支持正则的IP白名单中间件
- type TWhiteIpMgr struct {
- // 预编译的正则表达式列表
- ipPatterns []*regexp.Regexp
- // 是否允许内网IP(默认允许)
- allowInternal bool
- // 是否启用IP检查(默认启用)
- enabled bool
- }
- // 创建白名单中间件
- func NewIPWhiteListMiddleware(patterns []string, allowInternal bool) *TWhiteIpMgr {
- m := &TWhiteIpMgr{
- allowInternal: allowInternal,
- enabled: true,
- }
- // 编译正则表达式
- for _, pattern := range patterns {
- // 支持简写语法
- compiledPattern := normalizePattern(pattern)
- re, err := regexp.Compile(compiledPattern)
- if err != nil {
- // 如果编译失败,记录错误但仍继续
- fmt.Printf("warning:IP white list compile failed: %s, error: %v\n", pattern, err)
- continue
- }
- m.ipPatterns = append(m.ipPatterns, re)
- }
- return m
- }
- func (this *TWhiteIpMgr) UpdatePattern(patterns []string, allowInternal bool, enabled bool) {
- this.enabled = enabled
- this.allowInternal = allowInternal
- this.ipPatterns = make([]*regexp.Regexp, len(patterns))
- for _, pattern := range patterns {
- compiledPattern := normalizePattern(pattern)
- re, err := regexp.Compile(compiledPattern)
- if err != nil {
- // 如果编译失败,记录错误但仍继续
- fmt.Printf("warning:IP white list compile failed: %s, error: %v\n", pattern, err)
- continue
- }
- this.ipPatterns = append(this.ipPatterns, re)
- }
- }
- // 标准化IP模式
- func normalizePattern(pattern string) string {
- // 支持CIDR表示法转换成正则表达式
- if strings.Contains(pattern, "/") {
- return cidrToRegex(pattern)
- }
- // 支持IP段简写,如: 192.168.1.* 或 192.168.1.1-100
- pattern = strings.ReplaceAll(pattern, "*", `\d+`)
- // 支持IP范围,如: 192.168.1.1-100
- if strings.Contains(pattern, "-") {
- parts := strings.Split(pattern, ".")
- for i, part := range parts {
- if strings.Contains(part, "-") {
- rangeParts := strings.Split(part, "-")
- if len(rangeParts) == 2 {
- parts[i] = fmt.Sprintf(`(%s)`, buildRangeRegex(rangeParts[0], rangeParts[1]))
- }
- }
- }
- return "^" + strings.Join(parts, `\.`) + "$"
- }
- // 如果是普通IP,确保完全匹配
- if !strings.HasPrefix(pattern, "^") {
- pattern = "^" + pattern
- }
- if !strings.HasSuffix(pattern, "$") {
- pattern = pattern + "$"
- }
- return pattern
- }
- // cidrToRegex 将CIDR表示法转换为正则表达式
- func cidrToRegex(cidr string) string {
- ip, ipNet, err := net.ParseCIDR(cidr)
- if err != nil {
- return cidr // 如果解析失败,返回原字符串
- }
- // 获取网络掩码
- mask := ipNet.Mask
- ones, bits := mask.Size()
- if bits != 32 {
- return cidr // 仅支持IPv4
- }
- // 将IP转换为整数
- ipInt := ipToInt(ip.To4())
- // 计算网络地址和广播地址
- network := ipInt & (^uint32(0) << uint32(bits-ones))
- broadcast := network | (^uint32(0) >> uint32(ones))
- // 生成正则表达式
- return fmt.Sprintf(`^%s$`, ipRangeToRegex(intToIP(network), intToIP(broadcast)))
- }
- // ipToInt IP转整数
- func ipToInt(ip net.IP) uint32 {
- if len(ip) == 16 {
- return uint32(ip[12])<<24 | uint32(ip[13])<<16 | uint32(ip[14])<<8 | uint32(ip[15])
- }
- return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
- }
- // intToIP 整数转IP
- func intToIP(n uint32) net.IP {
- return net.IPv4(byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
- }
- // ipRangeToRegex IP范围转正则表达式
- func ipRangeToRegex(start, end net.IP) string {
- startParts := strings.Split(start.String(), ".")
- endParts := strings.Split(end.String(), ".")
- var regexParts []string
- for i := 0; i < 4; i++ {
- startNum := parseInt(startParts[i])
- endNum := parseInt(endParts[i])
- if startNum == endNum {
- regexParts = append(regexParts, fmt.Sprintf("%d", startNum))
- } else {
- regexParts = append(regexParts, fmt.Sprintf("(%s)", buildRangeRegex(startParts[i], endParts[i])))
- }
- }
- return strings.Join(regexParts, `\.`)
- }
- // buildRangeRegex 构建数字范围的正则表达式
- func buildRangeRegex(start, end string) string {
- startNum := parseInt(start)
- endNum := parseInt(end)
- if startNum == endNum {
- return start
- }
- // 简单处理:如果范围小,直接列举
- if endNum-startNum < 10 {
- var options []string
- for i := startNum; i <= endNum; i++ {
- options = append(options, fmt.Sprintf("%d", i))
- }
- return strings.Join(options, "|")
- }
- // 复杂范围,使用正则表达式模式
- return fmt.Sprintf("%d|%d|[1-9]\\d{0,2}", startNum, endNum) // 简化处理
- }
- // parseInt 字符串转整数
- func parseInt(s string) int {
- var result int
- fmt.Sscanf(s, "%d", &result)
- return result
- }
- // 检查是否为内网IP
- func isInternalIP(ipStr string) bool {
- ip := net.ParseIP(ipStr)
- if ip == nil {
- return false
- }
- // IPv4 检查
- if ip4 := ip.To4(); ip4 != nil {
- return ip4[0] == 10 ||
- (ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31) ||
- (ip4[0] == 192 && ip4[1] == 168) ||
- ip4[0] == 127
- }
- // IPv6 检查
- return ip.IsLoopback() || ip.IsPrivate()
- }
- // GetClientIP 获取客户端真实IP
- func GetClientIP(c *gin.Context) string {
- // 从代理头获取
- if ip := c.GetHeader("X-Forwarded-For"); ip != "" {
- ips := strings.Split(ip, ",")
- if len(ips) > 0 {
- return strings.TrimSpace(ips[0])
- }
- }
- if ip := c.GetHeader("X-Real-IP"); ip != "" {
- return ip
- }
- // 直接获取 RemoteAddr
- remoteAddr := c.Request.RemoteAddr
- if ip, _, err := net.SplitHostPort(remoteAddr); err == nil {
- return ip
- }
- return remoteAddr
- }
- // Middleware 返回Gin中间件函数
- func (m *TWhiteIpMgr) Middleware() gin.HandlerFunc {
- return func(c *gin.Context) {
- if !m.enabled {
- c.Next()
- return
- }
- clientIP := GetClientIP(c)
- // 检查内网IP
- if m.allowInternal && isInternalIP(clientIP) {
- c.Next()
- return
- }
- // 检查白名单
- if m.isIPAllowed(clientIP) {
- c.Next()
- return
- }
- // 拒绝访问
- c.JSON(403, gin.H{
- "code": 403,
- "message": fmt.Sprintf("IP %s 不在白名单中", clientIP),
- "data": nil,
- })
- c.Abort()
- }
- }
- // isIPAllowed 检查IP是否被允许
- func (m *TWhiteIpMgr) isIPAllowed(ip string) bool {
- // 如果没有设置任何模式,拒绝所有(除非是内网)
- if len(m.ipPatterns) == 0 {
- return false
- }
- for _, pattern := range m.ipPatterns {
- if pattern.MatchString(ip) {
- return true
- }
- }
- return false
- }
- // Enable 启用中间件
- func (m *TWhiteIpMgr) Enable() {
- m.enabled = true
- }
- // Disable 禁用中间件
- func (m *TWhiteIpMgr) Disable() {
- m.enabled = false
- }
- // AddPattern 动态添加IP模式
- func (m *TWhiteIpMgr) AddPattern(pattern string) error {
- compiledPattern := normalizePattern(pattern)
- re, err := regexp.Compile(compiledPattern)
- if err != nil {
- return err
- }
- m.ipPatterns = append(m.ipPatterns, re)
- return nil
- }
|