diff --git a/controller/channel-test.go b/controller/channel-test.go index 9cd96e4a..867cacd3 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -1,11 +1,11 @@ package controller import ( + "bufio" "bytes" "encoding/json" "errors" "fmt" - "io" "net/http" "one-api/common" "one-api/model" @@ -58,24 +58,51 @@ func testChannel(channel *model.Channel, request ChatRequest) error { return errors.New("invalid status code: " + strconv.Itoa(resp.StatusCode)) } - var response TextResponse + var streamResponseText string - body, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - err = json.Unmarshal(body, &response) - if err != nil { - return err - } + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } - // channel.BaseURL starts with https://api.openai.com - if response.Usage.CompletionTokens == 0 && strings.HasPrefix(channel.BaseURL, "https://api.openai.com") { - return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) + if i := strings.Index(string(data), "\n\n"); i >= 0 { + return i + 2, data[0:i], nil + } + + if atEOF { + return len(data), data, nil + } + + return 0, nil, nil + }) + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { // must be something wrong! + common.SysError("invalid stream response: " + data) + continue + } + data = data[6:] + if !strings.HasPrefix(data, "[DONE]") { + var streamResponse ChatCompletionsStreamResponse + err = json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return err + } + for _, choice := range streamResponse.Choices { + streamResponseText += choice.Delta.Content + } + } } defer resp.Body.Close() + // Check if streaming is complete and streamResponseText is populated + if streamResponseText == "" { + return errors.New("Streaming not complete") + } + return nil } @@ -83,7 +110,7 @@ func buildTestRequest() *ChatRequest { testRequest := &ChatRequest{ Model: "", // this will be set later MaxTokens: 1, - Stream: false, + Stream: true, } testMessage := Message{ Role: "user",