feat: support function call for ali (close #1242)

This commit is contained in:
JustSong 2024-03-30 10:43:26 +08:00
parent 5e81e19bc8
commit 2ba28c72cb
9 changed files with 74 additions and 40 deletions

6
common/conv/any.go Normal file
View File

@ -0,0 +1,6 @@
package conv
func AsString(v any) string {
str, _ := v.(string)
return str
}

View File

@ -48,7 +48,10 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
MaxTokens: request.MaxTokens, MaxTokens: request.MaxTokens,
Temperature: request.Temperature, Temperature: request.Temperature,
TopP: request.TopP, TopP: request.TopP,
TopK: request.TopK,
ResultFormat: "message",
}, },
Tools: request.Tools,
} }
} }
@ -117,19 +120,11 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR
} }
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: response.RequestId, Id: response.RequestId,
Object: "chat.completion", Object: "chat.completion",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice}, Choices: response.Output.Choices,
Usage: model.Usage{ Usage: model.Usage{
PromptTokens: response.Usage.InputTokens, PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens, CompletionTokens: response.Usage.OutputTokens,
@ -140,10 +135,14 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
} }
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
if len(aliResponse.Output.Choices) == 0 {
return nil
}
aliChoice := aliResponse.Output.Choices[0]
var choice openai.ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text choice.Delta = aliChoice.Message
if aliResponse.Output.FinishReason != "null" { if aliChoice.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason finishReason := aliChoice.FinishReason
choice.FinishReason = &finishReason choice.FinishReason = &finishReason
} }
response := openai.ChatCompletionsStreamResponse{ response := openai.ChatCompletionsStreamResponse{
@ -204,6 +203,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
} }
response := streamResponseAli2OpenAI(&aliResponse) response := streamResponseAli2OpenAI(&aliResponse)
if response == nil {
return true
}
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) //response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
//lastResponseText = aliResponse.Output.Text //lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
@ -226,6 +228,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
ctx := c.Request.Context()
var aliResponse ChatResponse var aliResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
@ -235,6 +238,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
logger.Debugf(ctx, "response body: %s\n", responseBody)
err = json.Unmarshal(responseBody, &aliResponse) err = json.Unmarshal(responseBody, &aliResponse)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

View File

@ -1,5 +1,10 @@
package ali package ali
import (
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
)
type Message struct { type Message struct {
Content string `json:"content"` Content string `json:"content"`
Role string `json:"role"` Role string `json:"role"`
@ -18,12 +23,14 @@ type Parameters struct {
IncrementalOutput bool `json:"incremental_output,omitempty"` IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
ResultFormat string `json:"result_format,omitempty"`
} }
type ChatRequest struct { type ChatRequest struct {
Model string `json:"model"` Model string `json:"model"`
Input Input `json:"input"` Input Input `json:"input"`
Parameters Parameters `json:"parameters,omitempty"` Parameters Parameters `json:"parameters,omitempty"`
Tools []model.Tool `json:"tools,omitempty"`
} }
type EmbeddingRequest struct { type EmbeddingRequest struct {
@ -62,8 +69,9 @@ type Usage struct {
} }
type Output struct { type Output struct {
Text string `json:"text"` //Text string `json:"text"`
FinishReason string `json:"finish_reason"` //FinishReason string `json:"finish_reason"`
Choices []openai.TextResponseChoice `json:"choices"`
} }
type ChatResponse struct { type ChatResponse struct {

View File

@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
@ -53,7 +54,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
continue // just ignore the error continue // just ignore the error
} }
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseText += choice.Delta.Content responseText += conv.AsString(choice.Delta.Content)
} }
if streamResponse.Usage != nil { if streamResponse.Usage != nil {
usage = streamResponse.Usage usage = streamResponse.Usage

View File

@ -118,12 +118,9 @@ type ImageResponse struct {
} }
type ChatCompletionsStreamResponseChoice struct { type ChatCompletionsStreamResponseChoice struct {
Index int `json:"index"` Index int `json:"index"`
Delta struct { Delta model.Message `json:"delta"`
Content string `json:"content"` FinishReason *string `json:"finish_reason,omitempty"`
Role string `json:"role,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
} }
type ChatCompletionsStreamResponse struct { type ChatCompletionsStreamResponse struct {

View File

@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
@ -129,7 +130,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
response := streamResponseTencent2OpenAI(&TencentResponse) response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 { if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.Content responseText += conv.AsString(response.Choices[0].Delta.Content)
} }
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {

View File

@ -5,25 +5,27 @@ type ResponseFormat struct {
} }
type GeneralOpenAIRequest struct { type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"` Model string `json:"model,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
N int `json:"n,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"` Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
Functions any `json:"functions,omitempty"`
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
Prompt any `json:"prompt,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
} }
func (r GeneralOpenAIRequest) ParseInput() []string { func (r GeneralOpenAIRequest) ParseInput() []string {

View File

@ -1,9 +1,10 @@
package model package model
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role,omitempty"`
Content any `json:"content"` Content any `json:"content,omitempty"`
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
ToolCalls []Tool `json:"tool_calls,omitempty"`
} }
func (m Message) IsStringContent() bool { func (m Message) IsStringContent() bool {

14
relay/model/tool.go Normal file
View File

@ -0,0 +1,14 @@
package model
type Tool struct {
Id string `json:"id,omitempty"`
Type string `json:"type"`
Function Function `json:"function"`
}
type Function struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Parameters any `json:"parameters,omitempty"` // request
Arguments any `json:"arguments,omitempty"` // response
}