ai-gateway/providers/providers.go
2023-12-26 16:40:50 +08:00

79 lines
2.7 KiB
Go

package providers
import (
"one-api/common"
"one-api/model"
"one-api/providers/aigc2d"
"one-api/providers/aiproxy"
"one-api/providers/ali"
"one-api/providers/api2d"
"one-api/providers/api2gpt"
"one-api/providers/azure"
azurespeech "one-api/providers/azureSpeech"
"one-api/providers/baidu"
"one-api/providers/base"
"one-api/providers/claude"
"one-api/providers/closeai"
"one-api/providers/gemini"
"one-api/providers/openai"
"one-api/providers/openaisb"
"one-api/providers/palm"
"one-api/providers/tencent"
"one-api/providers/xunfei"
"one-api/providers/zhipu"
"github.com/gin-gonic/gin"
)
// 定义供应商工厂接口
type ProviderFactory interface {
Create(c *gin.Context) base.ProviderInterface
}
// 创建全局的供应商工厂映射
var providerFactories = make(map[int]ProviderFactory)
// 在程序启动时,添加所有的供应商工厂
func init() {
providerFactories[common.ChannelTypeOpenAI] = openai.OpenAIProviderFactory{}
providerFactories[common.ChannelTypeAzure] = azure.AzureProviderFactory{}
providerFactories[common.ChannelTypeAli] = ali.AliProviderFactory{}
providerFactories[common.ChannelTypeTencent] = tencent.TencentProviderFactory{}
providerFactories[common.ChannelTypeBaidu] = baidu.BaiduProviderFactory{}
providerFactories[common.ChannelTypeAnthropic] = claude.ClaudeProviderFactory{}
providerFactories[common.ChannelTypePaLM] = palm.PalmProviderFactory{}
providerFactories[common.ChannelTypeZhipu] = zhipu.ZhipuProviderFactory{}
providerFactories[common.ChannelTypeXunfei] = xunfei.XunfeiProviderFactory{}
providerFactories[common.ChannelTypeAIProxy] = aiproxy.AIProxyProviderFactory{}
providerFactories[common.ChannelTypeAPI2D] = api2d.Api2dProviderFactory{}
providerFactories[common.ChannelTypeCloseAI] = closeai.CloseaiProviderFactory{}
providerFactories[common.ChannelTypeOpenAISB] = openaisb.OpenaiSBProviderFactory{}
providerFactories[common.ChannelTypeAIGC2D] = aigc2d.Aigc2dProviderFactory{}
providerFactories[common.ChannelTypeAPI2GPT] = api2gpt.Api2gptProviderFactory{}
providerFactories[common.ChannelTypeAzureSpeech] = azurespeech.AzureSpeechProviderFactory{}
providerFactories[common.ChannelTypeGemini] = gemini.GeminiProviderFactory{}
}
// 获取供应商
func GetProvider(channel *model.Channel, c *gin.Context) base.ProviderInterface {
factory, ok := providerFactories[channel.Type]
var provider base.ProviderInterface
if !ok {
// 处理未找到的供应商工厂
baseURL := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
if baseURL == "" {
return nil
}
provider = openai.CreateOpenAIProvider(c, baseURL)
}
provider = factory.Create(c)
provider.SetChannel(channel)
return provider
}