fix: add {} for openai compatibility in streaming tool_use
This commit is contained in:
parent
ea1c293d39
commit
d16cd6152e
@ -267,6 +267,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
var usage model.Usage
|
var usage model.Usage
|
||||||
var modelName string
|
var modelName string
|
||||||
var id string
|
var id string
|
||||||
|
var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
|
||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
@ -287,9 +288,20 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
if meta != nil {
|
if meta != nil {
|
||||||
usage.PromptTokens += meta.Usage.InputTokens
|
usage.PromptTokens += meta.Usage.InputTokens
|
||||||
usage.CompletionTokens += meta.Usage.OutputTokens
|
usage.CompletionTokens += meta.Usage.OutputTokens
|
||||||
|
if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
|
||||||
modelName = meta.Model
|
modelName = meta.Model
|
||||||
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
|
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
|
||||||
continue
|
continue
|
||||||
|
} else { // finish_reason case
|
||||||
|
if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
|
||||||
|
lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
|
||||||
|
if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
|
||||||
|
lastArgs.Arguments = "{}"
|
||||||
|
response.Choices[len(response.Choices)-1].Delta.Content = nil
|
||||||
|
response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if response == nil {
|
if response == nil {
|
||||||
continue
|
continue
|
||||||
@ -298,6 +310,12 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
response.Id = id
|
response.Id = id
|
||||||
response.Model = modelName
|
response.Model = modelName
|
||||||
response.Created = createdTime
|
response.Created = createdTime
|
||||||
|
|
||||||
|
for _, choice := range response.Choices {
|
||||||
|
if len(choice.Delta.ToolCalls) > 0 {
|
||||||
|
lastToolCallChoice = choice
|
||||||
|
}
|
||||||
|
}
|
||||||
err = render.ObjectData(c, response)
|
err = render.ObjectData(c, response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(err.Error())
|
logger.SysError(err.Error())
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
@ -143,6 +144,8 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
|
|||||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
var usage relaymodel.Usage
|
var usage relaymodel.Usage
|
||||||
var id string
|
var id string
|
||||||
|
var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
|
||||||
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
event, ok := <-stream.Events()
|
event, ok := <-stream.Events()
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -163,9 +166,18 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
|
|||||||
if meta != nil {
|
if meta != nil {
|
||||||
usage.PromptTokens += meta.Usage.InputTokens
|
usage.PromptTokens += meta.Usage.InputTokens
|
||||||
usage.CompletionTokens += meta.Usage.OutputTokens
|
usage.CompletionTokens += meta.Usage.OutputTokens
|
||||||
if len(meta.Id) > 0 { // only message_start has id. else it's finish_reason
|
if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
|
||||||
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
|
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
|
||||||
return true
|
return true
|
||||||
|
} else { // finish_reason case
|
||||||
|
if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
|
||||||
|
lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
|
||||||
|
if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
|
||||||
|
lastArgs.Arguments = "{}"
|
||||||
|
response.Choices[len(response.Choices)-1].Delta.Content = nil
|
||||||
|
response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if response == nil {
|
if response == nil {
|
||||||
@ -174,6 +186,12 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
|
|||||||
response.Id = id
|
response.Id = id
|
||||||
response.Model = c.GetString(ctxkey.OriginalModel)
|
response.Model = c.GetString(ctxkey.OriginalModel)
|
||||||
response.Created = createdTime
|
response.Created = createdTime
|
||||||
|
|
||||||
|
for _, choice := range response.Choices {
|
||||||
|
if len(choice.Delta.ToolCalls) > 0 {
|
||||||
|
lastToolCallChoice = choice
|
||||||
|
}
|
||||||
|
}
|
||||||
jsonStr, err := json.Marshal(response)
|
jsonStr, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error marshalling stream response: " + err.Error())
|
logger.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
Loading…
Reference in New Issue
Block a user