From 475dba1233a90501bd71944db0c6256bfd5f109f Mon Sep 17 00:00:00 2001 From: Martial BE Date: Wed, 3 Jan 2024 16:25:59 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20baidu=20support=20functions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- providers/baidu/chat.go | 75 ++++++++++++++++++++++++++++++++++------- providers/baidu/type.go | 26 ++++++++------ 2 files changed, 78 insertions(+), 23 deletions(-) diff --git a/providers/baidu/chat.go b/providers/baidu/chat.go index e675adc2..f3ded14c 100644 --- a/providers/baidu/chat.go +++ b/providers/baidu/chat.go @@ -27,10 +27,28 @@ func (baiduResponse *BaiduChatResponse) ResponseHandler(resp *http.Response) (Op choice := types.ChatCompletionChoice{ Index: 0, Message: types.ChatCompletionMessage{ - Role: "assistant", - Content: baiduResponse.Result, + Role: "assistant", + // Content: baiduResponse.Result, }, - FinishReason: "stop", + FinishReason: base.StopFinishReason, + } + + if baiduResponse.FunctionCall != nil { + if baiduResponse.FunctionCate == "tool" { + choice.Message.ToolCalls = []*types.ChatCompletionToolCalls{ + { + Id: baiduResponse.Id, + Type: "function", + Function: *baiduResponse.FunctionCall, + }, + } + choice.FinishReason = &base.StopFinishReasonToolFunction + } else { + choice.Message.FunctionCall = baiduResponse.FunctionCall + choice.FinishReason = &base.StopFinishReasonCallFunction + } + } else { + choice.Message.Content = baiduResponse.Result } OpenAIResponse = types.ChatCompletionResponse{ @@ -63,10 +81,24 @@ func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest) }) } } - return &BaiduChatRequest{ - Messages: messages, - Stream: request.Stream, + + baiduChatRequest := &BaiduChatRequest{ + Messages: messages, + Temperature: request.Temperature, + Stream: request.Stream, } + + if request.Tools != nil { + functions := make([]*types.ChatCompletionFunction, 0, len(request.Tools)) + for _, tool := range request.Tools { + functions = append(functions, &tool.Function) + } + baiduChatRequest.Functions = functions + } else if request.Functions != nil { + baiduChatRequest.Functions = request.Functions + } + + return baiduChatRequest } func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { @@ -88,14 +120,15 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel } if request.Stream { - usage, errWithCode = p.sendStreamRequest(req, request.Model) + usage, errWithCode = p.sendStreamRequest(req, request.Model, request.GetFunctionCate()) if errWithCode != nil { return } } else { baiduChatRequest := &BaiduChatResponse{ - Model: request.Model, + Model: request.Model, + FunctionCate: request.GetFunctionCate(), } errWithCode = p.SendRequest(req, baiduChatRequest, false) if errWithCode != nil { @@ -110,9 +143,26 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *types.ChatCompletionStreamResponse { var choice types.ChatCompletionStreamChoice - choice.Delta.Content = baiduResponse.Result - if baiduResponse.IsEnd { - choice.FinishReason = &base.StopFinishReason + + if baiduResponse.FunctionCall != nil { + if baiduResponse.FunctionCate == "tool" { + choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{ + { + Id: baiduResponse.Id, + Type: "function", + Function: *baiduResponse.FunctionCall, + }, + } + choice.FinishReason = &base.StopFinishReasonToolFunction + } else { + choice.Delta.FunctionCall = baiduResponse.FunctionCall + choice.FinishReason = &base.StopFinishReasonCallFunction + } + } else { + choice.Delta.Content = baiduResponse.Result + if baiduResponse.IsEnd { + choice.FinishReason = &base.StopFinishReason + } } response := types.ChatCompletionStreamResponse{ @@ -125,7 +175,7 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea return &response } -func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { +func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { defer req.Body.Close() usage = &types.Usage{} @@ -174,6 +224,7 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usag select { case data := <-dataChan: var baiduResponse BaiduChatStreamResponse + baiduResponse.FunctionCate = functionCate err := json.Unmarshal([]byte(data), &baiduResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) diff --git a/providers/baidu/type.go b/providers/baidu/type.go index b2e0f1e8..e222fe7a 100644 --- a/providers/baidu/type.go +++ b/providers/baidu/type.go @@ -19,20 +19,24 @@ type BaiduMessage struct { } type BaiduChatRequest struct { - Messages []BaiduMessage `json:"messages"` - Stream bool `json:"stream"` - UserId string `json:"user_id,omitempty"` + Messages []BaiduMessage `json:"messages"` + Functions []*types.ChatCompletionFunction `json:"functions,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Stream bool `json:"stream"` + UserId string `json:"user_id,omitempty"` } type BaiduChatResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Result string `json:"result"` - IsTruncated bool `json:"is_truncated"` - NeedClearHistory bool `json:"need_clear_history"` - Usage *types.Usage `json:"usage"` - Model string `json:"model,omitempty"` + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Result string `json:"result"` + IsTruncated bool `json:"is_truncated"` + NeedClearHistory bool `json:"need_clear_history"` + Usage *types.Usage `json:"usage"` + Model string `json:"model,omitempty"` + FunctionCall *types.ChatCompletionToolCallsFunction `json:"function_call,omitempty"` + FunctionCate string `json:"function_cate,omitempty"` BaiduError }