🎨 添加工厂方法

This commit is contained in:
MartialBE 2023-12-02 18:14:48 +08:00
parent 5e08cc8719
commit da87fca2a2
14 changed files with 125 additions and 84 deletions

View File

@ -1,18 +1,20 @@
package aiproxy package aiproxy
import ( import (
"one-api/providers/base"
"one-api/providers/openai" "one-api/providers/openai"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type AIProxyProvider struct { type AIProxyProviderFactory struct{}
*openai.OpenAIProvider
}
// 创建 CreateAIProxyProvider func (f AIProxyProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func CreateAIProxyProvider(c *gin.Context) *AIProxyProvider {
return &AIProxyProvider{ return &AIProxyProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aiproxy.io"), OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.aiproxy.io"),
} }
} }
type AIProxyProvider struct {
*openai.OpenAIProvider
}

View File

@ -8,13 +8,12 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type AliProvider struct { // 定义供应商工厂
base.BaseProvider type AliProviderFactory struct{}
}
// 创建 AliProvider
// https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation // https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
// 创建 AliAIProvider func (f AliProviderFactory) Create(c *gin.Context) base.ProviderInterface {
func CreateAliAIProvider(c *gin.Context) *AliProvider {
return &AliProvider{ return &AliProvider{
BaseProvider: base.BaseProvider{ BaseProvider: base.BaseProvider{
BaseURL: "https://dashscope.aliyuncs.com", BaseURL: "https://dashscope.aliyuncs.com",
@ -25,6 +24,10 @@ func CreateAliAIProvider(c *gin.Context) *AliProvider {
} }
} }
type AliProvider struct {
base.BaseProvider
}
// 获取请求头 // 获取请求头
func (p *AliProvider) GetRequestHeaders() (headers map[string]string) { func (p *AliProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string) headers = make(map[string]string)

View File

@ -1,18 +1,21 @@
package api2d package api2d
import ( import (
"one-api/providers/base"
"one-api/providers/openai" "one-api/providers/openai"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type Api2dProvider struct { type Api2dProviderFactory struct{}
*openai.OpenAIProvider
}
// 创建 Api2dProvider // 创建 Api2dProvider
func CreateApi2dProvider(c *gin.Context) *Api2dProvider { func (f Api2dProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &Api2dProvider{ return &Api2dProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://oa.api2d.net"), OpenAIProvider: openai.CreateOpenAIProvider(c, "https://oa.api2d.net"),
} }
} }
type Api2dProvider struct {
*openai.OpenAIProvider
}

View File

@ -7,12 +7,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type AzureProvider struct { type AzureProviderFactory struct{}
openai.OpenAIProvider
}
// 创建 OpenAIProvider // 创建 AzureProvider
func CreateAzureProvider(c *gin.Context) *AzureProvider { func (f AzureProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &AzureProvider{ return &AzureProvider{
OpenAIProvider: openai.OpenAIProvider{ OpenAIProvider: openai.OpenAIProvider{
BaseProvider: base.BaseProvider{ BaseProvider: base.BaseProvider{
@ -32,3 +30,7 @@ func CreateAzureProvider(c *gin.Context) *AzureProvider {
}, },
} }
} }
type AzureProvider struct {
openai.OpenAIProvider
}

View File

@ -13,13 +13,12 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
var baiduTokenStore sync.Map // 定义供应商工厂
type BaiduProviderFactory struct{}
type BaiduProvider struct { // 创建 BaiduProvider
base.BaseProvider
}
func CreateBaiduProvider(c *gin.Context) *BaiduProvider { func (f BaiduProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &BaiduProvider{ return &BaiduProvider{
BaseProvider: base.BaseProvider{ BaseProvider: base.BaseProvider{
BaseURL: "https://aip.baidubce.com", BaseURL: "https://aip.baidubce.com",
@ -30,6 +29,12 @@ func CreateBaiduProvider(c *gin.Context) *BaiduProvider {
} }
} }
var baiduTokenStore sync.Map
type BaiduProvider struct {
base.BaseProvider
}
// 获取完整请求 URL // 获取完整请求 URL
func (p *BaiduProvider) GetFullRequestURL(requestURL string, modelName string) string { func (p *BaiduProvider) GetFullRequestURL(requestURL string, modelName string) string {
var modelNameMap = map[string]string{ var modelNameMap = map[string]string{

View File

@ -6,11 +6,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type ClaudeProvider struct { type ClaudeProviderFactory struct{}
base.BaseProvider
}
func CreateClaudeProvider(c *gin.Context) *ClaudeProvider { // 创建 ClaudeProvider
func (f ClaudeProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &ClaudeProvider{ return &ClaudeProvider{
BaseProvider: base.BaseProvider{ BaseProvider: base.BaseProvider{
BaseURL: "https://api.anthropic.com", BaseURL: "https://api.anthropic.com",
@ -20,6 +19,10 @@ func CreateClaudeProvider(c *gin.Context) *ClaudeProvider {
} }
} }
type ClaudeProvider struct {
base.BaseProvider
}
// 获取请求头 // 获取请求头
func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) { func (p *ClaudeProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string) headers = make(map[string]string)

View File

@ -1,18 +1,21 @@
package closeai package closeai
import ( import (
"one-api/providers/base"
"one-api/providers/openai" "one-api/providers/openai"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type CloseaiProxyProvider struct { type CloseaiProviderFactory struct{}
*openai.OpenAIProvider
}
// 创建 CloseaiProxyProvider // 创建 CloseaiProvider
func CreateCloseaiProxyProvider(c *gin.Context) *CloseaiProxyProvider { func (f CloseaiProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &CloseaiProxyProvider{ return &CloseaiProxyProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"), OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.closeai-proxy.xyz"),
} }
} }
type CloseaiProxyProvider struct {
*openai.OpenAIProvider
}

View File

@ -16,6 +16,13 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type OpenAIProviderFactory struct{}
// 创建 OpenAIProvider
func (f OpenAIProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return CreateOpenAIProvider(c, "")
}
type OpenAIProvider struct { type OpenAIProvider struct {
base.BaseProvider base.BaseProvider
IsAzure bool IsAzure bool

View File

@ -1,18 +1,21 @@
package openaisb package openaisb
import ( import (
"one-api/providers/base"
"one-api/providers/openai" "one-api/providers/openai"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type OpenaiSBProvider struct { type OpenaiSBProviderFactory struct{}
*openai.OpenAIProvider
}
// 创建 OpenaiSBProvider // 创建 OpenaiSBProvider
func CreateOpenaiSBProvider(c *gin.Context) *OpenaiSBProvider { func (f OpenaiSBProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &OpenaiSBProvider{ return &OpenaiSBProvider{
OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.openai-sb.com"), OpenAIProvider: openai.CreateOpenAIProvider(c, "https://api.openai-sb.com"),
} }
} }
type OpenaiSBProvider struct {
*openai.OpenAIProvider
}

View File

@ -8,12 +8,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type PalmProvider struct { type PalmProviderFactory struct{}
base.BaseProvider
}
// 创建 PalmProvider // 创建 PalmProvider
func CreatePalmProvider(c *gin.Context) *PalmProvider { func (f PalmProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &PalmProvider{ return &PalmProvider{
BaseProvider: base.BaseProvider{ BaseProvider: base.BaseProvider{
BaseURL: "https://generativelanguage.googleapis.com", BaseURL: "https://generativelanguage.googleapis.com",
@ -23,6 +21,10 @@ func CreatePalmProvider(c *gin.Context) *PalmProvider {
} }
} }
type PalmProvider struct {
base.BaseProvider
}
// 获取请求头 // 获取请求头
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) { func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string) headers = make(map[string]string)

View File

@ -20,35 +20,36 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// 定义供应商工厂接口
type ProviderFactory interface {
Create(c *gin.Context) base.ProviderInterface
}
// 创建全局的供应商工厂映射
var providerFactories = make(map[int]ProviderFactory)
// 在程序启动时,添加所有的供应商工厂
func init() {
providerFactories[common.ChannelTypeOpenAI] = openai.OpenAIProviderFactory{}
providerFactories[common.ChannelTypeAzure] = azure.AzureProviderFactory{}
providerFactories[common.ChannelTypeAli] = ali.AliProviderFactory{}
providerFactories[common.ChannelTypeTencent] = tencent.TencentProviderFactory{}
providerFactories[common.ChannelTypeBaidu] = baidu.BaiduProviderFactory{}
providerFactories[common.ChannelTypeAnthropic] = claude.ClaudeProviderFactory{}
providerFactories[common.ChannelTypePaLM] = palm.PalmProviderFactory{}
providerFactories[common.ChannelTypeZhipu] = zhipu.ZhipuProviderFactory{}
providerFactories[common.ChannelTypeXunfei] = xunfei.XunfeiProviderFactory{}
providerFactories[common.ChannelTypeAIProxy] = aiproxy.AIProxyProviderFactory{}
providerFactories[common.ChannelTypeAPI2D] = api2d.Api2dProviderFactory{}
providerFactories[common.ChannelTypeCloseAI] = closeai.CloseaiProviderFactory{}
providerFactories[common.ChannelTypeOpenAISB] = openaisb.OpenaiSBProviderFactory{}
}
// 获取供应商
func GetProvider(channelType int, c *gin.Context) base.ProviderInterface { func GetProvider(channelType int, c *gin.Context) base.ProviderInterface {
switch channelType { factory, ok := providerFactories[channelType]
case common.ChannelTypeOpenAI: if !ok {
return openai.CreateOpenAIProvider(c, "") // 处理未找到的供应商工厂
case common.ChannelTypeAzure:
return azure.CreateAzureProvider(c)
case common.ChannelTypeAli:
return ali.CreateAliAIProvider(c)
case common.ChannelTypeTencent:
return tencent.CreateTencentProvider(c)
case common.ChannelTypeBaidu:
return baidu.CreateBaiduProvider(c)
case common.ChannelTypeAnthropic:
return claude.CreateClaudeProvider(c)
case common.ChannelTypePaLM:
return palm.CreatePalmProvider(c)
case common.ChannelTypeZhipu:
return zhipu.CreateZhipuProvider(c)
case common.ChannelTypeXunfei:
return xunfei.CreateXunfeiProvider(c)
case common.ChannelTypeAIProxy:
return aiproxy.CreateAIProxyProvider(c)
case common.ChannelTypeAPI2D:
return api2d.CreateApi2dProvider(c)
case common.ChannelTypeCloseAI:
return closeai.CreateCloseaiProxyProvider(c)
case common.ChannelTypeOpenAISB:
return openaisb.CreateOpenaiSBProvider(c)
default:
baseURL := common.ChannelBaseURLs[channelType] baseURL := common.ChannelBaseURLs[channelType]
if c.GetString("base_url") != "" { if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url") baseURL = c.GetString("base_url")
@ -59,4 +60,5 @@ func GetProvider(channelType int, c *gin.Context) base.ProviderInterface {
return nil return nil
} }
return factory.Create(c)
} }

View File

@ -14,12 +14,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
type TencentProvider struct { type TencentProviderFactory struct{}
base.BaseProvider
}
// 创建 TencentProvider // 创建 TencentProvider
func CreateTencentProvider(c *gin.Context) *TencentProvider { func (f TencentProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &TencentProvider{ return &TencentProvider{
BaseProvider: base.BaseProvider{ BaseProvider: base.BaseProvider{
BaseURL: "https://hunyuan.cloud.tencent.com", BaseURL: "https://hunyuan.cloud.tencent.com",
@ -29,6 +27,10 @@ func CreateTencentProvider(c *gin.Context) *TencentProvider {
} }
} }
type TencentProvider struct {
base.BaseProvider
}
// 获取请求头 // 获取请求头
func (p *TencentProvider) GetRequestHeaders() (headers map[string]string) { func (p *TencentProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string) headers = make(map[string]string)

View File

@ -14,15 +14,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// https://www.xfyun.cn/doc/spark/Web.html type XunfeiProviderFactory struct{}
type XunfeiProvider struct {
base.BaseProvider
domain string
apiId string
}
// 创建 XunfeiProvider // 创建 XunfeiProvider
func CreateXunfeiProvider(c *gin.Context) *XunfeiProvider { func (f XunfeiProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &XunfeiProvider{ return &XunfeiProvider{
BaseProvider: base.BaseProvider{ BaseProvider: base.BaseProvider{
BaseURL: "wss://spark-api.xf-yun.com", BaseURL: "wss://spark-api.xf-yun.com",
@ -32,6 +27,13 @@ func CreateXunfeiProvider(c *gin.Context) *XunfeiProvider {
} }
} }
// https://www.xfyun.cn/doc/spark/Web.html
type XunfeiProvider struct {
base.BaseProvider
domain string
apiId string
}
// 获取请求头 // 获取请求头
func (p *XunfeiProvider) GetRequestHeaders() (headers map[string]string) { func (p *XunfeiProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string) headers = make(map[string]string)

View File

@ -15,12 +15,10 @@ import (
var zhipuTokens sync.Map var zhipuTokens sync.Map
var expSeconds int64 = 24 * 3600 var expSeconds int64 = 24 * 3600
type ZhipuProvider struct { type ZhipuProviderFactory struct{}
base.BaseProvider
}
// 创建 ZhipuProvider // 创建 ZhipuProvider
func CreateZhipuProvider(c *gin.Context) *ZhipuProvider { func (f ZhipuProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return &ZhipuProvider{ return &ZhipuProvider{
BaseProvider: base.BaseProvider{ BaseProvider: base.BaseProvider{
BaseURL: "https://open.bigmodel.cn", BaseURL: "https://open.bigmodel.cn",
@ -30,6 +28,10 @@ func CreateZhipuProvider(c *gin.Context) *ZhipuProvider {
} }
} }
type ZhipuProvider struct {
base.BaseProvider
}
// 获取请求头 // 获取请求头
func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) { func (p *ZhipuProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string) headers = make(map[string]string)