feat: Support stream_options

This commit is contained in:
MartialBE 2024-05-26 19:58:15 +08:00
parent fa54ca7b50
commit eb260652b2
No known key found for this signature in database
GPG Key ID: 27C0267EC84B0A5C
11 changed files with 188 additions and 31 deletions

View File

@ -40,12 +40,26 @@ func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionReq
} }
func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
streamOptions := request.StreamOptions
// 如果支持流式返回Usage 则需要更改配置:
if p.SupportStreamOptions {
request.StreamOptions = &types.StreamOptions{
IncludeUsage: true,
}
} else {
// 避免误传导致报错
request.StreamOptions = nil
}
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
defer req.Body.Close() defer req.Body.Close()
// 恢复原来的配置
request.StreamOptions = streamOptions
// 发送请求 // 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req) resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil { if errWithCode != nil {

View File

@ -40,6 +40,16 @@ func (p *GroqProvider) CreateChatCompletion(request *types.ChatCompletionRequest
} }
func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
streamOptions := request.StreamOptions
// 如果支持流式返回Usage 则需要更改配置:
if p.SupportStreamOptions {
request.StreamOptions = &types.StreamOptions{
IncludeUsage: true,
}
} else {
// 避免误传导致报错
request.StreamOptions = nil
}
p.getChatRequestBody(request) p.getChatRequestBody(request)
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
@ -47,6 +57,9 @@ func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionR
} }
defer req.Body.Close() defer req.Body.Close()
// 恢复原来的配置
request.StreamOptions = streamOptions
// 发送请求 // 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req) resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil { if errWithCode != nil {

View File

@ -17,8 +17,9 @@ type OpenAIProviderFactory struct{}
type OpenAIProvider struct { type OpenAIProvider struct {
base.BaseProvider base.BaseProvider
IsAzure bool IsAzure bool
BalanceAction bool BalanceAction bool
SupportStreamOptions bool
} }
// 创建 OpenAIProvider // 创建 OpenAIProvider
@ -33,7 +34,7 @@ func (f OpenAIProviderFactory) Create(channel *model.Channel) base.ProviderInter
func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvider { func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvider {
config := getOpenAIConfig(baseURL) config := getOpenAIConfig(baseURL)
return &OpenAIProvider{ OpenAIProvider := &OpenAIProvider{
BaseProvider: base.BaseProvider{ BaseProvider: base.BaseProvider{
Config: config, Config: config,
Channel: channel, Channel: channel,
@ -42,6 +43,12 @@ func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvide
IsAzure: false, IsAzure: false,
BalanceAction: true, BalanceAction: true,
} }
if channel.Type == common.ChannelTypeOpenAI {
OpenAIProvider.SupportStreamOptions = true
}
return OpenAIProvider
} }
func getOpenAIConfig(baseURL string) base.ProviderConfig { func getOpenAIConfig(baseURL string) base.ProviderConfig {

View File

@ -8,7 +8,6 @@ import (
"one-api/common/requester" "one-api/common/requester"
"one-api/types" "one-api/types"
"strings" "strings"
"time"
) )
type OpenAIStreamHandler struct { type OpenAIStreamHandler struct {
@ -58,12 +57,25 @@ func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionReque
} }
func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
streamOptions := request.StreamOptions
// 如果支持流式返回Usage 则需要更改配置:
if p.SupportStreamOptions {
request.StreamOptions = &types.StreamOptions{
IncludeUsage: true,
}
} else {
// 避免误传导致报错
request.StreamOptions = nil
}
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
defer req.Body.Close() defer req.Body.Close()
// 恢复原来的配置
request.StreamOptions = streamOptions
// 发送请求 // 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req) resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil { if errWithCode != nil {
@ -110,18 +122,23 @@ func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, dataChan chan s
} }
if len(openaiResponse.Choices) == 0 { if len(openaiResponse.Choices) == 0 {
if openaiResponse.Usage != nil {
*h.Usage = *openaiResponse.Usage
}
*rawLine = nil *rawLine = nil
return return
} }
dataChan <- string(*rawLine) dataChan <- string(*rawLine)
if h.isAzure { if len(openaiResponse.Choices) > 0 && openaiResponse.Choices[0].Usage != nil {
// 阻塞 20ms *h.Usage = *openaiResponse.Choices[0].Usage
time.Sleep(20 * time.Millisecond) } else {
if h.Usage.TotalTokens == 0 {
h.Usage.TotalTokens = h.Usage.PromptTokens
}
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText
} }
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText
} }

