🐛 fix: stream mode delay issue (#53)

This commit is contained in:
Buer 2024-01-25 11:56:31 +08:00 committed by GitHub
parent 705804e6dd
commit d7193b8e46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 291 additions and 262 deletions

View File

@ -127,11 +127,16 @@ func RequestStream[T streamable](requester *HTTPRequester, resp *http.Response,
return nil, HandleErrorResp(resp, requester.ErrorHandler) return nil, HandleErrorResp(resp, requester.ErrorHandler)
} }
return &streamReader[T]{ stream := &streamReader[T]{
reader: bufio.NewReader(resp.Body), reader: bufio.NewReader(resp.Body),
response: resp, response: resp,
handlerPrefix: handlerPrefix, handlerPrefix: handlerPrefix,
}, nil
DataChan: make(chan T),
ErrChan: make(chan error),
}
return stream, nil
} }
// 设置请求体 // 设置请求体

View File

@ -3,16 +3,12 @@ package requester
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"io"
"net/http" "net/http"
) )
// 流处理函数,判断依据如下: var StreamClosed = []byte("stream_closed")
// 1.如果有错误信息,则直接返回错误信息
// 2.如果isFinished=true则返回io.EOF并且如果response不为空还将返回response type HandlerPrefix[T streamable] func(rawLine *[]byte, dataChan chan T, errChan chan error)
// 3.如果rawLine=nil 或者 response长度为0则直接跳过
// 4.如果以上条件都不满足则返回response
type HandlerPrefix[T streamable] func(rawLine *[]byte, isFinished *bool, response *[]T) error
type streamable interface { type streamable interface {
// types.ChatCompletionStreamResponse | types.CompletionResponse // types.ChatCompletionStreamResponse | types.CompletionResponse
@ -20,57 +16,48 @@ type streamable interface {
} }
type StreamReaderInterface[T streamable] interface { type StreamReaderInterface[T streamable] interface {
Recv() (*[]T, error) Recv() (<-chan T, <-chan error)
Close() Close()
} }
type streamReader[T streamable] struct { type streamReader[T streamable] struct {
isFinished bool
reader *bufio.Reader reader *bufio.Reader
response *http.Response response *http.Response
handlerPrefix HandlerPrefix[T] handlerPrefix HandlerPrefix[T]
DataChan chan T
ErrChan chan error
} }
func (stream *streamReader[T]) Recv() (response *[]T, err error) { func (stream *streamReader[T]) Recv() (<-chan T, <-chan error) {
if stream.isFinished { go stream.processLines()
err = io.EOF
return return stream.DataChan, stream.ErrChan
}
response, err = stream.processLines()
return
} }
//nolint:gocognit //nolint:gocognit
func (stream *streamReader[T]) processLines() (*[]T, error) { func (stream *streamReader[T]) processLines() {
for { for {
rawLine, readErr := stream.reader.ReadBytes('\n') rawLine, readErr := stream.reader.ReadBytes('\n')
if readErr != nil { if readErr != nil {
return nil, readErr stream.ErrChan <- readErr
return
} }
noSpaceLine := bytes.TrimSpace(rawLine) noSpaceLine := bytes.TrimSpace(rawLine)
if len(noSpaceLine) == 0 {
var response []T
err := stream.handlerPrefix(&noSpaceLine, &stream.isFinished, &response)
if err != nil {
return nil, err
}
if stream.isFinished {
if len(response) > 0 {
return &response, io.EOF
}
return nil, io.EOF
}
if noSpaceLine == nil || len(response) == 0 {
continue continue
} }
return &response, nil stream.handlerPrefix(&noSpaceLine, stream.DataChan, stream.ErrChan)
if noSpaceLine == nil {
continue
}
if bytes.Equal(noSpaceLine, StreamClosed) {
return
}
} }
} }

View File

@ -1,55 +1,41 @@
package requester package requester
import ( import (
"io" "bytes"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
type wsReader[T streamable] struct { type wsReader[T streamable] struct {
isFinished bool
reader *websocket.Conn reader *websocket.Conn
handlerPrefix HandlerPrefix[T] handlerPrefix HandlerPrefix[T]
DataChan chan T
ErrChan chan error
} }
func (stream *wsReader[T]) Recv() (response *[]T, err error) { func (stream *wsReader[T]) Recv() (<-chan T, <-chan error) {
if stream.isFinished { go stream.processLines()
err = io.EOF return stream.DataChan, stream.ErrChan
return
}
response, err = stream.processLines()
return
} }
func (stream *wsReader[T]) processLines() (*[]T, error) { func (stream *wsReader[T]) processLines() {
for { for {
_, msg, err := stream.reader.ReadMessage() _, msg, err := stream.reader.ReadMessage()
if err != nil { if err != nil {
return nil, err stream.ErrChan <- err
return
} }
var response []T stream.handlerPrefix(&msg, stream.DataChan, stream.ErrChan)
err = stream.handlerPrefix(&msg, &stream.isFinished, &response)
if err != nil { if msg == nil {
return nil, err
}
if stream.isFinished {
if len(response) > 0 {
return &response, io.EOF
}
return nil, io.EOF
}
if msg == nil || len(response) == 0 {
continue continue
} }
return &response, nil if bytes.Equal(msg, StreamClosed) {
return
}
} }
} }

View File

@ -38,10 +38,15 @@ func SendWSJsonRequest[T streamable](conn *websocket.Conn, data any, handlerPref
return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError) return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError)
} }
return &wsReader[T]{ stream := &wsReader[T]{
reader: conn, reader: conn,
handlerPrefix: handlerPrefix, handlerPrefix: handlerPrefix,
}, nil
DataChan: make(chan T),
ErrChan: make(chan error),
}
return stream, nil
} }
// 设置请求头 // 设置请求头

