feat: Support Coze

This commit is contained in:
Martial BE 2024-04-20 12:07:37 +08:00
parent 7511d614cf
commit 4fc987f4a0
No known key found for this signature in database
GPG Key ID: D06C32DF0EDB9084
7 changed files with 356 additions and 0 deletions

View File

@ -175,6 +175,7 @@ const (
ChannelTypeCloudflareAI = 35
ChannelTypeCohere = 36
ChannelTypeStabilityAI = 37
ChannelTypeCoze = 38
)
var ChannelBaseURLs = []string{
@ -216,6 +217,7 @@ var ChannelBaseURLs = []string{
"", //35
"https://api.cohere.ai/v1", //36
"https://api.stability.ai/v2beta", //37
"https://api.coze.com/open_api", //38
}
const (

83
providers/coze/base.go Normal file
View File

@ -0,0 +1,83 @@
package coze
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/types"
"strings"
)
type CozeProviderFactory struct{}
// 创建 CozeProvider
func (f CozeProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &CozeProvider{
BaseProvider: base.BaseProvider{
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle),
},
}
}
type CozeProvider struct {
base.BaseProvider
}
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://api.coze.com/open_api",
ChatCompletions: "/v2/chat",
}
}
// 请求错误处理
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
CozeError := &CozeStatus{}
err := json.NewDecoder(resp.Body).Decode(CozeError)
if err != nil {
return nil
}
return errorHandle(CozeError)
}
// 错误处理
func errorHandle(CozeError *CozeStatus) *types.OpenAIError {
if CozeError.Code == 0 {
return nil
}
return &types.OpenAIError{
Message: CozeError.Msg,
Type: "Coze error",
Code: CozeError.Code,
}
}
// 获取请求头
func (p *CozeProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key)
return headers
}
func (p *CozeProvider) GetFullRequestURL(requestURL string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
return fmt.Sprintf("%s%s", baseURL, requestURL)
}
func convertRole(role string) string {
switch role {
case types.ChatMessageRoleSystem, types.ChatMessageRoleAssistant:
return types.ChatMessageRoleAssistant
default:
return types.ChatMessageRoleUser
}
}

199
providers/coze/chat.go Normal file
View File

@ -0,0 +1,199 @@
package coze
import (
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
"strings"
)
type CozeStreamHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
}
func (p *CozeProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
chatResponse := &CozeResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, chatResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
return p.convertToChatOpenai(chatResponse, request)
}
func (p *CozeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
chatHandler := &CozeStreamHandler{
Usage: p.Usage,
Request: request,
}
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
}
func (p *CozeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url)
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_cloudflare_ai_config", http.StatusInternalServerError)
}
// 获取请求头
headers := p.GetRequestHeaders()
chatRequest := p.convertFromChatOpenai(request)
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(chatRequest), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
return req, nil
}
func (p *CozeProvider) convertToChatOpenai(response *CozeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
err := errorHandle(&response.CozeStatus)
if err != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *err,
StatusCode: http.StatusBadRequest,
}
return
}
openaiResponse = &types.ChatCompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Model: request.Model,
Choices: []types.ChatCompletionChoice{{
Index: 0,
Message: types.ChatCompletionMessage{
Role: types.ChatMessageRoleAssistant,
Content: response.String(),
},
FinishReason: types.FinishReasonStop,
}},
}
p.Usage.CompletionTokens = 0
p.Usage.PromptTokens = 1
p.Usage.TotalTokens = 1
openaiResponse.Usage = p.Usage
return
}
func (p *CozeProvider) convertFromChatOpenai(request *types.ChatCompletionRequest) *CozeRequest {
model := strings.TrimPrefix(request.Model, "coze-")
chatRequest := &CozeRequest{
Stream: request.Stream,
BotID: model,
User: "OneAPI",
}
msgLen := len(request.Messages) - 1
for index, message := range request.Messages {
if index == msgLen {
chatRequest.Query = message.StringContent()
} else {
chatRequest.ChatHistory = append(chatRequest.ChatHistory, CozeMessage{
Role: convertRole(message.Role),
Content: message.StringContent(),
ContentType: "text",
})
}
}
return chatRequest
}
// 转换为OpenAI聊天流式请求体
func (h *CozeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data: 或者 meta:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data:") {
*rawLine = nil
return
}
*rawLine = (*rawLine)[5:]
chatResponse := &CozeStreamResponse{}
err := json.Unmarshal(*rawLine, chatResponse)
if err != nil {
errChan <- common.ErrorToOpenAIError(err)
return
}
if chatResponse.Event == "done" {
errChan <- io.EOF
*rawLine = requester.StreamClosed
return
}
if chatResponse.Event != "message" || chatResponse.Message.Type != "answer" {
*rawLine = nil
return
}
h.convertToOpenaiStream(chatResponse, dataChan)
}
func (h *CozeStreamHandler) convertToOpenaiStream(chatResponse *CozeStreamResponse, dataChan chan string) {
streamResponse := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: h.Request.Model,
}
choice := types.ChatCompletionStreamChoice{
Index: 0,
Delta: types.ChatCompletionStreamChoiceDelta{
Role: types.ChatMessageRoleAssistant,
Content: "",
},
}
if chatResponse.IsFinish {
choice.FinishReason = types.FinishReasonStop
} else {
choice.Delta.Content = chatResponse.Message.Content
h.Usage.TotalTokens = 1
h.Usage.PromptTokens = 1
}
streamResponse.Choices = []types.ChatCompletionStreamChoice{choice}
responseBody, _ := json.Marshal(streamResponse)
dataChan <- string(responseBody)
}

