feat: add support for Claude 3 tool use (function calling) (#1587)
* feat: add tool support for AWS & Claude * fix: add {} for openai compatibility in streaming tool_use
This commit is contained in:
parent
1ce1e529ee
commit
0fc07ea558
@ -29,12 +29,30 @@ func stopReasonClaude2OpenAI(reason *string) string {
|
|||||||
return "stop"
|
return "stop"
|
||||||
case "max_tokens":
|
case "max_tokens":
|
||||||
return "length"
|
return "length"
|
||||||
|
case "tool_use":
|
||||||
|
return "tool_calls"
|
||||||
default:
|
default:
|
||||||
return *reason
|
return *reason
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
||||||
|
claudeTools := make([]Tool, 0, len(textRequest.Tools))
|
||||||
|
|
||||||
|
for _, tool := range textRequest.Tools {
|
||||||
|
if params, ok := tool.Function.Parameters.(map[string]any); ok {
|
||||||
|
claudeTools = append(claudeTools, Tool{
|
||||||
|
Name: tool.Function.Name,
|
||||||
|
Description: tool.Function.Description,
|
||||||
|
InputSchema: InputSchema{
|
||||||
|
Type: params["type"].(string),
|
||||||
|
Properties: params["properties"],
|
||||||
|
Required: params["required"],
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
claudeRequest := Request{
|
claudeRequest := Request{
|
||||||
Model: textRequest.Model,
|
Model: textRequest.Model,
|
||||||
MaxTokens: textRequest.MaxTokens,
|
MaxTokens: textRequest.MaxTokens,
|
||||||
@ -42,6 +60,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
|||||||
TopP: textRequest.TopP,
|
TopP: textRequest.TopP,
|
||||||
TopK: textRequest.TopK,
|
TopK: textRequest.TopK,
|
||||||
Stream: textRequest.Stream,
|
Stream: textRequest.Stream,
|
||||||
|
Tools: claudeTools,
|
||||||
|
}
|
||||||
|
if len(claudeTools) > 0 {
|
||||||
|
claudeToolChoice := struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
}{Type: "auto"} // default value https://docs.anthropic.com/en/docs/build-with-claude/tool-use#controlling-claudes-output
|
||||||
|
if choice, ok := textRequest.ToolChoice.(map[string]any); ok {
|
||||||
|
if function, ok := choice["function"].(map[string]any); ok {
|
||||||
|
claudeToolChoice.Type = "tool"
|
||||||
|
claudeToolChoice.Name = function["name"].(string)
|
||||||
|
}
|
||||||
|
} else if toolChoiceType, ok := textRequest.ToolChoice.(string); ok {
|
||||||
|
if toolChoiceType == "any" {
|
||||||
|
claudeToolChoice.Type = toolChoiceType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
claudeRequest.ToolChoice = claudeToolChoice
|
||||||
}
|
}
|
||||||
if claudeRequest.MaxTokens == 0 {
|
if claudeRequest.MaxTokens == 0 {
|
||||||
claudeRequest.MaxTokens = 4096
|
claudeRequest.MaxTokens = 4096
|
||||||
@ -64,7 +100,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
|||||||
if message.IsStringContent() {
|
if message.IsStringContent() {
|
||||||
content.Type = "text"
|
content.Type = "text"
|
||||||
content.Text = message.StringContent()
|
content.Text = message.StringContent()
|
||||||
|
if message.Role == "tool" {
|
||||||
|
claudeMessage.Role = "user"
|
||||||
|
content.Type = "tool_result"
|
||||||
|
content.Content = content.Text
|
||||||
|
content.Text = ""
|
||||||
|
content.ToolUseId = message.ToolCallId
|
||||||
|
}
|
||||||
claudeMessage.Content = append(claudeMessage.Content, content)
|
claudeMessage.Content = append(claudeMessage.Content, content)
|
||||||
|
for i := range message.ToolCalls {
|
||||||
|
inputParam := make(map[string]any)
|
||||||
|
_ = json.Unmarshal([]byte(message.ToolCalls[i].Function.Arguments.(string)), &inputParam)
|
||||||
|
claudeMessage.Content = append(claudeMessage.Content, Content{
|
||||||
|
Type: "tool_use",
|
||||||
|
Id: message.ToolCalls[i].Id,
|
||||||
|
Name: message.ToolCalls[i].Function.Name,
|
||||||
|
Input: inputParam,
|
||||||
|
})
|
||||||
|
}
|
||||||
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -97,16 +150,35 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
|
|||||||
var response *Response
|
var response *Response
|
||||||
var responseText string
|
var responseText string
|
||||||
var stopReason string
|
var stopReason string
|
||||||
|
tools := make([]model.Tool, 0)
|
||||||
|
|
||||||
switch claudeResponse.Type {
|
switch claudeResponse.Type {
|
||||||
case "message_start":
|
case "message_start":
|
||||||
return nil, claudeResponse.Message
|
return nil, claudeResponse.Message
|
||||||
case "content_block_start":
|
case "content_block_start":
|
||||||
if claudeResponse.ContentBlock != nil {
|
if claudeResponse.ContentBlock != nil {
|
||||||
responseText = claudeResponse.ContentBlock.Text
|
responseText = claudeResponse.ContentBlock.Text
|
||||||
|
if claudeResponse.ContentBlock.Type == "tool_use" {
|
||||||
|
tools = append(tools, model.Tool{
|
||||||
|
Id: claudeResponse.ContentBlock.Id,
|
||||||
|
Type: "function",
|
||||||
|
Function: model.Function{
|
||||||
|
Name: claudeResponse.ContentBlock.Name,
|
||||||
|
Arguments: "",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case "content_block_delta":
|
case "content_block_delta":
|
||||||
if claudeResponse.Delta != nil {
|
if claudeResponse.Delta != nil {
|
||||||
responseText = claudeResponse.Delta.Text
|
responseText = claudeResponse.Delta.Text
|
||||||
|
if claudeResponse.Delta.Type == "input_json_delta" {
|
||||||
|
tools = append(tools, model.Tool{
|
||||||
|
Function: model.Function{
|
||||||
|
Arguments: claudeResponse.Delta.PartialJson,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case "message_delta":
|
case "message_delta":
|
||||||
if claudeResponse.Usage != nil {
|
if claudeResponse.Usage != nil {
|
||||||
@ -120,6 +192,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
|
|||||||
}
|
}
|
||||||
var choice openai.ChatCompletionsStreamResponseChoice
|
var choice openai.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.Content = responseText
|
choice.Delta.Content = responseText
|
||||||
|
if len(tools) > 0 {
|
||||||
|
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
|
||||||
|
choice.Delta.ToolCalls = tools
|
||||||
|
}
|
||||||
choice.Delta.Role = "assistant"
|
choice.Delta.Role = "assistant"
|
||||||
finishReason := stopReasonClaude2OpenAI(&stopReason)
|
finishReason := stopReasonClaude2OpenAI(&stopReason)
|
||||||
if finishReason != "null" {
|
if finishReason != "null" {
|
||||||
@ -136,12 +212,27 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
|
|||||||
if len(claudeResponse.Content) > 0 {
|
if len(claudeResponse.Content) > 0 {
|
||||||
responseText = claudeResponse.Content[0].Text
|
responseText = claudeResponse.Content[0].Text
|
||||||
}
|
}
|
||||||
|
tools := make([]model.Tool, 0)
|
||||||
|
for _, v := range claudeResponse.Content {
|
||||||
|
if v.Type == "tool_use" {
|
||||||
|
args, _ := json.Marshal(v.Input)
|
||||||
|
tools = append(tools, model.Tool{
|
||||||
|
Id: v.Id,
|
||||||
|
Type: "function", // compatible with other OpenAI derivative applications
|
||||||
|
Function: model.Function{
|
||||||
|
Name: v.Name,
|
||||||
|
Arguments: string(args),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
choice := openai.TextResponseChoice{
|
choice := openai.TextResponseChoice{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Message: model.Message{
|
Message: model.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: responseText,
|
Content: responseText,
|
||||||
Name: nil,
|
Name: nil,
|
||||||
|
ToolCalls: tools,
|
||||||
},
|
},
|
||||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||||
}
|
}
|
||||||
@ -176,6 +267,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
var usage model.Usage
|
var usage model.Usage
|
||||||
var modelName string
|
var modelName string
|
||||||
var id string
|
var id string
|
||||||
|
var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
|
||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
@ -196,9 +288,20 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
if meta != nil {
|
if meta != nil {
|
||||||
usage.PromptTokens += meta.Usage.InputTokens
|
usage.PromptTokens += meta.Usage.InputTokens
|
||||||
usage.CompletionTokens += meta.Usage.OutputTokens
|
usage.CompletionTokens += meta.Usage.OutputTokens
|
||||||
|
if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
|
||||||
modelName = meta.Model
|
modelName = meta.Model
|
||||||
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
|
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
|
||||||
continue
|
continue
|
||||||
|
} else { // finish_reason case
|
||||||
|
if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
|
||||||
|
lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
|
||||||
|
if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
|
||||||
|
lastArgs.Arguments = "{}"
|
||||||
|
response.Choices[len(response.Choices)-1].Delta.Content = nil
|
||||||
|
response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if response == nil {
|
if response == nil {
|
||||||
continue
|
continue
|
||||||
@ -207,6 +310,12 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
response.Id = id
|
response.Id = id
|
||||||
response.Model = modelName
|
response.Model = modelName
|
||||||
response.Created = createdTime
|
response.Created = createdTime
|
||||||
|
|
||||||
|
for _, choice := range response.Choices {
|
||||||
|
if len(choice.Delta.ToolCalls) > 0 {
|
||||||
|
lastToolCallChoice = choice
|
||||||
|
}
|
||||||
|
}
|
||||||
err = render.ObjectData(c, response)
|
err = render.ObjectData(c, response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError(err.Error())
|
logger.SysError(err.Error())
|
||||||
|
@ -16,6 +16,12 @@ type Content struct {
|
|||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
Source *ImageSource `json:"source,omitempty"`
|
Source *ImageSource `json:"source,omitempty"`
|
||||||
|
// tool_calls
|
||||||
|
Id string `json:"id,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Input any `json:"input,omitempty"`
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
ToolUseId string `json:"tool_use_id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
@ -23,6 +29,18 @@ type Message struct {
|
|||||||
Content []Content `json:"content"`
|
Content []Content `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
InputSchema InputSchema `json:"input_schema"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type InputSchema struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties any `json:"properties,omitempty"`
|
||||||
|
Required any `json:"required,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
@ -33,6 +51,8 @@ type Request struct {
|
|||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
ToolChoice any `json:"tool_choice,omitempty"`
|
||||||
//Metadata `json:"metadata,omitempty"`
|
//Metadata `json:"metadata,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -61,6 +81,7 @@ type Response struct {
|
|||||||
type Delta struct {
|
type Delta struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
|
PartialJson string `json:"partial_json,omitempty"`
|
||||||
StopReason *string `json:"stop_reason"`
|
StopReason *string `json:"stop_reason"`
|
||||||
StopSequence *string `json:"stop_sequence"`
|
StopSequence *string `json:"stop_sequence"`
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
@ -143,6 +144,8 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
|
|||||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
var usage relaymodel.Usage
|
var usage relaymodel.Usage
|
||||||
var id string
|
var id string
|
||||||
|
var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
|
||||||
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
event, ok := <-stream.Events()
|
event, ok := <-stream.Events()
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -163,8 +166,19 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
|
|||||||
if meta != nil {
|
if meta != nil {
|
||||||
usage.PromptTokens += meta.Usage.InputTokens
|
usage.PromptTokens += meta.Usage.InputTokens
|
||||||
usage.CompletionTokens += meta.Usage.OutputTokens
|
usage.CompletionTokens += meta.Usage.OutputTokens
|
||||||
|
if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
|
||||||
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
|
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
|
||||||
return true
|
return true
|
||||||
|
} else { // finish_reason case
|
||||||
|
if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
|
||||||
|
lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
|
||||||
|
if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
|
||||||
|
lastArgs.Arguments = "{}"
|
||||||
|
response.Choices[len(response.Choices)-1].Delta.Content = nil
|
||||||
|
response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if response == nil {
|
if response == nil {
|
||||||
return true
|
return true
|
||||||
@ -172,6 +186,12 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
|
|||||||
response.Id = id
|
response.Id = id
|
||||||
response.Model = c.GetString(ctxkey.OriginalModel)
|
response.Model = c.GetString(ctxkey.OriginalModel)
|
||||||
response.Created = createdTime
|
response.Created = createdTime
|
||||||
|
|
||||||
|
for _, choice := range response.Choices {
|
||||||
|
if len(choice.Delta.ToolCalls) > 0 {
|
||||||
|
lastToolCallChoice = choice
|
||||||
|
}
|
||||||
|
}
|
||||||
jsonStr, err := json.Marshal(response)
|
jsonStr, err := json.Marshal(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error marshalling stream response: " + err.Error())
|
logger.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
@ -9,9 +9,12 @@ type Request struct {
|
|||||||
// AnthropicVersion should be "bedrock-2023-05-31"
|
// AnthropicVersion should be "bedrock-2023-05-31"
|
||||||
AnthropicVersion string `json:"anthropic_version"`
|
AnthropicVersion string `json:"anthropic_version"`
|
||||||
Messages []anthropic.Message `json:"messages"`
|
Messages []anthropic.Message `json:"messages"`
|
||||||
|
System string `json:"system,omitempty"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
|
Tools []anthropic.Tool `json:"tools,omitempty"`
|
||||||
|
ToolChoice any `json:"tool_choice,omitempty"`
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ type Message struct {
|
|||||||
Content any `json:"content,omitempty"`
|
Content any `json:"content,omitempty"`
|
||||||
Name *string `json:"name,omitempty"`
|
Name *string `json:"name,omitempty"`
|
||||||
ToolCalls []Tool `json:"tool_calls,omitempty"`
|
ToolCalls []Tool `json:"tool_calls,omitempty"`
|
||||||
|
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Message) IsStringContent() bool {
|
func (m Message) IsStringContent() bool {
|
||||||
|
@ -2,13 +2,13 @@ package model
|
|||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Id string `json:"id,omitempty"`
|
Id string `json:"id,omitempty"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty
|
||||||
Function Function `json:"function"`
|
Function Function `json:"function"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Function struct {
|
type Function struct {
|
||||||
Description string `json:"description,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name,omitempty"` // when splicing claude tools stream messages, it is empty
|
||||||
Parameters any `json:"parameters,omitempty"` // request
|
Parameters any `json:"parameters,omitempty"` // request
|
||||||
Arguments any `json:"arguments,omitempty"` // response
|
Arguments any `json:"arguments,omitempty"` // response
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user