fix: chatgptweb issue
This commit is contained in:
parent
203471d7a9
commit
481c4ebf49
@ -6,7 +6,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
@ -18,6 +17,13 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func formatFloat(input float64) float64 {
|
||||||
|
if input == float64(int64(input)) {
|
||||||
|
return input
|
||||||
|
}
|
||||||
|
return float64(int64(input*10)) / 10
|
||||||
|
}
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, request ChatRequest) error {
|
func testChannel(channel *model.Channel, request ChatRequest) error {
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeAzure:
|
case common.ChannelTypeAzure:
|
||||||
@ -65,11 +71,16 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Construct json data without adding escape character
|
// Construct json data without adding escape character
|
||||||
map1 := map[string]string{
|
map1 := make(map[string]interface{})
|
||||||
"prompt": prompt,
|
|
||||||
"systemMessage": systemMessage.Content,
|
map1["prompt"] = prompt
|
||||||
"temperature": strconv.FormatFloat(request.Temperature, 'f', 2, 64),
|
map1["systemMessage"] = systemMessage.Content
|
||||||
"top_p": strconv.FormatFloat(request.TopP, 'f', 2, 64),
|
|
||||||
|
if request.Temperature != 0 {
|
||||||
|
map1["temperature"] = formatFloat(request.Temperature)
|
||||||
|
}
|
||||||
|
if request.TopP != 0 {
|
||||||
|
map1["top_p"] = formatFloat(request.TopP)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert map to json string
|
// Convert map to json string
|
||||||
@ -122,8 +133,9 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
|||||||
var done = false
|
var done = false
|
||||||
var streamResponseText = ""
|
var streamResponseText = ""
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
|
||||||
if channel.Type != common.ChannelTypeChatGPTWeb {
|
if channel.Type != common.ChannelTypeChatGPTWeb {
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
if atEOF && len(data) == 0 {
|
if atEOF && len(data) == 0 {
|
||||||
return 0, nil, nil
|
return 0, nil, nil
|
||||||
@ -139,11 +151,14 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
|||||||
|
|
||||||
return 0, nil, nil
|
return 0, nil, nil
|
||||||
})
|
})
|
||||||
for scanner.Scan() {
|
}
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 6 { // must be something wrong!
|
for scanner.Scan() {
|
||||||
continue
|
data := scanner.Text()
|
||||||
}
|
if len(data) < 6 { // must be something wrong!
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if channel.Type != common.ChannelTypeChatGPTWeb {
|
||||||
|
|
||||||
// If data has event: event content inside, remove it, it can be prefix or inside the data
|
// If data has event: event content inside, remove it, it can be prefix or inside the data
|
||||||
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
|
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
|
||||||
@ -185,28 +200,26 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
|||||||
done = true
|
done = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else if channel.Type == common.ChannelTypeChatGPTWeb {
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
|
|
||||||
go func() {
|
} else if channel.Type == common.ChannelTypeChatGPTWeb {
|
||||||
for scanner.Scan() {
|
var chatResponse ChatGptWebChatResponse
|
||||||
var chatResponse ChatGptWebChatResponse
|
err = json.Unmarshal([]byte(data), &chatResponse)
|
||||||
err = json.Unmarshal(scanner.Bytes(), &chatResponse)
|
if err != nil {
|
||||||
|
// Print the body in string
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
buf.ReadFrom(resp.Body)
|
||||||
|
common.SysError("error unmarshalling chat response: " + err.Error() + " " + buf.String())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
// if response role is assistant and contains delta, append the content to streamResponseText
|
||||||
log.Println("error unmarshal chat response: " + err.Error())
|
if chatResponse.Role == "assistant" && chatResponse.Detail != nil {
|
||||||
continue
|
for _, choice := range chatResponse.Detail.Choices {
|
||||||
}
|
streamResponseText += choice.Delta.Content
|
||||||
|
|
||||||
// if response role is assistant and contains delta, append the content to streamResponseText
|
|
||||||
if chatResponse.Role == "assistant" && chatResponse.Detail != nil {
|
|
||||||
for _, choice := range chatResponse.Detail.Choices {
|
|
||||||
streamResponseText += choice.Delta.Content
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
@ -11,7 +11,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -223,11 +222,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Construct json data without adding escape character
|
// Construct json data without adding escape character
|
||||||
map1 := map[string]string{
|
map1 := make(map[string]interface{})
|
||||||
"prompt": prompt,
|
|
||||||
"systemMessage": systemMessage.Content,
|
map1["prompt"] = prompt
|
||||||
"temperature": strconv.FormatFloat(reqBody.Temperature, 'f', 2, 64),
|
map1["systemMessage"] = systemMessage.Content
|
||||||
"top_p": strconv.FormatFloat(reqBody.TopP, 'f', 2, 64),
|
|
||||||
|
if reqBody.Temperature != 0 {
|
||||||
|
map1["temperature"] = formatFloat(reqBody.Temperature)
|
||||||
|
}
|
||||||
|
if reqBody.TopP != 0 {
|
||||||
|
map1["top_p"] = formatFloat(reqBody.TopP)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert map to json string
|
// Convert map to json string
|
||||||
@ -348,16 +352,42 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
dataChan := make(chan string)
|
dataChan := make(chan string)
|
||||||
stopChan := make(chan bool)
|
stopChan := make(chan bool)
|
||||||
|
|
||||||
if channelType == common.ChannelTypeChatGPTWeb {
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
var chatResponse ChatGptWebChatResponse
|
|
||||||
err = json.Unmarshal(scanner.Bytes(), &chatResponse)
|
|
||||||
|
|
||||||
|
if channelType != common.ChannelTypeChatGPTWeb {
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||||
|
return i + 2, data[0:i], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < 6 { // must be something wrong!
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if channelType == common.ChannelTypeChatGPTWeb {
|
||||||
|
var chatResponse ChatGptWebChatResponse
|
||||||
|
err = json.Unmarshal([]byte(data), &chatResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("error unmarshal chat response: " + err.Error())
|
// Print the body in string
|
||||||
continue
|
buf := new(bytes.Buffer)
|
||||||
|
buf.ReadFrom(resp.Body)
|
||||||
|
common.SysError("error unmarshalling chat response: " + err.Error() + " " + buf.String())
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// if response role is assistant and contains delta, append the content to streamResponseText
|
// if response role is assistant and contains delta, append the content to streamResponseText
|
||||||
@ -387,33 +417,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
dataChan <- "data: " + string(jsonData)
|
dataChan <- "data: " + string(jsonData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
stopChan <- true
|
|
||||||
}()
|
|
||||||
} else {
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
if atEOF && len(data) == 0 {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
||||||
return i + 1, data[0:i], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if atEOF {
|
|
||||||
return len(data), data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, nil, nil
|
|
||||||
})
|
|
||||||
go func() {
|
|
||||||
for scanner.Scan() {
|
|
||||||
data := scanner.Text()
|
|
||||||
if len(data) < 6 { // must be something wrong!
|
|
||||||
// common.SysError("invalid stream response: " + data)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// If data has event: event content inside, remove it, it can be prefix or inside the data
|
// If data has event: event content inside, remove it, it can be prefix or inside the data
|
||||||
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
|
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
|
||||||
// Remove event: event in the front or back
|
// Remove event: event in the front or back
|
||||||
@ -463,10 +467,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
stopChan <- true
|
}
|
||||||
}()
|
stopChan <- true
|
||||||
}
|
}()
|
||||||
|
|
||||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||||
|
Loading…
Reference in New Issue
Block a user