refactor: move functions to render package

This commit is contained in:
JustSong 2024-06-30 18:26:47 +08:00
parent 1923ded809
commit 616759933a
15 changed files with 70 additions and 51 deletions

25
common/render/render.go Normal file
View File

@ -0,0 +1,25 @@
package render
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
)
func StringData(c *gin.Context, str string) {
c.Render(-1, common.CustomEvent{Data: "data: " + str})
}
func ObjectData(c *gin.Context, object interface{}) error {
jsonData, err := json.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
StringData(c, string(jsonData))
return nil
}
func Done(c *gin.Context) {
StringData(c, "[DONE]")
}

View File

@ -1,11 +1,7 @@
package common
import (
"encoding/json"
"fmt"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
)
@ -16,18 +12,3 @@ func LogQuota(quota int64) string {
return fmt.Sprintf("%d 点额度", quota)
}
}
func RenderStringData(c *gin.Context, data string) {
data = strings.TrimPrefix(data, "data: ")
c.Render(-1, CustomEvent{Data: "data: " + strings.TrimSuffix(data, "\r")})
c.Writer.Flush()
}
func RenderData(c *gin.Context, response interface{}) error {
jsonResponse, err := json.Marshal(response)
if err != nil {
return fmt.Errorf("error marshalling stream response: %w", err)
}
RenderStringData(c, string(jsonResponse))
return nil
}

View File

@ -4,6 +4,7 @@ import (
"bufio"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strconv"
@ -124,7 +125,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
documents = AIProxyLibraryResponse.Documents
}
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
err = common.RenderData(c, response)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
@ -135,11 +136,11 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}
response := documentsAIProxyLibrary(documents)
err := common.RenderData(c, response)
err := render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
common.RenderStringData(c, "[DONE]")
render.Done(c)
err = resp.Body.Close()
if err != nil {

View File

@ -3,6 +3,7 @@ package ali
import (
"bufio"
"encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
@ -207,7 +208,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
if response == nil {
continue
}
err = common.RenderData(c, response)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
@ -217,7 +218,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error())
}
common.RenderStringData(c, "[DONE]")
render.Done(c)
err := resp.Body.Close()
if err != nil {

View File

@ -4,6 +4,7 @@ import (
"bufio"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
@ -206,7 +207,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
response.Id = id
response.Model = modelName
response.Created = createdTime
err = common.RenderData(c, response)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
@ -216,7 +217,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error())
}
common.RenderStringData(c, "[DONE]")
render.Done(c)
err := resp.Body.Close()
if err != nil {

View File

@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
@ -161,7 +162,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
}
response := streamResponseBaidu2OpenAI(&baiduResponse)
err = common.RenderData(c, response)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
@ -171,7 +172,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error())
}
common.RenderStringData(c, "[DONE]")
render.Done(c)
err := resp.Body.Close()
if err != nil {

View File

@ -3,6 +3,7 @@ package cloudflare
import (
"bufio"
"encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
@ -92,7 +93,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelN
response.Id = id
response.Model = responseModel
err = common.RenderData(c, response)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
@ -102,7 +103,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelN
logger.SysError("error reading stream: " + err.Error())
}
common.RenderStringData(c, "[DONE]")
render.Done(c)
err := resp.Body.Close()
if err != nil {

View File

@ -4,6 +4,7 @@ import (
"bufio"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
@ -163,7 +164,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
response.Model = c.GetString("original_model")
response.Created = createdTime
err = common.RenderData(c, response)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
@ -173,7 +174,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error())
}
common.RenderStringData(c, "[DONE]")
render.Done(c)
err := resp.Body.Close()
if err != nil {

View File

@ -4,6 +4,7 @@ import (
"bufio"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
@ -141,7 +142,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
response.Model = modelName
response.Created = createdTime
err = common.RenderData(c, response)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
@ -151,7 +152,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error())
}
common.RenderStringData(c, "[DONE]")
render.Done(c)
err := resp.Body.Close()
if err != nil {

View File

@ -4,6 +4,7 @@ import (
"bufio"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
@ -302,7 +303,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
responseText += response.Choices[0].Delta.StringContent()
err = common.RenderData(c, response)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
@ -312,7 +313,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error())
}
common.RenderStringData(c, "[DONE]")
render.Done(c)
err := resp.Body.Close()
if err != nil {

View File

@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
@ -134,7 +135,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}
response := streamResponseOllama2OpenAI(&ollamaResponse)
err = common.RenderData(c, response)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
@ -144,7 +145,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error())
}
common.RenderStringData(c, "[DONE]")
render.Done(c)
err := resp.Body.Close()
if err != nil {

View File

@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
@ -39,7 +40,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
continue
}
if strings.HasPrefix(data[dataPrefixLength:], done) {
common.RenderStringData(c, data)
render.StringData(c, data)
continue
}
switch relayMode {
@ -48,14 +49,14 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
common.RenderStringData(c, data) // if error happened, pass the data to client
continue // just ignore the error
render.StringData(c, data) // if error happened, pass the data to client
continue // just ignore the error
}
if len(streamResponse.Choices) == 0 {
// but for empty choice, we should not pass it to client, this is for azure
continue // just ignore empty choice
}
common.RenderStringData(c, data)
render.StringData(c, data)
for _, choice := range streamResponse.Choices {
responseText += conv.AsString(choice.Delta.Content)
}
@ -63,7 +64,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
usage = streamResponse.Usage
}
case relaymode.Completions:
common.RenderStringData(c, data)
render.StringData(c, data)
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil {
@ -80,7 +81,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
logger.SysError("error reading stream: " + err.Error())
}
common.RenderStringData(c, "[DONE]")
render.Done(c)
err := resp.Body.Close()
if err != nil {

View File

@ -3,6 +3,7 @@ package palm
import (
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
@ -116,12 +117,12 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), ""
}
err = common.RenderData(c, string(jsonResponse))
err = render.ObjectData(c, string(jsonResponse))
if err != nil {
logger.SysError(err.Error())
}
common.RenderStringData(c, "[DONE]")
render.Done(c)
return nil, responseText
}

View File

@ -8,6 +8,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strconv"
@ -111,7 +112,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
responseText += conv.AsString(response.Choices[0].Delta.Content)
}
err = common.RenderData(c, response)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
}
@ -121,7 +122,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error())
}
common.RenderStringData(c, "[DONE]")
render.Done(c)
err := resp.Body.Close()
if err != nil {

View File

@ -3,6 +3,7 @@ package zhipu
import (
"bufio"
"encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io"
"net/http"
"strings"
@ -172,7 +173,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
dataSegment += "\n"
}
response := streamResponseZhipu2OpenAI(dataSegment)
err := common.RenderData(c, response)
err := render.ObjectData(c, response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
}
@ -185,7 +186,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
continue
}
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
err = common.RenderData(c, response)
err = render.ObjectData(c, response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
}
@ -198,7 +199,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error())
}
common.RenderStringData(c, "[DONE]")
render.Done(c)
err := resp.Body.Close()
if err != nil {