From 2756554f7c160e8095a7af827785881749fb39c7 Mon Sep 17 00:00:00 2001 From: ckt1031 <65409152+ckt1031@users.noreply.github.com> Date: Sat, 8 Jul 2023 14:24:51 +0800 Subject: [PATCH] fix: channel testing issue --- controller/channel-test.go | 74 +++++++++++++++++++++++++++++++++----- 1 file changed, 66 insertions(+), 8 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index f2ffad01..4625b576 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -1,17 +1,21 @@ package controller import ( + "bufio" "bytes" "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" + "io/ioutil" "net/http" "one-api/common" "one-api/model" "strconv" + "strings" "sync" "time" + + "github.com/gin-gonic/gin" ) func testChannel(channel *model.Channel, request ChatRequest) error { @@ -50,15 +54,69 @@ func testChannel(channel *model.Channel, request ChatRequest) error { if err != nil { return err } - defer resp.Body.Close() + var response TextResponse - err = json.NewDecoder(resp.Body).Decode(&response) - if err != nil { - return err - } - if response.Usage.CompletionTokens == 0 { - return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) + isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") + var streamResponseText string + + if isStream { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + + if i := strings.Index(string(data), "\n\n"); i >= 0 { + return i + 2, data[0:i], nil + } + + if atEOF { + return len(data), data, nil + } + + return 0, nil, nil + }) + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { // must be something wrong! + common.SysError("invalid stream response: " + data) + continue + } + data = data[6:] + if !strings.HasPrefix(data, "[DONE]") { + var streamResponse ChatCompletionsStreamResponse + err = json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return err + } + for _, choice := range streamResponse.Choices { + streamResponseText += choice.Delta.Content + } + } + } + + if streamResponseText == "" { + return errors.New("empty stream response") + } + } else { + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + err = json.Unmarshal(body, &response) + if err != nil { + return err + } + + // channel.BaseURL starts with https://api.openai.com + if response.Usage.CompletionTokens == 0 && strings.HasPrefix(channel.BaseURL, "https://api.openai.com") { + return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) + } } + + defer resp.Body.Close() + return nil }