diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 03d452ad..1d5d019c 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -13,7 +13,6 @@ import ( providersBase "one-api/providers/base" "one-api/types" "reflect" - "strconv" "github.com/gin-gonic/gin" "github.com/go-playground/validator/v10" @@ -55,9 +54,9 @@ func GetValidFieldName(err error, obj interface{}) string { } func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail bool) { - channelId, ok := c.Get("channelId") - if ok { - channel, fail = fetchChannelById(c, channelId.(int)) + channelId := c.GetInt("channelId") + if channelId > 0 { + channel, fail = fetchChannelById(c, channelId) if fail { return } @@ -73,13 +72,8 @@ func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fai return } -func fetchChannelById(c *gin.Context, channelId any) (*model.Channel, bool) { - id, err := strconv.Atoi(channelId.(string)) - if err != nil { - common.AbortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") - return nil, true - } - channel, err := model.GetChannelById(id, true) +func fetchChannelById(c *gin.Context, channelId int) (*model.Channel, bool) { + channel, err := model.GetChannelById(channelId, true) if err != nil { common.AbortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") return nil, true diff --git a/middleware/auth.go b/middleware/auth.go index ad7e64b7..e6f48c62 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,12 +1,13 @@ package middleware import ( - "github.com/gin-contrib/sessions" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strings" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" ) func authHelper(c *gin.Context, minRole int) { @@ -108,7 +109,12 @@ func TokenAuth() func(c *gin.Context) { c.Set("token_name", token.Name) if len(parts) > 1 { if model.IsAdmin(token.UserId) { - c.Set("channelId", parts[1]) + channelId := common.String2Int(parts[1]) + if channelId == 0 { + abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id") + return + } + c.Set("channelId", channelId) } else { abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") return