refactor: use config field to save config

This commit is contained in:
JustSong 2024-02-18 02:22:50 +08:00
parent 1aa374ccfb
commit de9a58ca0b
9 changed files with 43 additions and 15 deletions

View File

@ -92,3 +92,11 @@ var ChannelBaseURLs = []string{
"https://hunyuan.cloud.tencent.com", // 23 "https://hunyuan.cloud.tencent.com", // 23
"https://generativelanguage.googleapis.com", // 24 "https://generativelanguage.googleapis.com", // 24
} }
const (
ConfigKeyPrefix = "cfg_"
ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version"
ConfigKeyLibraryID = ConfigKeyPrefix + "library_id"
ConfigKeyPlugin = ConfigKeyPrefix + "plugin"
)

View File

@ -83,17 +83,22 @@ func Distribute() func(c *gin.Context) {
c.Set("model_mapping", channel.GetModelMapping()) c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
// this is for backward compatibility
switch channel.Type { switch channel.Type {
case common.ChannelTypeAzure: case common.ChannelTypeAzure:
c.Set("api_version", channel.Other) c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeXunfei: case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other) c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeGemini: case common.ChannelTypeGemini:
c.Set("api_version", channel.Other) c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeAIProxyLibrary: case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other) c.Set(common.ConfigKeyLibraryID, channel.Other)
case common.ChannelTypeAli: case common.ChannelTypeAli:
c.Set("plugin", channel.Other) c.Set(common.ConfigKeyPlugin, channel.Other)
}
cfg, _ := channel.LoadConfig()
for k, v := range cfg {
c.Set(common.ConfigKeyPrefix+k, v)
} }
c.Next() c.Next()
} }

View File

@ -21,7 +21,7 @@ type Channel struct {
TestTime int64 `json:"test_time" gorm:"bigint"` TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds ResponseTime int `json:"response_time"` // in milliseconds
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"` BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
Other string `json:"other"` Other string `json:"other"` // DEPRECATED: please save config to field Config
Balance float64 `json:"balance"` // in USD Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"` Models string `json:"models"`
@ -29,6 +29,7 @@ type Channel struct {
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"` Priority *int64 `json:"priority" gorm:"bigint;default:0"`
Config string `json:"config"`
} }
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@ -155,6 +156,18 @@ func (channel *Channel) Delete() error {
return err return err
} }
func (channel *Channel) LoadConfig() (map[string]string, error) {
if channel.Config == "" {
return nil, nil
}
cfg := make(map[string]string)
err := json.Unmarshal([]byte(channel.Config), &cfg)
if err != nil {
return nil, err
}
return cfg, nil
}
func UpdateChannelStatusById(id int, status int) { func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
if err != nil { if err != nil {

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
@ -29,7 +30,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
aiProxyLibraryRequest := ConvertRequest(*request) aiProxyLibraryRequest := ConvertRequest(*request)
aiProxyLibraryRequest.LibraryId = c.GetString("library_id") aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID)
return aiProxyLibraryRequest, nil return aiProxyLibraryRequest, nil
} }

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
@ -29,8 +30,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *ut
if meta.IsStream { if meta.IsStream {
req.Header.Set("X-DashScope-SSE", "enable") req.Header.Set("X-DashScope-SSE", "enable")
} }
if c.GetString("plugin") != "" { if c.GetString(common.ConfigKeyPlugin) != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin))
} }
return nil return nil
} }

View File

@ -246,7 +246,7 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string,
query := c.Request.URL.Query() query := c.Request.URL.Query()
apiVersion := query.Get("api-version") apiVersion := query.Get("api-version")
if apiVersion == "" { if apiVersion == "" {
apiVersion = c.GetString("api_version") apiVersion = c.GetString(common.ConfigKeyAPIVersion)
} }
if apiVersion == "" { if apiVersion == "" {
apiVersion = "v1.1" apiVersion = "v1.1"

View File

@ -174,7 +174,7 @@ func GetRequestBody(c *gin.Context, textRequest model.GeneralOpenAIRequest, isMo
requestBody = bytes.NewBuffer(jsonStr) requestBody = bytes.NewBuffer(jsonStr)
case constant.APITypeAIProxyLibrary: case constant.APITypeAIProxyLibrary:
aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest) aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest)
aiProxyLibraryRequest.LibraryId = c.GetString("library_id") aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID)
jsonStr, err := json.Marshal(aiProxyLibraryRequest) jsonStr, err := json.Marshal(aiProxyLibraryRequest)
if err != nil { if err != nil {
return nil, err return nil, err
@ -222,8 +222,8 @@ func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, i
if isStream { if isStream {
req.Header.Set("X-DashScope-SSE", "enable") req.Header.Set("X-DashScope-SSE", "enable")
} }
if c.GetString("plugin") != "" { if c.GetString(common.ConfigKeyPlugin) != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin))
} }
case constant.APITypeTencent: case constant.APITypeTencent:
req.Header.Set("Authorization", apiKey) req.Header.Set("Authorization", apiKey)

View File

@ -162,7 +162,7 @@ func GetAzureAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query() query := c.Request.URL.Query()
apiVersion := query.Get("api-version") apiVersion := query.Get("api-version")
if apiVersion == "" { if apiVersion == "" {
apiVersion = c.GetString("api_version") apiVersion = c.GetString(common.ConfigKeyAPIVersion)
} }
return apiVersion return apiVersion
} }

View File

@ -39,7 +39,7 @@ func GetRelayMeta(c *gin.Context) *RelayMeta {
Group: c.GetString("group"), Group: c.GetString("group"),
ModelMapping: c.GetStringMapString("model_mapping"), ModelMapping: c.GetStringMapString("model_mapping"),
BaseURL: c.GetString("base_url"), BaseURL: c.GetString("base_url"),
APIVersion: c.GetString("api_version"), APIVersion: c.GetString(common.ConfigKeyAPIVersion),
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Config: nil, Config: nil,
RequestURLPath: c.Request.URL.String(), RequestURLPath: c.Request.URL.String(),