View File

@ -1,7 +1,6 @@
package controller package controller
import ( import (
"fmt"
"math" "math"
"net/http" "net/http"
"one-api/common" "one-api/common"
@ -53,13 +52,13 @@ func RelayChat(c *gin.Context) {
} }
if chatRequest.Stream { if chatRequest.Stream {
var response requester.StreamReaderInterface[types.ChatCompletionStreamResponse] var response requester.StreamReaderInterface[string]
response, errWithCode = chatProvider.CreateChatCompletionStream(&chatRequest) response, errWithCode = chatProvider.CreateChatCompletionStream(&chatRequest)
if errWithCode != nil { if errWithCode != nil {
errorHelper(c, errWithCode) errorHelper(c, errWithCode)
return return
} }
errWithCode = responseStreamClient[types.ChatCompletionStreamResponse](c, response) errWithCode = responseStreamClient(c, response)
} else { } else {
var response *types.ChatCompletionResponse var response *types.ChatCompletionResponse
response, errWithCode = chatProvider.CreateChatCompletion(&chatRequest) response, errWithCode = chatProvider.CreateChatCompletion(&chatRequest)
@ -70,8 +69,6 @@ func RelayChat(c *gin.Context) {
errWithCode = responseJsonClient(c, response) errWithCode = responseJsonClient(c, response)
} }
fmt.Println(usage)
// 如果报错,则退还配额 // 如果报错,则退还配额
if errWithCode != nil { if errWithCode != nil {
quotaInfo.undo(c, errWithCode) quotaInfo.undo(c, errWithCode)

View File

@ -4,6 +4,7 @@ import (
"math" "math"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester"
providersBase "one-api/providers/base" providersBase "one-api/providers/base"
"one-api/types" "one-api/types"
@ -51,14 +52,16 @@ func RelayCompletions(c *gin.Context) {
} }
if completionRequest.Stream { if completionRequest.Stream {
response, errWithCode := completionProvider.CreateCompletionStream(&completionRequest) var response requester.StreamReaderInterface[string]
response, errWithCode = completionProvider.CreateCompletionStream(&completionRequest)
if errWithCode != nil { if errWithCode != nil {
errorHelper(c, errWithCode) errorHelper(c, errWithCode)
return return
} }
errWithCode = responseStreamClient[types.CompletionResponse](c, response) errWithCode = responseStreamClient(c, response)
} else { } else {
response, errWithCode := completionProvider.CreateCompletion(&completionRequest) var response *types.CompletionResponse
response, errWithCode = completionProvider.CreateCompletion(&completionRequest)
if errWithCode != nil { if errWithCode != nil {
errorHelper(c, errWithCode) errorHelper(c, errWithCode)
return return

View File

@ -154,34 +154,25 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith
return nil return nil
} }
func responseStreamClient[T any](c *gin.Context, stream requester.StreamReaderInterface[T]) *types.OpenAIErrorWithStatusCode { func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string]) *types.OpenAIErrorWithStatusCode {
requester.SetEventStreamHeaders(c) requester.SetEventStreamHeaders(c)
dataChan, errChan := stream.Recv()
defer stream.Close() defer stream.Close()
c.Stream(func(w io.Writer) bool {
for { select {
response, err := stream.Recv() case data := <-dataChan:
if errors.Is(err, io.EOF) { fmt.Fprintln(w, "data: "+data+"\n")
if response != nil && len(*response) > 0 { return true
for _, streamResponse := range *response { case err := <-errChan:
responseBody, _ := json.Marshal(streamResponse) if !errors.Is(err, io.EOF) {
c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)}) fmt.Fprintln(w, "data: "+err.Error()+"\n")
}
} }
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
break
}
if err != nil { fmt.Fprintln(w, "data: [DONE]")
c.Render(-1, common.CustomEvent{Data: "data: " + err.Error()}) return false
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
break
} }
})
for _, streamResponse := range *response {
responseBody, _ := json.Marshal(streamResponse)
c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)})
}
}
return nil return nil
} }

View File

@ -34,7 +34,7 @@ func (p *AliProvider) CreateChatCompletion(request *types.ChatCompletionRequest)
return p.convertToChatOpenai(aliResponse, request) return p.convertToChatOpenai(aliResponse, request)
} }
func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getAliChatRequest(request) req, errWithCode := p.getAliChatRequest(request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
@ -52,7 +52,7 @@ func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRe
Request: request, Request: request,
} }
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
} }
func (p *AliProvider) getAliChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *AliProvider) getAliChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -162,11 +162,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *AliChatRequest
} }
// 转换为OpenAI聊天流式请求体 // 转换为OpenAI聊天流式请求体
func (h *aliStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { func (h *aliStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回 // 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data:") { if !strings.HasPrefix(string(*rawLine), "data:") {
*rawLine = nil *rawLine = nil
return nil return
} }
// 去除前缀 // 去除前缀
@ -175,19 +175,21 @@ func (h *aliStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, resp
var aliResponse AliChatResponse var aliResponse AliChatResponse
err := json.Unmarshal(*rawLine, &aliResponse) err := json.Unmarshal(*rawLine, &aliResponse)
if err != nil { if err != nil {
return common.ErrorToOpenAIError(err) errChan <- common.ErrorToOpenAIError(err)
return
} }
error := errorHandle(&aliResponse.AliError) error := errorHandle(&aliResponse.AliError)
if error != nil { if error != nil {
return error errChan <- error
return
} }
return h.convertToOpenaiStream(&aliResponse, response) h.convertToOpenaiStream(&aliResponse, dataChan, errChan)
} }
func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, response *[]types.ChatCompletionStreamResponse) error { func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, dataChan chan string, errChan chan error) {
content := aliResponse.Output.Choices[0].Message.StringContent() content := aliResponse.Output.Choices[0].Message.StringContent()
var choice types.ChatCompletionStreamChoice var choice types.ChatCompletionStreamChoice
@ -222,7 +224,6 @@ func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, r
h.Usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens h.Usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
} }
*response = append(*response, streamResponse) responseBody, _ := json.Marshal(streamResponse)
dataChan <- string(responseBody)
return nil
} }

