refactor: move functions to render package

This commit is contained in:
JustSong 2024-06-30 18:26:47 +08:00
parent 1923ded809
commit 616759933a
15 changed files with 70 additions and 51 deletions

25
common/render/render.go Normal file
View File

@ -0,0 +1,25 @@
package render
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
)
func StringData(c *gin.Context, str string) {
c.Render(-1, common.CustomEvent{Data: "data: " + str})
}
func ObjectData(c *gin.Context, object interface{}) error {
jsonData, err := json.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
StringData(c, string(jsonData))
return nil
}
func Done(c *gin.Context) {
StringData(c, "[DONE]")
}

View File

@ -1,11 +1,7 @@
package common package common
import ( import (
"encoding/json"
"fmt" "fmt"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
) )
@ -16,18 +12,3 @@ func LogQuota(quota int64) string {
return fmt.Sprintf("%d 点额度", quota) return fmt.Sprintf("%d 点额度", quota)
} }
} }
func RenderStringData(c *gin.Context, data string) {
data = strings.TrimPrefix(data, "data: ")
c.Render(-1, CustomEvent{Data: "data: " + strings.TrimSuffix(data, "\r")})
c.Writer.Flush()
}
func RenderData(c *gin.Context, response interface{}) error {
jsonResponse, err := json.Marshal(response)
if err != nil {
return fmt.Errorf("error marshalling stream response: %w", err)
}
RenderStringData(c, string(jsonResponse))
return nil
}

View File

