fix: channel testing issue
This commit is contained in:
parent
6215d2e71c
commit
2756554f7c
@ -1,17 +1,21 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, request ChatRequest) error {
|
func testChannel(channel *model.Channel, request ChatRequest) error {
|
||||||
@ -50,15 +54,69 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
|
||||||
var response TextResponse
|
var response TextResponse
|
||||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
if err != nil {
|
var streamResponseText string
|
||||||
return err
|
|
||||||
}
|
if isStream {
|
||||||
if response.Usage.CompletionTokens == 0 {
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if i := strings.Index(string(data), "\n\n"); i >= 0 {
|
||||||
|
return i + 2, data[0:i], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < 6 { // must be something wrong!
|
||||||
|
common.SysError("invalid stream response: " + data)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = data[6:]
|
||||||
|
if !strings.HasPrefix(data, "[DONE]") {
|
||||||
|
var streamResponse ChatCompletionsStreamResponse
|
||||||
|
err = json.Unmarshal([]byte(data), &streamResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, choice := range streamResponse.Choices {
|
||||||
|
streamResponseText += choice.Delta.Content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if streamResponseText == "" {
|
||||||
|
return errors.New("empty stream response")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(body, &response)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// channel.BaseURL starts with https://api.openai.com
|
||||||
|
if response.Usage.CompletionTokens == 0 && strings.HasPrefix(channel.BaseURL, "https://api.openai.com") {
|
||||||
|
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user