View File

@ -39,7 +39,7 @@ func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionReq
return &response.ChatCompletionResponse, nil return &response.ChatCompletionResponse, nil
} }
func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
@ -57,7 +57,7 @@ func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatComplet
ModelName: request.Model, ModelName: request.Model,
} }
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream) return requester.RequestStream[string](p.Requester, resp, chatHandler.HandlerChatStream)
} }
// 获取聊天请求体 // 获取聊天请求体

View File

@ -2,6 +2,7 @@ package baidu
import ( import (
"encoding/json" "encoding/json"
"io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester" "one-api/common/requester"
@ -31,7 +32,7 @@ func (p *BaiduProvider) CreateChatCompletion(request *types.ChatCompletionReques
return p.convertToChatOpenai(baiduResponse, request) return p.convertToChatOpenai(baiduResponse, request)
} }
func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getBaiduChatRequest(request) req, errWithCode := p.getBaiduChatRequest(request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
@ -49,7 +50,7 @@ func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletion
Request: request, Request: request,
} }
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
} }
func (p *BaiduProvider) getBaiduChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *BaiduProvider) getBaiduChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -178,11 +179,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *BaiduChatReque
} }
// 转换为OpenAI聊天流式请求体 // 转换为OpenAI聊天流式请求体
func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回 // 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") { if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil *rawLine = nil
return nil return
} }
// 去除前缀 // 去除前缀
@ -191,18 +192,26 @@ func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, re
var baiduResponse BaiduChatStreamResponse var baiduResponse BaiduChatStreamResponse
err := json.Unmarshal(*rawLine, &baiduResponse) err := json.Unmarshal(*rawLine, &baiduResponse)
if err != nil { if err != nil {
return common.ErrorToOpenAIError(err) errChan <- common.ErrorToOpenAIError(err)
return
} }
error := errorHandle(&baiduResponse.BaiduError)
if error != nil {
errChan <- error
return
}
h.convertToOpenaiStream(&baiduResponse, dataChan, errChan)
if baiduResponse.IsEnd { if baiduResponse.IsEnd {
*isFinished = true errChan <- io.EOF
*rawLine = requester.StreamClosed
return
} }
return h.convertToOpenaiStream(&baiduResponse, response)
} }
func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, response *[]types.ChatCompletionStreamResponse) error { func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, dataChan chan string, errChan chan error) {
choice := types.ChatCompletionStreamChoice{ choice := types.ChatCompletionStreamChoice{
Index: 0, Index: 0,
Delta: types.ChatCompletionStreamChoiceDelta{ Delta: types.ChatCompletionStreamChoiceDelta{
@ -240,19 +249,19 @@ func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStrea
if baiduResponse.FunctionCall == nil { if baiduResponse.FunctionCall == nil {
chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice} chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletion) responseBody, _ := json.Marshal(chatCompletion)
dataChan <- string(responseBody)
} else { } else {
choices := choice.ConvertOpenaiStream() choices := choice.ConvertOpenaiStream()
for _, choice := range choices { for _, choice := range choices {
chatCompletionCopy := chatCompletion chatCompletionCopy := chatCompletion
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice} chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletionCopy) responseBody, _ := json.Marshal(chatCompletionCopy)
dataChan <- string(responseBody)
} }
} }
h.Usage.TotalTokens = baiduResponse.Usage.TotalTokens h.Usage.TotalTokens = baiduResponse.Usage.TotalTokens
h.Usage.PromptTokens = baiduResponse.Usage.PromptTokens h.Usage.PromptTokens = baiduResponse.Usage.PromptTokens
h.Usage.CompletionTokens += baiduResponse.Usage.CompletionTokens h.Usage.CompletionTokens += baiduResponse.Usage.CompletionTokens
return nil
} }

View File

@ -41,14 +41,14 @@ type ProviderInterface interface {
type CompletionInterface interface { type CompletionInterface interface {
ProviderInterface ProviderInterface
CreateCompletion(request *types.CompletionRequest) (*types.CompletionResponse, *types.OpenAIErrorWithStatusCode) CreateCompletion(request *types.CompletionRequest) (*types.CompletionResponse, *types.OpenAIErrorWithStatusCode)
CreateCompletionStream(request *types.CompletionRequest) (requester.StreamReaderInterface[types.CompletionResponse], *types.OpenAIErrorWithStatusCode) CreateCompletionStream(request *types.CompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode)
} }
// 聊天接口 // 聊天接口
type ChatInterface interface { type ChatInterface interface {
ProviderInterface ProviderInterface
CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode)
CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode)
} }
// 嵌入接口 // 嵌入接口

