From e628b643cd11feb5c428a8c28ee47812e045e644 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 22 Jul 2023 17:36:40 +0800 Subject: [PATCH] refactor: refactor claude related code --- controller/relay-claude.go | 117 +++++++++++++++++++++++++++++++++++++ controller/relay-text.go | 105 ++------------------------------- 2 files changed, 123 insertions(+), 99 deletions(-) diff --git a/controller/relay-claude.go b/controller/relay-claude.go index 99f472e4..22f41cef 100644 --- a/controller/relay-claude.go +++ b/controller/relay-claude.go @@ -1,7 +1,12 @@ package controller import ( + "bufio" + "encoding/json" "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" "one-api/common" "strings" ) @@ -102,3 +107,115 @@ func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { } return &fullTextResponse } + +func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { + responseText := "" + responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + createdTime := common.GetTimestamp() + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { + return i + 4, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if !strings.HasPrefix(data, "event: completion") { + continue + } + data = strings.TrimPrefix(data, "event: completion\r\ndata: ") + dataChan <- data + } + stopChan <- true + }() + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") + var claudeResponse ClaudeResponse + err := json.Unmarshal([]byte(data), &claudeResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + responseText += claudeResponse.Completion + response := streamResponseClaude2OpenAI(&claudeResponse) + response.Id = responseId + response.Created = createdTime + jsonStr, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var claudeResponse ClaudeResponse + err = json.Unmarshal(responseBody, &claudeResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if claudeResponse.Error.Type != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: claudeResponse.Error.Message, + Type: claudeResponse.Error.Type, + Param: "", + Code: claudeResponse.Error.Type, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseClaude2OpenAI(&claudeResponse) + completionTokens := countTokenText(claudeResponse.Completion, model) + usage := Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go index a3d0c801..a3556efc 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -373,111 +373,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } case APITypeClaude: if isStream { - responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - createdTime := common.GetTimestamp() - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { - return i + 4, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string) - stopChan := make(chan bool) - go func() { - for scanner.Scan() { - data := scanner.Text() - if !strings.HasPrefix(data, "event: completion") { - continue - } - data = strings.TrimPrefix(data, "event: completion\r\ndata: ") - dataChan <- data - } - stopChan <- true - }() - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - var claudeResponse ClaudeResponse - err = json.Unmarshal([]byte(data), &claudeResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } - streamResponseText += claudeResponse.Completion - response := streamResponseClaude2OpenAI(&claudeResponse) - response.Id = responseId - response.Created = createdTime - jsonStr, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - err = resp.Body.Close() + err, responseText := claudeStreamHandler(c, resp) if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + return err } + streamResponseText = responseText return nil } else { - responseBody, err := io.ReadAll(resp.Body) + err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return err } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - var claudeResponse ClaudeResponse - err = json.Unmarshal(responseBody, &claudeResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - if claudeResponse.Error.Type != "" { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ - Message: claudeResponse.Error.Message, - Type: claudeResponse.Error.Type, - Param: "", - Code: claudeResponse.Error.Type, - }, - StatusCode: resp.StatusCode, - } - } - fullTextResponse := responseClaude2OpenAI(&claudeResponse) - completionTokens := countTokenText(claudeResponse.Completion, textRequest.Model) - fullTextResponse.Usage = Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, - } - textResponse.Usage = fullTextResponse.Usage - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) + textResponse.Usage = *usage return nil } default: