feat: baidu support functions

This commit is contained in:
Martial BE 2024-01-03 16:25:59 +08:00 committed by Buer
parent 2810a96fd9
commit 475dba1233
2 changed files with 78 additions and 23 deletions

View File

@ -28,9 +28,27 @@ func (baiduResponse *BaiduChatResponse) ResponseHandler(resp *http.Response) (Op
Index: 0, Index: 0,
Message: types.ChatCompletionMessage{ Message: types.ChatCompletionMessage{
Role: "assistant", Role: "assistant",
Content: baiduResponse.Result, // Content: baiduResponse.Result,
}, },
FinishReason: "stop", FinishReason: base.StopFinishReason,
}
if baiduResponse.FunctionCall != nil {
if baiduResponse.FunctionCate == "tool" {
choice.Message.ToolCalls = []*types.ChatCompletionToolCalls{
{
Id: baiduResponse.Id,
Type: "function",
Function: *baiduResponse.FunctionCall,
},
}
choice.FinishReason = &base.StopFinishReasonToolFunction
} else {
choice.Message.FunctionCall = baiduResponse.FunctionCall
choice.FinishReason = &base.StopFinishReasonCallFunction
}
} else {
choice.Message.Content = baiduResponse.Result
} }
OpenAIResponse = types.ChatCompletionResponse{ OpenAIResponse = types.ChatCompletionResponse{
@ -63,10 +81,24 @@ func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest)
}) })
} }
} }
return &BaiduChatRequest{
baiduChatRequest := &BaiduChatRequest{
Messages: messages, Messages: messages,
Temperature: request.Temperature,
Stream: request.Stream, Stream: request.Stream,
} }
if request.Tools != nil {
functions := make([]*types.ChatCompletionFunction, 0, len(request.Tools))
for _, tool := range request.Tools {
functions = append(functions, &tool.Function)
}
baiduChatRequest.Functions = functions
} else if request.Functions != nil {
baiduChatRequest.Functions = request.Functions
}
return baiduChatRequest
} }
func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
@ -88,7 +120,7 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
} }
if request.Stream { if request.Stream {
usage, errWithCode = p.sendStreamRequest(req, request.Model) usage, errWithCode = p.sendStreamRequest(req, request.Model, request.GetFunctionCate())
if errWithCode != nil { if errWithCode != nil {
return return
} }
@ -96,6 +128,7 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
} else { } else {
baiduChatRequest := &BaiduChatResponse{ baiduChatRequest := &BaiduChatResponse{
Model: request.Model, Model: request.Model,
FunctionCate: request.GetFunctionCate(),
} }
errWithCode = p.SendRequest(req, baiduChatRequest, false) errWithCode = p.SendRequest(req, baiduChatRequest, false)
if errWithCode != nil { if errWithCode != nil {
@ -110,10 +143,27 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *types.ChatCompletionStreamResponse { func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *types.ChatCompletionStreamResponse {
var choice types.ChatCompletionStreamChoice var choice types.ChatCompletionStreamChoice
if baiduResponse.FunctionCall != nil {
if baiduResponse.FunctionCate == "tool" {
choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{
{
Id: baiduResponse.Id,
Type: "function",
Function: *baiduResponse.FunctionCall,
},
}
choice.FinishReason = &base.StopFinishReasonToolFunction
} else {
choice.Delta.FunctionCall = baiduResponse.FunctionCall
choice.FinishReason = &base.StopFinishReasonCallFunction
}
} else {
choice.Delta.Content = baiduResponse.Result choice.Delta.Content = baiduResponse.Result
if baiduResponse.IsEnd { if baiduResponse.IsEnd {
choice.FinishReason = &base.StopFinishReason choice.FinishReason = &base.StopFinishReason
} }
}
response := types.ChatCompletionStreamResponse{ response := types.ChatCompletionStreamResponse{
ID: baiduResponse.Id, ID: baiduResponse.Id,
@ -125,7 +175,7 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea
return &response return &response
} }
func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
defer req.Body.Close() defer req.Body.Close()
usage = &types.Usage{} usage = &types.Usage{}
@ -174,6 +224,7 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usag
select { select {
case data := <-dataChan: case data := <-dataChan:
var baiduResponse BaiduChatStreamResponse var baiduResponse BaiduChatStreamResponse
baiduResponse.FunctionCate = functionCate
err := json.Unmarshal([]byte(data), &baiduResponse) err := json.Unmarshal([]byte(data), &baiduResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())

View File

@ -20,6 +20,8 @@ type BaiduMessage struct {
type BaiduChatRequest struct { type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"` Messages []BaiduMessage `json:"messages"`
Functions []*types.ChatCompletionFunction `json:"functions,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"` UserId string `json:"user_id,omitempty"`
} }
@ -33,6 +35,8 @@ type BaiduChatResponse struct {
NeedClearHistory bool `json:"need_clear_history"` NeedClearHistory bool `json:"need_clear_history"`
Usage *types.Usage `json:"usage"` Usage *types.Usage `json:"usage"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
FunctionCall *types.ChatCompletionToolCallsFunction `json:"function_call,omitempty"`
FunctionCate string `json:"function_cate,omitempty"`
BaiduError BaiduError
} }