✨ feat: Support stream_options
This commit is contained in:
parent
fa54ca7b50
commit
eb260652b2
@ -40,12 +40,26 @@ func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionReq
|
||||
}
|
||||
|
||||
func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
streamOptions := request.StreamOptions
|
||||
// 如果支持流式返回Usage 则需要更改配置:
|
||||
if p.SupportStreamOptions {
|
||||
request.StreamOptions = &types.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
} else {
|
||||
// 避免误传导致报错
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 恢复原来的配置
|
||||
request.StreamOptions = streamOptions
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
|
@ -40,6 +40,16 @@ func (p *GroqProvider) CreateChatCompletion(request *types.ChatCompletionRequest
|
||||
}
|
||||
|
||||
func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
streamOptions := request.StreamOptions
|
||||
// 如果支持流式返回Usage 则需要更改配置:
|
||||
if p.SupportStreamOptions {
|
||||
request.StreamOptions = &types.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
} else {
|
||||
// 避免误传导致报错
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
p.getChatRequestBody(request)
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
@ -47,6 +57,9 @@ func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionR
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 恢复原来的配置
|
||||
request.StreamOptions = streamOptions
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
|
@ -19,6 +19,7 @@ type OpenAIProvider struct {
|
||||
base.BaseProvider
|
||||
IsAzure bool
|
||||
BalanceAction bool
|
||||
SupportStreamOptions bool
|
||||
}
|
||||
|
||||
// 创建 OpenAIProvider
|
||||
@ -33,7 +34,7 @@ func (f OpenAIProviderFactory) Create(channel *model.Channel) base.ProviderInter
|
||||
func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvider {
|
||||
config := getOpenAIConfig(baseURL)
|
||||
|
||||
return &OpenAIProvider{
|
||||
OpenAIProvider := &OpenAIProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
Config: config,
|
||||
Channel: channel,
|
||||
@ -42,6 +43,12 @@ func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvide
|
||||
IsAzure: false,
|
||||
BalanceAction: true,
|
||||
}
|
||||
|
||||
if channel.Type == common.ChannelTypeOpenAI {
|
||||
OpenAIProvider.SupportStreamOptions = true
|
||||
}
|
||||
|
||||
return OpenAIProvider
|
||||
}
|
||||
|
||||
func getOpenAIConfig(baseURL string) base.ProviderConfig {
|
||||
|
@ -8,7 +8,6 @@ import (
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OpenAIStreamHandler struct {
|
||||
@ -58,12 +57,25 @@ func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionReque
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
|
||||
streamOptions := request.StreamOptions
|
||||
// 如果支持流式返回Usage 则需要更改配置:
|
||||
if p.SupportStreamOptions {
|
||||
request.StreamOptions = &types.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
} else {
|
||||
// 避免误传导致报错
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 恢复原来的配置
|
||||
request.StreamOptions = streamOptions
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
@ -110,18 +122,23 @@ func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, dataChan chan s
|
||||
}
|
||||
|
||||
if len(openaiResponse.Choices) == 0 {
|
||||
if openaiResponse.Usage != nil {
|
||||
*h.Usage = *openaiResponse.Usage
|
||||
}
|
||||
*rawLine = nil
|
||||
return
|
||||
}
|
||||
|
||||
dataChan <- string(*rawLine)
|
||||
|
||||
if h.isAzure {
|
||||
// 阻塞 20ms
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
if len(openaiResponse.Choices) > 0 && openaiResponse.Choices[0].Usage != nil {
|
||||
*h.Usage = *openaiResponse.Choices[0].Usage
|
||||
} else {
|
||||
if h.Usage.TotalTokens == 0 {
|
||||
h.Usage.TotalTokens = h.Usage.PromptTokens
|
||||
}
|
||||
|
||||
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
|
||||
h.Usage.CompletionTokens += countTokenText
|
||||
h.Usage.TotalTokens += countTokenText
|
||||
}
|
||||
}
|
||||
|
@ -40,12 +40,25 @@ func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (ope
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[string], errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
streamOptions := request.StreamOptions
|
||||
// 如果支持流式返回Usage 则需要更改配置:
|
||||
if p.SupportStreamOptions {
|
||||
request.StreamOptions = &types.StreamOptions{
|
||||
IncludeUsage: true,
|
||||
}
|
||||
} else {
|
||||
// 避免误传导致报错
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 恢复原来的配置
|
||||
request.StreamOptions = streamOptions
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
@ -90,8 +103,19 @@ func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, dataChan
|
||||
return
|
||||
}
|
||||
|
||||
if len(openaiResponse.Choices) == 0 {
|
||||
if openaiResponse.Usage != nil {
|
||||
*h.Usage = *openaiResponse.Usage
|
||||
}
|
||||
*rawLine = nil
|
||||
return
|
||||
}
|
||||
|
||||
dataChan <- string(*rawLine)
|
||||
|
||||
if h.Usage.TotalTokens == 0 {
|
||||
h.Usage.TotalTokens = h.Usage.PromptTokens
|
||||
}
|
||||
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
|
||||
h.Usage.CompletionTokens += countTokenText
|
||||
h.Usage.TotalTokens += countTokenText
|
||||
|
@ -1,7 +1,9 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@ -36,6 +38,10 @@ func (r *relayChat) setRequest() error {
|
||||
r.c.Set("skip_only_chat", true)
|
||||
}
|
||||
|
||||
if !r.chatRequest.Stream && r.chatRequest.StreamOptions != nil {
|
||||
return errors.New("The 'stream_options' parameter is only allowed when 'stream' is enabled.")
|
||||
}
|
||||
|
||||
r.originalModel = r.chatRequest.Model
|
||||
|
||||
return nil
|
||||
@ -66,7 +72,11 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
return
|
||||
}
|
||||
|
||||
err = responseStreamClient(r.c, response, r.cache)
|
||||
doneStr := func() string {
|
||||
return r.getUsageResponse()
|
||||
}
|
||||
|
||||
err = responseStreamClient(r.c, response, r.cache, doneStr)
|
||||
} else {
|
||||
var response *types.ChatCompletionResponse
|
||||
response, err = chatProvider.CreateChatCompletion(&r.chatRequest)
|
||||
@ -86,3 +96,25 @@ func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *relayChat) getUsageResponse() string {
|
||||
if r.chatRequest.StreamOptions != nil && r.chatRequest.StreamOptions.IncludeUsage {
|
||||
usageResponse := types.ChatCompletionStreamResponse{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: r.chatRequest.Model,
|
||||
Choices: []types.ChatCompletionStreamChoice{},
|
||||
Usage: r.provider.GetUsage(),
|
||||
}
|
||||
|
||||
responseBody, err := json.Marshal(usageResponse)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(responseBody)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
@ -140,7 +140,9 @@ func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWith
|
||||
return nil
|
||||
}
|
||||
|
||||
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps) (errWithOP *types.OpenAIErrorWithStatusCode) {
|
||||
type StreamEndHandler func() string
|
||||
|
||||
func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface[string], cache *util.ChatCacheProps, endHandler StreamEndHandler) (errWithOP *types.OpenAIErrorWithStatusCode) {
|
||||
requester.SetEventStreamHeaders(c)
|
||||
dataChan, errChan := stream.Recv()
|
||||
|
||||
@ -160,6 +162,14 @@ func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface
|
||||
cache.NoCache()
|
||||
}
|
||||
|
||||
if errWithOP == nil && endHandler != nil {
|
||||
streamData := endHandler()
|
||||
if streamData != "" {
|
||||
fmt.Fprint(w, "data: "+streamData+"\n\n")
|
||||
cache.SetResponse(streamData)
|
||||
}
|
||||
}
|
||||
|
||||
streamData := "data: [DONE]\n"
|
||||
fmt.Fprint(w, streamData)
|
||||
cache.SetResponse(streamData)
|
||||
@ -167,7 +177,7 @@ func responseStreamClient(c *gin.Context, stream requester.StreamReaderInterface
|
||||
}
|
||||
})
|
||||
|
||||
return errWithOP
|
||||
return nil
|
||||
}
|
||||
|
||||
func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode {
|
||||
|
@ -1,7 +1,9 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@ -32,6 +34,10 @@ func (r *relayCompletions) setRequest() error {
|
||||
return errors.New("max_tokens is invalid")
|
||||
}
|
||||
|
||||
if !r.request.Stream && r.request.StreamOptions != nil {
|
||||
return errors.New("The 'stream_options' parameter is only allowed when 'stream' is enabled.")
|
||||
}
|
||||
|
||||
r.originalModel = r.request.Model
|
||||
|
||||
return nil
|
||||
@ -62,7 +68,11 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
|
||||
return
|
||||
}
|
||||
|
||||
err = responseStreamClient(r.c, response, r.cache)
|
||||
doneStr := func() string {
|
||||
return r.getUsageResponse()
|
||||
}
|
||||
|
||||
err = responseStreamClient(r.c, response, r.cache, doneStr)
|
||||
} else {
|
||||
var response *types.CompletionResponse
|
||||
response, err = provider.CreateCompletion(&r.request)
|
||||
@ -79,3 +89,25 @@ func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bo
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *relayCompletions) getUsageResponse() string {
|
||||
if r.request.StreamOptions != nil && r.request.StreamOptions.IncludeUsage {
|
||||
usageResponse := types.CompletionResponse{
|
||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: r.request.Model,
|
||||
Choices: []types.CompletionChoice{},
|
||||
Usage: r.provider.GetUsage(),
|
||||
}
|
||||
|
||||
responseBody, err := json.Marshal(usageResponse)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(responseBody)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
@ -165,6 +165,7 @@ type ChatCompletionRequest struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"`
|
||||
@ -356,6 +357,7 @@ type ChatCompletionStreamChoice struct {
|
||||
Delta ChatCompletionStreamChoiceDelta `json:"delta"`
|
||||
FinishReason any `json:"finish_reason"`
|
||||
ContentFilterResults any `json:"content_filter_results,omitempty"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
func (c *ChatCompletionStreamChoice) CheckChoice(request *ChatCompletionRequest) {
|
||||
@ -372,4 +374,5 @@ type ChatCompletionStreamResponse struct {
|
||||
Model string `json:"model"`
|
||||
Choices []ChatCompletionStreamChoice `json:"choices"`
|
||||
PromptAnnotations any `json:"prompt_annotations,omitempty"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
@ -34,3 +34,7 @@ type OpenAIErrorWithStatusCode struct {
|
||||
type OpenAIErrorResponse struct {
|
||||
Error OpenAIError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ type CompletionRequest struct {
|
||||
TopP float32 `json:"top_p,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
LogProbs int `json:"logprobs,omitempty"`
|
||||
Echo bool `json:"echo,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
|
Loading…
Reference in New Issue
Block a user