52
providers/coze/type.go Normal file
View File

@ -0,0 +1,52 @@
package coze
import "one-api/types"
type CozeStatus struct {
Code int `json:"code"`
Msg string `json:"msg"`
}
type CozeRequest struct {
BotID string `json:"bot_id"`
Query string `json:"query"`
Stream bool `json:"stream"`
User string `json:"user"`
ConversationID string `json:"conversation_id"`
ChatHistory []CozeMessage `json:"chat_history"`
}
type CozeMessage struct {
Role string `json:"role"`
Type string `json:"type,omitempty"`
Content string `json:"content"`
ContentType string `json:"content_type"`
}
type CozeResponse struct {
CozeStatus
ConversationID string `json:"conversation_id"`
Messages []CozeMessage `json:"messages"`
}
func (cr *CozeResponse) String() string {
message := ""
for _, msg := range cr.Messages {
if msg.Type == "answer" && msg.Role == types.ChatMessageRoleAssistant {
message = msg.Content
break
}
}
return message
}
type CozeStreamResponse struct {
Event string `json:"event"`
ErrorInformation string `json:"error_information,omitempty"`
Message CozeMessage `json:"message,omitempty"`
IsFinish bool `json:"is_finish,omitempty"`
Index int `json:"index,omitempty"`
ConversationID string `json:"conversation_id,omitempty"`
}

View File

@ -13,6 +13,7 @@ import (
"one-api/providers/claude"
"one-api/providers/cloudflareAI"
"one-api/providers/cohere"
"one-api/providers/coze"
"one-api/providers/deepseek"
"one-api/providers/gemini"
"one-api/providers/groq"
@ -60,6 +61,7 @@ func init() {
providerFactories[common.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{}
providerFactories[common.ChannelTypeCohere] = cohere.CohereProviderFactory{}
providerFactories[common.ChannelTypeStabilityAI] = stabilityAI.StabilityAIProviderFactory{}
providerFactories[common.ChannelTypeCoze] = coze.CozeProviderFactory{}
}

View File

@ -160,6 +160,13 @@ export const CHANNEL_OPTIONS = {
color: 'default',
url: ''
},
38: {
key: 38,
text: 'Coze',
value: 38,
color: 'primary',
url: ''
},
24: {
key: 24,
text: 'Azure Speech',

View File

@ -300,6 +300,17 @@ const typeConfig = {
test_model: ''
},
modelGroup: 'Stability AI'
},
38: {
input: {
models: ['coze-*']
},
prompt: {
models: '模型名称为coze-{bot_id},你也可以直接使用 coze-* 通配符来匹配所有coze开头的模型',
model_mapping:
'模型名称映射, 你可以取一个容易记忆的名字来代替coze-{bot_id},例如:{"coze-translate": "coze-xxxxx"},注意如果使用了模型映射那么上面的模型名称必须使用映射前的名称上述例子中你应该在模型中填入coze-translate(如果已经使用了coze-*,可以忽略)。'
},
modelGroup: 'Coze'
}
};