From 616759933a760403ceaf36c09fdc11bab1c92d6b Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 30 Jun 2024 18:26:47 +0800 Subject: [PATCH] refactor: move functions to render package --- common/render/render.go | 25 +++++++++++++++++++++++++ common/utils.go | 19 ------------------- relay/adaptor/aiproxy/main.go | 7 ++++--- relay/adaptor/ali/main.go | 5 +++-- relay/adaptor/anthropic/main.go | 5 +++-- relay/adaptor/baidu/main.go | 5 +++-- relay/adaptor/cloudflare/main.go | 5 +++-- relay/adaptor/cohere/main.go | 5 +++-- relay/adaptor/coze/main.go | 5 +++-- relay/adaptor/gemini/main.go | 5 +++-- relay/adaptor/ollama/main.go | 5 +++-- relay/adaptor/openai/main.go | 13 +++++++------ relay/adaptor/palm/palm.go | 5 +++-- relay/adaptor/tencent/main.go | 5 +++-- relay/adaptor/zhipu/main.go | 7 ++++--- 15 files changed, 70 insertions(+), 51 deletions(-) create mode 100644 common/render/render.go diff --git a/common/render/render.go b/common/render/render.go new file mode 100644 index 00000000..0b988793 --- /dev/null +++ b/common/render/render.go @@ -0,0 +1,25 @@ +package render + +import ( + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" +) + +func StringData(c *gin.Context, str string) { + c.Render(-1, common.CustomEvent{Data: "data: " + str}) +} + +func ObjectData(c *gin.Context, object interface{}) error { + jsonData, err := json.Marshal(object) + if err != nil { + return fmt.Errorf("error marshalling object: %w", err) + } + StringData(c, string(jsonData)) + return nil +} + +func Done(c *gin.Context) { + StringData(c, "[DONE]") +} diff --git a/common/utils.go b/common/utils.go index 730730b7..ecee2c8e 100644 --- a/common/utils.go +++ b/common/utils.go @@ -1,11 +1,7 @@ package common import ( - "encoding/json" "fmt" - "strings" - - "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" ) @@ -16,18 +12,3 @@ func LogQuota(quota int64) string { return fmt.Sprintf("%d 点额度", quota) } } - -func RenderStringData(c *gin.Context, data string) { - data = strings.TrimPrefix(data, "data: ") - c.Render(-1, CustomEvent{Data: "data: " + strings.TrimSuffix(data, "\r")}) - c.Writer.Flush() -} - -func RenderData(c *gin.Context, response interface{}) error { - jsonResponse, err := json.Marshal(response) - if err != nil { - return fmt.Errorf("error marshalling stream response: %w", err) - } - RenderStringData(c, string(jsonResponse)) - return nil -} diff --git a/relay/adaptor/aiproxy/main.go b/relay/adaptor/aiproxy/main.go index f29ce43f..d64b6809 100644 --- a/relay/adaptor/aiproxy/main.go +++ b/relay/adaptor/aiproxy/main.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strconv" @@ -124,7 +125,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC documents = AIProxyLibraryResponse.Documents } response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) - err = common.RenderData(c, response) + err = render.ObjectData(c, response) if err != nil { logger.SysError(err.Error()) } @@ -135,11 +136,11 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } response := documentsAIProxyLibrary(documents) - err := common.RenderData(c, response) + err := render.ObjectData(c, response) if err != nil { logger.SysError(err.Error()) } - common.RenderStringData(c, "[DONE]") + render.Done(c) err = resp.Body.Close() if err != nil { diff --git a/relay/adaptor/ali/main.go b/relay/adaptor/ali/main.go index 63fa7e37..f9039dbe 100644 --- a/relay/adaptor/ali/main.go +++ b/relay/adaptor/ali/main.go @@ -3,6 +3,7 @@ package ali import ( "bufio" "encoding/json" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -207,7 +208,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC if response == nil { continue } - err = common.RenderData(c, response) + err = render.ObjectData(c, response) if err != nil { logger.SysError(err.Error()) } @@ -217,7 +218,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC logger.SysError("error reading stream: " + err.Error()) } - common.RenderStringData(c, "[DONE]") + render.Done(c) err := resp.Body.Close() if err != nil { diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go index 4f511057..c817a9d1 100644 --- a/relay/adaptor/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -206,7 +207,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC response.Id = id response.Model = modelName response.Created = createdTime - err = common.RenderData(c, response) + err = render.ObjectData(c, response) if err != nil { logger.SysError(err.Error()) } @@ -216,7 +217,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC logger.SysError("error reading stream: " + err.Error()) } - common.RenderStringData(c, "[DONE]") + render.Done(c) err := resp.Body.Close() if err != nil { diff --git a/relay/adaptor/baidu/main.go b/relay/adaptor/baidu/main.go index 7a3bd94d..ebe70c32 100644 --- a/relay/adaptor/baidu/main.go +++ b/relay/adaptor/baidu/main.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -161,7 +162,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens } response := streamResponseBaidu2OpenAI(&baiduResponse) - err = common.RenderData(c, response) + err = render.ObjectData(c, response) if err != nil { logger.SysError(err.Error()) } @@ -171,7 +172,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC logger.SysError("error reading stream: " + err.Error()) } - common.RenderStringData(c, "[DONE]") + render.Done(c) err := resp.Body.Close() if err != nil { diff --git a/relay/adaptor/cloudflare/main.go b/relay/adaptor/cloudflare/main.go index 617ab27e..c76520a2 100644 --- a/relay/adaptor/cloudflare/main.go +++ b/relay/adaptor/cloudflare/main.go @@ -3,6 +3,7 @@ package cloudflare import ( "bufio" "encoding/json" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -92,7 +93,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelN response.Id = id response.Model = responseModel - err = common.RenderData(c, response) + err = render.ObjectData(c, response) if err != nil { logger.SysError(err.Error()) } @@ -102,7 +103,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelN logger.SysError("error reading stream: " + err.Error()) } - common.RenderStringData(c, "[DONE]") + render.Done(c) err := resp.Body.Close() if err != nil { diff --git a/relay/adaptor/cohere/main.go b/relay/adaptor/cohere/main.go index 661cf8b7..45db437b 100644 --- a/relay/adaptor/cohere/main.go +++ b/relay/adaptor/cohere/main.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -163,7 +164,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC response.Model = c.GetString("original_model") response.Created = createdTime - err = common.RenderData(c, response) + err = render.ObjectData(c, response) if err != nil { logger.SysError(err.Error()) } @@ -173,7 +174,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC logger.SysError("error reading stream: " + err.Error()) } - common.RenderStringData(c, "[DONE]") + render.Done(c) err := resp.Body.Close() if err != nil { diff --git a/relay/adaptor/coze/main.go b/relay/adaptor/coze/main.go index 6993f01f..d0402a76 100644 --- a/relay/adaptor/coze/main.go +++ b/relay/adaptor/coze/main.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -141,7 +142,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC response.Model = modelName response.Created = createdTime - err = common.RenderData(c, response) + err = render.ObjectData(c, response) if err != nil { logger.SysError(err.Error()) } @@ -151,7 +152,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC logger.SysError("error reading stream: " + err.Error()) } - common.RenderStringData(c, "[DONE]") + render.Done(c) err := resp.Body.Close() if err != nil { diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index 315c01b3..51fd6aa8 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -302,7 +303,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC responseText += response.Choices[0].Delta.StringContent() - err = common.RenderData(c, response) + err = render.ObjectData(c, response) if err != nil { logger.SysError(err.Error()) } @@ -312,7 +313,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC logger.SysError("error reading stream: " + err.Error()) } - common.RenderStringData(c, "[DONE]") + render.Done(c) err := resp.Body.Close() if err != nil { diff --git a/relay/adaptor/ollama/main.go b/relay/adaptor/ollama/main.go index b039a2d8..936a7e14 100644 --- a/relay/adaptor/ollama/main.go +++ b/relay/adaptor/ollama/main.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -134,7 +135,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } response := streamResponseOllama2OpenAI(&ollamaResponse) - err = common.RenderData(c, response) + err = render.ObjectData(c, response) if err != nil { logger.SysError(err.Error()) } @@ -144,7 +145,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC logger.SysError("error reading stream: " + err.Error()) } - common.RenderStringData(c, "[DONE]") + render.Done(c) err := resp.Body.Close() if err != nil { diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index 098b3c06..1d534644 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "encoding/json" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -39,7 +40,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E continue } if strings.HasPrefix(data[dataPrefixLength:], done) { - common.RenderStringData(c, data) + render.StringData(c, data) continue } switch relayMode { @@ -48,14 +49,14 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) if err != nil { logger.SysError("error unmarshalling stream response: " + err.Error()) - common.RenderStringData(c, data) // if error happened, pass the data to client - continue // just ignore the error + render.StringData(c, data) // if error happened, pass the data to client + continue // just ignore the error } if len(streamResponse.Choices) == 0 { // but for empty choice, we should not pass it to client, this is for azure continue // just ignore empty choice } - common.RenderStringData(c, data) + render.StringData(c, data) for _, choice := range streamResponse.Choices { responseText += conv.AsString(choice.Delta.Content) } @@ -63,7 +64,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E usage = streamResponse.Usage } case relaymode.Completions: - common.RenderStringData(c, data) + render.StringData(c, data) var streamResponse CompletionsStreamResponse err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) if err != nil { @@ -80,7 +81,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E logger.SysError("error reading stream: " + err.Error()) } - common.RenderStringData(c, "[DONE]") + render.Done(c) err := resp.Body.Close() if err != nil { diff --git a/relay/adaptor/palm/palm.go b/relay/adaptor/palm/palm.go index ed9e7d9e..d31784ec 100644 --- a/relay/adaptor/palm/palm.go +++ b/relay/adaptor/palm/palm.go @@ -3,6 +3,7 @@ package palm import ( "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" @@ -116,12 +117,12 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), "" } - err = common.RenderData(c, string(jsonResponse)) + err = render.ObjectData(c, string(jsonResponse)) if err != nil { logger.SysError(err.Error()) } - common.RenderStringData(c, "[DONE]") + render.Done(c) return nil, responseText } diff --git a/relay/adaptor/tencent/main.go b/relay/adaptor/tencent/main.go index 535d6bae..365e33ae 100644 --- a/relay/adaptor/tencent/main.go +++ b/relay/adaptor/tencent/main.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strconv" @@ -111,7 +112,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC responseText += conv.AsString(response.Choices[0].Delta.Content) } - err = common.RenderData(c, response) + err = render.ObjectData(c, response) if err != nil { logger.SysError(err.Error()) } @@ -121,7 +122,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC logger.SysError("error reading stream: " + err.Error()) } - common.RenderStringData(c, "[DONE]") + render.Done(c) err := resp.Body.Close() if err != nil { diff --git a/relay/adaptor/zhipu/main.go b/relay/adaptor/zhipu/main.go index 10b5434c..ab3a5678 100644 --- a/relay/adaptor/zhipu/main.go +++ b/relay/adaptor/zhipu/main.go @@ -3,6 +3,7 @@ package zhipu import ( "bufio" "encoding/json" + "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" @@ -172,7 +173,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC dataSegment += "\n" } response := streamResponseZhipu2OpenAI(dataSegment) - err := common.RenderData(c, response) + err := render.ObjectData(c, response) if err != nil { logger.SysError("error marshalling stream response: " + err.Error()) } @@ -185,7 +186,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC continue } response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) - err = common.RenderData(c, response) + err = render.ObjectData(c, response) if err != nil { logger.SysError("error marshalling stream response: " + err.Error()) } @@ -198,7 +199,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC logger.SysError("error reading stream: " + err.Error()) } - common.RenderStringData(c, "[DONE]") + render.Done(c) err := resp.Body.Close() if err != nil {