chengliang 2 месяцев назад
Родитель
Сommit
56817cc22f

+ 20 - 20
app/api/hotupdate/internal.go

@@ -29,11 +29,11 @@ func AddVersion(c *gin.Context) {
 	req := service.TAddVersionReq{}
 	err := c.ShouldBind(&req)
 	if err == nil {
-		err = checkAddVersionHash(req, true)
-		if err != nil {
-			appG.Response(http.StatusOK, e.ERROR_UPLOAD_SIGNURL_FAIL, err.Error())
-			return
-		}
+		//err = checkAddVersionHash(req, true)
+		//if err != nil {
+		//	appG.Response(http.StatusOK, e.ERROR_UPLOAD_SIGNURL_FAIL, err.Error())
+		//	return
+		//}
 		var err = service.GetTHotUpdateVerManager().AddVersion(&req)
 		if err != nil {
 			appG.Response(http.StatusOK, e.NO_RECORD, err.Error())
@@ -52,11 +52,11 @@ func GetMaxVerInfo(c *gin.Context) {
 	req := service.TGetVersionReq{}
 	err := c.ShouldBind(&req)
 	if err == nil {
-		err = checkAddVersionHash(req, false)
-		if err != nil {
-			appG.Response(http.StatusOK, e.ERROR_UPLOAD_SIGNURL_FAIL, err.Error())
-			return
-		}
+		//err = checkAddVersionHash(req, false)
+		//if err != nil {
+		//	appG.Response(http.StatusOK, e.ERROR_UPLOAD_SIGNURL_FAIL, err.Error())
+		//	return
+		//}
 		var rsp, err = service.GetTHotUpdateVerManager().GetMaxVerInfo(req.Proj, req.Os)
 		if err != nil {
 			appG.Response(http.StatusOK, e.NO_RECORD, err.Error())
@@ -75,11 +75,11 @@ func GetVersionList(c *gin.Context) {
 	req := service.TGetVersionListReq{}
 	err := c.ShouldBind(&req)
 	if err == nil {
-		err = checkAddVersionHash(req, false)
-		if err != nil {
-			appG.Response(http.StatusOK, e.ERROR_UPLOAD_SIGNURL_FAIL, err.Error())
-			return
-		}
+		//err = checkAddVersionHash(req, false)
+		//if err != nil {
+		//	appG.Response(http.StatusOK, e.ERROR_UPLOAD_SIGNURL_FAIL, err.Error())
+		//	return
+		//}
 		var rsp, err = service.GetTHotUpdateVerManager().GetVersionList(req.Start, req.Limit)
 		if err != nil {
 			appG.Response(http.StatusOK, e.NO_RECORD, err.Error())
@@ -98,11 +98,11 @@ func ChangeStatus(c *gin.Context) {
 	req := service.TChangeStautsReq{}
 	err := c.ShouldBind(&req)
 	if err == nil {
-		err = checkAddVersionHash(req, false)
-		if err != nil {
-			appG.Response(http.StatusOK, e.ERROR_UPLOAD_SIGNURL_FAIL, err.Error())
-			return
-		}
+		//err = checkAddVersionHash(req, false)
+		//if err != nil {
+		//	appG.Response(http.StatusOK, e.ERROR_UPLOAD_SIGNURL_FAIL, err.Error())
+		//	return
+		//}
 		var rsp, err = service.GetTHotUpdateVerManager().ChangeStatus(req.ID, req.Status)
 		if err != nil {
 			appG.Response(http.StatusOK, e.NO_RECORD, err.Error())

+ 13 - 0
app/service/hotupdate.go

@@ -84,6 +84,15 @@ func (this *THotUpdateVerManager) GetMaxPubVerInfo(proj, os string) (*TGetVersio
 	}
 }
 
+func (this *THotUpdateVerManager) reloadPubVerBy(proj, os string) error {
+	versionRsp, err := this.findDBMaxPubVersion(proj, os)
+	if err != nil {
+		return err
+	}
+	this.updateVersionMap(proj, os, versionRsp)
+	return nil
+}
+
 // 内部API调用
 func (this *THotUpdateVerManager) GetMaxVerInfo(proj, os string) (*TGetVersionRsp, error) {
 	if proj == "" || os == "" {
@@ -123,6 +132,10 @@ func (this *THotUpdateVerManager) ChangeStatus(id string, status int16) (*TGetVe
 		return nil, err
 	}
 	logger.Info("after update mVersion", mVersion)
+	err = this.reloadPubVerBy(mVersion.Proj, mVersion.Os)
+	if err != nil {
+		return nil, fmt.Errorf("reloadPubVerBy error: %v", err)
+	}
 	vRsp := &TGetVersionRsp{}
 	vRsp.FromMVersion(mVersion)
 	return vRsp, nil

+ 2 - 2
app/service/type.go

@@ -43,8 +43,8 @@ type TAddVersionReq struct {
 	PackageUrl        string `form:"packageUrl" binding:"required" json:"packageUrl"`
 	RemoteManifestUrl string `form:"remoteManifestUrl" binding:"required" json:"remoteManifestUrl"`
 	RemoteVersionUrl  string `form:"remoteVersionUrl" binding:"required" json:"remoteVersionUrl"`
-	TimeSec           int64  `form:"timesec" binding:"required" json:"timesec"`
-	Sign              string `form:"sign" binding:"required" json:"sign"`
+	TimeSec           int64  `form:"timesec" json:"timesec"`
+	Sign              string `form:"sign" json:"sign"`
 }
 
 type TGetVersionListReq struct {

+ 306 - 0
middleware/whiteip/whiteipmgr.go

@@ -0,0 +1,306 @@
+/**
+ * @author chengliang
+ * @date 2026/1/27 11:18
+ * @brief
+ *
+ **/
+
+package whiteip
+
+import (
+	"fmt"
+	//"github.com/dlclark/regexp2" // 支持更复杂的正则表达式
+	"github.com/gin-gonic/gin"
+	"net"
+	"regexp"
+	"strings"
+)
+
+// 支持正则的IP白名单中间件
+type TWhiteIpMgr struct {
+	// 预编译的正则表达式列表
+	ipPatterns []*regexp.Regexp
+	// 是否允许内网IP(默认允许)
+	allowInternal bool
+	// 是否启用IP检查(默认启用)
+	enabled bool
+}
+
+// 创建白名单中间件
+func NewIPWhiteListMiddleware(patterns []string, allowInternal bool) *TWhiteIpMgr {
+	m := &TWhiteIpMgr{
+		allowInternal: allowInternal,
+		enabled:       true,
+	}
+
+	// 编译正则表达式
+	for _, pattern := range patterns {
+		// 支持简写语法
+		compiledPattern := normalizePattern(pattern)
+
+		re, err := regexp.Compile(compiledPattern)
+		if err != nil {
+			// 如果编译失败,记录错误但仍继续
+			fmt.Printf("warning:IP white list compile failed: %s, error: %v\n", pattern, err)
+			continue
+		}
+
+		m.ipPatterns = append(m.ipPatterns, re)
+	}
+
+	return m
+}
+
+func (this *TWhiteIpMgr) UpdatePattern(patterns []string, allowInternal bool, enabled bool) {
+	this.enabled = enabled
+	this.allowInternal = allowInternal
+	this.ipPatterns = make([]*regexp.Regexp, len(patterns))
+	for _, pattern := range patterns {
+		compiledPattern := normalizePattern(pattern)
+
+		re, err := regexp.Compile(compiledPattern)
+		if err != nil {
+			// 如果编译失败,记录错误但仍继续
+			fmt.Printf("warning:IP white list compile failed: %s, error: %v\n", pattern, err)
+			continue
+		}
+
+		this.ipPatterns = append(this.ipPatterns, re)
+	}
+
+}
+
+// 标准化IP模式
+func normalizePattern(pattern string) string {
+	// 支持CIDR表示法转换成正则表达式
+	if strings.Contains(pattern, "/") {
+		return cidrToRegex(pattern)
+	}
+
+	// 支持IP段简写,如: 192.168.1.* 或 192.168.1.1-100
+	pattern = strings.ReplaceAll(pattern, "*", `\d+`)
+
+	// 支持IP范围,如: 192.168.1.1-100
+	if strings.Contains(pattern, "-") {
+		parts := strings.Split(pattern, ".")
+		for i, part := range parts {
+			if strings.Contains(part, "-") {
+				rangeParts := strings.Split(part, "-")
+				if len(rangeParts) == 2 {
+					parts[i] = fmt.Sprintf(`(%s)`, buildRangeRegex(rangeParts[0], rangeParts[1]))
+				}
+			}
+		}
+		return "^" + strings.Join(parts, `\.`) + "$"
+	}
+
+	// 如果是普通IP,确保完全匹配
+	if !strings.HasPrefix(pattern, "^") {
+		pattern = "^" + pattern
+	}
+	if !strings.HasSuffix(pattern, "$") {
+		pattern = pattern + "$"
+	}
+
+	return pattern
+}
+
+// cidrToRegex 将CIDR表示法转换为正则表达式
+func cidrToRegex(cidr string) string {
+	ip, ipNet, err := net.ParseCIDR(cidr)
+	if err != nil {
+		return cidr // 如果解析失败,返回原字符串
+	}
+
+	// 获取网络掩码
+	mask := ipNet.Mask
+	ones, bits := mask.Size()
+
+	if bits != 32 {
+		return cidr // 仅支持IPv4
+	}
+
+	// 将IP转换为整数
+	ipInt := ipToInt(ip.To4())
+
+	// 计算网络地址和广播地址
+	network := ipInt & (^uint32(0) << uint32(bits-ones))
+	broadcast := network | (^uint32(0) >> uint32(ones))
+
+	// 生成正则表达式
+	return fmt.Sprintf(`^%s$`, ipRangeToRegex(intToIP(network), intToIP(broadcast)))
+}
+
+// ipToInt IP转整数
+func ipToInt(ip net.IP) uint32 {
+	if len(ip) == 16 {
+		return uint32(ip[12])<<24 | uint32(ip[13])<<16 | uint32(ip[14])<<8 | uint32(ip[15])
+	}
+	return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
+}
+
+// intToIP 整数转IP
+func intToIP(n uint32) net.IP {
+	return net.IPv4(byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
+}
+
+// ipRangeToRegex IP范围转正则表达式
+func ipRangeToRegex(start, end net.IP) string {
+	startParts := strings.Split(start.String(), ".")
+	endParts := strings.Split(end.String(), ".")
+
+	var regexParts []string
+	for i := 0; i < 4; i++ {
+		startNum := parseInt(startParts[i])
+		endNum := parseInt(endParts[i])
+
+		if startNum == endNum {
+			regexParts = append(regexParts, fmt.Sprintf("%d", startNum))
+		} else {
+			regexParts = append(regexParts, fmt.Sprintf("(%s)", buildRangeRegex(startParts[i], endParts[i])))
+		}
+	}
+
+	return strings.Join(regexParts, `\.`)
+}
+
+// buildRangeRegex 构建数字范围的正则表达式
+func buildRangeRegex(start, end string) string {
+	startNum := parseInt(start)
+	endNum := parseInt(end)
+
+	if startNum == endNum {
+		return start
+	}
+
+	// 简单处理:如果范围小,直接列举
+	if endNum-startNum < 10 {
+		var options []string
+		for i := startNum; i <= endNum; i++ {
+			options = append(options, fmt.Sprintf("%d", i))
+		}
+		return strings.Join(options, "|")
+	}
+
+	// 复杂范围,使用正则表达式模式
+	return fmt.Sprintf("%d|%d|[1-9]\\d{0,2}", startNum, endNum) // 简化处理
+}
+
+// parseInt 字符串转整数
+func parseInt(s string) int {
+	var result int
+	fmt.Sscanf(s, "%d", &result)
+	return result
+}
+
+// 检查是否为内网IP
+func isInternalIP(ipStr string) bool {
+	ip := net.ParseIP(ipStr)
+	if ip == nil {
+		return false
+	}
+
+	// IPv4 检查
+	if ip4 := ip.To4(); ip4 != nil {
+		return ip4[0] == 10 ||
+			(ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31) ||
+			(ip4[0] == 192 && ip4[1] == 168) ||
+			ip4[0] == 127
+	}
+
+	// IPv6 检查
+	return ip.IsLoopback() || ip.IsPrivate()
+}
+
+// GetClientIP 获取客户端真实IP
+func GetClientIP(c *gin.Context) string {
+	// 从代理头获取
+	if ip := c.GetHeader("X-Forwarded-For"); ip != "" {
+		ips := strings.Split(ip, ",")
+		if len(ips) > 0 {
+			return strings.TrimSpace(ips[0])
+		}
+	}
+
+	if ip := c.GetHeader("X-Real-IP"); ip != "" {
+		return ip
+	}
+
+	// 直接获取 RemoteAddr
+	remoteAddr := c.Request.RemoteAddr
+	if ip, _, err := net.SplitHostPort(remoteAddr); err == nil {
+		return ip
+	}
+	return remoteAddr
+}
+
+// Middleware 返回Gin中间件函数
+func (m *TWhiteIpMgr) Middleware() gin.HandlerFunc {
+	return func(c *gin.Context) {
+		if !m.enabled {
+			c.Next()
+			return
+		}
+
+		clientIP := GetClientIP(c)
+
+		// 检查内网IP
+		if m.allowInternal && isInternalIP(clientIP) {
+			c.Next()
+			return
+		}
+
+		// 检查白名单
+		if m.isIPAllowed(clientIP) {
+			c.Next()
+			return
+		}
+
+		// 拒绝访问
+		c.JSON(403, gin.H{
+			"code":    403,
+			"message": fmt.Sprintf("IP %s 不在白名单中", clientIP),
+			"data":    nil,
+		})
+		c.Abort()
+	}
+}
+
+// isIPAllowed 检查IP是否被允许
+func (m *TWhiteIpMgr) isIPAllowed(ip string) bool {
+	// 如果没有设置任何模式,拒绝所有(除非是内网)
+	if len(m.ipPatterns) == 0 {
+		return false
+	}
+
+	for _, pattern := range m.ipPatterns {
+		if pattern.MatchString(ip) {
+			return true
+		}
+	}
+
+	return false
+}
+
+// Enable 启用中间件
+func (m *TWhiteIpMgr) Enable() {
+	m.enabled = true
+}
+
+// Disable 禁用中间件
+func (m *TWhiteIpMgr) Disable() {
+	m.enabled = false
+}
+
+// AddPattern 动态添加IP模式
+func (m *TWhiteIpMgr) AddPattern(pattern string) error {
+	compiledPattern := normalizePattern(pattern)
+
+	re, err := regexp.Compile(compiledPattern)
+	if err != nil {
+		return err
+	}
+
+	m.ipPatterns = append(m.ipPatterns, re)
+	return nil
+}

+ 17 - 0
model/mongo/ipwhite/ipwhite.go

@@ -0,0 +1,17 @@
+/**
+ * @author chengliang
+ * @date 2026/1/27 11:38
+ * @brief
+ *
+ **/
+
+package ipwhite
+
+import "github.com/kamva/mgm/v3"
+
+type TIpWhiteList struct {
+	mgm.DefaultModel `bson:",inline"`
+	Enabled          bool           `json:"enabled" bson:"enabled"`             // 白名单是否开启
+	AllowInternal    bool           `json:"allowInternal" bson:"allowInternal"` // 是否允许内网
+	IpPatternsMap    map[string]int `json:"ipPatternsMap" bson:"ipPatternsMap"` // ip 列表 支持正则
+}