whiteipmgr.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. /**
  2. * @author chengliang
  3. * @date 2026/1/27 11:18
  4. * @brief
  5. *
  6. **/
  7. package whiteip
  8. import (
  9. "fmt"
  10. //"github.com/dlclark/regexp2" // 支持更复杂的正则表达式
  11. "github.com/gin-gonic/gin"
  12. "net"
  13. "regexp"
  14. "strings"
  15. )
  16. // 支持正则的IP白名单中间件
  17. type TWhiteIpMgr struct {
  18. // 预编译的正则表达式列表
  19. ipPatterns []*regexp.Regexp
  20. // 是否允许内网IP(默认允许)
  21. allowInternal bool
  22. // 是否启用IP检查(默认启用)
  23. enabled bool
  24. }
  25. // 创建白名单中间件
  26. func NewIPWhiteListMiddleware(patterns []string, allowInternal bool) *TWhiteIpMgr {
  27. m := &TWhiteIpMgr{
  28. allowInternal: allowInternal,
  29. enabled: true,
  30. }
  31. // 编译正则表达式
  32. for _, pattern := range patterns {
  33. // 支持简写语法
  34. compiledPattern := normalizePattern(pattern)
  35. re, err := regexp.Compile(compiledPattern)
  36. if err != nil {
  37. // 如果编译失败,记录错误但仍继续
  38. fmt.Printf("warning:IP white list compile failed: %s, error: %v\n", pattern, err)
  39. continue
  40. }
  41. m.ipPatterns = append(m.ipPatterns, re)
  42. }
  43. return m
  44. }
  45. func (this *TWhiteIpMgr) UpdatePattern(patterns []string, allowInternal bool, enabled bool) {
  46. this.enabled = enabled
  47. this.allowInternal = allowInternal
  48. this.ipPatterns = make([]*regexp.Regexp, len(patterns))
  49. for _, pattern := range patterns {
  50. compiledPattern := normalizePattern(pattern)
  51. re, err := regexp.Compile(compiledPattern)
  52. if err != nil {
  53. // 如果编译失败,记录错误但仍继续
  54. fmt.Printf("warning:IP white list compile failed: %s, error: %v\n", pattern, err)
  55. continue
  56. }
  57. this.ipPatterns = append(this.ipPatterns, re)
  58. }
  59. }
  60. // 标准化IP模式
  61. func normalizePattern(pattern string) string {
  62. // 支持CIDR表示法转换成正则表达式
  63. if strings.Contains(pattern, "/") {
  64. return cidrToRegex(pattern)
  65. }
  66. // 支持IP段简写,如: 192.168.1.* 或 192.168.1.1-100
  67. pattern = strings.ReplaceAll(pattern, "*", `\d+`)
  68. // 支持IP范围,如: 192.168.1.1-100
  69. if strings.Contains(pattern, "-") {
  70. parts := strings.Split(pattern, ".")
  71. for i, part := range parts {
  72. if strings.Contains(part, "-") {
  73. rangeParts := strings.Split(part, "-")
  74. if len(rangeParts) == 2 {
  75. parts[i] = fmt.Sprintf(`(%s)`, buildRangeRegex(rangeParts[0], rangeParts[1]))
  76. }
  77. }
  78. }
  79. return "^" + strings.Join(parts, `\.`) + "$"
  80. }
  81. // 如果是普通IP,确保完全匹配
  82. if !strings.HasPrefix(pattern, "^") {
  83. pattern = "^" + pattern
  84. }
  85. if !strings.HasSuffix(pattern, "$") {
  86. pattern = pattern + "$"
  87. }
  88. return pattern
  89. }
  90. // cidrToRegex 将CIDR表示法转换为正则表达式
  91. func cidrToRegex(cidr string) string {
  92. ip, ipNet, err := net.ParseCIDR(cidr)
  93. if err != nil {
  94. return cidr // 如果解析失败,返回原字符串
  95. }
  96. // 获取网络掩码
  97. mask := ipNet.Mask
  98. ones, bits := mask.Size()
  99. if bits != 32 {
  100. return cidr // 仅支持IPv4
  101. }
  102. // 将IP转换为整数
  103. ipInt := ipToInt(ip.To4())
  104. // 计算网络地址和广播地址
  105. network := ipInt & (^uint32(0) << uint32(bits-ones))
  106. broadcast := network | (^uint32(0) >> uint32(ones))
  107. // 生成正则表达式
  108. return fmt.Sprintf(`^%s$`, ipRangeToRegex(intToIP(network), intToIP(broadcast)))
  109. }
  110. // ipToInt IP转整数
  111. func ipToInt(ip net.IP) uint32 {
  112. if len(ip) == 16 {
  113. return uint32(ip[12])<<24 | uint32(ip[13])<<16 | uint32(ip[14])<<8 | uint32(ip[15])
  114. }
  115. return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
  116. }
  117. // intToIP 整数转IP
  118. func intToIP(n uint32) net.IP {
  119. return net.IPv4(byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
  120. }
  121. // ipRangeToRegex IP范围转正则表达式
  122. func ipRangeToRegex(start, end net.IP) string {
  123. startParts := strings.Split(start.String(), ".")
  124. endParts := strings.Split(end.String(), ".")
  125. var regexParts []string
  126. for i := 0; i < 4; i++ {
  127. startNum := parseInt(startParts[i])
  128. endNum := parseInt(endParts[i])
  129. if startNum == endNum {
  130. regexParts = append(regexParts, fmt.Sprintf("%d", startNum))
  131. } else {
  132. regexParts = append(regexParts, fmt.Sprintf("(%s)", buildRangeRegex(startParts[i], endParts[i])))
  133. }
  134. }
  135. return strings.Join(regexParts, `\.`)
  136. }
  137. // buildRangeRegex 构建数字范围的正则表达式
  138. func buildRangeRegex(start, end string) string {
  139. startNum := parseInt(start)
  140. endNum := parseInt(end)
  141. if startNum == endNum {
  142. return start
  143. }
  144. // 简单处理:如果范围小,直接列举
  145. if endNum-startNum < 10 {
  146. var options []string
  147. for i := startNum; i <= endNum; i++ {
  148. options = append(options, fmt.Sprintf("%d", i))
  149. }
  150. return strings.Join(options, "|")
  151. }
  152. // 复杂范围,使用正则表达式模式
  153. return fmt.Sprintf("%d|%d|[1-9]\\d{0,2}", startNum, endNum) // 简化处理
  154. }
  155. // parseInt 字符串转整数
  156. func parseInt(s string) int {
  157. var result int
  158. fmt.Sscanf(s, "%d", &result)
  159. return result
  160. }
  161. // 检查是否为内网IP
  162. func isInternalIP(ipStr string) bool {
  163. ip := net.ParseIP(ipStr)
  164. if ip == nil {
  165. return false
  166. }
  167. // IPv4 检查
  168. if ip4 := ip.To4(); ip4 != nil {
  169. return ip4[0] == 10 ||
  170. (ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31) ||
  171. (ip4[0] == 192 && ip4[1] == 168) ||
  172. ip4[0] == 127
  173. }
  174. // IPv6 检查
  175. return ip.IsLoopback() || ip.IsPrivate()
  176. }
  177. // GetClientIP 获取客户端真实IP
  178. func GetClientIP(c *gin.Context) string {
  179. // 从代理头获取
  180. if ip := c.GetHeader("X-Forwarded-For"); ip != "" {
  181. ips := strings.Split(ip, ",")
  182. if len(ips) > 0 {
  183. return strings.TrimSpace(ips[0])
  184. }
  185. }
  186. if ip := c.GetHeader("X-Real-IP"); ip != "" {
  187. return ip
  188. }
  189. // 直接获取 RemoteAddr
  190. remoteAddr := c.Request.RemoteAddr
  191. if ip, _, err := net.SplitHostPort(remoteAddr); err == nil {
  192. return ip
  193. }
  194. return remoteAddr
  195. }
  196. // Middleware 返回Gin中间件函数
  197. func (m *TWhiteIpMgr) Middleware() gin.HandlerFunc {
  198. return func(c *gin.Context) {
  199. if !m.enabled {
  200. c.Next()
  201. return
  202. }
  203. clientIP := GetClientIP(c)
  204. // 检查内网IP
  205. if m.allowInternal && isInternalIP(clientIP) {
  206. c.Next()
  207. return
  208. }
  209. // 检查白名单
  210. if m.isIPAllowed(clientIP) {
  211. c.Next()
  212. return
  213. }
  214. // 拒绝访问
  215. c.JSON(403, gin.H{
  216. "code": 403,
  217. "message": fmt.Sprintf("IP %s 不在白名单中", clientIP),
  218. "data": nil,
  219. })
  220. c.Abort()
  221. }
  222. }
  223. // isIPAllowed 检查IP是否被允许
  224. func (m *TWhiteIpMgr) isIPAllowed(ip string) bool {
  225. // 如果没有设置任何模式,拒绝所有(除非是内网)
  226. if len(m.ipPatterns) == 0 {
  227. return false
  228. }
  229. for _, pattern := range m.ipPatterns {
  230. if pattern.MatchString(ip) {
  231. return true
  232. }
  233. }
  234. return false
  235. }
  236. // Enable 启用中间件
  237. func (m *TWhiteIpMgr) Enable() {
  238. m.enabled = true
  239. }
  240. // Disable 禁用中间件
  241. func (m *TWhiteIpMgr) Disable() {
  242. m.enabled = false
  243. }
  244. // AddPattern 动态添加IP模式
  245. func (m *TWhiteIpMgr) AddPattern(pattern string) error {
  246. compiledPattern := normalizePattern(pattern)
  247. re, err := regexp.Compile(compiledPattern)
  248. if err != nil {
  249. return err
  250. }
  251. m.ipPatterns = append(m.ipPatterns, re)
  252. return nil
  253. }