diff --git a/controller/billing.go b/controller/billing.go index ec8e3ce3..7bc63425 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -71,7 +71,7 @@ func GetUsage(c *gin.Context) { } usage := OpenAIUsageResponse{ Object: "list", - TotalUsage: amount, + TotalUsage: amount * 100, } c.JSON(200, usage) return diff --git a/controller/relay-text.go b/controller/relay-text.go index bdf38f26..95c58eb8 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "encoding/json" + "errors" "fmt" "github.com/gin-gonic/gin" "io" @@ -30,6 +31,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if relayMode == RelayModeModeration && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" } + // request validation + if textRequest.Model == "" { + return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) + } + switch relayMode { + case RelayModeCompletions: + if textRequest.Prompt == "" { + return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) + } + case RelayModeChatCompletions: + if len(textRequest.Messages) == 0 { + return errorWrapper(errors.New("messages is required"), "required_field_missing", http.StatusBadRequest) + } + case RelayModeEmbeddings: + case RelayModeModeration: + if textRequest.Input == "" { + return errorWrapper(errors.New("input is required"), "required_field_missing", http.StatusBadRequest) + } + } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() if c.GetString("base_url") != "" { @@ -140,21 +160,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if ratio != 0 && quota <= 0 { quota = 1 } + totalTokens := promptTokens + completionTokens + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + } tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, channelName) + model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) if strings.Contains(channelName, "免费") == false { quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } - //tokenName := c.GetString("token_name") - //logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - //model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, channelName) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) + if quota != 0 { + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) + } } } }()