From 481c4ebf49bb23c92317428acfd6b52331d8ebb2 Mon Sep 17 00:00:00 2001 From: ckt1031 <65409152+ckt1031@users.noreply.github.com> Date: Sun, 16 Jul 2023 15:35:32 +0800 Subject: [PATCH] fix: chatgptweb issue --- controller/channel-test.go | 73 ++++++++++++++++++------------ controller/relay-text.go | 93 ++++++++++++++++++++------------------ 2 files changed, 92 insertions(+), 74 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index b48110e9..7048c49d 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "log" "net/http" "one-api/common" "one-api/model" @@ -18,6 +17,13 @@ import ( "github.com/gin-gonic/gin" ) +func formatFloat(input float64) float64 { + if input == float64(int64(input)) { + return input + } + return float64(int64(input*10)) / 10 +} + func testChannel(channel *model.Channel, request ChatRequest) error { switch channel.Type { case common.ChannelTypeAzure: @@ -65,11 +71,16 @@ func testChannel(channel *model.Channel, request ChatRequest) error { } // Construct json data without adding escape character - map1 := map[string]string{ - "prompt": prompt, - "systemMessage": systemMessage.Content, - "temperature": strconv.FormatFloat(request.Temperature, 'f', 2, 64), - "top_p": strconv.FormatFloat(request.TopP, 'f', 2, 64), + map1 := make(map[string]interface{}) + + map1["prompt"] = prompt + map1["systemMessage"] = systemMessage.Content + + if request.Temperature != 0 { + map1["temperature"] = formatFloat(request.Temperature) + } + if request.TopP != 0 { + map1["top_p"] = formatFloat(request.TopP) } // Convert map to json string @@ -122,8 +133,9 @@ func testChannel(channel *model.Channel, request ChatRequest) error { var done = false var streamResponseText = "" + scanner := bufio.NewScanner(resp.Body) + if channel.Type != common.ChannelTypeChatGPTWeb { - scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { return 0, nil, nil @@ -139,11 +151,14 @@ func testChannel(channel *model.Channel, request ChatRequest) error { return 0, nil, nil }) - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { // must be something wrong! - continue - } + } + + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { // must be something wrong! + continue + } + if channel.Type != common.ChannelTypeChatGPTWeb { // If data has event: event content inside, remove it, it can be prefix or inside the data if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") { @@ -185,28 +200,26 @@ func testChannel(channel *model.Channel, request ChatRequest) error { done = true break } - } - } else if channel.Type == common.ChannelTypeChatGPTWeb { - scanner := bufio.NewScanner(resp.Body) - go func() { - for scanner.Scan() { - var chatResponse ChatGptWebChatResponse - err = json.Unmarshal(scanner.Bytes(), &chatResponse) + } else if channel.Type == common.ChannelTypeChatGPTWeb { + var chatResponse ChatGptWebChatResponse + err = json.Unmarshal([]byte(data), &chatResponse) + if err != nil { + // Print the body in string + buf := new(bytes.Buffer) + buf.ReadFrom(resp.Body) + common.SysError("error unmarshalling chat response: " + err.Error() + " " + buf.String()) + return err + } - if err != nil { - log.Println("error unmarshal chat response: " + err.Error()) - continue - } - - // if response role is assistant and contains delta, append the content to streamResponseText - if chatResponse.Role == "assistant" && chatResponse.Detail != nil { - for _, choice := range chatResponse.Detail.Choices { - streamResponseText += choice.Delta.Content - } + // if response role is assistant and contains delta, append the content to streamResponseText + if chatResponse.Role == "assistant" && chatResponse.Detail != nil { + for _, choice := range chatResponse.Detail.Choices { + streamResponseText += choice.Delta.Content } } - }() + + } } defer resp.Body.Close() diff --git a/controller/relay-text.go b/controller/relay-text.go index 13e0f793..342e6a8e 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -11,7 +11,6 @@ import ( "net/http" "one-api/common" "one-api/model" - "strconv" "strings" "github.com/gin-gonic/gin" @@ -223,11 +222,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } // Construct json data without adding escape character - map1 := map[string]string{ - "prompt": prompt, - "systemMessage": systemMessage.Content, - "temperature": strconv.FormatFloat(reqBody.Temperature, 'f', 2, 64), - "top_p": strconv.FormatFloat(reqBody.TopP, 'f', 2, 64), + map1 := make(map[string]interface{}) + + map1["prompt"] = prompt + map1["systemMessage"] = systemMessage.Content + + if reqBody.Temperature != 0 { + map1["temperature"] = formatFloat(reqBody.Temperature) + } + if reqBody.TopP != 0 { + map1["top_p"] = formatFloat(reqBody.TopP) } // Convert map to json string @@ -348,16 +352,42 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { dataChan := make(chan string) stopChan := make(chan bool) - if channelType == common.ChannelTypeChatGPTWeb { - scanner := bufio.NewScanner(resp.Body) - go func() { - for scanner.Scan() { - var chatResponse ChatGptWebChatResponse - err = json.Unmarshal(scanner.Bytes(), &chatResponse) + scanner := bufio.NewScanner(resp.Body) + if channelType != common.ChannelTypeChatGPTWeb { + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 2, data[0:i], nil + } + + if atEOF { + return len(data), data, nil + } + + return 0, nil, nil + }) + } + + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { // must be something wrong! + continue + } + + if channelType == common.ChannelTypeChatGPTWeb { + var chatResponse ChatGptWebChatResponse + err = json.Unmarshal([]byte(data), &chatResponse) if err != nil { - log.Println("error unmarshal chat response: " + err.Error()) - continue + // Print the body in string + buf := new(bytes.Buffer) + buf.ReadFrom(resp.Body) + common.SysError("error unmarshalling chat response: " + err.Error() + " " + buf.String()) + return } // if response role is assistant and contains delta, append the content to streamResponseText @@ -387,33 +417,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { dataChan <- "data: " + string(jsonData) } } - } - stopChan <- true - }() - } else { - scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - - if atEOF { - return len(data), data, nil - } - - return 0, nil, nil - }) - go func() { - for scanner.Scan() { - data := scanner.Text() - if len(data) < 6 { // must be something wrong! - // common.SysError("invalid stream response: " + data) - continue - } + } else { // If data has event: event content inside, remove it, it can be prefix or inside the data if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") { // Remove event: event in the front or back @@ -463,10 +467,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } } + } - stopChan <- true - }() - } + } + stopChan <- true + }() c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Cache-Control", "no-cache")