From 12a0e7105ec2c5f8c926d48912f4605237f475d0 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 22 Jul 2023 17:48:45 +0800 Subject: [PATCH] refactor: refactor openai related code --- controller/relay-openai.go | 133 +++++++++++++++++++++++++++++++++++++ controller/relay-text.go | 114 ++----------------------------- 2 files changed, 139 insertions(+), 108 deletions(-) create mode 100644 controller/relay-openai.go diff --git a/controller/relay-openai.go b/controller/relay-openai.go new file mode 100644 index 00000000..a3a95af3 --- /dev/null +++ b/controller/relay-openai.go @@ -0,0 +1,133 @@ +package controller + +import ( + "bufio" + "bytes" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "strings" +) + +func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { + responseText := "" + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { // ignore blank line or wrong format + continue + } + dataChan <- data + data = data[6:] + if !strings.HasPrefix(data, "[DONE]") { + switch relayMode { + case RelayModeChatCompletions: + var streamResponse ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return + } + for _, choice := range streamResponse.Choices { + responseText += choice.Delta.Content + } + case RelayModeCompletions: + var streamResponse CompletionsStreamResponse + err := json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return + } + for _, choice := range streamResponse.Choices { + responseText += choice.Text + } + } + } + } + stopChan <- true + }() + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + if strings.HasPrefix(data, "data: [DONE]") { + data = data[:12] + } + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") + c.Render(-1, common.CustomEvent{Data: data}) + return true + case <-stopChan: + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*OpenAIErrorWithStatusCode, *Usage) { + var textResponse TextResponse + if consumeQuota { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if textResponse.Error.Type != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: textResponse.Error, + StatusCode: resp.StatusCode, + }, nil + } + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + } + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the client will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err := io.Copy(c.Writer, resp.Body) + if err != nil { + return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &textResponse.Usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go index a3556efc..723ed59e 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -1,7 +1,6 @@ package controller import ( - "bufio" "bytes" "encoding/json" "errors" @@ -256,119 +255,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { switch apiType { case APITypeOpenAI: if isStream { - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { // ignore blank line or wrong format - continue - } - dataChan <- data - data = data[6:] - if !strings.HasPrefix(data, "[DONE]") { - switch relayMode { - case RelayModeChatCompletions: - var streamResponse ChatCompletionsStreamResponse - err = json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return - } - for _, choice := range streamResponse.Choices { - streamResponseText += choice.Delta.Content - } - case RelayModeCompletions: - var streamResponse CompletionsStreamResponse - err = json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return - } - for _, choice := range streamResponse.Choices { - streamResponseText += choice.Text - } - } - } - } - stopChan <- true - }() - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - if strings.HasPrefix(data, "data: [DONE]") { - data = data[:12] - } - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - c.Render(-1, common.CustomEvent{Data: data}) - return true - case <-stopChan: - return false - } - }) - err = resp.Body.Close() + err, responseText := openaiStreamHandler(c, resp, relayMode) if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + return err } + streamResponseText = responseText return nil } else { - if consumeQuota { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - if textResponse.Error.Type != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: textResponse.Error, - StatusCode: resp.StatusCode, - } - } - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - } - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the client will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) + err, usage := openaiHandler(c, resp, consumeQuota) if err != nil { - return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + return err } + textResponse.Usage = *usage return nil } case APITypeClaude: