diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 87037e34..3b6fe5a0 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -75,7 +75,7 @@ type XunfeiChatResponse struct { } `json:"payload"` } -func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *XunfeiChatRequest { +func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest { messages := make([]XunfeiMessage, 0, len(request.Messages)) for _, message := range request.Messages { if message.Role == "system" { @@ -96,7 +96,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *Xun } xunfeiRequest := XunfeiChatRequest{} xunfeiRequest.Header.AppId = xunfeiAppId - xunfeiRequest.Parameter.Chat.Domain = "general" + xunfeiRequest.Parameter.Chat.Domain = domain xunfeiRequest.Parameter.Chat.Temperature = request.Temperature xunfeiRequest.Parameter.Chat.TopK = request.N xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens @@ -178,15 +178,28 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { var usage Usage + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") + } + if apiVersion == "" { + apiVersion = "v1.1" + common.SysLog("api_version not found, use default: " + apiVersion) + } + domain := "general" + if apiVersion == "v2.1" { + domain = "generalv2" + } + hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion) d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } - hostUrl := "wss://aichat.xf-yun.com/v1/chat" conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil) if err != nil || resp.StatusCode != 101 { return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil } - data := requestOpenAI2Xunfei(textRequest, appId) + data := requestOpenAI2Xunfei(textRequest, appId, domain) err = conn.WriteJSON(data) if err != nil { return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil diff --git a/i18n/en.json b/i18n/en.json index ae395dae..aed65979 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -521,5 +521,7 @@ "此项可选,用于通过代理站来进行 API 调用,请输入代理站地址,格式为:https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com", "取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?", "按照如下格式输入:": "Enter in the following format:", + "模型版本": "Model version", + "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", "点击查看": "click to view" } diff --git a/middleware/distributor.go b/middleware/distributor.go index 91c00e1a..ebbde535 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -107,7 +107,7 @@ func Distribute() func(c *gin.Context) { c.Set("model_mapping", channel.ModelMapping) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.BaseURL) - if channel.Type == common.ChannelTypeAzure { + if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei { c.Set("api_version", channel.Other) } c.Next() diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index b5fb524e..fcbdb980 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -163,6 +163,9 @@ const EditChannel = () => { if (localInputs.type === 3 && localInputs.other === '') { localInputs.other = '2023-06-01-preview'; } + if (localInputs.type === 18 && localInputs.other === '') { + localInputs.other = 'v2.1'; + } if (localInputs.model_mapping === '') { localInputs.model_mapping = '{}'; } @@ -275,6 +278,20 @@ const EditChannel = () => { options={groupOptions} /> + { + inputs.type === 18 && ( + + + + ) + }