From dd3e79a20db44aa755a46fbe21ef93754eff30b2 Mon Sep 17 00:00:00 2001
From: Buer <42402987+MartialBE@users.noreply.github.com>
Date: Wed, 6 Mar 2024 18:01:43 +0800
Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20channel=20support=20weight?=
=?UTF-8?q?=20(#85)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* ✨ feat: channel support weight
* 💄 improve: show version
* 💄 improve: Channel add copy operation
* 💄 improve: Channel support batch add
---
.github/workflows/docker-image.yml | 10 +-
common/constants.go | 2 +
controller/channel-billing.go | 2 +-
controller/channel-test.go | 14 +-
controller/common.go | 75 ++++++++
controller/model.go | 2 +-
controller/relay-chat.go | 79 --------
controller/relay-completions.go | 79 --------
controller/relay-embeddings.go | 66 -------
controller/relay-image-edits.go | 79 --------
controller/relay-image-generations.go | 82 --------
controller/relay-image-variationsy.go | 74 --------
controller/relay-moderations.go | 66 -------
controller/relay-speech.go | 62 ------
controller/relay-transcriptions.go | 62 ------
controller/relay-translations.go | 62 ------
controller/relay.go | 63 -------
controller/relay/base.go | 53 ++++++
controller/relay/chat.go | 76 ++++++++
controller/relay/completions.go | 76 ++++++++
controller/relay/embeddings.go | 63 +++++++
controller/relay/image-edits.go | 71 +++++++
controller/relay/image-generations.go | 74 ++++++++
controller/relay/image-variationsy.go | 66 +++++++
controller/relay/main.go | 106 +++++++++++
controller/relay/moderations.go | 62 ++++++
controller/{ => relay}/quota.go | 5 +-
controller/relay/speech.go | 58 ++++++
controller/relay/transcriptions.go | 58 ++++++
controller/relay/translations.go | 58 ++++++
controller/{relay-utils.go => relay/utils.go} | 153 ++++++++-------
main.go | 4 +-
middleware/auth.go | 2 +-
model/ability.go | 48 +++++
model/balancer.go | 176 ++++++++++++++++++
model/cache.go | 106 -----------
model/channel.go | 9 +-
model/option.go | 3 +
router/relay-router.go | 21 ++-
web/src/views/Channel/component/EditModal.js | 58 ++++--
web/src/views/Channel/component/TableRow.js | 51 ++++-
web/src/views/Channel/index.js | 40 +++-
.../Setting/component/OperationSetting.js | 23 ++-
.../views/Setting/component/OtherSetting.js | 45 ++++-
44 files changed, 1425 insertions(+), 1019 deletions(-)
create mode 100644 controller/common.go
delete mode 100644 controller/relay-chat.go
delete mode 100644 controller/relay-completions.go
delete mode 100644 controller/relay-embeddings.go
delete mode 100644 controller/relay-image-edits.go
delete mode 100644 controller/relay-image-generations.go
delete mode 100644 controller/relay-image-variationsy.go
delete mode 100644 controller/relay-moderations.go
delete mode 100644 controller/relay-speech.go
delete mode 100644 controller/relay-transcriptions.go
delete mode 100644 controller/relay-translations.go
delete mode 100644 controller/relay.go
create mode 100644 controller/relay/base.go
create mode 100644 controller/relay/chat.go
create mode 100644 controller/relay/completions.go
create mode 100644 controller/relay/embeddings.go
create mode 100644 controller/relay/image-edits.go
create mode 100644 controller/relay/image-generations.go
create mode 100644 controller/relay/image-variationsy.go
create mode 100644 controller/relay/main.go
create mode 100644 controller/relay/moderations.go
rename controller/{ => relay}/quota.go (97%)
create mode 100644 controller/relay/speech.go
create mode 100644 controller/relay/transcriptions.go
create mode 100644 controller/relay/translations.go
rename controller/{relay-utils.go => relay/utils.go} (56%)
create mode 100644 model/balancer.go
diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml
index 64f4e9c0..cb916557 100644
--- a/.github/workflows/docker-image.yml
+++ b/.github/workflows/docker-image.yml
@@ -25,7 +25,15 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
-
+ - name: Save version info
+ run: |
+ TAG=$(git describe --tags --exact-match 2> /dev/null)
+ if [ $? -eq 0 ]; then
+ echo $TAG > VERSION
+ else
+ HASH=$(git rev-parse --short=7 HEAD)
+ echo "dev-$HASH" > VERSION
+ fi
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
diff --git a/common/constants.go b/common/constants.go
index f704fc98..d6f9ddb2 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -83,6 +83,8 @@ var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
var ApproximateTokenEnabled = false
var RetryTimes = 0
+var DefaultChannelWeight = uint(1)
+var RetryCooldownSeconds = 5
var RootUserEmail = ""
diff --git a/controller/channel-billing.go b/controller/channel-billing.go
index 83bfce88..5f838067 100644
--- a/controller/channel-billing.go
+++ b/controller/channel-billing.go
@@ -122,7 +122,7 @@ func updateAllChannelsBalance() error {
} else {
// err is nil & balance <= 0 means quota is used up
if balance <= 0 {
- disableChannel(channel.Id, channel.Name, "余额不足")
+ DisableChannel(channel.Id, channel.Name, "余额不足")
}
}
time.Sleep(common.RequestInterval)
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 37d544c3..1c6206d3 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -140,14 +140,6 @@ func notifyRootUser(subject string, content string) {
}
}
-// disable & notify
-func disableChannel(channelId int, channelName string, reason string) {
- model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
- subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
- content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
- notifyRootUser(subject, content)
-}
-
// enable & notify
func enableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
@@ -185,10 +177,10 @@ func testAllChannels(notify bool) error {
milliseconds := tok.Sub(tik).Milliseconds()
if milliseconds > disableThreshold {
err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
- disableChannel(channel.Id, channel.Name, err.Error())
+ DisableChannel(channel.Id, channel.Name, err.Error())
}
- if isChannelEnabled && shouldDisableChannel(openaiErr, -1) {
- disableChannel(channel.Id, channel.Name, err.Error())
+ if isChannelEnabled && ShouldDisableChannel(openaiErr, -1) {
+ DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
enableChannel(channel.Id, channel.Name)
diff --git a/controller/common.go b/controller/common.go
new file mode 100644
index 00000000..90349b2a
--- /dev/null
+++ b/controller/common.go
@@ -0,0 +1,75 @@
+package controller
+
+import (
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool {
+ if !common.AutomaticEnableChannelEnabled {
+ return false
+ }
+ if err != nil {
+ return false
+ }
+ if openAIErr != nil {
+ return false
+ }
+ return true
+}
+
+func ShouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
+ if !common.AutomaticDisableChannelEnabled {
+ return false
+ }
+
+ if err == nil {
+ return false
+ }
+
+ if statusCode == http.StatusUnauthorized {
+ return true
+ }
+
+ if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
+ return true
+ }
+ return false
+}
+
+// disable & notify
+func DisableChannel(channelId int, channelName string, reason string) {
+ model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
+ subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
+ content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
+ notifyRootUser(subject, content)
+}
+
+func RelayNotImplemented(c *gin.Context) {
+ err := types.OpenAIError{
+ Message: "API not implemented",
+ Type: "one_api_error",
+ Param: "",
+ Code: "api_not_implemented",
+ }
+ c.JSON(http.StatusNotImplemented, gin.H{
+ "error": err,
+ })
+}
+
+func RelayNotFound(c *gin.Context) {
+ err := types.OpenAIError{
+ Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
+ Type: "invalid_request_error",
+ Param: "",
+ Code: "",
+ }
+ c.JSON(http.StatusNotFound, gin.H{
+ "error": err,
+ })
+}
diff --git a/controller/model.go b/controller/model.go
index 61df4a79..98f99e50 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -70,7 +70,7 @@ func ListModels(c *gin.Context) {
groupName = user.Group
}
- models, err := model.CacheGetGroupModels(groupName)
+ models, err := model.ChannelGroup.GetGroupModels(groupName)
if err != nil {
common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error())
return
diff --git a/controller/relay-chat.go b/controller/relay-chat.go
deleted file mode 100644
index 61267311..00000000
--- a/controller/relay-chat.go
+++ /dev/null
@@ -1,79 +0,0 @@
-package controller
-
-import (
- "math"
- "net/http"
- "one-api/common"
- "one-api/common/requester"
- providersBase "one-api/providers/base"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func RelayChat(c *gin.Context) {
-
- var chatRequest types.ChatCompletionRequest
- if err := common.UnmarshalBodyReusable(c, &chatRequest); err != nil {
- common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
- return
- }
-
- if chatRequest.MaxTokens < 0 || chatRequest.MaxTokens > math.MaxInt32/2 {
- common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
- return
- }
-
- // 获取供应商
- provider, modelName, fail := getProvider(c, chatRequest.Model)
- if fail {
- return
- }
- chatRequest.Model = modelName
-
- chatProvider, ok := provider.(providersBase.ChatInterface)
- if !ok {
- common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
- return
- }
-
- // 获取Input Tokens
- promptTokens := common.CountTokenMessages(chatRequest.Messages, chatRequest.Model)
-
- usage := &types.Usage{
- PromptTokens: promptTokens,
- }
- provider.SetUsage(usage)
-
- quotaInfo, errWithCode := generateQuotaInfo(c, chatRequest.Model, promptTokens)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
-
- if chatRequest.Stream {
- var response requester.StreamReaderInterface[string]
- response, errWithCode = chatProvider.CreateChatCompletionStream(&chatRequest)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
- errWithCode = responseStreamClient(c, response)
- } else {
- var response *types.ChatCompletionResponse
- response, errWithCode = chatProvider.CreateChatCompletion(&chatRequest)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
- errWithCode = responseJsonClient(c, response)
- }
-
- // 如果报错,则退还配额
- if errWithCode != nil {
- quotaInfo.undo(c, errWithCode)
- return
- }
-
- quotaInfo.consume(c, usage)
-}
diff --git a/controller/relay-completions.go b/controller/relay-completions.go
deleted file mode 100644
index 0898a016..00000000
--- a/controller/relay-completions.go
+++ /dev/null
@@ -1,79 +0,0 @@
-package controller
-
-import (
- "math"
- "net/http"
- "one-api/common"
- "one-api/common/requester"
- providersBase "one-api/providers/base"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func RelayCompletions(c *gin.Context) {
-
- var completionRequest types.CompletionRequest
- if err := common.UnmarshalBodyReusable(c, &completionRequest); err != nil {
- common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
- return
- }
-
- if completionRequest.MaxTokens < 0 || completionRequest.MaxTokens > math.MaxInt32/2 {
- common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
- return
- }
-
- // 获取供应商
- provider, modelName, fail := getProvider(c, completionRequest.Model)
- if fail {
- return
- }
- completionRequest.Model = modelName
-
- completionProvider, ok := provider.(providersBase.CompletionInterface)
- if !ok {
- common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
- return
- }
-
- // 获取Input Tokens
- promptTokens := common.CountTokenInput(completionRequest.Prompt, completionRequest.Model)
-
- usage := &types.Usage{
- PromptTokens: promptTokens,
- }
- provider.SetUsage(usage)
-
- quotaInfo, errWithCode := generateQuotaInfo(c, completionRequest.Model, promptTokens)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
-
- if completionRequest.Stream {
- var response requester.StreamReaderInterface[string]
- response, errWithCode = completionProvider.CreateCompletionStream(&completionRequest)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
- errWithCode = responseStreamClient(c, response)
- } else {
- var response *types.CompletionResponse
- response, errWithCode = completionProvider.CreateCompletion(&completionRequest)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
- errWithCode = responseJsonClient(c, response)
- }
-
- // 如果报错,则退还配额
- if errWithCode != nil {
- quotaInfo.undo(c, errWithCode)
- return
- }
-
- quotaInfo.consume(c, usage)
-}
diff --git a/controller/relay-embeddings.go b/controller/relay-embeddings.go
deleted file mode 100644
index 58dffc48..00000000
--- a/controller/relay-embeddings.go
+++ /dev/null
@@ -1,66 +0,0 @@
-package controller
-
-import (
- "net/http"
- "one-api/common"
- providersBase "one-api/providers/base"
- "one-api/types"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func RelayEmbeddings(c *gin.Context) {
-
- var embeddingsRequest types.EmbeddingRequest
- if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
- embeddingsRequest.Model = c.Param("model")
- }
-
- if err := common.UnmarshalBodyReusable(c, &embeddingsRequest); err != nil {
- common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
- return
- }
-
- // 获取供应商
- provider, modelName, fail := getProvider(c, embeddingsRequest.Model)
- if fail {
- return
- }
- embeddingsRequest.Model = modelName
-
- embeddingsProvider, ok := provider.(providersBase.EmbeddingsInterface)
- if !ok {
- common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
- return
- }
-
- // 获取Input Tokens
- promptTokens := common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model)
-
- usage := &types.Usage{
- PromptTokens: promptTokens,
- }
- provider.SetUsage(usage)
-
- quotaInfo, errWithCode := generateQuotaInfo(c, embeddingsRequest.Model, promptTokens)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
-
- response, errWithCode := embeddingsProvider.CreateEmbeddings(&embeddingsRequest)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
- errWithCode = responseJsonClient(c, response)
-
- // 如果报错,则退还配额
- if errWithCode != nil {
- quotaInfo.undo(c, errWithCode)
- return
- }
-
- quotaInfo.consume(c, usage)
-}
diff --git a/controller/relay-image-edits.go b/controller/relay-image-edits.go
deleted file mode 100644
index ef301e2a..00000000
--- a/controller/relay-image-edits.go
+++ /dev/null
@@ -1,79 +0,0 @@
-package controller
-
-import (
- "net/http"
- "one-api/common"
- providersBase "one-api/providers/base"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func RelayImageEdits(c *gin.Context) {
-
- var imageEditRequest types.ImageEditRequest
-
- if err := common.UnmarshalBodyReusable(c, &imageEditRequest); err != nil {
- common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
- return
- }
-
- if imageEditRequest.Prompt == "" {
- common.AbortWithMessage(c, http.StatusBadRequest, "field prompt is required")
- return
- }
-
- if imageEditRequest.Model == "" {
- imageEditRequest.Model = "dall-e-2"
- }
-
- if imageEditRequest.Size == "" {
- imageEditRequest.Size = "1024x1024"
- }
-
- // 获取供应商
- provider, modelName, fail := getProvider(c, imageEditRequest.Model)
- if fail {
- return
- }
- imageEditRequest.Model = modelName
-
- imageEditsProvider, ok := provider.(providersBase.ImageEditsInterface)
- if !ok {
- common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
- return
- }
-
- // 获取Input Tokens
- promptTokens, err := common.CountTokenImage(imageEditRequest)
- if err != nil {
- common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
- return
- }
-
- usage := &types.Usage{
- PromptTokens: promptTokens,
- }
- provider.SetUsage(usage)
-
- quotaInfo, errWithCode := generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
-
- response, errWithCode := imageEditsProvider.CreateImageEdits(&imageEditRequest)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
- errWithCode = responseJsonClient(c, response)
-
- // 如果报错,则退还配额
- if errWithCode != nil {
- quotaInfo.undo(c, errWithCode)
- return
- }
-
- quotaInfo.consume(c, usage)
-}
diff --git a/controller/relay-image-generations.go b/controller/relay-image-generations.go
deleted file mode 100644
index 4332274d..00000000
--- a/controller/relay-image-generations.go
+++ /dev/null
@@ -1,82 +0,0 @@
-package controller
-
-import (
- "net/http"
- "one-api/common"
- providersBase "one-api/providers/base"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func RelayImageGenerations(c *gin.Context) {
-
- var imageRequest types.ImageRequest
-
- if err := common.UnmarshalBodyReusable(c, &imageRequest); err != nil {
- common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
- return
- }
-
- if imageRequest.Model == "" {
- imageRequest.Model = "dall-e-2"
- }
-
- if imageRequest.N == 0 {
- imageRequest.N = 1
- }
-
- if imageRequest.Size == "" {
- imageRequest.Size = "1024x1024"
- }
-
- if imageRequest.Quality == "" {
- imageRequest.Quality = "standard"
- }
-
- // 获取供应商
- provider, modelName, fail := getProvider(c, imageRequest.Model)
- if fail {
- return
- }
- imageRequest.Model = modelName
-
- imageGenerationsProvider, ok := provider.(providersBase.ImageGenerationsInterface)
- if !ok {
- common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
- return
- }
-
- // 获取Input Tokens
- promptTokens, err := common.CountTokenImage(imageRequest)
- if err != nil {
- common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
- return
- }
-
- usage := &types.Usage{
- PromptTokens: promptTokens,
- }
- provider.SetUsage(usage)
-
- quotaInfo, errWithCode := generateQuotaInfo(c, imageRequest.Model, promptTokens)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
-
- response, errWithCode := imageGenerationsProvider.CreateImageGenerations(&imageRequest)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
- errWithCode = responseJsonClient(c, response)
-
- // 如果报错,则退还配额
- if errWithCode != nil {
- quotaInfo.undo(c, errWithCode)
- return
- }
-
- quotaInfo.consume(c, usage)
-}
diff --git a/controller/relay-image-variationsy.go b/controller/relay-image-variationsy.go
deleted file mode 100644
index 2c9069ee..00000000
--- a/controller/relay-image-variationsy.go
+++ /dev/null
@@ -1,74 +0,0 @@
-package controller
-
-import (
- "net/http"
- "one-api/common"
- providersBase "one-api/providers/base"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func RelayImageVariations(c *gin.Context) {
-
- var imageEditRequest types.ImageEditRequest
-
- if err := common.UnmarshalBodyReusable(c, &imageEditRequest); err != nil {
- common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
- return
- }
-
- if imageEditRequest.Model == "" {
- imageEditRequest.Model = "dall-e-2"
- }
-
- if imageEditRequest.Size == "" {
- imageEditRequest.Size = "1024x1024"
- }
-
- // 获取供应商
- provider, modelName, fail := getProvider(c, imageEditRequest.Model)
- if fail {
- return
- }
- imageEditRequest.Model = modelName
-
- imageVariations, ok := provider.(providersBase.ImageVariationsInterface)
- if !ok {
- common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
- return
- }
-
- // 获取Input Tokens
- promptTokens, err := common.CountTokenImage(imageEditRequest)
- if err != nil {
- common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
- return
- }
-
- usage := &types.Usage{
- PromptTokens: promptTokens,
- }
- provider.SetUsage(usage)
-
- quotaInfo, errWithCode := generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
-
- response, errWithCode := imageVariations.CreateImageVariations(&imageEditRequest)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
- errWithCode = responseJsonClient(c, response)
-
- // 如果报错,则退还配额
- if errWithCode != nil {
- quotaInfo.undo(c, errWithCode)
- return
- }
-
- quotaInfo.consume(c, usage)
-}
diff --git a/controller/relay-moderations.go b/controller/relay-moderations.go
deleted file mode 100644
index 136b6fdd..00000000
--- a/controller/relay-moderations.go
+++ /dev/null
@@ -1,66 +0,0 @@
-package controller
-
-import (
- "net/http"
- "one-api/common"
- providersBase "one-api/providers/base"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func RelayModerations(c *gin.Context) {
-
- var moderationRequest types.ModerationRequest
-
- if err := common.UnmarshalBodyReusable(c, &moderationRequest); err != nil {
- common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
- return
- }
-
- if moderationRequest.Model == "" {
- moderationRequest.Model = "text-moderation-stable"
- }
-
- // 获取供应商
- provider, modelName, fail := getProvider(c, moderationRequest.Model)
- if fail {
- return
- }
- moderationRequest.Model = modelName
-
- moderationProvider, ok := provider.(providersBase.ModerationInterface)
- if !ok {
- common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
- return
- }
-
- // 获取Input Tokens
- promptTokens := common.CountTokenInput(moderationRequest.Input, moderationRequest.Model)
-
- usage := &types.Usage{
- PromptTokens: promptTokens,
- }
- provider.SetUsage(usage)
-
- quotaInfo, errWithCode := generateQuotaInfo(c, moderationRequest.Model, promptTokens)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
-
- response, errWithCode := moderationProvider.CreateModeration(&moderationRequest)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
- errWithCode = responseJsonClient(c, response)
-
- // 如果报错,则退还配额
- if errWithCode != nil {
- quotaInfo.undo(c, errWithCode)
- return
- }
-
- quotaInfo.consume(c, usage)
-}
diff --git a/controller/relay-speech.go b/controller/relay-speech.go
deleted file mode 100644
index ec7fd7ce..00000000
--- a/controller/relay-speech.go
+++ /dev/null
@@ -1,62 +0,0 @@
-package controller
-
-import (
- "net/http"
- "one-api/common"
- providersBase "one-api/providers/base"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func RelaySpeech(c *gin.Context) {
-
- var speechRequest types.SpeechAudioRequest
-
- if err := common.UnmarshalBodyReusable(c, &speechRequest); err != nil {
- common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
- return
- }
-
- // 获取供应商
- provider, modelName, fail := getProvider(c, speechRequest.Model)
- if fail {
- return
- }
- speechRequest.Model = modelName
-
- speechProvider, ok := provider.(providersBase.SpeechInterface)
- if !ok {
- common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
- return
- }
-
- // 获取Input Tokens
- promptTokens := len(speechRequest.Input)
-
- usage := &types.Usage{
- PromptTokens: promptTokens,
- }
- provider.SetUsage(usage)
-
- quotaInfo, errWithCode := generateQuotaInfo(c, speechRequest.Model, promptTokens)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
-
- response, errWithCode := speechProvider.CreateSpeech(&speechRequest)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
- errWithCode = responseMultipart(c, response)
-
- // 如果报错,则退还配额
- if errWithCode != nil {
- quotaInfo.undo(c, errWithCode)
- return
- }
-
- quotaInfo.consume(c, usage)
-}
diff --git a/controller/relay-transcriptions.go b/controller/relay-transcriptions.go
deleted file mode 100644
index a6005963..00000000
--- a/controller/relay-transcriptions.go
+++ /dev/null
@@ -1,62 +0,0 @@
-package controller
-
-import (
- "net/http"
- "one-api/common"
- providersBase "one-api/providers/base"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func RelayTranscriptions(c *gin.Context) {
-
- var audioRequest types.AudioRequest
-
- if err := common.UnmarshalBodyReusable(c, &audioRequest); err != nil {
- common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
- return
- }
-
- // 获取供应商
- provider, modelName, fail := getProvider(c, audioRequest.Model)
- if fail {
- return
- }
- audioRequest.Model = modelName
-
- transcriptionsProvider, ok := provider.(providersBase.TranscriptionsInterface)
- if !ok {
- common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
- return
- }
-
- // 获取Input Tokens
- promptTokens := 0
-
- usage := &types.Usage{
- PromptTokens: promptTokens,
- }
- provider.SetUsage(usage)
-
- quotaInfo, errWithCode := generateQuotaInfo(c, audioRequest.Model, promptTokens)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
-
- response, errWithCode := transcriptionsProvider.CreateTranscriptions(&audioRequest)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
- errWithCode = responseCustom(c, response)
-
- // 如果报错,则退还配额
- if errWithCode != nil {
- quotaInfo.undo(c, errWithCode)
- return
- }
-
- quotaInfo.consume(c, usage)
-}
diff --git a/controller/relay-translations.go b/controller/relay-translations.go
deleted file mode 100644
index c13935eb..00000000
--- a/controller/relay-translations.go
+++ /dev/null
@@ -1,62 +0,0 @@
-package controller
-
-import (
- "net/http"
- "one-api/common"
- providersBase "one-api/providers/base"
- "one-api/types"
-
- "github.com/gin-gonic/gin"
-)
-
-func RelayTranslations(c *gin.Context) {
-
- var audioRequest types.AudioRequest
-
- if err := common.UnmarshalBodyReusable(c, &audioRequest); err != nil {
- common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
- return
- }
-
- // 获取供应商
- provider, modelName, fail := getProvider(c, audioRequest.Model)
- if fail {
- return
- }
- audioRequest.Model = modelName
-
- translationProvider, ok := provider.(providersBase.TranslationInterface)
- if !ok {
- common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
- return
- }
-
- // 获取Input Tokens
- promptTokens := 0
-
- usage := &types.Usage{
- PromptTokens: promptTokens,
- }
- provider.SetUsage(usage)
-
- quotaInfo, errWithCode := generateQuotaInfo(c, audioRequest.Model, promptTokens)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
-
- response, errWithCode := translationProvider.CreateTranslation(&audioRequest)
- if errWithCode != nil {
- errorHelper(c, errWithCode)
- return
- }
- errWithCode = responseCustom(c, response)
-
- // 如果报错,则退还配额
- if errWithCode != nil {
- quotaInfo.undo(c, errWithCode)
- return
- }
-
- quotaInfo.consume(c, usage)
-}
diff --git a/controller/relay.go b/controller/relay.go
deleted file mode 100644
index 8f5d260a..00000000
--- a/controller/relay.go
+++ /dev/null
@@ -1,63 +0,0 @@
-package controller
-
-import (
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/types"
- "strconv"
-
- "github.com/gin-gonic/gin"
-)
-
-func RelayNotImplemented(c *gin.Context) {
- err := types.OpenAIError{
- Message: "API not implemented",
- Type: "one_api_error",
- Param: "",
- Code: "api_not_implemented",
- }
- c.JSON(http.StatusNotImplemented, gin.H{
- "error": err,
- })
-}
-
-func RelayNotFound(c *gin.Context) {
- err := types.OpenAIError{
- Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
- Type: "invalid_request_error",
- Param: "",
- Code: "",
- }
- c.JSON(http.StatusNotFound, gin.H{
- "error": err,
- })
-}
-
-func errorHelper(c *gin.Context, err *types.OpenAIErrorWithStatusCode) {
- requestId := c.GetString(common.RequestIdKey)
- retryTimesStr := c.Query("retry")
- retryTimes, _ := strconv.Atoi(retryTimesStr)
- if retryTimesStr == "" {
- retryTimes = common.RetryTimes
- }
- if retryTimes > 0 {
- c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
- } else {
- if err.StatusCode == http.StatusTooManyRequests {
- err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
- }
- err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
- c.JSON(err.StatusCode, gin.H{
- "error": err.OpenAIError,
- })
- }
- channelId := c.GetInt("channel_id")
- common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
- // https://platform.openai.com/docs/guides/error-codes/api-errors
- if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
- channelId := c.GetInt("channel_id")
- channelName := c.GetString("channel_name")
- disableChannel(channelId, channelName, err.Message)
- }
-}
diff --git a/controller/relay/base.go b/controller/relay/base.go
new file mode 100644
index 00000000..0f644f02
--- /dev/null
+++ b/controller/relay/base.go
@@ -0,0 +1,53 @@
+package relay
+
+import (
+ "one-api/types"
+
+ providersBase "one-api/providers/base"
+
+ "github.com/gin-gonic/gin"
+)
+
+type relayBase struct {
+ c *gin.Context
+ provider providersBase.ProviderInterface
+ originalModel string
+ modelName string
+}
+
+type RelayBaseInterface interface {
+ send() (err *types.OpenAIErrorWithStatusCode, done bool)
+ getPromptTokens() (int, error)
+ setRequest() error
+ setProvider(modelName string) error
+ getProvider() providersBase.ProviderInterface
+ getOriginalModel() string
+ getModelName() string
+ getContext() *gin.Context
+}
+
+func (r *relayBase) setProvider(modelName string) error {
+ provider, modelName, fail := getProvider(r.c, modelName)
+ if fail != nil {
+ return fail
+ }
+ r.provider = provider
+ r.modelName = modelName
+ return nil
+}
+
+func (r *relayBase) getContext() *gin.Context {
+ return r.c
+}
+
+func (r *relayBase) getProvider() providersBase.ProviderInterface {
+ return r.provider
+}
+
+func (r *relayBase) getOriginalModel() string {
+ return r.originalModel
+}
+
+func (r *relayBase) getModelName() string {
+ return r.modelName
+}
diff --git a/controller/relay/chat.go b/controller/relay/chat.go
new file mode 100644
index 00000000..33e2b469
--- /dev/null
+++ b/controller/relay/chat.go
@@ -0,0 +1,76 @@
+package relay
+
+import (
+ "errors"
+ "math"
+ "net/http"
+ "one-api/common"
+ "one-api/common/requester"
+ providersBase "one-api/providers/base"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type relayChat struct {
+ relayBase
+ chatRequest types.ChatCompletionRequest
+}
+
+func NewRelayChat(c *gin.Context) *relayChat {
+ relay := &relayChat{}
+ relay.c = c
+ return relay
+}
+
+func (r *relayChat) setRequest() error {
+ if err := common.UnmarshalBodyReusable(r.c, &r.chatRequest); err != nil {
+ return err
+ }
+
+ if r.chatRequest.MaxTokens < 0 || r.chatRequest.MaxTokens > math.MaxInt32/2 {
+ return errors.New("max_tokens is invalid")
+ }
+
+ r.originalModel = r.chatRequest.Model
+
+ return nil
+}
+
+func (r *relayChat) getPromptTokens() (int, error) {
+ return common.CountTokenMessages(r.chatRequest.Messages, r.modelName), nil
+}
+
+func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
+ chatProvider, ok := r.provider.(providersBase.ChatInterface)
+ if !ok {
+ err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
+ done = true
+ return
+ }
+
+ r.chatRequest.Model = r.modelName
+
+ if r.chatRequest.Stream {
+ var response requester.StreamReaderInterface[string]
+ response, err = chatProvider.CreateChatCompletionStream(&r.chatRequest)
+ if err != nil {
+ return
+ }
+
+ err = responseStreamClient(r.c, response)
+ } else {
+ var response *types.ChatCompletionResponse
+ response, err = chatProvider.CreateChatCompletion(&r.chatRequest)
+ if err != nil {
+ return
+ }
+ err = responseJsonClient(r.c, response)
+ }
+
+ if err != nil {
+ done = true
+ }
+
+ return
+}
diff --git a/controller/relay/completions.go b/controller/relay/completions.go
new file mode 100644
index 00000000..fdfbd03e
--- /dev/null
+++ b/controller/relay/completions.go
@@ -0,0 +1,76 @@
+package relay
+
+import (
+ "errors"
+ "math"
+ "net/http"
+ "one-api/common"
+ "one-api/common/requester"
+ providersBase "one-api/providers/base"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type relayCompletions struct {
+ relayBase
+ request types.CompletionRequest
+}
+
+func NewRelayCompletions(c *gin.Context) *relayCompletions {
+ relay := &relayCompletions{}
+ relay.c = c
+ return relay
+}
+
+func (r *relayCompletions) setRequest() error {
+ if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
+ return err
+ }
+
+ if r.request.MaxTokens < 0 || r.request.MaxTokens > math.MaxInt32/2 {
+ return errors.New("max_tokens is invalid")
+ }
+
+ r.originalModel = r.request.Model
+
+ return nil
+}
+
+func (r *relayCompletions) getPromptTokens() (int, error) {
+ return common.CountTokenInput(r.request.Prompt, r.modelName), nil
+}
+
+func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
+ provider, ok := r.provider.(providersBase.CompletionInterface)
+ if !ok {
+ err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
+ done = true
+ return
+ }
+
+ r.request.Model = r.modelName
+
+ if r.request.Stream {
+ var response requester.StreamReaderInterface[string]
+ response, err = provider.CreateCompletionStream(&r.request)
+ if err != nil {
+ return
+ }
+
+ err = responseStreamClient(r.c, response)
+ } else {
+ var response *types.CompletionResponse
+ response, err = provider.CreateCompletion(&r.request)
+ if err != nil {
+ return
+ }
+ err = responseJsonClient(r.c, response)
+ }
+
+ if err != nil {
+ done = true
+ }
+
+ return
+}
diff --git a/controller/relay/embeddings.go b/controller/relay/embeddings.go
new file mode 100644
index 00000000..5a2a3bf7
--- /dev/null
+++ b/controller/relay/embeddings.go
@@ -0,0 +1,63 @@
+package relay
+
+import (
+ "net/http"
+ "one-api/common"
+ providersBase "one-api/providers/base"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+type relayEmbeddings struct {
+ relayBase
+ request types.EmbeddingRequest
+}
+
+func NewRelayEmbeddings(c *gin.Context) *relayEmbeddings {
+ relay := &relayEmbeddings{}
+ relay.c = c
+ return relay
+}
+
+func (r *relayEmbeddings) setRequest() error {
+ if strings.HasSuffix(r.c.Request.URL.Path, "embeddings") {
+ r.request.Model = r.c.Param("model")
+ }
+
+ if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
+ return err
+ }
+
+ r.originalModel = r.request.Model
+
+ return nil
+}
+
+func (r *relayEmbeddings) getPromptTokens() (int, error) {
+ return common.CountTokenInput(r.request.Input, r.modelName), nil
+}
+
+func (r *relayEmbeddings) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
+ provider, ok := r.provider.(providersBase.EmbeddingsInterface)
+ if !ok {
+ err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
+ done = true
+ return
+ }
+
+ r.request.Model = r.modelName
+
+ response, err := provider.CreateEmbeddings(&r.request)
+ if err != nil {
+ return
+ }
+ err = responseJsonClient(r.c, response)
+
+ if err != nil {
+ done = true
+ }
+
+ return
+}
diff --git a/controller/relay/image-edits.go b/controller/relay/image-edits.go
new file mode 100644
index 00000000..16a77bdf
--- /dev/null
+++ b/controller/relay/image-edits.go
@@ -0,0 +1,71 @@
+package relay
+
+import (
+ "errors"
+ "net/http"
+ "one-api/common"
+ providersBase "one-api/providers/base"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type relayImageEdits struct {
+ relayBase
+ request types.ImageEditRequest
+}
+
+func NewRelayImageEdits(c *gin.Context) *relayImageEdits {
+ relay := &relayImageEdits{}
+ relay.c = c
+ return relay
+}
+
+func (r *relayImageEdits) setRequest() error {
+ if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
+ return err
+ }
+
+ if r.request.Prompt == "" {
+ return errors.New("field prompt is required")
+ }
+
+ if r.request.Model == "" {
+ r.request.Model = "dall-e-2"
+ }
+
+ if r.request.Size == "" {
+ r.request.Size = "1024x1024"
+ }
+
+ r.originalModel = r.request.Model
+
+ return nil
+}
+
+func (r *relayImageEdits) getPromptTokens() (int, error) {
+ return common.CountTokenImage(r.request)
+}
+
+func (r *relayImageEdits) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
+ provider, ok := r.provider.(providersBase.ImageEditsInterface)
+ if !ok {
+ err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
+ done = true
+ return
+ }
+
+ r.request.Model = r.modelName
+
+ response, err := provider.CreateImageEdits(&r.request)
+ if err != nil {
+ return
+ }
+ err = responseJsonClient(r.c, response)
+
+ if err != nil {
+ done = true
+ }
+
+ return
+}
diff --git a/controller/relay/image-generations.go b/controller/relay/image-generations.go
new file mode 100644
index 00000000..c63e3c91
--- /dev/null
+++ b/controller/relay/image-generations.go
@@ -0,0 +1,74 @@
+package relay
+
+import (
+ "net/http"
+ "one-api/common"
+ providersBase "one-api/providers/base"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type relayImageGenerations struct {
+ relayBase
+ request types.ImageRequest
+}
+
+func NewRelayImageGenerations(c *gin.Context) *relayImageGenerations {
+ relay := &relayImageGenerations{}
+ relay.c = c
+ return relay
+}
+
+func (r *relayImageGenerations) setRequest() error {
+ if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
+ return err
+ }
+
+ if r.request.Model == "" {
+ r.request.Model = "dall-e-2"
+ }
+
+ if r.request.N == 0 {
+ r.request.N = 1
+ }
+
+ if r.request.Size == "" {
+ r.request.Size = "1024x1024"
+ }
+
+ if r.request.Quality == "" {
+ r.request.Quality = "standard"
+ }
+
+ r.originalModel = r.request.Model
+
+ return nil
+}
+
+func (r *relayImageGenerations) getPromptTokens() (int, error) {
+ return common.CountTokenImage(r.request)
+}
+
+func (r *relayImageGenerations) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
+ provider, ok := r.provider.(providersBase.ImageGenerationsInterface)
+ if !ok {
+ err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
+ done = true
+ return
+ }
+
+ r.request.Model = r.modelName
+
+ response, err := provider.CreateImageGenerations(&r.request)
+ if err != nil {
+ return
+ }
+ err = responseJsonClient(r.c, response)
+
+ if err != nil {
+ done = true
+ }
+
+ return
+}
diff --git a/controller/relay/image-variationsy.go b/controller/relay/image-variationsy.go
new file mode 100644
index 00000000..2ef3decc
--- /dev/null
+++ b/controller/relay/image-variationsy.go
@@ -0,0 +1,66 @@
+package relay
+
+import (
+ "net/http"
+ "one-api/common"
+ providersBase "one-api/providers/base"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type relayImageVariations struct {
+ relayBase
+ request types.ImageEditRequest
+}
+
+func NewRelayImageVariations(c *gin.Context) *relayImageVariations {
+ relay := &relayImageVariations{}
+ relay.c = c
+ return relay
+}
+
+func (r *relayImageVariations) setRequest() error {
+ if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
+ return err
+ }
+
+ if r.request.Model == "" {
+ r.request.Model = "dall-e-2"
+ }
+
+ if r.request.Size == "" {
+ r.request.Size = "1024x1024"
+ }
+
+ r.originalModel = r.request.Model
+
+ return nil
+}
+
+func (r *relayImageVariations) getPromptTokens() (int, error) {
+ return common.CountTokenImage(r.request)
+}
+
+func (r *relayImageVariations) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
+ provider, ok := r.provider.(providersBase.ImageVariationsInterface)
+ if !ok {
+ err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
+ done = true
+ return
+ }
+
+ r.request.Model = r.modelName
+
+ response, err := provider.CreateImageVariations(&r.request)
+ if err != nil {
+ return
+ }
+ err = responseJsonClient(r.c, response)
+
+ if err != nil {
+ done = true
+ }
+
+ return
+}
diff --git a/controller/relay/main.go b/controller/relay/main.go
new file mode 100644
index 00000000..6329dcb5
--- /dev/null
+++ b/controller/relay/main.go
@@ -0,0 +1,106 @@
+package relay
+
+import (
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+func Relay(c *gin.Context) {
+ relay := Path2Relay(c, c.Request.URL.Path)
+ if relay == nil {
+ common.AbortWithMessage(c, http.StatusNotFound, "Not Found")
+ return
+ }
+
+ if err := relay.setRequest(); err != nil {
+ common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
+ return
+ }
+
+ if err := relay.setProvider(relay.getOriginalModel()); err != nil {
+ common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error())
+ return
+ }
+
+ apiErr, done := RelayHandler(relay)
+ if apiErr == nil {
+ return
+ }
+
+ channel := relay.getProvider().GetChannel()
+ go processChannelRelayError(c.Request.Context(), channel.Id, channel.Name, apiErr)
+
+ retryTimes := common.RetryTimes
+ if done || !shouldRetry(c, apiErr.StatusCode) {
+ common.LogError(c.Request.Context(), fmt.Sprintf("relay error happen, status code is %d, won't retry in this case", apiErr.StatusCode))
+ retryTimes = 0
+ }
+
+ for i := retryTimes; i > 0; i-- {
+ // 冻结通道
+ model.ChannelGroup.Cooldowns(channel.Id)
+ if err := relay.setProvider(relay.getOriginalModel()); err != nil {
+ continue
+ }
+
+ channel = relay.getProvider().GetChannel()
+ common.LogError(c.Request.Context(), fmt.Sprintf("using channel #%d(%s) to retry (remain times %d)", channel.Id, channel.Name, i))
+ apiErr, done = RelayHandler(relay)
+ if apiErr == nil {
+ return
+ }
+ go processChannelRelayError(c.Request.Context(), channel.Id, channel.Name, apiErr)
+ if done || !shouldRetry(c, apiErr.StatusCode) {
+ break
+ }
+ }
+
+ if apiErr != nil {
+ requestId := c.GetString(common.RequestIdKey)
+ if apiErr.StatusCode == http.StatusTooManyRequests {
+ apiErr.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
+ }
+ apiErr.OpenAIError.Message = common.MessageWithRequestId(apiErr.OpenAIError.Message, requestId)
+ c.JSON(apiErr.StatusCode, gin.H{
+ "error": apiErr.OpenAIError,
+ })
+
+ }
+}
+
+func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCode, done bool) {
+ promptTokens, tonkeErr := relay.getPromptTokens()
+ if tonkeErr != nil {
+ err = common.ErrorWrapper(tonkeErr, "token_error", http.StatusBadRequest)
+ done = true
+ return
+ }
+
+ usage := &types.Usage{
+ PromptTokens: promptTokens,
+ }
+
+ relay.getProvider().SetUsage(usage)
+
+ var quotaInfo *QuotaInfo
+ quotaInfo, err = generateQuotaInfo(relay.getContext(), relay.getModelName(), promptTokens)
+ if err != nil {
+ done = true
+ return
+ }
+
+ err, done = relay.send()
+
+ if err != nil {
+ quotaInfo.undo(relay.getContext())
+ return
+ }
+
+ quotaInfo.consume(relay.getContext(), usage)
+ return
+}
diff --git a/controller/relay/moderations.go b/controller/relay/moderations.go
new file mode 100644
index 00000000..471f1640
--- /dev/null
+++ b/controller/relay/moderations.go
@@ -0,0 +1,62 @@
+package relay
+
+import (
+ "net/http"
+ "one-api/common"
+ providersBase "one-api/providers/base"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type relayModerations struct {
+ relayBase
+ request types.ModerationRequest
+}
+
+func NewRelayModerations(c *gin.Context) *relayModerations {
+ relay := &relayModerations{}
+ relay.c = c
+ return relay
+}
+
+func (r *relayModerations) setRequest() error {
+ if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
+ return err
+ }
+
+ if r.request.Model == "" {
+ r.request.Model = "text-moderation-stable"
+ }
+
+ r.originalModel = r.request.Model
+
+ return nil
+}
+
+func (r *relayModerations) getPromptTokens() (int, error) {
+ return common.CountTokenInput(r.request.Input, r.modelName), nil
+}
+
+func (r *relayModerations) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
+ provider, ok := r.provider.(providersBase.ModerationInterface)
+ if !ok {
+ err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
+ done = true
+ return
+ }
+
+ r.request.Model = r.modelName
+
+ response, err := provider.CreateModeration(&r.request)
+ if err != nil {
+ return
+ }
+ err = responseJsonClient(r.c, response)
+
+ if err != nil {
+ done = true
+ }
+
+ return
+}
diff --git a/controller/quota.go b/controller/relay/quota.go
similarity index 97%
rename from controller/quota.go
rename to controller/relay/quota.go
index c2f96954..3343b6c5 100644
--- a/controller/quota.go
+++ b/controller/relay/quota.go
@@ -1,4 +1,4 @@
-package controller
+package relay
import (
"context"
@@ -144,7 +144,7 @@ func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName stri
return nil
}
-func (q *QuotaInfo) undo(c *gin.Context, errWithCode *types.OpenAIErrorWithStatusCode) {
+func (q *QuotaInfo) undo(c *gin.Context) {
tokenId := c.GetInt("token_id")
if q.HandelStatus {
go func(ctx context.Context) {
@@ -155,7 +155,6 @@ func (q *QuotaInfo) undo(c *gin.Context, errWithCode *types.OpenAIErrorWithStatu
}
}(c.Request.Context())
}
- errorHelper(c, errWithCode)
}
func (q *QuotaInfo) consume(c *gin.Context, usage *types.Usage) {
diff --git a/controller/relay/speech.go b/controller/relay/speech.go
new file mode 100644
index 00000000..45d6f6d6
--- /dev/null
+++ b/controller/relay/speech.go
@@ -0,0 +1,58 @@
+package relay
+
+import (
+ "net/http"
+ "one-api/common"
+ providersBase "one-api/providers/base"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type relaySpeech struct {
+ relayBase
+ request types.SpeechAudioRequest
+}
+
+func NewRelaySpeech(c *gin.Context) *relaySpeech {
+ relay := &relaySpeech{}
+ relay.c = c
+ return relay
+}
+
+func (r *relaySpeech) setRequest() error {
+ if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
+ return err
+ }
+
+ r.originalModel = r.request.Model
+
+ return nil
+}
+
+func (r *relaySpeech) getPromptTokens() (int, error) {
+ return len(r.request.Input), nil
+}
+
+func (r *relaySpeech) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
+ provider, ok := r.provider.(providersBase.SpeechInterface)
+ if !ok {
+ err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
+ done = true
+ return
+ }
+
+ r.request.Model = r.modelName
+
+ response, err := provider.CreateSpeech(&r.request)
+ if err != nil {
+ return
+ }
+ err = responseMultipart(r.c, response)
+
+ if err != nil {
+ done = true
+ }
+
+ return
+}
diff --git a/controller/relay/transcriptions.go b/controller/relay/transcriptions.go
new file mode 100644
index 00000000..17605a00
--- /dev/null
+++ b/controller/relay/transcriptions.go
@@ -0,0 +1,58 @@
+package relay
+
+import (
+ "net/http"
+ "one-api/common"
+ providersBase "one-api/providers/base"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type relayTranscriptions struct {
+ relayBase
+ request types.AudioRequest
+}
+
+func NewRelayTranscriptions(c *gin.Context) *relayTranscriptions {
+ relay := &relayTranscriptions{}
+ relay.c = c
+ return relay
+}
+
+func (r *relayTranscriptions) setRequest() error {
+ if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
+ return err
+ }
+
+ r.originalModel = r.request.Model
+
+ return nil
+}
+
+func (r *relayTranscriptions) getPromptTokens() (int, error) {
+ return 0, nil
+}
+
+func (r *relayTranscriptions) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
+ provider, ok := r.provider.(providersBase.TranscriptionsInterface)
+ if !ok {
+ err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
+ done = true
+ return
+ }
+
+ r.request.Model = r.modelName
+
+ response, err := provider.CreateTranscriptions(&r.request)
+ if err != nil {
+ return
+ }
+ err = responseCustom(r.c, response)
+
+ if err != nil {
+ done = true
+ }
+
+ return
+}
diff --git a/controller/relay/translations.go b/controller/relay/translations.go
new file mode 100644
index 00000000..abb65296
--- /dev/null
+++ b/controller/relay/translations.go
@@ -0,0 +1,58 @@
+package relay
+
+import (
+ "net/http"
+ "one-api/common"
+ providersBase "one-api/providers/base"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type relayTranslations struct {
+ relayBase
+ request types.AudioRequest
+}
+
+func NewRelayTranslations(c *gin.Context) *relayTranslations {
+ relay := &relayTranslations{}
+ relay.c = c
+ return relay
+}
+
+func (r *relayTranslations) setRequest() error {
+ if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
+ return err
+ }
+
+ r.originalModel = r.request.Model
+
+ return nil
+}
+
+func (r *relayTranslations) getPromptTokens() (int, error) {
+ return 0, nil
+}
+
+func (r *relayTranslations) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
+ provider, ok := r.provider.(providersBase.TranslationInterface)
+ if !ok {
+ err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
+ done = true
+ return
+ }
+
+ r.request.Model = r.modelName
+
+ response, err := provider.CreateTranslation(&r.request)
+ if err != nil {
+ return
+ }
+ err = responseCustom(r.c, response)
+
+ if err != nil {
+ done = true
+ }
+
+ return
+}
diff --git a/controller/relay-utils.go b/controller/relay/utils.go
similarity index 56%
rename from controller/relay-utils.go
rename to controller/relay/utils.go
index 1d5d019c..d3443f4e 100644
--- a/controller/relay-utils.go
+++ b/controller/relay/utils.go
@@ -1,6 +1,7 @@
-package controller
+package relay
import (
+ "context"
"encoding/json"
"errors"
"fmt"
@@ -8,127 +9,98 @@ import (
"net/http"
"one-api/common"
"one-api/common/requester"
+ "one-api/controller"
"one-api/model"
"one-api/providers"
providersBase "one-api/providers/base"
"one-api/types"
- "reflect"
+ "strings"
"github.com/gin-gonic/gin"
- "github.com/go-playground/validator/v10"
)
-func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail bool) {
+func Path2Relay(c *gin.Context, path string) RelayBaseInterface {
+ if strings.HasPrefix(path, "/v1/chat/completions") {
+ return NewRelayChat(c)
+ } else if strings.HasPrefix(path, "/v1/completions") {
+ return NewRelayCompletions(c)
+ } else if strings.HasPrefix(path, "/v1/embeddings") {
+ return NewRelayEmbeddings(c)
+ } else if strings.HasPrefix(path, "/v1/moderations") {
+ return NewRelayModerations(c)
+ } else if strings.HasPrefix(path, "/v1/images/generations") {
+ return NewRelayImageGenerations(c)
+ } else if strings.HasPrefix(path, "/v1/images/edits") {
+ return NewRelayImageEdits(c)
+ } else if strings.HasPrefix(path, "/v1/images/variations") {
+ return NewRelayImageVariations(c)
+ } else if strings.HasPrefix(path, "/v1/audio/speech") {
+ return NewRelaySpeech(c)
+ } else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
+ return NewRelayTranscriptions(c)
+ } else if strings.HasPrefix(path, "/v1/audio/translations") {
+ return NewRelayTranslations(c)
+ }
+
+ return nil
+}
+
+func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) {
channel, fail := fetchChannel(c, modeName)
- if fail {
+ if fail != nil {
return
}
+ c.Set("channel_id", channel.Id)
provider = providers.GetProvider(channel, c)
if provider == nil {
- common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found")
- fail = true
+ fail = errors.New("channel not found")
return
}
+ provider.SetOriginalModel(modeName)
- newModelName, err := provider.ModelMappingHandler(modeName)
- if err != nil {
- common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
- fail = true
+ newModelName, fail = provider.ModelMappingHandler(modeName)
+ if fail != nil {
return
}
return
}
-func GetValidFieldName(err error, obj interface{}) string {
- getObj := reflect.TypeOf(obj)
- if errs, ok := err.(validator.ValidationErrors); ok {
- for _, e := range errs {
- if f, exist := getObj.Elem().FieldByName(e.Field()); exist {
- return f.Name
- }
- }
- }
- return err.Error()
-}
-
-func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail bool) {
- channelId := c.GetInt("channelId")
+func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail error) {
+ channelId := c.GetInt("specific_channel_id")
if channelId > 0 {
- channel, fail = fetchChannelById(c, channelId)
- if fail {
- return
- }
-
- }
- channel, fail = fetchChannelByModel(c, modelName)
- if fail {
- return
+ return fetchChannelById(channelId)
}
- c.Set("channel_id", channel.Id)
-
- return
+ return fetchChannelByModel(c, modelName)
}
-func fetchChannelById(c *gin.Context, channelId int) (*model.Channel, bool) {
+func fetchChannelById(channelId int) (*model.Channel, error) {
channel, err := model.GetChannelById(channelId, true)
if err != nil {
- common.AbortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
- return nil, true
+ return nil, errors.New("无效的渠道 Id")
}
if channel.Status != common.ChannelStatusEnabled {
- common.AbortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
- return nil, true
+ return nil, errors.New("该渠道已被禁用")
}
- return channel, false
+ return channel, nil
}
-func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, bool) {
+func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, error) {
group := c.GetString("group")
- channel, err := model.CacheGetRandomSatisfiedChannel(group, modelName)
+ channel, err := model.ChannelGroup.Next(group, modelName)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, modelName)
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
- common.AbortWithMessage(c, http.StatusServiceUnavailable, message)
- return nil, true
+ return nil, errors.New(message)
}
- return channel, false
-}
-
-func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
- if !common.AutomaticDisableChannelEnabled {
- return false
- }
- if err == nil {
- return false
- }
- if statusCode == http.StatusUnauthorized {
- return true
- }
- if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
- return true
- }
- return false
-}
-
-func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool {
- if !common.AutomaticEnableChannelEnabled {
- return false
- }
- if err != nil {
- return false
- }
- if openAIErr != nil {
- return false
- }
- return true
+ return channel, nil
}
func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWithStatusCode {
@@ -201,3 +173,30 @@ func responseCustom(c *gin.Context, response *types.AudioResponseWrapper) *types
return nil
}
+
+func shouldRetry(c *gin.Context, statusCode int) bool {
+ channelId := c.GetInt("specific_channel_id")
+ if channelId > 0 {
+ return false
+ }
+ if statusCode == http.StatusTooManyRequests {
+ return true
+ }
+ if statusCode/100 == 5 {
+ return true
+ }
+ if statusCode == http.StatusBadRequest {
+ return false
+ }
+ if statusCode/100 == 2 {
+ return false
+ }
+ return true
+}
+
+func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *types.OpenAIErrorWithStatusCode) {
+ common.LogError(ctx, fmt.Sprintf("relay error (channel #%d(%s)): %s", channelId, channelName, err.Message))
+ if controller.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) {
+ controller.DisableChannel(channelId, channelName, err.Message)
+ }
+}
diff --git a/main.go b/main.go
index 16ef5e48..5a13aa51 100644
--- a/main.go
+++ b/main.go
@@ -59,11 +59,11 @@ func main() {
if common.MemoryCacheEnabled {
common.SysLog("memory cache enabled")
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
- model.InitChannelCache()
+ model.InitChannelGroup()
}
if common.MemoryCacheEnabled {
go model.SyncOptions(common.SyncFrequency)
- go model.SyncChannelCache(common.SyncFrequency)
+ go model.SyncChannelGroup(common.SyncFrequency)
}
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
diff --git a/middleware/auth.go b/middleware/auth.go
index e6f48c62..2537ceae 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -114,7 +114,7 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id")
return
}
- c.Set("channelId", channelId)
+ c.Set("specific_channel_id", channelId)
} else {
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return
diff --git a/model/ability.go b/model/ability.go
index 1a53b3ef..8aeb2811 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -11,6 +11,7 @@ type Ability struct {
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
+ Weight *uint `json:"weight" gorm:"default:1"`
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
@@ -67,6 +68,7 @@ func (channel *Channel) AddAbilities() error {
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
+ Weight: channel.Weight,
}
abilities = append(abilities, ability)
}
@@ -98,3 +100,49 @@ func (channel *Channel) UpdateAbilities() error {
func UpdateAbilityStatus(channelId int, status bool) error {
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
}
+
+func GetEnabledAbility() ([]*Ability, error) {
+ trueVal := "1"
+ if common.UsingPostgreSQL {
+ trueVal = "true"
+ }
+
+ var abilities []*Ability
+ err := DB.Where("enabled = ?", trueVal).Order("priority desc, weight desc").Find(&abilities).Error
+ return abilities, err
+}
+
+type AbilityChannelGroup struct {
+ Group string `json:"group"`
+ Model string `json:"model"`
+ Priority int `json:"priority"`
+ ChannelIds string `json:"channel_ids"`
+}
+
+func GetAbilityChannelGroup() ([]*AbilityChannelGroup, error) {
+ var abilities []*AbilityChannelGroup
+
+ var channelSql string
+ if common.UsingPostgreSQL {
+ channelSql = `string_agg("channel_id"::text, ',')`
+ } else if common.UsingSQLite {
+ channelSql = `group_concat("channel_id", ',')`
+ } else {
+ channelSql = "GROUP_CONCAT(`channel_id` SEPARATOR ',')"
+ }
+
+ trueVal := "1"
+ if common.UsingPostgreSQL {
+ trueVal = "true"
+ }
+
+ err := DB.Raw(`
+ SELECT `+quotePostgresField("group")+`, model, priority, `+channelSql+` as channel_ids
+ FROM abilities
+ WHERE enabled = ?
+ GROUP BY `+quotePostgresField("group")+`, model, priority
+ ORDER BY priority DESC
+ `, trueVal).Scan(&abilities).Error
+
+ return abilities, err
+}
diff --git a/model/balancer.go b/model/balancer.go
new file mode 100644
index 00000000..38a90c0e
--- /dev/null
+++ b/model/balancer.go
@@ -0,0 +1,176 @@
+package model
+
+import (
+ "errors"
+ "math/rand"
+ "one-api/common"
+ "strings"
+ "sync"
+ "time"
+)
+
+type ChannelChoice struct {
+ Channel *Channel
+ CooldownsTime int64
+}
+
+type ChannelsChooser struct {
+ sync.RWMutex
+ Channels map[int]*ChannelChoice
+ Rule map[string]map[string][][]int // group -> model -> priority -> channelIds
+}
+
+func (cc *ChannelsChooser) Cooldowns(channelId int) bool {
+ if common.RetryCooldownSeconds == 0 {
+ return false
+ }
+ cc.Lock()
+ defer cc.Unlock()
+ if _, ok := cc.Channels[channelId]; !ok {
+ return false
+ }
+
+ cc.Channels[channelId].CooldownsTime = time.Now().Unix() + int64(common.RetryCooldownSeconds)
+ return true
+}
+
+func (cc *ChannelsChooser) Balancer(channelIds []int) *Channel {
+ nowTime := time.Now().Unix()
+ totalWeight := 0
+
+ validChannels := make([]*ChannelChoice, 0, len(channelIds))
+ for _, channelId := range channelIds {
+ if choice, ok := cc.Channels[channelId]; ok && choice.CooldownsTime < nowTime {
+ weight := int(*choice.Channel.Weight)
+ totalWeight += weight
+ validChannels = append(validChannels, choice)
+ }
+ }
+
+ if len(validChannels) == 0 {
+ return nil
+ }
+
+ if len(validChannels) == 1 {
+ return validChannels[0].Channel
+ }
+
+ choiceWeight := rand.Intn(totalWeight)
+ for _, choice := range validChannels {
+ weight := int(*choice.Channel.Weight)
+ choiceWeight -= weight
+ if choiceWeight < 0 {
+ return choice.Channel
+ }
+ }
+
+ return nil
+}
+
+func (cc *ChannelsChooser) Next(group, model string) (*Channel, error) {
+ if !common.MemoryCacheEnabled {
+ return GetRandomSatisfiedChannel(group, model)
+ }
+ cc.RLock()
+ defer cc.RUnlock()
+ if _, ok := cc.Rule[group]; !ok {
+ return nil, errors.New("group not found")
+ }
+
+ if _, ok := cc.Rule[group][model]; !ok {
+ return nil, errors.New("model not found")
+ }
+
+ channelsPriority := cc.Rule[group][model]
+ if len(channelsPriority) == 0 {
+ return nil, errors.New("channel not found")
+ }
+
+ for _, priority := range channelsPriority {
+ channel := cc.Balancer(priority)
+ if channel != nil {
+ return channel, nil
+ }
+ }
+
+ return nil, errors.New("channel not found")
+}
+
+func (cc *ChannelsChooser) GetGroupModels(group string) ([]string, error) {
+ if !common.MemoryCacheEnabled {
+ return GetGroupModels(group)
+ }
+
+ cc.RLock()
+ defer cc.RUnlock()
+
+ if _, ok := cc.Rule[group]; !ok {
+ return nil, errors.New("group not found")
+ }
+
+ models := make([]string, 0, len(cc.Rule[group]))
+ for model := range cc.Rule[group] {
+ models = append(models, model)
+ }
+
+ return models, nil
+}
+
+var ChannelGroup = ChannelsChooser{}
+
+func InitChannelGroup() {
+ var channels []*Channel
+ DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
+
+ abilities, err := GetAbilityChannelGroup()
+ if err != nil {
+ common.SysLog("get enabled abilities failed: " + err.Error())
+ return
+ }
+
+ newGroup := make(map[string]map[string][][]int)
+ newChannels := make(map[int]*ChannelChoice)
+
+ for _, channel := range channels {
+ if *channel.Weight == 0 {
+ channel.Weight = &common.DefaultChannelWeight
+ }
+ newChannels[channel.Id] = &ChannelChoice{
+ Channel: channel,
+ CooldownsTime: 0,
+ }
+ }
+
+ for _, ability := range abilities {
+ if _, ok := newGroup[ability.Group]; !ok {
+ newGroup[ability.Group] = make(map[string][][]int)
+ }
+
+ if _, ok := newGroup[ability.Group][ability.Model]; !ok {
+ newGroup[ability.Group][ability.Model] = make([][]int, 0)
+ }
+
+ var priorityIds []int
+ // 逗号分割 ability.ChannelId
+ channelIds := strings.Split(ability.ChannelIds, ",")
+ for _, channelId := range channelIds {
+ priorityIds = append(priorityIds, common.String2Int(channelId))
+ }
+
+ newGroup[ability.Group][ability.Model] = append(newGroup[ability.Group][ability.Model], priorityIds)
+ }
+
+ ChannelGroup.Lock()
+ ChannelGroup.Rule = newGroup
+ ChannelGroup.Channels = newChannels
+ ChannelGroup.Unlock()
+ common.SysLog("channels synced from database")
+}
+
+func SyncChannelGroup(frequency int) {
+ for {
+ time.Sleep(time.Duration(frequency) * time.Second)
+ common.SysLog("syncing channels from database")
+ InitChannelGroup()
+ }
+}
diff --git a/model/cache.go b/model/cache.go
index e26995fe..9edb8a64 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -2,14 +2,9 @@ package model
import (
"encoding/json"
- "errors"
"fmt"
- "math/rand"
"one-api/common"
- "sort"
"strconv"
- "strings"
- "sync"
"time"
)
@@ -131,104 +126,3 @@ func CacheIsUserEnabled(userId int) (bool, error) {
}
return userEnabled, err
}
-
-var group2model2channels map[string]map[string][]*Channel
-var channelSyncLock sync.RWMutex
-
-func InitChannelCache() {
- newChannelId2channel := make(map[int]*Channel)
- var channels []*Channel
- DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
- for _, channel := range channels {
- newChannelId2channel[channel.Id] = channel
- }
- var abilities []*Ability
- DB.Find(&abilities)
- groups := make(map[string]bool)
- for _, ability := range abilities {
- groups[ability.Group] = true
- }
- newGroup2model2channels := make(map[string]map[string][]*Channel)
- for group := range groups {
- newGroup2model2channels[group] = make(map[string][]*Channel)
- }
- for _, channel := range channels {
- groups := strings.Split(channel.Group, ",")
- for _, group := range groups {
- models := strings.Split(channel.Models, ",")
- for _, model := range models {
- if _, ok := newGroup2model2channels[group][model]; !ok {
- newGroup2model2channels[group][model] = make([]*Channel, 0)
- }
- newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
- }
- }
- }
-
- // sort by priority
- for group, model2channels := range newGroup2model2channels {
- for model, channels := range model2channels {
- sort.Slice(channels, func(i, j int) bool {
- return channels[i].GetPriority() > channels[j].GetPriority()
- })
- newGroup2model2channels[group][model] = channels
- }
- }
-
- channelSyncLock.Lock()
- group2model2channels = newGroup2model2channels
- channelSyncLock.Unlock()
- common.SysLog("channels synced from database")
-}
-
-func SyncChannelCache(frequency int) {
- for {
- time.Sleep(time.Duration(frequency) * time.Second)
- common.SysLog("syncing channels from database")
- InitChannelCache()
- }
-}
-
-func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
- if !common.MemoryCacheEnabled {
- return GetRandomSatisfiedChannel(group, model)
- }
- channelSyncLock.RLock()
- defer channelSyncLock.RUnlock()
- channels := group2model2channels[group][model]
- if len(channels) == 0 {
- return nil, errors.New("channel not found")
- }
- endIdx := len(channels)
- // choose by priority
- firstChannel := channels[0]
- if firstChannel.GetPriority() > 0 {
- for i := range channels {
- if channels[i].GetPriority() != firstChannel.GetPriority() {
- endIdx = i
- break
- }
- }
- }
- idx := rand.Intn(endIdx)
- return channels[idx], nil
-}
-
-func CacheGetGroupModels(group string) ([]string, error) {
- if !common.MemoryCacheEnabled {
- return GetGroupModels(group)
- }
- channelSyncLock.RLock()
- defer channelSyncLock.RUnlock()
-
- groupModels := group2model2channels[group]
- if groupModels == nil {
- return nil, errors.New("group not found")
- }
-
- models := make([]string, 0)
- for model := range groupModels {
- models = append(models, model)
- }
- return models, nil
-}
diff --git a/model/channel.go b/model/channel.go
index 30a3f61a..6d4413ec 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -13,7 +13,7 @@ type Channel struct {
Key string `json:"key" form:"key" gorm:"type:varchar(767);not null;index"`
Status int `json:"status" form:"status" gorm:"default:1"`
Name string `json:"name" form:"name" gorm:"index"`
- Weight *uint `json:"weight" gorm:"default:0"`
+ Weight *uint `json:"weight" gorm:"default:1"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds
@@ -95,11 +95,8 @@ func GetAllChannels() ([]*Channel, error) {
func GetChannelById(id int, selectAll bool) (*Channel, error) {
channel := Channel{Id: id}
var err error = nil
- if selectAll {
- err = DB.First(&channel, "id = ?", id).Error
- } else {
- err = DB.Omit("key").First(&channel, "id = ?", id).Error
- }
+ err = DB.First(&channel, "id = ?", id).Error
+
return &channel, err
}
diff --git a/model/option.go b/model/option.go
index 084427f4..24e9a20b 100644
--- a/model/option.go
+++ b/model/option.go
@@ -77,6 +77,8 @@ func InitOptionMap() {
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
+ common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds)
+
common.OptionMapRWMutex.Unlock()
initModelRatio()
loadOptionsFromDatabase()
@@ -146,6 +148,7 @@ var optionIntMap = map[string]*int{
"QuotaRemindThreshold": &common.QuotaRemindThreshold,
"PreConsumedQuota": &common.PreConsumedQuota,
"RetryTimes": &common.RetryTimes,
+ "RetryCooldownSeconds": &common.RetryCooldownSeconds,
}
var optionBoolMap = map[string]*bool{
diff --git a/router/relay-router.go b/router/relay-router.go
index 519a764b..f6f548e7 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -2,6 +2,7 @@ package router
import (
"one-api/controller"
+ "one-api/controller/relay"
"one-api/middleware"
"github.com/gin-gonic/gin"
@@ -19,18 +20,18 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
{
- relayV1Router.POST("/completions", controller.RelayCompletions)
- relayV1Router.POST("/chat/completions", controller.RelayChat)
+ relayV1Router.POST("/completions", relay.Relay)
+ relayV1Router.POST("/chat/completions", relay.Relay)
// relayV1Router.POST("/edits", controller.Relay)
- relayV1Router.POST("/images/generations", controller.RelayImageGenerations)
- relayV1Router.POST("/images/edits", controller.RelayImageEdits)
- relayV1Router.POST("/images/variations", controller.RelayImageVariations)
- relayV1Router.POST("/embeddings", controller.RelayEmbeddings)
+ relayV1Router.POST("/images/generations", relay.Relay)
+ relayV1Router.POST("/images/edits", relay.Relay)
+ relayV1Router.POST("/images/variations", relay.Relay)
+ relayV1Router.POST("/embeddings", relay.Relay)
// relayV1Router.POST("/engines/:model/embeddings", controller.RelayEmbeddings)
- relayV1Router.POST("/audio/transcriptions", controller.RelayTranscriptions)
- relayV1Router.POST("/audio/translations", controller.RelayTranslations)
- relayV1Router.POST("/audio/speech", controller.RelaySpeech)
- relayV1Router.POST("/moderations", controller.RelayModerations)
+ relayV1Router.POST("/audio/transcriptions", relay.Relay)
+ relayV1Router.POST("/audio/translations", relay.Relay)
+ relayV1Router.POST("/audio/speech", relay.Relay)
+ relayV1Router.POST("/moderations", relay.Relay)
relayV1Router.GET("/files", controller.RelayNotImplemented)
relayV1Router.POST("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
diff --git a/web/src/views/Channel/component/EditModal.js b/web/src/views/Channel/component/EditModal.js
index a1fb429c..aa52b40d 100644
--- a/web/src/views/Channel/component/EditModal.js
+++ b/web/src/views/Channel/component/EditModal.js
@@ -21,7 +21,8 @@ import {
Container,
Autocomplete,
FormHelperText,
- Checkbox
+ Checkbox,
+ Switch
} from '@mui/material';
import { Formik } from 'formik';
@@ -73,6 +74,7 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => {
const [inputLabel, setInputLabel] = useState(defaultConfig.inputLabel); //
const [inputPrompt, setInputPrompt] = useState(defaultConfig.prompt);
const [modelOptions, setModelOptions] = useState([]);
+ const [batchAdd, setBatchAdd] = useState(false);
const initChannel = (typeValue) => {
if (typeConfig[typeValue]?.inputLabel) {
@@ -246,6 +248,7 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => {
}, []);
useEffect(() => {
+ setBatchAdd(false);
if (channelId) {
loadChannel().then();
} else {
@@ -479,18 +482,36 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => {
- {inputLabel.key}
-
+ {!batchAdd ? (
+ <>
+ {inputLabel.key}
+
+ >
+ ) : (
+
+ )}
+
{touched.key && errors.key ? (
{errors.key}
@@ -499,6 +520,17 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => {
{inputPrompt.key}
)}
+ {channelId === 0 && (
+
+ setBatchAdd(e.target.checked)} />
+ 批量添加
+
+ )}
+
{/* {inputLabel.model_mapping} */}
{
+ if (weightValve === '' || weightValve === item.weight) {
+ return;
+ }
+
+ if (weightValve <= 0) {
+ showError('权重不能小于 0');
+ return;
+ }
+
+ await manageChannel(item.id, 'weight', weightValve);
+ };
+
const handleResponseTime = async () => {
const { success, time } = await manageChannel(item.id, 'test', '');
if (success) {
@@ -176,6 +196,25 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal,
/>
+
+
+ 权重
+ setWeight(e.target.value)}
+ sx={{ textAlign: 'center' }}
+ endAdornment={
+
+
+
+
+
+ }
+ />
+
+
@@ -204,6 +243,16 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal,
编辑
+
+
+