feat: baidu support functions

This commit is contained in:
Martial BE 2024-01-03 16:25:59 +08:00 committed by Buer
parent 2810a96fd9
commit 475dba1233
2 changed files with 78 additions and 23 deletions

View File

@ -28,9 +28,27 @@ func (baiduResponse *BaiduChatResponse) ResponseHandler(resp *http.Response) (Op
Index: 0,
Message: types.ChatCompletionMessage{
Role: "assistant",
Content: baiduResponse.Result,
// Content: baiduResponse.Result,
},
FinishReason: "stop",
FinishReason: base.StopFinishReason,
}
if baiduResponse.FunctionCall != nil {
if baiduResponse.FunctionCate == "tool" {
choice.Message.ToolCalls = []*types.ChatCompletionToolCalls{
{
Id: baiduResponse.Id,
Type: "function",
Function: *baiduResponse.FunctionCall,
},
}
choice.FinishReason = &base.StopFinishReasonToolFunction
} else {
choice.Message.FunctionCall = baiduResponse.FunctionCall
choice.FinishReason = &base.StopFinishReasonCallFunction
}
} else {
choice.Message.Content = baiduResponse.Result
}
OpenAIResponse = types.ChatCompletionResponse{
@ -63,10 +81,24 @@ func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest)
})
}
}
return &BaiduChatRequest{
baiduChatRequest := &BaiduChatRequest{
Messages: messages,
Temperature: request.Temperature,
Stream: request.Stream,
}
if request.Tools != nil {
functions := make([]*types.ChatCompletionFunction, 0, len(request.Tools))
for _, tool := range request.Tools {
functions = append(functions, &tool.Function)
}
baiduChatRequest.Functions = functions
} else if request.Functions != nil {
baiduChatRequest.Functions = request.Functions
}
return baiduChatRequest
}
func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
@ -88,7 +120,7 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
}
if request.Stream {
usage, errWithCode = p.sendStreamRequest(req, request.Model)
usage, errWithCode = p.sendStreamRequest(req, request.Model, request.GetFunctionCate())
if errWithCode != nil {
return
}
@ -96,6 +128,7 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
} else {
baiduChatRequest := &BaiduChatResponse{
Model: request.Model,
FunctionCate: request.GetFunctionCate(),
}
errWithCode = p.SendRequest(req, baiduChatRequest, false)
if errWithCode != nil {
@ -110,10 +143,27 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *types.ChatCompletionStreamResponse {
var choice types.ChatCompletionStreamChoice
if baiduResponse.FunctionCall != nil {
if baiduResponse.FunctionCate == "tool" {
choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{
{
Id: baiduResponse.Id,
Type: "function",
Function: *baiduResponse.FunctionCall,
},
}
choice.FinishReason = &base.StopFinishReasonToolFunction
} else {
choice.Delta.FunctionCall = baiduResponse.FunctionCall
choice.FinishReason = &base.StopFinishReasonCallFunction
}
} else {
choice.Delta.Content = baiduResponse.Result
if baiduResponse.IsEnd {
choice.FinishReason = &base.StopFinishReason
}
}
response := types.ChatCompletionStreamResponse{
ID: baiduResponse.Id,
@ -125,7 +175,7 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea
return &response
}
func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
defer req.Body.Close()
usage = &types.Usage{}
@ -174,6 +224,7 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usag
select {
case data := <-dataChan:
var baiduResponse BaiduChatStreamResponse
baiduResponse.FunctionCate = functionCate
err := json.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())

View File

@ -20,6 +20,8 @@ type BaiduMessage struct {
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Functions []*types.ChatCompletionFunction `json:"functions,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"`
}
@ -33,6 +35,8 @@ type BaiduChatResponse struct {
NeedClearHistory bool `json:"need_clear_history"`
Usage *types.Usage `json:"usage"`
Model string `json:"model,omitempty"`
FunctionCall *types.ChatCompletionToolCallsFunction `json:"function_call,omitempty"`
FunctionCate string `json:"function_cate,omitempty"`
BaiduError
}