View File

@ -3,6 +3,7 @@ package claude
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester" "one-api/common/requester"
@ -32,7 +33,7 @@ func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionReque
return p.convertToChatOpenai(claudeResponse, request) return p.convertToChatOpenai(claudeResponse, request)
} }
func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request) req, errWithCode := p.getChatRequest(request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
@ -50,7 +51,7 @@ func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletio
Request: request, Request: request,
} }
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
} }
func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -149,11 +150,11 @@ func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *
} }
// 转换为OpenAI聊天流式请求体 // 转换为OpenAI聊天流式请求体
func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回 // 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), `data: {"type": "completion"`) { if !strings.HasPrefix(string(*rawLine), `data: {"type": "completion"`) {
*rawLine = nil *rawLine = nil
return nil return
} }
// 去除前缀 // 去除前缀
@ -162,17 +163,26 @@ func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, r
var claudeResponse *ClaudeResponse var claudeResponse *ClaudeResponse
err := json.Unmarshal(*rawLine, claudeResponse) err := json.Unmarshal(*rawLine, claudeResponse)
if err != nil { if err != nil {
return common.ErrorToOpenAIError(err) errChan <- common.ErrorToOpenAIError(err)
return
}
error := errorHandle(&claudeResponse.ClaudeResponseError)
if error != nil {
errChan <- error
return
} }
if claudeResponse.StopReason == "stop_sequence" { if claudeResponse.StopReason == "stop_sequence" {
*isFinished = true errChan <- io.EOF
*rawLine = requester.StreamClosed
return
} }
return h.convertToOpenaiStream(claudeResponse, response) h.convertToOpenaiStream(claudeResponse, dataChan, errChan)
} }
func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeResponse, response *[]types.ChatCompletionStreamResponse) error { func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeResponse, dataChan chan string, errChan chan error) {
var choice types.ChatCompletionStreamChoice var choice types.ChatCompletionStreamChoice
choice.Delta.Content = claudeResponse.Completion choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
@ -185,9 +195,8 @@ func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeRespon
Choices: []types.ChatCompletionStreamChoice{choice}, Choices: []types.ChatCompletionStreamChoice{choice},
} }
*response = append(*response, chatCompletion) responseBody, _ := json.Marshal(chatCompletion)
dataChan <- string(responseBody)
h.Usage.PromptTokens += common.CountTokenText(claudeResponse.Completion, h.Request.Model) h.Usage.PromptTokens += common.CountTokenText(claudeResponse.Completion, h.Request.Model)
return nil
} }

View File

@ -37,7 +37,7 @@ func (p *GeminiProvider) CreateChatCompletion(request *types.ChatCompletionReque
return p.convertToChatOpenai(geminiChatResponse, request) return p.convertToChatOpenai(geminiChatResponse, request)
} }
func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request) req, errWithCode := p.getChatRequest(request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
@ -55,7 +55,7 @@ func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletio
Request: request, Request: request,
} }
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
} }
func (p *GeminiProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *GeminiProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -228,11 +228,11 @@ func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, reque
} }
// 转换为OpenAI聊天流式请求体 // 转换为OpenAI聊天流式请求体
func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回 // 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") { if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil *rawLine = nil
return nil return
} }
// 去除前缀 // 去除前缀
@ -241,19 +241,21 @@ func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, r
var geminiResponse GeminiChatResponse var geminiResponse GeminiChatResponse
err := json.Unmarshal(*rawLine, &geminiResponse) err := json.Unmarshal(*rawLine, &geminiResponse)
if err != nil { if err != nil {
return common.ErrorToOpenAIError(err) errChan <- common.ErrorToOpenAIError(err)
return
} }
error := errorHandle(&geminiResponse.GeminiErrorResponse) error := errorHandle(&geminiResponse.GeminiErrorResponse)
if error != nil { if error != nil {
return error errChan <- error
return
} }
return h.convertToOpenaiStream(&geminiResponse, response) h.convertToOpenaiStream(&geminiResponse, dataChan, errChan)
} }
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, response *[]types.ChatCompletionStreamResponse) error { func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string, errChan chan error) {
choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates)) choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates))
for i, candidate := range geminiResponse.Candidates { for i, candidate := range geminiResponse.Candidates {
@ -275,10 +277,9 @@ func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatRe
Choices: choices, Choices: choices,
} }
*response = append(*response, streamResponse) responseBody, _ := json.Marshal(streamResponse)
dataChan <- string(responseBody)
h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model) h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
return nil
} }

View File

