This commit is contained in:
Xyfacai 2023-12-02 02:21:39 +00:00 committed by GitHub
commit f4b29b03cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 76 additions and 32 deletions

View File

@ -13,6 +13,7 @@ import (
"strconv"
"strings"
"time"
"unsafe"
)
func OpenBrowser(url string) {
@ -138,11 +139,12 @@ func GetUUID() string {
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() {
rand.Seed(time.Now().UnixNano())
if !strings.Contains(runtime.Version(), "go1.2") { // go1.20之前版本需要全局 seed其他插件无需再 seed
rand.Seed(time.Now().UnixNano()) //nolint: staticcheck
}
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
@ -159,7 +161,6 @@ func GenerateKey() string {
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
@ -207,3 +208,10 @@ func String2Int(str string) int {
}
return num
}
// []byte only read, panic on append
func StringToByteSlice(s string) []byte {
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
return *(*[]byte)(unsafe.Pointer(&tmp2))
}

View File

@ -9,10 +9,12 @@ import (
"net/http"
"one-api/common"
"strings"
"sync"
"time"
)
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
responseText := ""
var responseTextBuilder strings.Builder
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
@ -26,9 +28,16 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
dataChan := make(chan string, 5)
stopChan := make(chan bool, 2)
defer close(stopChan)
defer close(dataChan)
var wg sync.WaitGroup
go func() {
wg.Add(1)
defer wg.Done()
var streamItems []string
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
@ -40,29 +49,39 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
dataChan <- data
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
streamItems = append(streamItems, data)
}
}
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode {
case RelayModeChatCompletions:
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
var streamResponses []ChatCompletionsStreamResponseSimple
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue // just ignore the error
return // just ignore the error
}
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseText += choice.Delta.Content
responseTextBuilder.WriteString(choice.Delta.Content)
}
}
case RelayModeCompletions:
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
var streamResponses []CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
return // just ignore the error
}
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseText += choice.Text
responseTextBuilder.WriteString(choice.Text)
}
}
}
if len(dataChan) > 0 {
// wait data out
time.Sleep(2 * time.Second)
}
stopChan <- true
}()
@ -85,7 +104,8 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
wg.Wait()
return nil, responseTextBuilder.String()
}
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
@ -110,7 +130,6 @@ func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.

View File

@ -226,6 +226,10 @@ type ChatCompletionsStreamResponse struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type ChatCompletionsStreamResponseSimple struct {
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type CompletionsStreamResponse struct {
Choices []struct {
Text string `json:"text"`

2
go.mod
View File

@ -14,7 +14,7 @@ require (
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5
github.com/pkoukk/tiktoken-go v0.1.6
golang.org/x/crypto v0.14.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2

4
go.sum
View File

@ -118,8 +118,8 @@ github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZO
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4=
github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=

13
main.go
View File

@ -6,6 +6,7 @@ import (
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/controller"
"one-api/middleware"
@ -13,6 +14,8 @@ import (
"one-api/router"
"os"
"strconv"
_ "net/http/pprof"
)
//go:embed web/build
@ -30,6 +33,16 @@ func main() {
if common.DebugEnabled {
common.SysLog("running in debug mode")
}
if os.Getenv("ENABLE_PPROF") == "true" {
go func() {
err := http.ListenAndServe("0.0.0.0:8005", nil)
if err != nil {
common.FatalLog("pprof enabled failed: " + err.Error())
} else {
common.SysLog("pprof enabled")
}
}()
}
// Initialize SQL Database
err := model.InitDB()
if err != nil {