🐛 fix: stream mode delay issue (#53)

This commit is contained in:
Buer 2024-01-25 11:56:31 +08:00 committed by GitHub
parent 705804e6dd
commit d7193b8e46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 291 additions and 262 deletions

View File

@ -127,11 +127,16 @@ func RequestStream[T streamable](requester *HTTPRequester, resp *http.Response,
return nil, HandleErrorResp(resp, requester.ErrorHandler)
}
return &streamReader[T]{
stream := &streamReader[T]{
reader: bufio.NewReader(resp.Body),
response: resp,
handlerPrefix: handlerPrefix,
}, nil
DataChan: make(chan T),
ErrChan: make(chan error),
}
return stream, nil
}
// 设置请求体

View File

@ -3,16 +3,12 @@ package requester
import (
"bufio"
"bytes"
"io"
"net/http"
)
// 流处理函数,判断依据如下:
// 1.如果有错误信息,则直接返回错误信息
// 2.如果isFinished=true则返回io.EOF并且如果response不为空还将返回response
// 3.如果rawLine=nil 或者 response长度为0则直接跳过
// 4.如果以上条件都不满足则返回response
type HandlerPrefix[T streamable] func(rawLine *[]byte, isFinished *bool, response *[]T) error
var StreamClosed = []byte("stream_closed")
type HandlerPrefix[T streamable] func(rawLine *[]byte, dataChan chan T, errChan chan error)
type streamable interface {
// types.ChatCompletionStreamResponse | types.CompletionResponse
@ -20,57 +16,48 @@ type streamable interface {
}
type StreamReaderInterface[T streamable] interface {
Recv() (*[]T, error)
Recv() (<-chan T, <-chan error)
Close()
}
type streamReader[T streamable] struct {
isFinished bool
reader *bufio.Reader
response *http.Response
handlerPrefix HandlerPrefix[T]
DataChan chan T
ErrChan chan error
}
func (stream *streamReader[T]) Recv() (response *[]T, err error) {
if stream.isFinished {
err = io.EOF
return
}
response, err = stream.processLines()
return
func (stream *streamReader[T]) Recv() (<-chan T, <-chan error) {
go stream.processLines()
return stream.DataChan, stream.ErrChan
}
//nolint:gocognit
func (stream *streamReader[T]) processLines() (*[]T, error) {
func (stream *streamReader[T]) processLines() {
for {
rawLine, readErr := stream.reader.ReadBytes('\n')
if readErr != nil {
return nil, readErr
stream.ErrChan <- readErr
return
}
noSpaceLine := bytes.TrimSpace(rawLine)
var response []T
err := stream.handlerPrefix(&noSpaceLine, &stream.isFinished, &response)
if err != nil {
return nil, err
}
if stream.isFinished {
if len(response) > 0 {
return &response, io.EOF
}
return nil, io.EOF
}
if noSpaceLine == nil || len(response) == 0 {
if len(noSpaceLine) == 0 {
continue
}
return &response, nil
stream.handlerPrefix(&noSpaceLine, stream.DataChan, stream.ErrChan)
if noSpaceLine == nil {
continue
}
if bytes.Equal(noSpaceLine, StreamClosed) {
return
}
}
}

View File

@ -1,55 +1,41 @@
package requester
import (
"io"
"bytes"
"github.com/gorilla/websocket"
)
type wsReader[T streamable] struct {
isFinished bool
reader *websocket.Conn
handlerPrefix HandlerPrefix[T]
DataChan chan T
ErrChan chan error
}
func (stream *wsReader[T]) Recv() (response *[]T, err error) {
if stream.isFinished {
err = io.EOF
return
}
response, err = stream.processLines()
return
func (stream *wsReader[T]) Recv() (<-chan T, <-chan error) {
go stream.processLines()
return stream.DataChan, stream.ErrChan
}
func (stream *wsReader[T]) processLines() (*[]T, error) {
func (stream *wsReader[T]) processLines() {
for {
_, msg, err := stream.reader.ReadMessage()
if err != nil {
return nil, err
stream.ErrChan <- err
return
}
var response []T
err = stream.handlerPrefix(&msg, &stream.isFinished, &response)
stream.handlerPrefix(&msg, stream.DataChan, stream.ErrChan)
if err != nil {
return nil, err
}
if stream.isFinished {
if len(response) > 0 {
return &response, io.EOF
}
return nil, io.EOF
}
if msg == nil || len(response) == 0 {
if msg == nil {
continue
}
return &response, nil
if bytes.Equal(msg, StreamClosed) {
return
}
}
}

View File

@ -38,10 +38,15 @@ func SendWSJsonRequest[T streamable](conn *websocket.Conn, data any, handlerPref
return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError)
}
return &wsReader[T]{
stream := &wsReader[T]{
reader: conn,
handlerPrefix: handlerPrefix,
}, nil
DataChan: make(chan T),
ErrChan: make(chan error),
}
return stream, nil
}
// 设置请求头

View File

@ -1,7 +1,6 @@
package controller
import (
"fmt"
"math"
"net/http"
"one-api/common"
@ -53,13 +52,13 @@ func RelayChat(c *gin.Context) {
}
if chatRequest.Stream {
var response requester.StreamReaderInterface[types.ChatCompletionStreamResponse]
var response requester.StreamReaderInterface[string]
response, errWithCode = chatProvider.CreateChatCompletionStream(&chatRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseStreamClient[types.ChatCompletionStreamResponse](c, response)
errWithCode = responseStreamClient(c, response)
} else {
var response *types.ChatCompletionResponse
response, errWithCode = chatProvider.CreateChatCompletion(&chatRequest)
@ -70,8 +69,6 @@ func RelayChat(c *gin.Context) {
errWithCode = responseJsonClient(c, response)
}
fmt.Println(usage)
// 如果报错,则退还配额
if errWithCode != nil {
quotaInfo.undo(c, errWithCode)

View File

@ -4,6 +4,7 @@ import (
"math"
"net/http"
"one-api/common"
"one-api/common/requester"
providersBase "one-api/providers/base"
"one-api/types"
@ -51,14 +52,16 @@ func RelayCompletions(c *gin.Context) {
}
if completionRequest.Stream {
response, errWithCode := completionProvider.CreateCompletionStream(&completionRequest)
var response requester.StreamReaderInterface[string]
response, errWithCode = completionProvider.CreateCompletionStream(&completionRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseStreamClient[types.CompletionResponse](c, response)
errWithCode = responseStreamClient(c, response)
} else {
response, errWithCode := completionProvider.CreateCompletion(&completionRequest)
var response *types.CompletionResponse
response, errWithCode = completionProvider.CreateCompletion(&completionRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return

View File

@ -154,34 +154,25 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith
return nil
}
func responseStreamClient[T any](c *gin.Context, stream requester.StreamReaderInterface[T]) *types.OpenAIErrorWithStatusCode {
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string]) *types.OpenAIErrorWithStatusCode {
requester.SetEventStreamHeaders(c)
dataChan, errChan := stream.Recv()
defer stream.Close()
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
if response != nil && len(*response) > 0 {
for _, streamResponse := range *response {
responseBody, _ := json.Marshal(streamResponse)
c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)})
}
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
fmt.Fprintln(w, "data: "+data+"\n")
return true
case err := <-errChan:
if !errors.Is(err, io.EOF) {
fmt.Fprintln(w, "data: "+err.Error()+"\n")
}
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
break
}
if err != nil {
c.Render(-1, common.CustomEvent{Data: "data: " + err.Error()})
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
break
fmt.Fprintln(w, "data: [DONE]")
return false
}
for _, streamResponse := range *response {
responseBody, _ := json.Marshal(streamResponse)
c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)})
}
}
})
return nil
}

View File

@ -34,7 +34,7 @@ func (p *AliProvider) CreateChatCompletion(request *types.ChatCompletionRequest)
return p.convertToChatOpenai(aliResponse, request)
}
func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getAliChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
@ -52,7 +52,7 @@ func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRe
Request: request,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
}
func (p *AliProvider) getAliChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -162,11 +162,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *AliChatRequest
}
// 转换为OpenAI聊天流式请求体
func (h *aliStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
func (h *aliStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data:") {
*rawLine = nil
return nil
return
}
// 去除前缀
@ -175,19 +175,21 @@ func (h *aliStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, resp
var aliResponse AliChatResponse
err := json.Unmarshal(*rawLine, &aliResponse)
if err != nil {
return common.ErrorToOpenAIError(err)
errChan <- common.ErrorToOpenAIError(err)
return
}
error := errorHandle(&aliResponse.AliError)
if error != nil {
return error
errChan <- error
return
}
return h.convertToOpenaiStream(&aliResponse, response)
h.convertToOpenaiStream(&aliResponse, dataChan, errChan)
}
func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, response *[]types.ChatCompletionStreamResponse) error {
func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, dataChan chan string, errChan chan error) {
content := aliResponse.Output.Choices[0].Message.StringContent()
var choice types.ChatCompletionStreamChoice
@ -222,7 +224,6 @@ func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, r
h.Usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
*response = append(*response, streamResponse)
return nil
responseBody, _ := json.Marshal(streamResponse)
dataChan <- string(responseBody)
}

