From eb260652b26a36f3f5e747253a7dee7fb3a1a4ad Mon Sep 17 00:00:00 2001 From: MartialBE Date: Sun, 26 May 2024 19:58:15 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20Support=20stream=5Foptions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- providers/baichuan/chat.go | 14 ++++++++++++++ providers/groq/chat.go | 13 +++++++++++++ providers/openai/base.go | 13 ++++++++++--- providers/openai/chat.go | 33 +++++++++++++++++++++++++-------- providers/openai/completion.go | 24 ++++++++++++++++++++++++ relay/chat.go | 34 +++++++++++++++++++++++++++++++++- relay/common.go | 14 ++++++++++++-- relay/completions.go | 34 +++++++++++++++++++++++++++++++++- types/chat.go | 3 +++ types/common.go | 4 ++++ types/completion.go | 33 +++++++++++++++++---------------- 11 files changed, 188 insertions(+), 31 deletions(-) diff --git a/providers/baichuan/chat.go b/providers/baichuan/chat.go index 4bc19361..36ed4f33 100644 --- a/providers/baichuan/chat.go +++ b/providers/baichuan/chat.go @@ -40,12 +40,26 @@ func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionReq } func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { + streamOptions := request.StreamOptions + // 如果支持流式返回Usage 则需要更改配置: + if p.SupportStreamOptions { + request.StreamOptions = &types.StreamOptions{ + IncludeUsage: true, + } + } else { + // 避免误传导致报错 + request.StreamOptions = nil + } + req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) if errWithCode != nil { return nil, errWithCode } defer req.Body.Close() + // 恢复原来的配置 + request.StreamOptions = streamOptions + // 发送请求 resp, errWithCode := p.Requester.SendRequestRaw(req) if errWithCode != nil { diff --git a/providers/groq/chat.go b/providers/groq/chat.go index 26eb008a..d41ccb30 100644 --- a/providers/groq/chat.go +++ b/providers/groq/chat.go @@ -40,6 +40,16 @@ func (p *GroqProvider) CreateChatCompletion(request *types.ChatCompletionRequest } func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { + streamOptions := request.StreamOptions + // 如果支持流式返回Usage 则需要更改配置: + if p.SupportStreamOptions { + request.StreamOptions = &types.StreamOptions{ + IncludeUsage: true, + } + } else { + // 避免误传导致报错 + request.StreamOptions = nil + } p.getChatRequestBody(request) req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) if errWithCode != nil { @@ -47,6 +57,9 @@ func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionR } defer req.Body.Close() + // 恢复原来的配置 + request.StreamOptions = streamOptions + // 发送请求 resp, errWithCode := p.Requester.SendRequestRaw(req) if errWithCode != nil { diff --git a/providers/openai/base.go b/providers/openai/base.go index aff6d962..34d5744d 100644 --- a/providers/openai/base.go +++ b/providers/openai/base.go @@ -17,8 +17,9 @@ type OpenAIProviderFactory struct{} type OpenAIProvider struct { base.BaseProvider - IsAzure bool - BalanceAction bool + IsAzure bool + BalanceAction bool + SupportStreamOptions bool } // 创建 OpenAIProvider @@ -33,7 +34,7 @@ func (f OpenAIProviderFactory) Create(channel *model.Channel) base.ProviderInter func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvider { config := getOpenAIConfig(baseURL) - return &OpenAIProvider{ + OpenAIProvider := &OpenAIProvider{ BaseProvider: base.BaseProvider{ Config: config, Channel: channel, @@ -42,6 +43,12 @@ func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvide IsAzure: false, BalanceAction: true, } + + if channel.Type == common.ChannelTypeOpenAI { + OpenAIProvider.SupportStreamOptions = true + } + + return OpenAIProvider } func getOpenAIConfig(baseURL string) base.ProviderConfig { diff --git a/providers/openai/chat.go b/providers/openai/chat.go index 574a5f1f..e57b7d04 100644 --- a/providers/openai/chat.go +++ b/providers/openai/chat.go @@ -8,7 +8,6 @@ import ( "one-api/common/requester" "one-api/types" "strings" - "time" ) type OpenAIStreamHandler struct { @@ -58,12 +57,25 @@ func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionReque } func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { + streamOptions := request.StreamOptions + // 如果支持流式返回Usage 则需要更改配置: + if p.SupportStreamOptions { + request.StreamOptions = &types.StreamOptions{ + IncludeUsage: true, + } + } else { + // 避免误传导致报错 + request.StreamOptions = nil + } req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) if errWithCode != nil { return nil, errWithCode } defer req.Body.Close() + // 恢复原来的配置 + request.StreamOptions = streamOptions + // 发送请求 resp, errWithCode := p.Requester.SendRequestRaw(req) if errWithCode != nil { @@ -110,18 +122,23 @@ func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, dataChan chan s } if len(openaiResponse.Choices) == 0 { + if openaiResponse.Usage != nil { + *h.Usage = *openaiResponse.Usage + } *rawLine = nil return } dataChan <- string(*rawLine) - if h.isAzure { - // 阻塞 20ms - time.Sleep(20 * time.Millisecond) + if len(openaiResponse.Choices) > 0 && openaiResponse.Choices[0].Usage != nil { + *h.Usage = *openaiResponse.Choices[0].Usage + } else { + if h.Usage.TotalTokens == 0 { + h.Usage.TotalTokens = h.Usage.PromptTokens + } + countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName) + h.Usage.CompletionTokens += countTokenText + h.Usage.TotalTokens += countTokenText } - - countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName) - h.Usage.CompletionTokens += countTokenText - h.Usage.TotalTokens += countTokenText } diff --git a/providers/openai/completion.go b/providers/openai/completion.go index 7d1e7e74..e52120e7 100644 --- a/providers/openai/completion.go +++ b/providers/openai/completion.go @@ -40,12 +40,25 @@ func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (ope } func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[string], errWithCode *types.OpenAIErrorWithStatusCode) { + streamOptions := request.StreamOptions + // 如果支持流式返回Usage 则需要更改配置: + if p.SupportStreamOptions { + request.StreamOptions = &types.StreamOptions{ + IncludeUsage: true, + } + } else { + // 避免误传导致报错 + request.StreamOptions = nil + } req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request) if errWithCode != nil { return nil, errWithCode } defer req.Body.Close() + // 恢复原来的配置 + request.StreamOptions = streamOptions + // 发送请求 resp, errWithCode := p.Requester.SendRequestRaw(req) if errWithCode != nil { @@ -90,8 +103,19 @@ func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, dataChan return } + if len(openaiResponse.Choices) == 0 { + if openaiResponse.Usage != nil { + *h.Usage = *openaiResponse.Usage + } + *rawLine = nil + return + } + dataChan <- string(*rawLine) + if h.Usage.TotalTokens == 0 { + h.Usage.TotalTokens = h.Usage.PromptTokens + } countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName) h.Usage.CompletionTokens += countTokenText h.Usage.TotalTokens += countTokenText diff --git a/relay/chat.go b/relay/chat.go index 3e3de6b4..f8e1eb1b 100644 --- a/relay/chat.go +++ b/relay/chat.go @@ -1,7 +1,9 @@ package relay import ( + "encoding/json" "errors" + "fmt" "math" "net/http" "one-api/common" @@ -36,6 +38,10 @@ func (r *relayChat) setRequest() error { r.c.Set("skip_only_chat", true) } + if !r.chatRequest.Stream && r.chatRequest.StreamOptions != nil { + return errors.New("The 'stream_options' parameter is only allowed when 'stream' is enabled.") + } + r.originalModel = r.chatRequest.Model return nil @@ -66,7 +72,11 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) { return } - err = responseStreamClient(r.c, response, r.cache) + doneStr := func() string { + return r.getUsageResponse() + } + + err = responseStreamClient(r.c, response, r.cache, doneStr) } else { var response *types.ChatCompletionResponse response, err = chatProvider.CreateChatCompletion(&r.chatRequest) @@ -86,3 +96,25 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) { return } + +func (r *relayChat) getUsageResponse() string { + if r.chatRequest.StreamOptions != nil && r.chatRequest.StreamOptions.IncludeUsage { + usageResponse := types.ChatCompletionStreamResponse{ + ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: r.chatRequest.Model, + Choices: []types.ChatCompletionStreamChoice{}, + Usage: r.provider.GetUsage(), + } + + responseBody, err := json.Marshal(usageResponse) + if err != nil { + return "" + } + + return string(responseBody) + } + + return "" +} diff --git a/relay/common.go b/relay/common.go index a230e098..cebfcac5 100644 --- a/relay/common.go +++ b/relay/common.go @@ -140,7 +140,9 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith return nil } -func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps) (errWithOP *types.OpenAIErrorWithStatusCode) { +type StreamEndHandler func() string + +func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps, endHandler StreamEndHandler) (errWithOP *types.OpenAIErrorWithStatusCode) { requester.SetEventStreamHeaders(c) dataChan, errChan := stream.Recv() @@ -160,6 +162,14 @@ func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface cache.NoCache() } + if errWithOP == nil && endHandler != nil { + streamData := endHandler() + if streamData != "" { + fmt.Fprint(w, "data: "+streamData+"\n\n") + cache.SetResponse(streamData) + } + } + streamData := "data: [DONE]\n" fmt.Fprint(w, streamData) cache.SetResponse(streamData) @@ -167,7 +177,7 @@ func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface } }) - return errWithOP + return nil } func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode { diff --git a/relay/completions.go b/relay/completions.go index 4abeecb1..591b3c84 100644 --- a/relay/completions.go +++ b/relay/completions.go @@ -1,7 +1,9 @@ package relay import ( + "encoding/json" "errors" + "fmt" "math" "net/http" "one-api/common" @@ -32,6 +34,10 @@ func (r *relayCompletions) setRequest() error { return errors.New("max_tokens is invalid") } + if !r.request.Stream && r.request.StreamOptions != nil { + return errors.New("The 'stream_options' parameter is only allowed when 'stream' is enabled.") + } + r.originalModel = r.request.Model return nil @@ -62,7 +68,11 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo return } - err = responseStreamClient(r.c, response, r.cache) + doneStr := func() string { + return r.getUsageResponse() + } + + err = responseStreamClient(r.c, response, r.cache, doneStr) } else { var response *types.CompletionResponse response, err = provider.CreateCompletion(&r.request) @@ -79,3 +89,25 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo return } + +func (r *relayCompletions) getUsageResponse() string { + if r.request.StreamOptions != nil && r.request.StreamOptions.IncludeUsage { + usageResponse := types.CompletionResponse{ + ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: r.request.Model, + Choices: []types.CompletionChoice{}, + Usage: r.provider.GetUsage(), + } + + responseBody, err := json.Marshal(usageResponse) + if err != nil { + return "" + } + + return string(responseBody) + } + + return "" +} diff --git a/types/chat.go b/types/chat.go index 68cdca49..1bb6822d 100644 --- a/types/chat.go +++ b/types/chat.go @@ -165,6 +165,7 @@ type ChatCompletionRequest struct { TopP float64 `json:"top_p,omitempty"` N int `json:"n,omitempty"` Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` Stop []string `json:"stop,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"` ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` @@ -356,6 +357,7 @@ type ChatCompletionStreamChoice struct { Delta ChatCompletionStreamChoiceDelta `json:"delta"` FinishReason any `json:"finish_reason"` ContentFilterResults any `json:"content_filter_results,omitempty"` + Usage *Usage `json:"usage,omitempty"` } func (c *ChatCompletionStreamChoice) CheckChoice(request *ChatCompletionRequest) { @@ -372,4 +374,5 @@ type ChatCompletionStreamResponse struct { Model string `json:"model"` Choices []ChatCompletionStreamChoice `json:"choices"` PromptAnnotations any `json:"prompt_annotations,omitempty"` + Usage *Usage `json:"usage,omitempty"` } diff --git a/types/common.go b/types/common.go index 9bc02617..2be83e82 100644 --- a/types/common.go +++ b/types/common.go @@ -34,3 +34,7 @@ type OpenAIErrorWithStatusCode struct { type OpenAIErrorResponse struct { Error OpenAIError `json:"error,omitempty"` } + +type StreamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` +} diff --git a/types/completion.go b/types/completion.go index 698bd15d..b513ac47 100644 --- a/types/completion.go +++ b/types/completion.go @@ -1,22 +1,23 @@ package types type CompletionRequest struct { - Model string `json:"model" binding:"required"` - Prompt any `json:"prompt" binding:"required"` - Suffix string `json:"suffix,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - LogProbs int `json:"logprobs,omitempty"` - Echo bool `json:"echo,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - BestOf int `json:"best_of,omitempty"` - LogitBias any `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` + Model string `json:"model" binding:"required"` + Prompt any `json:"prompt" binding:"required"` + Suffix string `json:"suffix,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + Echo bool `json:"echo,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + BestOf int `json:"best_of,omitempty"` + LogitBias any `json:"logit_bias,omitempty"` + User string `json:"user,omitempty"` } type CompletionChoice struct {