From eeb867da10f3c656033031ee468d9238ae18081f Mon Sep 17 00:00:00 2001 From: Martial BE Date: Tue, 26 Dec 2023 16:40:50 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20Change=20the=20method=20of=20get?= =?UTF-8?q?ting=20channel=20parameters?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel-billing.go | 5 +---- controller/channel-test.go | 6 ++---- controller/channel.go | 10 ++-------- controller/github.go | 6 +++--- controller/group.go | 2 +- controller/log.go | 10 ++-------- controller/misc.go | 7 ------- controller/relay-chat.go | 2 +- controller/relay-completions.go | 2 +- controller/relay-embeddings.go | 2 +- controller/relay-image-edits.go | 2 +- controller/relay-image-generations.go | 2 +- controller/relay-image-variationsy.go | 2 +- controller/relay-moderations.go | 2 +- controller/relay-speech.go | 2 +- controller/relay-transcriptions.go | 2 +- controller/relay-translations.go | 2 +- controller/relay-utils.go | 26 ++------------------------ controller/token.go | 9 ++------- controller/user.go | 14 -------------- controller/wechat.go | 4 ++-- providers/ali/base.go | 6 +++--- providers/azureSpeech/base.go | 2 +- providers/baidu/base.go | 2 +- providers/base/common.go | 10 ++++++++-- providers/base/interface.go | 1 + providers/claude/base.go | 2 +- providers/gemini/base.go | 6 +++--- providers/openai/base.go | 6 +++--- providers/palm/base.go | 2 +- providers/providers.go | 23 ++++++++++++++--------- providers/tencent/base.go | 2 +- providers/xunfei/base.go | 4 ++-- providers/zhipu/base.go | 2 +- 34 files changed, 67 insertions(+), 120 deletions(-) diff --git a/controller/channel-billing.go b/controller/channel-billing.go index da80a26d..ed4447c0 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -55,10 +55,9 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { c, _ := gin.CreateTestContext(w) c.Request = req - setChannelToContext(c, channel) req.Header.Set("Content-Type", "application/json") - provider := providers.GetProvider(channel.Type, c) + provider := providers.GetProvider(channel, c) if provider == nil { return 0, errors.New("provider not found") } @@ -102,7 +101,6 @@ func UpdateChannelBalance(c *gin.Context) { "message": "", "balance": balance, }) - return } func updateAllChannelsBalance() error { @@ -146,7 +144,6 @@ func UpdateAllChannelsBalance(c *gin.Context) { "success": true, "message": "", }) - return } func AutomaticallyUpdateChannels(frequency int) { diff --git a/controller/channel-test.go b/controller/channel-test.go index e8024778..15faba24 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -29,7 +29,6 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e c, _ := gin.CreateTestContext(w) c.Request = req - setChannelToContext(c, channel) // 创建映射 channelTypeToModel := map[int]string{ common.ChannelTypePaLM: "PaLM-2", @@ -50,7 +49,7 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e } request.Model = model - provider := providers.GetProvider(channel.Type, c) + provider := providers.GetProvider(channel, c) if provider == nil { return errors.New("channel not implemented"), nil } @@ -74,7 +73,7 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e } if Usage.CompletionTokens == 0 { - return errors.New(fmt.Sprintf("channel %s, message 补全 tokens 非预期返回 0", channel.Name)), nil + return fmt.Errorf("channel %s, message 补全 tokens 非预期返回 0", channel.Name), nil } return nil, nil @@ -132,7 +131,6 @@ func TestChannel(c *gin.Context) { "message": "", "time": consumedTime, }) - return } var testAllChannelsLock sync.Mutex diff --git a/controller/channel.go b/controller/channel.go index 904abc23..7b2a45f7 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -1,12 +1,13 @@ package controller import ( - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" "strings" + + "github.com/gin-gonic/gin" ) func GetAllChannels(c *gin.Context) { @@ -27,7 +28,6 @@ func GetAllChannels(c *gin.Context) { "message": "", "data": channels, }) - return } func SearchChannels(c *gin.Context) { @@ -45,7 +45,6 @@ func SearchChannels(c *gin.Context) { "message": "", "data": channels, }) - return } func GetChannel(c *gin.Context) { @@ -70,7 +69,6 @@ func GetChannel(c *gin.Context) { "message": "", "data": channel, }) - return } func AddChannel(c *gin.Context) { @@ -106,7 +104,6 @@ func AddChannel(c *gin.Context) { "success": true, "message": "", }) - return } func DeleteChannel(c *gin.Context) { @@ -124,7 +121,6 @@ func DeleteChannel(c *gin.Context) { "success": true, "message": "", }) - return } func DeleteDisabledChannel(c *gin.Context) { @@ -141,7 +137,6 @@ func DeleteDisabledChannel(c *gin.Context) { "message": "", "data": rows, }) - return } func UpdateChannel(c *gin.Context) { @@ -167,5 +162,4 @@ func UpdateChannel(c *gin.Context) { "message": "", "data": channel, }) - return } diff --git a/controller/github.go b/controller/github.go index ee995379..00ec3a88 100644 --- a/controller/github.go +++ b/controller/github.go @@ -5,13 +5,14 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-contrib/sessions" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" "time" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" ) type GitHubOAuthResponse struct { @@ -211,7 +212,6 @@ func GitHubBind(c *gin.Context) { "success": true, "message": "bind", }) - return } func GenerateOAuthCode(c *gin.Context) { diff --git a/controller/group.go b/controller/group.go index 109e2bce..76eb5df8 100644 --- a/controller/group.go +++ b/controller/group.go @@ -9,7 +9,7 @@ import ( func GetGroups(c *gin.Context) { groupNames := make([]string, 0) - for groupName, _ := range common.GroupRatio { + for groupName := range common.GroupRatio { groupNames = append(groupNames, groupName) } c.JSON(http.StatusOK, gin.H{ diff --git a/controller/log.go b/controller/log.go index b65867fe..6327a6a8 100644 --- a/controller/log.go +++ b/controller/log.go @@ -1,11 +1,12 @@ package controller import ( - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" + + "github.com/gin-gonic/gin" ) func GetAllLogs(c *gin.Context) { @@ -33,7 +34,6 @@ func GetAllLogs(c *gin.Context) { "message": "", "data": logs, }) - return } func GetUserLogs(c *gin.Context) { @@ -60,7 +60,6 @@ func GetUserLogs(c *gin.Context) { "message": "", "data": logs, }) - return } func SearchAllLogs(c *gin.Context) { @@ -78,7 +77,6 @@ func SearchAllLogs(c *gin.Context) { "message": "", "data": logs, }) - return } func SearchUserLogs(c *gin.Context) { @@ -97,7 +95,6 @@ func SearchUserLogs(c *gin.Context) { "message": "", "data": logs, }) - return } func GetLogsStat(c *gin.Context) { @@ -118,7 +115,6 @@ func GetLogsStat(c *gin.Context) { //"token": tokenNum, }, }) - return } func GetLogsSelfStat(c *gin.Context) { @@ -139,7 +135,6 @@ func GetLogsSelfStat(c *gin.Context) { //"token": tokenNum, }, }) - return } func DeleteHistoryLogs(c *gin.Context) { @@ -164,5 +159,4 @@ func DeleteHistoryLogs(c *gin.Context) { "message": "", "data": count, }) - return } diff --git a/controller/misc.go b/controller/misc.go index 2bcbb41f..4940bcf7 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -35,7 +35,6 @@ func GetStatus(c *gin.Context) { "display_in_currency": common.DisplayInCurrencyEnabled, }, }) - return } func GetNotice(c *gin.Context) { @@ -46,7 +45,6 @@ func GetNotice(c *gin.Context) { "message": "", "data": common.OptionMap["Notice"], }) - return } func GetAbout(c *gin.Context) { @@ -57,7 +55,6 @@ func GetAbout(c *gin.Context) { "message": "", "data": common.OptionMap["About"], }) - return } func GetHomePageContent(c *gin.Context) { @@ -68,7 +65,6 @@ func GetHomePageContent(c *gin.Context) { "message": "", "data": common.OptionMap["HomePageContent"], }) - return } func SendEmailVerification(c *gin.Context) { @@ -121,7 +117,6 @@ func SendEmailVerification(c *gin.Context) { "success": true, "message": "", }) - return } func SendPasswordResetEmail(c *gin.Context) { @@ -160,7 +155,6 @@ func SendPasswordResetEmail(c *gin.Context) { "success": true, "message": "", }) - return } type PasswordResetRequest struct { @@ -200,5 +194,4 @@ func ResetPassword(c *gin.Context) { "message": "", "data": password, }) - return } diff --git a/controller/relay-chat.go b/controller/relay-chat.go index a1e93e25..8e74c789 100644 --- a/controller/relay-chat.go +++ b/controller/relay-chat.go @@ -43,7 +43,7 @@ func RelayChat(c *gin.Context) { } // 获取供应商 - provider, pass := getProvider(c, channel.Type, common.RelayModeChatCompletions) + provider, pass := getProvider(c, channel, common.RelayModeChatCompletions) if pass { return } diff --git a/controller/relay-completions.go b/controller/relay-completions.go index da60a773..1731cb86 100644 --- a/controller/relay-completions.go +++ b/controller/relay-completions.go @@ -43,7 +43,7 @@ func RelayCompletions(c *gin.Context) { } // 获取供应商 - provider, pass := getProvider(c, channel.Type, common.RelayModeCompletions) + provider, pass := getProvider(c, channel, common.RelayModeCompletions) if pass { return } diff --git a/controller/relay-embeddings.go b/controller/relay-embeddings.go index 5d3f9aec..df3192bf 100644 --- a/controller/relay-embeddings.go +++ b/controller/relay-embeddings.go @@ -42,7 +42,7 @@ func RelayEmbeddings(c *gin.Context) { } // 获取供应商 - provider, pass := getProvider(c, channel.Type, common.RelayModeEmbeddings) + provider, pass := getProvider(c, channel, common.RelayModeEmbeddings) if pass { return } diff --git a/controller/relay-image-edits.go b/controller/relay-image-edits.go index fb7c8850..006c9520 100644 --- a/controller/relay-image-edits.go +++ b/controller/relay-image-edits.go @@ -51,7 +51,7 @@ func RelayImageEdits(c *gin.Context) { } // 获取供应商 - provider, pass := getProvider(c, channel.Type, common.RelayModeImagesEdits) + provider, pass := getProvider(c, channel, common.RelayModeImagesEdits) if pass { return } diff --git a/controller/relay-image-generations.go b/controller/relay-image-generations.go index 20092e0e..4c7f30c0 100644 --- a/controller/relay-image-generations.go +++ b/controller/relay-image-generations.go @@ -54,7 +54,7 @@ func RelayImageGenerations(c *gin.Context) { } // 获取供应商 - provider, pass := getProvider(c, channel.Type, common.RelayModeImagesGenerations) + provider, pass := getProvider(c, channel, common.RelayModeImagesGenerations) if pass { return } diff --git a/controller/relay-image-variationsy.go b/controller/relay-image-variationsy.go index c128625a..019b431f 100644 --- a/controller/relay-image-variationsy.go +++ b/controller/relay-image-variationsy.go @@ -46,7 +46,7 @@ func RelayImageVariations(c *gin.Context) { } // 获取供应商 - provider, pass := getProvider(c, channel.Type, common.RelayModeImagesVariations) + provider, pass := getProvider(c, channel, common.RelayModeImagesVariations) if pass { return } diff --git a/controller/relay-moderations.go b/controller/relay-moderations.go index 2ffda2da..5feccdb4 100644 --- a/controller/relay-moderations.go +++ b/controller/relay-moderations.go @@ -42,7 +42,7 @@ func RelayModerations(c *gin.Context) { } // 获取供应商 - provider, pass := getProvider(c, channel.Type, common.RelayModeModerations) + provider, pass := getProvider(c, channel, common.RelayModeModerations) if pass { return } diff --git a/controller/relay-speech.go b/controller/relay-speech.go index 03ac3151..e5ace14c 100644 --- a/controller/relay-speech.go +++ b/controller/relay-speech.go @@ -38,7 +38,7 @@ func RelaySpeech(c *gin.Context) { } // 获取供应商 - provider, pass := getProvider(c, channel.Type, common.RelayModeAudioSpeech) + provider, pass := getProvider(c, channel, common.RelayModeAudioSpeech) if pass { return } diff --git a/controller/relay-transcriptions.go b/controller/relay-transcriptions.go index cf0f1831..b08174a1 100644 --- a/controller/relay-transcriptions.go +++ b/controller/relay-transcriptions.go @@ -38,7 +38,7 @@ func RelayTranscriptions(c *gin.Context) { } // 获取供应商 - provider, pass := getProvider(c, channel.Type, common.RelayModeAudioTranscription) + provider, pass := getProvider(c, channel, common.RelayModeAudioTranscription) if pass { return } diff --git a/controller/relay-translations.go b/controller/relay-translations.go index 776a6f7d..fcdada36 100644 --- a/controller/relay-translations.go +++ b/controller/relay-translations.go @@ -38,7 +38,7 @@ func RelayTranslations(c *gin.Context) { } // 获取供应商 - provider, pass := getProvider(c, channel.Type, common.RelayModeAudioTranslation) + provider, pass := getProvider(c, channel, common.RelayModeAudioTranslation) if pass { return } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 3307b0d1..c513883e 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -45,7 +45,6 @@ func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, pas return } - setChannelToContext(c, channel) return } @@ -84,8 +83,8 @@ func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, bool return channel, false } -func getProvider(c *gin.Context, channelType int, relayMode int) (providersBase.ProviderInterface, bool) { - provider := providers.GetProvider(channelType, c) +func getProvider(c *gin.Context, channel *model.Channel, relayMode int) (providersBase.ProviderInterface, bool) { + provider := providers.GetProvider(channel, c) if provider == nil { common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found") return nil, true @@ -99,27 +98,6 @@ func getProvider(c *gin.Context, channelType int, relayMode int) (providersBase. return provider, false } -func setChannelToContext(c *gin.Context, channel *model.Channel) { - // c.Set("channel", channel.Type) - c.Set("channel_id", channel.Id) - c.Set("channel_name", channel.Name) - c.Set("api_key", channel.Key) - c.Set("base_url", channel.GetBaseURL()) - switch channel.Type { - case common.ChannelTypeAzure: - c.Set("api_version", channel.Other) - case common.ChannelTypeXunfei: - c.Set("api_version", channel.Other) - case common.ChannelTypeGemini: - c.Set("api_version", channel.Other) - case common.ChannelTypeAIProxyLibrary: - c.Set("library_id", channel.Other) - case common.ChannelTypeAli: - c.Set("plugin", channel.Other) - } - -} - func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool { if !common.AutomaticDisableChannelEnabled { return false diff --git a/controller/token.go b/controller/token.go index 8642122c..a4e3a235 100644 --- a/controller/token.go +++ b/controller/token.go @@ -1,11 +1,12 @@ package controller import ( - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" + + "github.com/gin-gonic/gin" ) func GetAllTokens(c *gin.Context) { @@ -27,7 +28,6 @@ func GetAllTokens(c *gin.Context) { "message": "", "data": tokens, }) - return } func SearchTokens(c *gin.Context) { @@ -46,7 +46,6 @@ func SearchTokens(c *gin.Context) { "message": "", "data": tokens, }) - return } func GetToken(c *gin.Context) { @@ -72,7 +71,6 @@ func GetToken(c *gin.Context) { "message": "", "data": token, }) - return } func GetTokenStatus(c *gin.Context) { @@ -138,7 +136,6 @@ func AddToken(c *gin.Context) { "success": true, "message": "", }) - return } func DeleteToken(c *gin.Context) { @@ -156,7 +153,6 @@ func DeleteToken(c *gin.Context) { "success": true, "message": "", }) - return } func UpdateToken(c *gin.Context) { @@ -224,5 +220,4 @@ func UpdateToken(c *gin.Context) { "message": "", "data": cleanToken, }) - return } diff --git a/controller/user.go b/controller/user.go index 3c8ea997..5b5c4624 100644 --- a/controller/user.go +++ b/controller/user.go @@ -174,7 +174,6 @@ func Register(c *gin.Context) { "success": true, "message": "", }) - return } func GetAllUsers(c *gin.Context) { @@ -195,7 +194,6 @@ func GetAllUsers(c *gin.Context) { "message": "", "data": users, }) - return } func SearchUsers(c *gin.Context) { @@ -213,7 +211,6 @@ func SearchUsers(c *gin.Context) { "message": "", "data": users, }) - return } func GetUser(c *gin.Context) { @@ -246,7 +243,6 @@ func GetUser(c *gin.Context) { "message": "", "data": user, }) - return } func GetUserDashboard(c *gin.Context) { @@ -306,7 +302,6 @@ func GenerateAccessToken(c *gin.Context) { "message": "", "data": user.AccessToken, }) - return } func GetAffCode(c *gin.Context) { @@ -334,7 +329,6 @@ func GetAffCode(c *gin.Context) { "message": "", "data": user.AffCode, }) - return } func GetSelf(c *gin.Context) { @@ -352,7 +346,6 @@ func GetSelf(c *gin.Context) { "message": "", "data": user, }) - return } func UpdateUser(c *gin.Context) { @@ -416,7 +409,6 @@ func UpdateUser(c *gin.Context) { "success": true, "message": "", }) - return } func UpdateSelf(c *gin.Context) { @@ -463,7 +455,6 @@ func UpdateSelf(c *gin.Context) { "success": true, "message": "", }) - return } func DeleteUser(c *gin.Context) { @@ -525,7 +516,6 @@ func DeleteSelf(c *gin.Context) { "success": true, "message": "", }) - return } func CreateUser(c *gin.Context) { @@ -574,7 +564,6 @@ func CreateUser(c *gin.Context) { "success": true, "message": "", }) - return } type ManageRequest struct { @@ -691,7 +680,6 @@ func ManageUser(c *gin.Context) { "message": "", "data": clearUser, }) - return } func EmailBind(c *gin.Context) { @@ -733,7 +721,6 @@ func EmailBind(c *gin.Context) { "success": true, "message": "", }) - return } type topUpRequest struct { @@ -764,5 +751,4 @@ func TopUp(c *gin.Context) { "message": "", "data": quota, }) - return } diff --git a/controller/wechat.go b/controller/wechat.go index ff4c9fb6..fbd7d2bd 100644 --- a/controller/wechat.go +++ b/controller/wechat.go @@ -4,12 +4,13 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" "time" + + "github.com/gin-gonic/gin" ) type wechatLoginResponse struct { @@ -160,5 +161,4 @@ func WeChatBind(c *gin.Context) { "success": true, "message": "", }) - return } diff --git a/providers/ali/base.go b/providers/ali/base.go index 72ee3317..dec39b99 100644 --- a/providers/ali/base.go +++ b/providers/ali/base.go @@ -32,9 +32,9 @@ type AliProvider struct { func (p *AliProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) p.CommonRequestHeaders(headers) - headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key")) - if p.Context.GetString("plugin") != "" { - headers["X-DashScope-Plugin"] = p.Context.GetString("plugin") + headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key) + if p.Channel.Other != "" { + headers["X-DashScope-Plugin"] = p.Channel.Other } return headers diff --git a/providers/azureSpeech/base.go b/providers/azureSpeech/base.go index 9d6dbdd6..9f88a69a 100644 --- a/providers/azureSpeech/base.go +++ b/providers/azureSpeech/base.go @@ -27,7 +27,7 @@ type AzureSpeechProvider struct { // 获取请求头 func (p *AzureSpeechProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) - headers["Ocp-Apim-Subscription-Key"] = p.Context.GetString("api_key") + headers["Ocp-Apim-Subscription-Key"] = p.Channel.Key headers["Content-Type"] = "application/ssml+xml" headers["User-Agent"] = "OneAPI" // headers["X-Microsoft-OutputFormat"] = "audio-16khz-128kbitrate-mono-mp3" diff --git a/providers/baidu/base.go b/providers/baidu/base.go index ce2900b7..b365e0fb 100644 --- a/providers/baidu/base.go +++ b/providers/baidu/base.go @@ -63,7 +63,7 @@ func (p *BaiduProvider) GetRequestHeaders() (headers map[string]string) { } func (p *BaiduProvider) getBaiduAccessToken() (string, error) { - apiKey := p.Context.GetString("api_key") + apiKey := p.Channel.Key if val, ok := baiduTokenStore.Load(apiKey); ok { var accessToken BaiduAccessToken if accessToken, ok = val.(BaiduAccessToken); ok { diff --git a/providers/base/common.go b/providers/base/common.go index 30359151..02d63e57 100644 --- a/providers/base/common.go +++ b/providers/base/common.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/model" "one-api/types" "strings" @@ -28,17 +29,22 @@ type BaseProvider struct { ImagesVariations string Proxy string Context *gin.Context + Channel *model.Channel } // 获取基础URL func (p *BaseProvider) GetBaseURL() string { - if p.Context.GetString("base_url") != "" { - return p.Context.GetString("base_url") + if p.Channel.GetBaseURL() != "" { + return p.Channel.GetBaseURL() } return p.BaseURL } +func (p *BaseProvider) SetChannel(channel *model.Channel) { + p.Channel = channel +} + // 获取完整请求URL func (p *BaseProvider) GetFullRequestURL(requestURL string, modelName string) string { baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") diff --git a/providers/base/interface.go b/providers/base/interface.go index 5c05b404..584e12b0 100644 --- a/providers/base/interface.go +++ b/providers/base/interface.go @@ -12,6 +12,7 @@ type ProviderInterface interface { GetFullRequestURL(requestURL string, modelName string) string GetRequestHeaders() (headers map[string]string) SupportAPI(relayMode int) bool + SetChannel(channel *model.Channel) } // 完成接口 diff --git a/providers/claude/base.go b/providers/claude/base.go index dc334ac7..59b819eb 100644 --- a/providers/claude/base.go +++ b/providers/claude/base.go @@ -28,7 +28,7 @@ func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) p.CommonRequestHeaders(headers) - headers["x-api-key"] = p.Context.GetString("api_key") + headers["x-api-key"] = p.Channel.Key anthropicVersion := p.Context.Request.Header.Get("anthropic-version") if anthropicVersion == "" { anthropicVersion = "2023-06-01" diff --git a/providers/gemini/base.go b/providers/gemini/base.go index 43fb3d7a..26d71c7d 100644 --- a/providers/gemini/base.go +++ b/providers/gemini/base.go @@ -28,8 +28,8 @@ type GeminiProvider struct { func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string) string { baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") version := "v1" - if p.Context.GetString("api_version") != "" { - version = p.Context.GetString("api_version") + if p.Channel.Other != "" { + version = p.Channel.Other } return fmt.Sprintf("%s/%s/models/%s:%s", baseURL, version, modelName, requestURL) @@ -40,7 +40,7 @@ func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string) func (p *GeminiProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) p.CommonRequestHeaders(headers) - headers["x-goog-api-key"] = p.Context.GetString("api_key") + headers["x-goog-api-key"] = p.Channel.Key return headers } diff --git a/providers/openai/base.go b/providers/openai/base.go index 88e0e800..42edf1b2 100644 --- a/providers/openai/base.go +++ b/providers/openai/base.go @@ -59,7 +59,7 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") if p.IsAzure { - apiVersion := p.Context.GetString("api_version") + apiVersion := p.Channel.Other if modelName == "dall-e-2" { // 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本 // 已经没有dall-e-2了,所以暂时写死 @@ -85,9 +85,9 @@ func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) p.CommonRequestHeaders(headers) if p.IsAzure { - headers["api-key"] = p.Context.GetString("api_key") + headers["api-key"] = p.Channel.Key } else { - headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key")) + headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key) } return headers diff --git a/providers/palm/base.go b/providers/palm/base.go index 8f89bb72..f500b418 100644 --- a/providers/palm/base.go +++ b/providers/palm/base.go @@ -29,7 +29,7 @@ type PalmProvider struct { func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) p.CommonRequestHeaders(headers) - headers["x-goog-api-key"] = p.Context.GetString("api_key") + headers["x-goog-api-key"] = p.Channel.Key return headers } diff --git a/providers/providers.go b/providers/providers.go index a143d843..52d9662e 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -2,6 +2,7 @@ package providers import ( "one-api/common" + "one-api/model" "one-api/providers/aigc2d" "one-api/providers/aiproxy" "one-api/providers/ali" @@ -55,19 +56,23 @@ func init() { } // 获取供应商 -func GetProvider(channelType int, c *gin.Context) base.ProviderInterface { - factory, ok := providerFactories[channelType] +func GetProvider(channel *model.Channel, c *gin.Context) base.ProviderInterface { + factory, ok := providerFactories[channel.Type] + var provider base.ProviderInterface if !ok { // 处理未找到的供应商工厂 - baseURL := common.ChannelBaseURLs[channelType] - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") + baseURL := common.ChannelBaseURLs[channel.Type] + if channel.GetBaseURL() != "" { + baseURL = channel.GetBaseURL() } - if baseURL != "" { - return openai.CreateOpenAIProvider(c, baseURL) + if baseURL == "" { + return nil } - return nil + provider = openai.CreateOpenAIProvider(c, baseURL) } - return factory.Create(c) + provider = factory.Create(c) + provider.SetChannel(channel) + + return provider } diff --git a/providers/tencent/base.go b/providers/tencent/base.go index 17cf6ca1..7c259f9f 100644 --- a/providers/tencent/base.go +++ b/providers/tencent/base.go @@ -52,7 +52,7 @@ func (p *TencentProvider) parseTencentConfig(config string) (appId int64, secret } func (p *TencentProvider) getTencentSign(req TencentChatRequest) string { - apiKey := p.Context.GetString("api_key") + apiKey := p.Channel.Key appId, secretId, secretKey, err := p.parseTencentConfig(apiKey) if err != nil { return "" diff --git a/providers/xunfei/base.go b/providers/xunfei/base.go index 7d85a524..4cc22cb2 100644 --- a/providers/xunfei/base.go +++ b/providers/xunfei/base.go @@ -42,7 +42,7 @@ func (p *XunfeiProvider) GetRequestHeaders() (headers map[string]string) { // 获取完整请求 URL func (p *XunfeiProvider) GetFullRequestURL(requestURL string, modelName string) string { - splits := strings.Split(p.Context.GetString("api_key"), "|") + splits := strings.Split(p.Channel.Key, "|") if len(splits) != 3 { return "" } @@ -58,7 +58,7 @@ func (p *XunfeiProvider) getXunfeiAuthUrl(apiKey string, apiSecret string) (stri query := p.Context.Request.URL.Query() apiVersion := query.Get("api-version") if apiVersion == "" { - apiVersion = p.Context.GetString("api_version") + apiVersion = p.Channel.Key } if apiVersion == "" { apiVersion = "v1.1" diff --git a/providers/zhipu/base.go b/providers/zhipu/base.go index 9342a39e..7d1d08ce 100644 --- a/providers/zhipu/base.go +++ b/providers/zhipu/base.go @@ -49,7 +49,7 @@ func (p *ZhipuProvider) GetFullRequestURL(requestURL string, modelName string) s } func (p *ZhipuProvider) getZhipuToken() string { - apikey := p.Context.GetString("api_key") + apikey := p.Channel.Key data, ok := zhipuTokens.Load(apikey) if ok { tokenData := data.(zhipuTokenData)