diff --git a/controller/relay-openai.go b/controller/relay-openai.go index 6bdfbc08..fe4f2729 100644 --- a/controller/relay-openai.go +++ b/controller/relay-openai.go @@ -11,8 +11,10 @@ import ( "strings" ) -func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { +func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string, string, string) { responseText := "" + responseFunctionCallName := "" + responseFunctionCallArguments := "" scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -50,6 +52,8 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O } for _, choice := range streamResponse.Choices { responseText += choice.Delta.Content + responseFunctionCallName += choice.Delta.FunctionCall.Name + responseFunctionCallArguments += choice.Delta.FunctionCall.Arguments } case RelayModeCompletions: var streamResponse CompletionsStreamResponse @@ -83,9 +87,9 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", "", "" } - return nil, responseText + return nil, responseText, responseFunctionCallName, responseFunctionCallArguments } func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { diff --git a/controller/relay-text.go b/controller/relay-text.go index e8dab514..c57e7553 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -177,6 +177,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { switch relayMode { case RelayModeChatCompletions: promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) + promptTokens += countTokenFunctions(textRequest.Functions, textRequest.Model) + promptTokens += countTokenFunctionCall(textRequest.FunctionCall, textRequest.Model) case RelayModeCompletions: promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) case RelayModeModerations: @@ -366,12 +368,19 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { switch apiType { case APITypeOpenAI: if isStream { - err, responseText := openaiStreamHandler(c, resp, relayMode) + err, responseText, responseFunctionCallName, responseFunctionCallArguments := openaiStreamHandler(c, resp, relayMode) if err != nil { return err } textResponse.Usage.PromptTokens = promptTokens textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + if responseFunctionCallName != "" { + textResponse.Usage.CompletionTokens += countTokenFunctionCall(responseFunctionCallName, textRequest.Model) + } + if responseFunctionCallArguments != "" { + responseFunctionCallArguments = strings.Replace(responseFunctionCallArguments, "\\\"", "\"", -1) + textResponse.Usage.CompletionTokens += countTokenFunctionCall(responseFunctionCallArguments, textRequest.Model) + } return nil } else { err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 5b3e0274..3cc2cf15 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -1,10 +1,13 @@ package controller import ( + "encoding/json" "fmt" "github.com/gin-gonic/gin" "github.com/pkoukk/tiktoken-go" + "gorm.io/gorm/utils" "one-api/common" + "strings" ) var stopFinishReason = "stop" @@ -34,6 +37,71 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } +func countTokenFunctionCall(functionCall any, model string) int { + tokenEncoder := getTokenEncoder(model) + jsonBytes, err := json.Marshal(functionCall) + if err != nil { + return 0 + } + return getTokenNum(tokenEncoder, string(jsonBytes)) +} + +func countTokenFunctions(functions []Function, model string) int { + // https://community.openai.com/t/how-to-know-of-tokens-beforehand-when-i-make-function-calling-chat-history-request-witn-nodejs/289060/6 + if len(functions) == 0 { + return 0 + } + tokenEncoder := getTokenEncoder(model) + + paramSignature := func(name string, pSpec Property, pRequired []string) string { + var requiredString string + if utils.Contains(pRequired, name) == false { + requiredString = "?" + } + var enumString string + if len(pSpec.Enum) > 0 { + enumValues := make([]string, len(pSpec.Enum)) + for i, v := range pSpec.Enum { + enumValues[i] = fmt.Sprintf("\"%s\"", v) + } + enumString = strings.Join(enumValues, " | ") + } else { + enumString = pSpec.Type + } + signature := fmt.Sprintf("%s%s: %s, ", name, requiredString, enumString) + if pSpec.Description != "" { + signature = fmt.Sprintf("// %s\n%s", pSpec.Description, signature) + } + return signature + } + + functionSignature := func(fSpec Function) string { + var params []string + for name, p := range fSpec.Parameters.Properties { + params = append(params, paramSignature(name, p, fSpec.Parameters.Required)) + } + var descriptionString string + if fSpec.Description != "" { + descriptionString = fmt.Sprintf("// %s\n", fSpec.Description) + } + + var paramString string + if len(params) > 0 { + paramString = fmt.Sprintf("_: {\n%s\n}", strings.Join(params, "\n")) + } + + return fmt.Sprintf("%stype %s = (%s) => any;", descriptionString, fSpec.Name, paramString) + } + + var functionSignatures []string + for _, f := range functions { + functionSignatures = append(functionSignatures, functionSignature(f)) + } + functionString := fmt.Sprintf("# Tools\n\n## functions\n\nnamespace functions {\n\n%s\n\n} // namespace functions", strings.Join(functionSignatures, "\n\n")) + + return getTokenNum(tokenEncoder, functionString) +} + func countTokenMessages(messages []Message, model string) int { tokenEncoder := getTokenEncoder(model) // Reference: diff --git a/controller/relay.go b/controller/relay.go index 86f16c45..11145ae8 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -16,6 +16,28 @@ type Message struct { Name *string `json:"name,omitempty"` } +type Property struct { + Type string `json:"type"` + Description string `json:"description"` + Enum []string `json:"enum"` +} + +type Parameter struct { + Type string `json:"type"` + Properties map[string]Property `json:"properties"` + Required []string `json:"required"` +} + +type Function struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters Parameter `json:"parameters"` +} + +type FunctionCall struct { + Name string `json:"name"` +} + const ( RelayModeUnknown = iota RelayModeChatCompletions @@ -29,17 +51,19 @@ const ( // https://platform.openai.com/docs/api-reference/chat type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` - Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt any `json:"prompt,omitempty"` + Stream bool `json:"stream,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Input any `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` + Functions []Function `json:"functions,omitempty"` + FunctionCall interface{} `json:"functioncall,omitempty"` } type ChatRequest struct { @@ -89,7 +113,8 @@ type TextResponse struct { type OpenAITextResponseChoice struct { Index int `json:"index"` Message `json:"message"` - FinishReason string `json:"finish_reason"` + FinishReason string `json:"finish_reason"` + Functions []Function `json:"functions,omitempty"` } type OpenAITextResponse struct { @@ -122,7 +147,11 @@ type ImageResponse struct { type ChatCompletionsStreamResponseChoice struct { Delta struct { - Content string `json:"content"` + Content string `json:"content"` + FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function_call"` } `json:"delta"` FinishReason *string `json:"finish_reason"` }