From 2ba28c72cbd41ed1020d71d3fd57367ea99be7fd Mon Sep 17 00:00:00 2001 From: JustSong Date: Sat, 30 Mar 2024 10:43:26 +0800 Subject: [PATCH] feat: support function call for ali (close #1242) --- common/conv/any.go | 6 ++++++ relay/channel/ali/main.go | 28 ++++++++++++++++------------ relay/channel/ali/model.go | 18 +++++++++++++----- relay/channel/openai/main.go | 3 ++- relay/channel/openai/model.go | 9 +++------ relay/channel/tencent/main.go | 3 ++- relay/model/general.go | 26 ++++++++++++++------------ relay/model/message.go | 7 ++++--- relay/model/tool.go | 14 ++++++++++++++ 9 files changed, 74 insertions(+), 40 deletions(-) create mode 100644 common/conv/any.go create mode 100644 relay/model/tool.go diff --git a/common/conv/any.go b/common/conv/any.go new file mode 100644 index 00000000..467e8bb7 --- /dev/null +++ b/common/conv/any.go @@ -0,0 +1,6 @@ +package conv + +func AsString(v any) string { + str, _ := v.(string) + return str +} diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go index 62115d58..6fdfa4d4 100644 --- a/relay/channel/ali/main.go +++ b/relay/channel/ali/main.go @@ -48,7 +48,10 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { MaxTokens: request.MaxTokens, Temperature: request.Temperature, TopP: request.TopP, + TopK: request.TopK, + ResultFormat: "message", }, + Tools: request.Tools, } } @@ -117,19 +120,11 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR } func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { - choice := openai.TextResponseChoice{ - Index: 0, - Message: model.Message{ - Role: "assistant", - Content: response.Output.Text, - }, - FinishReason: response.Output.FinishReason, - } fullTextResponse := openai.TextResponse{ Id: response.RequestId, Object: "chat.completion", Created: helper.GetTimestamp(), - Choices: []openai.TextResponseChoice{choice}, + Choices: response.Output.Choices, Usage: model.Usage{ PromptTokens: response.Usage.InputTokens, CompletionTokens: response.Usage.OutputTokens, @@ -140,10 +135,14 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { } func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + if len(aliResponse.Output.Choices) == 0 { + return nil + } + aliChoice := aliResponse.Output.Choices[0] var choice openai.ChatCompletionsStreamResponseChoice - choice.Delta.Content = aliResponse.Output.Text - if aliResponse.Output.FinishReason != "null" { - finishReason := aliResponse.Output.FinishReason + choice.Delta = aliChoice.Message + if aliChoice.FinishReason != "null" { + finishReason := aliChoice.FinishReason choice.FinishReason = &finishReason } response := openai.ChatCompletionsStreamResponse{ @@ -204,6 +203,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens } response := streamResponseAli2OpenAI(&aliResponse) + if response == nil { + return true + } //response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) //lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) @@ -226,6 +228,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + ctx := c.Request.Context() var aliResponse ChatResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { @@ -235,6 +238,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, * if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + logger.Debugf(ctx, "response body: %s\n", responseBody) err = json.Unmarshal(responseBody, &aliResponse) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/channel/ali/model.go b/relay/channel/ali/model.go index 76e814d1..e19d427a 100644 --- a/relay/channel/ali/model.go +++ b/relay/channel/ali/model.go @@ -1,5 +1,10 @@ package ali +import ( + "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/model" +) + type Message struct { Content string `json:"content"` Role string `json:"role"` @@ -18,12 +23,14 @@ type Parameters struct { IncrementalOutput bool `json:"incremental_output,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` Temperature float64 `json:"temperature,omitempty"` + ResultFormat string `json:"result_format,omitempty"` } type ChatRequest struct { - Model string `json:"model"` - Input Input `json:"input"` - Parameters Parameters `json:"parameters,omitempty"` + Model string `json:"model"` + Input Input `json:"input"` + Parameters Parameters `json:"parameters,omitempty"` + Tools []model.Tool `json:"tools,omitempty"` } type EmbeddingRequest struct { @@ -62,8 +69,9 @@ type Usage struct { } type Output struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` + //Text string `json:"text"` + //FinishReason string `json:"finish_reason"` + Choices []openai.TextResponseChoice `json:"choices"` } type ChatResponse struct { diff --git a/relay/channel/openai/main.go b/relay/channel/openai/main.go index d47cd164..63cb9ae8 100644 --- a/relay/channel/openai/main.go +++ b/relay/channel/openai/main.go @@ -6,6 +6,7 @@ import ( "encoding/json" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" @@ -53,7 +54,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E continue // just ignore the error } for _, choice := range streamResponse.Choices { - responseText += choice.Delta.Content + responseText += conv.AsString(choice.Delta.Content) } if streamResponse.Usage != nil { usage = streamResponse.Usage diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go index 6c0b2c53..30d77739 100644 --- a/relay/channel/openai/model.go +++ b/relay/channel/openai/model.go @@ -118,12 +118,9 @@ type ImageResponse struct { } type ChatCompletionsStreamResponseChoice struct { - Index int `json:"index"` - Delta struct { - Content string `json:"content"` - Role string `json:"role,omitempty"` - } `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` + Index int `json:"index"` + Delta model.Message `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` } type ChatCompletionsStreamResponse struct { diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go index cfdc0bfd..b5a64cde 100644 --- a/relay/channel/tencent/main.go +++ b/relay/channel/tencent/main.go @@ -10,6 +10,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/channel/openai" @@ -129,7 +130,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC } response := streamResponseTencent2OpenAI(&TencentResponse) if len(response.Choices) != 0 { - responseText += response.Choices[0].Delta.Content + responseText += conv.AsString(response.Choices[0].Delta.Content) } jsonResponse, err := json.Marshal(response) if err != nil { diff --git a/relay/model/general.go b/relay/model/general.go index fbcc04e8..86facf04 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -5,25 +5,27 @@ type ResponseFormat struct { } type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` + Model string `json:"model,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"` Seed float64 `json:"seed,omitempty"` - Tools any `json:"tools,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Tools []Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` + FunctionCall any `json:"function_call,omitempty"` + Functions any `json:"functions,omitempty"` User string `json:"user,omitempty"` + Prompt any `json:"prompt,omitempty"` + Input any `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` } func (r GeneralOpenAIRequest) ParseInput() []string { diff --git a/relay/model/message.go b/relay/model/message.go index c6c8a271..32a1055b 100644 --- a/relay/model/message.go +++ b/relay/model/message.go @@ -1,9 +1,10 @@ package model type Message struct { - Role string `json:"role"` - Content any `json:"content"` - Name *string `json:"name,omitempty"` + Role string `json:"role,omitempty"` + Content any `json:"content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls []Tool `json:"tool_calls,omitempty"` } func (m Message) IsStringContent() bool { diff --git a/relay/model/tool.go b/relay/model/tool.go new file mode 100644 index 00000000..253dca35 --- /dev/null +++ b/relay/model/tool.go @@ -0,0 +1,14 @@ +package model + +type Tool struct { + Id string `json:"id,omitempty"` + Type string `json:"type"` + Function Function `json:"function"` +} + +type Function struct { + Description string `json:"description,omitempty"` + Name string `json:"name"` + Parameters any `json:"parameters,omitempty"` // request + Arguments any `json:"arguments,omitempty"` // response +}