diff --git a/controller/relay-text.go b/controller/relay-text.go index 6d481983..4481c652 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -541,24 +541,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return nil } case APITypeXunfei: - if isStream { - auth := c.Request.Header.Get("Authorization") - auth = strings.TrimPrefix(auth, "Bearer ") - splits := strings.Split(auth, "|") - if len(splits) != 3 { - return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) - } - err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } else { - return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) + auth := c.Request.Header.Get("Authorization") + auth = strings.TrimPrefix(auth, "Bearer ") + splits := strings.Split(auth, "|") + if len(splits) != 3 { + return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) } + var err *OpenAIErrorWithStatusCode + var usage *Usage + if isStream { + err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) + } else { + err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2]) + } + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil case APITypeAIProxyLibrary: if isStream { err, usage := aiProxyLibraryStreamHandler(c, resp) diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 3b6fe5a0..ff6bf065 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { Role: "assistant", Content: response.Payload.Choices.Text[0].Content, }, + FinishReason: stopFinishReason, } fullTextResponse := OpenAITextResponse{ Object: "chat.completion", @@ -177,33 +178,82 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { } func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { + domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) + dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) + if err != nil { + return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil + } + setEventStreamHeaders(c) var usage Usage - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString("api_version") + c.Stream(func(w io.Writer) bool { + select { + case xunfeiResponse := <-dataChan: + usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens + usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens + usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens + response := streamResponseXunfei2OpenAI(&xunfeiResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + return nil, &usage +} + +func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { + domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) + dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) + if err != nil { + return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } - if apiVersion == "" { - apiVersion = "v1.1" - common.SysLog("api_version not found, use default: " + apiVersion) + var usage Usage + var content string + var xunfeiResponse XunfeiChatResponse + stop := false + for !stop { + select { + case xunfeiResponse = <-dataChan: + content += xunfeiResponse.Payload.Choices.Text[0].Content + usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens + usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens + usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens + case stop = <-stopChan: + } } - domain := "general" - if apiVersion == "v2.1" { - domain = "generalv2" + + xunfeiResponse.Payload.Choices.Text[0].Content = content + + response := responseXunfei2OpenAI(&xunfeiResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } - hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion) + c.Writer.Header().Set("Content-Type", "application/json") + _, _ = c.Writer.Write(jsonResponse) + return nil, &usage +} + +func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } - conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil) + conn, resp, err := d.Dial(authUrl, nil) if err != nil || resp.StatusCode != 101 { - return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil + return nil, nil, err } data := requestOpenAI2Xunfei(textRequest, appId, domain) err = conn.WriteJSON(data) if err != nil { - return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil + return nil, nil, err } + dataChan := make(chan XunfeiChatResponse) stopChan := make(chan bool) go func() { @@ -230,61 +280,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId } stopChan <- true }() - setEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case xunfeiResponse := <-dataChan: - usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens - usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens - usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens - response := streamResponseXunfei2OpenAI(&xunfeiResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - return nil, &usage + + return dataChan, stopChan, nil } -func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var xunfeiResponse XunfeiChatResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil +func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) { + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + if apiVersion == "" { + apiVersion = "v1.1" + common.SysLog("api_version not found, use default: " + apiVersion) } - err = json.Unmarshal(responseBody, &xunfeiResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + domain := "general" + if apiVersion == "v2.1" { + domain = "generalv2" } - if xunfeiResponse.Header.Code != 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ - Message: xunfeiResponse.Header.Message, - Type: "xunfei_error", - Param: "", - Code: xunfeiResponse.Header.Code, - }, - StatusCode: resp.StatusCode, - }, nil - } - fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse) - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage + authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) + return domain, authUrl }