Merge remote-tracking branch 'songquanpeng/main'

# Conflicts:
#	controller/relay-text.go
This commit is contained in:
quzard 2023-06-25 10:28:53 +08:00
commit 760dabd79d
2 changed files with 33 additions and 8 deletions

View File

@ -71,7 +71,7 @@ func GetUsage(c *gin.Context) {
}
usage := OpenAIUsageResponse{
Object: "list",
TotalUsage: amount,
TotalUsage: amount * 100,
}
c.JSON(200, usage)
return

View File

@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
@ -30,6 +31,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if relayMode == RelayModeModeration && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
// request validation
if textRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
}
switch relayMode {
case RelayModeCompletions:
if textRequest.Prompt == "" {
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeChatCompletions:
if len(textRequest.Messages) == 0 {
return errorWrapper(errors.New("messages is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEmbeddings:
case RelayModeModeration:
if textRequest.Input == "" {
return errorWrapper(errors.New("input is required"), "required_field_missing", http.StatusBadRequest)
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
@ -140,21 +160,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, channelName)
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
if strings.Contains(channelName, "免费") == false {
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
//tokenName := c.GetString("token_name")
//logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
//model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, channelName)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
if quota != 0 {
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}
}()