feat: support non-stream mode for xunfei (#498)
* feat:xunfei suport none stream * fix:join content ignore seq --------- Co-authored-by: igophper <admin@jialilgu.cn>
This commit is contained in:
parent
12ef9679a7
commit
24df3e5f62
@ -541,24 +541,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
case APITypeXunfei:
|
case APITypeXunfei:
|
||||||
if isStream {
|
auth := c.Request.Header.Get("Authorization")
|
||||||
auth := c.Request.Header.Get("Authorization")
|
auth = strings.TrimPrefix(auth, "Bearer ")
|
||||||
auth = strings.TrimPrefix(auth, "Bearer ")
|
splits := strings.Split(auth, "|")
|
||||||
splits := strings.Split(auth, "|")
|
if len(splits) != 3 {
|
||||||
if len(splits) != 3 {
|
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||||
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
|
|
||||||
}
|
}
|
||||||
|
var err *OpenAIErrorWithStatusCode
|
||||||
|
var usage *Usage
|
||||||
|
if isStream {
|
||||||
|
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
||||||
|
} else {
|
||||||
|
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
case APITypeAIProxyLibrary:
|
case APITypeAIProxyLibrary:
|
||||||
if isStream {
|
if isStream {
|
||||||
err, usage := aiProxyLibraryStreamHandler(c, resp)
|
err, usage := aiProxyLibraryStreamHandler(c, resp)
|
||||||
|
@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
|
|||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: response.Payload.Choices.Text[0].Content,
|
Content: response.Payload.Choices.Text[0].Content,
|
||||||
},
|
},
|
||||||
|
FinishReason: stopFinishReason,
|
||||||
}
|
}
|
||||||
fullTextResponse := OpenAITextResponse{
|
fullTextResponse := OpenAITextResponse{
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
@ -177,33 +178,82 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||||
|
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
setEventStreamHeaders(c)
|
||||||
var usage Usage
|
var usage Usage
|
||||||
query := c.Request.URL.Query()
|
c.Stream(func(w io.Writer) bool {
|
||||||
apiVersion := query.Get("api-version")
|
select {
|
||||||
if apiVersion == "" {
|
case xunfeiResponse := <-dataChan:
|
||||||
apiVersion = c.GetString("api_version")
|
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||||
|
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||||
|
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||||
|
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||||
|
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
if apiVersion == "" {
|
var usage Usage
|
||||||
apiVersion = "v1.1"
|
var content string
|
||||||
common.SysLog("api_version not found, use default: " + apiVersion)
|
var xunfeiResponse XunfeiChatResponse
|
||||||
|
stop := false
|
||||||
|
for !stop {
|
||||||
|
select {
|
||||||
|
case xunfeiResponse = <-dataChan:
|
||||||
|
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
||||||
|
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||||
|
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||||
|
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||||
|
case stop = <-stopChan:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
domain := "general"
|
|
||||||
if apiVersion == "v2.1" {
|
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
||||||
domain = "generalv2"
|
|
||||||
|
response := responseXunfei2OpenAI(&xunfeiResponse)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion)
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
|
||||||
d := websocket.Dialer{
|
d := websocket.Dialer{
|
||||||
HandshakeTimeout: 5 * time.Second,
|
HandshakeTimeout: 5 * time.Second,
|
||||||
}
|
}
|
||||||
conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
|
conn, resp, err := d.Dial(authUrl, nil)
|
||||||
if err != nil || resp.StatusCode != 101 {
|
if err != nil || resp.StatusCode != 101 {
|
||||||
return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
||||||
err = conn.WriteJSON(data)
|
err = conn.WriteJSON(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
dataChan := make(chan XunfeiChatResponse)
|
dataChan := make(chan XunfeiChatResponse)
|
||||||
stopChan := make(chan bool)
|
stopChan := make(chan bool)
|
||||||
go func() {
|
go func() {
|
||||||
@ -230,61 +280,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(c)
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
return dataChan, stopChan, nil
|
||||||
select {
|
|
||||||
case xunfeiResponse := <-dataChan:
|
|
||||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
|
||||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
|
||||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
|
||||||
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return nil, &usage
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
|
||||||
var xunfeiResponse XunfeiChatResponse
|
query := c.Request.URL.Query()
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
apiVersion := query.Get("api-version")
|
||||||
if err != nil {
|
if apiVersion == "" {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
apiVersion = c.GetString("api_version")
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
if apiVersion == "" {
|
||||||
if err != nil {
|
apiVersion = "v1.1"
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
common.SysLog("api_version not found, use default: " + apiVersion)
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &xunfeiResponse)
|
domain := "general"
|
||||||
if err != nil {
|
if apiVersion == "v2.1" {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
domain = "generalv2"
|
||||||
}
|
}
|
||||||
if xunfeiResponse.Header.Code != 0 {
|
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
||||||
return &OpenAIErrorWithStatusCode{
|
return domain, authUrl
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: xunfeiResponse.Header.Message,
|
|
||||||
Type: "xunfei_error",
|
|
||||||
Param: "",
|
|
||||||
Code: xunfeiResponse.Header.Code,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user