View File

@ -39,7 +39,7 @@ func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionReq
return &response.ChatCompletionResponse, nil
}
func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
@ -57,7 +57,7 @@ func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatComplet
ModelName: request.Model,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream)
return requester.RequestStream[string](p.Requester, resp, chatHandler.HandlerChatStream)
}
// 获取聊天请求体

View File

@ -2,6 +2,7 @@ package baidu
import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
@ -31,7 +32,7 @@ func (p *BaiduProvider) CreateChatCompletion(request *types.ChatCompletionReques
return p.convertToChatOpenai(baiduResponse, request)
}
func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getBaiduChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
@ -49,7 +50,7 @@ func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletion
Request: request,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
}
func (p *BaiduProvider) getBaiduChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -178,11 +179,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *BaiduChatReque
}
// 转换为OpenAI聊天流式请求体
func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return nil
return
}
// 去除前缀
@ -191,18 +192,26 @@ func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, re
var baiduResponse BaiduChatStreamResponse
err := json.Unmarshal(*rawLine, &baiduResponse)
if err != nil {
return common.ErrorToOpenAIError(err)
errChan <- common.ErrorToOpenAIError(err)
return
}
error := errorHandle(&baiduResponse.BaiduError)
if error != nil {
errChan <- error
return
}
h.convertToOpenaiStream(&baiduResponse, dataChan, errChan)
if baiduResponse.IsEnd {
*isFinished = true
errChan <- io.EOF
*rawLine = requester.StreamClosed
return
}
return h.convertToOpenaiStream(&baiduResponse, response)
}
func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, response *[]types.ChatCompletionStreamResponse) error {
func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, dataChan chan string, errChan chan error) {
choice := types.ChatCompletionStreamChoice{
Index: 0,
Delta: types.ChatCompletionStreamChoiceDelta{
@ -240,19 +249,19 @@ func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStrea
if baiduResponse.FunctionCall == nil {
chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletion)
responseBody, _ := json.Marshal(chatCompletion)
dataChan <- string(responseBody)
} else {
choices := choice.ConvertOpenaiStream()
for _, choice := range choices {
chatCompletionCopy := chatCompletion
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletionCopy)
responseBody, _ := json.Marshal(chatCompletionCopy)
dataChan <- string(responseBody)
}
}
h.Usage.TotalTokens = baiduResponse.Usage.TotalTokens
h.Usage.PromptTokens = baiduResponse.Usage.PromptTokens
h.Usage.CompletionTokens += baiduResponse.Usage.CompletionTokens
return nil
}

