diff --git a/common/constants.go b/common/constants.go index a9265b86..210c10f1 100644 --- a/common/constants.go +++ b/common/constants.go @@ -189,6 +189,7 @@ const ( ChannelTypeTencent = 23 ChannelTypeAzureSpeech = 24 ChannelTypeGemini = 25 + ChannelTypeBaichuan = 26 ) var ChannelBaseURLs = []string{ @@ -218,6 +219,7 @@ var ChannelBaseURLs = []string{ "https://hunyuan.cloud.tencent.com", //23 "", //24 "", //25 + "https://api.baichuan-ai.com", //26 } const ( diff --git a/common/model-ratio.go b/common/model-ratio.go index d6bd1d3a..3857f3d9 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -101,6 +101,10 @@ var ModelRatio = map[string]float64{ "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 + "Baichuan2-Turbo": 0.5715, // ¥0.008 / 1k tokens + "Baichuan2-Turbo-192k": 1.143, // ¥0.016 / 1k tokens + "Baichuan2-53B": 1.4286, // ¥0.02 / 1k tokens + "Baichuan-Text-Embedding": 0.0357, // ¥0.0005 / 1k tokens } func ModelRatio2JSONString() string { diff --git a/common/token.go b/common/token.go index 887989ad..88aad380 100644 --- a/common/token.go +++ b/common/token.go @@ -190,13 +190,13 @@ func countImageTokens(url string, detail string) (_ int, err error) { func CountTokenInput(input any, model string) int { switch v := input.(type) { case string: - return CountTokenInput(v, model) + return CountTokenText(v, model) case []string: text := "" for _, s := range v { text += s } - return CountTokenInput(text, model) + return CountTokenText(text, model) } return 0 } diff --git a/providers/baichuan/base.go b/providers/baichuan/base.go new file mode 100644 index 00000000..a0e5fb8a --- /dev/null +++ b/providers/baichuan/base.go @@ -0,0 +1,30 @@ +package baichuan + +import ( + "one-api/providers/base" + "one-api/providers/openai" + + "github.com/gin-gonic/gin" +) + +// 定义供应商工厂 +type BaichuanProviderFactory struct{} + +// 创建 BaichuanProvider +// https://platform.baichuan-ai.com/docs/api +func (f BaichuanProviderFactory) Create(c *gin.Context) base.ProviderInterface { + return &BaichuanProvider{ + OpenAIProvider: openai.OpenAIProvider{ + BaseProvider: base.BaseProvider{ + BaseURL: "https://api.baichuan-ai.com", + ChatCompletions: "/v1/chat/completions", + Embeddings: "/v1/embeddings", + Context: c, + }, + }, + } +} + +type BaichuanProvider struct { + openai.OpenAIProvider +} diff --git a/providers/baichuan/chat.go b/providers/baichuan/chat.go new file mode 100644 index 00000000..d439c4d9 --- /dev/null +++ b/providers/baichuan/chat.go @@ -0,0 +1,100 @@ +package baichuan + +import ( + "net/http" + "one-api/common" + "one-api/providers/openai" + "one-api/types" + "strings" +) + +func (baichuanResponse *BaichuanChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { + if baichuanResponse.Error.Message != "" { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: baichuanResponse.Error, + StatusCode: resp.StatusCode, + } + + return + } + + OpenAIResponse = types.ChatCompletionResponse{ + ID: baichuanResponse.ID, + Object: baichuanResponse.Object, + Created: baichuanResponse.Created, + Model: baichuanResponse.Model, + Choices: baichuanResponse.Choices, + Usage: baichuanResponse.Usage, + } + + return +} + +// 获取聊天请求体 +func (p *BaichuanProvider) getChatRequestBody(request *types.ChatCompletionRequest) *BaichuanChatRequest { + messages := make([]BaichuanMessage, 0, len(request.Messages)) + for i := 0; i < len(request.Messages); i++ { + message := request.Messages[i] + if message.Role == "system" || message.Role == "assistant" { + message.Role = "assistant" + } else { + message.Role = "user" + } + messages = append(messages, BaichuanMessage{ + Content: message.StringContent(), + Role: strings.ToLower(message.Role), + }) + } + + return &BaichuanChatRequest{ + Model: request.Model, + Messages: messages, + Stream: request.Stream, + Temperature: request.Temperature, + TopP: request.TopP, + TopK: request.N, + } +} + +// 聊天 +func (p *BaichuanProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { + + requestBody := p.getChatRequestBody(request) + + fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model) + headers := p.GetRequestHeaders() + if request.Stream { + headers["Accept"] = "text/event-stream" + } + + client := common.NewClient() + req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + if request.Stream { + openAIProviderChatStreamResponse := &openai.OpenAIProviderChatStreamResponse{} + var textResponse string + errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderChatStreamResponse) + if errWithCode != nil { + return + } + + usage = &types.Usage{ + PromptTokens: promptTokens, + CompletionTokens: common.CountTokenText(textResponse, request.Model), + TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model), + } + + } else { + baichuanResponse := &BaichuanChatResponse{} + errWithCode = p.SendRequest(req, baichuanResponse, false) + if errWithCode != nil { + return + } + + usage = baichuanResponse.Usage + } + return +} diff --git a/providers/baichuan/type.go b/providers/baichuan/type.go new file mode 100644 index 00000000..eadccf86 --- /dev/null +++ b/providers/baichuan/type.go @@ -0,0 +1,36 @@ +package baichuan + +import "one-api/providers/openai" + +type BaichuanMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type BaichuanKnowledgeBase struct { + Ids []string `json:"id"` +} + +type BaichuanChatRequest struct { + Model string `json:"model"` + Messages []BaichuanMessage `json:"messages"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + WithSearchEnhance bool `json:"with_search_enhance,omitempty"` + KnowledgeBase BaichuanKnowledgeBase `json:"knowledge_base,omitempty"` +} + +type BaichuanKnowledgeBaseResponse struct { + Cites []struct { + Title string `json:"title"` + Content string `json:"content"` + FileId string `json:"file_id"` + } `json:"cites"` +} + +type BaichuanChatResponse struct { + openai.OpenAIProviderChatResponse + KnowledgeBase BaichuanKnowledgeBaseResponse `json:"knowledge_base,omitempty"` +} diff --git a/providers/openai/base.go b/providers/openai/base.go index 99db8079..1b621e8c 100644 --- a/providers/openai/base.go +++ b/providers/openai/base.go @@ -108,7 +108,7 @@ func (p *OpenAIProvider) GetRequestBody(request any, isModelMapped bool) (reques } // 发送流式请求 -func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) { +func (p *OpenAIProvider) SendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) { defer req.Body.Close() client := common.GetHttpClient(p.Channel.Proxy) diff --git a/providers/openai/chat.go b/providers/openai/chat.go index 01ab8723..142ed58f 100644 --- a/providers/openai/chat.go +++ b/providers/openai/chat.go @@ -46,7 +46,7 @@ func (p *OpenAIProvider) ChatAction(request *types.ChatCompletionRequest, isMode if request.Stream { openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{} var textResponse string - errWithCode, textResponse = p.sendStreamRequest(req, openAIProviderChatStreamResponse) + errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderChatStreamResponse) if errWithCode != nil { return } diff --git a/providers/openai/completion.go b/providers/openai/completion.go index 5560abef..7064db0c 100644 --- a/providers/openai/completion.go +++ b/providers/openai/completion.go @@ -47,7 +47,7 @@ func (p *OpenAIProvider) CompleteAction(request *types.CompletionRequest, isMode if request.Stream { // TODO var textResponse string - errWithCode, textResponse = p.sendStreamRequest(req, openAIProviderCompletionResponse) + errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderCompletionResponse) if errWithCode != nil { return } diff --git a/providers/providers.go b/providers/providers.go index 52d9662e..01043ce3 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -10,6 +10,7 @@ import ( "one-api/providers/api2gpt" "one-api/providers/azure" azurespeech "one-api/providers/azureSpeech" + "one-api/providers/baichuan" "one-api/providers/baidu" "one-api/providers/base" "one-api/providers/claude" @@ -52,6 +53,7 @@ func init() { providerFactories[common.ChannelTypeAPI2GPT] = api2gpt.Api2gptProviderFactory{} providerFactories[common.ChannelTypeAzureSpeech] = azurespeech.AzureSpeechProviderFactory{} providerFactories[common.ChannelTypeGemini] = gemini.GeminiProviderFactory{} + providerFactories[common.ChannelTypeBaichuan] = baichuan.BaichuanProviderFactory{} } diff --git a/web/src/constants/ChannelConstants.js b/web/src/constants/ChannelConstants.js index c91c5de8..ddd51d01 100644 --- a/web/src/constants/ChannelConstants.js +++ b/web/src/constants/ChannelConstants.js @@ -65,6 +65,12 @@ export const CHANNEL_OPTIONS = { value: 23, color: 'default' }, + 26: { + key: 26, + text: '百川', + value: 26, + color: 'orange' + }, 24: { key: 24, text: 'Azure Speech', diff --git a/web/src/views/Channel/type/Config.js b/web/src/views/Channel/type/Config.js index 5a6f8a9e..bc770b15 100644 --- a/web/src/views/Channel/type/Config.js +++ b/web/src/views/Channel/type/Config.js @@ -130,6 +130,11 @@ const typeConfig = { prompt: { other: '请输入版本号,例如:v1' } + }, + 26: { + input: { + models: ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan2-53B', 'Baichuan-Text-Embedding'] + } } };