From d8b13b2c077e39e80e1109c51ca4b9649b1b5056 Mon Sep 17 00:00:00 2001 From: MartialBE Date: Sat, 2 Dec 2023 14:29:30 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20=E5=88=A0=E9=99=A4=E5=86=97?= =?UTF-8?q?=E4=BD=99=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel-test.go | 10 +- controller/relay-chat.go | 5 +- controller/relay-completions.go | 5 +- controller/relay-embeddings.go | 5 +- controller/relay-helper.go | 437 -------------------------- controller/relay-image-edits.go | 5 +- controller/relay-image-generations.go | 5 +- controller/relay-image-variationsy.go | 5 +- controller/relay-moderations.go | 5 +- controller/relay-speech.go | 5 +- controller/relay-transcriptions.go | 5 +- controller/relay-translations.go | 5 +- controller/relay-utils.go | 20 +- middleware/distributor.go | 116 ------- 14 files changed, 27 insertions(+), 606 deletions(-) delete mode 100644 controller/relay-helper.go diff --git a/controller/channel-test.go b/controller/channel-test.go index 64556107..62ada94e 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -28,12 +28,8 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req - c.Set("channel", channel.Type) - c.Set("channel_id", channel.Id) - c.Set("channel_name", channel.Name) - c.Set("model_mapping", channel.GetModelMapping()) - c.Set("api_key", channel.Key) - c.Set("base_url", channel.GetBaseURL()) + + setChannelToContext(c, channel) switch channel.Type { case common.ChannelTypePaLM: @@ -70,7 +66,7 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e } isModelMapped := false - modelMap, err := parseModelMapping(c.GetString("model_mapping")) + modelMap, err := parseModelMapping(channel.GetModelMapping()) if err != nil { return err, nil } diff --git a/controller/relay-chat.go b/controller/relay-chat.go index 50c05d4b..17dc8039 100644 --- a/controller/relay-chat.go +++ b/controller/relay-chat.go @@ -24,12 +24,9 @@ func RelayChat(c *gin.Context) { return } - // 写入渠道信息 - setChannelToContext(c, channel) - // 解析模型映射 var isModelMapped bool - modelMap, err := parseModelMapping(c.GetString("model_mapping")) + modelMap, err := parseModelMapping(channel.GetModelMapping()) if err != nil { common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) return diff --git a/controller/relay-completions.go b/controller/relay-completions.go index 0e2c8791..c6f7ab86 100644 --- a/controller/relay-completions.go +++ b/controller/relay-completions.go @@ -24,12 +24,9 @@ func RelayCompletions(c *gin.Context) { return } - // 写入渠道信息 - setChannelToContext(c, channel) - // 解析模型映射 var isModelMapped bool - modelMap, err := parseModelMapping(c.GetString("model_mapping")) + modelMap, err := parseModelMapping(channel.GetModelMapping()) if err != nil { common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) return diff --git a/controller/relay-embeddings.go b/controller/relay-embeddings.go index cc5b2d9a..5d3f9aec 100644 --- a/controller/relay-embeddings.go +++ b/controller/relay-embeddings.go @@ -29,12 +29,9 @@ func RelayEmbeddings(c *gin.Context) { return } - // 写入渠道信息 - setChannelToContext(c, channel) - // 解析模型映射 var isModelMapped bool - modelMap, err := parseModelMapping(c.GetString("model_mapping")) + modelMap, err := parseModelMapping(channel.GetModelMapping()) if err != nil { common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) return diff --git a/controller/relay-helper.go b/controller/relay-helper.go deleted file mode 100644 index 5233d936..00000000 --- a/controller/relay-helper.go +++ /dev/null @@ -1,437 +0,0 @@ -package controller - -import ( - "context" - "errors" - "fmt" - "net/http" - "one-api/common" - "one-api/model" - "one-api/providers" - providers_base "one-api/providers/base" - "one-api/types" - - "github.com/gin-gonic/gin" -) - -func relayHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatusCode { - // 获取请求参数 - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - tokenId := c.GetInt("token_id") - userId := c.GetInt("id") - group := c.GetString("group") - - // 获取 Provider - provider := providers.GetProvider(channelType, c) - if provider == nil { - return common.ErrorWrapper(errors.New("channel not found"), "channel_not_found", http.StatusNotImplemented) - } - - if !provider.SupportAPI(relayMode) { - return common.ErrorWrapper(errors.New("channel does not support this API"), "channel_not_support_api", http.StatusNotImplemented) - } - - modelMap, err := parseModelMapping(c.GetString("model_mapping")) - if err != nil { - return common.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - - quotaInfo := &QuotaInfo{ - modelName: "", - promptTokens: 0, - userId: userId, - channelId: channelId, - tokenId: tokenId, - } - - var usage *types.Usage - var openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode - - switch relayMode { - case common.RelayModeChatCompletions: - usage, openAIErrorWithStatusCode = handleChatCompletions(c, provider, modelMap, quotaInfo, group) - case common.RelayModeCompletions: - usage, openAIErrorWithStatusCode = handleCompletions(c, provider, modelMap, quotaInfo, group) - case common.RelayModeEmbeddings: - usage, openAIErrorWithStatusCode = handleEmbeddings(c, provider, modelMap, quotaInfo, group) - case common.RelayModeModerations: - usage, openAIErrorWithStatusCode = handleModerations(c, provider, modelMap, quotaInfo, group) - case common.RelayModeAudioSpeech: - usage, openAIErrorWithStatusCode = handleSpeech(c, provider, modelMap, quotaInfo, group) - case common.RelayModeAudioTranscription: - usage, openAIErrorWithStatusCode = handleTranscriptions(c, provider, modelMap, quotaInfo, group) - case common.RelayModeAudioTranslation: - usage, openAIErrorWithStatusCode = handleTranslations(c, provider, modelMap, quotaInfo, group) - case common.RelayModeImagesGenerations: - usage, openAIErrorWithStatusCode = handleImageGenerations(c, provider, modelMap, quotaInfo, group) - case common.RelayModeImagesEdits: - usage, openAIErrorWithStatusCode = handleImageEdits(c, provider, modelMap, quotaInfo, group, "edit") - case common.RelayModeImagesVariations: - usage, openAIErrorWithStatusCode = handleImageEdits(c, provider, modelMap, quotaInfo, group, "variation") - default: - return common.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest) - } - - if openAIErrorWithStatusCode != nil { - if quotaInfo.preConsumedQuota != 0 { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - return openAIErrorWithStatusCode - } - - tokenName := c.GetString("token_name") - defer func(ctx context.Context) { - go func() { - err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx) - if err != nil { - common.LogError(ctx, err.Error()) - } - }() - }(c.Request.Context()) - - return nil -} - -func handleChatCompletions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { - var chatRequest types.ChatCompletionRequest - isModelMapped := false - - chatProvider, ok := provider.(providers_base.ChatInterface) - if !ok { - return nil, common.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) - } - - err := common.UnmarshalBodyReusable(c, &chatRequest) - if err != nil { - return nil, common.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - - if chatRequest.Messages == nil || len(chatRequest.Messages) == 0 { - return nil, common.ErrorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) - } - - if modelMap != nil && modelMap[chatRequest.Model] != "" { - chatRequest.Model = modelMap[chatRequest.Model] - isModelMapped = true - } - promptTokens := common.CountTokenMessages(chatRequest.Messages, chatRequest.Model) - - quotaInfo.modelName = chatRequest.Model - quotaInfo.promptTokens = promptTokens - quotaInfo.initQuotaInfo(group) - quota_err := quotaInfo.preQuotaConsumption() - if quota_err != nil { - return nil, quota_err - } - return chatProvider.ChatAction(&chatRequest, isModelMapped, promptTokens) -} - -func handleCompletions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { - var completionRequest types.CompletionRequest - isModelMapped := false - completionProvider, ok := provider.(providers_base.CompletionInterface) - if !ok { - return nil, common.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) - } - - err := common.UnmarshalBodyReusable(c, &completionRequest) - if err != nil { - return nil, common.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - - if completionRequest.Prompt == "" { - return nil, common.ErrorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest) - } - - if modelMap != nil && modelMap[completionRequest.Model] != "" { - completionRequest.Model = modelMap[completionRequest.Model] - isModelMapped = true - } - promptTokens := common.CountTokenInput(completionRequest.Prompt, completionRequest.Model) - - quotaInfo.modelName = completionRequest.Model - quotaInfo.promptTokens = promptTokens - quotaInfo.initQuotaInfo(group) - quota_err := quotaInfo.preQuotaConsumption() - if quota_err != nil { - return nil, quota_err - } - return completionProvider.CompleteAction(&completionRequest, isModelMapped, promptTokens) -} - -func handleEmbeddings(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { - var embeddingsRequest types.EmbeddingRequest - isModelMapped := false - embeddingsProvider, ok := provider.(providers_base.EmbeddingsInterface) - if !ok { - return nil, common.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) - } - - err := common.UnmarshalBodyReusable(c, &embeddingsRequest) - if err != nil { - return nil, common.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - - if embeddingsRequest.Input == "" { - return nil, common.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) - } - - if modelMap != nil && modelMap[embeddingsRequest.Model] != "" { - embeddingsRequest.Model = modelMap[embeddingsRequest.Model] - isModelMapped = true - } - promptTokens := common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model) - - quotaInfo.modelName = embeddingsRequest.Model - quotaInfo.promptTokens = promptTokens - quotaInfo.initQuotaInfo(group) - quota_err := quotaInfo.preQuotaConsumption() - if quota_err != nil { - return nil, quota_err - } - return embeddingsProvider.EmbeddingsAction(&embeddingsRequest, isModelMapped, promptTokens) -} - -func handleModerations(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { - var moderationRequest types.ModerationRequest - isModelMapped := false - moderationProvider, ok := provider.(providers_base.ModerationInterface) - if !ok { - return nil, common.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) - } - - err := common.UnmarshalBodyReusable(c, &moderationRequest) - if err != nil { - return nil, common.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - - if moderationRequest.Input == "" { - return nil, common.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) - } - - if moderationRequest.Model == "" { - moderationRequest.Model = "text-moderation-latest" - } - - if modelMap != nil && modelMap[moderationRequest.Model] != "" { - moderationRequest.Model = modelMap[moderationRequest.Model] - isModelMapped = true - } - promptTokens := common.CountTokenInput(moderationRequest.Input, moderationRequest.Model) - - quotaInfo.modelName = moderationRequest.Model - quotaInfo.promptTokens = promptTokens - quotaInfo.initQuotaInfo(group) - quota_err := quotaInfo.preQuotaConsumption() - if quota_err != nil { - return nil, quota_err - } - return moderationProvider.ModerationAction(&moderationRequest, isModelMapped, promptTokens) -} - -func handleSpeech(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { - var speechRequest types.SpeechAudioRequest - isModelMapped := false - speechProvider, ok := provider.(providers_base.SpeechInterface) - if !ok { - return nil, common.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) - } - - err := common.UnmarshalBodyReusable(c, &speechRequest) - if err != nil { - return nil, common.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - - if speechRequest.Input == "" { - return nil, common.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) - } - - if modelMap != nil && modelMap[speechRequest.Model] != "" { - speechRequest.Model = modelMap[speechRequest.Model] - isModelMapped = true - } - promptTokens := len(speechRequest.Input) - - quotaInfo.modelName = speechRequest.Model - quotaInfo.promptTokens = promptTokens - quotaInfo.initQuotaInfo(group) - quota_err := quotaInfo.preQuotaConsumption() - if quota_err != nil { - return nil, quota_err - } - return speechProvider.SpeechAction(&speechRequest, isModelMapped, promptTokens) -} - -func handleTranscriptions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { - var audioRequest types.AudioRequest - isModelMapped := false - speechProvider, ok := provider.(providers_base.TranscriptionsInterface) - if !ok { - return nil, common.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) - } - - err := common.UnmarshalBodyReusable(c, &audioRequest) - if err != nil { - return nil, common.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - - if audioRequest.File == nil { - fmt.Println(audioRequest) - return nil, common.ErrorWrapper(errors.New("field file is required"), "required_field_missing", http.StatusBadRequest) - } - - if modelMap != nil && modelMap[audioRequest.Model] != "" { - audioRequest.Model = modelMap[audioRequest.Model] - isModelMapped = true - } - promptTokens := 0 - - quotaInfo.modelName = audioRequest.Model - quotaInfo.promptTokens = promptTokens - quotaInfo.initQuotaInfo(group) - quota_err := quotaInfo.preQuotaConsumption() - if quota_err != nil { - return nil, quota_err - } - return speechProvider.TranscriptionsAction(&audioRequest, isModelMapped, promptTokens) -} - -func handleTranslations(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { - var audioRequest types.AudioRequest - isModelMapped := false - speechProvider, ok := provider.(providers_base.TranslationInterface) - if !ok { - return nil, common.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) - } - - err := common.UnmarshalBodyReusable(c, &audioRequest) - if err != nil { - return nil, common.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - - if audioRequest.File == nil { - fmt.Println(audioRequest) - return nil, common.ErrorWrapper(errors.New("field file is required"), "required_field_missing", http.StatusBadRequest) - } - - if modelMap != nil && modelMap[audioRequest.Model] != "" { - audioRequest.Model = modelMap[audioRequest.Model] - isModelMapped = true - } - promptTokens := 0 - - quotaInfo.modelName = audioRequest.Model - quotaInfo.promptTokens = promptTokens - quotaInfo.initQuotaInfo(group) - quota_err := quotaInfo.preQuotaConsumption() - if quota_err != nil { - return nil, quota_err - } - return speechProvider.TranslationAction(&audioRequest, isModelMapped, promptTokens) -} - -func handleImageGenerations(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { - var imageRequest types.ImageRequest - isModelMapped := false - imageGenerationsProvider, ok := provider.(providers_base.ImageGenerationsInterface) - if !ok { - return nil, common.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) - } - - err := common.UnmarshalBodyReusable(c, &imageRequest) - if err != nil { - return nil, common.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-2" - } - - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - - if imageRequest.Quality == "" { - imageRequest.Quality = "standard" - } - - if modelMap != nil && modelMap[imageRequest.Model] != "" { - imageRequest.Model = modelMap[imageRequest.Model] - isModelMapped = true - } - promptTokens, err := common.CountTokenImage(imageRequest) - if err != nil { - return nil, common.ErrorWrapper(err, "count_token_image_failed", http.StatusInternalServerError) - } - - quotaInfo.modelName = imageRequest.Model - quotaInfo.promptTokens = promptTokens - quotaInfo.initQuotaInfo(group) - quota_err := quotaInfo.preQuotaConsumption() - if quota_err != nil { - return nil, quota_err - } - return imageGenerationsProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens) -} - -func handleImageEdits(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string, imageType string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { - var imageEditRequest types.ImageEditRequest - isModelMapped := false - var imageEditsProvider providers_base.ImageEditsInterface - var imageVariations providers_base.ImageVariationsInterface - var ok bool - if imageType == "edit" { - imageEditsProvider, ok = provider.(providers_base.ImageEditsInterface) - if !ok { - return nil, common.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) - } - } else { - imageVariations, ok = provider.(providers_base.ImageVariationsInterface) - if !ok { - return nil, common.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) - } - } - - err := common.UnmarshalBodyReusable(c, &imageEditRequest) - if err != nil { - return nil, common.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - - if imageEditRequest.Model == "" { - imageEditRequest.Model = "dall-e-2" - } - - if imageEditRequest.Size == "" { - imageEditRequest.Size = "1024x1024" - } - - if modelMap != nil && modelMap[imageEditRequest.Model] != "" { - imageEditRequest.Model = modelMap[imageEditRequest.Model] - isModelMapped = true - } - promptTokens, err := common.CountTokenImage(imageEditRequest) - if err != nil { - return nil, common.ErrorWrapper(err, "count_token_image_failed", http.StatusInternalServerError) - } - - quotaInfo.modelName = imageEditRequest.Model - quotaInfo.promptTokens = promptTokens - quotaInfo.initQuotaInfo(group) - quota_err := quotaInfo.preQuotaConsumption() - if quota_err != nil { - return nil, quota_err - } - - if imageType == "edit" { - return imageEditsProvider.ImageEditsAction(&imageEditRequest, isModelMapped, promptTokens) - } - - return imageVariations.ImageVariationsAction(&imageEditRequest, isModelMapped, promptTokens) -} diff --git a/controller/relay-image-edits.go b/controller/relay-image-edits.go index 091e4fb8..fb7c8850 100644 --- a/controller/relay-image-edits.go +++ b/controller/relay-image-edits.go @@ -38,12 +38,9 @@ func RelayImageEdits(c *gin.Context) { return } - // 写入渠道信息 - setChannelToContext(c, channel) - // 解析模型映射 var isModelMapped bool - modelMap, err := parseModelMapping(c.GetString("model_mapping")) + modelMap, err := parseModelMapping(channel.GetModelMapping()) if err != nil { common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) return diff --git a/controller/relay-image-generations.go b/controller/relay-image-generations.go index 3310ba3f..f339b79c 100644 --- a/controller/relay-image-generations.go +++ b/controller/relay-image-generations.go @@ -37,12 +37,9 @@ func RelayImageGenerations(c *gin.Context) { return } - // 写入渠道信息 - setChannelToContext(c, channel) - // 解析模型映射 var isModelMapped bool - modelMap, err := parseModelMapping(c.GetString("model_mapping")) + modelMap, err := parseModelMapping(channel.GetModelMapping()) if err != nil { common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) return diff --git a/controller/relay-image-variationsy.go b/controller/relay-image-variationsy.go index 4bbce43a..c128625a 100644 --- a/controller/relay-image-variationsy.go +++ b/controller/relay-image-variationsy.go @@ -33,12 +33,9 @@ func RelayImageVariations(c *gin.Context) { return } - // 写入渠道信息 - setChannelToContext(c, channel) - // 解析模型映射 var isModelMapped bool - modelMap, err := parseModelMapping(c.GetString("model_mapping")) + modelMap, err := parseModelMapping(channel.GetModelMapping()) if err != nil { common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) return diff --git a/controller/relay-moderations.go b/controller/relay-moderations.go index bd8b6936..2ffda2da 100644 --- a/controller/relay-moderations.go +++ b/controller/relay-moderations.go @@ -29,12 +29,9 @@ func RelayModerations(c *gin.Context) { return } - // 写入渠道信息 - setChannelToContext(c, channel) - // 解析模型映射 var isModelMapped bool - modelMap, err := parseModelMapping(c.GetString("model_mapping")) + modelMap, err := parseModelMapping(channel.GetModelMapping()) if err != nil { common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) return diff --git a/controller/relay-speech.go b/controller/relay-speech.go index dbe0d2a1..03ac3151 100644 --- a/controller/relay-speech.go +++ b/controller/relay-speech.go @@ -25,12 +25,9 @@ func RelaySpeech(c *gin.Context) { return } - // 写入渠道信息 - setChannelToContext(c, channel) - // 解析模型映射 var isModelMapped bool - modelMap, err := parseModelMapping(c.GetString("model_mapping")) + modelMap, err := parseModelMapping(channel.GetModelMapping()) if err != nil { common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) return diff --git a/controller/relay-transcriptions.go b/controller/relay-transcriptions.go index 7bdc1c2a..cf0f1831 100644 --- a/controller/relay-transcriptions.go +++ b/controller/relay-transcriptions.go @@ -25,12 +25,9 @@ func RelayTranscriptions(c *gin.Context) { return } - // 写入渠道信息 - setChannelToContext(c, channel) - // 解析模型映射 var isModelMapped bool - modelMap, err := parseModelMapping(c.GetString("model_mapping")) + modelMap, err := parseModelMapping(channel.GetModelMapping()) if err != nil { common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) return diff --git a/controller/relay-translations.go b/controller/relay-translations.go index 9a16c6ba..776a6f7d 100644 --- a/controller/relay-translations.go +++ b/controller/relay-translations.go @@ -25,12 +25,9 @@ func RelayTranslations(c *gin.Context) { return } - // 写入渠道信息 - setChannelToContext(c, channel) - // 解析模型映射 var isModelMapped bool - modelMap, err := parseModelMapping(c.GetString("model_mapping")) + modelMap, err := parseModelMapping(channel.GetModelMapping()) if err != nil { common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) return diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 6368f191..6e930df3 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -31,13 +31,22 @@ func GetValidFieldName(err error, obj interface{}) string { return err.Error() } -func fetchChannel(c *gin.Context, modelName string) (*model.Channel, bool) { +func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, pass bool) { channelId, ok := c.Get("channelId") if ok { - return fetchChannelById(c, channelId.(int)) - } - return fetchChannelByModel(c, modelName) + channel, pass = fetchChannelById(c, channelId.(int)) + if pass { + return + } + } + channel, pass = fetchChannelByModel(c, modelName) + if pass { + return + } + + setChannelToContext(c, channel) + return } func fetchChannelById(c *gin.Context, channelId any) (*model.Channel, bool) { @@ -91,10 +100,9 @@ func getProvider(c *gin.Context, channelType int, relayMode int) (providersBase. } func setChannelToContext(c *gin.Context, channel *model.Channel) { - c.Set("channel", channel.Type) + // c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) - c.Set("model_mapping", channel.GetModelMapping()) c.Set("api_key", channel.Key) c.Set("base_url", channel.GetBaseURL()) switch channel.Type { diff --git a/middleware/distributor.go b/middleware/distributor.go index 811b512e..72a1b362 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -1,133 +1,17 @@ package middleware import ( - "fmt" - "net/http" - "one-api/common" "one-api/model" - "strconv" - "strings" "github.com/gin-gonic/gin" ) -type ModelRequestInterface interface { - GetModel() string - SetModel(string) -} - -type ModelRequest struct { - Model string `json:"model"` -} - -func (m *ModelRequest) GetModel() string { - return m.Model -} - -func (m *ModelRequest) SetModel(model string) { - m.Model = model -} - -type ModelFormRequest struct { - Model string `form:"model"` -} - -func (m *ModelFormRequest) GetModel() string { - return m.Model -} - -func (m *ModelFormRequest) SetModel(model string) { - m.Model = model -} - func Distribute() func(c *gin.Context) { return func(c *gin.Context) { userId := c.GetInt("id") userGroup, _ := model.CacheGetUserGroup(userId) c.Set("group", userGroup) - var channel *model.Channel - channelId, ok := c.Get("channelId") - if ok { - id, err := strconv.Atoi(channelId.(string)) - if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") - return - } - channel, err = model.GetChannelById(id, true) - if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") - return - } - if channel.Status != common.ChannelStatusEnabled { - abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") - return - } - } else { - // Select a channel for the user - modelRequest := getModelRequest(c) - err := common.UnmarshalBodyReusable(c, modelRequest) - if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的请求") - return - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - if modelRequest.GetModel() == "" { - modelRequest.SetModel("text-moderation-stable") - } - } - if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - if modelRequest.GetModel() == "" { - modelRequest.SetModel(c.Param("model")) - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - if modelRequest.GetModel() == "" { - modelRequest.SetModel("dall-e-2") - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - if modelRequest.GetModel() == "" { - modelRequest.SetModel("whisper-1") - } - } - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.GetModel()) - if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.GetModel()) - if channel != nil { - common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) - message = "数据库一致性已被破坏,请联系管理员" - } - abortWithMessage(c, http.StatusServiceUnavailable, message) - return - } - } - c.Set("channel", channel.Type) - c.Set("channel_id", channel.Id) - c.Set("channel_name", channel.Name) - c.Set("model_mapping", channel.GetModelMapping()) - c.Set("api_key", channel.Key) - // c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - c.Set("base_url", channel.GetBaseURL()) - switch channel.Type { - case common.ChannelTypeAzure: - c.Set("api_version", channel.Other) - case common.ChannelTypeXunfei: - c.Set("api_version", channel.Other) - case common.ChannelTypeAIProxyLibrary: - c.Set("library_id", channel.Other) - } c.Next() } } - -func getModelRequest(c *gin.Context) (modelRequest ModelRequestInterface) { - contentType := c.Request.Header.Get("Content-Type") - if strings.HasPrefix(contentType, "application/json") { - modelRequest = &ModelRequest{} - } else if strings.HasPrefix(contentType, "multipart/form-data") { - modelRequest = &ModelFormRequest{} - } - - return -}