View File

@ -41,14 +41,14 @@ type ProviderInterface interface {
type CompletionInterface interface {
ProviderInterface
CreateCompletion(request *types.CompletionRequest) (*types.CompletionResponse, *types.OpenAIErrorWithStatusCode)
CreateCompletionStream(request *types.CompletionRequest) (requester.StreamReaderInterface[types.CompletionResponse], *types.OpenAIErrorWithStatusCode)
CreateCompletionStream(request *types.CompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode)
}
// 聊天接口
type ChatInterface interface {
ProviderInterface
CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode)
CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode)
CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode)
}
// 嵌入接口

View File

@ -3,6 +3,7 @@ package claude
import (
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
@ -32,7 +33,7 @@ func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionReque
return p.convertToChatOpenai(claudeResponse, request)
}
func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
@ -50,7 +51,7 @@ func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletio
Request: request,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
}
func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -149,11 +150,11 @@ func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *
}
// 转换为OpenAI聊天流式请求体
func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), `data: {"type": "completion"`) {
*rawLine = nil
return nil
return
}
// 去除前缀
@ -162,17 +163,26 @@ func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, r
var claudeResponse *ClaudeResponse
err := json.Unmarshal(*rawLine, claudeResponse)
if err != nil {
return common.ErrorToOpenAIError(err)
errChan <- common.ErrorToOpenAIError(err)
return
}
error := errorHandle(&claudeResponse.ClaudeResponseError)
if error != nil {
errChan <- error
return
}
if claudeResponse.StopReason == "stop_sequence" {
*isFinished = true
errChan <- io.EOF
*rawLine = requester.StreamClosed
return
}
return h.convertToOpenaiStream(claudeResponse, response)
h.convertToOpenaiStream(claudeResponse, dataChan, errChan)
}
func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeResponse, response *[]types.ChatCompletionStreamResponse) error {
func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeResponse, dataChan chan string, errChan chan error) {
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
@ -185,9 +195,8 @@ func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeRespon
Choices: []types.ChatCompletionStreamChoice{choice},
}
*response = append(*response, chatCompletion)
responseBody, _ := json.Marshal(chatCompletion)
dataChan <- string(responseBody)
h.Usage.PromptTokens += common.CountTokenText(claudeResponse.Completion, h.Request.Model)
return nil
}

