diff --git a/common/constants.go b/common/constants.go index c7d3f222..f938f828 100644 --- a/common/constants.go +++ b/common/constants.go @@ -97,6 +97,8 @@ var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second +var LogPrompt = os.Getenv("LOG_PROMPT") == "true" + const ( RequestIdKey = "X-Oneapi-Request-Id" ) diff --git a/controller/relay-openai.go b/controller/relay-openai.go index d90827f4..da593fe2 100644 --- a/controller/relay-openai.go +++ b/controller/relay-openai.go @@ -11,9 +11,11 @@ import ( "strings" ) -func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { - // 1. 因为这个是空的 +func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string, string, string) { responseText := "" + responseFunctionCallName := "" + responseFunctionCallArguments := "" + scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -32,7 +34,6 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O go func() { for scanner.Scan() { data := scanner.Text() - common.LogInfo(c, "stream received: "+data) if len(data) < 6 { // ignore blank line or wrong format continue } @@ -52,6 +53,10 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O } for _, choice := range streamResponse.Choices { responseText += choice.Delta.Content + if choice.Delta.FunctionCall != nil { + responseFunctionCallName += choice.Delta.FunctionCall.Name + responseFunctionCallArguments += choice.Delta.FunctionCall.Arguments + } } case RelayModeCompletions: var streamResponse CompletionsStreamResponse @@ -85,10 +90,9 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O }) err := resp.Body.Close() if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", "", "" } - common.LogInfo(c, "stream ended, responseText: "+responseText) - return nil, responseText + return nil, responseText, responseFunctionCallName, responseFunctionCallArguments } func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { diff --git a/controller/relay-text.go b/controller/relay-text.go index 25b8bc06..56142204 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -202,6 +202,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { switch relayMode { case RelayModeChatCompletions: promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) + promptTokens += countTokenFunctions(textRequest.Functions, textRequest.Model) + promptTokens += countTokenFunctionCall(textRequest.FunctionCall, textRequest.Model) case RelayModeCompletions: promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) case RelayModeModerations: @@ -332,6 +334,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { isStream := textRequest.Stream if apiType != APITypeXunfei { // cause xunfei use websocket + if common.LogPrompt { + requestRaw, err := io.ReadAll(requestBody) + var logContent string + if err != nil { + logContent = fmt.Sprintf("failed to read request body, err: %s", err) + } else { + logContent = "request content: " + string(requestRaw) + } + common.LogInfo(c, logContent) + } req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) @@ -447,12 +459,19 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { switch apiType { case APITypeOpenAI: if isStream { - err, responseText := openaiStreamHandler(c, resp, relayMode) + err, responseText, responseFunctionCallName, responseFunctionCallArguments := openaiStreamHandler(c, resp, relayMode) if err != nil { return err } textResponse.Usage.PromptTokens = promptTokens textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + if responseFunctionCallName != "" { + textResponse.Usage.CompletionTokens += countTokenFunctionCall(responseFunctionCallName, textRequest.Model) + } + if responseFunctionCallArguments != "" { + responseFunctionCallArguments = strings.Replace(responseFunctionCallArguments, "\\\"", "\"", -1) + textResponse.Usage.CompletionTokens += countTokenFunctionCall(responseFunctionCallArguments, textRequest.Model) + } return nil } else { err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index cf5d9b69..c4a2031c 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/pkoukk/tiktoken-go" + "gorm.io/gorm/utils" "io" "net/http" "one-api/common" @@ -114,6 +115,71 @@ func countTokenText(text string, model string) int { return getTokenNum(tokenEncoder, text) } +func countTokenFunctionCall(functionCall any, model string) int { + tokenEncoder := getTokenEncoder(model) + jsonBytes, err := json.Marshal(functionCall) + if err != nil { + return 0 + } + return getTokenNum(tokenEncoder, string(jsonBytes)) +} + +func countTokenFunctions(functions []Function, model string) int { + // https://community.openai.com/t/how-to-know-of-tokens-beforehand-when-i-make-function-calling-chat-history-request-witn-nodejs/289060/6 + if len(functions) == 0 { + return 0 + } + tokenEncoder := getTokenEncoder(model) + + paramSignature := func(name string, pSpec Property, pRequired []string) string { + var requiredString string + if utils.Contains(pRequired, name) == false { + requiredString = "?" + } + var enumString string + if len(pSpec.Enum) > 0 { + enumValues := make([]string, len(pSpec.Enum)) + for i, v := range pSpec.Enum { + enumValues[i] = fmt.Sprintf("\"%s\"", v) + } + enumString = strings.Join(enumValues, " | ") + } else { + enumString = pSpec.Type + } + signature := fmt.Sprintf("%s%s: %s, ", name, requiredString, enumString) + if pSpec.Description != "" { + signature = fmt.Sprintf("// %s\n%s", pSpec.Description, signature) + } + return signature + } + + functionSignature := func(fSpec Function) string { + var params []string + for name, p := range fSpec.Parameters.Properties { + params = append(params, paramSignature(name, p, fSpec.Parameters.Required)) + } + var descriptionString string + if fSpec.Description != "" { + descriptionString = fmt.Sprintf("// %s\n", fSpec.Description) + } + + var paramString string + if len(params) > 0 { + paramString = fmt.Sprintf("_: {\n%s\n}", strings.Join(params, "\n")) + } + + return fmt.Sprintf("%stype %s = (%s) => any;", descriptionString, fSpec.Name, paramString) + } + + var functionSignatures []string + for _, f := range functions { + functionSignatures = append(functionSignatures, functionSignature(f)) + } + functionString := fmt.Sprintf("# Tools\n\n## functions\n\nnamespace functions {\n\n%s\n\n} // namespace functions", strings.Join(functionSignatures, "\n\n")) + + return getTokenNum(tokenEncoder, functionString) +} + func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode { openAIError := OpenAIError{ Message: err.Error(), diff --git a/controller/relay.go b/controller/relay.go index 1926110e..45ed91a5 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -29,19 +29,38 @@ const ( // https://platform.openai.com/docs/api-reference/chat +type Property struct { + Type string `json:"type"` + Description string `json:"description"` + Enum []string `json:"enum"` +} + +type Parameter struct { + Type string `json:"type"` + Properties map[string]Property `json:"properties"` + Required []string `json:"required"` +} + +type Function struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters Parameter `json:"parameters"` +} + type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` - Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt any `json:"prompt,omitempty"` + Stream bool `json:"stream,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Input any `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` + Functions []Function `json:"functions,omitempty"` + FunctionCall any `json:"functioncall,omitempty"` } func (r GeneralOpenAIRequest) ParseInput() []string { @@ -111,6 +130,11 @@ type TextResponse struct { Error OpenAIError `json:"error"` } +type FunctionCall struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` +} + type OpenAITextResponseChoice struct { Index int `json:"index"` Message `json:"message"` @@ -147,7 +171,8 @@ type ImageResponse struct { type ChatCompletionsStreamResponseChoice struct { Delta struct { - Content string `json:"content"` + Content string `json:"content"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` } `json:"delta"` FinishReason *string `json:"finish_reason"` }