From ee9e746520e57520fd2b7c7f1bc85d4eaed077fd Mon Sep 17 00:00:00 2001 From: moondie <528893699@qq.com> Date: Sun, 24 Dec 2023 16:17:21 +0800 Subject: [PATCH] feat: update ali stream implementation & enable internet search (#856) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update relay-ali.go: 改进stream模式,添加联网搜索能力 通义千问支持stream的增量模式,不需要每次去掉上次的前缀;实测qwen-max联网模式效果不错,添加了联网模式。如果别的模型有问题可以改为单独给qwen-max开放 * 删除"stream参数" 刚发现原来阿里api没有这个参数,上次误加了。 * refactor: only enable search when specified * fix: remove custom suffix when get model ratio --------- Co-authored-by: JustSong --- common/model-ratio.go | 3 +++ controller/relay-ali.go | 37 +++++++++++++++++----------- web/src/pages/Channel/EditChannel.js | 7 ++++++ 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index d1c96d96..d6b51f84 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -115,6 +115,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error { } func GetModelRatio(name string) float64 { + if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { + name = strings.TrimSuffix(name, "-internet") + } ratio, ok := ModelRatio[name] if !ok { SysError("model ratio not found: " + name) diff --git a/controller/relay-ali.go b/controller/relay-ali.go index 65626f6a..7968bfb6 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -23,10 +23,11 @@ type AliInput struct { } type AliParameters struct { - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Seed uint64 `json:"seed,omitempty"` - EnableSearch bool `json:"enable_search,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` + IncrementalOutput bool `json:"incremental_output,omitempty"` } type AliChatRequest struct { @@ -81,6 +82,8 @@ type AliChatResponse struct { AliError } +const AliEnableSearchModelSuffix = "-internet" + func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { messages := make([]AliMessage, 0, len(request.Messages)) for i := 0; i < len(request.Messages); i++ { @@ -90,17 +93,21 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { Role: strings.ToLower(message.Role), }) } + enableSearch := false + aliModel := request.Model + if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) { + enableSearch = true + aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix) + } return &AliChatRequest{ - Model: request.Model, + Model: aliModel, Input: AliInput{ Messages: messages, }, - //Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's - // TopP: request.TopP, - // TopK: 50, - // //Seed: 0, - // //EnableSearch: false, - //}, + Parameters: AliParameters{ + EnableSearch: enableSearch, + IncrementalOutput: request.Stream, + }, } } @@ -202,7 +209,7 @@ func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStre Id: aliResponse.RequestId, Object: "chat.completion.chunk", Created: common.GetTimestamp(), - Model: "ernie-bot", + Model: "qwen", Choices: []ChatCompletionsStreamResponseChoice{choice}, } return &response @@ -240,7 +247,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat stopChan <- true }() setEventStreamHeaders(c) - lastResponseText := "" + //lastResponseText := "" c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: @@ -256,8 +263,8 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens } response := streamResponseAli2OpenAI(&aliResponse) - response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) - lastResponseText = aliResponse.Output.Text + //response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) + //lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 364da69d..b1c7ae62 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -70,6 +70,13 @@ const EditChannel = () => { break; case 17: localModels = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1']; + let withInternetVersion = []; + for (let i = 0; i < localModels.length; i++) { + if (localModels[i].startsWith('qwen-')) { + withInternetVersion.push(localModels[i] + '-internet'); + } + } + localModels = [...localModels, ...withInternetVersion]; break; case 16: localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite'];