196 lines
7.0 KiB
Go
196 lines
7.0 KiB
Go
package controller
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/songquanpeng/one-api/common"
|
|
"github.com/songquanpeng/one-api/common/config"
|
|
"github.com/songquanpeng/one-api/common/logger"
|
|
"github.com/songquanpeng/one-api/model"
|
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
|
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
|
"github.com/songquanpeng/one-api/relay/meta"
|
|
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
|
"github.com/songquanpeng/one-api/relay/util"
|
|
"math"
|
|
"net/http"
|
|
)
|
|
|
|
func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) {
|
|
textRequest := &relaymodel.GeneralOpenAIRequest{}
|
|
err := common.UnmarshalBodyReusable(c, textRequest)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if relayMode == relaymode.Moderations && textRequest.Model == "" {
|
|
textRequest.Model = "text-moderation-latest"
|
|
}
|
|
if relayMode == relaymode.Embeddings && textRequest.Model == "" {
|
|
textRequest.Model = c.Param("model")
|
|
}
|
|
err = util.ValidateTextRequest(textRequest, relayMode)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return textRequest, nil
|
|
}
|
|
|
|
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
|
|
imageRequest := &relaymodel.ImageRequest{}
|
|
err := common.UnmarshalBodyReusable(c, imageRequest)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if imageRequest.N == 0 {
|
|
imageRequest.N = 1
|
|
}
|
|
if imageRequest.Size == "" {
|
|
imageRequest.Size = "1024x1024"
|
|
}
|
|
if imageRequest.Model == "" {
|
|
imageRequest.Model = "dall-e-2"
|
|
}
|
|
return imageRequest, nil
|
|
}
|
|
|
|
func isValidImageSize(model string, size string) bool {
|
|
if model == "cogview-3" {
|
|
return true
|
|
}
|
|
_, ok := billingratio.ImageSizeRatios[model][size]
|
|
return ok
|
|
}
|
|
|
|
func getImageSizeRatio(model string, size string) float64 {
|
|
ratio, ok := billingratio.ImageSizeRatios[model][size]
|
|
if !ok {
|
|
return 1
|
|
}
|
|
return ratio
|
|
}
|
|
|
|
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
|
// model validation
|
|
hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size)
|
|
if !hasValidSize {
|
|
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
|
|
}
|
|
// check prompt length
|
|
if imageRequest.Prompt == "" {
|
|
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
|
}
|
|
if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] {
|
|
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
|
|
}
|
|
// Number of generated images validation
|
|
if !isWithinRange(imageRequest.Model, imageRequest.N) {
|
|
// channel not azure
|
|
if meta.ChannelType != channeltype.Azure {
|
|
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
|
|
if imageRequest == nil {
|
|
return 0, errors.New("imageRequest is nil")
|
|
}
|
|
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
|
|
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
|
|
if imageRequest.Size == "1024x1024" {
|
|
imageCostRatio *= 2
|
|
} else {
|
|
imageCostRatio *= 1.5
|
|
}
|
|
}
|
|
return imageCostRatio, nil
|
|
}
|
|
|
|
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
|
|
switch relayMode {
|
|
case relaymode.ChatCompletions:
|
|
return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
|
|
case relaymode.Completions:
|
|
return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
|
case relaymode.Moderations:
|
|
return openai.CountTokenInput(textRequest.Input, textRequest.Model)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int64 {
|
|
preConsumedTokens := config.PreConsumedQuota
|
|
if textRequest.MaxTokens != 0 {
|
|
preConsumedTokens = int64(promptTokens) + int64(textRequest.MaxTokens)
|
|
}
|
|
return int64(float64(preConsumedTokens) * ratio)
|
|
}
|
|
|
|
func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *meta.Meta) (int64, *relaymodel.ErrorWithStatusCode) {
|
|
preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio)
|
|
|
|
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
|
if err != nil {
|
|
return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
|
}
|
|
if userQuota-preConsumedQuota < 0 {
|
|
return preConsumedQuota, openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
|
}
|
|
err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota)
|
|
if err != nil {
|
|
return preConsumedQuota, openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
|
}
|
|
if userQuota > 100*preConsumedQuota {
|
|
// in this case, we do not pre-consume quota
|
|
// because the user has enough quota
|
|
preConsumedQuota = 0
|
|
logger.Info(ctx, fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota))
|
|
}
|
|
if preConsumedQuota > 0 {
|
|
err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota)
|
|
if err != nil {
|
|
return preConsumedQuota, openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
|
}
|
|
}
|
|
return preConsumedQuota, nil
|
|
}
|
|
|
|
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
|
|
if usage == nil {
|
|
logger.Error(ctx, "usage is nil, which is unexpected")
|
|
return
|
|
}
|
|
var quota int64
|
|
completionRatio := billingratio.GetCompletionRatio(textRequest.Model)
|
|
promptTokens := usage.PromptTokens
|
|
completionTokens := usage.CompletionTokens
|
|
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
|
if ratio != 0 && quota <= 0 {
|
|
quota = 1
|
|
}
|
|
totalTokens := promptTokens + completionTokens
|
|
if totalTokens == 0 {
|
|
// in this case, must be some error happened
|
|
// we cannot just return, because we may have to return the pre-consumed quota
|
|
quota = 0
|
|
}
|
|
quotaDelta := quota - preConsumedQuota
|
|
err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta)
|
|
if err != nil {
|
|
logger.Error(ctx, "error consuming token remain quota: "+err.Error())
|
|
}
|
|
err = model.CacheUpdateUserQuota(ctx, meta.UserId)
|
|
if err != nil {
|
|
logger.Error(ctx, "error update user quota cache: "+err.Error())
|
|
}
|
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
|
|
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
|
|
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
|
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
|
|
}
|