diff --git a/common/utils.go b/common/utils.go index 21bec8f5..b21f5a89 100644 --- a/common/utils.go +++ b/common/utils.go @@ -13,6 +13,7 @@ import ( "strconv" "strings" "time" + "unsafe" ) func OpenBrowser(url string) { @@ -138,11 +139,12 @@ func GetUUID() string { const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" func init() { - rand.Seed(time.Now().UnixNano()) + if !strings.Contains(runtime.Version(), "go1.2") { // go1.20之前版本需要全局 seed,其他插件无需再 seed + rand.Seed(time.Now().UnixNano()) //nolint: staticcheck + } } func GenerateKey() string { - rand.Seed(time.Now().UnixNano()) key := make([]byte, 48) for i := 0; i < 16; i++ { key[i] = keyChars[rand.Intn(len(keyChars))] @@ -159,7 +161,6 @@ func GenerateKey() string { } func GetRandomString(length int) string { - rand.Seed(time.Now().UnixNano()) key := make([]byte, length) for i := 0; i < length; i++ { key[i] = keyChars[rand.Intn(len(keyChars))] @@ -207,3 +208,10 @@ func String2Int(str string) int { } return num } + +// []byte only read, panic on append +func StringToByteSlice(s string) []byte { + tmp1 := (*[2]uintptr)(unsafe.Pointer(&s)) + tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]} + return *(*[]byte)(unsafe.Pointer(&tmp2)) +} diff --git a/controller/relay-openai.go b/controller/relay-openai.go index 37867843..6ae34ae1 100644 --- a/controller/relay-openai.go +++ b/controller/relay-openai.go @@ -9,10 +9,12 @@ import ( "net/http" "one-api/common" "strings" + "sync" + "time" ) func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { - responseText := "" + var responseTextBuilder strings.Builder scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -26,9 +28,16 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O } return 0, nil, nil }) - dataChan := make(chan string) - stopChan := make(chan bool) + dataChan := make(chan string, 5) + stopChan := make(chan bool, 2) + defer close(stopChan) + defer close(dataChan) + + var wg sync.WaitGroup go func() { + wg.Add(1) + defer wg.Done() + var streamItems []string for scanner.Scan() { data := scanner.Text() if len(data) < 6 { // ignore blank line or wrong format @@ -40,29 +49,39 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O dataChan <- data data = data[6:] if !strings.HasPrefix(data, "[DONE]") { - switch relayMode { - case RelayModeChatCompletions: - var streamResponse ChatCompletionsStreamResponse - err := json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - continue // just ignore the error - } - for _, choice := range streamResponse.Choices { - responseText += choice.Delta.Content - } - case RelayModeCompletions: - var streamResponse CompletionsStreamResponse - err := json.Unmarshal([]byte(data), &streamResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - continue - } - for _, choice := range streamResponse.Choices { - responseText += choice.Text - } + streamItems = append(streamItems, data) + } + } + streamResp := "[" + strings.Join(streamItems, ",") + "]" + switch relayMode { + case RelayModeChatCompletions: + var streamResponses []ChatCompletionsStreamResponseSimple + err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return // just ignore the error + } + for _, streamResponse := range streamResponses { + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Delta.Content) } } + case RelayModeCompletions: + var streamResponses []CompletionsStreamResponse + err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return // just ignore the error + } + for _, streamResponse := range streamResponses { + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Text) + } + } + } + if len(dataChan) > 0 { + // wait data out + time.Sleep(2 * time.Second) } stopChan <- true }() @@ -85,7 +104,8 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O if err != nil { return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } - return nil, responseText + wg.Wait() + return nil, responseTextBuilder.String() } func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { @@ -110,7 +130,6 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model } // Reset response body resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. // And then we will have to send an error response, but in this case, the header has already been set. // So the httpClient will be confused by the response. diff --git a/controller/relay.go b/controller/relay.go index f91ba6da..1534997a 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -226,6 +226,10 @@ type ChatCompletionsStreamResponse struct { Choices []ChatCompletionsStreamResponseChoice `json:"choices"` } +type ChatCompletionsStreamResponseSimple struct { + Choices []ChatCompletionsStreamResponseChoice `json:"choices"` +} + type CompletionsStreamResponse struct { Choices []struct { Text string `json:"text"` diff --git a/go.mod b/go.mod index 10b78d68..47a281d8 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 - github.com/pkoukk/tiktoken-go v0.1.5 + github.com/pkoukk/tiktoken-go v0.1.6 golang.org/x/crypto v0.14.0 gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.5.2 diff --git a/go.sum b/go.sum index 4865bcaa..65761c9a 100644 --- a/go.sum +++ b/go.sum @@ -118,8 +118,8 @@ github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZO github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= -github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4= -github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= +github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= +github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= diff --git a/main.go b/main.go index 88938516..b8c0f138 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" + "net/http" "one-api/common" "one-api/controller" "one-api/middleware" @@ -13,6 +14,8 @@ import ( "one-api/router" "os" "strconv" + + _ "net/http/pprof" ) //go:embed web/build @@ -30,6 +33,16 @@ func main() { if common.DebugEnabled { common.SysLog("running in debug mode") } + if os.Getenv("ENABLE_PPROF") == "true" { + go func() { + err := http.ListenAndServe("0.0.0.0:8005", nil) + if err != nil { + common.FatalLog("pprof enabled failed: " + err.Error()) + } else { + common.SysLog("pprof enabled") + } + }() + } // Initialize SQL Database err := model.InitDB() if err != nil {