diff --git a/common/requester/http_requester.go b/common/requester/http_requester.go index 1b6c84cf..57808ac7 100644 --- a/common/requester/http_requester.go +++ b/common/requester/http_requester.go @@ -127,11 +127,16 @@ func RequestStream[T streamable](requester *HTTPRequester, resp *http.Response, return nil, HandleErrorResp(resp, requester.ErrorHandler) } - return &streamReader[T]{ + stream := &streamReader[T]{ reader: bufio.NewReader(resp.Body), response: resp, handlerPrefix: handlerPrefix, - }, nil + + DataChan: make(chan T), + ErrChan: make(chan error), + } + + return stream, nil } // 设置请求体 diff --git a/common/requester/http_stream_reader.go b/common/requester/http_stream_reader.go index 06e70090..6045a012 100644 --- a/common/requester/http_stream_reader.go +++ b/common/requester/http_stream_reader.go @@ -3,16 +3,12 @@ package requester import ( "bufio" "bytes" - "io" "net/http" ) -// 流处理函数,判断依据如下: -// 1.如果有错误信息,则直接返回错误信息 -// 2.如果isFinished=true,则返回io.EOF,并且如果response不为空,还将返回response -// 3.如果rawLine=nil 或者 response长度为0,则直接跳过 -// 4.如果以上条件都不满足,则返回response -type HandlerPrefix[T streamable] func(rawLine *[]byte, isFinished *bool, response *[]T) error +var StreamClosed = []byte("stream_closed") + +type HandlerPrefix[T streamable] func(rawLine *[]byte, dataChan chan T, errChan chan error) type streamable interface { // types.ChatCompletionStreamResponse | types.CompletionResponse @@ -20,57 +16,48 @@ type streamable interface { } type StreamReaderInterface[T streamable] interface { - Recv() (*[]T, error) + Recv() (<-chan T, <-chan error) Close() } type streamReader[T streamable] struct { - isFinished bool - reader *bufio.Reader response *http.Response handlerPrefix HandlerPrefix[T] + + DataChan chan T + ErrChan chan error } -func (stream *streamReader[T]) Recv() (response *[]T, err error) { - if stream.isFinished { - err = io.EOF - return - } - response, err = stream.processLines() - return +func (stream *streamReader[T]) Recv() (<-chan T, <-chan error) { + go stream.processLines() + + return stream.DataChan, stream.ErrChan } //nolint:gocognit -func (stream *streamReader[T]) processLines() (*[]T, error) { +func (stream *streamReader[T]) processLines() { for { rawLine, readErr := stream.reader.ReadBytes('\n') if readErr != nil { - return nil, readErr + stream.ErrChan <- readErr + return } - noSpaceLine := bytes.TrimSpace(rawLine) - - var response []T - err := stream.handlerPrefix(&noSpaceLine, &stream.isFinished, &response) - - if err != nil { - return nil, err - } - - if stream.isFinished { - if len(response) > 0 { - return &response, io.EOF - } - return nil, io.EOF - } - - if noSpaceLine == nil || len(response) == 0 { + if len(noSpaceLine) == 0 { continue } - return &response, nil + stream.handlerPrefix(&noSpaceLine, stream.DataChan, stream.ErrChan) + + if noSpaceLine == nil { + continue + } + + if bytes.Equal(noSpaceLine, StreamClosed) { + return + } } } diff --git a/common/requester/ws_reader.go b/common/requester/ws_reader.go index 24d91be4..da089065 100644 --- a/common/requester/ws_reader.go +++ b/common/requester/ws_reader.go @@ -1,55 +1,41 @@ package requester import ( - "io" + "bytes" "github.com/gorilla/websocket" ) type wsReader[T streamable] struct { - isFinished bool - reader *websocket.Conn handlerPrefix HandlerPrefix[T] + + DataChan chan T + ErrChan chan error } -func (stream *wsReader[T]) Recv() (response *[]T, err error) { - if stream.isFinished { - err = io.EOF - return - } - - response, err = stream.processLines() - return +func (stream *wsReader[T]) Recv() (<-chan T, <-chan error) { + go stream.processLines() + return stream.DataChan, stream.ErrChan } -func (stream *wsReader[T]) processLines() (*[]T, error) { +func (stream *wsReader[T]) processLines() { for { _, msg, err := stream.reader.ReadMessage() if err != nil { - return nil, err + stream.ErrChan <- err + return } - var response []T - err = stream.handlerPrefix(&msg, &stream.isFinished, &response) + stream.handlerPrefix(&msg, stream.DataChan, stream.ErrChan) - if err != nil { - return nil, err - } - - if stream.isFinished { - if len(response) > 0 { - return &response, io.EOF - } - return nil, io.EOF - } - - if msg == nil || len(response) == 0 { + if msg == nil { continue } - return &response, nil - + if bytes.Equal(msg, StreamClosed) { + return + } } } diff --git a/common/requester/ws_requester.go b/common/requester/ws_requester.go index e16c8121..914e6f1e 100644 --- a/common/requester/ws_requester.go +++ b/common/requester/ws_requester.go @@ -38,10 +38,15 @@ func SendWSJsonRequest[T streamable](conn *websocket.Conn, data any, handlerPref return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError) } - return &wsReader[T]{ + stream := &wsReader[T]{ reader: conn, handlerPrefix: handlerPrefix, - }, nil + + DataChan: make(chan T), + ErrChan: make(chan error), + } + + return stream, nil } // 设置请求头 diff --git a/controller/relay-chat.go b/controller/relay-chat.go index 9880ac5e..61267311 100644 --- a/controller/relay-chat.go +++ b/controller/relay-chat.go @@ -1,7 +1,6 @@ package controller import ( - "fmt" "math" "net/http" "one-api/common" @@ -53,13 +52,13 @@ func RelayChat(c *gin.Context) { } if chatRequest.Stream { - var response requester.StreamReaderInterface[types.ChatCompletionStreamResponse] + var response requester.StreamReaderInterface[string] response, errWithCode = chatProvider.CreateChatCompletionStream(&chatRequest) if errWithCode != nil { errorHelper(c, errWithCode) return } - errWithCode = responseStreamClient[types.ChatCompletionStreamResponse](c, response) + errWithCode = responseStreamClient(c, response) } else { var response *types.ChatCompletionResponse response, errWithCode = chatProvider.CreateChatCompletion(&chatRequest) @@ -70,8 +69,6 @@ func RelayChat(c *gin.Context) { errWithCode = responseJsonClient(c, response) } - fmt.Println(usage) - // 如果报错,则退还配额 if errWithCode != nil { quotaInfo.undo(c, errWithCode) diff --git a/controller/relay-completions.go b/controller/relay-completions.go index 4af1e685..0898a016 100644 --- a/controller/relay-completions.go +++ b/controller/relay-completions.go @@ -4,6 +4,7 @@ import ( "math" "net/http" "one-api/common" + "one-api/common/requester" providersBase "one-api/providers/base" "one-api/types" @@ -51,14 +52,16 @@ func RelayCompletions(c *gin.Context) { } if completionRequest.Stream { - response, errWithCode := completionProvider.CreateCompletionStream(&completionRequest) + var response requester.StreamReaderInterface[string] + response, errWithCode = completionProvider.CreateCompletionStream(&completionRequest) if errWithCode != nil { errorHelper(c, errWithCode) return } - errWithCode = responseStreamClient[types.CompletionResponse](c, response) + errWithCode = responseStreamClient(c, response) } else { - response, errWithCode := completionProvider.CreateCompletion(&completionRequest) + var response *types.CompletionResponse + response, errWithCode = completionProvider.CreateCompletion(&completionRequest) if errWithCode != nil { errorHelper(c, errWithCode) return diff --git a/controller/relay-utils.go b/controller/relay-utils.go index ec2e6ede..03d452ad 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -154,34 +154,25 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith return nil } -func responseStreamClient[T any](c *gin.Context, stream requester.StreamReaderInterface[T]) *types.OpenAIErrorWithStatusCode { +func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string]) *types.OpenAIErrorWithStatusCode { requester.SetEventStreamHeaders(c) + dataChan, errChan := stream.Recv() + defer stream.Close() - - for { - response, err := stream.Recv() - if errors.Is(err, io.EOF) { - if response != nil && len(*response) > 0 { - for _, streamResponse := range *response { - responseBody, _ := json.Marshal(streamResponse) - c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)}) - } + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + fmt.Fprintln(w, "data: "+data+"\n") + return true + case err := <-errChan: + if !errors.Is(err, io.EOF) { + fmt.Fprintln(w, "data: "+err.Error()+"\n") } - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - break - } - if err != nil { - c.Render(-1, common.CustomEvent{Data: "data: " + err.Error()}) - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - break + fmt.Fprintln(w, "data: [DONE]") + return false } - - for _, streamResponse := range *response { - responseBody, _ := json.Marshal(streamResponse) - c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)}) - } - } + }) return nil } diff --git a/providers/ali/chat.go b/providers/ali/chat.go index b38b3c0f..577dc7ae 100644 --- a/providers/ali/chat.go +++ b/providers/ali/chat.go @@ -34,7 +34,7 @@ func (p *AliProvider) CreateChatCompletion(request *types.ChatCompletionRequest) return p.convertToChatOpenai(aliResponse, request) } -func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { +func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { req, errWithCode := p.getAliChatRequest(request) if errWithCode != nil { return nil, errWithCode @@ -52,7 +52,7 @@ func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRe Request: request, } - return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) + return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream) } func (p *AliProvider) getAliChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { @@ -162,11 +162,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *AliChatRequest } // 转换为OpenAI聊天流式请求体 -func (h *aliStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { +func (h *aliStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { // 如果rawLine 前缀不为data:,则直接返回 if !strings.HasPrefix(string(*rawLine), "data:") { *rawLine = nil - return nil + return } // 去除前缀 @@ -175,19 +175,21 @@ func (h *aliStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, resp var aliResponse AliChatResponse err := json.Unmarshal(*rawLine, &aliResponse) if err != nil { - return common.ErrorToOpenAIError(err) + errChan <- common.ErrorToOpenAIError(err) + return } error := errorHandle(&aliResponse.AliError) if error != nil { - return error + errChan <- error + return } - return h.convertToOpenaiStream(&aliResponse, response) + h.convertToOpenaiStream(&aliResponse, dataChan, errChan) } -func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, response *[]types.ChatCompletionStreamResponse) error { +func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, dataChan chan string, errChan chan error) { content := aliResponse.Output.Choices[0].Message.StringContent() var choice types.ChatCompletionStreamChoice @@ -222,7 +224,6 @@ func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, r h.Usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens } - *response = append(*response, streamResponse) - - return nil + responseBody, _ := json.Marshal(streamResponse) + dataChan <- string(responseBody) } diff --git a/providers/baichuan/chat.go b/providers/baichuan/chat.go index 39dc3894..7e42367f 100644 --- a/providers/baichuan/chat.go +++ b/providers/baichuan/chat.go @@ -39,7 +39,7 @@ func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionReq return &response.ChatCompletionResponse, nil } -func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { +func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) if errWithCode != nil { return nil, errWithCode @@ -57,7 +57,7 @@ func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatComplet ModelName: request.Model, } - return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream) + return requester.RequestStream[string](p.Requester, resp, chatHandler.HandlerChatStream) } // 获取聊天请求体 diff --git a/providers/baidu/chat.go b/providers/baidu/chat.go index 68a60ef8..fd999e45 100644 --- a/providers/baidu/chat.go +++ b/providers/baidu/chat.go @@ -2,6 +2,7 @@ package baidu import ( "encoding/json" + "io" "net/http" "one-api/common" "one-api/common/requester" @@ -31,7 +32,7 @@ func (p *BaiduProvider) CreateChatCompletion(request *types.ChatCompletionReques return p.convertToChatOpenai(baiduResponse, request) } -func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { +func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { req, errWithCode := p.getBaiduChatRequest(request) if errWithCode != nil { return nil, errWithCode @@ -49,7 +50,7 @@ func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletion Request: request, } - return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) + return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream) } func (p *BaiduProvider) getBaiduChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { @@ -178,11 +179,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *BaiduChatReque } // 转换为OpenAI聊天流式请求体 -func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { +func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { // 如果rawLine 前缀不为data:,则直接返回 if !strings.HasPrefix(string(*rawLine), "data: ") { *rawLine = nil - return nil + return } // 去除前缀 @@ -191,18 +192,26 @@ func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, re var baiduResponse BaiduChatStreamResponse err := json.Unmarshal(*rawLine, &baiduResponse) if err != nil { - return common.ErrorToOpenAIError(err) + errChan <- common.ErrorToOpenAIError(err) + return } + error := errorHandle(&baiduResponse.BaiduError) + if error != nil { + errChan <- error + return + } + + h.convertToOpenaiStream(&baiduResponse, dataChan, errChan) + if baiduResponse.IsEnd { - *isFinished = true + errChan <- io.EOF + *rawLine = requester.StreamClosed + return } - - return h.convertToOpenaiStream(&baiduResponse, response) - } -func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, response *[]types.ChatCompletionStreamResponse) error { +func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, dataChan chan string, errChan chan error) { choice := types.ChatCompletionStreamChoice{ Index: 0, Delta: types.ChatCompletionStreamChoiceDelta{ @@ -240,19 +249,19 @@ func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStrea if baiduResponse.FunctionCall == nil { chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice} - *response = append(*response, chatCompletion) + responseBody, _ := json.Marshal(chatCompletion) + dataChan <- string(responseBody) } else { choices := choice.ConvertOpenaiStream() for _, choice := range choices { chatCompletionCopy := chatCompletion chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice} - *response = append(*response, chatCompletionCopy) + responseBody, _ := json.Marshal(chatCompletionCopy) + dataChan <- string(responseBody) } } h.Usage.TotalTokens = baiduResponse.Usage.TotalTokens h.Usage.PromptTokens = baiduResponse.Usage.PromptTokens h.Usage.CompletionTokens += baiduResponse.Usage.CompletionTokens - - return nil } diff --git a/providers/base/interface.go b/providers/base/interface.go index 84889acb..868f5711 100644 --- a/providers/base/interface.go +++ b/providers/base/interface.go @@ -41,14 +41,14 @@ type ProviderInterface interface { type CompletionInterface interface { ProviderInterface CreateCompletion(request *types.CompletionRequest) (*types.CompletionResponse, *types.OpenAIErrorWithStatusCode) - CreateCompletionStream(request *types.CompletionRequest) (requester.StreamReaderInterface[types.CompletionResponse], *types.OpenAIErrorWithStatusCode) + CreateCompletionStream(request *types.CompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) } // 聊天接口 type ChatInterface interface { ProviderInterface CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) - CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) + CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) } // 嵌入接口 diff --git a/providers/claude/chat.go b/providers/claude/chat.go index f7f5d0d2..752ea264 100644 --- a/providers/claude/chat.go +++ b/providers/claude/chat.go @@ -3,6 +3,7 @@ package claude import ( "encoding/json" "fmt" + "io" "net/http" "one-api/common" "one-api/common/requester" @@ -32,7 +33,7 @@ func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionReque return p.convertToChatOpenai(claudeResponse, request) } -func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { +func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { req, errWithCode := p.getChatRequest(request) if errWithCode != nil { return nil, errWithCode @@ -50,7 +51,7 @@ func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletio Request: request, } - return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) + return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream) } func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { @@ -149,11 +150,11 @@ func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request * } // 转换为OpenAI聊天流式请求体 -func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { +func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { // 如果rawLine 前缀不为data:,则直接返回 if !strings.HasPrefix(string(*rawLine), `data: {"type": "completion"`) { *rawLine = nil - return nil + return } // 去除前缀 @@ -162,17 +163,26 @@ func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, r var claudeResponse *ClaudeResponse err := json.Unmarshal(*rawLine, claudeResponse) if err != nil { - return common.ErrorToOpenAIError(err) + errChan <- common.ErrorToOpenAIError(err) + return + } + + error := errorHandle(&claudeResponse.ClaudeResponseError) + if error != nil { + errChan <- error + return } if claudeResponse.StopReason == "stop_sequence" { - *isFinished = true + errChan <- io.EOF + *rawLine = requester.StreamClosed + return } - return h.convertToOpenaiStream(claudeResponse, response) + h.convertToOpenaiStream(claudeResponse, dataChan, errChan) } -func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeResponse, response *[]types.ChatCompletionStreamResponse) error { +func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeResponse, dataChan chan string, errChan chan error) { var choice types.ChatCompletionStreamChoice choice.Delta.Content = claudeResponse.Completion finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) @@ -185,9 +195,8 @@ func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeRespon Choices: []types.ChatCompletionStreamChoice{choice}, } - *response = append(*response, chatCompletion) + responseBody, _ := json.Marshal(chatCompletion) + dataChan <- string(responseBody) h.Usage.PromptTokens += common.CountTokenText(claudeResponse.Completion, h.Request.Model) - - return nil } diff --git a/providers/gemini/chat.go b/providers/gemini/chat.go index acaedb97..36389220 100644 --- a/providers/gemini/chat.go +++ b/providers/gemini/chat.go @@ -37,7 +37,7 @@ func (p *GeminiProvider) CreateChatCompletion(request *types.ChatCompletionReque return p.convertToChatOpenai(geminiChatResponse, request) } -func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { +func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { req, errWithCode := p.getChatRequest(request) if errWithCode != nil { return nil, errWithCode @@ -55,7 +55,7 @@ func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletio Request: request, } - return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) + return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream) } func (p *GeminiProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { @@ -228,11 +228,11 @@ func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, reque } // 转换为OpenAI聊天流式请求体 -func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { +func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { // 如果rawLine 前缀不为data:,则直接返回 if !strings.HasPrefix(string(*rawLine), "data: ") { *rawLine = nil - return nil + return } // 去除前缀 @@ -241,19 +241,21 @@ func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, r var geminiResponse GeminiChatResponse err := json.Unmarshal(*rawLine, &geminiResponse) if err != nil { - return common.ErrorToOpenAIError(err) + errChan <- common.ErrorToOpenAIError(err) + return } error := errorHandle(&geminiResponse.GeminiErrorResponse) if error != nil { - return error + errChan <- error + return } - return h.convertToOpenaiStream(&geminiResponse, response) + h.convertToOpenaiStream(&geminiResponse, dataChan, errChan) } -func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, response *[]types.ChatCompletionStreamResponse) error { +func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string, errChan chan error) { choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates)) for i, candidate := range geminiResponse.Candidates { @@ -275,10 +277,9 @@ func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatRe Choices: choices, } - *response = append(*response, streamResponse) + responseBody, _ := json.Marshal(streamResponse) + dataChan <- string(responseBody) h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model) h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens - - return nil } diff --git a/providers/minimax/chat.go b/providers/minimax/chat.go index f2f4be41..b17e40bc 100644 --- a/providers/minimax/chat.go +++ b/providers/minimax/chat.go @@ -32,7 +32,7 @@ func (p *MiniMaxProvider) CreateChatCompletion(request *types.ChatCompletionRequ return p.convertToChatOpenai(response, request) } -func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { +func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { req, errWithCode := p.getChatRequest(request) if errWithCode != nil { return nil, errWithCode @@ -50,7 +50,7 @@ func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompleti Request: request, } - return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) + return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream) } func (p *MiniMaxProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { @@ -191,11 +191,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *MiniMaxChatReq } // 转换为OpenAI聊天流式请求体 -func (h *minimaxStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { +func (h *minimaxStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { // 如果rawLine 前缀不为data: 或者 meta:,则直接返回 if !strings.HasPrefix(string(*rawLine), "data: ") { *rawLine = nil - return nil + return } *rawLine = (*rawLine)[6:] @@ -203,25 +203,27 @@ func (h *minimaxStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, miniResponse := &MiniMaxChatResponse{} err := json.Unmarshal(*rawLine, miniResponse) if err != nil { - return common.ErrorToOpenAIError(err) + errChan <- common.ErrorToOpenAIError(err) + return } error := errorHandle(&miniResponse.BaseResp) if error != nil { - return error + errChan <- error + return } choice := miniResponse.Choices[0] if choice.Messages[0].FunctionCall != nil && choice.FinishReason == "" { *rawLine = nil - return nil + return } - return h.convertToOpenaiStream(miniResponse, response) + h.convertToOpenaiStream(miniResponse, dataChan, errChan) } -func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatResponse, response *[]types.ChatCompletionStreamResponse) error { +func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatResponse, dataChan chan string, errChan chan error) { streamResponse := types.ChatCompletionStreamResponse{ ID: miniResponse.RequestID, Object: "chat.completion.chunk", @@ -235,8 +237,8 @@ func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatRe if miniChoice.Messages[0].FunctionCall == nil && miniChoice.FinishReason != "" { streamResponse.ID = miniResponse.ID openaiChoice.FinishReason = convertFinishReason(miniChoice.FinishReason) - h.appendResponse(&streamResponse, &openaiChoice, response) - return nil + dataChan <- h.getResponseString(&streamResponse, &openaiChoice) + return } openaiChoice.Delta = types.ChatCompletionStreamChoiceDelta{ @@ -248,19 +250,17 @@ func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatRe convertChoices := openaiChoice.ConvertOpenaiStream() for _, convertChoice := range convertChoices { chatCompletionCopy := streamResponse - h.appendResponse(&chatCompletionCopy, &convertChoice, response) + dataChan <- h.getResponseString(&chatCompletionCopy, &convertChoice) } } else { openaiChoice.Delta.Content = miniChoice.Messages[0].Text - h.appendResponse(&streamResponse, &openaiChoice, response) + dataChan <- h.getResponseString(&streamResponse, &openaiChoice) } if miniResponse.Usage != nil { h.handleUsage(miniResponse) } - - return nil } func (h *minimaxStreamHandler) handleFunctionCall(choice *Choice, openaiChoice *types.ChatCompletionStreamChoice) { @@ -274,9 +274,10 @@ func (h *minimaxStreamHandler) handleFunctionCall(choice *Choice, openaiChoice * } } -func (h *minimaxStreamHandler) appendResponse(streamResponse *types.ChatCompletionStreamResponse, openaiChoice *types.ChatCompletionStreamChoice, response *[]types.ChatCompletionStreamResponse) { +func (h *minimaxStreamHandler) getResponseString(streamResponse *types.ChatCompletionStreamResponse, openaiChoice *types.ChatCompletionStreamChoice) string { streamResponse.Choices = []types.ChatCompletionStreamChoice{*openaiChoice} - *response = append(*response, *streamResponse) + responseBody, _ := json.Marshal(streamResponse) + return string(responseBody) } func (h *minimaxStreamHandler) handleUsage(miniResponse *MiniMaxChatResponse) { diff --git a/providers/openai/chat.go b/providers/openai/chat.go index 3186c889..7e892b38 100644 --- a/providers/openai/chat.go +++ b/providers/openai/chat.go @@ -2,16 +2,19 @@ package openai import ( "encoding/json" + "io" "net/http" "one-api/common" "one-api/common/requester" "one-api/types" "strings" + "time" ) type OpenAIStreamHandler struct { Usage *types.Usage ModelName string + isAzure bool } func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { @@ -43,7 +46,7 @@ func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionReque return &response.ChatCompletionResponse, nil } -func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { +func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) if errWithCode != nil { return nil, errWithCode @@ -59,16 +62,17 @@ func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletio chatHandler := OpenAIStreamHandler{ Usage: p.Usage, ModelName: request.Model, + isAzure: p.IsAzure, } - return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream) + return requester.RequestStream[string](p.Requester, resp, chatHandler.HandlerChatStream) } -func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { +func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, dataChan chan string, errChan chan error) { // 如果rawLine 前缀不为data:,则直接返回 if !strings.HasPrefix(string(*rawLine), "data: ") { *rawLine = nil - return nil + return } // 去除前缀 @@ -76,26 +80,32 @@ func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, isFinished *boo // 如果等于 DONE 则结束 if string(*rawLine) == "[DONE]" { - *isFinished = true - return nil + errChan <- io.EOF + *rawLine = requester.StreamClosed + return } var openaiResponse OpenAIProviderChatStreamResponse err := json.Unmarshal(*rawLine, &openaiResponse) if err != nil { - return common.ErrorToOpenAIError(err) + errChan <- common.ErrorToOpenAIError(err) + return } error := ErrorHandle(&openaiResponse.OpenAIErrorResponse) if error != nil { - return error + errChan <- error + return + } + + dataChan <- string(*rawLine) + + if h.isAzure { + // 阻塞 20ms + time.Sleep(20 * time.Millisecond) } countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName) h.Usage.CompletionTokens += countTokenText h.Usage.TotalTokens += countTokenText - - *response = append(*response, openaiResponse.ChatCompletionStreamResponse) - - return nil } diff --git a/providers/openai/completion.go b/providers/openai/completion.go index 81bbd505..2fc04859 100644 --- a/providers/openai/completion.go +++ b/providers/openai/completion.go @@ -2,6 +2,7 @@ package openai import ( "encoding/json" + "io" "net/http" "one-api/common" "one-api/common/requester" @@ -38,7 +39,7 @@ func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (ope return &response.CompletionResponse, nil } -func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[types.CompletionResponse], errWithCode *types.OpenAIErrorWithStatusCode) { +func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[string], errWithCode *types.OpenAIErrorWithStatusCode) { req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) if errWithCode != nil { return nil, errWithCode @@ -56,14 +57,14 @@ func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest ModelName: request.Model, } - return requester.RequestStream[types.CompletionResponse](p.Requester, resp, chatHandler.handlerCompletionStream) + return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerCompletionStream) } -func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, isFinished *bool, response *[]types.CompletionResponse) error { +func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, dataChan chan string, errChan chan error) { // 如果rawLine 前缀不为data:,则直接返回 if !strings.HasPrefix(string(*rawLine), "data: ") { *rawLine = nil - return nil + return } // 去除前缀 @@ -71,26 +72,27 @@ func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, isFinishe // 如果等于 DONE 则结束 if string(*rawLine) == "[DONE]" { - *isFinished = true - return nil + errChan <- io.EOF + *rawLine = requester.StreamClosed + return } var openaiResponse OpenAIProviderCompletionResponse err := json.Unmarshal(*rawLine, &openaiResponse) if err != nil { - return common.ErrorToOpenAIError(err) + errChan <- common.ErrorToOpenAIError(err) + return } error := ErrorHandle(&openaiResponse.OpenAIErrorResponse) if error != nil { - return error + errChan <- error + return } + dataChan <- string(*rawLine) + countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName) h.Usage.CompletionTokens += countTokenText h.Usage.TotalTokens += countTokenText - - *response = append(*response, openaiResponse.CompletionResponse) - - return nil } diff --git a/providers/palm/chat.go b/providers/palm/chat.go index f7fd4fdd..088f2d64 100644 --- a/providers/palm/chat.go +++ b/providers/palm/chat.go @@ -32,7 +32,7 @@ func (p *PalmProvider) CreateChatCompletion(request *types.ChatCompletionRequest return p.convertToChatOpenai(palmResponse, request) } -func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { +func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { req, errWithCode := p.getChatRequest(request) if errWithCode != nil { return nil, errWithCode @@ -50,7 +50,7 @@ func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionR Request: request, } - return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) + return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream) } func (p *PalmProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { @@ -142,11 +142,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *PaLMChatReques } // 转换为OpenAI聊天流式请求体 -func (h *palmStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { +func (h *palmStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { // 如果rawLine 前缀不为data:,则直接返回 if !strings.HasPrefix(string(*rawLine), "data: ") { *rawLine = nil - return nil + return } // 去除前缀 @@ -155,19 +155,21 @@ func (h *palmStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, res var palmChatResponse PaLMChatResponse err := json.Unmarshal(*rawLine, &palmChatResponse) if err != nil { - return common.ErrorToOpenAIError(err) + errChan <- common.ErrorToOpenAIError(err) + return } error := errorHandle(&palmChatResponse.PaLMErrorResponse) if error != nil { - return error + errChan <- error + return } - return h.convertToOpenaiStream(&palmChatResponse, response) + h.convertToOpenaiStream(&palmChatResponse, dataChan, errChan) } -func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResponse, response *[]types.ChatCompletionStreamResponse) error { +func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResponse, dataChan chan string, errChan chan error) { var choice types.ChatCompletionStreamChoice if len(palmChatResponse.Candidates) > 0 { choice.Delta.Content = palmChatResponse.Candidates[0].Content @@ -182,10 +184,9 @@ func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResp Created: common.GetTimestamp(), } - *response = append(*response, streamResponse) + responseBody, _ := json.Marshal(streamResponse) + dataChan <- string(responseBody) h.Usage.CompletionTokens += common.CountTokenText(palmChatResponse.Candidates[0].Content, h.Request.Model) h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens - - return nil } diff --git a/providers/tencent/chat.go b/providers/tencent/chat.go index 32e2e673..9977d274 100644 --- a/providers/tencent/chat.go +++ b/providers/tencent/chat.go @@ -32,7 +32,7 @@ func (p *TencentProvider) CreateChatCompletion(request *types.ChatCompletionRequ return p.convertToChatOpenai(tencentChatResponse, request) } -func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { +func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { req, errWithCode := p.getChatRequest(request) if errWithCode != nil { return nil, errWithCode @@ -50,7 +50,7 @@ func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompleti Request: request, } - return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) + return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream) } func (p *TencentProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { @@ -157,11 +157,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *TencentChatReq } // 转换为OpenAI聊天流式请求体 -func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { +func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { // 如果rawLine 前缀不为data:,则直接返回 if !strings.HasPrefix(string(*rawLine), "data:") { *rawLine = nil - return nil + return } // 去除前缀 @@ -170,19 +170,21 @@ func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, var tencentChatResponse TencentChatResponse err := json.Unmarshal(*rawLine, &tencentChatResponse) if err != nil { - return common.ErrorToOpenAIError(err) + errChan <- common.ErrorToOpenAIError(err) + return } error := errorHandle(&tencentChatResponse.TencentResponseError) if error != nil { - return error + errChan <- error + return } - return h.convertToOpenaiStream(&tencentChatResponse, response) + h.convertToOpenaiStream(&tencentChatResponse, dataChan, errChan) } -func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, response *[]types.ChatCompletionStreamResponse) error { +func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, dataChan chan string, errChan chan error) { streamResponse := types.ChatCompletionStreamResponse{ Object: "chat.completion.chunk", Created: common.GetTimestamp(), @@ -197,10 +199,9 @@ func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *Tencen streamResponse.Choices = append(streamResponse.Choices, choice) } - *response = append(*response, streamResponse) + responseBody, _ := json.Marshal(streamResponse) + dataChan <- string(responseBody) h.Usage.CompletionTokens += common.CountTokenText(tencentChatResponse.Choices[0].Delta.Content, h.Request.Model) h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens - - return nil } diff --git a/providers/xunfei/chat.go b/providers/xunfei/chat.go index b7966387..c93920f0 100644 --- a/providers/xunfei/chat.go +++ b/providers/xunfei/chat.go @@ -40,7 +40,7 @@ func (p *XunfeiProvider) CreateChatCompletion(request *types.ChatCompletionReque } -func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { +func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { wsConn, errWithCode := p.getChatRequest(request) if errWithCode != nil { return nil, errWithCode @@ -53,7 +53,7 @@ func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletio Request: request, } - return requester.SendWSJsonRequest[types.ChatCompletionStreamResponse](wsConn, xunfeiRequest, chatHandler.handlerStream) + return requester.SendWSJsonRequest[string](wsConn, xunfeiRequest, chatHandler.handlerStream) } func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (*websocket.Conn, *types.OpenAIErrorWithStatusCode) { @@ -123,23 +123,26 @@ func (p *XunfeiProvider) convertFromChatOpenai(request *types.ChatCompletionRequ func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterface[XunfeiChatResponse]) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { var content string var xunfeiResponse XunfeiChatResponse + dataChan, errChan := stream.Recv() - for { - response, err := stream.Recv() + stop := false + for !stop { + select { + case response := <-dataChan: + if len(response.Payload.Choices.Text) == 0 { + continue + } + xunfeiResponse = response + content += xunfeiResponse.Payload.Choices.Text[0].Content + case err := <-errChan: + if err != nil && !errors.Is(err, io.EOF) { + return nil, common.ErrorWrapper(err, "xunfei_failed", http.StatusInternalServerError) + } - if err != nil && !errors.Is(err, io.EOF) { - return nil, common.ErrorWrapper(err, "xunfei_failed", http.StatusInternalServerError) + if errors.Is(err, io.EOF) { + stop = true + } } - - if errors.Is(err, io.EOF) && response == nil { - break - } - - if len((*response)[0].Payload.Choices.Text) == 0 { - continue - } - xunfeiResponse = (*response)[0] - content += xunfeiResponse.Payload.Choices.Text[0].Content } if len(xunfeiResponse.Payload.Choices.Text) == 0 { @@ -193,7 +196,7 @@ func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterfa } func (h *xunfeiHandler) handlerData(rawLine *[]byte, isFinished *bool) (*XunfeiChatResponse, error) { - // 如果rawLine 前缀不为data:,则直接返回 + // 如果rawLine 前缀不为{,则直接返回 if !strings.HasPrefix(string(*rawLine), "{") { *rawLine = nil return nil, nil @@ -221,34 +224,47 @@ func (h *xunfeiHandler) handlerData(rawLine *[]byte, isFinished *bool) (*XunfeiC return &xunfeiChatResponse, nil } -func (h *xunfeiHandler) handlerNotStream(rawLine *[]byte, isFinished *bool, response *[]XunfeiChatResponse) error { - xunfeiChatResponse, err := h.handlerData(rawLine, isFinished) +func (h *xunfeiHandler) handlerNotStream(rawLine *[]byte, dataChan chan XunfeiChatResponse, errChan chan error) { + isFinished := false + xunfeiChatResponse, err := h.handlerData(rawLine, &isFinished) if err != nil { - return err + errChan <- err + return } if *rawLine == nil { - return nil + return } - *response = append(*response, *xunfeiChatResponse) - return nil + dataChan <- *xunfeiChatResponse + + if isFinished { + errChan <- io.EOF + *rawLine = requester.StreamClosed + } } -func (h *xunfeiHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { - xunfeiChatResponse, err := h.handlerData(rawLine, isFinished) +func (h *xunfeiHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { + isFinished := false + xunfeiChatResponse, err := h.handlerData(rawLine, &isFinished) if err != nil { - return err + errChan <- err + return } if *rawLine == nil { - return nil + return } - return h.convertToOpenaiStream(xunfeiChatResponse, response) + h.convertToOpenaiStream(xunfeiChatResponse, dataChan, errChan) + + if isFinished { + errChan <- io.EOF + *rawLine = requester.StreamClosed + } } -func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResponse, response *[]types.ChatCompletionStreamResponse) error { +func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResponse, dataChan chan string, errChan chan error) { if len(xunfeiChatResponse.Payload.Choices.Text) == 0 { xunfeiChatResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}} } @@ -293,15 +309,15 @@ func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResp if xunfeiText.FunctionCall == nil { chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice} - *response = append(*response, chatCompletion) + responseBody, _ := json.Marshal(chatCompletion) + dataChan <- string(responseBody) } else { choices := choice.ConvertOpenaiStream() for _, choice := range choices { chatCompletionCopy := chatCompletion chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice} - *response = append(*response, chatCompletionCopy) + responseBody, _ := json.Marshal(chatCompletionCopy) + dataChan <- string(responseBody) } } - - return nil } diff --git a/providers/zhipu/chat.go b/providers/zhipu/chat.go index d6431213..2ff589e2 100644 --- a/providers/zhipu/chat.go +++ b/providers/zhipu/chat.go @@ -2,6 +2,7 @@ package zhipu import ( "encoding/json" + "io" "net/http" "one-api/common" "one-api/common/requester" @@ -31,7 +32,7 @@ func (p *ZhipuProvider) CreateChatCompletion(request *types.ChatCompletionReques return p.convertToChatOpenai(zhipuChatResponse, request) } -func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { +func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { req, errWithCode := p.getChatRequest(request) if errWithCode != nil { return nil, errWithCode @@ -49,7 +50,7 @@ func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletion Request: request, } - return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) + return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream) } func (p *ZhipuProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { @@ -140,35 +141,38 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest { } // 转换为OpenAI聊天流式请求体 -func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { +func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) { // 如果rawLine 前缀不为data: 或者 meta:,则直接返回 if !strings.HasPrefix(string(*rawLine), "data: ") { *rawLine = nil - return nil + return } *rawLine = (*rawLine)[6:] if strings.HasPrefix(string(*rawLine), "[DONE]") { - *isFinished = true - return nil + errChan <- io.EOF + *rawLine = requester.StreamClosed + return } zhipuResponse := &ZhipuStreamResponse{} err := json.Unmarshal(*rawLine, zhipuResponse) if err != nil { - return common.ErrorToOpenAIError(err) + errChan <- common.ErrorToOpenAIError(err) + return } error := errorHandle(&zhipuResponse.Error) if error != nil { - return error + errChan <- error + return } - return h.convertToOpenaiStream(zhipuResponse, response) + h.convertToOpenaiStream(zhipuResponse, dataChan, errChan) } -func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamResponse, response *[]types.ChatCompletionStreamResponse) error { +func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamResponse, dataChan chan string, errChan chan error) { streamResponse := types.ChatCompletionStreamResponse{ ID: zhipuResponse.ID, Object: "chat.completion.chunk", @@ -183,16 +187,16 @@ func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamRes for _, choice := range choices { chatCompletionCopy := streamResponse chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice} - *response = append(*response, chatCompletionCopy) + responseBody, _ := json.Marshal(chatCompletionCopy) + dataChan <- string(responseBody) } } else { streamResponse.Choices = []types.ChatCompletionStreamChoice{choice} - *response = append(*response, streamResponse) + responseBody, _ := json.Marshal(streamResponse) + dataChan <- string(responseBody) } if zhipuResponse.Usage != nil { *h.Usage = *zhipuResponse.Usage } - - return nil }