View File

@ -40,12 +40,25 @@ func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (ope
} }
func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[string], errWithCode *types.OpenAIErrorWithStatusCode) { func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[string], errWithCode *types.OpenAIErrorWithStatusCode) {
streamOptions := request.StreamOptions
// 如果支持流式返回Usage 则需要更改配置:
if p.SupportStreamOptions {
request.StreamOptions = &types.StreamOptions{
IncludeUsage: true,
}
} else {
// 避免误传导致报错
request.StreamOptions = nil
}
req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request) req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
} }
defer req.Body.Close() defer req.Body.Close()
// 恢复原来的配置
request.StreamOptions = streamOptions
// 发送请求 // 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req) resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil { if errWithCode != nil {
@ -90,8 +103,19 @@ func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, dataChan
return return
} }
if len(openaiResponse.Choices) == 0 {
if openaiResponse.Usage != nil {
*h.Usage = *openaiResponse.Usage
}
*rawLine = nil
return
}
dataChan <- string(*rawLine) dataChan <- string(*rawLine)
if h.Usage.TotalTokens == 0 {
h.Usage.TotalTokens = h.Usage.PromptTokens
}
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName) countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText h.Usage.TotalTokens += countTokenText

View File

@ -1,7 +1,9 @@
package relay package relay
import ( import (
"encoding/json"
"errors" "errors"
"fmt"
"math" "math"
"net/http" "net/http"
"one-api/common" "one-api/common"
@ -36,6 +38,10 @@ func (r *relayChat) setRequest() error {
r.c.Set("skip_only_chat", true) r.c.Set("skip_only_chat", true)
} }
if !r.chatRequest.Stream && r.chatRequest.StreamOptions != nil {
return errors.New("The 'stream_options' parameter is only allowed when 'stream' is enabled.")
}
r.originalModel = r.chatRequest.Model r.originalModel = r.chatRequest.Model
return nil return nil
@ -66,7 +72,11 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
return return
} }
err = responseStreamClient(r.c, response, r.cache) doneStr := func() string {
return r.getUsageResponse()
}
err = responseStreamClient(r.c, response, r.cache, doneStr)
} else { } else {
var response *types.ChatCompletionResponse var response *types.ChatCompletionResponse
response, err = chatProvider.CreateChatCompletion(&r.chatRequest) response, err = chatProvider.CreateChatCompletion(&r.chatRequest)
@ -86,3 +96,25 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
return return
} }
func (r *relayChat) getUsageResponse() string {
if r.chatRequest.StreamOptions != nil && r.chatRequest.StreamOptions.IncludeUsage {
usageResponse := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: r.chatRequest.Model,
Choices: []types.ChatCompletionStreamChoice{},
Usage: r.provider.GetUsage(),
}
responseBody, err := json.Marshal(usageResponse)
if err != nil {
return ""
}
return string(responseBody)
}
return ""
}

View File