View File

@ -37,7 +37,7 @@ func (p *GeminiProvider) CreateChatCompletion(request *types.ChatCompletionReque
return p.convertToChatOpenai(geminiChatResponse, request)
}
func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
@ -55,7 +55,7 @@ func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletio
Request: request,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
}
func (p *GeminiProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -228,11 +228,11 @@ func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, reque
}
// 转换为OpenAI聊天流式请求体
func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return nil
return
}
// 去除前缀
@ -241,19 +241,21 @@ func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, r
var geminiResponse GeminiChatResponse
err := json.Unmarshal(*rawLine, &geminiResponse)
if err != nil {
return common.ErrorToOpenAIError(err)
errChan <- common.ErrorToOpenAIError(err)
return
}
error := errorHandle(&geminiResponse.GeminiErrorResponse)
if error != nil {
return error
errChan <- error
return
}
return h.convertToOpenaiStream(&geminiResponse, response)
h.convertToOpenaiStream(&geminiResponse, dataChan, errChan)
}
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, response *[]types.ChatCompletionStreamResponse) error {
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string, errChan chan error) {
choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates))
for i, candidate := range geminiResponse.Candidates {
@ -275,10 +277,9 @@ func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatRe
Choices: choices,
}
*response = append(*response, streamResponse)
responseBody, _ := json.Marshal(streamResponse)
dataChan <- string(responseBody)
h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
return nil
}

View File

