Refine openai function call bill

This commit is contained in:
glzjin 2023-08-13 21:32:04 +08:00
parent da1d81998f
commit 6b19dbba1f
4 changed files with 127 additions and 17 deletions

View File

@ -11,8 +11,10 @@ import (
"strings" "strings"
) )
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string, string, string) {
responseText := "" responseText := ""
responseFunctionCallName := ""
responseFunctionCallArguments := ""
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
@ -50,6 +52,8 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
} }
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseText += choice.Delta.Content responseText += choice.Delta.Content
responseFunctionCallName += choice.Delta.FunctionCall.Name
responseFunctionCallArguments += choice.Delta.FunctionCall.Arguments
} }
case RelayModeCompletions: case RelayModeCompletions:
var streamResponse CompletionsStreamResponse var streamResponse CompletionsStreamResponse
@ -83,9 +87,9 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", "", ""
} }
return nil, responseText return nil, responseText, responseFunctionCallName, responseFunctionCallArguments
} }
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {

View File

@ -177,6 +177,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
switch relayMode { switch relayMode {
case RelayModeChatCompletions: case RelayModeChatCompletions:
promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
promptTokens += countTokenFunctions(textRequest.Functions, textRequest.Model)
promptTokens += countTokenFunctionCall(textRequest.FunctionCall, textRequest.Model)
case RelayModeCompletions: case RelayModeCompletions:
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
case RelayModeModerations: case RelayModeModerations:
@ -366,12 +368,19 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
switch apiType { switch apiType {
case APITypeOpenAI: case APITypeOpenAI:
if isStream { if isStream {
err, responseText := openaiStreamHandler(c, resp, relayMode) err, responseText, responseFunctionCallName, responseFunctionCallArguments := openaiStreamHandler(c, resp, relayMode)
if err != nil { if err != nil {
return err return err
} }
textResponse.Usage.PromptTokens = promptTokens textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
if responseFunctionCallName != "" {
textResponse.Usage.CompletionTokens += countTokenFunctionCall(responseFunctionCallName, textRequest.Model)
}
if responseFunctionCallArguments != "" {
responseFunctionCallArguments = strings.Replace(responseFunctionCallArguments, "\\\"", "\"", -1)
textResponse.Usage.CompletionTokens += countTokenFunctionCall(responseFunctionCallArguments, textRequest.Model)
}
return nil return nil
} else { } else {
err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)

View File

@ -1,10 +1,13 @@
package controller package controller
import ( import (
"encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go" "github.com/pkoukk/tiktoken-go"
"gorm.io/gorm/utils"
"one-api/common" "one-api/common"
"strings"
) )
var stopFinishReason = "stop" var stopFinishReason = "stop"
@ -34,6 +37,71 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
return len(tokenEncoder.Encode(text, nil, nil)) return len(tokenEncoder.Encode(text, nil, nil))
} }
func countTokenFunctionCall(functionCall any, model string) int {
tokenEncoder := getTokenEncoder(model)
jsonBytes, err := json.Marshal(functionCall)
if err != nil {
return 0
}
return getTokenNum(tokenEncoder, string(jsonBytes))
}
func countTokenFunctions(functions []Function, model string) int {
// https://community.openai.com/t/how-to-know-of-tokens-beforehand-when-i-make-function-calling-chat-history-request-witn-nodejs/289060/6
if len(functions) == 0 {
return 0
}
tokenEncoder := getTokenEncoder(model)
paramSignature := func(name string, pSpec Property, pRequired []string) string {
var requiredString string
if utils.Contains(pRequired, name) == false {
requiredString = "?"
}
var enumString string
if len(pSpec.Enum) > 0 {
enumValues := make([]string, len(pSpec.Enum))
for i, v := range pSpec.Enum {
enumValues[i] = fmt.Sprintf("\"%s\"", v)
}
enumString = strings.Join(enumValues, " | ")
} else {
enumString = pSpec.Type
}
signature := fmt.Sprintf("%s%s: %s, ", name, requiredString, enumString)
if pSpec.Description != "" {
signature = fmt.Sprintf("// %s\n%s", pSpec.Description, signature)
}
return signature
}
functionSignature := func(fSpec Function) string {
var params []string
for name, p := range fSpec.Parameters.Properties {
params = append(params, paramSignature(name, p, fSpec.Parameters.Required))
}
var descriptionString string
if fSpec.Description != "" {
descriptionString = fmt.Sprintf("// %s\n", fSpec.Description)
}
var paramString string
if len(params) > 0 {
paramString = fmt.Sprintf("_: {\n%s\n}", strings.Join(params, "\n"))
}
return fmt.Sprintf("%stype %s = (%s) => any;", descriptionString, fSpec.Name, paramString)
}
var functionSignatures []string
for _, f := range functions {
functionSignatures = append(functionSignatures, functionSignature(f))
}
functionString := fmt.Sprintf("# Tools\n\n## functions\n\nnamespace functions {\n\n%s\n\n} // namespace functions", strings.Join(functionSignatures, "\n\n"))
return getTokenNum(tokenEncoder, functionString)
}
func countTokenMessages(messages []Message, model string) int { func countTokenMessages(messages []Message, model string) int {
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)
// Reference: // Reference:

View File

@ -16,6 +16,28 @@ type Message struct {
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
} }
type Property struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum"`
}
type Parameter struct {
Type string `json:"type"`
Properties map[string]Property `json:"properties"`
Required []string `json:"required"`
}
type Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters Parameter `json:"parameters"`
}
type FunctionCall struct {
Name string `json:"name"`
}
const ( const (
RelayModeUnknown = iota RelayModeUnknown = iota
RelayModeChatCompletions RelayModeChatCompletions
@ -29,17 +51,19 @@ const (
// https://platform.openai.com/docs/api-reference/chat // https://platform.openai.com/docs/api-reference/chat
type GeneralOpenAIRequest struct { type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"` Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"` N int `json:"n,omitempty"`
Input any `json:"input,omitempty"` Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"` Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"` Size string `json:"size,omitempty"`
Functions []Function `json:"functions,omitempty"`
FunctionCall interface{} `json:"functioncall,omitempty"`
} }
type ChatRequest struct { type ChatRequest struct {
@ -89,7 +113,8 @@ type TextResponse struct {
type OpenAITextResponseChoice struct { type OpenAITextResponseChoice struct {
Index int `json:"index"` Index int `json:"index"`
Message `json:"message"` Message `json:"message"`
FinishReason string `json:"finish_reason"` FinishReason string `json:"finish_reason"`
Functions []Function `json:"functions,omitempty"`
} }
type OpenAITextResponse struct { type OpenAITextResponse struct {
@ -122,7 +147,11 @@ type ImageResponse struct {
type ChatCompletionsStreamResponseChoice struct { type ChatCompletionsStreamResponseChoice struct {
Delta struct { Delta struct {
Content string `json:"content"` Content string `json:"content"`
FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function_call"`
} `json:"delta"` } `json:"delta"`
FinishReason *string `json:"finish_reason"` FinishReason *string `json:"finish_reason"`
} }