From 2810a96fd9299c4c89def4fc6d6f61b1743c1533 Mon Sep 17 00:00:00 2001 From: Martial BE Date: Wed, 3 Jan 2024 15:40:20 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20xunfei=20support=20function?= =?UTF-8?q?s(stream)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- providers/base/common.go | 3 +- providers/xunfei/chat.go | 159 ++++++++++++++++++++++++++------------- providers/xunfei/type.go | 6 -- types/chat.go | 22 ++++-- 4 files changed, 126 insertions(+), 64 deletions(-) diff --git a/providers/base/common.go b/providers/base/common.go index 7f2d1a41..ef80a397 100644 --- a/providers/base/common.go +++ b/providers/base/common.go @@ -14,6 +14,8 @@ import ( ) var StopFinishReason = "stop" +var StopFinishReasonToolFunction = "tool_calls" +var StopFinishReasonCallFunction = "function_call" type BaseProvider struct { BaseURL string @@ -27,7 +29,6 @@ type BaseProvider struct { ImagesGenerations string ImagesEdit string ImagesVariations string - Proxy string Context *gin.Context Channel *model.Channel } diff --git a/providers/xunfei/chat.go b/providers/xunfei/chat.go index c3b94ec9..8a63f90e 100644 --- a/providers/xunfei/chat.go +++ b/providers/xunfei/chat.go @@ -2,6 +2,7 @@ package xunfei import ( "encoding/json" + "fmt" "io" "net/http" "one-api/common" @@ -15,22 +16,24 @@ import ( func (p *XunfeiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { authUrl := p.GetFullRequestURL(p.ChatCompletions, request.Model) - if request.Stream { - return p.sendStreamRequest(request, authUrl) - } else { - return p.sendRequest(request, authUrl) - } -} - -func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - usage = &types.Usage{} dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl) if err != nil { return nil, common.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError) } + if request.Stream { + return p.sendStreamRequest(dataChan, stopChan, request.GetFunctionCate()) + } else { + return p.sendRequest(dataChan, stopChan, request.GetFunctionCate()) + } +} + +func (p *XunfeiProvider) sendRequest(dataChan chan XunfeiChatResponse, stopChan chan bool, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { + usage = &types.Usage{} + var content string var xunfeiResponse XunfeiChatResponse + stop := false for !stop { select { @@ -46,17 +49,17 @@ func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authU } } + if xunfeiResponse.Header.Code != 0 { + return nil, common.ErrorWrapper(fmt.Errorf("xunfei response: %s", xunfeiResponse.Header.Message), "xunfei_response_error", http.StatusInternalServerError) + } + if len(xunfeiResponse.Payload.Choices.Text) == 0 { - xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ - { - Content: "", - }, - } + xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}} } xunfeiResponse.Payload.Choices.Text[0].Content = content - response := p.responseXunfei2OpenAI(&xunfeiResponse) + response := p.responseXunfei2OpenAI(&xunfeiResponse, functionCate) jsonResponse, err := json.Marshal(response) if err != nil { return nil, common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) @@ -66,30 +69,56 @@ func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authU return usage, nil } -func (p *XunfeiProvider) sendStreamRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { +func (p *XunfeiProvider) sendStreamRequest(dataChan chan XunfeiChatResponse, stopChan chan bool, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { usage = &types.Usage{} - dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl) - if err != nil { - return nil, common.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError) + + // 等待第一个dataChan的响应 + xunfeiResponse, ok := <-dataChan + if !ok { + return nil, common.ErrorWrapper(fmt.Errorf("xunfei response channel closed"), "xunfei_response_error", http.StatusInternalServerError) } + if xunfeiResponse.Header.Code != 0 { + errWithCode = common.ErrorWrapper(fmt.Errorf("xunfei response: %s", xunfeiResponse.Header.Message), "xunfei_response_error", http.StatusInternalServerError) + return nil, errWithCode + } + + // 如果第一个响应没有错误,设置StreamHeaders并开始streaming common.SetEventStreamHeaders(p.Context) p.Context.Stream(func(w io.Writer) bool { - select { - case xunfeiResponse := <-dataChan: - usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens - usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens - usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens - response := p.streamResponseXunfei2OpenAI(&xunfeiResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + // 处理第一个响应 + usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens + usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens + usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens + response := p.streamResponseXunfei2OpenAI(&xunfeiResponse, functionCate) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) return true - case <-stopChan: - p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false + } + p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + + // 处理后续的响应 + for { + select { + case xunfeiResponse, ok := <-dataChan: + if !ok { + p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens + usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens + usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens + response := p.streamResponseXunfei2OpenAI(&xunfeiResponse, functionCate) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + case <-stopChan: + p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } } }) return usage, nil @@ -123,6 +152,9 @@ func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionReque } xunfeiRequest.Payload.Functions = &XunfeiChatPayloadFunctions{} xunfeiRequest.Payload.Functions.Text = functions + } else if request.Functions != nil { + xunfeiRequest.Payload.Functions = &XunfeiChatPayloadFunctions{} + xunfeiRequest.Payload.Functions.Text = request.Functions } xunfeiRequest.Header.AppId = p.apiId @@ -134,13 +166,9 @@ func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionReque return &xunfeiRequest } -func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse) *types.ChatCompletionResponse { +func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse, functionCate string) *types.ChatCompletionResponse { if len(response.Payload.Choices.Text) == 0 { - response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ - { - Content: "", - }, - } + response.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}} } choice := types.ChatCompletionChoice{ @@ -153,13 +181,22 @@ func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse) *ty if xunfeiText.FunctionCall != nil { choice.Message = types.ChatCompletionMessage{ Role: "assistant", - ToolCalls: []*types.ChatCompletionToolCalls{ + } + + if functionCate == "tool" { + choice.Message.ToolCalls = []*types.ChatCompletionToolCalls{ { + Id: response.Header.Sid, Type: "function", Function: *xunfeiText.FunctionCall, }, - }, + } + choice.FinishReason = &base.StopFinishReasonToolFunction + } else { + choice.Message.FunctionCall = xunfeiText.FunctionCall + choice.FinishReason = &base.StopFinishReasonCallFunction } + } else { choice.Message = types.ChatCompletionMessage{ Role: "assistant", @@ -168,7 +205,9 @@ func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse) *ty } fullTextResponse := types.ChatCompletionResponse{ + ID: response.Header.Sid, Object: "chat.completion", + Model: "SparkDesk", Created: common.GetTimestamp(), Choices: []types.ChatCompletionChoice{choice}, Usage: &response.Payload.Usage.Text, @@ -220,20 +259,38 @@ func (p *XunfeiProvider) xunfeiMakeRequest(textRequest *types.ChatCompletionRequ return dataChan, stopChan, nil } -func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *types.ChatCompletionStreamResponse { +func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse, functionCate string) *types.ChatCompletionStreamResponse { if len(xunfeiResponse.Payload.Choices.Text) == 0 { - xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ - { - Content: "", - }, - } + xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}} } var choice types.ChatCompletionStreamChoice - choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content - if xunfeiResponse.Payload.Choices.Status == 2 { - choice.FinishReason = &base.StopFinishReason + xunfeiText := xunfeiResponse.Payload.Choices.Text[0] + + if xunfeiText.FunctionCall != nil { + if functionCate == "tool" { + choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{ + { + Id: xunfeiResponse.Header.Sid, + Index: 0, + Type: "function", + Function: *xunfeiText.FunctionCall, + }, + } + choice.FinishReason = &base.StopFinishReasonToolFunction + } else { + choice.Delta.FunctionCall = xunfeiText.FunctionCall + choice.FinishReason = &base.StopFinishReasonCallFunction + } + + } else { + choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content + if xunfeiResponse.Payload.Choices.Status == 2 { + choice.FinishReason = &base.StopFinishReason + } } + response := types.ChatCompletionStreamResponse{ + ID: xunfeiResponse.Header.Sid, Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: "SparkDesk", diff --git a/providers/xunfei/type.go b/providers/xunfei/type.go index b74d064d..7da8daf4 100644 --- a/providers/xunfei/type.go +++ b/providers/xunfei/type.go @@ -62,12 +62,6 @@ type XunfeiChatResponse struct { Text []XunfeiChatResponseTextItem `json:"text"` } `json:"choices"` Usage struct { - //Text struct { - // QuestionTokens string `json:"question_tokens"` - // PromptTokens string `json:"prompt_tokens"` - // CompletionTokens string `json:"completion_tokens"` - // TotalTokens string `json:"total_tokens"` - //} `json:"text"` Text types.Usage `json:"text"` } `json:"usage"` } `json:"payload"` diff --git a/types/chat.go b/types/chat.go index 701e9b86..cf87d787 100644 --- a/types/chat.go +++ b/types/chat.go @@ -6,12 +6,13 @@ const ( ) type ChatCompletionToolCallsFunction struct { - Name string `json:"name"` - Arguments string `json:"arguments"` + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` } type ChatCompletionToolCalls struct { Id string `json:"id"` + Index int `json:"index,omitempty"` Type string `json:"type"` Function ChatCompletionToolCallsFunction `json:"function"` } @@ -129,6 +130,15 @@ type ChatCompletionRequest struct { ToolChoice any `json:"tool_choice,omitempty"` } +func (r ChatCompletionRequest) GetFunctionCate() string { + if r.Tools != nil { + return "tool" + } else if r.Functions != nil { + return "function" + } + return "" +} + type ChatCompletionFunction struct { Name string `json:"name"` Description string `json:"description"` @@ -157,10 +167,10 @@ type ChatCompletionResponse struct { } type ChatCompletionStreamChoiceDelta struct { - Content string `json:"content,omitempty"` - Role string `json:"role,omitempty"` - FunctionCall any `json:"function_call,omitempty"` - ToolCalls any `json:"tool_calls,omitempty"` + Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` + FunctionCall *ChatCompletionToolCallsFunction `json:"function_call,omitempty"` + ToolCalls []*ChatCompletionToolCalls `json:"tool_calls,omitempty"` } type ChatCompletionStreamChoice struct {