@ -32,7 +32,7 @@ func (p *MiniMaxProvider) CreateChatCompletion(request *types.ChatCompletionRequ
return p.convertToChatOpenai(response, request)
}
func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
@ -50,7 +50,7 @@ func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompleti
Request: request,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
}
func (p *MiniMaxProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -191,11 +191,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *MiniMaxChatReq
}
// 转换为OpenAI聊天流式请求体
func (h *minimaxStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
func (h *minimaxStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data: 或者 meta:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return nil
return
}
*rawLine = (*rawLine)[6:]
@ -203,25 +203,27 @@ func (h *minimaxStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool,
miniResponse := &MiniMaxChatResponse{}
err := json.Unmarshal(*rawLine, miniResponse)
if err != nil {
return common.ErrorToOpenAIError(err)
errChan <- common.ErrorToOpenAIError(err)
return
}
error := errorHandle(&miniResponse.BaseResp)
if error != nil {
return error
errChan <- error
return
}
choice := miniResponse.Choices[0]
if choice.Messages[0].FunctionCall != nil && choice.FinishReason == "" {
*rawLine = nil
return nil
return
}
return h.convertToOpenaiStream(miniResponse, response)
h.convertToOpenaiStream(miniResponse, dataChan, errChan)
}
func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatResponse, response *[]types.ChatCompletionStreamResponse) error {
func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatResponse, dataChan chan string, errChan chan error) {
streamResponse := types.ChatCompletionStreamResponse{
ID: miniResponse.RequestID,
Object: "chat.completion.chunk",
@ -235,8 +237,8 @@ func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatRe
if miniChoice.Messages[0].FunctionCall == nil && miniChoice.FinishReason != "" {
streamResponse.ID = miniResponse.ID
openaiChoice.FinishReason = convertFinishReason(miniChoice.FinishReason)
h.appendResponse(&streamResponse, &openaiChoice, response)
return nil
dataChan <- h.getResponseString(&streamResponse, &openaiChoice)
return
}
openaiChoice.Delta = types.ChatCompletionStreamChoiceDelta{
@ -248,19 +250,17 @@ func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatRe
convertChoices := openaiChoice.ConvertOpenaiStream()
for _, convertChoice := range convertChoices {
chatCompletionCopy := streamResponse
h.appendResponse(&chatCompletionCopy, &convertChoice, response)
dataChan <- h.getResponseString(&chatCompletionCopy, &convertChoice)
}
} else {
openaiChoice.Delta.Content = miniChoice.Messages[0].Text
h.appendResponse(&streamResponse, &openaiChoice, response)
dataChan <- h.getResponseString(&streamResponse, &openaiChoice)
}
if miniResponse.Usage != nil {
h.handleUsage(miniResponse)
}
return nil
}
func (h *minimaxStreamHandler) handleFunctionCall(choice *Choice, openaiChoice *types.ChatCompletionStreamChoice) {
@ -274,9 +274,10 @@ func (h *minimaxStreamHandler) handleFunctionCall(choice *Choice, openaiChoice *
}
}
func (h *minimaxStreamHandler) appendResponse(streamResponse *types.ChatCompletionStreamResponse, openaiChoice *types.ChatCompletionStreamChoice, response *[]types.ChatCompletionStreamResponse) {
func (h *minimaxStreamHandler) getResponseString(streamResponse *types.ChatCompletionStreamResponse, openaiChoice *types.ChatCompletionStreamChoice) string {
streamResponse.Choices = []types.ChatCompletionStreamChoice{*openaiChoice}
*response = append(*response, *streamResponse)
responseBody, _ := json.Marshal(streamResponse)
return string(responseBody)
}
func (h *minimaxStreamHandler) handleUsage(miniResponse *MiniMaxChatResponse) {

View File

@ -2,16 +2,19 @@ package openai
import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
"strings"
"time"
)
type OpenAIStreamHandler struct {
Usage *types.Usage
ModelName string
isAzure bool
}
func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
@ -43,7 +46,7 @@ func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionReque
return &response.ChatCompletionResponse, nil
}
func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
@ -59,16 +62,17 @@ func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletio
chatHandler := OpenAIStreamHandler{
Usage: p.Usage,
ModelName: request.Model,
isAzure: p.IsAzure,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream)
return requester.RequestStream[string](p.Requester, resp, chatHandler.HandlerChatStream)
}
func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return nil
return
}
// 去除前缀
@ -76,26 +80,32 @@ func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, isFinished *boo
// 如果等于 DONE 则结束
if string(*rawLine) == "[DONE]" {
*isFinished = true
return nil
errChan <- io.EOF
*rawLine = requester.StreamClosed
return
}
var openaiResponse OpenAIProviderChatStreamResponse
err := json.Unmarshal(*rawLine, &openaiResponse)
if err != nil {
return common.ErrorToOpenAIError(err)
errChan <- common.ErrorToOpenAIError(err)
return
}
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
if error != nil {
return error
errChan <- error
return
}
dataChan <- string(*rawLine)
if h.isAzure {
// 阻塞 20ms
time.Sleep(20 * time.Millisecond)
}
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText
*response = append(*response, openaiResponse.ChatCompletionStreamResponse)
return nil
}