@ -32,7 +32,7 @@ func (p *MiniMaxProvider) CreateChatCompletion(request *types.ChatCompletionRequ
return p.convertToChatOpenai(response, request) return p.convertToChatOpenai(response, request)
} }
func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request) req, errWithCode := p.getChatRequest(request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
@ -50,7 +50,7 @@ func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompleti
Request: request, Request: request,
} }
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
} }
func (p *MiniMaxProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *MiniMaxProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -191,11 +191,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *MiniMaxChatReq
} }
// 转换为OpenAI聊天流式请求体 // 转换为OpenAI聊天流式请求体
func (h *minimaxStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { func (h *minimaxStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data: 或者 meta:,则直接返回 // 如果rawLine 前缀不为data: 或者 meta:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") { if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil *rawLine = nil
return nil return
} }
*rawLine = (*rawLine)[6:] *rawLine = (*rawLine)[6:]
@ -203,25 +203,27 @@ func (h *minimaxStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool,
miniResponse := &MiniMaxChatResponse{} miniResponse := &MiniMaxChatResponse{}
err := json.Unmarshal(*rawLine, miniResponse) err := json.Unmarshal(*rawLine, miniResponse)
if err != nil { if err != nil {
return common.ErrorToOpenAIError(err) errChan <- common.ErrorToOpenAIError(err)
return
} }
error := errorHandle(&miniResponse.BaseResp) error := errorHandle(&miniResponse.BaseResp)
if error != nil { if error != nil {
return error errChan <- error
return
} }
choice := miniResponse.Choices[0] choice := miniResponse.Choices[0]
if choice.Messages[0].FunctionCall != nil && choice.FinishReason == "" { if choice.Messages[0].FunctionCall != nil && choice.FinishReason == "" {
*rawLine = nil *rawLine = nil
return nil return
} }
return h.convertToOpenaiStream(miniResponse, response) h.convertToOpenaiStream(miniResponse, dataChan, errChan)
} }
func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatResponse, response *[]types.ChatCompletionStreamResponse) error { func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatResponse, dataChan chan string, errChan chan error) {
streamResponse := types.ChatCompletionStreamResponse{ streamResponse := types.ChatCompletionStreamResponse{
ID: miniResponse.RequestID, ID: miniResponse.RequestID,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
@ -235,8 +237,8 @@ func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatRe
if miniChoice.Messages[0].FunctionCall == nil && miniChoice.FinishReason != "" { if miniChoice.Messages[0].FunctionCall == nil && miniChoice.FinishReason != "" {
streamResponse.ID = miniResponse.ID streamResponse.ID = miniResponse.ID
openaiChoice.FinishReason = convertFinishReason(miniChoice.FinishReason) openaiChoice.FinishReason = convertFinishReason(miniChoice.FinishReason)
h.appendResponse(&streamResponse, &openaiChoice, response) dataChan <- h.getResponseString(&streamResponse, &openaiChoice)
return nil return
} }
openaiChoice.Delta = types.ChatCompletionStreamChoiceDelta{ openaiChoice.Delta = types.ChatCompletionStreamChoiceDelta{
@ -248,19 +250,17 @@ func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatRe
convertChoices := openaiChoice.ConvertOpenaiStream() convertChoices := openaiChoice.ConvertOpenaiStream()
for _, convertChoice := range convertChoices { for _, convertChoice := range convertChoices {
chatCompletionCopy := streamResponse chatCompletionCopy := streamResponse
h.appendResponse(&chatCompletionCopy, &convertChoice, response) dataChan <- h.getResponseString(&chatCompletionCopy, &convertChoice)
} }
} else { } else {
openaiChoice.Delta.Content = miniChoice.Messages[0].Text openaiChoice.Delta.Content = miniChoice.Messages[0].Text
h.appendResponse(&streamResponse, &openaiChoice, response) dataChan <- h.getResponseString(&streamResponse, &openaiChoice)
} }
if miniResponse.Usage != nil { if miniResponse.Usage != nil {
h.handleUsage(miniResponse) h.handleUsage(miniResponse)
} }
return nil
} }
func (h *minimaxStreamHandler) handleFunctionCall(choice *Choice, openaiChoice *types.ChatCompletionStreamChoice) { func (h *minimaxStreamHandler) handleFunctionCall(choice *Choice, openaiChoice *types.ChatCompletionStreamChoice) {
@ -274,9 +274,10 @@ func (h *minimaxStreamHandler) handleFunctionCall(choice *Choice, openaiChoice *
} }
} }
func (h *minimaxStreamHandler) appendResponse(streamResponse *types.ChatCompletionStreamResponse, openaiChoice *types.ChatCompletionStreamChoice, response *[]types.ChatCompletionStreamResponse) { func (h *minimaxStreamHandler) getResponseString(streamResponse *types.ChatCompletionStreamResponse, openaiChoice *types.ChatCompletionStreamChoice) string {
streamResponse.Choices = []types.ChatCompletionStreamChoice{*openaiChoice} streamResponse.Choices = []types.ChatCompletionStreamChoice{*openaiChoice}
*response = append(*response, *streamResponse) responseBody, _ := json.Marshal(streamResponse)
return string(responseBody)
} }
func (h *minimaxStreamHandler) handleUsage(miniResponse *MiniMaxChatResponse) { func (h *minimaxStreamHandler) handleUsage(miniResponse *MiniMaxChatResponse) {

View File

@ -2,16 +2,19 @@ package openai
import ( import (
"encoding/json" "encoding/json"
"io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester" "one-api/common/requester"
"one-api/types" "one-api/types"
"strings" "strings"
"time"
) )
type OpenAIStreamHandler struct { type OpenAIStreamHandler struct {
Usage *types.Usage Usage *types.Usage
ModelName string ModelName string
isAzure bool
} }
func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
@ -43,7 +46,7 @@ func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionReque
return &response.ChatCompletionResponse, nil return &response.ChatCompletionResponse, nil
} }
func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
@ -59,16 +62,17 @@ func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletio
chatHandler := OpenAIStreamHandler{ chatHandler := OpenAIStreamHandler{
Usage: p.Usage, Usage: p.Usage,
ModelName: request.Model, ModelName: request.Model,
isAzure: p.IsAzure,
} }
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream) return requester.RequestStream[string](p.Requester, resp, chatHandler.HandlerChatStream)
} }
func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回 // 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") { if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil *rawLine = nil
return nil return
} }
// 去除前缀 // 去除前缀
@ -76,26 +80,32 @@ func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, isFinished *boo
// 如果等于 DONE 则结束 // 如果等于 DONE 则结束
if string(*rawLine) == "[DONE]" { if string(*rawLine) == "[DONE]" {
*isFinished = true errChan <- io.EOF
return nil *rawLine = requester.StreamClosed
return
} }
var openaiResponse OpenAIProviderChatStreamResponse var openaiResponse OpenAIProviderChatStreamResponse
err := json.Unmarshal(*rawLine, &openaiResponse) err := json.Unmarshal(*rawLine, &openaiResponse)
if err != nil { if err != nil {
return common.ErrorToOpenAIError(err) errChan <- common.ErrorToOpenAIError(err)
return
} }
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse) error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
if error != nil { if error != nil {
return error errChan <- error
return
}
dataChan <- string(*rawLine)
if h.isAzure {
// 阻塞 20ms
time.Sleep(20 * time.Millisecond)
} }
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName) countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText h.Usage.TotalTokens += countTokenText
*response = append(*response, openaiResponse.ChatCompletionStreamResponse)
return nil
} }

