From 4fc987f4a00a800fade759729a8548d92df0a4e7 Mon Sep 17 00:00:00 2001 From: Martial BE Date: Sat, 20 Apr 2024 12:07:37 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20Support=20Coze?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/constants.go | 2 + providers/coze/base.go | 83 +++++++++++ providers/coze/chat.go | 199 ++++++++++++++++++++++++++ providers/coze/type.go | 52 +++++++ providers/providers.go | 2 + web/src/constants/ChannelConstants.js | 7 + web/src/views/Channel/type/Config.js | 11 ++ 7 files changed, 356 insertions(+) create mode 100644 providers/coze/base.go create mode 100644 providers/coze/chat.go create mode 100644 providers/coze/type.go diff --git a/common/constants.go b/common/constants.go index 82dbb90c..2b9d4c2c 100644 --- a/common/constants.go +++ b/common/constants.go @@ -175,6 +175,7 @@ const ( ChannelTypeCloudflareAI = 35 ChannelTypeCohere = 36 ChannelTypeStabilityAI = 37 + ChannelTypeCoze = 38 ) var ChannelBaseURLs = []string{ @@ -216,6 +217,7 @@ var ChannelBaseURLs = []string{ "", //35 "https://api.cohere.ai/v1", //36 "https://api.stability.ai/v2beta", //37 + "https://api.coze.com/open_api", //38 } const ( diff --git a/providers/coze/base.go b/providers/coze/base.go new file mode 100644 index 00000000..729e1a96 --- /dev/null +++ b/providers/coze/base.go @@ -0,0 +1,83 @@ +package coze + +import ( + "encoding/json" + "fmt" + "net/http" + "one-api/common/requester" + "one-api/model" + "one-api/providers/base" + "one-api/types" + "strings" +) + +type CozeProviderFactory struct{} + +// 创建 CozeProvider +func (f CozeProviderFactory) Create(channel *model.Channel) base.ProviderInterface { + return &CozeProvider{ + BaseProvider: base.BaseProvider{ + Config: getConfig(), + Channel: channel, + Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle), + }, + } +} + +type CozeProvider struct { + base.BaseProvider +} + +func getConfig() base.ProviderConfig { + return base.ProviderConfig{ + BaseURL: "https://api.coze.com/open_api", + ChatCompletions: "/v2/chat", + } +} + +// 请求错误处理 +func requestErrorHandle(resp *http.Response) *types.OpenAIError { + CozeError := &CozeStatus{} + err := json.NewDecoder(resp.Body).Decode(CozeError) + if err != nil { + return nil + } + + return errorHandle(CozeError) +} + +// 错误处理 +func errorHandle(CozeError *CozeStatus) *types.OpenAIError { + if CozeError.Code == 0 { + return nil + } + return &types.OpenAIError{ + Message: CozeError.Msg, + Type: "Coze error", + Code: CozeError.Code, + } +} + +// 获取请求头 +func (p *CozeProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + p.CommonRequestHeaders(headers) + headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key) + + return headers +} + +func (p *CozeProvider) GetFullRequestURL(requestURL string) string { + baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") + + return fmt.Sprintf("%s%s", baseURL, requestURL) +} + +func convertRole(role string) string { + switch role { + case types.ChatMessageRoleSystem, types.ChatMessageRoleAssistant: + return types.ChatMessageRoleAssistant + default: + return types.ChatMessageRoleUser + } +} diff --git a/providers/coze/chat.go b/providers/coze/chat.go new file mode 100644 index 00000000..bc5fd679 --- /dev/null +++ b/providers/coze/chat.go @@ -0,0 +1,199 @@ +package coze + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/common/requester" + "one-api/types" + "strings" +) + +type CozeStreamHandler struct { + Usage *types.Usage + Request *types.ChatCompletionRequest +} + +func (p *CozeProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + chatResponse := &CozeResponse{} + // 发送请求 + _, errWithCode = p.Requester.SendRequest(req, chatResponse, false) + if errWithCode != nil { + return nil, errWithCode + } + + return p.convertToChatOpenai(chatResponse, request) +} + +func (p *CozeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { + req, errWithCode := p.getChatRequest(request) + if errWithCode != nil { + return nil, errWithCode + } + defer req.Body.Close() + + // 发送请求 + resp, errWithCode := p.Requester.SendRequestRaw(req) + if errWithCode != nil { + return nil, errWithCode + } + + chatHandler := &CozeStreamHandler{ + Usage: p.Usage, + Request: request, + } + + return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream) +} + +func (p *CozeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { + url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + if errWithCode != nil { + return nil, errWithCode + } + // 获取请求地址 + fullRequestURL := p.GetFullRequestURL(url) + if fullRequestURL == "" { + return nil, common.ErrorWrapper(nil, "invalid_cloudflare_ai_config", http.StatusInternalServerError) + } + + // 获取请求头 + headers := p.GetRequestHeaders() + chatRequest := p.convertFromChatOpenai(request) + + // 创建请求 + req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(chatRequest), p.Requester.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + return req, nil +} + +func (p *CozeProvider) convertToChatOpenai(response *CozeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { + err := errorHandle(&response.CozeStatus) + if err != nil { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: *err, + StatusCode: http.StatusBadRequest, + } + return + } + + openaiResponse = &types.ChatCompletionResponse{ + ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion", + Created: common.GetTimestamp(), + Model: request.Model, + Choices: []types.ChatCompletionChoice{{ + Index: 0, + Message: types.ChatCompletionMessage{ + Role: types.ChatMessageRoleAssistant, + Content: response.String(), + }, + FinishReason: types.FinishReasonStop, + }}, + } + + p.Usage.CompletionTokens = 0 + p.Usage.PromptTokens = 1 + p.Usage.TotalTokens = 1 + openaiResponse.Usage = p.Usage + + return +} + +func (p *CozeProvider) convertFromChatOpenai(request *types.ChatCompletionRequest) *CozeRequest { + model := strings.TrimPrefix(request.Model, "coze-") + chatRequest := &CozeRequest{ + Stream: request.Stream, + BotID: model, + User: "OneAPI", + } + msgLen := len(request.Messages) - 1 + + for index, message := range request.Messages { + if index == msgLen { + chatRequest.Query = message.StringContent() + } else { + chatRequest.ChatHistory = append(chatRequest.ChatHistory, CozeMessage{ + Role: convertRole(message.Role), + Content: message.StringContent(), + ContentType: "text", + }) + } + + } + + return chatRequest +} + +// 转换为OpenAI聊天流式请求体 +func (h *CozeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { + // 如果rawLine 前缀不为data: 或者 meta:,则直接返回 + if !strings.HasPrefix(string(*rawLine), "data:") { + *rawLine = nil + return + } + + *rawLine = (*rawLine)[5:] + + chatResponse := &CozeStreamResponse{} + err := json.Unmarshal(*rawLine, chatResponse) + if err != nil { + errChan <- common.ErrorToOpenAIError(err) + return + } + + if chatResponse.Event == "done" { + errChan <- io.EOF + *rawLine = requester.StreamClosed + return + } + + if chatResponse.Event != "message" || chatResponse.Message.Type != "answer" { + *rawLine = nil + return + } + + h.convertToOpenaiStream(chatResponse, dataChan) +} + +func (h *CozeStreamHandler) convertToOpenaiStream(chatResponse *CozeStreamResponse, dataChan chan string) { + streamResponse := types.ChatCompletionStreamResponse{ + ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: h.Request.Model, + } + + choice := types.ChatCompletionStreamChoice{ + Index: 0, + Delta: types.ChatCompletionStreamChoiceDelta{ + Role: types.ChatMessageRoleAssistant, + Content: "", + }, + } + + if chatResponse.IsFinish { + choice.FinishReason = types.FinishReasonStop + } else { + choice.Delta.Content = chatResponse.Message.Content + + h.Usage.TotalTokens = 1 + h.Usage.PromptTokens = 1 + } + + streamResponse.Choices = []types.ChatCompletionStreamChoice{choice} + responseBody, _ := json.Marshal(streamResponse) + dataChan <- string(responseBody) + +} diff --git a/providers/coze/type.go b/providers/coze/type.go new file mode 100644 index 00000000..a4ad109c --- /dev/null +++ b/providers/coze/type.go @@ -0,0 +1,52 @@ +package coze + +import "one-api/types" + +type CozeStatus struct { + Code int `json:"code"` + Msg string `json:"msg"` +} + +type CozeRequest struct { + BotID string `json:"bot_id"` + Query string `json:"query"` + Stream bool `json:"stream"` + User string `json:"user"` + ConversationID string `json:"conversation_id"` + ChatHistory []CozeMessage `json:"chat_history"` +} + +type CozeMessage struct { + Role string `json:"role"` + Type string `json:"type,omitempty"` + Content string `json:"content"` + ContentType string `json:"content_type"` +} + +type CozeResponse struct { + CozeStatus + ConversationID string `json:"conversation_id"` + Messages []CozeMessage `json:"messages"` +} + +func (cr *CozeResponse) String() string { + message := "" + + for _, msg := range cr.Messages { + if msg.Type == "answer" && msg.Role == types.ChatMessageRoleAssistant { + message = msg.Content + break + } + } + + return message +} + +type CozeStreamResponse struct { + Event string `json:"event"` + ErrorInformation string `json:"error_information,omitempty"` + Message CozeMessage `json:"message,omitempty"` + IsFinish bool `json:"is_finish,omitempty"` + Index int `json:"index,omitempty"` + ConversationID string `json:"conversation_id,omitempty"` +} diff --git a/providers/providers.go b/providers/providers.go index bce39afd..2e1f0f78 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -13,6 +13,7 @@ import ( "one-api/providers/claude" "one-api/providers/cloudflareAI" "one-api/providers/cohere" + "one-api/providers/coze" "one-api/providers/deepseek" "one-api/providers/gemini" "one-api/providers/groq" @@ -60,6 +61,7 @@ func init() { providerFactories[common.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{} providerFactories[common.ChannelTypeCohere] = cohere.CohereProviderFactory{} providerFactories[common.ChannelTypeStabilityAI] = stabilityAI.StabilityAIProviderFactory{} + providerFactories[common.ChannelTypeCoze] = coze.CozeProviderFactory{} } diff --git a/web/src/constants/ChannelConstants.js b/web/src/constants/ChannelConstants.js index 179ea070..bc78c7e2 100644 --- a/web/src/constants/ChannelConstants.js +++ b/web/src/constants/ChannelConstants.js @@ -160,6 +160,13 @@ export const CHANNEL_OPTIONS = { color: 'default', url: '' }, + 38: { + key: 38, + text: 'Coze', + value: 38, + color: 'primary', + url: '' + }, 24: { key: 24, text: 'Azure Speech', diff --git a/web/src/views/Channel/type/Config.js b/web/src/views/Channel/type/Config.js index 5c6416a1..86c2176d 100644 --- a/web/src/views/Channel/type/Config.js +++ b/web/src/views/Channel/type/Config.js @@ -300,6 +300,17 @@ const typeConfig = { test_model: '' }, modelGroup: 'Stability AI' + }, + 38: { + input: { + models: ['coze-*'] + }, + prompt: { + models: '模型名称为coze-{bot_id},你也可以直接使用 coze-* 通配符来匹配所有coze开头的模型', + model_mapping: + '模型名称映射, 你可以取一个容易记忆的名字来代替coze-{bot_id},例如:{"coze-translate": "coze-xxxxx"},注意:如果使用了模型映射,那么上面的模型名称必须使用映射前的名称,上述例子中,你应该在模型中填入coze-translate(如果已经使用了coze-*,可以忽略)。' + }, + modelGroup: 'Coze' } };