feat: support non-stream mode for xunfei (#498)

* feat:xunfei suport none stream

* fix:join content ignore seq

---------

Co-authored-by: igophper <admin@jialilgu.cn>
This commit is contained in:
igophper 2023-09-17 18:16:12 +08:00 committed by GitHub
parent 12ef9679a7
commit 24df3e5f62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 98 additions and 83 deletions

View File

@ -541,24 +541,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return nil return nil
} }
case APITypeXunfei: case APITypeXunfei:
if isStream { auth := c.Request.Header.Get("Authorization")
auth := c.Request.Header.Get("Authorization") auth = strings.TrimPrefix(auth, "Bearer ")
auth = strings.TrimPrefix(auth, "Bearer ") splits := strings.Split(auth, "|")
splits := strings.Split(auth, "|") if len(splits) != 3 {
if len(splits) != 3 { return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
} }
var err *OpenAIErrorWithStatusCode
var usage *Usage
if isStream {
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
} else {
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
case APITypeAIProxyLibrary: case APITypeAIProxyLibrary:
if isStream { if isStream {
err, usage := aiProxyLibraryStreamHandler(c, resp) err, usage := aiProxyLibraryStreamHandler(c, resp)

View File

@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
Role: "assistant", Role: "assistant",
Content: response.Payload.Choices.Text[0].Content, Content: response.Payload.Choices.Text[0].Content,
}, },
FinishReason: stopFinishReason,
} }
fullTextResponse := OpenAITextResponse{ fullTextResponse := OpenAITextResponse{
Object: "chat.completion", Object: "chat.completion",
@ -177,33 +178,82 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
} }
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
setEventStreamHeaders(c)
var usage Usage var usage Usage
query := c.Request.URL.Query() c.Stream(func(w io.Writer) bool {
apiVersion := query.Get("api-version") select {
if apiVersion == "" { case xunfeiResponse := <-dataChan:
apiVersion = c.GetString("api_version") usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, &usage
}
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
} }
if apiVersion == "" { var usage Usage
apiVersion = "v1.1" var content string
common.SysLog("api_version not found, use default: " + apiVersion) var xunfeiResponse XunfeiChatResponse
stop := false
for !stop {
select {
case xunfeiResponse = <-dataChan:
content += xunfeiResponse.Payload.Choices.Text[0].Content
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
case stop = <-stopChan:
}
} }
domain := "general"
if apiVersion == "v2.1" { xunfeiResponse.Payload.Choices.Text[0].Content = content
domain = "generalv2"
response := responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion) c.Writer.Header().Set("Content-Type", "application/json")
_, _ = c.Writer.Write(jsonResponse)
return nil, &usage
}
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
d := websocket.Dialer{ d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second, HandshakeTimeout: 5 * time.Second,
} }
conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil) conn, resp, err := d.Dial(authUrl, nil)
if err != nil || resp.StatusCode != 101 { if err != nil || resp.StatusCode != 101 {
return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil return nil, nil, err
} }
data := requestOpenAI2Xunfei(textRequest, appId, domain) data := requestOpenAI2Xunfei(textRequest, appId, domain)
err = conn.WriteJSON(data) err = conn.WriteJSON(data)
if err != nil { if err != nil {
return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil return nil, nil, err
} }
dataChan := make(chan XunfeiChatResponse) dataChan := make(chan XunfeiChatResponse)
stopChan := make(chan bool) stopChan := make(chan bool)
go func() { go func() {
@ -230,61 +280,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
} }
stopChan <- true stopChan <- true
}() }()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { return dataChan, stopChan, nil
select {
case xunfeiResponse := <-dataChan:
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, &usage
} }
func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
var xunfeiResponse XunfeiChatResponse query := c.Request.URL.Query()
responseBody, err := io.ReadAll(resp.Body) apiVersion := query.Get("api-version")
if err != nil { if apiVersion == "" {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil apiVersion = c.GetString("api_version")
} }
err = resp.Body.Close() if apiVersion == "" {
if err != nil { apiVersion = "v1.1"
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil common.SysLog("api_version not found, use default: " + apiVersion)
} }
err = json.Unmarshal(responseBody, &xunfeiResponse) domain := "general"
if err != nil { if apiVersion == "v2.1" {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil domain = "generalv2"
} }
if xunfeiResponse.Header.Code != 0 { authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
return &OpenAIErrorWithStatusCode{ return domain, authUrl
OpenAIError: OpenAIError{
Message: xunfeiResponse.Header.Message,
Type: "xunfei_error",
Param: "",
Code: xunfeiResponse.Header.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
} }