🐛 fix: stream mode delay issue (#53)
This commit is contained in:
parent
705804e6dd
commit
d7193b8e46
@ -127,11 +127,16 @@ func RequestStream[T streamable](requester *HTTPRequester, resp *http.Response,
|
||||
return nil, HandleErrorResp(resp, requester.ErrorHandler)
|
||||
}
|
||||
|
||||
return &streamReader[T]{
|
||||
stream := &streamReader[T]{
|
||||
reader: bufio.NewReader(resp.Body),
|
||||
response: resp,
|
||||
handlerPrefix: handlerPrefix,
|
||||
}, nil
|
||||
|
||||
DataChan: make(chan T),
|
||||
ErrChan: make(chan error),
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// 设置请求体
|
||||
|
@ -3,16 +3,12 @@ package requester
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// 流处理函数,判断依据如下:
|
||||
// 1.如果有错误信息,则直接返回错误信息
|
||||
// 2.如果isFinished=true,则返回io.EOF,并且如果response不为空,还将返回response
|
||||
// 3.如果rawLine=nil 或者 response长度为0,则直接跳过
|
||||
// 4.如果以上条件都不满足,则返回response
|
||||
type HandlerPrefix[T streamable] func(rawLine *[]byte, isFinished *bool, response *[]T) error
|
||||
var StreamClosed = []byte("stream_closed")
|
||||
|
||||
type HandlerPrefix[T streamable] func(rawLine *[]byte, dataChan chan T, errChan chan error)
|
||||
|
||||
type streamable interface {
|
||||
// types.ChatCompletionStreamResponse | types.CompletionResponse
|
||||
@ -20,57 +16,48 @@ type streamable interface {
|
||||
}
|
||||
|
||||
type StreamReaderInterface[T streamable] interface {
|
||||
Recv() (*[]T, error)
|
||||
Recv() (<-chan T, <-chan error)
|
||||
Close()
|
||||
}
|
||||
|
||||
type streamReader[T streamable] struct {
|
||||
isFinished bool
|
||||
|
||||
reader *bufio.Reader
|
||||
response *http.Response
|
||||
|
||||
handlerPrefix HandlerPrefix[T]
|
||||
|
||||
DataChan chan T
|
||||
ErrChan chan error
|
||||
}
|
||||
|
||||
func (stream *streamReader[T]) Recv() (response *[]T, err error) {
|
||||
if stream.isFinished {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
response, err = stream.processLines()
|
||||
return
|
||||
func (stream *streamReader[T]) Recv() (<-chan T, <-chan error) {
|
||||
go stream.processLines()
|
||||
|
||||
return stream.DataChan, stream.ErrChan
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (stream *streamReader[T]) processLines() (*[]T, error) {
|
||||
func (stream *streamReader[T]) processLines() {
|
||||
for {
|
||||
rawLine, readErr := stream.reader.ReadBytes('\n')
|
||||
if readErr != nil {
|
||||
return nil, readErr
|
||||
stream.ErrChan <- readErr
|
||||
return
|
||||
}
|
||||
|
||||
noSpaceLine := bytes.TrimSpace(rawLine)
|
||||
|
||||
var response []T
|
||||
err := stream.handlerPrefix(&noSpaceLine, &stream.isFinished, &response)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stream.isFinished {
|
||||
if len(response) > 0 {
|
||||
return &response, io.EOF
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
if noSpaceLine == nil || len(response) == 0 {
|
||||
if len(noSpaceLine) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
stream.handlerPrefix(&noSpaceLine, stream.DataChan, stream.ErrChan)
|
||||
|
||||
if noSpaceLine == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(noSpaceLine, StreamClosed) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,55 +1,41 @@
|
||||
package requester
|
||||
|
||||
import (
|
||||
"io"
|
||||
"bytes"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type wsReader[T streamable] struct {
|
||||
isFinished bool
|
||||
|
||||
reader *websocket.Conn
|
||||
handlerPrefix HandlerPrefix[T]
|
||||
|
||||
DataChan chan T
|
||||
ErrChan chan error
|
||||
}
|
||||
|
||||
func (stream *wsReader[T]) Recv() (response *[]T, err error) {
|
||||
if stream.isFinished {
|
||||
err = io.EOF
|
||||
return
|
||||
}
|
||||
|
||||
response, err = stream.processLines()
|
||||
return
|
||||
func (stream *wsReader[T]) Recv() (<-chan T, <-chan error) {
|
||||
go stream.processLines()
|
||||
return stream.DataChan, stream.ErrChan
|
||||
}
|
||||
|
||||
func (stream *wsReader[T]) processLines() (*[]T, error) {
|
||||
func (stream *wsReader[T]) processLines() {
|
||||
for {
|
||||
_, msg, err := stream.reader.ReadMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
stream.ErrChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
var response []T
|
||||
err = stream.handlerPrefix(&msg, &stream.isFinished, &response)
|
||||
stream.handlerPrefix(&msg, stream.DataChan, stream.ErrChan)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stream.isFinished {
|
||||
if len(response) > 0 {
|
||||
return &response, io.EOF
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
if msg == nil || len(response) == 0 {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
|
||||
if bytes.Equal(msg, StreamClosed) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -38,10 +38,15 @@ func SendWSJsonRequest[T streamable](conn *websocket.Conn, data any, handlerPref
|
||||
return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return &wsReader[T]{
|
||||
stream := &wsReader[T]{
|
||||
reader: conn,
|
||||
handlerPrefix: handlerPrefix,
|
||||
}, nil
|
||||
|
||||
DataChan: make(chan T),
|
||||
ErrChan: make(chan error),
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
|
@ -1,7 +1,6 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@ -53,13 +52,13 @@ func RelayChat(c *gin.Context) {
|
||||
}
|
||||
|
||||
if chatRequest.Stream {
|
||||
var response requester.StreamReaderInterface[types.ChatCompletionStreamResponse]
|
||||
var response requester.StreamReaderInterface[string]
|
||||
response, errWithCode = chatProvider.CreateChatCompletionStream(&chatRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseStreamClient[types.ChatCompletionStreamResponse](c, response)
|
||||
errWithCode = responseStreamClient(c, response)
|
||||
} else {
|
||||
var response *types.ChatCompletionResponse
|
||||
response, errWithCode = chatProvider.CreateChatCompletion(&chatRequest)
|
||||
@ -70,8 +69,6 @@ func RelayChat(c *gin.Context) {
|
||||
errWithCode = responseJsonClient(c, response)
|
||||
}
|
||||
|
||||
fmt.Println(usage)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
@ -51,14 +52,16 @@ func RelayCompletions(c *gin.Context) {
|
||||
}
|
||||
|
||||
if completionRequest.Stream {
|
||||
response, errWithCode := completionProvider.CreateCompletionStream(&completionRequest)
|
||||
var response requester.StreamReaderInterface[string]
|
||||
response, errWithCode = completionProvider.CreateCompletionStream(&completionRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseStreamClient[types.CompletionResponse](c, response)
|
||||
errWithCode = responseStreamClient(c, response)
|
||||
} else {
|
||||
response, errWithCode := completionProvider.CreateCompletion(&completionRequest)
|
||||
var response *types.CompletionResponse
|
||||
response, errWithCode = completionProvider.CreateCompletion(&completionRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
|
@ -154,34 +154,25 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith
|
||||
return nil
|
||||
}
|
||||
|
||||
func responseStreamClient[T any](c *gin.Context, stream requester.StreamReaderInterface[T]) *types.OpenAIErrorWithStatusCode {
|
||||
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string]) *types.OpenAIErrorWithStatusCode {
|
||||
requester.SetEventStreamHeaders(c)
|
||||
dataChan, errChan := stream.Recv()
|
||||
|
||||
defer stream.Close()
|
||||
|
||||
for {
|
||||
response, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
if response != nil && len(*response) > 0 {
|
||||
for _, streamResponse := range *response {
|
||||
responseBody, _ := json.Marshal(streamResponse)
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)})
|
||||
}
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
fmt.Fprintln(w, "data: "+data+"\n")
|
||||
return true
|
||||
case err := <-errChan:
|
||||
if !errors.Is(err, io.EOF) {
|
||||
fmt.Fprintln(w, "data: "+err.Error()+"\n")
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + err.Error()})
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
break
|
||||
fmt.Fprintln(w, "data: [DONE]")
|
||||
return false
|
||||
}
|
||||
|
||||
for _, streamResponse := range *response {
|
||||
responseBody, _ := json.Marshal(streamResponse)
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ func (p *AliProvider) CreateChatCompletion(request *types.ChatCompletionRequest)
|
||||
return p.convertToChatOpenai(aliResponse, request)
|
||||
}
|
||||
|
||||
func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getAliChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
@ -52,7 +52,7 @@ func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRe
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
||||
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *AliProvider) getAliChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
@ -162,11 +162,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *AliChatRequest
|
||||
}
|
||||
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *aliStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *aliStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data:") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
// 去除前缀
|
||||
@ -175,19 +175,21 @@ func (h *aliStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, resp
|
||||
var aliResponse AliChatResponse
|
||||
err := json.Unmarshal(*rawLine, &aliResponse)
|
||||
if err != nil {
|
||||
return common.ErrorToOpenAIError(err)
|
||||
errChan <- common.ErrorToOpenAIError(err)
|
||||
return
|
||||
}
|
||||
|
||||
error := errorHandle(&aliResponse.AliError)
|
||||
if error != nil {
|
||||
return error
|
||||
errChan <- error
|
||||
return
|
||||
}
|
||||
|
||||
return h.convertToOpenaiStream(&aliResponse, response)
|
||||
h.convertToOpenaiStream(&aliResponse, dataChan, errChan)
|
||||
|
||||
}
|
||||
|
||||
func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, dataChan chan string, errChan chan error) {
|
||||
content := aliResponse.Output.Choices[0].Message.StringContent()
|
||||
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
@ -222,7 +224,6 @@ func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, r
|
||||
h.Usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
}
|
||||
|
||||
*response = append(*response, streamResponse)
|
||||
|
||||
return nil
|
||||
responseBody, _ := json.Marshal(streamResponse)
|
||||
dataChan <- string(responseBody)
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionReq
|
||||
return &response.ChatCompletionResponse, nil
|
||||
}
|
||||
|
||||
func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
@ -57,7 +57,7 @@ func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatComplet
|
||||
ModelName: request.Model,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream)
|
||||
return requester.RequestStream[string](p.Requester, resp, chatHandler.HandlerChatStream)
|
||||
}
|
||||
|
||||
// 获取聊天请求体
|
||||
|
@ -2,6 +2,7 @@ package baidu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
@ -31,7 +32,7 @@ func (p *BaiduProvider) CreateChatCompletion(request *types.ChatCompletionReques
|
||||
return p.convertToChatOpenai(baiduResponse, request)
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getBaiduChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
@ -49,7 +50,7 @@ func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletion
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
||||
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) getBaiduChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
@ -178,11 +179,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *BaiduChatReque
|
||||
}
|
||||
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
// 去除前缀
|
||||
@ -191,18 +192,26 @@ func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, re
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
err := json.Unmarshal(*rawLine, &baiduResponse)
|
||||
if err != nil {
|
||||
return common.ErrorToOpenAIError(err)
|
||||
errChan <- common.ErrorToOpenAIError(err)
|
||||
return
|
||||
}
|
||||
|
||||
error := errorHandle(&baiduResponse.BaiduError)
|
||||
if error != nil {
|
||||
errChan <- error
|
||||
return
|
||||
}
|
||||
|
||||
h.convertToOpenaiStream(&baiduResponse, dataChan, errChan)
|
||||
|
||||
if baiduResponse.IsEnd {
|
||||
*isFinished = true
|
||||
errChan <- io.EOF
|
||||
*rawLine = requester.StreamClosed
|
||||
return
|
||||
}
|
||||
|
||||
return h.convertToOpenaiStream(&baiduResponse, response)
|
||||
|
||||
}
|
||||
|
||||
func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, dataChan chan string, errChan chan error) {
|
||||
choice := types.ChatCompletionStreamChoice{
|
||||
Index: 0,
|
||||
Delta: types.ChatCompletionStreamChoiceDelta{
|
||||
@ -240,19 +249,19 @@ func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStrea
|
||||
|
||||
if baiduResponse.FunctionCall == nil {
|
||||
chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
*response = append(*response, chatCompletion)
|
||||
responseBody, _ := json.Marshal(chatCompletion)
|
||||
dataChan <- string(responseBody)
|
||||
} else {
|
||||
choices := choice.ConvertOpenaiStream()
|
||||
for _, choice := range choices {
|
||||
chatCompletionCopy := chatCompletion
|
||||
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
*response = append(*response, chatCompletionCopy)
|
||||
responseBody, _ := json.Marshal(chatCompletionCopy)
|
||||
dataChan <- string(responseBody)
|
||||
}
|
||||
}
|
||||
|
||||
h.Usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
h.Usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||
h.Usage.CompletionTokens += baiduResponse.Usage.CompletionTokens
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -41,14 +41,14 @@ type ProviderInterface interface {
|
||||
type CompletionInterface interface {
|
||||
ProviderInterface
|
||||
CreateCompletion(request *types.CompletionRequest) (*types.CompletionResponse, *types.OpenAIErrorWithStatusCode)
|
||||
CreateCompletionStream(request *types.CompletionRequest) (requester.StreamReaderInterface[types.CompletionResponse], *types.OpenAIErrorWithStatusCode)
|
||||
CreateCompletionStream(request *types.CompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 聊天接口
|
||||
type ChatInterface interface {
|
||||
ProviderInterface
|
||||
CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode)
|
||||
CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode)
|
||||
CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 嵌入接口
|
||||
|
@ -3,6 +3,7 @@ package claude
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
@ -32,7 +33,7 @@ func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionReque
|
||||
return p.convertToChatOpenai(claudeResponse, request)
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
@ -50,7 +51,7 @@ func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletio
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
||||
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
@ -149,11 +150,11 @@ func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *
|
||||
}
|
||||
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), `data: {"type": "completion"`) {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
// 去除前缀
|
||||
@ -162,17 +163,26 @@ func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, r
|
||||
var claudeResponse *ClaudeResponse
|
||||
err := json.Unmarshal(*rawLine, claudeResponse)
|
||||
if err != nil {
|
||||
return common.ErrorToOpenAIError(err)
|
||||
errChan <- common.ErrorToOpenAIError(err)
|
||||
return
|
||||
}
|
||||
|
||||
error := errorHandle(&claudeResponse.ClaudeResponseError)
|
||||
if error != nil {
|
||||
errChan <- error
|
||||
return
|
||||
}
|
||||
|
||||
if claudeResponse.StopReason == "stop_sequence" {
|
||||
*isFinished = true
|
||||
errChan <- io.EOF
|
||||
*rawLine = requester.StreamClosed
|
||||
return
|
||||
}
|
||||
|
||||
return h.convertToOpenaiStream(claudeResponse, response)
|
||||
h.convertToOpenaiStream(claudeResponse, dataChan, errChan)
|
||||
}
|
||||
|
||||
func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeResponse, dataChan chan string, errChan chan error) {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = claudeResponse.Completion
|
||||
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
||||
@ -185,9 +195,8 @@ func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeRespon
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
|
||||
*response = append(*response, chatCompletion)
|
||||
responseBody, _ := json.Marshal(chatCompletion)
|
||||
dataChan <- string(responseBody)
|
||||
|
||||
h.Usage.PromptTokens += common.CountTokenText(claudeResponse.Completion, h.Request.Model)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ func (p *GeminiProvider) CreateChatCompletion(request *types.ChatCompletionReque
|
||||
return p.convertToChatOpenai(geminiChatResponse, request)
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
@ -55,7 +55,7 @@ func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletio
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
||||
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
@ -228,11 +228,11 @@ func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, reque
|
||||
}
|
||||
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
// 去除前缀
|
||||
@ -241,19 +241,21 @@ func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, r
|
||||
var geminiResponse GeminiChatResponse
|
||||
err := json.Unmarshal(*rawLine, &geminiResponse)
|
||||
if err != nil {
|
||||
return common.ErrorToOpenAIError(err)
|
||||
errChan <- common.ErrorToOpenAIError(err)
|
||||
return
|
||||
}
|
||||
|
||||
error := errorHandle(&geminiResponse.GeminiErrorResponse)
|
||||
if error != nil {
|
||||
return error
|
||||
errChan <- error
|
||||
return
|
||||
}
|
||||
|
||||
return h.convertToOpenaiStream(&geminiResponse, response)
|
||||
h.convertToOpenaiStream(&geminiResponse, dataChan, errChan)
|
||||
|
||||
}
|
||||
|
||||
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string, errChan chan error) {
|
||||
choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates))
|
||||
|
||||
for i, candidate := range geminiResponse.Candidates {
|
||||
@ -275,10 +277,9 @@ func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatRe
|
||||
Choices: choices,
|
||||
}
|
||||
|
||||
*response = append(*response, streamResponse)
|
||||
responseBody, _ := json.Marshal(streamResponse)
|
||||
dataChan <- string(responseBody)
|
||||
|
||||
h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model)
|
||||
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ func (p *MiniMaxProvider) CreateChatCompletion(request *types.ChatCompletionRequ
|
||||
return p.convertToChatOpenai(response, request)
|
||||
}
|
||||
|
||||
func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
@ -50,7 +50,7 @@ func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompleti
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
||||
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *MiniMaxProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
@ -191,11 +191,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *MiniMaxChatReq
|
||||
}
|
||||
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *minimaxStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *minimaxStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||
// 如果rawLine 前缀不为data: 或者 meta:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
*rawLine = (*rawLine)[6:]
|
||||
@ -203,25 +203,27 @@ func (h *minimaxStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool,
|
||||
miniResponse := &MiniMaxChatResponse{}
|
||||
err := json.Unmarshal(*rawLine, miniResponse)
|
||||
if err != nil {
|
||||
return common.ErrorToOpenAIError(err)
|
||||
errChan <- common.ErrorToOpenAIError(err)
|
||||
return
|
||||
}
|
||||
|
||||
error := errorHandle(&miniResponse.BaseResp)
|
||||
if error != nil {
|
||||
return error
|
||||
errChan <- error
|
||||
return
|
||||
}
|
||||
|
||||
choice := miniResponse.Choices[0]
|
||||
|
||||
if choice.Messages[0].FunctionCall != nil && choice.FinishReason == "" {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
return h.convertToOpenaiStream(miniResponse, response)
|
||||
h.convertToOpenaiStream(miniResponse, dataChan, errChan)
|
||||
}
|
||||
|
||||
func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatResponse, dataChan chan string, errChan chan error) {
|
||||
streamResponse := types.ChatCompletionStreamResponse{
|
||||
ID: miniResponse.RequestID,
|
||||
Object: "chat.completion.chunk",
|
||||
@ -235,8 +237,8 @@ func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatRe
|
||||
if miniChoice.Messages[0].FunctionCall == nil && miniChoice.FinishReason != "" {
|
||||
streamResponse.ID = miniResponse.ID
|
||||
openaiChoice.FinishReason = convertFinishReason(miniChoice.FinishReason)
|
||||
h.appendResponse(&streamResponse, &openaiChoice, response)
|
||||
return nil
|
||||
dataChan <- h.getResponseString(&streamResponse, &openaiChoice)
|
||||
return
|
||||
}
|
||||
|
||||
openaiChoice.Delta = types.ChatCompletionStreamChoiceDelta{
|
||||
@ -248,19 +250,17 @@ func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatRe
|
||||
convertChoices := openaiChoice.ConvertOpenaiStream()
|
||||
for _, convertChoice := range convertChoices {
|
||||
chatCompletionCopy := streamResponse
|
||||
h.appendResponse(&chatCompletionCopy, &convertChoice, response)
|
||||
dataChan <- h.getResponseString(&chatCompletionCopy, &convertChoice)
|
||||
}
|
||||
|
||||
} else {
|
||||
openaiChoice.Delta.Content = miniChoice.Messages[0].Text
|
||||
h.appendResponse(&streamResponse, &openaiChoice, response)
|
||||
dataChan <- h.getResponseString(&streamResponse, &openaiChoice)
|
||||
}
|
||||
|
||||
if miniResponse.Usage != nil {
|
||||
h.handleUsage(miniResponse)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *minimaxStreamHandler) handleFunctionCall(choice *Choice, openaiChoice *types.ChatCompletionStreamChoice) {
|
||||
@ -274,9 +274,10 @@ func (h *minimaxStreamHandler) handleFunctionCall(choice *Choice, openaiChoice *
|
||||
}
|
||||
}
|
||||
|
||||
func (h *minimaxStreamHandler) appendResponse(streamResponse *types.ChatCompletionStreamResponse, openaiChoice *types.ChatCompletionStreamChoice, response *[]types.ChatCompletionStreamResponse) {
|
||||
func (h *minimaxStreamHandler) getResponseString(streamResponse *types.ChatCompletionStreamResponse, openaiChoice *types.ChatCompletionStreamChoice) string {
|
||||
streamResponse.Choices = []types.ChatCompletionStreamChoice{*openaiChoice}
|
||||
*response = append(*response, *streamResponse)
|
||||
responseBody, _ := json.Marshal(streamResponse)
|
||||
return string(responseBody)
|
||||
}
|
||||
|
||||
func (h *minimaxStreamHandler) handleUsage(miniResponse *MiniMaxChatResponse) {
|
||||
|
@ -2,16 +2,19 @@ package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OpenAIStreamHandler struct {
|
||||
Usage *types.Usage
|
||||
ModelName string
|
||||
isAzure bool
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
@ -43,7 +46,7 @@ func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionReque
|
||||
return &response.ChatCompletionResponse, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
@ -59,16 +62,17 @@ func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletio
|
||||
chatHandler := OpenAIStreamHandler{
|
||||
Usage: p.Usage,
|
||||
ModelName: request.Model,
|
||||
isAzure: p.IsAzure,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream)
|
||||
return requester.RequestStream[string](p.Requester, resp, chatHandler.HandlerChatStream)
|
||||
}
|
||||
|
||||
func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
// 去除前缀
|
||||
@ -76,26 +80,32 @@ func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, isFinished *boo
|
||||
|
||||
// 如果等于 DONE 则结束
|
||||
if string(*rawLine) == "[DONE]" {
|
||||
*isFinished = true
|
||||
return nil
|
||||
errChan <- io.EOF
|
||||
*rawLine = requester.StreamClosed
|
||||
return
|
||||
}
|
||||
|
||||
var openaiResponse OpenAIProviderChatStreamResponse
|
||||
err := json.Unmarshal(*rawLine, &openaiResponse)
|
||||
if err != nil {
|
||||
return common.ErrorToOpenAIError(err)
|
||||
errChan <- common.ErrorToOpenAIError(err)
|
||||
return
|
||||
}
|
||||
|
||||
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
|
||||
if error != nil {
|
||||
return error
|
||||
errChan <- error
|
||||
return
|
||||
}
|
||||
|
||||
dataChan <- string(*rawLine)
|
||||
|
||||
if h.isAzure {
|
||||
// 阻塞 20ms
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
|
||||
h.Usage.CompletionTokens += countTokenText
|
||||
h.Usage.TotalTokens += countTokenText
|
||||
|
||||
*response = append(*response, openaiResponse.ChatCompletionStreamResponse)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
@ -38,7 +39,7 @@ func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (ope
|
||||
return &response.CompletionResponse, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[types.CompletionResponse], errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[string], errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
@ -56,14 +57,14 @@ func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest
|
||||
ModelName: request.Model,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.CompletionResponse](p.Requester, resp, chatHandler.handlerCompletionStream)
|
||||
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerCompletionStream)
|
||||
}
|
||||
|
||||
func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, isFinished *bool, response *[]types.CompletionResponse) error {
|
||||
func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
// 去除前缀
|
||||
@ -71,26 +72,27 @@ func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, isFinishe
|
||||
|
||||
// 如果等于 DONE 则结束
|
||||
if string(*rawLine) == "[DONE]" {
|
||||
*isFinished = true
|
||||
return nil
|
||||
errChan <- io.EOF
|
||||
*rawLine = requester.StreamClosed
|
||||
return
|
||||
}
|
||||
|
||||
var openaiResponse OpenAIProviderCompletionResponse
|
||||
err := json.Unmarshal(*rawLine, &openaiResponse)
|
||||
if err != nil {
|
||||
return common.ErrorToOpenAIError(err)
|
||||
errChan <- common.ErrorToOpenAIError(err)
|
||||
return
|
||||
}
|
||||
|
||||
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
|
||||
if error != nil {
|
||||
return error
|
||||
errChan <- error
|
||||
return
|
||||
}
|
||||
|
||||
dataChan <- string(*rawLine)
|
||||
|
||||
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
|
||||
h.Usage.CompletionTokens += countTokenText
|
||||
h.Usage.TotalTokens += countTokenText
|
||||
|
||||
*response = append(*response, openaiResponse.CompletionResponse)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ func (p *PalmProvider) CreateChatCompletion(request *types.ChatCompletionRequest
|
||||
return p.convertToChatOpenai(palmResponse, request)
|
||||
}
|
||||
|
||||
func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
@ -50,7 +50,7 @@ func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionR
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
||||
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *PalmProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
@ -142,11 +142,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *PaLMChatReques
|
||||
}
|
||||
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *palmStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *palmStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
// 去除前缀
|
||||
@ -155,19 +155,21 @@ func (h *palmStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, res
|
||||
var palmChatResponse PaLMChatResponse
|
||||
err := json.Unmarshal(*rawLine, &palmChatResponse)
|
||||
if err != nil {
|
||||
return common.ErrorToOpenAIError(err)
|
||||
errChan <- common.ErrorToOpenAIError(err)
|
||||
return
|
||||
}
|
||||
|
||||
error := errorHandle(&palmChatResponse.PaLMErrorResponse)
|
||||
if error != nil {
|
||||
return error
|
||||
errChan <- error
|
||||
return
|
||||
}
|
||||
|
||||
return h.convertToOpenaiStream(&palmChatResponse, response)
|
||||
h.convertToOpenaiStream(&palmChatResponse, dataChan, errChan)
|
||||
|
||||
}
|
||||
|
||||
func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResponse, dataChan chan string, errChan chan error) {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
if len(palmChatResponse.Candidates) > 0 {
|
||||
choice.Delta.Content = palmChatResponse.Candidates[0].Content
|
||||
@ -182,10 +184,9 @@ func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResp
|
||||
Created: common.GetTimestamp(),
|
||||
}
|
||||
|
||||
*response = append(*response, streamResponse)
|
||||
responseBody, _ := json.Marshal(streamResponse)
|
||||
dataChan <- string(responseBody)
|
||||
|
||||
h.Usage.CompletionTokens += common.CountTokenText(palmChatResponse.Candidates[0].Content, h.Request.Model)
|
||||
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ func (p *TencentProvider) CreateChatCompletion(request *types.ChatCompletionRequ
|
||||
return p.convertToChatOpenai(tencentChatResponse, request)
|
||||
}
|
||||
|
||||
func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
@ -50,7 +50,7 @@ func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompleti
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
||||
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *TencentProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
@ -157,11 +157,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *TencentChatReq
|
||||
}
|
||||
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data:") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
// 去除前缀
|
||||
@ -170,19 +170,21 @@ func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool,
|
||||
var tencentChatResponse TencentChatResponse
|
||||
err := json.Unmarshal(*rawLine, &tencentChatResponse)
|
||||
if err != nil {
|
||||
return common.ErrorToOpenAIError(err)
|
||||
errChan <- common.ErrorToOpenAIError(err)
|
||||
return
|
||||
}
|
||||
|
||||
error := errorHandle(&tencentChatResponse.TencentResponseError)
|
||||
if error != nil {
|
||||
return error
|
||||
errChan <- error
|
||||
return
|
||||
}
|
||||
|
||||
return h.convertToOpenaiStream(&tencentChatResponse, response)
|
||||
h.convertToOpenaiStream(&tencentChatResponse, dataChan, errChan)
|
||||
|
||||
}
|
||||
|
||||
func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, dataChan chan string, errChan chan error) {
|
||||
streamResponse := types.ChatCompletionStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
@ -197,10 +199,9 @@ func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *Tencen
|
||||
streamResponse.Choices = append(streamResponse.Choices, choice)
|
||||
}
|
||||
|
||||
*response = append(*response, streamResponse)
|
||||
responseBody, _ := json.Marshal(streamResponse)
|
||||
dataChan <- string(responseBody)
|
||||
|
||||
h.Usage.CompletionTokens += common.CountTokenText(tencentChatResponse.Choices[0].Delta.Content, h.Request.Model)
|
||||
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ func (p *XunfeiProvider) CreateChatCompletion(request *types.ChatCompletionReque
|
||||
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
wsConn, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
@ -53,7 +53,7 @@ func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletio
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.SendWSJsonRequest[types.ChatCompletionStreamResponse](wsConn, xunfeiRequest, chatHandler.handlerStream)
|
||||
return requester.SendWSJsonRequest[string](wsConn, xunfeiRequest, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (*websocket.Conn, *types.OpenAIErrorWithStatusCode) {
|
||||
@ -123,23 +123,26 @@ func (p *XunfeiProvider) convertFromChatOpenai(request *types.ChatCompletionRequ
|
||||
func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterface[XunfeiChatResponse]) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
var content string
|
||||
var xunfeiResponse XunfeiChatResponse
|
||||
dataChan, errChan := stream.Recv()
|
||||
|
||||
for {
|
||||
response, err := stream.Recv()
|
||||
stop := false
|
||||
for !stop {
|
||||
select {
|
||||
case response := <-dataChan:
|
||||
if len(response.Payload.Choices.Text) == 0 {
|
||||
continue
|
||||
}
|
||||
xunfeiResponse = response
|
||||
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
case err := <-errChan:
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, common.ErrorWrapper(err, "xunfei_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return nil, common.ErrorWrapper(err, "xunfei_failed", http.StatusInternalServerError)
|
||||
if errors.Is(err, io.EOF) {
|
||||
stop = true
|
||||
}
|
||||
}
|
||||
|
||||
if errors.Is(err, io.EOF) && response == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if len((*response)[0].Payload.Choices.Text) == 0 {
|
||||
continue
|
||||
}
|
||||
xunfeiResponse = (*response)[0]
|
||||
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
}
|
||||
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
@ -193,7 +196,7 @@ func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterfa
|
||||
}
|
||||
|
||||
func (h *xunfeiHandler) handlerData(rawLine *[]byte, isFinished *bool) (*XunfeiChatResponse, error) {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
// 如果rawLine 前缀不为{,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "{") {
|
||||
*rawLine = nil
|
||||
return nil, nil
|
||||
@ -221,34 +224,47 @@ func (h *xunfeiHandler) handlerData(rawLine *[]byte, isFinished *bool) (*XunfeiC
|
||||
return &xunfeiChatResponse, nil
|
||||
}
|
||||
|
||||
func (h *xunfeiHandler) handlerNotStream(rawLine *[]byte, isFinished *bool, response *[]XunfeiChatResponse) error {
|
||||
xunfeiChatResponse, err := h.handlerData(rawLine, isFinished)
|
||||
func (h *xunfeiHandler) handlerNotStream(rawLine *[]byte, dataChan chan XunfeiChatResponse, errChan chan error) {
|
||||
isFinished := false
|
||||
xunfeiChatResponse, err := h.handlerData(rawLine, &isFinished)
|
||||
if err != nil {
|
||||
return err
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
if *rawLine == nil {
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
*response = append(*response, *xunfeiChatResponse)
|
||||
return nil
|
||||
dataChan <- *xunfeiChatResponse
|
||||
|
||||
if isFinished {
|
||||
errChan <- io.EOF
|
||||
*rawLine = requester.StreamClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (h *xunfeiHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
xunfeiChatResponse, err := h.handlerData(rawLine, isFinished)
|
||||
func (h *xunfeiHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||
isFinished := false
|
||||
xunfeiChatResponse, err := h.handlerData(rawLine, &isFinished)
|
||||
if err != nil {
|
||||
return err
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
if *rawLine == nil {
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
return h.convertToOpenaiStream(xunfeiChatResponse, response)
|
||||
h.convertToOpenaiStream(xunfeiChatResponse, dataChan, errChan)
|
||||
|
||||
if isFinished {
|
||||
errChan <- io.EOF
|
||||
*rawLine = requester.StreamClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResponse, dataChan chan string, errChan chan error) {
|
||||
if len(xunfeiChatResponse.Payload.Choices.Text) == 0 {
|
||||
xunfeiChatResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
|
||||
}
|
||||
@ -293,15 +309,15 @@ func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResp
|
||||
|
||||
if xunfeiText.FunctionCall == nil {
|
||||
chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
*response = append(*response, chatCompletion)
|
||||
responseBody, _ := json.Marshal(chatCompletion)
|
||||
dataChan <- string(responseBody)
|
||||
} else {
|
||||
choices := choice.ConvertOpenaiStream()
|
||||
for _, choice := range choices {
|
||||
chatCompletionCopy := chatCompletion
|
||||
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
*response = append(*response, chatCompletionCopy)
|
||||
responseBody, _ := json.Marshal(chatCompletionCopy)
|
||||
dataChan <- string(responseBody)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package zhipu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
@ -31,7 +32,7 @@ func (p *ZhipuProvider) CreateChatCompletion(request *types.ChatCompletionReques
|
||||
return p.convertToChatOpenai(zhipuChatResponse, request)
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getChatRequest(request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
@ -49,7 +50,7 @@ func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletion
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
||||
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
@ -140,35 +141,38 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest {
|
||||
}
|
||||
|
||||
// 转换为OpenAI聊天流式请求体
|
||||
func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||
// 如果rawLine 前缀不为data: 或者 meta:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
*rawLine = (*rawLine)[6:]
|
||||
|
||||
if strings.HasPrefix(string(*rawLine), "[DONE]") {
|
||||
*isFinished = true
|
||||
return nil
|
||||
errChan <- io.EOF
|
||||
*rawLine = requester.StreamClosed
|
||||
return
|
||||
}
|
||||
|
||||
zhipuResponse := &ZhipuStreamResponse{}
|
||||
err := json.Unmarshal(*rawLine, zhipuResponse)
|
||||
if err != nil {
|
||||
return common.ErrorToOpenAIError(err)
|
||||
errChan <- common.ErrorToOpenAIError(err)
|
||||
return
|
||||
}
|
||||
|
||||
error := errorHandle(&zhipuResponse.Error)
|
||||
if error != nil {
|
||||
return error
|
||||
errChan <- error
|
||||
return
|
||||
}
|
||||
|
||||
return h.convertToOpenaiStream(zhipuResponse, response)
|
||||
h.convertToOpenaiStream(zhipuResponse, dataChan, errChan)
|
||||
}
|
||||
|
||||
func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamResponse, response *[]types.ChatCompletionStreamResponse) error {
|
||||
func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamResponse, dataChan chan string, errChan chan error) {
|
||||
streamResponse := types.ChatCompletionStreamResponse{
|
||||
ID: zhipuResponse.ID,
|
||||
Object: "chat.completion.chunk",
|
||||
@ -183,16 +187,16 @@ func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamRes
|
||||
for _, choice := range choices {
|
||||
chatCompletionCopy := streamResponse
|
||||
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
*response = append(*response, chatCompletionCopy)
|
||||
responseBody, _ := json.Marshal(chatCompletionCopy)
|
||||
dataChan <- string(responseBody)
|
||||
}
|
||||
} else {
|
||||
streamResponse.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
*response = append(*response, streamResponse)
|
||||
responseBody, _ := json.Marshal(streamResponse)
|
||||
dataChan <- string(responseBody)
|
||||
}
|
||||
|
||||
if zhipuResponse.Usage != nil {
|
||||
*h.Usage = *zhipuResponse.Usage
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user