fix: chatgptweb issue
This commit is contained in:
parent
203471d7a9
commit
481c4ebf49
@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
@ -18,6 +17,13 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func formatFloat(input float64) float64 {
|
||||
if input == float64(int64(input)) {
|
||||
return input
|
||||
}
|
||||
return float64(int64(input*10)) / 10
|
||||
}
|
||||
|
||||
func testChannel(channel *model.Channel, request ChatRequest) error {
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
@ -65,11 +71,16 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
||||
}
|
||||
|
||||
// Construct json data without adding escape character
|
||||
map1 := map[string]string{
|
||||
"prompt": prompt,
|
||||
"systemMessage": systemMessage.Content,
|
||||
"temperature": strconv.FormatFloat(request.Temperature, 'f', 2, 64),
|
||||
"top_p": strconv.FormatFloat(request.TopP, 'f', 2, 64),
|
||||
map1 := make(map[string]interface{})
|
||||
|
||||
map1["prompt"] = prompt
|
||||
map1["systemMessage"] = systemMessage.Content
|
||||
|
||||
if request.Temperature != 0 {
|
||||
map1["temperature"] = formatFloat(request.Temperature)
|
||||
}
|
||||
if request.TopP != 0 {
|
||||
map1["top_p"] = formatFloat(request.TopP)
|
||||
}
|
||||
|
||||
// Convert map to json string
|
||||
@ -122,8 +133,9 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
||||
var done = false
|
||||
var streamResponseText = ""
|
||||
|
||||
if channel.Type != common.ChannelTypeChatGPTWeb {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
|
||||
if channel.Type != common.ChannelTypeChatGPTWeb {
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
@ -139,11 +151,14 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
||||
|
||||
return 0, nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // must be something wrong!
|
||||
continue
|
||||
}
|
||||
if channel.Type != common.ChannelTypeChatGPTWeb {
|
||||
|
||||
// If data has event: event content inside, remove it, it can be prefix or inside the data
|
||||
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
|
||||
@ -185,18 +200,16 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
||||
done = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
} else if channel.Type == common.ChannelTypeChatGPTWeb {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
var chatResponse ChatGptWebChatResponse
|
||||
err = json.Unmarshal(scanner.Bytes(), &chatResponse)
|
||||
|
||||
err = json.Unmarshal([]byte(data), &chatResponse)
|
||||
if err != nil {
|
||||
log.Println("error unmarshal chat response: " + err.Error())
|
||||
continue
|
||||
// Print the body in string
|
||||
buf := new(bytes.Buffer)
|
||||
buf.ReadFrom(resp.Body)
|
||||
common.SysError("error unmarshalling chat response: " + err.Error() + " " + buf.String())
|
||||
return err
|
||||
}
|
||||
|
||||
// if response role is assistant and contains delta, append the content to streamResponseText
|
||||
@ -205,8 +218,8 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
||||
streamResponseText += choice.Delta.Content
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
@ -11,7 +11,6 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@ -223,11 +222,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
|
||||
// Construct json data without adding escape character
|
||||
map1 := map[string]string{
|
||||
"prompt": prompt,
|
||||
"systemMessage": systemMessage.Content,
|
||||
"temperature": strconv.FormatFloat(reqBody.Temperature, 'f', 2, 64),
|
||||
"top_p": strconv.FormatFloat(reqBody.TopP, 'f', 2, 64),
|
||||
map1 := make(map[string]interface{})
|
||||
|
||||
map1["prompt"] = prompt
|
||||
map1["systemMessage"] = systemMessage.Content
|
||||
|
||||
if reqBody.Temperature != 0 {
|
||||
map1["temperature"] = formatFloat(reqBody.Temperature)
|
||||
}
|
||||
if reqBody.TopP != 0 {
|
||||
map1["top_p"] = formatFloat(reqBody.TopP)
|
||||
}
|
||||
|
||||
// Convert map to json string
|
||||
@ -348,18 +352,44 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
|
||||
if channelType == common.ChannelTypeChatGPTWeb {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
|
||||
if channelType != common.ChannelTypeChatGPTWeb {
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 2, data[0:i], nil
|
||||
}
|
||||
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
|
||||
return 0, nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
var chatResponse ChatGptWebChatResponse
|
||||
err = json.Unmarshal(scanner.Bytes(), &chatResponse)
|
||||
|
||||
if err != nil {
|
||||
log.Println("error unmarshal chat response: " + err.Error())
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // must be something wrong!
|
||||
continue
|
||||
}
|
||||
|
||||
if channelType == common.ChannelTypeChatGPTWeb {
|
||||
var chatResponse ChatGptWebChatResponse
|
||||
err = json.Unmarshal([]byte(data), &chatResponse)
|
||||
if err != nil {
|
||||
// Print the body in string
|
||||
buf := new(bytes.Buffer)
|
||||
buf.ReadFrom(resp.Body)
|
||||
common.SysError("error unmarshalling chat response: " + err.Error() + " " + buf.String())
|
||||
return
|
||||
}
|
||||
|
||||
// if response role is assistant and contains delta, append the content to streamResponseText
|
||||
if chatResponse.Role == "assistant" && chatResponse.Detail != nil {
|
||||
for _, choice := range chatResponse.Detail.Choices {
|
||||
@ -387,33 +417,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
dataChan <- "data: " + string(jsonData)
|
||||
}
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
} else {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
|
||||
return 0, nil, nil
|
||||
})
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // must be something wrong!
|
||||
// common.SysError("invalid stream response: " + data)
|
||||
continue
|
||||
}
|
||||
// If data has event: event content inside, remove it, it can be prefix or inside the data
|
||||
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
|
||||
// Remove event: event in the front or back
|
||||
@ -463,10 +467,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
|
Loading…
Reference in New Issue
Block a user