ai-gateway/providers/providers.go

74 lines
2.6 KiB
Go
Raw Permalink Normal View History

package providers
import (
"one-api/common"
2023-12-02 11:54:21 +00:00
"one-api/providers/aigc2d"
2023-12-02 09:51:28 +00:00
"one-api/providers/aiproxy"
"one-api/providers/ali"
2023-12-02 09:51:28 +00:00
"one-api/providers/api2d"
2023-12-02 11:54:21 +00:00
"one-api/providers/api2gpt"
"one-api/providers/azure"
2023-12-02 14:13:47 +00:00
azurespeech "one-api/providers/azureSpeech"
"one-api/providers/baidu"
"one-api/providers/base"
"one-api/providers/claude"
2023-12-02 09:51:28 +00:00
"one-api/providers/closeai"
"one-api/providers/gemini"
"one-api/providers/openai"
2023-12-02 09:51:28 +00:00
"one-api/providers/openaisb"
"one-api/providers/palm"
"one-api/providers/tencent"
"one-api/providers/xunfei"
"one-api/providers/zhipu"
"github.com/gin-gonic/gin"
)
2023-12-02 10:14:48 +00:00
// 定义供应商工厂接口
type ProviderFactory interface {
Create(c *gin.Context) base.ProviderInterface
}
// 创建全局的供应商工厂映射
var providerFactories = make(map[int]ProviderFactory)
// 在程序启动时,添加所有的供应商工厂
func init() {
providerFactories[common.ChannelTypeOpenAI] = openai.OpenAIProviderFactory{}
providerFactories[common.ChannelTypeAzure] = azure.AzureProviderFactory{}
providerFactories[common.ChannelTypeAli] = ali.AliProviderFactory{}
providerFactories[common.ChannelTypeTencent] = tencent.TencentProviderFactory{}
providerFactories[common.ChannelTypeBaidu] = baidu.BaiduProviderFactory{}
providerFactories[common.ChannelTypeAnthropic] = claude.ClaudeProviderFactory{}
providerFactories[common.ChannelTypePaLM] = palm.PalmProviderFactory{}
providerFactories[common.ChannelTypeZhipu] = zhipu.ZhipuProviderFactory{}
providerFactories[common.ChannelTypeXunfei] = xunfei.XunfeiProviderFactory{}
providerFactories[common.ChannelTypeAIProxy] = aiproxy.AIProxyProviderFactory{}
providerFactories[common.ChannelTypeAPI2D] = api2d.Api2dProviderFactory{}
providerFactories[common.ChannelTypeCloseAI] = closeai.CloseaiProviderFactory{}
providerFactories[common.ChannelTypeOpenAISB] = openaisb.OpenaiSBProviderFactory{}
2023-12-02 11:54:21 +00:00
providerFactories[common.ChannelTypeAIGC2D] = aigc2d.Aigc2dProviderFactory{}
providerFactories[common.ChannelTypeAPI2GPT] = api2gpt.Api2gptProviderFactory{}
2023-12-02 14:13:47 +00:00
providerFactories[common.ChannelTypeAzureSpeech] = azurespeech.AzureSpeechProviderFactory{}
providerFactories[common.ChannelTypeGemini] = gemini.GeminiProviderFactory{}
2023-12-02 11:54:21 +00:00
2023-12-02 10:14:48 +00:00
}
// 获取供应商
func GetProvider(channelType int, c *gin.Context) base.ProviderInterface {
2023-12-02 10:14:48 +00:00
factory, ok := providerFactories[channelType]
if !ok {
// 处理未找到的供应商工厂
baseURL := common.ChannelBaseURLs[channelType]
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
if baseURL != "" {
return openai.CreateOpenAIProvider(c, baseURL)
}
return nil
}
2023-12-02 10:14:48 +00:00
return factory.Create(c)
}