feat: add subnet validation (#1275)
This commit is contained in:
parent
c49778c254
commit
68605800af
@ -2,14 +2,20 @@ package network
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func IsValidSubnet(subnet string) error {
|
||||||
|
_, _, err := net.ParseCIDR(subnet)
|
||||||
|
return fmt.Errorf("failed to parse subnet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
func IsIpInSubnet(ctx context.Context, ip string, subnet string) bool {
|
func IsIpInSubnet(ctx context.Context, ip string, subnet string) bool {
|
||||||
_, ipNet, err := net.ParseCIDR(subnet)
|
_, ipNet, err := net.ParseCIDR(subnet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf(ctx, "failed to parse subnet: %s, subnet: %s", err.Error(), subnet)
|
logger.Errorf(ctx, "failed to parse subnet: %s", err.Error())
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return ipNet.Contains(net.ParseIP(ip))
|
return ipNet.Contains(net.ParseIP(ip))
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
|
"github.com/songquanpeng/one-api/common/network"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -104,6 +106,19 @@ func GetTokenStatus(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateToken(c *gin.Context, token model.Token) error {
|
||||||
|
if len(token.Name) > 30 {
|
||||||
|
return fmt.Errorf("令牌名称过长")
|
||||||
|
}
|
||||||
|
if token.Subnet != nil && *token.Subnet != "" {
|
||||||
|
err := network.IsValidSubnet(*token.Subnet)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("无效的网段:%s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func AddToken(c *gin.Context) {
|
func AddToken(c *gin.Context) {
|
||||||
token := model.Token{}
|
token := model.Token{}
|
||||||
err := c.ShouldBindJSON(&token)
|
err := c.ShouldBindJSON(&token)
|
||||||
@ -114,13 +129,15 @@ func AddToken(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(token.Name) > 30 {
|
err = validateToken(c, token)
|
||||||
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "令牌名称过长",
|
"message": fmt.Sprintf("参数错误:%s", err.Error()),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanToken := model.Token{
|
cleanToken := model.Token{
|
||||||
UserId: c.GetInt("id"),
|
UserId: c.GetInt("id"),
|
||||||
Name: token.Name,
|
Name: token.Name,
|
||||||
@ -179,10 +196,11 @@ func UpdateToken(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(token.Name) > 30 {
|
err = validateToken(c, token)
|
||||||
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "令牌名称过长",
|
"message": fmt.Sprintf("参数错误:%s", err.Error()),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user