🐛 fix: stream mode delay issue (#53)
This commit is contained in:
parent
705804e6dd
commit
d7193b8e46
@ -127,11 +127,16 @@ func RequestStream[T streamable](requester *HTTPRequester, resp *http.Response,
|
|||||||
return nil, HandleErrorResp(resp, requester.ErrorHandler)
|
return nil, HandleErrorResp(resp, requester.ErrorHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &streamReader[T]{
|
stream := &streamReader[T]{
|
||||||
reader: bufio.NewReader(resp.Body),
|
reader: bufio.NewReader(resp.Body),
|
||||||
response: resp,
|
response: resp,
|
||||||
handlerPrefix: handlerPrefix,
|
handlerPrefix: handlerPrefix,
|
||||||
}, nil
|
|
||||||
|
DataChan: make(chan T),
|
||||||
|
ErrChan: make(chan error),
|
||||||
|
}
|
||||||
|
|
||||||
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置请求体
|
// 设置请求体
|
||||||
|
@ -3,16 +3,12 @@ package requester
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 流处理函数,判断依据如下:
|
var StreamClosed = []byte("stream_closed")
|
||||||
// 1.如果有错误信息,则直接返回错误信息
|
|
||||||
// 2.如果isFinished=true,则返回io.EOF,并且如果response不为空,还将返回response
|
type HandlerPrefix[T streamable] func(rawLine *[]byte, dataChan chan T, errChan chan error)
|
||||||
// 3.如果rawLine=nil 或者 response长度为0,则直接跳过
|
|
||||||
// 4.如果以上条件都不满足,则返回response
|
|
||||||
type HandlerPrefix[T streamable] func(rawLine *[]byte, isFinished *bool, response *[]T) error
|
|
||||||
|
|
||||||
type streamable interface {
|
type streamable interface {
|
||||||
// types.ChatCompletionStreamResponse | types.CompletionResponse
|
// types.ChatCompletionStreamResponse | types.CompletionResponse
|
||||||
@ -20,57 +16,48 @@ type streamable interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type StreamReaderInterface[T streamable] interface {
|
type StreamReaderInterface[T streamable] interface {
|
||||||
Recv() (*[]T, error)
|
Recv() (<-chan T, <-chan error)
|
||||||
Close()
|
Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
type streamReader[T streamable] struct {
|
type streamReader[T streamable] struct {
|
||||||
isFinished bool
|
|
||||||
|
|
||||||
reader *bufio.Reader
|
reader *bufio.Reader
|
||||||
response *http.Response
|
response *http.Response
|
||||||
|
|
||||||
handlerPrefix HandlerPrefix[T]
|
handlerPrefix HandlerPrefix[T]
|
||||||
|
|
||||||
|
DataChan chan T
|
||||||
|
ErrChan chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stream *streamReader[T]) Recv() (response *[]T, err error) {
|
func (stream *streamReader[T]) Recv() (<-chan T, <-chan error) {
|
||||||
if stream.isFinished {
|
go stream.processLines()
|
||||||
err = io.EOF
|
|
||||||
return
|
return stream.DataChan, stream.ErrChan
|
||||||
}
|
|
||||||
response, err = stream.processLines()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gocognit
|
//nolint:gocognit
|
||||||
func (stream *streamReader[T]) processLines() (*[]T, error) {
|
func (stream *streamReader[T]) processLines() {
|
||||||
for {
|
for {
|
||||||
rawLine, readErr := stream.reader.ReadBytes('\n')
|
rawLine, readErr := stream.reader.ReadBytes('\n')
|
||||||
if readErr != nil {
|
if readErr != nil {
|
||||||
return nil, readErr
|
stream.ErrChan <- readErr
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
noSpaceLine := bytes.TrimSpace(rawLine)
|
noSpaceLine := bytes.TrimSpace(rawLine)
|
||||||
|
if len(noSpaceLine) == 0 {
|
||||||
var response []T
|
|
||||||
err := stream.handlerPrefix(&noSpaceLine, &stream.isFinished, &response)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if stream.isFinished {
|
|
||||||
if len(response) > 0 {
|
|
||||||
return &response, io.EOF
|
|
||||||
}
|
|
||||||
return nil, io.EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
if noSpaceLine == nil || len(response) == 0 {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
return &response, nil
|
stream.handlerPrefix(&noSpaceLine, stream.DataChan, stream.ErrChan)
|
||||||
|
|
||||||
|
if noSpaceLine == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(noSpaceLine, StreamClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,55 +1,41 @@
|
|||||||
package requester
|
package requester
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"bytes"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
type wsReader[T streamable] struct {
|
type wsReader[T streamable] struct {
|
||||||
isFinished bool
|
|
||||||
|
|
||||||
reader *websocket.Conn
|
reader *websocket.Conn
|
||||||
handlerPrefix HandlerPrefix[T]
|
handlerPrefix HandlerPrefix[T]
|
||||||
|
|
||||||
|
DataChan chan T
|
||||||
|
ErrChan chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stream *wsReader[T]) Recv() (response *[]T, err error) {
|
func (stream *wsReader[T]) Recv() (<-chan T, <-chan error) {
|
||||||
if stream.isFinished {
|
go stream.processLines()
|
||||||
err = io.EOF
|
return stream.DataChan, stream.ErrChan
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response, err = stream.processLines()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (stream *wsReader[T]) processLines() (*[]T, error) {
|
func (stream *wsReader[T]) processLines() {
|
||||||
for {
|
for {
|
||||||
_, msg, err := stream.reader.ReadMessage()
|
_, msg, err := stream.reader.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
stream.ErrChan <- err
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var response []T
|
stream.handlerPrefix(&msg, stream.DataChan, stream.ErrChan)
|
||||||
err = stream.handlerPrefix(&msg, &stream.isFinished, &response)
|
|
||||||
|
|
||||||
if err != nil {
|
if msg == nil {
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if stream.isFinished {
|
|
||||||
if len(response) > 0 {
|
|
||||||
return &response, io.EOF
|
|
||||||
}
|
|
||||||
return nil, io.EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
if msg == nil || len(response) == 0 {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
return &response, nil
|
if bytes.Equal(msg, StreamClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,10 +38,15 @@ func SendWSJsonRequest[T streamable](conn *websocket.Conn, data any, handlerPref
|
|||||||
return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError)
|
return nil, common.ErrorWrapper(err, "ws_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &wsReader[T]{
|
stream := &wsReader[T]{
|
||||||
reader: conn,
|
reader: conn,
|
||||||
handlerPrefix: handlerPrefix,
|
handlerPrefix: handlerPrefix,
|
||||||
}, nil
|
|
||||||
|
DataChan: make(chan T),
|
||||||
|
ErrChan: make(chan error),
|
||||||
|
}
|
||||||
|
|
||||||
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置请求头
|
// 设置请求头
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@ -53,13 +52,13 @@ func RelayChat(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if chatRequest.Stream {
|
if chatRequest.Stream {
|
||||||
var response requester.StreamReaderInterface[types.ChatCompletionStreamResponse]
|
var response requester.StreamReaderInterface[string]
|
||||||
response, errWithCode = chatProvider.CreateChatCompletionStream(&chatRequest)
|
response, errWithCode = chatProvider.CreateChatCompletionStream(&chatRequest)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
errorHelper(c, errWithCode)
|
errorHelper(c, errWithCode)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
errWithCode = responseStreamClient[types.ChatCompletionStreamResponse](c, response)
|
errWithCode = responseStreamClient(c, response)
|
||||||
} else {
|
} else {
|
||||||
var response *types.ChatCompletionResponse
|
var response *types.ChatCompletionResponse
|
||||||
response, errWithCode = chatProvider.CreateChatCompletion(&chatRequest)
|
response, errWithCode = chatProvider.CreateChatCompletion(&chatRequest)
|
||||||
@ -70,8 +69,6 @@ func RelayChat(c *gin.Context) {
|
|||||||
errWithCode = responseJsonClient(c, response)
|
errWithCode = responseJsonClient(c, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println(usage)
|
|
||||||
|
|
||||||
// 如果报错,则退还配额
|
// 如果报错,则退还配额
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
quotaInfo.undo(c, errWithCode)
|
quotaInfo.undo(c, errWithCode)
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/requester"
|
||||||
providersBase "one-api/providers/base"
|
providersBase "one-api/providers/base"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
|
|
||||||
@ -51,14 +52,16 @@ func RelayCompletions(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if completionRequest.Stream {
|
if completionRequest.Stream {
|
||||||
response, errWithCode := completionProvider.CreateCompletionStream(&completionRequest)
|
var response requester.StreamReaderInterface[string]
|
||||||
|
response, errWithCode = completionProvider.CreateCompletionStream(&completionRequest)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
errorHelper(c, errWithCode)
|
errorHelper(c, errWithCode)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
errWithCode = responseStreamClient[types.CompletionResponse](c, response)
|
errWithCode = responseStreamClient(c, response)
|
||||||
} else {
|
} else {
|
||||||
response, errWithCode := completionProvider.CreateCompletion(&completionRequest)
|
var response *types.CompletionResponse
|
||||||
|
response, errWithCode = completionProvider.CreateCompletion(&completionRequest)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
errorHelper(c, errWithCode)
|
errorHelper(c, errWithCode)
|
||||||
return
|
return
|
||||||
|
@ -154,34 +154,25 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseStreamClient[T any](c *gin.Context, stream requester.StreamReaderInterface[T]) *types.OpenAIErrorWithStatusCode {
|
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string]) *types.OpenAIErrorWithStatusCode {
|
||||||
requester.SetEventStreamHeaders(c)
|
requester.SetEventStreamHeaders(c)
|
||||||
|
dataChan, errChan := stream.Recv()
|
||||||
|
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
for {
|
select {
|
||||||
response, err := stream.Recv()
|
case data := <-dataChan:
|
||||||
if errors.Is(err, io.EOF) {
|
fmt.Fprintln(w, "data: "+data+"\n")
|
||||||
if response != nil && len(*response) > 0 {
|
return true
|
||||||
for _, streamResponse := range *response {
|
case err := <-errChan:
|
||||||
responseBody, _ := json.Marshal(streamResponse)
|
if !errors.Is(err, io.EOF) {
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)})
|
fmt.Fprintln(w, "data: "+err.Error()+"\n")
|
||||||
}
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
fmt.Fprintln(w, "data: [DONE]")
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + err.Error()})
|
return false
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, streamResponse := range *response {
|
|
||||||
responseBody, _ := json.Marshal(streamResponse)
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,7 @@ func (p *AliProvider) CreateChatCompletion(request *types.ChatCompletionRequest)
|
|||||||
return p.convertToChatOpenai(aliResponse, request)
|
return p.convertToChatOpenai(aliResponse, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||||
req, errWithCode := p.getAliChatRequest(request)
|
req, errWithCode := p.getAliChatRequest(request)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return nil, errWithCode
|
return nil, errWithCode
|
||||||
@ -52,7 +52,7 @@ func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRe
|
|||||||
Request: request,
|
Request: request,
|
||||||
}
|
}
|
||||||
|
|
||||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *AliProvider) getAliChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
func (p *AliProvider) getAliChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||||
@ -162,11 +162,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *AliChatRequest
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 转换为OpenAI聊天流式请求体
|
// 转换为OpenAI聊天流式请求体
|
||||||
func (h *aliStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *aliStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||||
// 如果rawLine 前缀不为data:,则直接返回
|
// 如果rawLine 前缀不为data:,则直接返回
|
||||||
if !strings.HasPrefix(string(*rawLine), "data:") {
|
if !strings.HasPrefix(string(*rawLine), "data:") {
|
||||||
*rawLine = nil
|
*rawLine = nil
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 去除前缀
|
// 去除前缀
|
||||||
@ -175,19 +175,21 @@ func (h *aliStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, resp
|
|||||||
var aliResponse AliChatResponse
|
var aliResponse AliChatResponse
|
||||||
err := json.Unmarshal(*rawLine, &aliResponse)
|
err := json.Unmarshal(*rawLine, &aliResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorToOpenAIError(err)
|
errChan <- common.ErrorToOpenAIError(err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
error := errorHandle(&aliResponse.AliError)
|
error := errorHandle(&aliResponse.AliError)
|
||||||
if error != nil {
|
if error != nil {
|
||||||
return error
|
errChan <- error
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.convertToOpenaiStream(&aliResponse, response)
|
h.convertToOpenaiStream(&aliResponse, dataChan, errChan)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, dataChan chan string, errChan chan error) {
|
||||||
content := aliResponse.Output.Choices[0].Message.StringContent()
|
content := aliResponse.Output.Choices[0].Message.StringContent()
|
||||||
|
|
||||||
var choice types.ChatCompletionStreamChoice
|
var choice types.ChatCompletionStreamChoice
|
||||||
@ -222,7 +224,6 @@ func (h *aliStreamHandler) convertToOpenaiStream(aliResponse *AliChatResponse, r
|
|||||||
h.Usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
h.Usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
*response = append(*response, streamResponse)
|
responseBody, _ := json.Marshal(streamResponse)
|
||||||
|
dataChan <- string(responseBody)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionReq
|
|||||||
return &response.ChatCompletionResponse, nil
|
return &response.ChatCompletionResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return nil, errWithCode
|
return nil, errWithCode
|
||||||
@ -57,7 +57,7 @@ func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatComplet
|
|||||||
ModelName: request.Model,
|
ModelName: request.Model,
|
||||||
}
|
}
|
||||||
|
|
||||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream)
|
return requester.RequestStream[string](p.Requester, resp, chatHandler.HandlerChatStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取聊天请求体
|
// 获取聊天请求体
|
||||||
|
@ -2,6 +2,7 @@ package baidu
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
@ -31,7 +32,7 @@ func (p *BaiduProvider) CreateChatCompletion(request *types.ChatCompletionReques
|
|||||||
return p.convertToChatOpenai(baiduResponse, request)
|
return p.convertToChatOpenai(baiduResponse, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||||
req, errWithCode := p.getBaiduChatRequest(request)
|
req, errWithCode := p.getBaiduChatRequest(request)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return nil, errWithCode
|
return nil, errWithCode
|
||||||
@ -49,7 +50,7 @@ func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletion
|
|||||||
Request: request,
|
Request: request,
|
||||||
}
|
}
|
||||||
|
|
||||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *BaiduProvider) getBaiduChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
func (p *BaiduProvider) getBaiduChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||||
@ -178,11 +179,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *BaiduChatReque
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 转换为OpenAI聊天流式请求体
|
// 转换为OpenAI聊天流式请求体
|
||||||
func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||||
// 如果rawLine 前缀不为data:,则直接返回
|
// 如果rawLine 前缀不为data:,则直接返回
|
||||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||||
*rawLine = nil
|
*rawLine = nil
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 去除前缀
|
// 去除前缀
|
||||||
@ -191,18 +192,26 @@ func (h *baiduStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, re
|
|||||||
var baiduResponse BaiduChatStreamResponse
|
var baiduResponse BaiduChatStreamResponse
|
||||||
err := json.Unmarshal(*rawLine, &baiduResponse)
|
err := json.Unmarshal(*rawLine, &baiduResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorToOpenAIError(err)
|
errChan <- common.ErrorToOpenAIError(err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
error := errorHandle(&baiduResponse.BaiduError)
|
||||||
|
if error != nil {
|
||||||
|
errChan <- error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.convertToOpenaiStream(&baiduResponse, dataChan, errChan)
|
||||||
|
|
||||||
if baiduResponse.IsEnd {
|
if baiduResponse.IsEnd {
|
||||||
*isFinished = true
|
errChan <- io.EOF
|
||||||
|
*rawLine = requester.StreamClosed
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.convertToOpenaiStream(&baiduResponse, response)
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStreamResponse, dataChan chan string, errChan chan error) {
|
||||||
choice := types.ChatCompletionStreamChoice{
|
choice := types.ChatCompletionStreamChoice{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Delta: types.ChatCompletionStreamChoiceDelta{
|
Delta: types.ChatCompletionStreamChoiceDelta{
|
||||||
@ -240,19 +249,19 @@ func (h *baiduStreamHandler) convertToOpenaiStream(baiduResponse *BaiduChatStrea
|
|||||||
|
|
||||||
if baiduResponse.FunctionCall == nil {
|
if baiduResponse.FunctionCall == nil {
|
||||||
chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice}
|
chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||||
*response = append(*response, chatCompletion)
|
responseBody, _ := json.Marshal(chatCompletion)
|
||||||
|
dataChan <- string(responseBody)
|
||||||
} else {
|
} else {
|
||||||
choices := choice.ConvertOpenaiStream()
|
choices := choice.ConvertOpenaiStream()
|
||||||
for _, choice := range choices {
|
for _, choice := range choices {
|
||||||
chatCompletionCopy := chatCompletion
|
chatCompletionCopy := chatCompletion
|
||||||
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
|
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||||
*response = append(*response, chatCompletionCopy)
|
responseBody, _ := json.Marshal(chatCompletionCopy)
|
||||||
|
dataChan <- string(responseBody)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h.Usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
h.Usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||||
h.Usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
h.Usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||||
h.Usage.CompletionTokens += baiduResponse.Usage.CompletionTokens
|
h.Usage.CompletionTokens += baiduResponse.Usage.CompletionTokens
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -41,14 +41,14 @@ type ProviderInterface interface {
|
|||||||
type CompletionInterface interface {
|
type CompletionInterface interface {
|
||||||
ProviderInterface
|
ProviderInterface
|
||||||
CreateCompletion(request *types.CompletionRequest) (*types.CompletionResponse, *types.OpenAIErrorWithStatusCode)
|
CreateCompletion(request *types.CompletionRequest) (*types.CompletionResponse, *types.OpenAIErrorWithStatusCode)
|
||||||
CreateCompletionStream(request *types.CompletionRequest) (requester.StreamReaderInterface[types.CompletionResponse], *types.OpenAIErrorWithStatusCode)
|
CreateCompletionStream(request *types.CompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 聊天接口
|
// 聊天接口
|
||||||
type ChatInterface interface {
|
type ChatInterface interface {
|
||||||
ProviderInterface
|
ProviderInterface
|
||||||
CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode)
|
CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode)
|
||||||
CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode)
|
CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 嵌入接口
|
// 嵌入接口
|
||||||
|
@ -3,6 +3,7 @@ package claude
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
@ -32,7 +33,7 @@ func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionReque
|
|||||||
return p.convertToChatOpenai(claudeResponse, request)
|
return p.convertToChatOpenai(claudeResponse, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||||
req, errWithCode := p.getChatRequest(request)
|
req, errWithCode := p.getChatRequest(request)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return nil, errWithCode
|
return nil, errWithCode
|
||||||
@ -50,7 +51,7 @@ func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletio
|
|||||||
Request: request,
|
Request: request,
|
||||||
}
|
}
|
||||||
|
|
||||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||||
@ -149,11 +150,11 @@ func (p *ClaudeProvider) convertToChatOpenai(response *ClaudeResponse, request *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 转换为OpenAI聊天流式请求体
|
// 转换为OpenAI聊天流式请求体
|
||||||
func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||||
// 如果rawLine 前缀不为data:,则直接返回
|
// 如果rawLine 前缀不为data:,则直接返回
|
||||||
if !strings.HasPrefix(string(*rawLine), `data: {"type": "completion"`) {
|
if !strings.HasPrefix(string(*rawLine), `data: {"type": "completion"`) {
|
||||||
*rawLine = nil
|
*rawLine = nil
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 去除前缀
|
// 去除前缀
|
||||||
@ -162,17 +163,26 @@ func (h *claudeStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, r
|
|||||||
var claudeResponse *ClaudeResponse
|
var claudeResponse *ClaudeResponse
|
||||||
err := json.Unmarshal(*rawLine, claudeResponse)
|
err := json.Unmarshal(*rawLine, claudeResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorToOpenAIError(err)
|
errChan <- common.ErrorToOpenAIError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
error := errorHandle(&claudeResponse.ClaudeResponseError)
|
||||||
|
if error != nil {
|
||||||
|
errChan <- error
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if claudeResponse.StopReason == "stop_sequence" {
|
if claudeResponse.StopReason == "stop_sequence" {
|
||||||
*isFinished = true
|
errChan <- io.EOF
|
||||||
|
*rawLine = requester.StreamClosed
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.convertToOpenaiStream(claudeResponse, response)
|
h.convertToOpenaiStream(claudeResponse, dataChan, errChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeResponse, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeResponse, dataChan chan string, errChan chan error) {
|
||||||
var choice types.ChatCompletionStreamChoice
|
var choice types.ChatCompletionStreamChoice
|
||||||
choice.Delta.Content = claudeResponse.Completion
|
choice.Delta.Content = claudeResponse.Completion
|
||||||
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
||||||
@ -185,9 +195,8 @@ func (h *claudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeRespon
|
|||||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||||
}
|
}
|
||||||
|
|
||||||
*response = append(*response, chatCompletion)
|
responseBody, _ := json.Marshal(chatCompletion)
|
||||||
|
dataChan <- string(responseBody)
|
||||||
|
|
||||||
h.Usage.PromptTokens += common.CountTokenText(claudeResponse.Completion, h.Request.Model)
|
h.Usage.PromptTokens += common.CountTokenText(claudeResponse.Completion, h.Request.Model)
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -37,7 +37,7 @@ func (p *GeminiProvider) CreateChatCompletion(request *types.ChatCompletionReque
|
|||||||
return p.convertToChatOpenai(geminiChatResponse, request)
|
return p.convertToChatOpenai(geminiChatResponse, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||||
req, errWithCode := p.getChatRequest(request)
|
req, errWithCode := p.getChatRequest(request)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return nil, errWithCode
|
return nil, errWithCode
|
||||||
@ -55,7 +55,7 @@ func (p *GeminiProvider) CreateChatCompletionStream(request *types.ChatCompletio
|
|||||||
Request: request,
|
Request: request,
|
||||||
}
|
}
|
||||||
|
|
||||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GeminiProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
func (p *GeminiProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||||
@ -228,11 +228,11 @@ func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, reque
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 转换为OpenAI聊天流式请求体
|
// 转换为OpenAI聊天流式请求体
|
||||||
func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||||
// 如果rawLine 前缀不为data:,则直接返回
|
// 如果rawLine 前缀不为data:,则直接返回
|
||||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||||
*rawLine = nil
|
*rawLine = nil
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 去除前缀
|
// 去除前缀
|
||||||
@ -241,19 +241,21 @@ func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, r
|
|||||||
var geminiResponse GeminiChatResponse
|
var geminiResponse GeminiChatResponse
|
||||||
err := json.Unmarshal(*rawLine, &geminiResponse)
|
err := json.Unmarshal(*rawLine, &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorToOpenAIError(err)
|
errChan <- common.ErrorToOpenAIError(err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
error := errorHandle(&geminiResponse.GeminiErrorResponse)
|
error := errorHandle(&geminiResponse.GeminiErrorResponse)
|
||||||
if error != nil {
|
if error != nil {
|
||||||
return error
|
errChan <- error
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.convertToOpenaiStream(&geminiResponse, response)
|
h.convertToOpenaiStream(&geminiResponse, dataChan, errChan)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string, errChan chan error) {
|
||||||
choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates))
|
choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates))
|
||||||
|
|
||||||
for i, candidate := range geminiResponse.Candidates {
|
for i, candidate := range geminiResponse.Candidates {
|
||||||
@ -275,10 +277,9 @@ func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatRe
|
|||||||
Choices: choices,
|
Choices: choices,
|
||||||
}
|
}
|
||||||
|
|
||||||
*response = append(*response, streamResponse)
|
responseBody, _ := json.Marshal(streamResponse)
|
||||||
|
dataChan <- string(responseBody)
|
||||||
|
|
||||||
h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model)
|
h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model)
|
||||||
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
|
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -32,7 +32,7 @@ func (p *MiniMaxProvider) CreateChatCompletion(request *types.ChatCompletionRequ
|
|||||||
return p.convertToChatOpenai(response, request)
|
return p.convertToChatOpenai(response, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||||
req, errWithCode := p.getChatRequest(request)
|
req, errWithCode := p.getChatRequest(request)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return nil, errWithCode
|
return nil, errWithCode
|
||||||
@ -50,7 +50,7 @@ func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompleti
|
|||||||
Request: request,
|
Request: request,
|
||||||
}
|
}
|
||||||
|
|
||||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *MiniMaxProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
func (p *MiniMaxProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||||
@ -191,11 +191,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *MiniMaxChatReq
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 转换为OpenAI聊天流式请求体
|
// 转换为OpenAI聊天流式请求体
|
||||||
func (h *minimaxStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *minimaxStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||||
// 如果rawLine 前缀不为data: 或者 meta:,则直接返回
|
// 如果rawLine 前缀不为data: 或者 meta:,则直接返回
|
||||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||||
*rawLine = nil
|
*rawLine = nil
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
*rawLine = (*rawLine)[6:]
|
*rawLine = (*rawLine)[6:]
|
||||||
@ -203,25 +203,27 @@ func (h *minimaxStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool,
|
|||||||
miniResponse := &MiniMaxChatResponse{}
|
miniResponse := &MiniMaxChatResponse{}
|
||||||
err := json.Unmarshal(*rawLine, miniResponse)
|
err := json.Unmarshal(*rawLine, miniResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorToOpenAIError(err)
|
errChan <- common.ErrorToOpenAIError(err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
error := errorHandle(&miniResponse.BaseResp)
|
error := errorHandle(&miniResponse.BaseResp)
|
||||||
if error != nil {
|
if error != nil {
|
||||||
return error
|
errChan <- error
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
choice := miniResponse.Choices[0]
|
choice := miniResponse.Choices[0]
|
||||||
|
|
||||||
if choice.Messages[0].FunctionCall != nil && choice.FinishReason == "" {
|
if choice.Messages[0].FunctionCall != nil && choice.FinishReason == "" {
|
||||||
*rawLine = nil
|
*rawLine = nil
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.convertToOpenaiStream(miniResponse, response)
|
h.convertToOpenaiStream(miniResponse, dataChan, errChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatResponse, dataChan chan string, errChan chan error) {
|
||||||
streamResponse := types.ChatCompletionStreamResponse{
|
streamResponse := types.ChatCompletionStreamResponse{
|
||||||
ID: miniResponse.RequestID,
|
ID: miniResponse.RequestID,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
@ -235,8 +237,8 @@ func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatRe
|
|||||||
if miniChoice.Messages[0].FunctionCall == nil && miniChoice.FinishReason != "" {
|
if miniChoice.Messages[0].FunctionCall == nil && miniChoice.FinishReason != "" {
|
||||||
streamResponse.ID = miniResponse.ID
|
streamResponse.ID = miniResponse.ID
|
||||||
openaiChoice.FinishReason = convertFinishReason(miniChoice.FinishReason)
|
openaiChoice.FinishReason = convertFinishReason(miniChoice.FinishReason)
|
||||||
h.appendResponse(&streamResponse, &openaiChoice, response)
|
dataChan <- h.getResponseString(&streamResponse, &openaiChoice)
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
openaiChoice.Delta = types.ChatCompletionStreamChoiceDelta{
|
openaiChoice.Delta = types.ChatCompletionStreamChoiceDelta{
|
||||||
@ -248,19 +250,17 @@ func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatRe
|
|||||||
convertChoices := openaiChoice.ConvertOpenaiStream()
|
convertChoices := openaiChoice.ConvertOpenaiStream()
|
||||||
for _, convertChoice := range convertChoices {
|
for _, convertChoice := range convertChoices {
|
||||||
chatCompletionCopy := streamResponse
|
chatCompletionCopy := streamResponse
|
||||||
h.appendResponse(&chatCompletionCopy, &convertChoice, response)
|
dataChan <- h.getResponseString(&chatCompletionCopy, &convertChoice)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
openaiChoice.Delta.Content = miniChoice.Messages[0].Text
|
openaiChoice.Delta.Content = miniChoice.Messages[0].Text
|
||||||
h.appendResponse(&streamResponse, &openaiChoice, response)
|
dataChan <- h.getResponseString(&streamResponse, &openaiChoice)
|
||||||
}
|
}
|
||||||
|
|
||||||
if miniResponse.Usage != nil {
|
if miniResponse.Usage != nil {
|
||||||
h.handleUsage(miniResponse)
|
h.handleUsage(miniResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *minimaxStreamHandler) handleFunctionCall(choice *Choice, openaiChoice *types.ChatCompletionStreamChoice) {
|
func (h *minimaxStreamHandler) handleFunctionCall(choice *Choice, openaiChoice *types.ChatCompletionStreamChoice) {
|
||||||
@ -274,9 +274,10 @@ func (h *minimaxStreamHandler) handleFunctionCall(choice *Choice, openaiChoice *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *minimaxStreamHandler) appendResponse(streamResponse *types.ChatCompletionStreamResponse, openaiChoice *types.ChatCompletionStreamChoice, response *[]types.ChatCompletionStreamResponse) {
|
func (h *minimaxStreamHandler) getResponseString(streamResponse *types.ChatCompletionStreamResponse, openaiChoice *types.ChatCompletionStreamChoice) string {
|
||||||
streamResponse.Choices = []types.ChatCompletionStreamChoice{*openaiChoice}
|
streamResponse.Choices = []types.ChatCompletionStreamChoice{*openaiChoice}
|
||||||
*response = append(*response, *streamResponse)
|
responseBody, _ := json.Marshal(streamResponse)
|
||||||
|
return string(responseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *minimaxStreamHandler) handleUsage(miniResponse *MiniMaxChatResponse) {
|
func (h *minimaxStreamHandler) handleUsage(miniResponse *MiniMaxChatResponse) {
|
||||||
|
@ -2,16 +2,19 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenAIStreamHandler struct {
|
type OpenAIStreamHandler struct {
|
||||||
Usage *types.Usage
|
Usage *types.Usage
|
||||||
ModelName string
|
ModelName string
|
||||||
|
isAzure bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
@ -43,7 +46,7 @@ func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionReque
|
|||||||
return &response.ChatCompletionResponse, nil
|
return &response.ChatCompletionResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return nil, errWithCode
|
return nil, errWithCode
|
||||||
@ -59,16 +62,17 @@ func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletio
|
|||||||
chatHandler := OpenAIStreamHandler{
|
chatHandler := OpenAIStreamHandler{
|
||||||
Usage: p.Usage,
|
Usage: p.Usage,
|
||||||
ModelName: request.Model,
|
ModelName: request.Model,
|
||||||
|
isAzure: p.IsAzure,
|
||||||
}
|
}
|
||||||
|
|
||||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream)
|
return requester.RequestStream[string](p.Requester, resp, chatHandler.HandlerChatStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||||
// 如果rawLine 前缀不为data:,则直接返回
|
// 如果rawLine 前缀不为data:,则直接返回
|
||||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||||
*rawLine = nil
|
*rawLine = nil
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 去除前缀
|
// 去除前缀
|
||||||
@ -76,26 +80,32 @@ func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, isFinished *boo
|
|||||||
|
|
||||||
// 如果等于 DONE 则结束
|
// 如果等于 DONE 则结束
|
||||||
if string(*rawLine) == "[DONE]" {
|
if string(*rawLine) == "[DONE]" {
|
||||||
*isFinished = true
|
errChan <- io.EOF
|
||||||
return nil
|
*rawLine = requester.StreamClosed
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var openaiResponse OpenAIProviderChatStreamResponse
|
var openaiResponse OpenAIProviderChatStreamResponse
|
||||||
err := json.Unmarshal(*rawLine, &openaiResponse)
|
err := json.Unmarshal(*rawLine, &openaiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorToOpenAIError(err)
|
errChan <- common.ErrorToOpenAIError(err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
|
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
|
||||||
if error != nil {
|
if error != nil {
|
||||||
return error
|
errChan <- error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dataChan <- string(*rawLine)
|
||||||
|
|
||||||
|
if h.isAzure {
|
||||||
|
// 阻塞 20ms
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
|
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
|
||||||
h.Usage.CompletionTokens += countTokenText
|
h.Usage.CompletionTokens += countTokenText
|
||||||
h.Usage.TotalTokens += countTokenText
|
h.Usage.TotalTokens += countTokenText
|
||||||
|
|
||||||
*response = append(*response, openaiResponse.ChatCompletionStreamResponse)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
@ -38,7 +39,7 @@ func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (ope
|
|||||||
return &response.CompletionResponse, nil
|
return &response.CompletionResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[types.CompletionResponse], errWithCode *types.OpenAIErrorWithStatusCode) {
|
func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[string], errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return nil, errWithCode
|
return nil, errWithCode
|
||||||
@ -56,14 +57,14 @@ func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest
|
|||||||
ModelName: request.Model,
|
ModelName: request.Model,
|
||||||
}
|
}
|
||||||
|
|
||||||
return requester.RequestStream[types.CompletionResponse](p.Requester, resp, chatHandler.handlerCompletionStream)
|
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerCompletionStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, isFinished *bool, response *[]types.CompletionResponse) error {
|
func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||||
// 如果rawLine 前缀不为data:,则直接返回
|
// 如果rawLine 前缀不为data:,则直接返回
|
||||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||||
*rawLine = nil
|
*rawLine = nil
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 去除前缀
|
// 去除前缀
|
||||||
@ -71,26 +72,27 @@ func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, isFinishe
|
|||||||
|
|
||||||
// 如果等于 DONE 则结束
|
// 如果等于 DONE 则结束
|
||||||
if string(*rawLine) == "[DONE]" {
|
if string(*rawLine) == "[DONE]" {
|
||||||
*isFinished = true
|
errChan <- io.EOF
|
||||||
return nil
|
*rawLine = requester.StreamClosed
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var openaiResponse OpenAIProviderCompletionResponse
|
var openaiResponse OpenAIProviderCompletionResponse
|
||||||
err := json.Unmarshal(*rawLine, &openaiResponse)
|
err := json.Unmarshal(*rawLine, &openaiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorToOpenAIError(err)
|
errChan <- common.ErrorToOpenAIError(err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
|
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
|
||||||
if error != nil {
|
if error != nil {
|
||||||
return error
|
errChan <- error
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dataChan <- string(*rawLine)
|
||||||
|
|
||||||
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
|
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
|
||||||
h.Usage.CompletionTokens += countTokenText
|
h.Usage.CompletionTokens += countTokenText
|
||||||
h.Usage.TotalTokens += countTokenText
|
h.Usage.TotalTokens += countTokenText
|
||||||
|
|
||||||
*response = append(*response, openaiResponse.CompletionResponse)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -32,7 +32,7 @@ func (p *PalmProvider) CreateChatCompletion(request *types.ChatCompletionRequest
|
|||||||
return p.convertToChatOpenai(palmResponse, request)
|
return p.convertToChatOpenai(palmResponse, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||||
req, errWithCode := p.getChatRequest(request)
|
req, errWithCode := p.getChatRequest(request)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return nil, errWithCode
|
return nil, errWithCode
|
||||||
@ -50,7 +50,7 @@ func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionR
|
|||||||
Request: request,
|
Request: request,
|
||||||
}
|
}
|
||||||
|
|
||||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PalmProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
func (p *PalmProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||||
@ -142,11 +142,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *PaLMChatReques
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 转换为OpenAI聊天流式请求体
|
// 转换为OpenAI聊天流式请求体
|
||||||
func (h *palmStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *palmStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||||
// 如果rawLine 前缀不为data:,则直接返回
|
// 如果rawLine 前缀不为data:,则直接返回
|
||||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||||
*rawLine = nil
|
*rawLine = nil
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 去除前缀
|
// 去除前缀
|
||||||
@ -155,19 +155,21 @@ func (h *palmStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, res
|
|||||||
var palmChatResponse PaLMChatResponse
|
var palmChatResponse PaLMChatResponse
|
||||||
err := json.Unmarshal(*rawLine, &palmChatResponse)
|
err := json.Unmarshal(*rawLine, &palmChatResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorToOpenAIError(err)
|
errChan <- common.ErrorToOpenAIError(err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
error := errorHandle(&palmChatResponse.PaLMErrorResponse)
|
error := errorHandle(&palmChatResponse.PaLMErrorResponse)
|
||||||
if error != nil {
|
if error != nil {
|
||||||
return error
|
errChan <- error
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.convertToOpenaiStream(&palmChatResponse, response)
|
h.convertToOpenaiStream(&palmChatResponse, dataChan, errChan)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResponse, dataChan chan string, errChan chan error) {
|
||||||
var choice types.ChatCompletionStreamChoice
|
var choice types.ChatCompletionStreamChoice
|
||||||
if len(palmChatResponse.Candidates) > 0 {
|
if len(palmChatResponse.Candidates) > 0 {
|
||||||
choice.Delta.Content = palmChatResponse.Candidates[0].Content
|
choice.Delta.Content = palmChatResponse.Candidates[0].Content
|
||||||
@ -182,10 +184,9 @@ func (h *palmStreamHandler) convertToOpenaiStream(palmChatResponse *PaLMChatResp
|
|||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
}
|
}
|
||||||
|
|
||||||
*response = append(*response, streamResponse)
|
responseBody, _ := json.Marshal(streamResponse)
|
||||||
|
dataChan <- string(responseBody)
|
||||||
|
|
||||||
h.Usage.CompletionTokens += common.CountTokenText(palmChatResponse.Candidates[0].Content, h.Request.Model)
|
h.Usage.CompletionTokens += common.CountTokenText(palmChatResponse.Candidates[0].Content, h.Request.Model)
|
||||||
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
|
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -32,7 +32,7 @@ func (p *TencentProvider) CreateChatCompletion(request *types.ChatCompletionRequ
|
|||||||
return p.convertToChatOpenai(tencentChatResponse, request)
|
return p.convertToChatOpenai(tencentChatResponse, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||||
req, errWithCode := p.getChatRequest(request)
|
req, errWithCode := p.getChatRequest(request)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return nil, errWithCode
|
return nil, errWithCode
|
||||||
@ -50,7 +50,7 @@ func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompleti
|
|||||||
Request: request,
|
Request: request,
|
||||||
}
|
}
|
||||||
|
|
||||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TencentProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
func (p *TencentProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||||
@ -157,11 +157,11 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *TencentChatReq
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 转换为OpenAI聊天流式请求体
|
// 转换为OpenAI聊天流式请求体
|
||||||
func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||||
// 如果rawLine 前缀不为data:,则直接返回
|
// 如果rawLine 前缀不为data:,则直接返回
|
||||||
if !strings.HasPrefix(string(*rawLine), "data:") {
|
if !strings.HasPrefix(string(*rawLine), "data:") {
|
||||||
*rawLine = nil
|
*rawLine = nil
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 去除前缀
|
// 去除前缀
|
||||||
@ -170,19 +170,21 @@ func (h *tencentStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool,
|
|||||||
var tencentChatResponse TencentChatResponse
|
var tencentChatResponse TencentChatResponse
|
||||||
err := json.Unmarshal(*rawLine, &tencentChatResponse)
|
err := json.Unmarshal(*rawLine, &tencentChatResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorToOpenAIError(err)
|
errChan <- common.ErrorToOpenAIError(err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
error := errorHandle(&tencentChatResponse.TencentResponseError)
|
error := errorHandle(&tencentChatResponse.TencentResponseError)
|
||||||
if error != nil {
|
if error != nil {
|
||||||
return error
|
errChan <- error
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.convertToOpenaiStream(&tencentChatResponse, response)
|
h.convertToOpenaiStream(&tencentChatResponse, dataChan, errChan)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *TencentChatResponse, dataChan chan string, errChan chan error) {
|
||||||
streamResponse := types.ChatCompletionStreamResponse{
|
streamResponse := types.ChatCompletionStreamResponse{
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
@ -197,10 +199,9 @@ func (h *tencentStreamHandler) convertToOpenaiStream(tencentChatResponse *Tencen
|
|||||||
streamResponse.Choices = append(streamResponse.Choices, choice)
|
streamResponse.Choices = append(streamResponse.Choices, choice)
|
||||||
}
|
}
|
||||||
|
|
||||||
*response = append(*response, streamResponse)
|
responseBody, _ := json.Marshal(streamResponse)
|
||||||
|
dataChan <- string(responseBody)
|
||||||
|
|
||||||
h.Usage.CompletionTokens += common.CountTokenText(tencentChatResponse.Choices[0].Delta.Content, h.Request.Model)
|
h.Usage.CompletionTokens += common.CountTokenText(tencentChatResponse.Choices[0].Delta.Content, h.Request.Model)
|
||||||
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
|
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -40,7 +40,7 @@ func (p *XunfeiProvider) CreateChatCompletion(request *types.ChatCompletionReque
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||||
wsConn, errWithCode := p.getChatRequest(request)
|
wsConn, errWithCode := p.getChatRequest(request)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return nil, errWithCode
|
return nil, errWithCode
|
||||||
@ -53,7 +53,7 @@ func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletio
|
|||||||
Request: request,
|
Request: request,
|
||||||
}
|
}
|
||||||
|
|
||||||
return requester.SendWSJsonRequest[types.ChatCompletionStreamResponse](wsConn, xunfeiRequest, chatHandler.handlerStream)
|
return requester.SendWSJsonRequest[string](wsConn, xunfeiRequest, chatHandler.handlerStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (*websocket.Conn, *types.OpenAIErrorWithStatusCode) {
|
func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (*websocket.Conn, *types.OpenAIErrorWithStatusCode) {
|
||||||
@ -123,23 +123,26 @@ func (p *XunfeiProvider) convertFromChatOpenai(request *types.ChatCompletionRequ
|
|||||||
func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterface[XunfeiChatResponse]) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterface[XunfeiChatResponse]) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
|
||||||
var content string
|
var content string
|
||||||
var xunfeiResponse XunfeiChatResponse
|
var xunfeiResponse XunfeiChatResponse
|
||||||
|
dataChan, errChan := stream.Recv()
|
||||||
|
|
||||||
for {
|
stop := false
|
||||||
response, err := stream.Recv()
|
for !stop {
|
||||||
|
select {
|
||||||
|
case response := <-dataChan:
|
||||||
|
if len(response.Payload.Choices.Text) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
xunfeiResponse = response
|
||||||
|
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
||||||
|
case err := <-errChan:
|
||||||
if err != nil && !errors.Is(err, io.EOF) {
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
return nil, common.ErrorWrapper(err, "xunfei_failed", http.StatusInternalServerError)
|
return nil, common.ErrorWrapper(err, "xunfei_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if errors.Is(err, io.EOF) && response == nil {
|
if errors.Is(err, io.EOF) {
|
||||||
break
|
stop = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if len((*response)[0].Payload.Choices.Text) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
xunfeiResponse = (*response)[0]
|
|
||||||
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||||
@ -193,7 +196,7 @@ func (h *xunfeiHandler) convertToChatOpenai(stream requester.StreamReaderInterfa
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *xunfeiHandler) handlerData(rawLine *[]byte, isFinished *bool) (*XunfeiChatResponse, error) {
|
func (h *xunfeiHandler) handlerData(rawLine *[]byte, isFinished *bool) (*XunfeiChatResponse, error) {
|
||||||
// 如果rawLine 前缀不为data:,则直接返回
|
// 如果rawLine 前缀不为{,则直接返回
|
||||||
if !strings.HasPrefix(string(*rawLine), "{") {
|
if !strings.HasPrefix(string(*rawLine), "{") {
|
||||||
*rawLine = nil
|
*rawLine = nil
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -221,34 +224,47 @@ func (h *xunfeiHandler) handlerData(rawLine *[]byte, isFinished *bool) (*XunfeiC
|
|||||||
return &xunfeiChatResponse, nil
|
return &xunfeiChatResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *xunfeiHandler) handlerNotStream(rawLine *[]byte, isFinished *bool, response *[]XunfeiChatResponse) error {
|
func (h *xunfeiHandler) handlerNotStream(rawLine *[]byte, dataChan chan XunfeiChatResponse, errChan chan error) {
|
||||||
xunfeiChatResponse, err := h.handlerData(rawLine, isFinished)
|
isFinished := false
|
||||||
|
xunfeiChatResponse, err := h.handlerData(rawLine, &isFinished)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
errChan <- err
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if *rawLine == nil {
|
if *rawLine == nil {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
*response = append(*response, *xunfeiChatResponse)
|
dataChan <- *xunfeiChatResponse
|
||||||
return nil
|
|
||||||
|
if isFinished {
|
||||||
|
errChan <- io.EOF
|
||||||
|
*rawLine = requester.StreamClosed
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *xunfeiHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *xunfeiHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||||
xunfeiChatResponse, err := h.handlerData(rawLine, isFinished)
|
isFinished := false
|
||||||
|
xunfeiChatResponse, err := h.handlerData(rawLine, &isFinished)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
errChan <- err
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if *rawLine == nil {
|
if *rawLine == nil {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.convertToOpenaiStream(xunfeiChatResponse, response)
|
h.convertToOpenaiStream(xunfeiChatResponse, dataChan, errChan)
|
||||||
|
|
||||||
|
if isFinished {
|
||||||
|
errChan <- io.EOF
|
||||||
|
*rawLine = requester.StreamClosed
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResponse, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResponse, dataChan chan string, errChan chan error) {
|
||||||
if len(xunfeiChatResponse.Payload.Choices.Text) == 0 {
|
if len(xunfeiChatResponse.Payload.Choices.Text) == 0 {
|
||||||
xunfeiChatResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
|
xunfeiChatResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{{}}
|
||||||
}
|
}
|
||||||
@ -293,15 +309,15 @@ func (h *xunfeiHandler) convertToOpenaiStream(xunfeiChatResponse *XunfeiChatResp
|
|||||||
|
|
||||||
if xunfeiText.FunctionCall == nil {
|
if xunfeiText.FunctionCall == nil {
|
||||||
chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice}
|
chatCompletion.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||||
*response = append(*response, chatCompletion)
|
responseBody, _ := json.Marshal(chatCompletion)
|
||||||
|
dataChan <- string(responseBody)
|
||||||
} else {
|
} else {
|
||||||
choices := choice.ConvertOpenaiStream()
|
choices := choice.ConvertOpenaiStream()
|
||||||
for _, choice := range choices {
|
for _, choice := range choices {
|
||||||
chatCompletionCopy := chatCompletion
|
chatCompletionCopy := chatCompletion
|
||||||
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
|
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||||
*response = append(*response, chatCompletionCopy)
|
responseBody, _ := json.Marshal(chatCompletionCopy)
|
||||||
|
dataChan <- string(responseBody)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package zhipu
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
@ -31,7 +32,7 @@ func (p *ZhipuProvider) CreateChatCompletion(request *types.ChatCompletionReques
|
|||||||
return p.convertToChatOpenai(zhipuChatResponse, request)
|
return p.convertToChatOpenai(zhipuChatResponse, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||||
req, errWithCode := p.getChatRequest(request)
|
req, errWithCode := p.getChatRequest(request)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
return nil, errWithCode
|
return nil, errWithCode
|
||||||
@ -49,7 +50,7 @@ func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletion
|
|||||||
Request: request,
|
Request: request,
|
||||||
}
|
}
|
||||||
|
|
||||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.handlerStream)
|
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ZhipuProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
func (p *ZhipuProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||||
@ -140,35 +141,38 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 转换为OpenAI聊天流式请求体
|
// 转换为OpenAI聊天流式请求体
|
||||||
func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *zhipuStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
|
||||||
// 如果rawLine 前缀不为data: 或者 meta:,则直接返回
|
// 如果rawLine 前缀不为data: 或者 meta:,则直接返回
|
||||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||||
*rawLine = nil
|
*rawLine = nil
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
*rawLine = (*rawLine)[6:]
|
*rawLine = (*rawLine)[6:]
|
||||||
|
|
||||||
if strings.HasPrefix(string(*rawLine), "[DONE]") {
|
if strings.HasPrefix(string(*rawLine), "[DONE]") {
|
||||||
*isFinished = true
|
errChan <- io.EOF
|
||||||
return nil
|
*rawLine = requester.StreamClosed
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
zhipuResponse := &ZhipuStreamResponse{}
|
zhipuResponse := &ZhipuStreamResponse{}
|
||||||
err := json.Unmarshal(*rawLine, zhipuResponse)
|
err := json.Unmarshal(*rawLine, zhipuResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return common.ErrorToOpenAIError(err)
|
errChan <- common.ErrorToOpenAIError(err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
error := errorHandle(&zhipuResponse.Error)
|
error := errorHandle(&zhipuResponse.Error)
|
||||||
if error != nil {
|
if error != nil {
|
||||||
return error
|
errChan <- error
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.convertToOpenaiStream(zhipuResponse, response)
|
h.convertToOpenaiStream(zhipuResponse, dataChan, errChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamResponse, response *[]types.ChatCompletionStreamResponse) error {
|
func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamResponse, dataChan chan string, errChan chan error) {
|
||||||
streamResponse := types.ChatCompletionStreamResponse{
|
streamResponse := types.ChatCompletionStreamResponse{
|
||||||
ID: zhipuResponse.ID,
|
ID: zhipuResponse.ID,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
@ -183,16 +187,16 @@ func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamRes
|
|||||||
for _, choice := range choices {
|
for _, choice := range choices {
|
||||||
chatCompletionCopy := streamResponse
|
chatCompletionCopy := streamResponse
|
||||||
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
|
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||||
*response = append(*response, chatCompletionCopy)
|
responseBody, _ := json.Marshal(chatCompletionCopy)
|
||||||
|
dataChan <- string(responseBody)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
streamResponse.Choices = []types.ChatCompletionStreamChoice{choice}
|
streamResponse.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||||
*response = append(*response, streamResponse)
|
responseBody, _ := json.Marshal(streamResponse)
|
||||||
|
dataChan <- string(responseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
if zhipuResponse.Usage != nil {
|
if zhipuResponse.Usage != nil {
|
||||||
*h.Usage = *zhipuResponse.Usage
|
*h.Usage = *zhipuResponse.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user