diff --git a/common/network/ip.go b/common/network/ip.go index fd0c3f5e..e774351e 100644 --- a/common/network/ip.go +++ b/common/network/ip.go @@ -2,14 +2,20 @@ package network import ( "context" + "fmt" "github.com/songquanpeng/one-api/common/logger" "net" ) +func IsValidSubnet(subnet string) error { + _, _, err := net.ParseCIDR(subnet) + return fmt.Errorf("failed to parse subnet: %w", err) +} + func IsIpInSubnet(ctx context.Context, ip string, subnet string) bool { _, ipNet, err := net.ParseCIDR(subnet) if err != nil { - logger.Errorf(ctx, "failed to parse subnet: %s, subnet: %s", err.Error(), subnet) + logger.Errorf(ctx, "failed to parse subnet: %s", err.Error()) return false } return ipNet.Contains(net.ParseIP(ip)) diff --git a/controller/token.go b/controller/token.go index c0b1d24b..068cbab7 100644 --- a/controller/token.go +++ b/controller/token.go @@ -1,10 +1,12 @@ package controller import ( + "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/network" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -104,6 +106,19 @@ func GetTokenStatus(c *gin.Context) { }) } +func validateToken(c *gin.Context, token model.Token) error { + if len(token.Name) > 30 { + return fmt.Errorf("令牌名称过长") + } + if token.Subnet != nil && *token.Subnet != "" { + err := network.IsValidSubnet(*token.Subnet) + if err != nil { + return fmt.Errorf("无效的网段:%s", err.Error()) + } + } + return nil +} + func AddToken(c *gin.Context) { token := model.Token{} err := c.ShouldBindJSON(&token) @@ -114,13 +129,15 @@ func AddToken(c *gin.Context) { }) return } - if len(token.Name) > 30 { + err = validateToken(c, token) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "令牌名称过长", + "message": fmt.Sprintf("参数错误:%s", err.Error()), }) return } + cleanToken := model.Token{ UserId: c.GetInt("id"), Name: token.Name, @@ -179,10 +196,11 @@ func UpdateToken(c *gin.Context) { }) return } - if len(token.Name) > 30 { + err = validateToken(c, token) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "令牌名称过长", + "message": fmt.Sprintf("参数错误:%s", err.Error()), }) return }