From da87fca2a27ab412cbfd4fce5a51c6d5808e2eb9 Mon Sep 17 00:00:00 2001 From: MartialBE Date: Sat, 2 Dec 2023 18:14:48 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20=E6=B7=BB=E5=8A=A0=E5=B7=A5?= =?UTF-8?q?=E5=8E=82=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- providers/aiproxy/base.go | 12 ++++---- providers/ali/base.go | 13 +++++---- providers/api2d/base.go | 11 +++++--- providers/azure/base.go | 12 ++++---- providers/baidu/base.go | 15 ++++++---- providers/claude/base.go | 11 +++++--- providers/closeai/base.go | 13 +++++---- providers/openai/base.go | 7 +++++ providers/openaisb/base.go | 11 +++++--- providers/palm/base.go | 10 ++++--- providers/providers.go | 58 ++++++++++++++++++++------------------ providers/tencent/base.go | 10 ++++--- providers/xunfei/base.go | 16 ++++++----- providers/zhipu/base.go | 10 ++++--- 14 files changed, 125 insertions(+), 84 deletions(-) diff --git a/providers/aiproxy/base.go b/providers/aiproxy/base.go index 9acc859e..8b0d0ff3 100644 --- a/providers/aiproxy/base.go +++ b/providers/aiproxy/base.go @@ -1,18 +1,20 @@ package aiproxy import ( + "one-api/providers/base" "one-api/providers/openai" "github.com/gin-gonic/gin" ) -type AIProxyProvider struct { - *openai.OpenAIProvider -} +type AIProxyProviderFactory struct{} -// 创建 CreateAIProxyProvider -func CreateAIProxyProvider(c *gin.Context) *AIProxyProvider { +func (f AIProxyProviderFactory) Create(c *gin.Context) base.ProviderInterface { return &AIProxyProvider{ OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aiproxy.io"), } } + +type AIProxyProvider struct { + *openai.OpenAIProvider +} diff --git a/providers/ali/base.go b/providers/ali/base.go index 249abf05..f49067c0 100644 --- a/providers/ali/base.go +++ b/providers/ali/base.go @@ -8,13 +8,12 @@ import ( "github.com/gin-gonic/gin" ) -type AliProvider struct { - base.BaseProvider -} +// 定义供应商工厂 +type AliProviderFactory struct{} +// 创建 AliProvider // https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation -// 创建 AliAIProvider -func CreateAliAIProvider(c *gin.Context) *AliProvider { +func (f AliProviderFactory) Create(c *gin.Context) base.ProviderInterface { return &AliProvider{ BaseProvider: base.BaseProvider{ BaseURL: "https://dashscope.aliyuncs.com", @@ -25,6 +24,10 @@ func CreateAliAIProvider(c *gin.Context) *AliProvider { } } +type AliProvider struct { + base.BaseProvider +} + // 获取请求头 func (p *AliProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) diff --git a/providers/api2d/base.go b/providers/api2d/base.go index b81d2371..ca9ab256 100644 --- a/providers/api2d/base.go +++ b/providers/api2d/base.go @@ -1,18 +1,21 @@ package api2d import ( + "one-api/providers/base" "one-api/providers/openai" "github.com/gin-gonic/gin" ) -type Api2dProvider struct { - *openai.OpenAIProvider -} +type Api2dProviderFactory struct{} // 创建 Api2dProvider -func CreateApi2dProvider(c *gin.Context) *Api2dProvider { +func (f Api2dProviderFactory) Create(c *gin.Context) base.ProviderInterface { return &Api2dProvider{ OpenAIProvider: openai.CreateOpenAIProvider(c, "https://oa.api2d.net"), } } + +type Api2dProvider struct { + *openai.OpenAIProvider +} diff --git a/providers/azure/base.go b/providers/azure/base.go index 1a2f0aaa..6e560e6b 100644 --- a/providers/azure/base.go +++ b/providers/azure/base.go @@ -7,12 +7,10 @@ import ( "github.com/gin-gonic/gin" ) -type AzureProvider struct { - openai.OpenAIProvider -} +type AzureProviderFactory struct{} -// 创建 OpenAIProvider -func CreateAzureProvider(c *gin.Context) *AzureProvider { +// 创建 AzureProvider +func (f AzureProviderFactory) Create(c *gin.Context) base.ProviderInterface { return &AzureProvider{ OpenAIProvider: openai.OpenAIProvider{ BaseProvider: base.BaseProvider{ @@ -32,3 +30,7 @@ func CreateAzureProvider(c *gin.Context) *AzureProvider { }, } } + +type AzureProvider struct { + openai.OpenAIProvider +} diff --git a/providers/baidu/base.go b/providers/baidu/base.go index 1a5005fc..ce2900b7 100644 --- a/providers/baidu/base.go +++ b/providers/baidu/base.go @@ -13,13 +13,12 @@ import ( "github.com/gin-gonic/gin" ) -var baiduTokenStore sync.Map +// 定义供应商工厂 +type BaiduProviderFactory struct{} -type BaiduProvider struct { - base.BaseProvider -} +// 创建 BaiduProvider -func CreateBaiduProvider(c *gin.Context) *BaiduProvider { +func (f BaiduProviderFactory) Create(c *gin.Context) base.ProviderInterface { return &BaiduProvider{ BaseProvider: base.BaseProvider{ BaseURL: "https://aip.baidubce.com", @@ -30,6 +29,12 @@ func CreateBaiduProvider(c *gin.Context) *BaiduProvider { } } +var baiduTokenStore sync.Map + +type BaiduProvider struct { + base.BaseProvider +} + // 获取完整请求 URL func (p *BaiduProvider) GetFullRequestURL(requestURL string, modelName string) string { var modelNameMap = map[string]string{ diff --git a/providers/claude/base.go b/providers/claude/base.go index f77a5c2f..dc334ac7 100644 --- a/providers/claude/base.go +++ b/providers/claude/base.go @@ -6,11 +6,10 @@ import ( "github.com/gin-gonic/gin" ) -type ClaudeProvider struct { - base.BaseProvider -} +type ClaudeProviderFactory struct{} -func CreateClaudeProvider(c *gin.Context) *ClaudeProvider { +// 创建 ClaudeProvider +func (f ClaudeProviderFactory) Create(c *gin.Context) base.ProviderInterface { return &ClaudeProvider{ BaseProvider: base.BaseProvider{ BaseURL: "https://api.anthropic.com", @@ -20,6 +19,10 @@ func CreateClaudeProvider(c *gin.Context) *ClaudeProvider { } } +type ClaudeProvider struct { + base.BaseProvider +} + // 获取请求头 func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) diff --git a/providers/closeai/base.go b/providers/closeai/base.go index 0d64d6e7..c5387ec5 100644 --- a/providers/closeai/base.go +++ b/providers/closeai/base.go @@ -1,18 +1,21 @@ package closeai import ( + "one-api/providers/base" "one-api/providers/openai" "github.com/gin-gonic/gin" ) -type CloseaiProxyProvider struct { - *openai.OpenAIProvider -} +type CloseaiProviderFactory struct{} -// 创建 CloseaiProxyProvider -func CreateCloseaiProxyProvider(c *gin.Context) *CloseaiProxyProvider { +// 创建 CloseaiProvider +func (f CloseaiProviderFactory) Create(c *gin.Context) base.ProviderInterface { return &CloseaiProxyProvider{ OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"), } } + +type CloseaiProxyProvider struct { + *openai.OpenAIProvider +} diff --git a/providers/openai/base.go b/providers/openai/base.go index 4f4e1bc5..88e0e800 100644 --- a/providers/openai/base.go +++ b/providers/openai/base.go @@ -16,6 +16,13 @@ import ( "github.com/gin-gonic/gin" ) +type OpenAIProviderFactory struct{} + +// 创建 OpenAIProvider +func (f OpenAIProviderFactory) Create(c *gin.Context) base.ProviderInterface { + return CreateOpenAIProvider(c, "") +} + type OpenAIProvider struct { base.BaseProvider IsAzure bool diff --git a/providers/openaisb/base.go b/providers/openaisb/base.go index f5f46dfb..c770d3e4 100644 --- a/providers/openaisb/base.go +++ b/providers/openaisb/base.go @@ -1,18 +1,21 @@ package openaisb import ( + "one-api/providers/base" "one-api/providers/openai" "github.com/gin-gonic/gin" ) -type OpenaiSBProvider struct { - *openai.OpenAIProvider -} +type OpenaiSBProviderFactory struct{} // 创建 OpenaiSBProvider -func CreateOpenaiSBProvider(c *gin.Context) *OpenaiSBProvider { +func (f OpenaiSBProviderFactory) Create(c *gin.Context) base.ProviderInterface { return &OpenaiSBProvider{ OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.openai-sb.com"), } } + +type OpenaiSBProvider struct { + *openai.OpenAIProvider +} diff --git a/providers/palm/base.go b/providers/palm/base.go index ee10709e..aba4fb72 100644 --- a/providers/palm/base.go +++ b/providers/palm/base.go @@ -8,12 +8,10 @@ import ( "github.com/gin-gonic/gin" ) -type PalmProvider struct { - base.BaseProvider -} +type PalmProviderFactory struct{} // 创建 PalmProvider -func CreatePalmProvider(c *gin.Context) *PalmProvider { +func (f PalmProviderFactory) Create(c *gin.Context) base.ProviderInterface { return &PalmProvider{ BaseProvider: base.BaseProvider{ BaseURL: "https://generativelanguage.googleapis.com", @@ -23,6 +21,10 @@ func CreatePalmProvider(c *gin.Context) *PalmProvider { } } +type PalmProvider struct { + base.BaseProvider +} + // 获取请求头 func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) diff --git a/providers/providers.go b/providers/providers.go index cee972b8..74a3c385 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -20,35 +20,36 @@ import ( "github.com/gin-gonic/gin" ) +// 定义供应商工厂接口 +type ProviderFactory interface { + Create(c *gin.Context) base.ProviderInterface +} + +// 创建全局的供应商工厂映射 +var providerFactories = make(map[int]ProviderFactory) + +// 在程序启动时,添加所有的供应商工厂 +func init() { + providerFactories[common.ChannelTypeOpenAI] = openai.OpenAIProviderFactory{} + providerFactories[common.ChannelTypeAzure] = azure.AzureProviderFactory{} + providerFactories[common.ChannelTypeAli] = ali.AliProviderFactory{} + providerFactories[common.ChannelTypeTencent] = tencent.TencentProviderFactory{} + providerFactories[common.ChannelTypeBaidu] = baidu.BaiduProviderFactory{} + providerFactories[common.ChannelTypeAnthropic] = claude.ClaudeProviderFactory{} + providerFactories[common.ChannelTypePaLM] = palm.PalmProviderFactory{} + providerFactories[common.ChannelTypeZhipu] = zhipu.ZhipuProviderFactory{} + providerFactories[common.ChannelTypeXunfei] = xunfei.XunfeiProviderFactory{} + providerFactories[common.ChannelTypeAIProxy] = aiproxy.AIProxyProviderFactory{} + providerFactories[common.ChannelTypeAPI2D] = api2d.Api2dProviderFactory{} + providerFactories[common.ChannelTypeCloseAI] = closeai.CloseaiProviderFactory{} + providerFactories[common.ChannelTypeOpenAISB] = openaisb.OpenaiSBProviderFactory{} +} + +// 获取供应商 func GetProvider(channelType int, c *gin.Context) base.ProviderInterface { - switch channelType { - case common.ChannelTypeOpenAI: - return openai.CreateOpenAIProvider(c, "") - case common.ChannelTypeAzure: - return azure.CreateAzureProvider(c) - case common.ChannelTypeAli: - return ali.CreateAliAIProvider(c) - case common.ChannelTypeTencent: - return tencent.CreateTencentProvider(c) - case common.ChannelTypeBaidu: - return baidu.CreateBaiduProvider(c) - case common.ChannelTypeAnthropic: - return claude.CreateClaudeProvider(c) - case common.ChannelTypePaLM: - return palm.CreatePalmProvider(c) - case common.ChannelTypeZhipu: - return zhipu.CreateZhipuProvider(c) - case common.ChannelTypeXunfei: - return xunfei.CreateXunfeiProvider(c) - case common.ChannelTypeAIProxy: - return aiproxy.CreateAIProxyProvider(c) - case common.ChannelTypeAPI2D: - return api2d.CreateApi2dProvider(c) - case common.ChannelTypeCloseAI: - return closeai.CreateCloseaiProxyProvider(c) - case common.ChannelTypeOpenAISB: - return openaisb.CreateOpenaiSBProvider(c) - default: + factory, ok := providerFactories[channelType] + if !ok { + // 处理未找到的供应商工厂 baseURL := common.ChannelBaseURLs[channelType] if c.GetString("base_url") != "" { baseURL = c.GetString("base_url") @@ -59,4 +60,5 @@ func GetProvider(channelType int, c *gin.Context) base.ProviderInterface { return nil } + return factory.Create(c) } diff --git a/providers/tencent/base.go b/providers/tencent/base.go index ab7b2b49..17cf6ca1 100644 --- a/providers/tencent/base.go +++ b/providers/tencent/base.go @@ -14,12 +14,10 @@ import ( "github.com/gin-gonic/gin" ) -type TencentProvider struct { - base.BaseProvider -} +type TencentProviderFactory struct{} // 创建 TencentProvider -func CreateTencentProvider(c *gin.Context) *TencentProvider { +func (f TencentProviderFactory) Create(c *gin.Context) base.ProviderInterface { return &TencentProvider{ BaseProvider: base.BaseProvider{ BaseURL: "https://hunyuan.cloud.tencent.com", @@ -29,6 +27,10 @@ func CreateTencentProvider(c *gin.Context) *TencentProvider { } } +type TencentProvider struct { + base.BaseProvider +} + // 获取请求头 func (p *TencentProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) diff --git a/providers/xunfei/base.go b/providers/xunfei/base.go index c8b37a94..76d52f5b 100644 --- a/providers/xunfei/base.go +++ b/providers/xunfei/base.go @@ -14,15 +14,10 @@ import ( "github.com/gin-gonic/gin" ) -// https://www.xfyun.cn/doc/spark/Web.html -type XunfeiProvider struct { - base.BaseProvider - domain string - apiId string -} +type XunfeiProviderFactory struct{} // 创建 XunfeiProvider -func CreateXunfeiProvider(c *gin.Context) *XunfeiProvider { +func (f XunfeiProviderFactory) Create(c *gin.Context) base.ProviderInterface { return &XunfeiProvider{ BaseProvider: base.BaseProvider{ BaseURL: "wss://spark-api.xf-yun.com", @@ -32,6 +27,13 @@ func CreateXunfeiProvider(c *gin.Context) *XunfeiProvider { } } +// https://www.xfyun.cn/doc/spark/Web.html +type XunfeiProvider struct { + base.BaseProvider + domain string + apiId string +} + // 获取请求头 func (p *XunfeiProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) diff --git a/providers/zhipu/base.go b/providers/zhipu/base.go index 1b59e1a3..9342a39e 100644 --- a/providers/zhipu/base.go +++ b/providers/zhipu/base.go @@ -15,12 +15,10 @@ import ( var zhipuTokens sync.Map var expSeconds int64 = 24 * 3600 -type ZhipuProvider struct { - base.BaseProvider -} +type ZhipuProviderFactory struct{} // 创建 ZhipuProvider -func CreateZhipuProvider(c *gin.Context) *ZhipuProvider { +func (f ZhipuProviderFactory) Create(c *gin.Context) base.ProviderInterface { return &ZhipuProvider{ BaseProvider: base.BaseProvider{ BaseURL: "https://open.bigmodel.cn", @@ -30,6 +28,10 @@ func CreateZhipuProvider(c *gin.Context) *ZhipuProvider { } } +type ZhipuProvider struct { + base.BaseProvider +} + // 获取请求头 func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string)