Update relay-ali.go: 改进stream模式,添加联网搜索能力

通义千问支持stream的增量模式,不需要每次去掉上次的前缀;实测qwen-max联网模式效果不错,添加了联网模式。如果别的模型有问题可以改为单独给qwen-max开放
This commit is contained in:
moondie 2023-12-22 04:06:45 +08:00 committed by GitHub
parent b7fcb319da
commit 7dd1cf1e24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -27,6 +27,9 @@ type AliParameters struct {
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"` Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"` EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
Stream bool `json:"stream,omitempty"`
} }
type AliChatRequest struct { type AliChatRequest struct {
@ -95,12 +98,14 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
Input: AliInput{ Input: AliInput{
Messages: messages, Messages: messages,
}, },
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
// TopP: request.TopP, // TopP: request.TopP,
// TopK: 50, // TopK: 50,
// //Seed: 0, // //Seed: 0,
// //EnableSearch: false, EnableSearch: true,
//}, IncrementalOutput=true,
Stream=request.Stream,
},
} }
} }
@ -202,7 +207,7 @@ func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStre
Id: aliResponse.RequestId, Id: aliResponse.RequestId,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Model: "ernie-bot", Model: "qwen",
Choices: []ChatCompletionsStreamResponseChoice{choice}, Choices: []ChatCompletionsStreamResponseChoice{choice},
} }
return &response return &response
@ -240,7 +245,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
stopChan <- true stopChan <- true
}() }()
setEventStreamHeaders(c) setEventStreamHeaders(c)
lastResponseText := "" //lastResponseText := ""
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
@ -256,8 +261,8 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
} }
response := streamResponseAli2OpenAI(&aliResponse) response := streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) //response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Text //lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
common.SysError("error marshalling stream response: " + err.Error()) common.SysError("error marshalling stream response: " + err.Error())