View File

@ -2,6 +2,7 @@ package openai
import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
@ -38,7 +39,7 @@ func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (ope
return &response.CompletionResponse, nil
}
func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[types.CompletionResponse], errWithCode *types.OpenAIErrorWithStatusCode) {
func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[string], errWithCode *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
@ -56,14 +57,14 @@ func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest
ModelName: request.Model,
}
return requester.RequestStream[types.CompletionResponse](p.Requester, resp, chatHandler.handlerCompletionStream)
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerCompletionStream)
}
func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, isFinished *bool, response *[]types.CompletionResponse) error {
func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return nil
return
}
// 去除前缀
@ -71,26 +72,27 @@ func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, isFinishe
// 如果等于 DONE 则结束
if string(*rawLine) == "[DONE]" {
*isFinished = true
return nil
errChan <- io.EOF
*rawLine = requester.StreamClosed
return
}
var openaiResponse OpenAIProviderCompletionResponse
err := json.Unmarshal(*rawLine, &openaiResponse)
if err != nil {
return common.ErrorToOpenAIError(err)
errChan <- common.ErrorToOpenAIError(err)
return
}
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
if error != nil {
return error
errChan <- error
return
}
dataChan <- string(*rawLine)
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText
*response = append(*response, openaiResponse.CompletionResponse)
return nil
}

View File

@ -32,7 +32,7 @@ func (p *PalmProvider) CreateChatCompletion(request *types.ChatCompletionRequest
return p.convertToChatOpenai(palmResponse, request)
}
func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
@ -50,7 +50,7 @@ func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionR
Request: request,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
}
func (p *PalmProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -142,11 +142,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *PaLMChatReques
}
// 转换为OpenAI聊天流式请求体
func (h *palmStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
func (h *palmStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return nil
return
}
// 去除前缀
@ -155,19 +155,21 @@ func (h *palmStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, res
var palmChatResponse PaLMChatResponse
err := json.Unmarshal(*rawLine, &palmChatResponse)
if err != nil {
return common.ErrorToOpenAIError(err)
errChan <- common.ErrorToOpenAIError(err)
return
}
error := errorHandle(&palmChatResponse.PaLMErrorResponse)
if error != nil {
return error
errChan <- error
return
}
return h.convertToOpenaiStream(&palmChatResponse, response)
h.convertToOpenaiStream(&palmChatResponse, dataChan, errChan)
}
func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResponse, response *[]types.ChatCompletionStreamResponse) error {
func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResponse, dataChan chan string, errChan chan error) {
var choice types.ChatCompletionStreamChoice
if len(palmChatResponse.Candidates) > 0 {
choice.Delta.Content = palmChatResponse.Candidates[0].Content
@ -182,10 +184,9 @@ func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResp
Created: common.GetTimestamp(),
}
*response = append(*response, streamResponse)
responseBody, _ := json.Marshal(streamResponse)
dataChan <- string(responseBody)
h.Usage.CompletionTokens += common.CountTokenText(palmChatResponse.Candidates[0].Content, h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
return nil
}

View File

@ -32,7 +32,7 @@ func (p *TencentProvider) CreateChatCompletion(request *types.ChatCompletionRequ
return p.convertToChatOpenai(tencentChatResponse, request)
}
func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
@ -50,7 +50,7 @@ func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompleti
Request: request,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
}
func (p *TencentProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -157,11 +157,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *TencentChatReq
}
// 转换为OpenAI聊天流式请求体
func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data:") {
*rawLine = nil
return nil
return
}
// 去除前缀
@ -170,19 +170,21 @@ func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool,
var tencentChatResponse TencentChatResponse
err := json.Unmarshal(*rawLine, &tencentChatResponse)
if err != nil {
return common.ErrorToOpenAIError(err)
errChan <- common.ErrorToOpenAIError(err)
return
}
error := errorHandle(&tencentChatResponse.TencentResponseError)
if error != nil {
return error
errChan <- error
return
}
return h.convertToOpenaiStream(&tencentChatResponse, response)
h.convertToOpenaiStream(&tencentChatResponse, dataChan, errChan)
}
func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, response *[]types.ChatCompletionStreamResponse) error {
func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, dataChan chan string, errChan chan error) {
streamResponse := types.ChatCompletionStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
@ -197,10 +199,9 @@ func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *Tencen
streamResponse.Choices = append(streamResponse.Choices, choice)
}
*response = append(*response, streamResponse)
responseBody, _ := json.Marshal(streamResponse)
dataChan <- string(responseBody)
h.Usage.CompletionTokens += common.CountTokenText(tencentChatResponse.Choices[0].Delta.Content, h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
return nil
}

