fix: chatgptweb issue

This commit is contained in:
ckt1031 2023-07-16 15:35:32 +08:00
parent 203471d7a9
commit 481c4ebf49
2 changed files with 92 additions and 74 deletions

View File

@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
@ -18,6 +17,13 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func formatFloat(input float64) float64 {
if input == float64(int64(input)) {
return input
}
return float64(int64(input*10)) / 10
}
func testChannel(channel *model.Channel, request ChatRequest) error { func testChannel(channel *model.Channel, request ChatRequest) error {
switch channel.Type { switch channel.Type {
case common.ChannelTypeAzure: case common.ChannelTypeAzure:
@ -65,11 +71,16 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
} }
// Construct json data without adding escape character // Construct json data without adding escape character
map1 := map[string]string{ map1 := make(map[string]interface{})
"prompt": prompt,
"systemMessage": systemMessage.Content, map1["prompt"] = prompt
"temperature": strconv.FormatFloat(request.Temperature, 'f', 2, 64), map1["systemMessage"] = systemMessage.Content
"top_p": strconv.FormatFloat(request.TopP, 'f', 2, 64),
if request.Temperature != 0 {
map1["temperature"] = formatFloat(request.Temperature)
}
if request.TopP != 0 {
map1["top_p"] = formatFloat(request.TopP)
} }
// Convert map to json string // Convert map to json string
@ -122,8 +133,9 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
var done = false var done = false
var streamResponseText = "" var streamResponseText = ""
scanner := bufio.NewScanner(resp.Body)
if channel.Type != common.ChannelTypeChatGPTWeb { if channel.Type != common.ChannelTypeChatGPTWeb {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
return 0, nil, nil return 0, nil, nil
@ -139,11 +151,14 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
return 0, nil, nil return 0, nil, nil
}) })
for scanner.Scan() { }
data := scanner.Text()
if len(data) < 6 { // must be something wrong! for scanner.Scan() {
continue data := scanner.Text()
} if len(data) < 6 { // must be something wrong!
continue
}
if channel.Type != common.ChannelTypeChatGPTWeb {
// If data has event: event content inside, remove it, it can be prefix or inside the data // If data has event: event content inside, remove it, it can be prefix or inside the data
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") { if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
@ -185,28 +200,26 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
done = true done = true
break break
} }
}
} else if channel.Type == common.ChannelTypeChatGPTWeb {
scanner := bufio.NewScanner(resp.Body)
go func() { } else if channel.Type == common.ChannelTypeChatGPTWeb {
for scanner.Scan() { var chatResponse ChatGptWebChatResponse
var chatResponse ChatGptWebChatResponse err = json.Unmarshal([]byte(data), &chatResponse)
err = json.Unmarshal(scanner.Bytes(), &chatResponse) if err != nil {
// Print the body in string
buf := new(bytes.Buffer)
buf.ReadFrom(resp.Body)
common.SysError("error unmarshalling chat response: " + err.Error() + " " + buf.String())
return err
}
if err != nil { // if response role is assistant and contains delta, append the content to streamResponseText
log.Println("error unmarshal chat response: " + err.Error()) if chatResponse.Role == "assistant" && chatResponse.Detail != nil {
continue for _, choice := range chatResponse.Detail.Choices {
} streamResponseText += choice.Delta.Content
// if response role is assistant and contains delta, append the content to streamResponseText
if chatResponse.Role == "assistant" && chatResponse.Detail != nil {
for _, choice := range chatResponse.Detail.Choices {
streamResponseText += choice.Delta.Content
}
} }
} }
}()
}
} }
defer resp.Body.Close() defer resp.Body.Close()

View File

@ -11,7 +11,6 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -223,11 +222,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
} }
// Construct json data without adding escape character // Construct json data without adding escape character
map1 := map[string]string{ map1 := make(map[string]interface{})
"prompt": prompt,
"systemMessage": systemMessage.Content, map1["prompt"] = prompt
"temperature": strconv.FormatFloat(reqBody.Temperature, 'f', 2, 64), map1["systemMessage"] = systemMessage.Content
"top_p": strconv.FormatFloat(reqBody.TopP, 'f', 2, 64),
if reqBody.Temperature != 0 {
map1["temperature"] = formatFloat(reqBody.Temperature)
}
if reqBody.TopP != 0 {
map1["top_p"] = formatFloat(reqBody.TopP)
} }
// Convert map to json string // Convert map to json string
@ -348,16 +352,42 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
dataChan := make(chan string) dataChan := make(chan string)
stopChan := make(chan bool) stopChan := make(chan bool)
if channelType == common.ChannelTypeChatGPTWeb { scanner := bufio.NewScanner(resp.Body)
scanner := bufio.NewScanner(resp.Body)
go func() {
for scanner.Scan() {
var chatResponse ChatGptWebChatResponse
err = json.Unmarshal(scanner.Bytes(), &chatResponse)
if channelType != common.ChannelTypeChatGPTWeb {
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 2, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
}
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // must be something wrong!
continue
}
if channelType == common.ChannelTypeChatGPTWeb {
var chatResponse ChatGptWebChatResponse
err = json.Unmarshal([]byte(data), &chatResponse)
if err != nil { if err != nil {
log.Println("error unmarshal chat response: " + err.Error()) // Print the body in string
continue buf := new(bytes.Buffer)
buf.ReadFrom(resp.Body)
common.SysError("error unmarshalling chat response: " + err.Error() + " " + buf.String())
return
} }
// if response role is assistant and contains delta, append the content to streamResponseText // if response role is assistant and contains delta, append the content to streamResponseText
@ -387,33 +417,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
dataChan <- "data: " + string(jsonData) dataChan <- "data: " + string(jsonData)
} }
} }
} } else {
stopChan <- true
}()
} else {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // must be something wrong!
// common.SysError("invalid stream response: " + data)
continue
}
// If data has event: event content inside, remove it, it can be prefix or inside the data // If data has event: event content inside, remove it, it can be prefix or inside the data
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") { if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
// Remove event: event in the front or back // Remove event: event in the front or back
@ -463,10 +467,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
} }
} }
} }
stopChan <- true }
}() stopChan <- true
} }()
c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Cache-Control", "no-cache")