@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strconv" "strconv"
@ -124,7 +125,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
documents = AIProxyLibraryResponse.Documents documents = AIProxyLibraryResponse.Documents
} }
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse) response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
err = common.RenderData(c, response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError(err.Error()) logger.SysError(err.Error())
} }
@ -135,11 +136,11 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
response := documentsAIProxyLibrary(documents) response := documentsAIProxyLibrary(documents)
err := common.RenderData(c, response) err := render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError(err.Error()) logger.SysError(err.Error())
} }
common.RenderStringData(c, "[DONE]") render.Done(c)
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {

View File

@ -3,6 +3,7 @@ package ali
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@ -207,7 +208,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
if response == nil { if response == nil {
continue continue
} }
err = common.RenderData(c, response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError(err.Error()) logger.SysError(err.Error())
} }
@ -217,7 +218,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error()) logger.SysError("error reading stream: " + err.Error())
} }
common.RenderStringData(c, "[DONE]") render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {

View File

@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@ -206,7 +207,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
response.Id = id response.Id = id
response.Model = modelName response.Model = modelName
response.Created = createdTime response.Created = createdTime
err = common.RenderData(c, response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError(err.Error()) logger.SysError(err.Error())
} }
@ -216,7 +217,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error()) logger.SysError("error reading stream: " + err.Error())
} }
common.RenderStringData(c, "[DONE]") render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@ -161,7 +162,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
} }
response := streamResponseBaidu2OpenAI(&baiduResponse) response := streamResponseBaidu2OpenAI(&baiduResponse)
err = common.RenderData(c, response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError(err.Error()) logger.SysError(err.Error())
} }
@ -171,7 +172,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error()) logger.SysError("error reading stream: " + err.Error())
} }
common.RenderStringData(c, "[DONE]") render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {

View File

@ -3,6 +3,7 @@ package cloudflare
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@ -92,7 +93,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelN
response.Id = id response.Id = id
response.Model = responseModel response.Model = responseModel
err = common.RenderData(c, response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError(err.Error()) logger.SysError(err.Error())
} }
@ -102,7 +103,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelN
logger.SysError("error reading stream: " + err.Error()) logger.SysError("error reading stream: " + err.Error())
} }
common.RenderStringData(c, "[DONE]") render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {

View File

@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@ -163,7 +164,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
response.Model = c.GetString("original_model") response.Model = c.GetString("original_model")
response.Created = createdTime response.Created = createdTime
err = common.RenderData(c, response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError(err.Error()) logger.SysError(err.Error())
} }
@ -173,7 +174,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error()) logger.SysError("error reading stream: " + err.Error())
} }
common.RenderStringData(c, "[DONE]") render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {

View File

@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@ -141,7 +142,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
response.Model = modelName response.Model = modelName
response.Created = createdTime response.Created = createdTime
err = common.RenderData(c, response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError(err.Error()) logger.SysError(err.Error())
} }
@ -151,7 +152,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error()) logger.SysError("error reading stream: " + err.Error())
} }
common.RenderStringData(c, "[DONE]") render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {

View File

@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@ -302,7 +303,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
responseText += response.Choices[0].Delta.StringContent() responseText += response.Choices[0].Delta.StringContent()
err = common.RenderData(c, response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError(err.Error()) logger.SysError(err.Error())
} }
@ -312,7 +313,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error()) logger.SysError("error reading stream: " + err.Error())
} }
common.RenderStringData(c, "[DONE]") render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@ -134,7 +135,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
response := streamResponseOllama2OpenAI(&ollamaResponse) response := streamResponseOllama2OpenAI(&ollamaResponse)
err = common.RenderData(c, response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError(err.Error()) logger.SysError(err.Error())
} }
@ -144,7 +145,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error()) logger.SysError("error reading stream: " + err.Error())
} }
common.RenderStringData(c, "[DONE]") render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {

View File

@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@ -39,7 +40,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
continue continue
} }
if strings.HasPrefix(data[dataPrefixLength:], done) { if strings.HasPrefix(data[dataPrefixLength:], done) {
common.RenderStringData(c, data) render.StringData(c, data)
continue continue
} }
switch relayMode { switch relayMode {
@ -48,14 +49,14 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
common.RenderStringData(c, data) // if error happened, pass the data to client render.StringData(c, data) // if error happened, pass the data to client
continue // just ignore the error continue // just ignore the error
} }
if len(streamResponse.Choices) == 0 { if len(streamResponse.Choices) == 0 {
// but for empty choice, we should not pass it to client, this is for azure // but for empty choice, we should not pass it to client, this is for azure
continue // just ignore empty choice continue // just ignore empty choice
} }
common.RenderStringData(c, data) render.StringData(c, data)
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseText += conv.AsString(choice.Delta.Content) responseText += conv.AsString(choice.Delta.Content)
} }
@ -63,7 +64,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
usage = streamResponse.Usage usage = streamResponse.Usage
} }
case relaymode.Completions: case relaymode.Completions:
common.RenderStringData(c, data) render.StringData(c, data)
var streamResponse CompletionsStreamResponse var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil { if err != nil {
@ -80,7 +81,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
logger.SysError("error reading stream: " + err.Error()) logger.SysError("error reading stream: " + err.Error())
} }
common.RenderStringData(c, "[DONE]") render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {

View File

@ -3,6 +3,7 @@ package palm
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
@ -116,12 +117,12 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), "" return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), ""
} }
err = common.RenderData(c, string(jsonResponse)) err = render.ObjectData(c, string(jsonResponse))
if err != nil { if err != nil {
logger.SysError(err.Error()) logger.SysError(err.Error())
} }
common.RenderStringData(c, "[DONE]") render.Done(c)
return nil, responseText return nil, responseText
} }

View File

@ -8,6 +8,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strconv" "strconv"
@ -111,7 +112,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
responseText += conv.AsString(response.Choices[0].Delta.Content) responseText += conv.AsString(response.Choices[0].Delta.Content)
} }
err = common.RenderData(c, response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError(err.Error()) logger.SysError(err.Error())
} }
@ -121,7 +122,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error()) logger.SysError("error reading stream: " + err.Error())
} }
common.RenderStringData(c, "[DONE]") render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {

View File

@ -3,6 +3,7 @@ package zhipu
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@ -172,7 +173,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
dataSegment += "\n" dataSegment += "\n"
} }
response := streamResponseZhipu2OpenAI(dataSegment) response := streamResponseZhipu2OpenAI(dataSegment)
err := common.RenderData(c, response) err := render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError("error marshalling stream response: " + err.Error())
} }
@ -185,7 +186,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
continue continue
} }
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
err = common.RenderData(c, response) err = render.ObjectData(c, response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) logger.SysError("error marshalling stream response: " + err.Error())
} }
@ -198,7 +199,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
logger.SysError("error reading stream: " + err.Error()) logger.SysError("error reading stream: " + err.Error())
} }
common.RenderStringData(c, "[DONE]") render.Done(c)
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {