✨ feat: baidu support functions
This commit is contained in:
parent
2810a96fd9
commit
475dba1233
@ -28,9 +28,27 @@ func (baiduResponse *BaiduChatResponse) ResponseHandler(resp *http.Response) (Op
|
||||
Index: 0,
|
||||
Message: types.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: baiduResponse.Result,
|
||||
// Content: baiduResponse.Result,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
FinishReason: base.StopFinishReason,
|
||||
}
|
||||
|
||||
if baiduResponse.FunctionCall != nil {
|
||||
if baiduResponse.FunctionCate == "tool" {
|
||||
choice.Message.ToolCalls = []*types.ChatCompletionToolCalls{
|
||||
{
|
||||
Id: baiduResponse.Id,
|
||||
Type: "function",
|
||||
Function: *baiduResponse.FunctionCall,
|
||||
},
|
||||
}
|
||||
choice.FinishReason = &base.StopFinishReasonToolFunction
|
||||
} else {
|
||||
choice.Message.FunctionCall = baiduResponse.FunctionCall
|
||||
choice.FinishReason = &base.StopFinishReasonCallFunction
|
||||
}
|
||||
} else {
|
||||
choice.Message.Content = baiduResponse.Result
|
||||
}
|
||||
|
||||
OpenAIResponse = types.ChatCompletionResponse{
|
||||
@ -63,10 +81,24 @@ func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest)
|
||||
})
|
||||
}
|
||||
}
|
||||
return &BaiduChatRequest{
|
||||
|
||||
baiduChatRequest := &BaiduChatRequest{
|
||||
Messages: messages,
|
||||
Temperature: request.Temperature,
|
||||
Stream: request.Stream,
|
||||
}
|
||||
|
||||
if request.Tools != nil {
|
||||
functions := make([]*types.ChatCompletionFunction, 0, len(request.Tools))
|
||||
for _, tool := range request.Tools {
|
||||
functions = append(functions, &tool.Function)
|
||||
}
|
||||
baiduChatRequest.Functions = functions
|
||||
} else if request.Functions != nil {
|
||||
baiduChatRequest.Functions = request.Functions
|
||||
}
|
||||
|
||||
return baiduChatRequest
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
@ -88,7 +120,7 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
usage, errWithCode = p.sendStreamRequest(req, request.Model)
|
||||
usage, errWithCode = p.sendStreamRequest(req, request.Model, request.GetFunctionCate())
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
@ -96,6 +128,7 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
|
||||
} else {
|
||||
baiduChatRequest := &BaiduChatResponse{
|
||||
Model: request.Model,
|
||||
FunctionCate: request.GetFunctionCate(),
|
||||
}
|
||||
errWithCode = p.SendRequest(req, baiduChatRequest, false)
|
||||
if errWithCode != nil {
|
||||
@ -110,10 +143,27 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
|
||||
|
||||
func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *types.ChatCompletionStreamResponse {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
|
||||
if baiduResponse.FunctionCall != nil {
|
||||
if baiduResponse.FunctionCate == "tool" {
|
||||
choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{
|
||||
{
|
||||
Id: baiduResponse.Id,
|
||||
Type: "function",
|
||||
Function: *baiduResponse.FunctionCall,
|
||||
},
|
||||
}
|
||||
choice.FinishReason = &base.StopFinishReasonToolFunction
|
||||
} else {
|
||||
choice.Delta.FunctionCall = baiduResponse.FunctionCall
|
||||
choice.FinishReason = &base.StopFinishReasonCallFunction
|
||||
}
|
||||
} else {
|
||||
choice.Delta.Content = baiduResponse.Result
|
||||
if baiduResponse.IsEnd {
|
||||
choice.FinishReason = &base.StopFinishReason
|
||||
}
|
||||
}
|
||||
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
ID: baiduResponse.Id,
|
||||
@ -125,7 +175,7 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea
|
||||
return &response
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
defer req.Body.Close()
|
||||
|
||||
usage = &types.Usage{}
|
||||
@ -174,6 +224,7 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usag
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
baiduResponse.FunctionCate = functionCate
|
||||
err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
|
@ -20,6 +20,8 @@ type BaiduMessage struct {
|
||||
|
||||
type BaiduChatRequest struct {
|
||||
Messages []BaiduMessage `json:"messages"`
|
||||
Functions []*types.ChatCompletionFunction `json:"functions,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
UserId string `json:"user_id,omitempty"`
|
||||
}
|
||||
@ -33,6 +35,8 @@ type BaiduChatResponse struct {
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage *types.Usage `json:"usage"`
|
||||
Model string `json:"model,omitempty"`
|
||||
FunctionCall *types.ChatCompletionToolCallsFunction `json:"function_call,omitempty"`
|
||||
FunctionCate string `json:"function_cate,omitempty"`
|
||||
BaiduError
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user