From dd3e79a20db44aa755a46fbe21ef93754eff30b2 Mon Sep 17 00:00:00 2001 From: Buer <42402987+MartialBE@users.noreply.github.com> Date: Wed, 6 Mar 2024 18:01:43 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20channel=20support=20weight?= =?UTF-8?q?=20(#85)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ feat: channel support weight * 💄 improve: show version * 💄 improve: Channel add copy operation * 💄 improve: Channel support batch add --- .github/workflows/docker-image.yml | 10 +- common/constants.go | 2 + controller/channel-billing.go | 2 +- controller/channel-test.go | 14 +- controller/common.go | 75 ++++++++ controller/model.go | 2 +- controller/relay-chat.go | 79 -------- controller/relay-completions.go | 79 -------- controller/relay-embeddings.go | 66 ------- controller/relay-image-edits.go | 79 -------- controller/relay-image-generations.go | 82 -------- controller/relay-image-variationsy.go | 74 -------- controller/relay-moderations.go | 66 ------- controller/relay-speech.go | 62 ------ controller/relay-transcriptions.go | 62 ------ controller/relay-translations.go | 62 ------ controller/relay.go | 63 ------- controller/relay/base.go | 53 ++++++ controller/relay/chat.go | 76 ++++++++ controller/relay/completions.go | 76 ++++++++ controller/relay/embeddings.go | 63 +++++++ controller/relay/image-edits.go | 71 +++++++ controller/relay/image-generations.go | 74 ++++++++ controller/relay/image-variationsy.go | 66 +++++++ controller/relay/main.go | 106 +++++++++++ controller/relay/moderations.go | 62 ++++++ controller/{ => relay}/quota.go | 5 +- controller/relay/speech.go | 58 ++++++ controller/relay/transcriptions.go | 58 ++++++ controller/relay/translations.go | 58 ++++++ controller/{relay-utils.go => relay/utils.go} | 153 ++++++++------- main.go | 4 +- middleware/auth.go | 2 +- model/ability.go | 48 +++++ model/balancer.go | 176 ++++++++++++++++++ model/cache.go | 106 ----------- model/channel.go | 9 +- model/option.go | 3 + router/relay-router.go | 21 ++- web/src/views/Channel/component/EditModal.js | 58 ++++-- web/src/views/Channel/component/TableRow.js | 51 ++++- web/src/views/Channel/index.js | 40 +++- .../Setting/component/OperationSetting.js | 23 ++- .../views/Setting/component/OtherSetting.js | 45 ++++- 44 files changed, 1425 insertions(+), 1019 deletions(-) create mode 100644 controller/common.go delete mode 100644 controller/relay-chat.go delete mode 100644 controller/relay-completions.go delete mode 100644 controller/relay-embeddings.go delete mode 100644 controller/relay-image-edits.go delete mode 100644 controller/relay-image-generations.go delete mode 100644 controller/relay-image-variationsy.go delete mode 100644 controller/relay-moderations.go delete mode 100644 controller/relay-speech.go delete mode 100644 controller/relay-transcriptions.go delete mode 100644 controller/relay-translations.go delete mode 100644 controller/relay.go create mode 100644 controller/relay/base.go create mode 100644 controller/relay/chat.go create mode 100644 controller/relay/completions.go create mode 100644 controller/relay/embeddings.go create mode 100644 controller/relay/image-edits.go create mode 100644 controller/relay/image-generations.go create mode 100644 controller/relay/image-variationsy.go create mode 100644 controller/relay/main.go create mode 100644 controller/relay/moderations.go rename controller/{ => relay}/quota.go (97%) create mode 100644 controller/relay/speech.go create mode 100644 controller/relay/transcriptions.go create mode 100644 controller/relay/translations.go rename controller/{relay-utils.go => relay/utils.go} (56%) create mode 100644 model/balancer.go diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 64f4e9c0..cb916557 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -25,7 +25,15 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 - + - name: Save version info + run: | + TAG=$(git describe --tags --exact-match 2> /dev/null) + if [ $? -eq 0 ]; then + echo $TAG > VERSION + else + HASH=$(git rev-parse --short=7 HEAD) + echo "dev-$HASH" > VERSION + fi - name: Set up QEMU uses: docker/setup-qemu-action@v2 diff --git a/common/constants.go b/common/constants.go index f704fc98..d6f9ddb2 100644 --- a/common/constants.go +++ b/common/constants.go @@ -83,6 +83,8 @@ var QuotaRemindThreshold = 1000 var PreConsumedQuota = 500 var ApproximateTokenEnabled = false var RetryTimes = 0 +var DefaultChannelWeight = uint(1) +var RetryCooldownSeconds = 5 var RootUserEmail = "" diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 83bfce88..5f838067 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -122,7 +122,7 @@ func updateAllChannelsBalance() error { } else { // err is nil & balance <= 0 means quota is used up if balance <= 0 { - disableChannel(channel.Id, channel.Name, "余额不足") + DisableChannel(channel.Id, channel.Name, "余额不足") } } time.Sleep(common.RequestInterval) diff --git a/controller/channel-test.go b/controller/channel-test.go index 37d544c3..1c6206d3 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -140,14 +140,6 @@ func notifyRootUser(subject string, content string) { } } -// disable & notify -func disableChannel(channelId int, channelName string, reason string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) - subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) - notifyRootUser(subject, content) -} - // enable & notify func enableChannel(channelId int, channelName string) { model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) @@ -185,10 +177,10 @@ func testAllChannels(notify bool) error { milliseconds := tok.Sub(tik).Milliseconds() if milliseconds > disableThreshold { err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) - disableChannel(channel.Id, channel.Name, err.Error()) + DisableChannel(channel.Id, channel.Name, err.Error()) } - if isChannelEnabled && shouldDisableChannel(openaiErr, -1) { - disableChannel(channel.Id, channel.Name, err.Error()) + if isChannelEnabled && ShouldDisableChannel(openaiErr, -1) { + DisableChannel(channel.Id, channel.Name, err.Error()) } if !isChannelEnabled && shouldEnableChannel(err, openaiErr) { enableChannel(channel.Id, channel.Name) diff --git a/controller/common.go b/controller/common.go new file mode 100644 index 00000000..90349b2a --- /dev/null +++ b/controller/common.go @@ -0,0 +1,75 @@ +package controller + +import ( + "fmt" + "net/http" + "one-api/common" + "one-api/model" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool { + if !common.AutomaticEnableChannelEnabled { + return false + } + if err != nil { + return false + } + if openAIErr != nil { + return false + } + return true +} + +func ShouldDisableChannel(err *types.OpenAIError, statusCode int) bool { + if !common.AutomaticDisableChannelEnabled { + return false + } + + if err == nil { + return false + } + + if statusCode == http.StatusUnauthorized { + return true + } + + if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { + return true + } + return false +} + +// disable & notify +func DisableChannel(channelId int, channelName string, reason string) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) + notifyRootUser(subject, content) +} + +func RelayNotImplemented(c *gin.Context) { + err := types.OpenAIError{ + Message: "API not implemented", + Type: "one_api_error", + Param: "", + Code: "api_not_implemented", + } + c.JSON(http.StatusNotImplemented, gin.H{ + "error": err, + }) +} + +func RelayNotFound(c *gin.Context) { + err := types.OpenAIError{ + Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), + Type: "invalid_request_error", + Param: "", + Code: "", + } + c.JSON(http.StatusNotFound, gin.H{ + "error": err, + }) +} diff --git a/controller/model.go b/controller/model.go index 61df4a79..98f99e50 100644 --- a/controller/model.go +++ b/controller/model.go @@ -70,7 +70,7 @@ func ListModels(c *gin.Context) { groupName = user.Group } - models, err := model.CacheGetGroupModels(groupName) + models, err := model.ChannelGroup.GetGroupModels(groupName) if err != nil { common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error()) return diff --git a/controller/relay-chat.go b/controller/relay-chat.go deleted file mode 100644 index 61267311..00000000 --- a/controller/relay-chat.go +++ /dev/null @@ -1,79 +0,0 @@ -package controller - -import ( - "math" - "net/http" - "one-api/common" - "one-api/common/requester" - providersBase "one-api/providers/base" - "one-api/types" - - "github.com/gin-gonic/gin" -) - -func RelayChat(c *gin.Context) { - - var chatRequest types.ChatCompletionRequest - if err := common.UnmarshalBodyReusable(c, &chatRequest); err != nil { - common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) - return - } - - if chatRequest.MaxTokens < 0 || chatRequest.MaxTokens > math.MaxInt32/2 { - common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid") - return - } - - // 获取供应商 - provider, modelName, fail := getProvider(c, chatRequest.Model) - if fail { - return - } - chatRequest.Model = modelName - - chatProvider, ok := provider.(providersBase.ChatInterface) - if !ok { - common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") - return - } - - // 获取Input Tokens - promptTokens := common.CountTokenMessages(chatRequest.Messages, chatRequest.Model) - - usage := &types.Usage{ - PromptTokens: promptTokens, - } - provider.SetUsage(usage) - - quotaInfo, errWithCode := generateQuotaInfo(c, chatRequest.Model, promptTokens) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - - if chatRequest.Stream { - var response requester.StreamReaderInterface[string] - response, errWithCode = chatProvider.CreateChatCompletionStream(&chatRequest) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - errWithCode = responseStreamClient(c, response) - } else { - var response *types.ChatCompletionResponse - response, errWithCode = chatProvider.CreateChatCompletion(&chatRequest) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - errWithCode = responseJsonClient(c, response) - } - - // 如果报错,则退还配额 - if errWithCode != nil { - quotaInfo.undo(c, errWithCode) - return - } - - quotaInfo.consume(c, usage) -} diff --git a/controller/relay-completions.go b/controller/relay-completions.go deleted file mode 100644 index 0898a016..00000000 --- a/controller/relay-completions.go +++ /dev/null @@ -1,79 +0,0 @@ -package controller - -import ( - "math" - "net/http" - "one-api/common" - "one-api/common/requester" - providersBase "one-api/providers/base" - "one-api/types" - - "github.com/gin-gonic/gin" -) - -func RelayCompletions(c *gin.Context) { - - var completionRequest types.CompletionRequest - if err := common.UnmarshalBodyReusable(c, &completionRequest); err != nil { - common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) - return - } - - if completionRequest.MaxTokens < 0 || completionRequest.MaxTokens > math.MaxInt32/2 { - common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid") - return - } - - // 获取供应商 - provider, modelName, fail := getProvider(c, completionRequest.Model) - if fail { - return - } - completionRequest.Model = modelName - - completionProvider, ok := provider.(providersBase.CompletionInterface) - if !ok { - common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") - return - } - - // 获取Input Tokens - promptTokens := common.CountTokenInput(completionRequest.Prompt, completionRequest.Model) - - usage := &types.Usage{ - PromptTokens: promptTokens, - } - provider.SetUsage(usage) - - quotaInfo, errWithCode := generateQuotaInfo(c, completionRequest.Model, promptTokens) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - - if completionRequest.Stream { - var response requester.StreamReaderInterface[string] - response, errWithCode = completionProvider.CreateCompletionStream(&completionRequest) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - errWithCode = responseStreamClient(c, response) - } else { - var response *types.CompletionResponse - response, errWithCode = completionProvider.CreateCompletion(&completionRequest) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - errWithCode = responseJsonClient(c, response) - } - - // 如果报错,则退还配额 - if errWithCode != nil { - quotaInfo.undo(c, errWithCode) - return - } - - quotaInfo.consume(c, usage) -} diff --git a/controller/relay-embeddings.go b/controller/relay-embeddings.go deleted file mode 100644 index 58dffc48..00000000 --- a/controller/relay-embeddings.go +++ /dev/null @@ -1,66 +0,0 @@ -package controller - -import ( - "net/http" - "one-api/common" - providersBase "one-api/providers/base" - "one-api/types" - "strings" - - "github.com/gin-gonic/gin" -) - -func RelayEmbeddings(c *gin.Context) { - - var embeddingsRequest types.EmbeddingRequest - if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - embeddingsRequest.Model = c.Param("model") - } - - if err := common.UnmarshalBodyReusable(c, &embeddingsRequest); err != nil { - common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) - return - } - - // 获取供应商 - provider, modelName, fail := getProvider(c, embeddingsRequest.Model) - if fail { - return - } - embeddingsRequest.Model = modelName - - embeddingsProvider, ok := provider.(providersBase.EmbeddingsInterface) - if !ok { - common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") - return - } - - // 获取Input Tokens - promptTokens := common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model) - - usage := &types.Usage{ - PromptTokens: promptTokens, - } - provider.SetUsage(usage) - - quotaInfo, errWithCode := generateQuotaInfo(c, embeddingsRequest.Model, promptTokens) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - - response, errWithCode := embeddingsProvider.CreateEmbeddings(&embeddingsRequest) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - errWithCode = responseJsonClient(c, response) - - // 如果报错,则退还配额 - if errWithCode != nil { - quotaInfo.undo(c, errWithCode) - return - } - - quotaInfo.consume(c, usage) -} diff --git a/controller/relay-image-edits.go b/controller/relay-image-edits.go deleted file mode 100644 index ef301e2a..00000000 --- a/controller/relay-image-edits.go +++ /dev/null @@ -1,79 +0,0 @@ -package controller - -import ( - "net/http" - "one-api/common" - providersBase "one-api/providers/base" - "one-api/types" - - "github.com/gin-gonic/gin" -) - -func RelayImageEdits(c *gin.Context) { - - var imageEditRequest types.ImageEditRequest - - if err := common.UnmarshalBodyReusable(c, &imageEditRequest); err != nil { - common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) - return - } - - if imageEditRequest.Prompt == "" { - common.AbortWithMessage(c, http.StatusBadRequest, "field prompt is required") - return - } - - if imageEditRequest.Model == "" { - imageEditRequest.Model = "dall-e-2" - } - - if imageEditRequest.Size == "" { - imageEditRequest.Size = "1024x1024" - } - - // 获取供应商 - provider, modelName, fail := getProvider(c, imageEditRequest.Model) - if fail { - return - } - imageEditRequest.Model = modelName - - imageEditsProvider, ok := provider.(providersBase.ImageEditsInterface) - if !ok { - common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") - return - } - - // 获取Input Tokens - promptTokens, err := common.CountTokenImage(imageEditRequest) - if err != nil { - common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) - return - } - - usage := &types.Usage{ - PromptTokens: promptTokens, - } - provider.SetUsage(usage) - - quotaInfo, errWithCode := generateQuotaInfo(c, imageEditRequest.Model, promptTokens) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - - response, errWithCode := imageEditsProvider.CreateImageEdits(&imageEditRequest) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - errWithCode = responseJsonClient(c, response) - - // 如果报错,则退还配额 - if errWithCode != nil { - quotaInfo.undo(c, errWithCode) - return - } - - quotaInfo.consume(c, usage) -} diff --git a/controller/relay-image-generations.go b/controller/relay-image-generations.go deleted file mode 100644 index 4332274d..00000000 --- a/controller/relay-image-generations.go +++ /dev/null @@ -1,82 +0,0 @@ -package controller - -import ( - "net/http" - "one-api/common" - providersBase "one-api/providers/base" - "one-api/types" - - "github.com/gin-gonic/gin" -) - -func RelayImageGenerations(c *gin.Context) { - - var imageRequest types.ImageRequest - - if err := common.UnmarshalBodyReusable(c, &imageRequest); err != nil { - common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) - return - } - - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-2" - } - - if imageRequest.N == 0 { - imageRequest.N = 1 - } - - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - - if imageRequest.Quality == "" { - imageRequest.Quality = "standard" - } - - // 获取供应商 - provider, modelName, fail := getProvider(c, imageRequest.Model) - if fail { - return - } - imageRequest.Model = modelName - - imageGenerationsProvider, ok := provider.(providersBase.ImageGenerationsInterface) - if !ok { - common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") - return - } - - // 获取Input Tokens - promptTokens, err := common.CountTokenImage(imageRequest) - if err != nil { - common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) - return - } - - usage := &types.Usage{ - PromptTokens: promptTokens, - } - provider.SetUsage(usage) - - quotaInfo, errWithCode := generateQuotaInfo(c, imageRequest.Model, promptTokens) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - - response, errWithCode := imageGenerationsProvider.CreateImageGenerations(&imageRequest) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - errWithCode = responseJsonClient(c, response) - - // 如果报错,则退还配额 - if errWithCode != nil { - quotaInfo.undo(c, errWithCode) - return - } - - quotaInfo.consume(c, usage) -} diff --git a/controller/relay-image-variationsy.go b/controller/relay-image-variationsy.go deleted file mode 100644 index 2c9069ee..00000000 --- a/controller/relay-image-variationsy.go +++ /dev/null @@ -1,74 +0,0 @@ -package controller - -import ( - "net/http" - "one-api/common" - providersBase "one-api/providers/base" - "one-api/types" - - "github.com/gin-gonic/gin" -) - -func RelayImageVariations(c *gin.Context) { - - var imageEditRequest types.ImageEditRequest - - if err := common.UnmarshalBodyReusable(c, &imageEditRequest); err != nil { - common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) - return - } - - if imageEditRequest.Model == "" { - imageEditRequest.Model = "dall-e-2" - } - - if imageEditRequest.Size == "" { - imageEditRequest.Size = "1024x1024" - } - - // 获取供应商 - provider, modelName, fail := getProvider(c, imageEditRequest.Model) - if fail { - return - } - imageEditRequest.Model = modelName - - imageVariations, ok := provider.(providersBase.ImageVariationsInterface) - if !ok { - common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") - return - } - - // 获取Input Tokens - promptTokens, err := common.CountTokenImage(imageEditRequest) - if err != nil { - common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) - return - } - - usage := &types.Usage{ - PromptTokens: promptTokens, - } - provider.SetUsage(usage) - - quotaInfo, errWithCode := generateQuotaInfo(c, imageEditRequest.Model, promptTokens) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - - response, errWithCode := imageVariations.CreateImageVariations(&imageEditRequest) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - errWithCode = responseJsonClient(c, response) - - // 如果报错,则退还配额 - if errWithCode != nil { - quotaInfo.undo(c, errWithCode) - return - } - - quotaInfo.consume(c, usage) -} diff --git a/controller/relay-moderations.go b/controller/relay-moderations.go deleted file mode 100644 index 136b6fdd..00000000 --- a/controller/relay-moderations.go +++ /dev/null @@ -1,66 +0,0 @@ -package controller - -import ( - "net/http" - "one-api/common" - providersBase "one-api/providers/base" - "one-api/types" - - "github.com/gin-gonic/gin" -) - -func RelayModerations(c *gin.Context) { - - var moderationRequest types.ModerationRequest - - if err := common.UnmarshalBodyReusable(c, &moderationRequest); err != nil { - common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) - return - } - - if moderationRequest.Model == "" { - moderationRequest.Model = "text-moderation-stable" - } - - // 获取供应商 - provider, modelName, fail := getProvider(c, moderationRequest.Model) - if fail { - return - } - moderationRequest.Model = modelName - - moderationProvider, ok := provider.(providersBase.ModerationInterface) - if !ok { - common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") - return - } - - // 获取Input Tokens - promptTokens := common.CountTokenInput(moderationRequest.Input, moderationRequest.Model) - - usage := &types.Usage{ - PromptTokens: promptTokens, - } - provider.SetUsage(usage) - - quotaInfo, errWithCode := generateQuotaInfo(c, moderationRequest.Model, promptTokens) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - - response, errWithCode := moderationProvider.CreateModeration(&moderationRequest) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - errWithCode = responseJsonClient(c, response) - - // 如果报错,则退还配额 - if errWithCode != nil { - quotaInfo.undo(c, errWithCode) - return - } - - quotaInfo.consume(c, usage) -} diff --git a/controller/relay-speech.go b/controller/relay-speech.go deleted file mode 100644 index ec7fd7ce..00000000 --- a/controller/relay-speech.go +++ /dev/null @@ -1,62 +0,0 @@ -package controller - -import ( - "net/http" - "one-api/common" - providersBase "one-api/providers/base" - "one-api/types" - - "github.com/gin-gonic/gin" -) - -func RelaySpeech(c *gin.Context) { - - var speechRequest types.SpeechAudioRequest - - if err := common.UnmarshalBodyReusable(c, &speechRequest); err != nil { - common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) - return - } - - // 获取供应商 - provider, modelName, fail := getProvider(c, speechRequest.Model) - if fail { - return - } - speechRequest.Model = modelName - - speechProvider, ok := provider.(providersBase.SpeechInterface) - if !ok { - common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") - return - } - - // 获取Input Tokens - promptTokens := len(speechRequest.Input) - - usage := &types.Usage{ - PromptTokens: promptTokens, - } - provider.SetUsage(usage) - - quotaInfo, errWithCode := generateQuotaInfo(c, speechRequest.Model, promptTokens) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - - response, errWithCode := speechProvider.CreateSpeech(&speechRequest) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - errWithCode = responseMultipart(c, response) - - // 如果报错,则退还配额 - if errWithCode != nil { - quotaInfo.undo(c, errWithCode) - return - } - - quotaInfo.consume(c, usage) -} diff --git a/controller/relay-transcriptions.go b/controller/relay-transcriptions.go deleted file mode 100644 index a6005963..00000000 --- a/controller/relay-transcriptions.go +++ /dev/null @@ -1,62 +0,0 @@ -package controller - -import ( - "net/http" - "one-api/common" - providersBase "one-api/providers/base" - "one-api/types" - - "github.com/gin-gonic/gin" -) - -func RelayTranscriptions(c *gin.Context) { - - var audioRequest types.AudioRequest - - if err := common.UnmarshalBodyReusable(c, &audioRequest); err != nil { - common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) - return - } - - // 获取供应商 - provider, modelName, fail := getProvider(c, audioRequest.Model) - if fail { - return - } - audioRequest.Model = modelName - - transcriptionsProvider, ok := provider.(providersBase.TranscriptionsInterface) - if !ok { - common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") - return - } - - // 获取Input Tokens - promptTokens := 0 - - usage := &types.Usage{ - PromptTokens: promptTokens, - } - provider.SetUsage(usage) - - quotaInfo, errWithCode := generateQuotaInfo(c, audioRequest.Model, promptTokens) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - - response, errWithCode := transcriptionsProvider.CreateTranscriptions(&audioRequest) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - errWithCode = responseCustom(c, response) - - // 如果报错,则退还配额 - if errWithCode != nil { - quotaInfo.undo(c, errWithCode) - return - } - - quotaInfo.consume(c, usage) -} diff --git a/controller/relay-translations.go b/controller/relay-translations.go deleted file mode 100644 index c13935eb..00000000 --- a/controller/relay-translations.go +++ /dev/null @@ -1,62 +0,0 @@ -package controller - -import ( - "net/http" - "one-api/common" - providersBase "one-api/providers/base" - "one-api/types" - - "github.com/gin-gonic/gin" -) - -func RelayTranslations(c *gin.Context) { - - var audioRequest types.AudioRequest - - if err := common.UnmarshalBodyReusable(c, &audioRequest); err != nil { - common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) - return - } - - // 获取供应商 - provider, modelName, fail := getProvider(c, audioRequest.Model) - if fail { - return - } - audioRequest.Model = modelName - - translationProvider, ok := provider.(providersBase.TranslationInterface) - if !ok { - common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented") - return - } - - // 获取Input Tokens - promptTokens := 0 - - usage := &types.Usage{ - PromptTokens: promptTokens, - } - provider.SetUsage(usage) - - quotaInfo, errWithCode := generateQuotaInfo(c, audioRequest.Model, promptTokens) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - - response, errWithCode := translationProvider.CreateTranslation(&audioRequest) - if errWithCode != nil { - errorHelper(c, errWithCode) - return - } - errWithCode = responseCustom(c, response) - - // 如果报错,则退还配额 - if errWithCode != nil { - quotaInfo.undo(c, errWithCode) - return - } - - quotaInfo.consume(c, usage) -} diff --git a/controller/relay.go b/controller/relay.go deleted file mode 100644 index 8f5d260a..00000000 --- a/controller/relay.go +++ /dev/null @@ -1,63 +0,0 @@ -package controller - -import ( - "fmt" - "net/http" - "one-api/common" - "one-api/types" - "strconv" - - "github.com/gin-gonic/gin" -) - -func RelayNotImplemented(c *gin.Context) { - err := types.OpenAIError{ - Message: "API not implemented", - Type: "one_api_error", - Param: "", - Code: "api_not_implemented", - } - c.JSON(http.StatusNotImplemented, gin.H{ - "error": err, - }) -} - -func RelayNotFound(c *gin.Context) { - err := types.OpenAIError{ - Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), - Type: "invalid_request_error", - Param: "", - Code: "", - } - c.JSON(http.StatusNotFound, gin.H{ - "error": err, - }) -} - -func errorHelper(c *gin.Context, err *types.OpenAIErrorWithStatusCode) { - requestId := c.GetString(common.RequestIdKey) - retryTimesStr := c.Query("retry") - retryTimes, _ := strconv.Atoi(retryTimesStr) - if retryTimesStr == "" { - retryTimes = common.RetryTimes - } - if retryTimes > 0 { - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) - } else { - if err.StatusCode == http.StatusTooManyRequests { - err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" - } - err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) - c.JSON(err.StatusCode, gin.H{ - "error": err.OpenAIError, - }) - } - channelId := c.GetInt("channel_id") - common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) - // https://platform.openai.com/docs/guides/error-codes/api-errors - if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { - channelId := c.GetInt("channel_id") - channelName := c.GetString("channel_name") - disableChannel(channelId, channelName, err.Message) - } -} diff --git a/controller/relay/base.go b/controller/relay/base.go new file mode 100644 index 00000000..0f644f02 --- /dev/null +++ b/controller/relay/base.go @@ -0,0 +1,53 @@ +package relay + +import ( + "one-api/types" + + providersBase "one-api/providers/base" + + "github.com/gin-gonic/gin" +) + +type relayBase struct { + c *gin.Context + provider providersBase.ProviderInterface + originalModel string + modelName string +} + +type RelayBaseInterface interface { + send() (err *types.OpenAIErrorWithStatusCode, done bool) + getPromptTokens() (int, error) + setRequest() error + setProvider(modelName string) error + getProvider() providersBase.ProviderInterface + getOriginalModel() string + getModelName() string + getContext() *gin.Context +} + +func (r *relayBase) setProvider(modelName string) error { + provider, modelName, fail := getProvider(r.c, modelName) + if fail != nil { + return fail + } + r.provider = provider + r.modelName = modelName + return nil +} + +func (r *relayBase) getContext() *gin.Context { + return r.c +} + +func (r *relayBase) getProvider() providersBase.ProviderInterface { + return r.provider +} + +func (r *relayBase) getOriginalModel() string { + return r.originalModel +} + +func (r *relayBase) getModelName() string { + return r.modelName +} diff --git a/controller/relay/chat.go b/controller/relay/chat.go new file mode 100644 index 00000000..33e2b469 --- /dev/null +++ b/controller/relay/chat.go @@ -0,0 +1,76 @@ +package relay + +import ( + "errors" + "math" + "net/http" + "one-api/common" + "one-api/common/requester" + providersBase "one-api/providers/base" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +type relayChat struct { + relayBase + chatRequest types.ChatCompletionRequest +} + +func NewRelayChat(c *gin.Context) *relayChat { + relay := &relayChat{} + relay.c = c + return relay +} + +func (r *relayChat) setRequest() error { + if err := common.UnmarshalBodyReusable(r.c, &r.chatRequest); err != nil { + return err + } + + if r.chatRequest.MaxTokens < 0 || r.chatRequest.MaxTokens > math.MaxInt32/2 { + return errors.New("max_tokens is invalid") + } + + r.originalModel = r.chatRequest.Model + + return nil +} + +func (r *relayChat) getPromptTokens() (int, error) { + return common.CountTokenMessages(r.chatRequest.Messages, r.modelName), nil +} + +func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) { + chatProvider, ok := r.provider.(providersBase.ChatInterface) + if !ok { + err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable) + done = true + return + } + + r.chatRequest.Model = r.modelName + + if r.chatRequest.Stream { + var response requester.StreamReaderInterface[string] + response, err = chatProvider.CreateChatCompletionStream(&r.chatRequest) + if err != nil { + return + } + + err = responseStreamClient(r.c, response) + } else { + var response *types.ChatCompletionResponse + response, err = chatProvider.CreateChatCompletion(&r.chatRequest) + if err != nil { + return + } + err = responseJsonClient(r.c, response) + } + + if err != nil { + done = true + } + + return +} diff --git a/controller/relay/completions.go b/controller/relay/completions.go new file mode 100644 index 00000000..fdfbd03e --- /dev/null +++ b/controller/relay/completions.go @@ -0,0 +1,76 @@ +package relay + +import ( + "errors" + "math" + "net/http" + "one-api/common" + "one-api/common/requester" + providersBase "one-api/providers/base" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +type relayCompletions struct { + relayBase + request types.CompletionRequest +} + +func NewRelayCompletions(c *gin.Context) *relayCompletions { + relay := &relayCompletions{} + relay.c = c + return relay +} + +func (r *relayCompletions) setRequest() error { + if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil { + return err + } + + if r.request.MaxTokens < 0 || r.request.MaxTokens > math.MaxInt32/2 { + return errors.New("max_tokens is invalid") + } + + r.originalModel = r.request.Model + + return nil +} + +func (r *relayCompletions) getPromptTokens() (int, error) { + return common.CountTokenInput(r.request.Prompt, r.modelName), nil +} + +func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bool) { + provider, ok := r.provider.(providersBase.CompletionInterface) + if !ok { + err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable) + done = true + return + } + + r.request.Model = r.modelName + + if r.request.Stream { + var response requester.StreamReaderInterface[string] + response, err = provider.CreateCompletionStream(&r.request) + if err != nil { + return + } + + err = responseStreamClient(r.c, response) + } else { + var response *types.CompletionResponse + response, err = provider.CreateCompletion(&r.request) + if err != nil { + return + } + err = responseJsonClient(r.c, response) + } + + if err != nil { + done = true + } + + return +} diff --git a/controller/relay/embeddings.go b/controller/relay/embeddings.go new file mode 100644 index 00000000..5a2a3bf7 --- /dev/null +++ b/controller/relay/embeddings.go @@ -0,0 +1,63 @@ +package relay + +import ( + "net/http" + "one-api/common" + providersBase "one-api/providers/base" + "one-api/types" + "strings" + + "github.com/gin-gonic/gin" +) + +type relayEmbeddings struct { + relayBase + request types.EmbeddingRequest +} + +func NewRelayEmbeddings(c *gin.Context) *relayEmbeddings { + relay := &relayEmbeddings{} + relay.c = c + return relay +} + +func (r *relayEmbeddings) setRequest() error { + if strings.HasSuffix(r.c.Request.URL.Path, "embeddings") { + r.request.Model = r.c.Param("model") + } + + if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil { + return err + } + + r.originalModel = r.request.Model + + return nil +} + +func (r *relayEmbeddings) getPromptTokens() (int, error) { + return common.CountTokenInput(r.request.Input, r.modelName), nil +} + +func (r *relayEmbeddings) send() (err *types.OpenAIErrorWithStatusCode, done bool) { + provider, ok := r.provider.(providersBase.EmbeddingsInterface) + if !ok { + err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable) + done = true + return + } + + r.request.Model = r.modelName + + response, err := provider.CreateEmbeddings(&r.request) + if err != nil { + return + } + err = responseJsonClient(r.c, response) + + if err != nil { + done = true + } + + return +} diff --git a/controller/relay/image-edits.go b/controller/relay/image-edits.go new file mode 100644 index 00000000..16a77bdf --- /dev/null +++ b/controller/relay/image-edits.go @@ -0,0 +1,71 @@ +package relay + +import ( + "errors" + "net/http" + "one-api/common" + providersBase "one-api/providers/base" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +type relayImageEdits struct { + relayBase + request types.ImageEditRequest +} + +func NewRelayImageEdits(c *gin.Context) *relayImageEdits { + relay := &relayImageEdits{} + relay.c = c + return relay +} + +func (r *relayImageEdits) setRequest() error { + if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil { + return err + } + + if r.request.Prompt == "" { + return errors.New("field prompt is required") + } + + if r.request.Model == "" { + r.request.Model = "dall-e-2" + } + + if r.request.Size == "" { + r.request.Size = "1024x1024" + } + + r.originalModel = r.request.Model + + return nil +} + +func (r *relayImageEdits) getPromptTokens() (int, error) { + return common.CountTokenImage(r.request) +} + +func (r *relayImageEdits) send() (err *types.OpenAIErrorWithStatusCode, done bool) { + provider, ok := r.provider.(providersBase.ImageEditsInterface) + if !ok { + err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable) + done = true + return + } + + r.request.Model = r.modelName + + response, err := provider.CreateImageEdits(&r.request) + if err != nil { + return + } + err = responseJsonClient(r.c, response) + + if err != nil { + done = true + } + + return +} diff --git a/controller/relay/image-generations.go b/controller/relay/image-generations.go new file mode 100644 index 00000000..c63e3c91 --- /dev/null +++ b/controller/relay/image-generations.go @@ -0,0 +1,74 @@ +package relay + +import ( + "net/http" + "one-api/common" + providersBase "one-api/providers/base" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +type relayImageGenerations struct { + relayBase + request types.ImageRequest +} + +func NewRelayImageGenerations(c *gin.Context) *relayImageGenerations { + relay := &relayImageGenerations{} + relay.c = c + return relay +} + +func (r *relayImageGenerations) setRequest() error { + if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil { + return err + } + + if r.request.Model == "" { + r.request.Model = "dall-e-2" + } + + if r.request.N == 0 { + r.request.N = 1 + } + + if r.request.Size == "" { + r.request.Size = "1024x1024" + } + + if r.request.Quality == "" { + r.request.Quality = "standard" + } + + r.originalModel = r.request.Model + + return nil +} + +func (r *relayImageGenerations) getPromptTokens() (int, error) { + return common.CountTokenImage(r.request) +} + +func (r *relayImageGenerations) send() (err *types.OpenAIErrorWithStatusCode, done bool) { + provider, ok := r.provider.(providersBase.ImageGenerationsInterface) + if !ok { + err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable) + done = true + return + } + + r.request.Model = r.modelName + + response, err := provider.CreateImageGenerations(&r.request) + if err != nil { + return + } + err = responseJsonClient(r.c, response) + + if err != nil { + done = true + } + + return +} diff --git a/controller/relay/image-variationsy.go b/controller/relay/image-variationsy.go new file mode 100644 index 00000000..2ef3decc --- /dev/null +++ b/controller/relay/image-variationsy.go @@ -0,0 +1,66 @@ +package relay + +import ( + "net/http" + "one-api/common" + providersBase "one-api/providers/base" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +type relayImageVariations struct { + relayBase + request types.ImageEditRequest +} + +func NewRelayImageVariations(c *gin.Context) *relayImageVariations { + relay := &relayImageVariations{} + relay.c = c + return relay +} + +func (r *relayImageVariations) setRequest() error { + if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil { + return err + } + + if r.request.Model == "" { + r.request.Model = "dall-e-2" + } + + if r.request.Size == "" { + r.request.Size = "1024x1024" + } + + r.originalModel = r.request.Model + + return nil +} + +func (r *relayImageVariations) getPromptTokens() (int, error) { + return common.CountTokenImage(r.request) +} + +func (r *relayImageVariations) send() (err *types.OpenAIErrorWithStatusCode, done bool) { + provider, ok := r.provider.(providersBase.ImageVariationsInterface) + if !ok { + err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable) + done = true + return + } + + r.request.Model = r.modelName + + response, err := provider.CreateImageVariations(&r.request) + if err != nil { + return + } + err = responseJsonClient(r.c, response) + + if err != nil { + done = true + } + + return +} diff --git a/controller/relay/main.go b/controller/relay/main.go new file mode 100644 index 00000000..6329dcb5 --- /dev/null +++ b/controller/relay/main.go @@ -0,0 +1,106 @@ +package relay + +import ( + "fmt" + "net/http" + "one-api/common" + "one-api/model" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +func Relay(c *gin.Context) { + relay := Path2Relay(c, c.Request.URL.Path) + if relay == nil { + common.AbortWithMessage(c, http.StatusNotFound, "Not Found") + return + } + + if err := relay.setRequest(); err != nil { + common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) + return + } + + if err := relay.setProvider(relay.getOriginalModel()); err != nil { + common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error()) + return + } + + apiErr, done := RelayHandler(relay) + if apiErr == nil { + return + } + + channel := relay.getProvider().GetChannel() + go processChannelRelayError(c.Request.Context(), channel.Id, channel.Name, apiErr) + + retryTimes := common.RetryTimes + if done || !shouldRetry(c, apiErr.StatusCode) { + common.LogError(c.Request.Context(), fmt.Sprintf("relay error happen, status code is %d, won't retry in this case", apiErr.StatusCode)) + retryTimes = 0 + } + + for i := retryTimes; i > 0; i-- { + // 冻结通道 + model.ChannelGroup.Cooldowns(channel.Id) + if err := relay.setProvider(relay.getOriginalModel()); err != nil { + continue + } + + channel = relay.getProvider().GetChannel() + common.LogError(c.Request.Context(), fmt.Sprintf("using channel #%d(%s) to retry (remain times %d)", channel.Id, channel.Name, i)) + apiErr, done = RelayHandler(relay) + if apiErr == nil { + return + } + go processChannelRelayError(c.Request.Context(), channel.Id, channel.Name, apiErr) + if done || !shouldRetry(c, apiErr.StatusCode) { + break + } + } + + if apiErr != nil { + requestId := c.GetString(common.RequestIdKey) + if apiErr.StatusCode == http.StatusTooManyRequests { + apiErr.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" + } + apiErr.OpenAIError.Message = common.MessageWithRequestId(apiErr.OpenAIError.Message, requestId) + c.JSON(apiErr.StatusCode, gin.H{ + "error": apiErr.OpenAIError, + }) + + } +} + +func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCode, done bool) { + promptTokens, tonkeErr := relay.getPromptTokens() + if tonkeErr != nil { + err = common.ErrorWrapper(tonkeErr, "token_error", http.StatusBadRequest) + done = true + return + } + + usage := &types.Usage{ + PromptTokens: promptTokens, + } + + relay.getProvider().SetUsage(usage) + + var quotaInfo *QuotaInfo + quotaInfo, err = generateQuotaInfo(relay.getContext(), relay.getModelName(), promptTokens) + if err != nil { + done = true + return + } + + err, done = relay.send() + + if err != nil { + quotaInfo.undo(relay.getContext()) + return + } + + quotaInfo.consume(relay.getContext(), usage) + return +} diff --git a/controller/relay/moderations.go b/controller/relay/moderations.go new file mode 100644 index 00000000..471f1640 --- /dev/null +++ b/controller/relay/moderations.go @@ -0,0 +1,62 @@ +package relay + +import ( + "net/http" + "one-api/common" + providersBase "one-api/providers/base" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +type relayModerations struct { + relayBase + request types.ModerationRequest +} + +func NewRelayModerations(c *gin.Context) *relayModerations { + relay := &relayModerations{} + relay.c = c + return relay +} + +func (r *relayModerations) setRequest() error { + if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil { + return err + } + + if r.request.Model == "" { + r.request.Model = "text-moderation-stable" + } + + r.originalModel = r.request.Model + + return nil +} + +func (r *relayModerations) getPromptTokens() (int, error) { + return common.CountTokenInput(r.request.Input, r.modelName), nil +} + +func (r *relayModerations) send() (err *types.OpenAIErrorWithStatusCode, done bool) { + provider, ok := r.provider.(providersBase.ModerationInterface) + if !ok { + err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable) + done = true + return + } + + r.request.Model = r.modelName + + response, err := provider.CreateModeration(&r.request) + if err != nil { + return + } + err = responseJsonClient(r.c, response) + + if err != nil { + done = true + } + + return +} diff --git a/controller/quota.go b/controller/relay/quota.go similarity index 97% rename from controller/quota.go rename to controller/relay/quota.go index c2f96954..3343b6c5 100644 --- a/controller/quota.go +++ b/controller/relay/quota.go @@ -1,4 +1,4 @@ -package controller +package relay import ( "context" @@ -144,7 +144,7 @@ func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName stri return nil } -func (q *QuotaInfo) undo(c *gin.Context, errWithCode *types.OpenAIErrorWithStatusCode) { +func (q *QuotaInfo) undo(c *gin.Context) { tokenId := c.GetInt("token_id") if q.HandelStatus { go func(ctx context.Context) { @@ -155,7 +155,6 @@ func (q *QuotaInfo) undo(c *gin.Context, errWithCode *types.OpenAIErrorWithStatu } }(c.Request.Context()) } - errorHelper(c, errWithCode) } func (q *QuotaInfo) consume(c *gin.Context, usage *types.Usage) { diff --git a/controller/relay/speech.go b/controller/relay/speech.go new file mode 100644 index 00000000..45d6f6d6 --- /dev/null +++ b/controller/relay/speech.go @@ -0,0 +1,58 @@ +package relay + +import ( + "net/http" + "one-api/common" + providersBase "one-api/providers/base" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +type relaySpeech struct { + relayBase + request types.SpeechAudioRequest +} + +func NewRelaySpeech(c *gin.Context) *relaySpeech { + relay := &relaySpeech{} + relay.c = c + return relay +} + +func (r *relaySpeech) setRequest() error { + if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil { + return err + } + + r.originalModel = r.request.Model + + return nil +} + +func (r *relaySpeech) getPromptTokens() (int, error) { + return len(r.request.Input), nil +} + +func (r *relaySpeech) send() (err *types.OpenAIErrorWithStatusCode, done bool) { + provider, ok := r.provider.(providersBase.SpeechInterface) + if !ok { + err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable) + done = true + return + } + + r.request.Model = r.modelName + + response, err := provider.CreateSpeech(&r.request) + if err != nil { + return + } + err = responseMultipart(r.c, response) + + if err != nil { + done = true + } + + return +} diff --git a/controller/relay/transcriptions.go b/controller/relay/transcriptions.go new file mode 100644 index 00000000..17605a00 --- /dev/null +++ b/controller/relay/transcriptions.go @@ -0,0 +1,58 @@ +package relay + +import ( + "net/http" + "one-api/common" + providersBase "one-api/providers/base" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +type relayTranscriptions struct { + relayBase + request types.AudioRequest +} + +func NewRelayTranscriptions(c *gin.Context) *relayTranscriptions { + relay := &relayTranscriptions{} + relay.c = c + return relay +} + +func (r *relayTranscriptions) setRequest() error { + if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil { + return err + } + + r.originalModel = r.request.Model + + return nil +} + +func (r *relayTranscriptions) getPromptTokens() (int, error) { + return 0, nil +} + +func (r *relayTranscriptions) send() (err *types.OpenAIErrorWithStatusCode, done bool) { + provider, ok := r.provider.(providersBase.TranscriptionsInterface) + if !ok { + err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable) + done = true + return + } + + r.request.Model = r.modelName + + response, err := provider.CreateTranscriptions(&r.request) + if err != nil { + return + } + err = responseCustom(r.c, response) + + if err != nil { + done = true + } + + return +} diff --git a/controller/relay/translations.go b/controller/relay/translations.go new file mode 100644 index 00000000..abb65296 --- /dev/null +++ b/controller/relay/translations.go @@ -0,0 +1,58 @@ +package relay + +import ( + "net/http" + "one-api/common" + providersBase "one-api/providers/base" + "one-api/types" + + "github.com/gin-gonic/gin" +) + +type relayTranslations struct { + relayBase + request types.AudioRequest +} + +func NewRelayTranslations(c *gin.Context) *relayTranslations { + relay := &relayTranslations{} + relay.c = c + return relay +} + +func (r *relayTranslations) setRequest() error { + if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil { + return err + } + + r.originalModel = r.request.Model + + return nil +} + +func (r *relayTranslations) getPromptTokens() (int, error) { + return 0, nil +} + +func (r *relayTranslations) send() (err *types.OpenAIErrorWithStatusCode, done bool) { + provider, ok := r.provider.(providersBase.TranslationInterface) + if !ok { + err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable) + done = true + return + } + + r.request.Model = r.modelName + + response, err := provider.CreateTranslation(&r.request) + if err != nil { + return + } + err = responseCustom(r.c, response) + + if err != nil { + done = true + } + + return +} diff --git a/controller/relay-utils.go b/controller/relay/utils.go similarity index 56% rename from controller/relay-utils.go rename to controller/relay/utils.go index 1d5d019c..d3443f4e 100644 --- a/controller/relay-utils.go +++ b/controller/relay/utils.go @@ -1,6 +1,7 @@ -package controller +package relay import ( + "context" "encoding/json" "errors" "fmt" @@ -8,127 +9,98 @@ import ( "net/http" "one-api/common" "one-api/common/requester" + "one-api/controller" "one-api/model" "one-api/providers" providersBase "one-api/providers/base" "one-api/types" - "reflect" + "strings" "github.com/gin-gonic/gin" - "github.com/go-playground/validator/v10" ) -func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail bool) { +func Path2Relay(c *gin.Context, path string) RelayBaseInterface { + if strings.HasPrefix(path, "/v1/chat/completions") { + return NewRelayChat(c) + } else if strings.HasPrefix(path, "/v1/completions") { + return NewRelayCompletions(c) + } else if strings.HasPrefix(path, "/v1/embeddings") { + return NewRelayEmbeddings(c) + } else if strings.HasPrefix(path, "/v1/moderations") { + return NewRelayModerations(c) + } else if strings.HasPrefix(path, "/v1/images/generations") { + return NewRelayImageGenerations(c) + } else if strings.HasPrefix(path, "/v1/images/edits") { + return NewRelayImageEdits(c) + } else if strings.HasPrefix(path, "/v1/images/variations") { + return NewRelayImageVariations(c) + } else if strings.HasPrefix(path, "/v1/audio/speech") { + return NewRelaySpeech(c) + } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { + return NewRelayTranscriptions(c) + } else if strings.HasPrefix(path, "/v1/audio/translations") { + return NewRelayTranslations(c) + } + + return nil +} + +func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) { channel, fail := fetchChannel(c, modeName) - if fail { + if fail != nil { return } + c.Set("channel_id", channel.Id) provider = providers.GetProvider(channel, c) if provider == nil { - common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found") - fail = true + fail = errors.New("channel not found") return } + provider.SetOriginalModel(modeName) - newModelName, err := provider.ModelMappingHandler(modeName) - if err != nil { - common.AbortWithMessage(c, http.StatusInternalServerError, err.Error()) - fail = true + newModelName, fail = provider.ModelMappingHandler(modeName) + if fail != nil { return } return } -func GetValidFieldName(err error, obj interface{}) string { - getObj := reflect.TypeOf(obj) - if errs, ok := err.(validator.ValidationErrors); ok { - for _, e := range errs { - if f, exist := getObj.Elem().FieldByName(e.Field()); exist { - return f.Name - } - } - } - return err.Error() -} - -func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail bool) { - channelId := c.GetInt("channelId") +func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail error) { + channelId := c.GetInt("specific_channel_id") if channelId > 0 { - channel, fail = fetchChannelById(c, channelId) - if fail { - return - } - - } - channel, fail = fetchChannelByModel(c, modelName) - if fail { - return + return fetchChannelById(channelId) } - c.Set("channel_id", channel.Id) - - return + return fetchChannelByModel(c, modelName) } -func fetchChannelById(c *gin.Context, channelId int) (*model.Channel, bool) { +func fetchChannelById(channelId int) (*model.Channel, error) { channel, err := model.GetChannelById(channelId, true) if err != nil { - common.AbortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") - return nil, true + return nil, errors.New("无效的渠道 Id") } if channel.Status != common.ChannelStatusEnabled { - common.AbortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") - return nil, true + return nil, errors.New("该渠道已被禁用") } - return channel, false + return channel, nil } -func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, bool) { +func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, error) { group := c.GetString("group") - channel, err := model.CacheGetRandomSatisfiedChannel(group, modelName) + channel, err := model.ChannelGroup.Next(group, modelName) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, modelName) if channel != nil { common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" } - common.AbortWithMessage(c, http.StatusServiceUnavailable, message) - return nil, true + return nil, errors.New(message) } - return channel, false -} - -func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool { - if !common.AutomaticDisableChannelEnabled { - return false - } - if err == nil { - return false - } - if statusCode == http.StatusUnauthorized { - return true - } - if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { - return true - } - return false -} - -func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool { - if !common.AutomaticEnableChannelEnabled { - return false - } - if err != nil { - return false - } - if openAIErr != nil { - return false - } - return true + return channel, nil } func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWithStatusCode { @@ -201,3 +173,30 @@ func responseCustom(c *gin.Context, response *types.AudioResponseWrapper) *types return nil } + +func shouldRetry(c *gin.Context, statusCode int) bool { + channelId := c.GetInt("specific_channel_id") + if channelId > 0 { + return false + } + if statusCode == http.StatusTooManyRequests { + return true + } + if statusCode/100 == 5 { + return true + } + if statusCode == http.StatusBadRequest { + return false + } + if statusCode/100 == 2 { + return false + } + return true +} + +func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *types.OpenAIErrorWithStatusCode) { + common.LogError(ctx, fmt.Sprintf("relay error (channel #%d(%s)): %s", channelId, channelName, err.Message)) + if controller.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) { + controller.DisableChannel(channelId, channelName, err.Message) + } +} diff --git a/main.go b/main.go index 16ef5e48..5a13aa51 100644 --- a/main.go +++ b/main.go @@ -59,11 +59,11 @@ func main() { if common.MemoryCacheEnabled { common.SysLog("memory cache enabled") common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) - model.InitChannelCache() + model.InitChannelGroup() } if common.MemoryCacheEnabled { go model.SyncOptions(common.SyncFrequency) - go model.SyncChannelCache(common.SyncFrequency) + go model.SyncChannelGroup(common.SyncFrequency) } if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) diff --git a/middleware/auth.go b/middleware/auth.go index e6f48c62..2537ceae 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -114,7 +114,7 @@ func TokenAuth() func(c *gin.Context) { abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id") return } - c.Set("channelId", channelId) + c.Set("specific_channel_id", channelId) } else { abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") return diff --git a/model/ability.go b/model/ability.go index 1a53b3ef..8aeb2811 100644 --- a/model/ability.go +++ b/model/ability.go @@ -11,6 +11,7 @@ type Ability struct { ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` Enabled bool `json:"enabled"` Priority *int64 `json:"priority" gorm:"bigint;default:0;index"` + Weight *uint `json:"weight" gorm:"default:1"` } func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { @@ -67,6 +68,7 @@ func (channel *Channel) AddAbilities() error { ChannelId: channel.Id, Enabled: channel.Status == common.ChannelStatusEnabled, Priority: channel.Priority, + Weight: channel.Weight, } abilities = append(abilities, ability) } @@ -98,3 +100,49 @@ func (channel *Channel) UpdateAbilities() error { func UpdateAbilityStatus(channelId int, status bool) error { return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error } + +func GetEnabledAbility() ([]*Ability, error) { + trueVal := "1" + if common.UsingPostgreSQL { + trueVal = "true" + } + + var abilities []*Ability + err := DB.Where("enabled = ?", trueVal).Order("priority desc, weight desc").Find(&abilities).Error + return abilities, err +} + +type AbilityChannelGroup struct { + Group string `json:"group"` + Model string `json:"model"` + Priority int `json:"priority"` + ChannelIds string `json:"channel_ids"` +} + +func GetAbilityChannelGroup() ([]*AbilityChannelGroup, error) { + var abilities []*AbilityChannelGroup + + var channelSql string + if common.UsingPostgreSQL { + channelSql = `string_agg("channel_id"::text, ',')` + } else if common.UsingSQLite { + channelSql = `group_concat("channel_id", ',')` + } else { + channelSql = "GROUP_CONCAT(`channel_id` SEPARATOR ',')" + } + + trueVal := "1" + if common.UsingPostgreSQL { + trueVal = "true" + } + + err := DB.Raw(` + SELECT `+quotePostgresField("group")+`, model, priority, `+channelSql+` as channel_ids + FROM abilities + WHERE enabled = ? + GROUP BY `+quotePostgresField("group")+`, model, priority + ORDER BY priority DESC + `, trueVal).Scan(&abilities).Error + + return abilities, err +} diff --git a/model/balancer.go b/model/balancer.go new file mode 100644 index 00000000..38a90c0e --- /dev/null +++ b/model/balancer.go @@ -0,0 +1,176 @@ +package model + +import ( + "errors" + "math/rand" + "one-api/common" + "strings" + "sync" + "time" +) + +type ChannelChoice struct { + Channel *Channel + CooldownsTime int64 +} + +type ChannelsChooser struct { + sync.RWMutex + Channels map[int]*ChannelChoice + Rule map[string]map[string][][]int // group -> model -> priority -> channelIds +} + +func (cc *ChannelsChooser) Cooldowns(channelId int) bool { + if common.RetryCooldownSeconds == 0 { + return false + } + cc.Lock() + defer cc.Unlock() + if _, ok := cc.Channels[channelId]; !ok { + return false + } + + cc.Channels[channelId].CooldownsTime = time.Now().Unix() + int64(common.RetryCooldownSeconds) + return true +} + +func (cc *ChannelsChooser) Balancer(channelIds []int) *Channel { + nowTime := time.Now().Unix() + totalWeight := 0 + + validChannels := make([]*ChannelChoice, 0, len(channelIds)) + for _, channelId := range channelIds { + if choice, ok := cc.Channels[channelId]; ok && choice.CooldownsTime < nowTime { + weight := int(*choice.Channel.Weight) + totalWeight += weight + validChannels = append(validChannels, choice) + } + } + + if len(validChannels) == 0 { + return nil + } + + if len(validChannels) == 1 { + return validChannels[0].Channel + } + + choiceWeight := rand.Intn(totalWeight) + for _, choice := range validChannels { + weight := int(*choice.Channel.Weight) + choiceWeight -= weight + if choiceWeight < 0 { + return choice.Channel + } + } + + return nil +} + +func (cc *ChannelsChooser) Next(group, model string) (*Channel, error) { + if !common.MemoryCacheEnabled { + return GetRandomSatisfiedChannel(group, model) + } + cc.RLock() + defer cc.RUnlock() + if _, ok := cc.Rule[group]; !ok { + return nil, errors.New("group not found") + } + + if _, ok := cc.Rule[group][model]; !ok { + return nil, errors.New("model not found") + } + + channelsPriority := cc.Rule[group][model] + if len(channelsPriority) == 0 { + return nil, errors.New("channel not found") + } + + for _, priority := range channelsPriority { + channel := cc.Balancer(priority) + if channel != nil { + return channel, nil + } + } + + return nil, errors.New("channel not found") +} + +func (cc *ChannelsChooser) GetGroupModels(group string) ([]string, error) { + if !common.MemoryCacheEnabled { + return GetGroupModels(group) + } + + cc.RLock() + defer cc.RUnlock() + + if _, ok := cc.Rule[group]; !ok { + return nil, errors.New("group not found") + } + + models := make([]string, 0, len(cc.Rule[group])) + for model := range cc.Rule[group] { + models = append(models, model) + } + + return models, nil +} + +var ChannelGroup = ChannelsChooser{} + +func InitChannelGroup() { + var channels []*Channel + DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) + + abilities, err := GetAbilityChannelGroup() + if err != nil { + common.SysLog("get enabled abilities failed: " + err.Error()) + return + } + + newGroup := make(map[string]map[string][][]int) + newChannels := make(map[int]*ChannelChoice) + + for _, channel := range channels { + if *channel.Weight == 0 { + channel.Weight = &common.DefaultChannelWeight + } + newChannels[channel.Id] = &ChannelChoice{ + Channel: channel, + CooldownsTime: 0, + } + } + + for _, ability := range abilities { + if _, ok := newGroup[ability.Group]; !ok { + newGroup[ability.Group] = make(map[string][][]int) + } + + if _, ok := newGroup[ability.Group][ability.Model]; !ok { + newGroup[ability.Group][ability.Model] = make([][]int, 0) + } + + var priorityIds []int + // 逗号分割 ability.ChannelId + channelIds := strings.Split(ability.ChannelIds, ",") + for _, channelId := range channelIds { + priorityIds = append(priorityIds, common.String2Int(channelId)) + } + + newGroup[ability.Group][ability.Model] = append(newGroup[ability.Group][ability.Model], priorityIds) + } + + ChannelGroup.Lock() + ChannelGroup.Rule = newGroup + ChannelGroup.Channels = newChannels + ChannelGroup.Unlock() + common.SysLog("channels synced from database") +} + +func SyncChannelGroup(frequency int) { + for { + time.Sleep(time.Duration(frequency) * time.Second) + common.SysLog("syncing channels from database") + InitChannelGroup() + } +} diff --git a/model/cache.go b/model/cache.go index e26995fe..9edb8a64 100644 --- a/model/cache.go +++ b/model/cache.go @@ -2,14 +2,9 @@ package model import ( "encoding/json" - "errors" "fmt" - "math/rand" "one-api/common" - "sort" "strconv" - "strings" - "sync" "time" ) @@ -131,104 +126,3 @@ func CacheIsUserEnabled(userId int) (bool, error) { } return userEnabled, err } - -var group2model2channels map[string]map[string][]*Channel -var channelSyncLock sync.RWMutex - -func InitChannelCache() { - newChannelId2channel := make(map[int]*Channel) - var channels []*Channel - DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) - for _, channel := range channels { - newChannelId2channel[channel.Id] = channel - } - var abilities []*Ability - DB.Find(&abilities) - groups := make(map[string]bool) - for _, ability := range abilities { - groups[ability.Group] = true - } - newGroup2model2channels := make(map[string]map[string][]*Channel) - for group := range groups { - newGroup2model2channels[group] = make(map[string][]*Channel) - } - for _, channel := range channels { - groups := strings.Split(channel.Group, ",") - for _, group := range groups { - models := strings.Split(channel.Models, ",") - for _, model := range models { - if _, ok := newGroup2model2channels[group][model]; !ok { - newGroup2model2channels[group][model] = make([]*Channel, 0) - } - newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel) - } - } - } - - // sort by priority - for group, model2channels := range newGroup2model2channels { - for model, channels := range model2channels { - sort.Slice(channels, func(i, j int) bool { - return channels[i].GetPriority() > channels[j].GetPriority() - }) - newGroup2model2channels[group][model] = channels - } - } - - channelSyncLock.Lock() - group2model2channels = newGroup2model2channels - channelSyncLock.Unlock() - common.SysLog("channels synced from database") -} - -func SyncChannelCache(frequency int) { - for { - time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing channels from database") - InitChannelCache() - } -} - -func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { - if !common.MemoryCacheEnabled { - return GetRandomSatisfiedChannel(group, model) - } - channelSyncLock.RLock() - defer channelSyncLock.RUnlock() - channels := group2model2channels[group][model] - if len(channels) == 0 { - return nil, errors.New("channel not found") - } - endIdx := len(channels) - // choose by priority - firstChannel := channels[0] - if firstChannel.GetPriority() > 0 { - for i := range channels { - if channels[i].GetPriority() != firstChannel.GetPriority() { - endIdx = i - break - } - } - } - idx := rand.Intn(endIdx) - return channels[idx], nil -} - -func CacheGetGroupModels(group string) ([]string, error) { - if !common.MemoryCacheEnabled { - return GetGroupModels(group) - } - channelSyncLock.RLock() - defer channelSyncLock.RUnlock() - - groupModels := group2model2channels[group] - if groupModels == nil { - return nil, errors.New("group not found") - } - - models := make([]string, 0) - for model := range groupModels { - models = append(models, model) - } - return models, nil -} diff --git a/model/channel.go b/model/channel.go index 30a3f61a..6d4413ec 100644 --- a/model/channel.go +++ b/model/channel.go @@ -13,7 +13,7 @@ type Channel struct { Key string `json:"key" form:"key" gorm:"type:varchar(767);not null;index"` Status int `json:"status" form:"status" gorm:"default:1"` Name string `json:"name" form:"name" gorm:"index"` - Weight *uint `json:"weight" gorm:"default:0"` + Weight *uint `json:"weight" gorm:"default:1"` CreatedTime int64 `json:"created_time" gorm:"bigint"` TestTime int64 `json:"test_time" gorm:"bigint"` ResponseTime int `json:"response_time"` // in milliseconds @@ -95,11 +95,8 @@ func GetAllChannels() ([]*Channel, error) { func GetChannelById(id int, selectAll bool) (*Channel, error) { channel := Channel{Id: id} var err error = nil - if selectAll { - err = DB.First(&channel, "id = ?", id).Error - } else { - err = DB.Omit("key").First(&channel, "id = ?", id).Error - } + err = DB.First(&channel, "id = ?", id).Error + return &channel, err } diff --git a/model/option.go b/model/option.go index 084427f4..24e9a20b 100644 --- a/model/option.go +++ b/model/option.go @@ -77,6 +77,8 @@ func InitOptionMap() { common.OptionMap["ChatLink"] = common.ChatLink common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) + common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds) + common.OptionMapRWMutex.Unlock() initModelRatio() loadOptionsFromDatabase() @@ -146,6 +148,7 @@ var optionIntMap = map[string]*int{ "QuotaRemindThreshold": &common.QuotaRemindThreshold, "PreConsumedQuota": &common.PreConsumedQuota, "RetryTimes": &common.RetryTimes, + "RetryCooldownSeconds": &common.RetryCooldownSeconds, } var optionBoolMap = map[string]*bool{ diff --git a/router/relay-router.go b/router/relay-router.go index 519a764b..f6f548e7 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -2,6 +2,7 @@ package router import ( "one-api/controller" + "one-api/controller/relay" "one-api/middleware" "github.com/gin-gonic/gin" @@ -19,18 +20,18 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router := router.Group("/v1") relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute()) { - relayV1Router.POST("/completions", controller.RelayCompletions) - relayV1Router.POST("/chat/completions", controller.RelayChat) + relayV1Router.POST("/completions", relay.Relay) + relayV1Router.POST("/chat/completions", relay.Relay) // relayV1Router.POST("/edits", controller.Relay) - relayV1Router.POST("/images/generations", controller.RelayImageGenerations) - relayV1Router.POST("/images/edits", controller.RelayImageEdits) - relayV1Router.POST("/images/variations", controller.RelayImageVariations) - relayV1Router.POST("/embeddings", controller.RelayEmbeddings) + relayV1Router.POST("/images/generations", relay.Relay) + relayV1Router.POST("/images/edits", relay.Relay) + relayV1Router.POST("/images/variations", relay.Relay) + relayV1Router.POST("/embeddings", relay.Relay) // relayV1Router.POST("/engines/:model/embeddings", controller.RelayEmbeddings) - relayV1Router.POST("/audio/transcriptions", controller.RelayTranscriptions) - relayV1Router.POST("/audio/translations", controller.RelayTranslations) - relayV1Router.POST("/audio/speech", controller.RelaySpeech) - relayV1Router.POST("/moderations", controller.RelayModerations) + relayV1Router.POST("/audio/transcriptions", relay.Relay) + relayV1Router.POST("/audio/translations", relay.Relay) + relayV1Router.POST("/audio/speech", relay.Relay) + relayV1Router.POST("/moderations", relay.Relay) relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) diff --git a/web/src/views/Channel/component/EditModal.js b/web/src/views/Channel/component/EditModal.js index a1fb429c..aa52b40d 100644 --- a/web/src/views/Channel/component/EditModal.js +++ b/web/src/views/Channel/component/EditModal.js @@ -21,7 +21,8 @@ import { Container, Autocomplete, FormHelperText, - Checkbox + Checkbox, + Switch } from '@mui/material'; import { Formik } from 'formik'; @@ -73,6 +74,7 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => { const [inputLabel, setInputLabel] = useState(defaultConfig.inputLabel); // const [inputPrompt, setInputPrompt] = useState(defaultConfig.prompt); const [modelOptions, setModelOptions] = useState([]); + const [batchAdd, setBatchAdd] = useState(false); const initChannel = (typeValue) => { if (typeConfig[typeValue]?.inputLabel) { @@ -246,6 +248,7 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => { }, []); useEffect(() => { + setBatchAdd(false); if (channelId) { loadChannel().then(); } else { @@ -479,18 +482,36 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => { - {inputLabel.key} - + {!batchAdd ? ( + <> + {inputLabel.key} + + + ) : ( + + )} + {touched.key && errors.key ? ( {errors.key} @@ -499,6 +520,17 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => { {inputPrompt.key} )} + {channelId === 0 && ( + + setBatchAdd(e.target.checked)} /> + 批量添加 + + )} + {/* {inputLabel.model_mapping} */} { + if (weightValve === '' || weightValve === item.weight) { + return; + } + + if (weightValve <= 0) { + showError('权重不能小于 0'); + return; + } + + await manageChannel(item.id, 'weight', weightValve); + }; + const handleResponseTime = async () => { const { success, time } = await manageChannel(item.id, 'test', ''); if (success) { @@ -176,6 +196,25 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal, /> + + + 权重 + setWeight(e.target.value)} + sx={{ textAlign: 'center' }} + endAdornment={ + + + + + + } + /> + + @@ -204,6 +243,16 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal, 编辑 + + { + handleCloseMenu(); + manageChannel(item.id, 'copy'); + }} + > + 复制{' '} + + 删除 diff --git a/web/src/views/Channel/index.js b/web/src/views/Channel/index.js index dedce917..dadebec4 100644 --- a/web/src/views/Channel/index.js +++ b/web/src/views/Channel/index.js @@ -11,6 +11,7 @@ import LinearProgress from '@mui/material/LinearProgress'; import ButtonGroup from '@mui/material/ButtonGroup'; import Toolbar from '@mui/material/Toolbar'; import useMediaQuery from '@mui/material/useMediaQuery'; +import Alert from '@mui/material/Alert'; import { Button, IconButton, Card, Box, Stack, Container, Typography, Divider } from '@mui/material'; import ChannelTableRow from './component/TableRow'; @@ -116,6 +117,19 @@ export default function ChannelPage() { try { switch (action) { + case 'copy': { + let oldRes = await API.get(`/api/channel/${id}`); + const { success, message, data } = oldRes.data; + if (!success) { + showError(message); + return; + } + // 删除 data.id + delete data.id; + data.name = data.name + '_copy'; + res = await API.post(`/api/channel/`, { ...data }); + break; + } case 'delete': res = await API.delete(url + id); break; @@ -134,6 +148,15 @@ export default function ChannelPage() { priority: parseInt(value) }); break; + case 'weight': + if (value === '') { + return; + } + res = await API.put(url, { + ...data, + weight: parseInt(value) + }); + break; case 'test': res = await API.get(url + `test/${id}`); break; @@ -141,7 +164,7 @@ export default function ChannelPage() { const { success, message } = res.data; if (success) { showSuccess('操作成功完成!'); - if (action === 'delete') { + if (action === 'delete' || action === 'copy') { await handleRefresh(); } } else { @@ -271,6 +294,20 @@ export default function ChannelPage() { + + + 优先级/权重解释: +
+ 1. 优先级越大,越优先使用;(只有该优先级下的节点都冻结或者禁用了,才会使用低优先级的节点) +
+ 2. 相同优先级下:如果“MEMORY_CACHE_ENABLED”启用,则根据权重进行负载均衡(加权随机);否则忽略权重直接随机 +
+ 3. 如果在设置-通用设置中设置了“重试次数”和“重试间隔”,则会在失败后重试。 +
+ 4. + 重试逻辑:1)先在高优先级中的节点重试,如果高优先级中的节点都冻结了,才会在低优先级中的节点重试。2)如果设置了“重试间隔”,则某一渠道失败后,会冻结一段时间,所有人都不会再使用这个渠道,直到冻结时间结束。3)重试次数用完后,直接结束。 +
+
@@ -349,6 +386,7 @@ export default function ChannelPage() { { id: 'response_time', label: '响应时间', disableSort: false }, { id: 'balance', label: '余额', disableSort: false }, { id: 'priority', label: '优先级', disableSort: false }, + { id: 'weight', label: '权重', disableSort: false }, { id: 'action', label: '操作', disableSort: true } ]} /> diff --git a/web/src/views/Setting/component/OperationSetting.js b/web/src/views/Setting/component/OperationSetting.js index ec2bc715..96680c71 100644 --- a/web/src/views/Setting/component/OperationSetting.js +++ b/web/src/views/Setting/component/OperationSetting.js @@ -30,7 +30,8 @@ const OperationSetting = () => { DisplayInCurrencyEnabled: '', DisplayTokenStatEnabled: '', ApproximateTokenEnabled: '', - RetryTimes: 0 + RetryTimes: 0, + RetryCooldownSeconds: 0 }); const [originInputs, setOriginInputs] = useState({}); const [newModelRatioView, setNewModelRatioView] = useState(false); @@ -139,6 +140,11 @@ const OperationSetting = () => { } break; case 'general': + if (inputs.QuotaPerUnit < 0 || inputs.RetryTimes < 0 || inputs.RetryCooldownSeconds < 0) { + showError('单位额度、重试次数、冷却时间不能为负数'); + return; + } + if (originInputs['TopUpLink'] !== inputs.TopUpLink) { await updateOption('TopUpLink', inputs.TopUpLink); } @@ -151,6 +157,9 @@ const OperationSetting = () => { if (originInputs['RetryTimes'] !== inputs.RetryTimes) { await updateOption('RetryTimes', inputs.RetryTimes); } + if (originInputs['RetryCooldownSeconds'] !== inputs.RetryCooldownSeconds) { + await updateOption('RetryCooldownSeconds', inputs.RetryCooldownSeconds); + } break; } @@ -224,6 +233,18 @@ const OperationSetting = () => { disabled={loading} /> + + 重试间隔(秒) + + { const checkUpdate = async () => { try { - const res = await API.get('https://api.github.com/repos/MartialBE/one-api/releases/latest'); - const { tag_name, body } = res.data; - if (tag_name === process.env.REACT_APP_VERSION) { - showSuccess(`已是最新版本:${tag_name}`); + if (!process.env.REACT_APP_VERSION) { + showError('无法获取当前版本号'); + return; + } + + // 如果版本前缀是v开头的 + if (process.env.REACT_APP_VERSION.startsWith('v')) { + const res = await API.get('https://api.github.com/repos/MartialBE/one-api/releases/latest'); + const { tag_name, body } = res.data; + if (tag_name === process.env.REACT_APP_VERSION) { + showSuccess(`已是最新版本:${tag_name}`); + } else { + setUpdateData({ + tag_name: tag_name, + content: marked.parse(body) + }); + setShowUpdateModal(true); + } } else { - setUpdateData({ - tag_name: tag_name, - content: marked.parse(body) - }); - setShowUpdateModal(true); + const res = await API.get('https://api.github.com/repos/MartialBE/one-api/commits/main'); + const { sha, commit } = res.data; + const newVersion = 'dev-' + sha.substr(0, 7); + if (newVersion === process.env.REACT_APP_VERSION) { + showSuccess(`已是最新版本:${newVersion}`); + } else { + setUpdateData({ + tag_name: newVersion, + content: marked.parse(commit.message) + }); + setShowUpdateModal(true); + } } } catch (error) { return; @@ -137,6 +159,9 @@ const OtherSetting = () => { + + 当前版本:{process.env.REACT_APP_VERSION} +