feat: xunfei support functions(stream)

This commit is contained in:
Martial BE 2024-01-03 15:40:20 +08:00 committed by Buer
parent e052009eba
commit 2810a96fd9
4 changed files with 126 additions and 64 deletions

View File

@ -14,6 +14,8 @@ import (
)
var StopFinishReason = "stop"
var StopFinishReasonToolFunction = "tool_calls"
var StopFinishReasonCallFunction = "function_call"
type BaseProvider struct {
BaseURL string
@ -27,7 +29,6 @@ type BaseProvider struct {
ImagesGenerations string
ImagesEdit string
ImagesVariations string
Proxy string
Context *gin.Context
Channel *model.Channel
}

View File

@ -2,6 +2,7 @@ package xunfei
import (
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
@ -15,22 +16,24 @@ import (
func (p *XunfeiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
authUrl := p.GetFullRequestURL(p.ChatCompletions, request.Model)
if request.Stream {
return p.sendStreamRequest(request, authUrl)
} else {
return p.sendRequest(request, authUrl)
}
}
func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
usage = &types.Usage{}
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
if err != nil {
return nil, common.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError)
}
if request.Stream {
return p.sendStreamRequest(dataChan, stopChan, request.GetFunctionCate())
} else {
return p.sendRequest(dataChan, stopChan, request.GetFunctionCate())
}
}
func (p *XunfeiProvider) sendRequest(dataChan chan XunfeiChatResponse, stopChan chan bool, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
usage = &types.Usage{}
var content string
var xunfeiResponse XunfeiChatResponse
stop := false
for !stop {
select {
@ -46,17 +49,17 @@ func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authU
}
}
if xunfeiResponse.Header.Code != 0 {
return nil, common.ErrorWrapper(fmt.Errorf("xunfei response: %s", xunfeiResponse.Header.Message), "xunfei_response_error", http.StatusInternalServerError)
}
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{
Content: "",
},
}
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
}
xunfeiResponse.Payload.Choices.Text[0].Content = content
response := p.responseXunfei2OpenAI(&xunfeiResponse)
response := p.responseXunfei2OpenAI(&xunfeiResponse, functionCate)
jsonResponse, err := json.Marshal(response)
if err != nil {
return nil, common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
@ -66,30 +69,56 @@ func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authU
return usage, nil
}
func (p *XunfeiProvider) sendStreamRequest(request *types.ChatCompletionRequest, authUrl string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
func (p *XunfeiProvider) sendStreamRequest(dataChan chan XunfeiChatResponse, stopChan chan bool, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
usage = &types.Usage{}
dataChan, stopChan, err := p.xunfeiMakeRequest(request, authUrl)
if err != nil {
return nil, common.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError)
// 等待第一个dataChan的响应
xunfeiResponse, ok := <-dataChan
if !ok {
return nil, common.ErrorWrapper(fmt.Errorf("xunfei response channel closed"), "xunfei_response_error", http.StatusInternalServerError)
}
if xunfeiResponse.Header.Code != 0 {
errWithCode = common.ErrorWrapper(fmt.Errorf("xunfei response: %s", xunfeiResponse.Header.Message), "xunfei_response_error", http.StatusInternalServerError)
return nil, errWithCode
}
// 如果第一个响应没有错误设置StreamHeaders并开始streaming
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case xunfeiResponse := <-dataChan:
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := p.streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
// 处理第一个响应
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := p.streamResponseXunfei2OpenAI(&xunfeiResponse, functionCate)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
// 处理后续的响应
for {
select {
case xunfeiResponse, ok := <-dataChan:
if !ok {
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := p.streamResponseXunfei2OpenAI(&xunfeiResponse, functionCate)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
case <-stopChan:
p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
}
})
return usage, nil
@ -123,6 +152,9 @@ func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionReque
}
xunfeiRequest.Payload.Functions = &XunfeiChatPayloadFunctions{}
xunfeiRequest.Payload.Functions.Text = functions
} else if request.Functions != nil {
xunfeiRequest.Payload.Functions = &XunfeiChatPayloadFunctions{}
xunfeiRequest.Payload.Functions.Text = request.Functions
}
xunfeiRequest.Header.AppId = p.apiId
@ -134,13 +166,9 @@ func (p *XunfeiProvider) requestOpenAI2Xunfei(request *types.ChatCompletionReque
return &xunfeiRequest
}
func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse) *types.ChatCompletionResponse {
func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse, functionCate string) *types.ChatCompletionResponse {
if len(response.Payload.Choices.Text) == 0 {
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{
Content: "",
},
}
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
}
choice := types.ChatCompletionChoice{
@ -153,13 +181,22 @@ func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse) *ty
if xunfeiText.FunctionCall != nil {
choice.Message = types.ChatCompletionMessage{
Role: "assistant",
ToolCalls: []*types.ChatCompletionToolCalls{
}
if functionCate == "tool" {
choice.Message.ToolCalls = []*types.ChatCompletionToolCalls{
{
Id: response.Header.Sid,
Type: "function",
Function: *xunfeiText.FunctionCall,
},
},
}
choice.FinishReason = &base.StopFinishReasonToolFunction
} else {
choice.Message.FunctionCall = xunfeiText.FunctionCall
choice.FinishReason = &base.StopFinishReasonCallFunction
}
} else {
choice.Message = types.ChatCompletionMessage{
Role: "assistant",
@ -168,7 +205,9 @@ func (p *XunfeiProvider) responseXunfei2OpenAI(response *XunfeiChatResponse) *ty
}
fullTextResponse := types.ChatCompletionResponse{
ID: response.Header.Sid,
Object: "chat.completion",
Model: "SparkDesk",
Created: common.GetTimestamp(),
Choices: []types.ChatCompletionChoice{choice},
Usage: &response.Payload.Usage.Text,
@ -220,20 +259,38 @@ func (p *XunfeiProvider) xunfeiMakeRequest(textRequest *types.ChatCompletionRequ
return dataChan, stopChan, nil
}
func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *types.ChatCompletionStreamResponse {
func (p *XunfeiProvider) streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse, functionCate string) *types.ChatCompletionStreamResponse {
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
{
Content: "",
},
}
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
}
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &base.StopFinishReason
xunfeiText := xunfeiResponse.Payload.Choices.Text[0]
if xunfeiText.FunctionCall != nil {
if functionCate == "tool" {
choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{
{
Id: xunfeiResponse.Header.Sid,
Index: 0,
Type: "function",
Function: *xunfeiText.FunctionCall,
},
}
choice.FinishReason = &base.StopFinishReasonToolFunction
} else {
choice.Delta.FunctionCall = xunfeiText.FunctionCall
choice.FinishReason = &base.StopFinishReasonCallFunction
}
} else {
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &base.StopFinishReason
}
}
response := types.ChatCompletionStreamResponse{
ID: xunfeiResponse.Header.Sid,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "SparkDesk",

View File

@ -62,12 +62,6 @@ type XunfeiChatResponse struct {
Text []XunfeiChatResponseTextItem `json:"text"`
} `json:"choices"`
Usage struct {
//Text struct {
// QuestionTokens string `json:"question_tokens"`
// PromptTokens string `json:"prompt_tokens"`
// CompletionTokens string `json:"completion_tokens"`
// TotalTokens string `json:"total_tokens"`
//} `json:"text"`
Text types.Usage `json:"text"`
} `json:"usage"`
} `json:"payload"`

View File

@ -6,12 +6,13 @@ const (
)
type ChatCompletionToolCallsFunction struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
}
type ChatCompletionToolCalls struct {
Id string `json:"id"`
Index int `json:"index,omitempty"`
Type string `json:"type"`
Function ChatCompletionToolCallsFunction `json:"function"`
}
@ -129,6 +130,15 @@ type ChatCompletionRequest struct {
ToolChoice any `json:"tool_choice,omitempty"`
}
func (r ChatCompletionRequest) GetFunctionCate() string {
if r.Tools != nil {
return "tool"
} else if r.Functions != nil {
return "function"
}
return ""
}
type ChatCompletionFunction struct {
Name string `json:"name"`
Description string `json:"description"`
@ -157,10 +167,10 @@ type ChatCompletionResponse struct {
}
type ChatCompletionStreamChoiceDelta struct {
Content string `json:"content,omitempty"`
Role string `json:"role,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
Content string `json:"content,omitempty"`
Role string `json:"role,omitempty"`
FunctionCall *ChatCompletionToolCallsFunction `json:"function_call,omitempty"`
ToolCalls []*ChatCompletionToolCalls `json:"tool_calls,omitempty"`
}
type ChatCompletionStreamChoice struct {