From de18d6fe16e41a18465db908b9851426e6340721 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Mar 2024 19:30:11 +0800 Subject: [PATCH 1/9] refactor: refactor image relay (close #1068) --- common/model-ratio.go | 23 ------- relay/constant/image.go | 24 +++++++ relay/controller/helper.go | 59 +++++++++++++++++ relay/controller/image.go | 132 +++++++++++-------------------------- 4 files changed, 122 insertions(+), 116 deletions(-) create mode 100644 relay/constant/image.go diff --git a/common/model-ratio.go b/common/model-ratio.go index 1594b534..ab0ad748 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -7,29 +7,6 @@ import ( "time" ) -var DalleSizeRatios = map[string]map[string]float64{ - "dall-e-2": { - "256x256": 1, - "512x512": 1.125, - "1024x1024": 1.25, - }, - "dall-e-3": { - "1024x1024": 1, - "1024x1792": 2, - "1792x1024": 2, - }, -} - -var DalleGenerationImageAmounts = map[string][2]int{ - "dall-e-2": {1, 10}, - "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. -} - -var DalleImagePromptLengthLimitations = map[string]int{ - "dall-e-2": 1000, - "dall-e-3": 4000, -} - const ( USD2RMB = 7 USD = 500 // $0.002 = 1 -> $1 = 500 diff --git a/relay/constant/image.go b/relay/constant/image.go new file mode 100644 index 00000000..5e04895f --- /dev/null +++ b/relay/constant/image.go @@ -0,0 +1,24 @@ +package constant + +var DalleSizeRatios = map[string]map[string]float64{ + "dall-e-2": { + "256x256": 1, + "512x512": 1.125, + "1024x1024": 1.25, + }, + "dall-e-3": { + "1024x1024": 1, + "1024x1792": 2, + "1792x1024": 2, + }, +} + +var DalleGenerationImageAmounts = map[string][2]int{ + "dall-e-2": {1, 10}, + "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. +} + +var DalleImagePromptLengthLimitations = map[string]int{ + "dall-e-2": 1000, + "dall-e-3": 4000, +} diff --git a/relay/controller/helper.go b/relay/controller/helper.go index a06b2768..d5078304 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -36,6 +36,65 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener return textRequest, nil } +func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error) { + imageRequest := &openai.ImageRequest{} + err := common.UnmarshalBodyReusable(c, imageRequest) + if err != nil { + return nil, err + } + if imageRequest.N == 0 { + imageRequest.N = 1 + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-2" + } + return imageRequest, nil +} + +func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode { + // model validation + _, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] + if !hasValidSize { + return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) + } + // check prompt length + if imageRequest.Prompt == "" { + return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) + } + if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] { + return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) + } + // Number of generated images validation + if !isWithinRange(imageRequest.Model, imageRequest.N) { + // channel not azure + if meta.ChannelType != common.ChannelTypeAzure { + return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + } + } + return nil +} + +func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) { + if imageRequest == nil { + return 0, errors.New("imageRequest is nil") + } + imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] + if !hasValidSize { + return 0, fmt.Errorf("size not supported for this image model: %s", imageRequest.Size) + } + if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { + if imageRequest.Size == "1024x1024" { + imageCostRatio *= 2 + } else { + imageCostRatio *= 1.5 + } + } + return imageCostRatio, nil +} + func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { case constant.RelayModeChatCompletions: diff --git a/relay/controller/image.go b/relay/controller/image.go index 6ec368f5..3ce3809b 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -10,6 +10,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/constant" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "io" @@ -20,120 +21,65 @@ import ( ) func isWithinRange(element string, value int) bool { - if _, ok := common.DalleGenerationImageAmounts[element]; !ok { + if _, ok := constant.DalleGenerationImageAmounts[element]; !ok { return false } - min := common.DalleGenerationImageAmounts[element][0] - max := common.DalleGenerationImageAmounts[element][1] + min := constant.DalleGenerationImageAmounts[element][0] + max := constant.DalleGenerationImageAmounts[element][1] return value >= min && value <= max } func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { - imageModel := "dall-e-2" - imageSize := "1024x1024" - - tokenId := c.GetInt("token_id") - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - userId := c.GetInt("id") - group := c.GetString("group") - - var imageRequest openai.ImageRequest - err := common.UnmarshalBodyReusable(c, &imageRequest) + ctx := c.Request.Context() + meta := util.GetRelayMeta(c) + imageRequest, err := getImageRequest(c, meta.Mode) if err != nil { - return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - - if imageRequest.N == 0 { - imageRequest.N = 1 - } - - // Size validation - if imageRequest.Size != "" { - imageSize = imageRequest.Size - } - - // Model validation - if imageRequest.Model != "" { - imageModel = imageRequest.Model - } - - imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] - - // Check if model is supported - if hasValidSize { - if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { - if imageSize == "1024x1024" { - imageCostRatio *= 2 - } else { - imageCostRatio *= 1.5 - } - } - } else { - return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) - } - - // Prompt validation - if imageRequest.Prompt == "" { - return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) - } - - // Check prompt length - if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { - return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) - } - - // Number of generated images validation - if !isWithinRange(imageModel, imageRequest.N) { - // channel not azure - if channelType != common.ChannelTypeAzure { - return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) - } + logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) } // map model name - modelMapping := c.GetString("model_mapping") - isModelMapped := false - if modelMapping != "" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[imageModel] != "" { - imageModel = modelMap[imageModel] - isModelMapped = true - } + var isModelMapped bool + meta.OriginModelName = imageRequest.Model + imageRequest.Model, isModelMapped = util.GetMappedModelName(imageRequest.Model, meta.ModelMapping) + meta.ActualModelName = imageRequest.Model + + // model validation + bizErr := validateImageRequest(imageRequest, meta) + if bizErr != nil { + return bizErr } - baseURL := common.ChannelBaseURLs[channelType] + + imageCostRatio, err := getImageCostRatio(imageRequest) + if err != nil { + return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) + } + requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) - if channelType == common.ChannelTypeAzure { + fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) + if meta.ChannelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api apiVersion := util.GetAzureAPIVersion(c) // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion) } var requestBody io.Reader - if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body + if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { - return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) } else { requestBody = c.Request.Body } - modelRatio := common.GetModelRatio(imageModel) - groupRatio := common.GetGroupRatio(group) + modelRatio := common.GetModelRatio(imageRequest.Model) + groupRatio := common.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio - userQuota, err := model.CacheGetUserQuota(userId) + userQuota, err := model.CacheGetUserQuota(meta.UserId) quota := int(ratio*imageCostRatio*1000) * imageRequest.N @@ -146,7 +92,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } token := c.Request.Header.Get("Authorization") - if channelType == common.ChannelTypeAzure { // Azure authentication + if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication token = strings.TrimPrefix(token, "Bearer ") req.Header.Set("api-key", token) } else { @@ -169,25 +115,25 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - var textResponse openai.ImageResponse + var imageResponse openai.ImageResponse defer func(ctx context.Context) { if resp.StatusCode != http.StatusOK { return } - err := model.PostConsumeTokenQuota(tokenId, quota) + err := model.PostConsumeTokenQuota(meta.TokenId, quota) if err != nil { logger.SysError("error consuming token remain quota: " + err.Error()) } - err = model.CacheUpdateUserQuota(userId) + err = model.CacheUpdateUserQuota(meta.UserId) if err != nil { logger.SysError("error update user quota cache: " + err.Error()) } if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) } @@ -202,7 +148,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } - err = json.Unmarshal(responseBody, &textResponse) + err = json.Unmarshal(responseBody, &imageResponse) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } From 82e916b5ff9c6c3f8325c91c04003fb6751f024c Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Mar 2024 20:51:28 +0800 Subject: [PATCH 2/9] fix: fix azure test (close #1069) --- controller/channel-test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/controller/channel-test.go b/controller/channel-test.go index b498f4f1..485d7702 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -8,6 +8,7 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" @@ -51,6 +52,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error c.Request.Header.Set("Content-Type", "application/json") c.Set("channel", channel.Type) c.Set("base_url", channel.GetBaseURL()) + middleware.SetupContextForSelectedChannel(c, channel, "") meta := util.GetRelayMeta(c) apiType := constant.ChannelType2APIType(channel.Type) adaptor := helper.GetAdaptor(apiType) From b35f3523d305de758021286fff9f62bece08af02 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Mar 2024 21:03:04 +0800 Subject: [PATCH 3/9] feat: add gemini model alias (close #1064) --- relay/channel/gemini/constants.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/relay/channel/gemini/constants.go b/relay/channel/gemini/constants.go index 5bb0c168..4e7c57f9 100644 --- a/relay/channel/gemini/constants.go +++ b/relay/channel/gemini/constants.go @@ -1,6 +1,6 @@ package gemini var ModelList = []string{ - "gemini-pro", - "gemini-pro-vision", + "gemini-pro", "gemini-1.0-pro-001", + "gemini-pro-vision", "gemini-1.0-pro-vision-001", } From 9d8967f7d325999ab28616aef5b9eeaa1cccf6dc Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Mar 2024 21:46:45 +0800 Subject: [PATCH 4/9] feat: support Mistral's models now (close #1051) --- README.md | 1 + common/constants.go | 2 ++ common/model-ratio.go | 23 ++++++++++++++++--- controller/model.go | 12 ++++++++++ relay/channel/mistral/constants.go | 10 ++++++++ relay/channel/openai/adaptor.go | 5 ++++ relay/controller/helper.go | 10 ++++---- web/berry/src/constants/ChannelConstants.js | 6 +++++ .../src/constants/channel.constants.js | 1 + 9 files changed, 61 insertions(+), 9 deletions(-) create mode 100644 relay/channel/mistral/constants.go diff --git a/README.md b/README.md index a92142ae..69bb10ef 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] [Anthropic Claude 系列模型](https://anthropic.com) + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) + + [x] [Mistral 系列模型](https://mistral.ai/) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) diff --git a/common/constants.go b/common/constants.go index f67dc146..ac901139 100644 --- a/common/constants.go +++ b/common/constants.go @@ -66,6 +66,7 @@ const ( ChannelTypeMoonshot = 25 ChannelTypeBaichuan = 26 ChannelTypeMinimax = 27 + ChannelTypeMistral = 28 ) var ChannelBaseURLs = []string{ @@ -97,6 +98,7 @@ var ChannelBaseURLs = []string{ "https://api.moonshot.cn", // 25 "https://api.baichuan-ai.com", // 26 "https://api.minimax.chat", // 27 + "https://api.mistral.ai", // 28 } const ( diff --git a/common/model-ratio.go b/common/model-ratio.go index ab0ad748..2e66ac0d 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -17,7 +17,6 @@ const ( // https://platform.openai.com/docs/models/model-endpoint-compatibility // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf // https://openai.com/pricing -// TODO: when a new api is enabled, check the pricing here // 1 === $0.002 / 1K tokens // 1 === ¥0.014 / 1k tokens var ModelRatio = map[string]float64{ @@ -116,15 +115,29 @@ var ModelRatio = map[string]float64{ "abab6-chat": 0.1 * RMB, "abab5.5-chat": 0.015 * RMB, "abab5.5s-chat": 0.005 * RMB, + // https://docs.mistral.ai/platform/pricing/ + "open-mistral-7b": 0.25 / 1000 * USD, + "open-mixtral-8x7b": 0.7 / 1000 * USD, + "mistral-small-latest": 2.0 / 1000 * USD, + "mistral-medium-latest": 2.7 / 1000 * USD, + "mistral-large-latest": 8.0 / 1000 * USD, + "mistral-embed": 0.1 / 1000 * USD, } +var CompletionRatio = map[string]float64{} + var DefaultModelRatio map[string]float64 +var DefaultCompletionRatio map[string]float64 func init() { DefaultModelRatio = make(map[string]float64) for k, v := range ModelRatio { DefaultModelRatio[k] = v } + DefaultCompletionRatio = make(map[string]float64) + for k, v := range CompletionRatio { + DefaultCompletionRatio[k] = v + } } func ModelRatio2JSONString() string { @@ -155,8 +168,6 @@ func GetModelRatio(name string) float64 { return ratio } -var CompletionRatio = map[string]float64{} - func CompletionRatio2JSONString() string { jsonBytes, err := json.Marshal(CompletionRatio) if err != nil { @@ -174,6 +185,9 @@ func GetCompletionRatio(name string) float64 { if ratio, ok := CompletionRatio[name]; ok { return ratio } + if ratio, ok := DefaultCompletionRatio[name]; ok { + return ratio + } if strings.HasPrefix(name, "gpt-3.5") { if strings.HasSuffix(name, "0125") { // https://openai.com/blog/new-embedding-models-and-api-updates @@ -206,5 +220,8 @@ func GetCompletionRatio(name string) float64 { if strings.HasPrefix(name, "claude-2") { return 2.965517 } + if strings.HasPrefix(name, "mistral-") { + return 3 + } return 1 } diff --git a/controller/model.go b/controller/model.go index 0f33f919..0d0d2658 100644 --- a/controller/model.go +++ b/controller/model.go @@ -6,6 +6,7 @@ import ( "github.com/songquanpeng/one-api/relay/channel/ai360" "github.com/songquanpeng/one-api/relay/channel/baichuan" "github.com/songquanpeng/one-api/relay/channel/minimax" + "github.com/songquanpeng/one-api/relay/channel/mistral" "github.com/songquanpeng/one-api/relay/channel/moonshot" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" @@ -122,6 +123,17 @@ func init() { Parent: nil, }) } + for _, modelName := range mistral.ModelList { + openAIModels = append(openAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "mistralai", + Permission: permission, + Root: modelName, + Parent: nil, + }) + } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { openAIModelsMap[model.Id] = model diff --git a/relay/channel/mistral/constants.go b/relay/channel/mistral/constants.go new file mode 100644 index 00000000..cdb157f5 --- /dev/null +++ b/relay/channel/mistral/constants.go @@ -0,0 +1,10 @@ +package mistral + +var ModelList = []string{ + "open-mistral-7b", + "open-mixtral-8x7b", + "mistral-small-latest", + "mistral-medium-latest", + "mistral-large-latest", + "mistral-embed", +} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 27d0fc27..5a04a768 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -9,6 +9,7 @@ import ( "github.com/songquanpeng/one-api/relay/channel/ai360" "github.com/songquanpeng/one-api/relay/channel/baichuan" "github.com/songquanpeng/one-api/relay/channel/minimax" + "github.com/songquanpeng/one-api/relay/channel/mistral" "github.com/songquanpeng/one-api/relay/channel/moonshot" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" @@ -94,6 +95,8 @@ func (a *Adaptor) GetModelList() []string { return baichuan.ModelList case common.ChannelTypeMinimax: return minimax.ModelList + case common.ChannelTypeMistral: + return mistral.ModelList default: return ModelList } @@ -111,6 +114,8 @@ func (a *Adaptor) GetChannelName() string { return "baichuan" case common.ChannelTypeMinimax: return "minimax" + case common.ChannelTypeMistral: + return "mistralai" default: return "openai" } diff --git a/relay/controller/helper.go b/relay/controller/helper.go index d5078304..89fc69ce 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -172,10 +172,8 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R if err != nil { logger.Error(ctx, "error update user quota cache: "+err.Error()) } - if quota != 0 { - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) - model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) - model.UpdateChannelUsedQuota(meta.ChannelId, quota) - } + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) + model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) + model.UpdateChannelUsedQuota(meta.ChannelId, quota) } diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index 98ceaebf..31c45048 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -29,6 +29,12 @@ export const CHANNEL_OPTIONS = { value: 24, color: 'orange' }, + 28: { + key: 28, + text: 'Mistral AI', + value: 28, + color: 'orange' + }, 15: { key: 15, text: '百度文心千帆', diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index beb0adb1..b21bb15d 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -4,6 +4,7 @@ export const CHANNEL_OPTIONS = [ { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, + { key: 28, text: 'Mistral AI', value: 28, color: 'orange' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, From 2df877a352f3ae150b276602172fc1ebcac21446 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Mar 2024 22:14:07 +0800 Subject: [PATCH 5/9] feat: switch priority when retry (close #1048) --- common/random.go | 8 ++++++++ controller/relay.go | 2 +- middleware/distributor.go | 2 +- model/cache.go | 7 ++++++- 4 files changed, 16 insertions(+), 3 deletions(-) create mode 100644 common/random.go diff --git a/common/random.go b/common/random.go new file mode 100644 index 00000000..44bd2856 --- /dev/null +++ b/common/random.go @@ -0,0 +1,8 @@ +package common + +import "math/rand" + +// RandRange returns a random number between min and max (max is not included) +func RandRange(min, max int) int { + return min + rand.Intn(max-min) +} diff --git a/controller/relay.go b/controller/relay.go index 278c0b32..33a8243d 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -62,7 +62,7 @@ func Relay(c *gin.Context) { retryTimes = 0 } for i := retryTimes; i > 0; i-- { - channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel) + channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, true) if err != nil { logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) break diff --git a/middleware/distributor.go b/middleware/distributor.go index aeb2796a..e845c2f8 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -68,7 +68,7 @@ func Distribute() func(c *gin.Context) { } } requestModel = modelRequest.Model - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) if channel != nil { diff --git a/model/cache.go b/model/cache.go index 04a60348..3c3575b8 100644 --- a/model/cache.go +++ b/model/cache.go @@ -191,7 +191,7 @@ func SyncChannelCache(frequency int) { } } -func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { +func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { if !config.MemoryCacheEnabled { return GetRandomSatisfiedChannel(group, model) } @@ -213,5 +213,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error } } idx := rand.Intn(endIdx) + if ignoreFirstPriority { + if endIdx < len(channels) { // which means there are more than one priority + idx = common.RandRange(endIdx, len(channels)) + } + } return channels[idx], nil } From 10a926b8f3912846c5493cd2378b6d07ff36dcbe Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Mar 2024 22:16:34 +0800 Subject: [PATCH 6/9] feat: only use the top priority when first retry (#1048) --- controller/relay.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controller/relay.go b/controller/relay.go index 33a8243d..9b2d462c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -62,7 +62,7 @@ func Relay(c *gin.Context) { retryTimes = 0 } for i := retryTimes; i > 0; i-- { - channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, true) + channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes) if err != nil { logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) break From c6ace985c2ab8649cf614d252f8d2e7914bcd81c Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Mar 2024 22:51:01 +0800 Subject: [PATCH 7/9] fix: set missing ali parameters (close #1028) --- relay/channel/ali/main.go | 6 ++++++ relay/channel/ali/model.go | 2 ++ 2 files changed, 8 insertions(+) diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go index b9625584..62115d58 100644 --- a/relay/channel/ali/main.go +++ b/relay/channel/ali/main.go @@ -33,6 +33,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { enableSearch = true aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) } + if request.TopP >= 1 { + request.TopP = 0.9999 + } return &ChatRequest{ Model: aliModel, Input: Input{ @@ -42,6 +45,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { EnableSearch: enableSearch, IncrementalOutput: request.Stream, Seed: uint64(request.Seed), + MaxTokens: request.MaxTokens, + Temperature: request.Temperature, + TopP: request.TopP, }, } } diff --git a/relay/channel/ali/model.go b/relay/channel/ali/model.go index 54f13041..76e814d1 100644 --- a/relay/channel/ali/model.go +++ b/relay/channel/ali/model.go @@ -16,6 +16,8 @@ type Parameters struct { Seed uint64 `json:"seed,omitempty"` EnableSearch bool `json:"enable_search,omitempty"` IncrementalOutput bool `json:"incremental_output,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` } type ChatRequest struct { From 95cfb8e8c952d6795acf69e896f30d1ddea08248 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Mar 2024 22:58:41 +0800 Subject: [PATCH 8/9] fix: using the first available model if default model is not found (close #1021) --- controller/channel-test.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/controller/channel-test.go b/controller/channel-test.go index 485d7702..7007e205 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -19,6 +19,7 @@ import ( "net/http/httptest" "net/url" "strconv" + "strings" "sync" "time" @@ -61,6 +62,12 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error } adaptor.Init(meta) modelName := adaptor.GetModelList()[0] + if !strings.Contains(channel.Models, modelName) { + modelNames := strings.Split(channel.Models, ",") + if len(modelNames) > 0 { + modelName = modelNames[0] + } + } request := buildTestRequest() request.Model = modelName meta.OriginModelName, meta.ActualModelName = modelName, modelName From 4fb22ad4ce7cafba7473e19e3a180ede93cf35cc Mon Sep 17 00:00:00 2001 From: momomobinx Date: Sun, 3 Mar 2024 23:50:28 +0800 Subject: [PATCH 9/9] feat: support third part models of baidu (#1046) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 百度千帆平台上的第三方大模型调用 --- relay/channel/baidu/adaptor.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index d2d06ce0..066a8107 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -36,6 +36,8 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" case "Embedding-V1": fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" + default: + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + meta.ActualModelName } var accessToken string var err error