From 348adc2b025b3ef72eaf324c62d021622383e655 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 5 Apr 2024 17:25:28 +0800 Subject: [PATCH] feat: able to set multiple subnets --- common/network/ip.go | 31 ++++++++++++++++++++++-- common/network/ip_test.go | 4 +-- controller/token.go | 2 +- middleware/auth.go | 2 +- web/default/src/pages/Token/EditToken.js | 2 +- 5 files changed, 34 insertions(+), 7 deletions(-) diff --git a/common/network/ip.go b/common/network/ip.go index eca79499..0fbe5e6f 100644 --- a/common/network/ip.go +++ b/common/network/ip.go @@ -5,9 +5,18 @@ import ( "fmt" "github.com/songquanpeng/one-api/common/logger" "net" + "strings" ) -func IsValidSubnet(subnet string) error { +func splitSubnets(subnets string) []string { + res := strings.Split(subnets, ",") + for i := 0; i < len(res); i++ { + res[i] = strings.TrimSpace(res[i]) + } + return res +} + +func isValidSubnet(subnet string) error { _, _, err := net.ParseCIDR(subnet) if err != nil { return fmt.Errorf("failed to parse subnet: %w", err) @@ -15,7 +24,7 @@ func IsValidSubnet(subnet string) error { return nil } -func IsIpInSubnet(ctx context.Context, ip string, subnet string) bool { +func isIpInSubnet(ctx context.Context, ip string, subnet string) bool { _, ipNet, err := net.ParseCIDR(subnet) if err != nil { logger.Errorf(ctx, "failed to parse subnet: %s", err.Error()) @@ -23,3 +32,21 @@ func IsIpInSubnet(ctx context.Context, ip string, subnet string) bool { } return ipNet.Contains(net.ParseIP(ip)) } + +func IsValidSubnets(subnets string) error { + for _, subnet := range splitSubnets(subnets) { + if err := isValidSubnet(subnet); err != nil { + return err + } + } + return nil +} + +func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool { + for _, subnet := range splitSubnets(subnets) { + if isIpInSubnet(ctx, ip, subnet) { + return true + } + } + return false +} diff --git a/common/network/ip_test.go b/common/network/ip_test.go index 72ccf0dd..6c593458 100644 --- a/common/network/ip_test.go +++ b/common/network/ip_test.go @@ -13,7 +13,7 @@ func TestIsIpInSubnet(t *testing.T) { ip2 := "125.216.250.89" subnet := "192.168.0.0/24" Convey("TestIsIpInSubnet", t, func() { - So(IsIpInSubnet(ctx, ip1, subnet), ShouldBeTrue) - So(IsIpInSubnet(ctx, ip2, subnet), ShouldBeFalse) + So(isIpInSubnet(ctx, ip1, subnet), ShouldBeTrue) + So(isIpInSubnet(ctx, ip2, subnet), ShouldBeFalse) }) } diff --git a/controller/token.go b/controller/token.go index 068cbab7..7d20371c 100644 --- a/controller/token.go +++ b/controller/token.go @@ -111,7 +111,7 @@ func validateToken(c *gin.Context, token model.Token) error { return fmt.Errorf("令牌名称过长") } if token.Subnet != nil && *token.Subnet != "" { - err := network.IsValidSubnet(*token.Subnet) + err := network.IsValidSubnets(*token.Subnet) if err != nil { return fmt.Errorf("无效的网段:%s", err.Error()) } diff --git a/middleware/auth.go b/middleware/auth.go index 1630c565..223cef3d 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -102,7 +102,7 @@ func TokenAuth() func(c *gin.Context) { return } if token.Subnet != nil && *token.Subnet != "" { - if !network.IsIpInSubnet(ctx, c.ClientIP(), *token.Subnet) { + if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) { abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP())) return } diff --git a/web/default/src/pages/Token/EditToken.js b/web/default/src/pages/Token/EditToken.js index 077f6254..684b7e92 100644 --- a/web/default/src/pages/Token/EditToken.js +++ b/web/default/src/pages/Token/EditToken.js @@ -158,7 +158,7 @@ const EditToken = () => {