View File

@ -2,6 +2,7 @@ package openai
import ( import (
"encoding/json" "encoding/json"
"io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester" "one-api/common/requester"
@ -38,7 +39,7 @@ func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (ope
return &response.CompletionResponse, nil return &response.CompletionResponse, nil
} }
func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[types.CompletionResponse], errWithCode *types.OpenAIErrorWithStatusCode) { func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[string], errWithCode *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
@ -56,14 +57,14 @@ func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest
ModelName: request.Model, ModelName: request.Model,
} }
return requester.RequestStream[types.CompletionResponse](p.Requester, resp, chatHandler.handlerCompletionStream) return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerCompletionStream)
} }
func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, isFinished *bool, response *[]types.CompletionResponse) error { func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回 // 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") { if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil *rawLine = nil
return nil return
} }
// 去除前缀 // 去除前缀
@ -71,26 +72,27 @@ func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, isFinishe
// 如果等于 DONE 则结束 // 如果等于 DONE 则结束
if string(*rawLine) == "[DONE]" { if string(*rawLine) == "[DONE]" {
*isFinished = true errChan <- io.EOF
return nil *rawLine = requester.StreamClosed
return
} }
var openaiResponse OpenAIProviderCompletionResponse var openaiResponse OpenAIProviderCompletionResponse
err := json.Unmarshal(*rawLine, &openaiResponse) err := json.Unmarshal(*rawLine, &openaiResponse)
if err != nil { if err != nil {
return common.ErrorToOpenAIError(err) errChan <- common.ErrorToOpenAIError(err)
return
} }
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse) error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
if error != nil { if error != nil {
return error errChan <- error
return
} }
dataChan <- string(*rawLine)
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName) countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText h.Usage.TotalTokens += countTokenText
*response = append(*response, openaiResponse.CompletionResponse)
return nil
} }

View File

@ -32,7 +32,7 @@ func (p *PalmProvider) CreateChatCompletion(request *types.ChatCompletionRequest
return p.convertToChatOpenai(palmResponse, request) return p.convertToChatOpenai(palmResponse, request)
} }
func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request) req, errWithCode := p.getChatRequest(request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
@ -50,7 +50,7 @@ func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionR
Request: request, Request: request,
} }
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
} }
func (p *PalmProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *PalmProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -142,11 +142,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *PaLMChatReques
} }
// 转换为OpenAI聊天流式请求体 // 转换为OpenAI聊天流式请求体
func (h *palmStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { func (h *palmStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回 // 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") { if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil *rawLine = nil
return nil return
} }
// 去除前缀 // 去除前缀
@ -155,19 +155,21 @@ func (h *palmStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, res
var palmChatResponse PaLMChatResponse var palmChatResponse PaLMChatResponse
err := json.Unmarshal(*rawLine, &palmChatResponse) err := json.Unmarshal(*rawLine, &palmChatResponse)
if err != nil { if err != nil {
return common.ErrorToOpenAIError(err) errChan <- common.ErrorToOpenAIError(err)
return
} }
error := errorHandle(&palmChatResponse.PaLMErrorResponse) error := errorHandle(&palmChatResponse.PaLMErrorResponse)
if error != nil { if error != nil {
return error errChan <- error
return
} }
return h.convertToOpenaiStream(&palmChatResponse, response) h.convertToOpenaiStream(&palmChatResponse, dataChan, errChan)
} }
func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResponse, response *[]types.ChatCompletionStreamResponse) error { func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResponse, dataChan chan string, errChan chan error) {
var choice types.ChatCompletionStreamChoice var choice types.ChatCompletionStreamChoice
if len(palmChatResponse.Candidates) > 0 { if len(palmChatResponse.Candidates) > 0 {
choice.Delta.Content = palmChatResponse.Candidates[0].Content choice.Delta.Content = palmChatResponse.Candidates[0].Content
@ -182,10 +184,9 @@ func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResp
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
} }
*response = append(*response, streamResponse) responseBody, _ := json.Marshal(streamResponse)
dataChan <- string(responseBody)
h.Usage.CompletionTokens += common.CountTokenText(palmChatResponse.Candidates[0].Content, h.Request.Model) h.Usage.CompletionTokens += common.CountTokenText(palmChatResponse.Candidates[0].Content, h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
return nil
} }