@ -140,7 +140,9 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith
return nil return nil
} }
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps) (errWithOP *types.OpenAIErrorWithStatusCode) { type StreamEndHandler func() string
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps, endHandler StreamEndHandler) (errWithOP *types.OpenAIErrorWithStatusCode) {
requester.SetEventStreamHeaders(c) requester.SetEventStreamHeaders(c)
dataChan, errChan := stream.Recv() dataChan, errChan := stream.Recv()
@ -160,6 +162,14 @@ func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface
cache.NoCache() cache.NoCache()
} }
if errWithOP == nil && endHandler != nil {
streamData := endHandler()
if streamData != "" {
fmt.Fprint(w, "data: "+streamData+"\n\n")
cache.SetResponse(streamData)
}
}
streamData := "data: [DONE]\n" streamData := "data: [DONE]\n"
fmt.Fprint(w, streamData) fmt.Fprint(w, streamData)
cache.SetResponse(streamData) cache.SetResponse(streamData)
@ -167,7 +177,7 @@ func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface
} }
}) })
return errWithOP return nil
} }
func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode { func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode {

View File

@ -1,7 +1,9 @@
package relay package relay
import ( import (
"encoding/json"
"errors" "errors"
"fmt"
"math" "math"
"net/http" "net/http"
"one-api/common" "one-api/common"
@ -32,6 +34,10 @@ func (r *relayCompletions) setRequest() error {
return errors.New("max_tokens is invalid") return errors.New("max_tokens is invalid")
} }
if !r.request.Stream && r.request.StreamOptions != nil {
return errors.New("The 'stream_options' parameter is only allowed when 'stream' is enabled.")
}
r.originalModel = r.request.Model r.originalModel = r.request.Model
return nil return nil
@ -62,7 +68,11 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
return return
} }
err = responseStreamClient(r.c, response, r.cache) doneStr := func() string {
return r.getUsageResponse()
}
err = responseStreamClient(r.c, response, r.cache, doneStr)
} else { } else {
var response *types.CompletionResponse var response *types.CompletionResponse
response, err = provider.CreateCompletion(&r.request) response, err = provider.CreateCompletion(&r.request)
@ -79,3 +89,25 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
return return
} }
func (r *relayCompletions) getUsageResponse() string {
if r.request.StreamOptions != nil && r.request.StreamOptions.IncludeUsage {
usageResponse := types.CompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: r.request.Model,
Choices: []types.CompletionChoice{},
Usage: r.provider.GetUsage(),
}
responseBody, err := json.Marshal(usageResponse)
if err != nil {
return ""
}
return string(responseBody)
}
return ""
}

View File

@ -165,6 +165,7 @@ type ChatCompletionRequest struct {
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"` N int `json:"n,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
Stop []string `json:"stop,omitempty"` Stop []string `json:"stop,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"`
@ -356,6 +357,7 @@ type ChatCompletionStreamChoice struct {
Delta ChatCompletionStreamChoiceDelta `json:"delta"` Delta ChatCompletionStreamChoiceDelta `json:"delta"`
FinishReason any `json:"finish_reason"` FinishReason any `json:"finish_reason"`
ContentFilterResults any `json:"content_filter_results,omitempty"` ContentFilterResults any `json:"content_filter_results,omitempty"`
Usage *Usage `json:"usage,omitempty"`
} }
func (c *ChatCompletionStreamChoice) CheckChoice(request *ChatCompletionRequest) { func (c *ChatCompletionStreamChoice) CheckChoice(request *ChatCompletionRequest) {
@ -372,4 +374,5 @@ type ChatCompletionStreamResponse struct {
Model string `json:"model"` Model string `json:"model"`
Choices []ChatCompletionStreamChoice `json:"choices"` Choices []ChatCompletionStreamChoice `json:"choices"`
PromptAnnotations any `json:"prompt_annotations,omitempty"` PromptAnnotations any `json:"prompt_annotations,omitempty"`
Usage *Usage `json:"usage,omitempty"`
} }

View File

@ -34,3 +34,7 @@ type OpenAIErrorWithStatusCode struct {
type OpenAIErrorResponse struct { type OpenAIErrorResponse struct {
Error OpenAIError `json:"error,omitempty"` Error OpenAIError `json:"error,omitempty"`
} }
type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}

View File

@ -1,22 +1,23 @@
package types package types
type CompletionRequest struct { type CompletionRequest struct {
Model string `json:"model" binding:"required"` Model string `json:"model" binding:"required"`
Prompt any `json:"prompt" binding:"required"` Prompt any `json:"prompt" binding:"required"`
Suffix string `json:"suffix,omitempty"` Suffix string `json:"suffix,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"` Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"` TopP float32 `json:"top_p,omitempty"`
N int `json:"n,omitempty"` N int `json:"n,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
LogProbs int `json:"logprobs,omitempty"` StreamOptions *StreamOptions `json:"stream_options,omitempty"`
Echo bool `json:"echo,omitempty"` LogProbs int `json:"logprobs,omitempty"`
Stop []string `json:"stop,omitempty"` Echo bool `json:"echo,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"` Stop []string `json:"stop,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` PresencePenalty float32 `json:"presence_penalty,omitempty"`
BestOf int `json:"best_of,omitempty"` FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
LogitBias any `json:"logit_bias,omitempty"` BestOf int `json:"best_of,omitempty"`
User string `json:"user,omitempty"` LogitBias any `json:"logit_bias,omitempty"`
User string `json:"user,omitempty"`
} }
type CompletionChoice struct { type CompletionChoice struct {