fix: calculate usage if not given in non-stream mode (#352)

This commit is contained in:
glzjin 2023-08-06 17:40:31 +08:00 committed by GitHub
parent 1dfa190e79
commit 446337c329
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 4 deletions

View File

@ -92,7 +92,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
return nil, responseText
}
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*OpenAIErrorWithStatusCode, *Usage) {
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
var textResponse TextResponse
if consumeQuota {
responseBody, err := io.ReadAll(resp.Body)
@ -132,5 +132,17 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*Ope
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range textResponse.Choices {
completionTokens += countTokenText(choice.Message.Content, model)
}
textResponse.Usage = Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
return nil, &textResponse.Usage
}

View File

@ -362,7 +362,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := openaiHandler(c, resp, consumeQuota)
err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)
if err != nil {
return err
}

View File

@ -81,8 +81,9 @@ type OpenAIErrorWithStatusCode struct {
}
type TextResponse struct {
Usage `json:"usage"`
Error OpenAIError `json:"error"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type OpenAITextResponseChoice struct {