View File

@ -32,7 +32,7 @@ func (p *TencentProvider) CreateChatCompletion(request *types.ChatCompletionRequ
return p.convertToChatOpenai(tencentChatResponse, request) return p.convertToChatOpenai(tencentChatResponse, request)
} }
func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request) req, errWithCode := p.getChatRequest(request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
@ -50,7 +50,7 @@ func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompleti
Request: request, Request: request,
} }
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
} }
func (p *TencentProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *TencentProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -157,11 +157,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *TencentChatReq
} }
// 转换为OpenAI聊天流式请求体 // 转换为OpenAI聊天流式请求体
func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回 // 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data:") { if !strings.HasPrefix(string(*rawLine), "data:") {
*rawLine = nil *rawLine = nil
return nil return
} }
// 去除前缀 // 去除前缀
@ -170,19 +170,21 @@ func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool,
var tencentChatResponse TencentChatResponse var tencentChatResponse TencentChatResponse
err := json.Unmarshal(*rawLine, &tencentChatResponse) err := json.Unmarshal(*rawLine, &tencentChatResponse)
if err != nil { if err != nil {
return common.ErrorToOpenAIError(err) errChan <- common.ErrorToOpenAIError(err)
return
} }
error := errorHandle(&tencentChatResponse.TencentResponseError) error := errorHandle(&tencentChatResponse.TencentResponseError)
if error != nil { if error != nil {
return error errChan <- error
return
} }
return h.convertToOpenaiStream(&tencentChatResponse, response) h.convertToOpenaiStream(&tencentChatResponse, dataChan, errChan)
} }
func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, response *[]types.ChatCompletionStreamResponse) error { func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, dataChan chan string, errChan chan error) {
streamResponse := types.ChatCompletionStreamResponse{ streamResponse := types.ChatCompletionStreamResponse{
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
@ -197,10 +199,9 @@ func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *Tencen
streamResponse.Choices = append(streamResponse.Choices, choice) streamResponse.Choices = append(streamResponse.Choices, choice)
} }
*response = append(*response, streamResponse) responseBody, _ := json.Marshal(streamResponse)
dataChan <- string(responseBody)
h.Usage.CompletionTokens += common.CountTokenText(tencentChatResponse.Choices[0].Delta.Content, h.Request.Model) h.Usage.CompletionTokens += common.CountTokenText(tencentChatResponse.Choices[0].Delta.Content, h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
return nil
} }

View File

@ -40,7 +40,7 @@ func (p *XunfeiProvider) CreateChatCompletion(request *types.ChatCompletionReque
} }
func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
wsConn, errWithCode := p.getChatRequest(request) wsConn, errWithCode := p.getChatRequest(request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
@ -53,7 +53,7 @@ func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletio
Request: request, Request: request,
} }
return requester.SendWSJsonRequest[types.ChatCompletionStreamResponse](wsConn, xunfeiRequest, chatHandler.handlerStream) return requester.SendWSJsonRequest[string](wsConn, xunfeiRequest, chatHandler.handlerStream)
} }
func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (*websocket.Conn, *types.OpenAIErrorWithStatusCode) { func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (*websocket.Conn, *types.OpenAIErrorWithStatusCode) {
@ -123,23 +123,26 @@ func (p *XunfeiProvider) convertFromChatOpenai(request *types.ChatCompletionRequ
func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterface[XunfeiChatResponse]) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) { func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterface[XunfeiChatResponse]) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
var content string var content string
var xunfeiResponse XunfeiChatResponse var xunfeiResponse XunfeiChatResponse
dataChan, errChan := stream.Recv()
for { stop := false
response, err := stream.Recv() for !stop {
select {
case response := <-dataChan:
if len(response.Payload.Choices.Text) == 0 {
continue
}
xunfeiResponse = response
content += xunfeiResponse.Payload.Choices.Text[0].Content
case err := <-errChan:
if err != nil && !errors.Is(err, io.EOF) {
return nil, common.ErrorWrapper(err, "xunfei_failed", http.StatusInternalServerError)
}
if err != nil && !errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
return nil, common.ErrorWrapper(err, "xunfei_failed", http.StatusInternalServerError) stop = true
}
} }
if errors.Is(err, io.EOF) && response == nil {
break
}
if len((*response)[0].Payload.Choices.Text) == 0 {
continue
}
xunfeiResponse = (*response)[0]
content += xunfeiResponse.Payload.Choices.Text[0].Content
} }
if len(xunfeiResponse.Payload.Choices.Text) == 0 { if len(xunfeiResponse.Payload.Choices.Text) == 0 {
@ -193,7 +196,7 @@ func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterfa
} }
func (h *xunfeiHandler) handlerData(rawLine *[]byte, isFinished *bool) (*XunfeiChatResponse, error) { func (h *xunfeiHandler) handlerData(rawLine *[]byte, isFinished *bool) (*XunfeiChatResponse, error) {
// 如果rawLine 前缀不为data:,则直接返回 // 如果rawLine 前缀不为{,则直接返回
if !strings.HasPrefix(string(*rawLine), "{") { if !strings.HasPrefix(string(*rawLine), "{") {
*rawLine = nil *rawLine = nil
return nil, nil return nil, nil
@ -221,34 +224,47 @@ func (h *xunfeiHandler) handlerData(rawLine *[]byte, isFinished *bool) (*XunfeiC
return &xunfeiChatResponse, nil return &xunfeiChatResponse, nil
} }
func (h *xunfeiHandler) handlerNotStream(rawLine *[]byte, isFinished *bool, response *[]XunfeiChatResponse) error { func (h *xunfeiHandler) handlerNotStream(rawLine *[]byte, dataChan chan XunfeiChatResponse, errChan chan error) {
xunfeiChatResponse, err := h.handlerData(rawLine, isFinished) isFinished := false
xunfeiChatResponse, err := h.handlerData(rawLine, &isFinished)
if err != nil { if err != nil {
return err errChan <- err
return
} }
if *rawLine == nil { if *rawLine == nil {
return nil return
} }
*response = append(*response, *xunfeiChatResponse) dataChan <- *xunfeiChatResponse
return nil
if isFinished {
errChan <- io.EOF
*rawLine = requester.StreamClosed
}
} }
func (h *xunfeiHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { func (h *xunfeiHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
xunfeiChatResponse, err := h.handlerData(rawLine, isFinished) isFinished := false
xunfeiChatResponse, err := h.handlerData(rawLine, &isFinished)
if err != nil { if err != nil {
return err errChan <- err
return
} }
if *rawLine == nil { if *rawLine == nil {
return nil return
} }
return h.convertToOpenaiStream(xunfeiChatResponse, response) h.convertToOpenaiStream(xunfeiChatResponse, dataChan, errChan)
if isFinished {
errChan <- io.EOF
*rawLine = requester.StreamClosed
}
} }
func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResponse, response *[]types.ChatCompletionStreamResponse) error { func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResponse, dataChan chan string, errChan chan error) {
if len(xunfeiChatResponse.Payload.Choices.Text) == 0 { if len(xunfeiChatResponse.Payload.Choices.Text) == 0 {
xunfeiChatResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}} xunfeiChatResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
} }
@ -293,15 +309,15 @@ func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResp
if xunfeiText.FunctionCall == nil { if xunfeiText.FunctionCall == nil {
chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice} chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletion) responseBody, _ := json.Marshal(chatCompletion)
dataChan <- string(responseBody)
} else { } else {
choices := choice.ConvertOpenaiStream() choices := choice.ConvertOpenaiStream()
for _, choice := range choices { for _, choice := range choices {
chatCompletionCopy := chatCompletion chatCompletionCopy := chatCompletion
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice} chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletionCopy) responseBody, _ := json.Marshal(chatCompletionCopy)
dataChan <- string(responseBody)
} }
} }
return nil
} }

