fix: suport stream header (#805 #861)

closes #805 #861
This commit is contained in:
igophper 2023-12-24 13:52:59 +08:00
parent b7fcb319da
commit 2b73e13800

View File

@ -55,6 +55,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
group := c.GetString("group") group := c.GetString("group")
var textRequest GeneralOpenAIRequest var textRequest GeneralOpenAIRequest
err := common.UnmarshalBodyReusable(c, &textRequest) err := common.UnmarshalBodyReusable(c, &textRequest)
isStream := textRequest.Stream || strings.HasPrefix(c.Request.Header.Get("Accept"), "text/event-stream")
textRequest.Stream = isStream
if err != nil { if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
} }
@ -193,7 +197,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
version = c.GetString("api_version") version = c.GetString("api_version")
} }
action := "generateContent" action := "generateContent"
if textRequest.Stream { if isStream {
action = "streamGenerateContent" action = "streamGenerateContent"
} }
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
@ -202,7 +206,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
fullRequestURL += "?key=" + apiKey fullRequestURL += "?key=" + apiKey
case APITypeZhipu: case APITypeZhipu:
method := "invoke" method := "invoke"
if textRequest.Stream { if isStream {
method = "sse-invoke" method = "sse-invoke"
} }
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
@ -355,7 +359,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
var req *http.Request var req *http.Request
var resp *http.Response var resp *http.Response
isStream := textRequest.Stream
if apiType != APITypeXunfei { // cause xunfei use websocket if apiType != APITypeXunfei { // cause xunfei use websocket
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
@ -387,7 +390,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
req.Header.Set("Authorization", token) req.Header.Set("Authorization", token)
case APITypeAli: case APITypeAli:
req.Header.Set("Authorization", "Bearer "+apiKey) req.Header.Set("Authorization", "Bearer "+apiKey)
if textRequest.Stream { if isStream {
req.Header.Set("X-DashScope-SSE", "enable") req.Header.Set("X-DashScope-SSE", "enable")
} }
if c.GetString("plugin") != "" { if c.GetString("plugin") != "" {
@ -541,7 +544,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return nil return nil
} }
case APITypePaLM: case APITypePaLM:
if textRequest.Stream { // PaLM2 API does not support stream if isStream { // PaLM2 API does not support stream
err, responseText := palmStreamHandler(c, resp) err, responseText := palmStreamHandler(c, resp)
if err != nil { if err != nil {
return err return err
@ -560,7 +563,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return nil return nil
} }
case APITypeGemini: case APITypeGemini:
if textRequest.Stream { if isStream {
err, responseText := geminiChatStreamHandler(c, resp) err, responseText := geminiChatStreamHandler(c, resp)
if err != nil { if err != nil {
return err return err