From e4500bf8bfee72fb6d115bed9e19c04b63b67a77 Mon Sep 17 00:00:00 2001 From: ckt1031 <65409152+ckt1031@users.noreply.github.com> Date: Fri, 14 Jul 2023 22:41:22 +0800 Subject: [PATCH] featL add token-side model selection --- controller/relay-text.go | 21 ++++++++++++ controller/token.go | 4 ++- i18n/en.json | 3 +- model/token.go | 6 ++-- web/src/pages/Token/EditToken.js | 56 +++++++++++++++++++++++++++++++- 5 files changed, 85 insertions(+), 5 deletions(-) diff --git a/controller/relay-text.go b/controller/relay-text.go index cf20fc72..36b18a1b 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -69,6 +69,27 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { isModelMapped = true } } + + // Get token info + tokenInfo, err := model.GetTokenById(tokenId) + + if err != nil { + return errorWrapper(err, "get_token_info_failed", http.StatusInternalServerError) + } + + hasModelAvailable := func() bool { + for _, token := range strings.Split(tokenInfo.Models, ",") { + if token == textRequest.Model { + return true + } + } + return false + }() + + if !hasModelAvailable { + return errorWrapper(errors.New("model not available for use"), "model_not_available_for_use", http.StatusBadRequest) + } + baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() if c.GetString("base_url") != "" { diff --git a/controller/token.go b/controller/token.go index 5341ea3a..c71fc288 100644 --- a/controller/token.go +++ b/controller/token.go @@ -1,11 +1,12 @@ package controller import ( - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" + + "github.com/gin-gonic/gin" ) func GetAllTokens(c *gin.Context) { @@ -203,6 +204,7 @@ func UpdateToken(c *gin.Context) { cleanToken.ExpiredTime = token.ExpiredTime cleanToken.RemainQuota = token.RemainQuota cleanToken.UnlimitedQuota = token.UnlimitedQuota + cleanToken.Models = token.Models } err = cleanToken.Update() if err != nil { diff --git a/i18n/en.json b/i18n/en.json index d7c7af4e..c5edc88b 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -523,5 +523,6 @@ "该 Discord 账户已被绑定": "The Discord account has been bound", "管理员未开启通过 Discord 登录以及注册": "The administrator has not enabled login and registration via Discord", "无法启用 Discord OAuth,请先填入 Discord Client ID 以及 Discord Client Secret!": "Unable to enable Discord OAuth, please fill in the Discord Client ID and Discord Client Secret first!", - "兑换失败,": "Redemption failed, " + "兑换失败,": "Redemption failed, ", + "请选择此密钥支持的模型": "Please select the models supported by this key" } diff --git a/model/token.go b/model/token.go index 7cd226c6..e6ccc537 100644 --- a/model/token.go +++ b/model/token.go @@ -3,8 +3,9 @@ package model import ( "errors" "fmt" - "gorm.io/gorm" "one-api/common" + + "gorm.io/gorm" ) type Token struct { @@ -19,6 +20,7 @@ type Token struct { RemainQuota int `json:"remain_quota" gorm:"default:0"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota + Models string `json:"models"` } func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { @@ -99,7 +101,7 @@ func (token *Token) Insert() error { // Update Make sure your token's fields is completed, because this will update non-zero values func (token *Token) Update() error { var err error - err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error + err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models").Updates(token).Error return err } diff --git a/web/src/pages/Token/EditToken.js b/web/src/pages/Token/EditToken.js index a4b6044f..ee8b6eea 100644 --- a/web/src/pages/Token/EditToken.js +++ b/web/src/pages/Token/EditToken.js @@ -13,8 +13,12 @@ const EditToken = () => { name: '', remain_quota: isEdit ? 0 : 500000, expired_time: -1, - unlimited_quota: false + unlimited_quota: false, + models: [], }; + const [modelOptions, setModelOptions] = useState([]); + const [basicModels, setBasicModels] = useState([]); + const [fullModels, setFullModels] = useState([]); const [inputs, setInputs] = useState(originInputs); const { name, remain_quota, expired_time, unlimited_quota } = inputs; @@ -41,6 +45,21 @@ const EditToken = () => { setInputs({ ...inputs, unlimited_quota: !unlimited_quota }); }; + const fetchModels = async () => { + try { + let res = await API.get(`/api/channel/models`); + setModelOptions(res.data.data.map((model) => ({ + key: model.id, + text: model.id, + value: model.id + }))); + setFullModels(res.data.data.map((model) => model.id)); + setBasicModels(res.data.data.filter((model) => !model.id.startsWith('gpt-4')).map((model) => model.id)); + } catch (error) { + showError(error.message); + } + }; + const loadToken = async () => { let res = await API.get(`/api/token/${tokenId}`); const { success, message, data } = res.data; @@ -72,6 +91,11 @@ const EditToken = () => { } localInputs.expired_time = Math.ceil(time / 1000); } + if (inputs.models.length === 0) { + showError('请至少选择一个模型!'); + return; + } + localInputs.models = localInputs.models.join(','); let res; if (isEdit) { res = await API.put(`/api/token/`, { ...localInputs, id: parseInt(tokenId) }); @@ -91,6 +115,10 @@ const EditToken = () => { } }; + useEffect(() => { + fetchModels().then(); + }, []); + return ( <> @@ -151,6 +179,32 @@ const EditToken = () => { + + + +
+ + + +