View File

@ -2,6 +2,7 @@ package zhipu
import ( import (
"encoding/json" "encoding/json"
"io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/requester" "one-api/common/requester"
@ -31,7 +32,7 @@ func (p *ZhipuProvider) CreateChatCompletion(request *types.ChatCompletionReques
return p.convertToChatOpenai(zhipuChatResponse, request) return p.convertToChatOpenai(zhipuChatResponse, request)
} }
func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) { func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request) req, errWithCode := p.getChatRequest(request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
@ -49,7 +50,7 @@ func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletion
Request: request, Request: request,
} }
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream) return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
} }
func (p *ZhipuProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { func (p *ZhipuProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -140,35 +141,38 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest {
} }
// 转换为OpenAI聊天流式请求体 // 转换为OpenAI聊天流式请求体
func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error { func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data: 或者 meta:,则直接返回 // 如果rawLine 前缀不为data: 或者 meta:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") { if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil *rawLine = nil
return nil return
} }
*rawLine = (*rawLine)[6:] *rawLine = (*rawLine)[6:]
if strings.HasPrefix(string(*rawLine), "[DONE]") { if strings.HasPrefix(string(*rawLine), "[DONE]") {
*isFinished = true errChan <- io.EOF
return nil *rawLine = requester.StreamClosed
return
} }
zhipuResponse := &ZhipuStreamResponse{} zhipuResponse := &ZhipuStreamResponse{}
err := json.Unmarshal(*rawLine, zhipuResponse) err := json.Unmarshal(*rawLine, zhipuResponse)
if err != nil { if err != nil {
return common.ErrorToOpenAIError(err) errChan <- common.ErrorToOpenAIError(err)
return
} }
error := errorHandle(&zhipuResponse.Error) error := errorHandle(&zhipuResponse.Error)
if error != nil { if error != nil {
return error errChan <- error
return
} }
return h.convertToOpenaiStream(zhipuResponse, response) h.convertToOpenaiStream(zhipuResponse, dataChan, errChan)
} }
func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamResponse, response *[]types.ChatCompletionStreamResponse) error { func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamResponse, dataChan chan string, errChan chan error) {
streamResponse := types.ChatCompletionStreamResponse{ streamResponse := types.ChatCompletionStreamResponse{
ID: zhipuResponse.ID, ID: zhipuResponse.ID,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
@ -183,16 +187,16 @@ func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamRes
for _, choice := range choices { for _, choice := range choices {
chatCompletionCopy := streamResponse chatCompletionCopy := streamResponse
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice} chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletionCopy) responseBody, _ := json.Marshal(chatCompletionCopy)
dataChan <- string(responseBody)
} }
} else { } else {
streamResponse.Choices = []types.ChatCompletionStreamChoice{choice} streamResponse.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, streamResponse) responseBody, _ := json.Marshal(streamResponse)
dataChan <- string(responseBody)
} }
if zhipuResponse.Usage != nil { if zhipuResponse.Usage != nil {
*h.Usage = *zhipuResponse.Usage *h.Usage = *zhipuResponse.Usage
} }
return nil
} }