diff --git a/common/constants.go b/common/constants.go index 325454d4..31036685 100644 --- a/common/constants.go +++ b/common/constants.go @@ -92,3 +92,11 @@ var ChannelBaseURLs = []string{ "https://hunyuan.cloud.tencent.com", // 23 "https://generativelanguage.googleapis.com", // 24 } + +const ( + ConfigKeyPrefix = "cfg_" + + ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version" + ConfigKeyLibraryID = ConfigKeyPrefix + "library_id" + ConfigKeyPlugin = ConfigKeyPrefix + "plugin" +) diff --git a/middleware/distributor.go b/middleware/distributor.go index 0ed250fd..704f6236 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -83,17 +83,22 @@ func Distribute() func(c *gin.Context) { c.Set("model_mapping", channel.GetModelMapping()) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.GetBaseURL()) + // this is for backward compatibility switch channel.Type { case common.ChannelTypeAzure: - c.Set("api_version", channel.Other) + c.Set(common.ConfigKeyAPIVersion, channel.Other) case common.ChannelTypeXunfei: - c.Set("api_version", channel.Other) + c.Set(common.ConfigKeyAPIVersion, channel.Other) case common.ChannelTypeGemini: - c.Set("api_version", channel.Other) + c.Set(common.ConfigKeyAPIVersion, channel.Other) case common.ChannelTypeAIProxyLibrary: - c.Set("library_id", channel.Other) + c.Set(common.ConfigKeyLibraryID, channel.Other) case common.ChannelTypeAli: - c.Set("plugin", channel.Other) + c.Set(common.ConfigKeyPlugin, channel.Other) + } + cfg, _ := channel.LoadConfig() + for k, v := range cfg { + c.Set(common.ConfigKeyPrefix+k, v) } c.Next() } diff --git a/model/channel.go b/model/channel.go index 0503a620..19af2263 100644 --- a/model/channel.go +++ b/model/channel.go @@ -21,7 +21,7 @@ type Channel struct { TestTime int64 `json:"test_time" gorm:"bigint"` ResponseTime int `json:"response_time"` // in milliseconds BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"` - Other string `json:"other"` + Other string `json:"other"` // DEPRECATED: please save config to field Config Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` Models string `json:"models"` @@ -29,6 +29,7 @@ type Channel struct { UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` + Config string `json:"config"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { @@ -155,6 +156,18 @@ func (channel *Channel) Delete() error { return err } +func (channel *Channel) LoadConfig() (map[string]string, error) { + if channel.Config == "" { + return nil, nil + } + cfg := make(map[string]string) + err := json.Unmarshal([]byte(channel.Config), &cfg) + if err != nil { + return nil, err + } + return cfg, nil +} + func UpdateChannelStatusById(id int, status int) { err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) if err != nil { diff --git a/relay/channel/aiproxy/adaptor.go b/relay/channel/aiproxy/adaptor.go index eab79c30..34fd62f5 100644 --- a/relay/channel/aiproxy/adaptor.go +++ b/relay/channel/aiproxy/adaptor.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" @@ -29,7 +30,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } aiProxyLibraryRequest := ConvertRequest(*request) - aiProxyLibraryRequest.LibraryId = c.GetString("library_id") + aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID) return aiProxyLibraryRequest, nil } diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 177aa49e..85986cce 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" @@ -29,8 +30,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *ut if meta.IsStream { req.Header.Set("X-DashScope-SSE", "enable") } - if c.GetString("plugin") != "" { - req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) + if c.GetString(common.ConfigKeyPlugin) != "" { + req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin)) } return nil } diff --git a/relay/channel/xunfei/main.go b/relay/channel/xunfei/main.go index 8efade87..4a777075 100644 --- a/relay/channel/xunfei/main.go +++ b/relay/channel/xunfei/main.go @@ -246,7 +246,7 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, query := c.Request.URL.Query() apiVersion := query.Get("api-version") if apiVersion == "" { - apiVersion = c.GetString("api_version") + apiVersion = c.GetString(common.ConfigKeyAPIVersion) } if apiVersion == "" { apiVersion = "v1.1" diff --git a/relay/controller/temp.go b/relay/controller/temp.go index 75aea4ff..ac73fac3 100644 --- a/relay/controller/temp.go +++ b/relay/controller/temp.go @@ -174,7 +174,7 @@ func GetRequestBody(c *gin.Context, textRequest model.GeneralOpenAIRequest, isMo requestBody = bytes.NewBuffer(jsonStr) case constant.APITypeAIProxyLibrary: aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest) - aiProxyLibraryRequest.LibraryId = c.GetString("library_id") + aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID) jsonStr, err := json.Marshal(aiProxyLibraryRequest) if err != nil { return nil, err @@ -222,8 +222,8 @@ func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, i if isStream { req.Header.Set("X-DashScope-SSE", "enable") } - if c.GetString("plugin") != "" { - req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) + if c.GetString(common.ConfigKeyPlugin) != "" { + req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin)) } case constant.APITypeTencent: req.Header.Set("Authorization", apiKey) diff --git a/relay/util/common.go b/relay/util/common.go index 21e1dfaf..6d993378 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -162,7 +162,7 @@ func GetAzureAPIVersion(c *gin.Context) string { query := c.Request.URL.Query() apiVersion := query.Get("api-version") if apiVersion == "" { - apiVersion = c.GetString("api_version") + apiVersion = c.GetString(common.ConfigKeyAPIVersion) } return apiVersion } diff --git a/relay/util/relay_meta.go b/relay/util/relay_meta.go index 58afc23f..31b9d2b4 100644 --- a/relay/util/relay_meta.go +++ b/relay/util/relay_meta.go @@ -39,7 +39,7 @@ func GetRelayMeta(c *gin.Context) *RelayMeta { Group: c.GetString("group"), ModelMapping: c.GetStringMapString("model_mapping"), BaseURL: c.GetString("base_url"), - APIVersion: c.GetString("api_version"), + APIVersion: c.GetString(common.ConfigKeyAPIVersion), APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), Config: nil, RequestURLPath: c.Request.URL.String(),