refactor: use config field to save config

This commit is contained in:
JustSong 2024-02-18 02:22:50 +08:00
parent 1aa374ccfb
commit de9a58ca0b
9 changed files with 43 additions and 15 deletions

View File

@ -92,3 +92,11 @@ var ChannelBaseURLs = []string{
"https://hunyuan.cloud.tencent.com", // 23
"https://generativelanguage.googleapis.com", // 24
}
const (
ConfigKeyPrefix = "cfg_"
ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version"
ConfigKeyLibraryID = ConfigKeyPrefix + "library_id"
ConfigKeyPlugin = ConfigKeyPrefix + "plugin"
)

View File

@ -83,17 +83,22 @@ func Distribute() func(c *gin.Context) {
c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL())
// this is for backward compatibility
switch channel.Type {
case common.ChannelTypeAzure:
c.Set("api_version", channel.Other)
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other)
c.Set(common.ConfigKeyLibraryID, channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
c.Set(common.ConfigKeyPlugin, channel.Other)
}
cfg, _ := channel.LoadConfig()
for k, v := range cfg {
c.Set(common.ConfigKeyPrefix+k, v)
}
c.Next()
}

View File

@ -21,7 +21,7 @@ type Channel struct {
TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
Other string `json:"other"`
Other string `json:"other"` // DEPRECATED: please save config to field Config
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"`
@ -29,6 +29,7 @@ type Channel struct {
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
Config string `json:"config"`
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@ -155,6 +156,18 @@ func (channel *Channel) Delete() error {
return err
}
func (channel *Channel) LoadConfig() (map[string]string, error) {
if channel.Config == "" {
return nil, nil
}
cfg := make(map[string]string)
err := json.Unmarshal([]byte(channel.Config), &cfg)
if err != nil {
return nil, err
}
return cfg, nil
}
func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
if err != nil {

View File

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
@ -29,7 +30,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil")
}
aiProxyLibraryRequest := ConvertRequest(*request)
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID)
return aiProxyLibraryRequest, nil
}

View File

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
@ -29,8 +30,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *ut
if meta.IsStream {
req.Header.Set("X-DashScope-SSE", "enable")
}
if c.GetString("plugin") != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
if c.GetString(common.ConfigKeyPlugin) != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin))
}
return nil
}

View File

@ -246,7 +246,7 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string,
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
apiVersion = c.GetString(common.ConfigKeyAPIVersion)
}
if apiVersion == "" {
apiVersion = "v1.1"

View File

@ -174,7 +174,7 @@ func GetRequestBody(c *gin.Context, textRequest model.GeneralOpenAIRequest, isMo
requestBody = bytes.NewBuffer(jsonStr)
case constant.APITypeAIProxyLibrary:
aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest)
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID)
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
if err != nil {
return nil, err
@ -222,8 +222,8 @@ func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, i
if isStream {
req.Header.Set("X-DashScope-SSE", "enable")
}
if c.GetString("plugin") != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
if c.GetString(common.ConfigKeyPlugin) != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin))
}
case constant.APITypeTencent:
req.Header.Set("Authorization", apiKey)

View File

@ -162,7 +162,7 @@ func GetAzureAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
apiVersion = c.GetString(common.ConfigKeyAPIVersion)
}
return apiVersion
}

View File

@ -39,7 +39,7 @@ func GetRelayMeta(c *gin.Context) *RelayMeta {
Group: c.GetString("group"),
ModelMapping: c.GetStringMapString("model_mapping"),
BaseURL: c.GetString("base_url"),
APIVersion: c.GetString("api_version"),
APIVersion: c.GetString(common.ConfigKeyAPIVersion),
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Config: nil,
RequestURLPath: c.Request.URL.String(),