fix: channel testing issue

This commit is contained in:
ckt1031 2023-07-08 14:24:51 +08:00
parent 6215d2e71c
commit 2756554f7c

View File

@ -1,17 +1,21 @@
package controller package controller
import ( import (
"bufio"
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "io/ioutil"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin"
) )
func testChannel(channel *model.Channel, request ChatRequest) error { func testChannel(channel *model.Channel, request ChatRequest) error {
@ -50,15 +54,69 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
if err != nil { if err != nil {
return err return err
} }
defer resp.Body.Close()
var response TextResponse var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response) isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
var streamResponseText string
if isStream {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n\n"); i >= 0 {
return i + 2, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // must be something wrong!
common.SysError("invalid stream response: " + data)
continue
}
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
var streamResponse ChatCompletionsStreamResponse
err = json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return err
}
for _, choice := range streamResponse.Choices {
streamResponseText += choice.Delta.Content
}
}
}
if streamResponseText == "" {
return errors.New("empty stream response")
}
} else {
body, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
return err return err
} }
if response.Usage.CompletionTokens == 0 { err = json.Unmarshal(body, &response)
if err != nil {
return err
}
// channel.BaseURL starts with https://api.openai.com
if response.Usage.CompletionTokens == 0 && strings.HasPrefix(channel.BaseURL, "https://api.openai.com") {
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
} }
}
defer resp.Body.Close()
return nil return nil
} }