View File

@ -40,7 +40,7 @@ func (p *XunfeiProvider) CreateChatCompletion(request *types.ChatCompletionReque
}
func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
wsConn, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
@ -53,7 +53,7 @@ func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletio
Request: request,
}
return requester.SendWSJsonRequest[types.ChatCompletionStreamResponse](wsConn, xunfeiRequest, chatHandler.handlerStream)
return requester.SendWSJsonRequest[string](wsConn, xunfeiRequest, chatHandler.handlerStream)
}
func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (*websocket.Conn, *types.OpenAIErrorWithStatusCode) {
@ -123,23 +123,26 @@ func (p *XunfeiProvider) convertFromChatOpenai(request *types.ChatCompletionRequ
func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterface[XunfeiChatResponse]) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
var content string
var xunfeiResponse XunfeiChatResponse
dataChan, errChan := stream.Recv()
for {
response, err := stream.Recv()
stop := false
for !stop {
select {
case response := <-dataChan:
if len(response.Payload.Choices.Text) == 0 {
continue
}
xunfeiResponse = response
content += xunfeiResponse.Payload.Choices.Text[0].Content
case err := <-errChan:
if err != nil && !errors.Is(err, io.EOF) {
return nil, common.ErrorWrapper(err, "xunfei_failed", http.StatusInternalServerError)
}
if err != nil && !errors.Is(err, io.EOF) {
return nil, common.ErrorWrapper(err, "xunfei_failed", http.StatusInternalServerError)
if errors.Is(err, io.EOF) {
stop = true
}
}
if errors.Is(err, io.EOF) && response == nil {
break
}
if len((*response)[0].Payload.Choices.Text) == 0 {
continue
}
xunfeiResponse = (*response)[0]
content += xunfeiResponse.Payload.Choices.Text[0].Content
}
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
@ -193,7 +196,7 @@ func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterfa
}
func (h *xunfeiHandler) handlerData(rawLine *[]byte, isFinished *bool) (*XunfeiChatResponse, error) {
// 如果rawLine 前缀不为data:,则直接返回
// 如果rawLine 前缀不为{,则直接返回
if !strings.HasPrefix(string(*rawLine), "{") {
*rawLine = nil
return nil, nil
@ -221,34 +224,47 @@ func (h *xunfeiHandler) handlerData(rawLine *[]byte, isFinished *bool) (*XunfeiC
return &xunfeiChatResponse, nil
}
func (h *xunfeiHandler) handlerNotStream(rawLine *[]byte, isFinished *bool, response *[]XunfeiChatResponse) error {
xunfeiChatResponse, err := h.handlerData(rawLine, isFinished)
func (h *xunfeiHandler) handlerNotStream(rawLine *[]byte, dataChan chan XunfeiChatResponse, errChan chan error) {
isFinished := false
xunfeiChatResponse, err := h.handlerData(rawLine, &isFinished)
if err != nil {
return err
errChan <- err
return
}
if *rawLine == nil {
return nil
return
}
*response = append(*response, *xunfeiChatResponse)
return nil
dataChan <- *xunfeiChatResponse
if isFinished {
errChan <- io.EOF
*rawLine = requester.StreamClosed
}
}
func (h *xunfeiHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
xunfeiChatResponse, err := h.handlerData(rawLine, isFinished)
func (h *xunfeiHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
isFinished := false
xunfeiChatResponse, err := h.handlerData(rawLine, &isFinished)
if err != nil {
return err
errChan <- err
return
}
if *rawLine == nil {
return nil
return
}
return h.convertToOpenaiStream(xunfeiChatResponse, response)
h.convertToOpenaiStream(xunfeiChatResponse, dataChan, errChan)
if isFinished {
errChan <- io.EOF
*rawLine = requester.StreamClosed
}
}
func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResponse, response *[]types.ChatCompletionStreamResponse) error {
func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResponse, dataChan chan string, errChan chan error) {
if len(xunfeiChatResponse.Payload.Choices.Text) == 0 {
xunfeiChatResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
}
@ -293,15 +309,15 @@ func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResp
if xunfeiText.FunctionCall == nil {
chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletion)
responseBody, _ := json.Marshal(chatCompletion)
dataChan <- string(responseBody)
} else {
choices := choice.ConvertOpenaiStream()
for _, choice := range choices {
chatCompletionCopy := chatCompletion
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletionCopy)
responseBody, _ := json.Marshal(chatCompletionCopy)
dataChan <- string(responseBody)
}
}
return nil
}

View File

@ -2,6 +2,7 @@ package zhipu
import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
@ -31,7 +32,7 @@ func (p *ZhipuProvider) CreateChatCompletion(request *types.ChatCompletionReques
return p.convertToChatOpenai(zhipuChatResponse, request)
}
func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
@ -49,7 +50,7 @@ func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletion
Request: request,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
}
func (p *ZhipuProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
@ -140,35 +141,38 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest {
}
// 转换为OpenAI聊天流式请求体
func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data: 或者 meta:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return nil
return
}
*rawLine = (*rawLine)[6:]
if strings.HasPrefix(string(*rawLine), "[DONE]") {
*isFinished = true
return nil
errChan <- io.EOF
*rawLine = requester.StreamClosed
return
}
zhipuResponse := &ZhipuStreamResponse{}
err := json.Unmarshal(*rawLine, zhipuResponse)
if err != nil {
return common.ErrorToOpenAIError(err)
errChan <- common.ErrorToOpenAIError(err)
return
}
error := errorHandle(&zhipuResponse.Error)
if error != nil {
return error
errChan <- error
return
}
return h.convertToOpenaiStream(zhipuResponse, response)
h.convertToOpenaiStream(zhipuResponse, dataChan, errChan)
}
func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamResponse, response *[]types.ChatCompletionStreamResponse) error {
func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamResponse, dataChan chan string, errChan chan error) {
streamResponse := types.ChatCompletionStreamResponse{
ID: zhipuResponse.ID,
Object: "chat.completion.chunk",
@ -183,16 +187,16 @@ func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamRes
for _, choice := range choices {
chatCompletionCopy := streamResponse
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, chatCompletionCopy)
responseBody, _ := json.Marshal(chatCompletionCopy)
dataChan <- string(responseBody)
}
} else {
streamResponse.Choices = []types.ChatCompletionStreamChoice{choice}
*response = append(*response, streamResponse)
responseBody, _ := json.Marshal(streamResponse)
dataChan <- string(responseBody)
}
if zhipuResponse.Usage != nil {
*h.Usage = *zhipuResponse.Usage
}
return nil
}