✨ feat: baidu support functions
This commit is contained in:
parent
2810a96fd9
commit
475dba1233
@ -28,9 +28,27 @@ func (baiduResponse *BaiduChatResponse) ResponseHandler(resp *http.Response) (Op
|
|||||||
Index: 0,
|
Index: 0,
|
||||||
Message: types.ChatCompletionMessage{
|
Message: types.ChatCompletionMessage{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: baiduResponse.Result,
|
// Content: baiduResponse.Result,
|
||||||
},
|
},
|
||||||
FinishReason: "stop",
|
FinishReason: base.StopFinishReason,
|
||||||
|
}
|
||||||
|
|
||||||
|
if baiduResponse.FunctionCall != nil {
|
||||||
|
if baiduResponse.FunctionCate == "tool" {
|
||||||
|
choice.Message.ToolCalls = []*types.ChatCompletionToolCalls{
|
||||||
|
{
|
||||||
|
Id: baiduResponse.Id,
|
||||||
|
Type: "function",
|
||||||
|
Function: *baiduResponse.FunctionCall,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
choice.FinishReason = &base.StopFinishReasonToolFunction
|
||||||
|
} else {
|
||||||
|
choice.Message.FunctionCall = baiduResponse.FunctionCall
|
||||||
|
choice.FinishReason = &base.StopFinishReasonCallFunction
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
choice.Message.Content = baiduResponse.Result
|
||||||
}
|
}
|
||||||
|
|
||||||
OpenAIResponse = types.ChatCompletionResponse{
|
OpenAIResponse = types.ChatCompletionResponse{
|
||||||
@ -63,10 +81,24 @@ func (p *BaiduProvider) getChatRequestBody(request *types.ChatCompletionRequest)
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &BaiduChatRequest{
|
|
||||||
|
baiduChatRequest := &BaiduChatRequest{
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
|
Temperature: request.Temperature,
|
||||||
Stream: request.Stream,
|
Stream: request.Stream,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if request.Tools != nil {
|
||||||
|
functions := make([]*types.ChatCompletionFunction, 0, len(request.Tools))
|
||||||
|
for _, tool := range request.Tools {
|
||||||
|
functions = append(functions, &tool.Function)
|
||||||
|
}
|
||||||
|
baiduChatRequest.Functions = functions
|
||||||
|
} else if request.Functions != nil {
|
||||||
|
baiduChatRequest.Functions = request.Functions
|
||||||
|
}
|
||||||
|
|
||||||
|
return baiduChatRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
@ -88,7 +120,7 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
|
|||||||
}
|
}
|
||||||
|
|
||||||
if request.Stream {
|
if request.Stream {
|
||||||
usage, errWithCode = p.sendStreamRequest(req, request.Model)
|
usage, errWithCode = p.sendStreamRequest(req, request.Model, request.GetFunctionCate())
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -96,6 +128,7 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
|
|||||||
} else {
|
} else {
|
||||||
baiduChatRequest := &BaiduChatResponse{
|
baiduChatRequest := &BaiduChatResponse{
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
|
FunctionCate: request.GetFunctionCate(),
|
||||||
}
|
}
|
||||||
errWithCode = p.SendRequest(req, baiduChatRequest, false)
|
errWithCode = p.SendRequest(req, baiduChatRequest, false)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
@ -110,10 +143,27 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
|
|||||||
|
|
||||||
func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *types.ChatCompletionStreamResponse {
|
func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *types.ChatCompletionStreamResponse {
|
||||||
var choice types.ChatCompletionStreamChoice
|
var choice types.ChatCompletionStreamChoice
|
||||||
|
|
||||||
|
if baiduResponse.FunctionCall != nil {
|
||||||
|
if baiduResponse.FunctionCate == "tool" {
|
||||||
|
choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{
|
||||||
|
{
|
||||||
|
Id: baiduResponse.Id,
|
||||||
|
Type: "function",
|
||||||
|
Function: *baiduResponse.FunctionCall,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
choice.FinishReason = &base.StopFinishReasonToolFunction
|
||||||
|
} else {
|
||||||
|
choice.Delta.FunctionCall = baiduResponse.FunctionCall
|
||||||
|
choice.FinishReason = &base.StopFinishReasonCallFunction
|
||||||
|
}
|
||||||
|
} else {
|
||||||
choice.Delta.Content = baiduResponse.Result
|
choice.Delta.Content = baiduResponse.Result
|
||||||
if baiduResponse.IsEnd {
|
if baiduResponse.IsEnd {
|
||||||
choice.FinishReason = &base.StopFinishReason
|
choice.FinishReason = &base.StopFinishReason
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
response := types.ChatCompletionStreamResponse{
|
response := types.ChatCompletionStreamResponse{
|
||||||
ID: baiduResponse.Id,
|
ID: baiduResponse.Id,
|
||||||
@ -125,7 +175,7 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string, functionCate string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
defer req.Body.Close()
|
defer req.Body.Close()
|
||||||
|
|
||||||
usage = &types.Usage{}
|
usage = &types.Usage{}
|
||||||
@ -174,6 +224,7 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usag
|
|||||||
select {
|
select {
|
||||||
case data := <-dataChan:
|
case data := <-dataChan:
|
||||||
var baiduResponse BaiduChatStreamResponse
|
var baiduResponse BaiduChatStreamResponse
|
||||||
|
baiduResponse.FunctionCate = functionCate
|
||||||
err := json.Unmarshal([]byte(data), &baiduResponse)
|
err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
@ -20,6 +20,8 @@ type BaiduMessage struct {
|
|||||||
|
|
||||||
type BaiduChatRequest struct {
|
type BaiduChatRequest struct {
|
||||||
Messages []BaiduMessage `json:"messages"`
|
Messages []BaiduMessage `json:"messages"`
|
||||||
|
Functions []*types.ChatCompletionFunction `json:"functions,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
UserId string `json:"user_id,omitempty"`
|
UserId string `json:"user_id,omitempty"`
|
||||||
}
|
}
|
||||||
@ -33,6 +35,8 @@ type BaiduChatResponse struct {
|
|||||||
NeedClearHistory bool `json:"need_clear_history"`
|
NeedClearHistory bool `json:"need_clear_history"`
|
||||||
Usage *types.Usage `json:"usage"`
|
Usage *types.Usage `json:"usage"`
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
|
FunctionCall *types.ChatCompletionToolCallsFunction `json:"function_call,omitempty"`
|
||||||
|
FunctionCate string `json:"function_cate,omitempty"`
|
||||||
BaiduError
|
BaiduError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user