diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go index 64e3d869..d3e306c8 100644 --- a/relay/adaptor/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -267,6 +267,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC var usage model.Usage var modelName string var id string + var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice for scanner.Scan() { data := scanner.Text() @@ -287,9 +288,20 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC if meta != nil { usage.PromptTokens += meta.Usage.InputTokens usage.CompletionTokens += meta.Usage.OutputTokens - modelName = meta.Model - id = fmt.Sprintf("chatcmpl-%s", meta.Id) - continue + if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event. + modelName = meta.Model + id = fmt.Sprintf("chatcmpl-%s", meta.Id) + continue + } else { // finish_reason case + if len(lastToolCallChoice.Delta.ToolCalls) > 0 { + lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function + if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments. + lastArgs.Arguments = "{}" + response.Choices[len(response.Choices)-1].Delta.Content = nil + response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls + } + } + } } if response == nil { continue @@ -298,6 +310,12 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC response.Id = id response.Model = modelName response.Created = createdTime + + for _, choice := range response.Choices { + if len(choice.Delta.ToolCalls) > 0 { + lastToolCallChoice = choice + } + } err = render.ObjectData(c, response) if err != nil { logger.SysError(err.Error()) diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/main.go index 8b2c5753..72f40ddc 100644 --- a/relay/adaptor/aws/main.go +++ b/relay/adaptor/aws/main.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "io" "net/http" @@ -143,6 +144,8 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E c.Writer.Header().Set("Content-Type", "text/event-stream") var usage relaymodel.Usage var id string + var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice + c.Stream(func(w io.Writer) bool { event, ok := <-stream.Events() if !ok { @@ -163,9 +166,18 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E if meta != nil { usage.PromptTokens += meta.Usage.InputTokens usage.CompletionTokens += meta.Usage.OutputTokens - if len(meta.Id) > 0 { // only message_start has id. else it's finish_reason + if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event. id = fmt.Sprintf("chatcmpl-%s", meta.Id) return true + } else { // finish_reason case + if len(lastToolCallChoice.Delta.ToolCalls) > 0 { + lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function + if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments. + lastArgs.Arguments = "{}" + response.Choices[len(response.Choices)-1].Delta.Content = nil + response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls + } + } } } if response == nil { @@ -174,6 +186,12 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E response.Id = id response.Model = c.GetString(ctxkey.OriginalModel) response.Created = createdTime + + for _, choice := range response.Choices { + if len(choice.Delta.ToolCalls) > 0 { + lastToolCallChoice = choice + } + } jsonStr, err := json.Marshal(response) if err != nil { logger.SysError("error marshalling stream response: " + err.Error())