✨ feat: channel support weight (#85)
* ✨ feat: channel support weight * 💄 improve: show version * 💄 improve: Channel add copy operation * 💄 improve: Channel support batch add
This commit is contained in:
parent
7c78ed9fad
commit
dd3e79a20d
10
.github/workflows/docker-image.yml
vendored
10
.github/workflows/docker-image.yml
vendored
@ -25,7 +25,15 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
- name: Save version info
|
||||||
|
run: |
|
||||||
|
TAG=$(git describe --tags --exact-match 2> /dev/null)
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
echo $TAG > VERSION
|
||||||
|
else
|
||||||
|
HASH=$(git rev-parse --short=7 HEAD)
|
||||||
|
echo "dev-$HASH" > VERSION
|
||||||
|
fi
|
||||||
- name: Set up QEMU
|
- name: Set up QEMU
|
||||||
uses: docker/setup-qemu-action@v2
|
uses: docker/setup-qemu-action@v2
|
||||||
|
|
||||||
|
@ -83,6 +83,8 @@ var QuotaRemindThreshold = 1000
|
|||||||
var PreConsumedQuota = 500
|
var PreConsumedQuota = 500
|
||||||
var ApproximateTokenEnabled = false
|
var ApproximateTokenEnabled = false
|
||||||
var RetryTimes = 0
|
var RetryTimes = 0
|
||||||
|
var DefaultChannelWeight = uint(1)
|
||||||
|
var RetryCooldownSeconds = 5
|
||||||
|
|
||||||
var RootUserEmail = ""
|
var RootUserEmail = ""
|
||||||
|
|
||||||
|
@ -122,7 +122,7 @@ func updateAllChannelsBalance() error {
|
|||||||
} else {
|
} else {
|
||||||
// err is nil & balance <= 0 means quota is used up
|
// err is nil & balance <= 0 means quota is used up
|
||||||
if balance <= 0 {
|
if balance <= 0 {
|
||||||
disableChannel(channel.Id, channel.Name, "余额不足")
|
DisableChannel(channel.Id, channel.Name, "余额不足")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
time.Sleep(common.RequestInterval)
|
time.Sleep(common.RequestInterval)
|
||||||
|
@ -140,14 +140,6 @@ func notifyRootUser(subject string, content string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// disable & notify
|
|
||||||
func disableChannel(channelId int, channelName string, reason string) {
|
|
||||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
|
|
||||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
|
||||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
|
||||||
notifyRootUser(subject, content)
|
|
||||||
}
|
|
||||||
|
|
||||||
// enable & notify
|
// enable & notify
|
||||||
func enableChannel(channelId int, channelName string) {
|
func enableChannel(channelId int, channelName string) {
|
||||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
|
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
|
||||||
@ -185,10 +177,10 @@ func testAllChannels(notify bool) error {
|
|||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
if milliseconds > disableThreshold {
|
if milliseconds > disableThreshold {
|
||||||
err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
|
err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
|
||||||
disableChannel(channel.Id, channel.Name, err.Error())
|
DisableChannel(channel.Id, channel.Name, err.Error())
|
||||||
}
|
}
|
||||||
if isChannelEnabled && shouldDisableChannel(openaiErr, -1) {
|
if isChannelEnabled && ShouldDisableChannel(openaiErr, -1) {
|
||||||
disableChannel(channel.Id, channel.Name, err.Error())
|
DisableChannel(channel.Id, channel.Name, err.Error())
|
||||||
}
|
}
|
||||||
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
|
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
|
||||||
enableChannel(channel.Id, channel.Name)
|
enableChannel(channel.Id, channel.Name)
|
||||||
|
75
controller/common.go
Normal file
75
controller/common.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool {
|
||||||
|
if !common.AutomaticEnableChannelEnabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if openAIErr != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func ShouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
|
||||||
|
if !common.AutomaticDisableChannelEnabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if statusCode == http.StatusUnauthorized {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// disable & notify
|
||||||
|
func DisableChannel(channelId int, channelName string, reason string) {
|
||||||
|
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
|
||||||
|
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||||
|
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||||
|
notifyRootUser(subject, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RelayNotImplemented(c *gin.Context) {
|
||||||
|
err := types.OpenAIError{
|
||||||
|
Message: "API not implemented",
|
||||||
|
Type: "one_api_error",
|
||||||
|
Param: "",
|
||||||
|
Code: "api_not_implemented",
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusNotImplemented, gin.H{
|
||||||
|
"error": err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func RelayNotFound(c *gin.Context) {
|
||||||
|
err := types.OpenAIError{
|
||||||
|
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
Param: "",
|
||||||
|
Code: "",
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{
|
||||||
|
"error": err,
|
||||||
|
})
|
||||||
|
}
|
@ -70,7 +70,7 @@ func ListModels(c *gin.Context) {
|
|||||||
groupName = user.Group
|
groupName = user.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
models, err := model.CacheGetGroupModels(groupName)
|
models, err := model.ChannelGroup.GetGroupModels(groupName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error())
|
common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error())
|
||||||
return
|
return
|
||||||
|
@ -1,79 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/common/requester"
|
|
||||||
providersBase "one-api/providers/base"
|
|
||||||
"one-api/types"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func RelayChat(c *gin.Context) {
|
|
||||||
|
|
||||||
var chatRequest types.ChatCompletionRequest
|
|
||||||
if err := common.UnmarshalBodyReusable(c, &chatRequest); err != nil {
|
|
||||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if chatRequest.MaxTokens < 0 || chatRequest.MaxTokens > math.MaxInt32/2 {
|
|
||||||
common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取供应商
|
|
||||||
provider, modelName, fail := getProvider(c, chatRequest.Model)
|
|
||||||
if fail {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
chatRequest.Model = modelName
|
|
||||||
|
|
||||||
chatProvider, ok := provider.(providersBase.ChatInterface)
|
|
||||||
if !ok {
|
|
||||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取Input Tokens
|
|
||||||
promptTokens := common.CountTokenMessages(chatRequest.Messages, chatRequest.Model)
|
|
||||||
|
|
||||||
usage := &types.Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
}
|
|
||||||
provider.SetUsage(usage)
|
|
||||||
|
|
||||||
quotaInfo, errWithCode := generateQuotaInfo(c, chatRequest.Model, promptTokens)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if chatRequest.Stream {
|
|
||||||
var response requester.StreamReaderInterface[string]
|
|
||||||
response, errWithCode = chatProvider.CreateChatCompletionStream(&chatRequest)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
errWithCode = responseStreamClient(c, response)
|
|
||||||
} else {
|
|
||||||
var response *types.ChatCompletionResponse
|
|
||||||
response, errWithCode = chatProvider.CreateChatCompletion(&chatRequest)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
errWithCode = responseJsonClient(c, response)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果报错,则退还配额
|
|
||||||
if errWithCode != nil {
|
|
||||||
quotaInfo.undo(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
quotaInfo.consume(c, usage)
|
|
||||||
}
|
|
@ -1,79 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/common/requester"
|
|
||||||
providersBase "one-api/providers/base"
|
|
||||||
"one-api/types"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func RelayCompletions(c *gin.Context) {
|
|
||||||
|
|
||||||
var completionRequest types.CompletionRequest
|
|
||||||
if err := common.UnmarshalBodyReusable(c, &completionRequest); err != nil {
|
|
||||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if completionRequest.MaxTokens < 0 || completionRequest.MaxTokens > math.MaxInt32/2 {
|
|
||||||
common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取供应商
|
|
||||||
provider, modelName, fail := getProvider(c, completionRequest.Model)
|
|
||||||
if fail {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
completionRequest.Model = modelName
|
|
||||||
|
|
||||||
completionProvider, ok := provider.(providersBase.CompletionInterface)
|
|
||||||
if !ok {
|
|
||||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取Input Tokens
|
|
||||||
promptTokens := common.CountTokenInput(completionRequest.Prompt, completionRequest.Model)
|
|
||||||
|
|
||||||
usage := &types.Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
}
|
|
||||||
provider.SetUsage(usage)
|
|
||||||
|
|
||||||
quotaInfo, errWithCode := generateQuotaInfo(c, completionRequest.Model, promptTokens)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if completionRequest.Stream {
|
|
||||||
var response requester.StreamReaderInterface[string]
|
|
||||||
response, errWithCode = completionProvider.CreateCompletionStream(&completionRequest)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
errWithCode = responseStreamClient(c, response)
|
|
||||||
} else {
|
|
||||||
var response *types.CompletionResponse
|
|
||||||
response, errWithCode = completionProvider.CreateCompletion(&completionRequest)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
errWithCode = responseJsonClient(c, response)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果报错,则退还配额
|
|
||||||
if errWithCode != nil {
|
|
||||||
quotaInfo.undo(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
quotaInfo.consume(c, usage)
|
|
||||||
}
|
|
@ -1,66 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
providersBase "one-api/providers/base"
|
|
||||||
"one-api/types"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func RelayEmbeddings(c *gin.Context) {
|
|
||||||
|
|
||||||
var embeddingsRequest types.EmbeddingRequest
|
|
||||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
|
||||||
embeddingsRequest.Model = c.Param("model")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := common.UnmarshalBodyReusable(c, &embeddingsRequest); err != nil {
|
|
||||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取供应商
|
|
||||||
provider, modelName, fail := getProvider(c, embeddingsRequest.Model)
|
|
||||||
if fail {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
embeddingsRequest.Model = modelName
|
|
||||||
|
|
||||||
embeddingsProvider, ok := provider.(providersBase.EmbeddingsInterface)
|
|
||||||
if !ok {
|
|
||||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取Input Tokens
|
|
||||||
promptTokens := common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model)
|
|
||||||
|
|
||||||
usage := &types.Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
}
|
|
||||||
provider.SetUsage(usage)
|
|
||||||
|
|
||||||
quotaInfo, errWithCode := generateQuotaInfo(c, embeddingsRequest.Model, promptTokens)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response, errWithCode := embeddingsProvider.CreateEmbeddings(&embeddingsRequest)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
errWithCode = responseJsonClient(c, response)
|
|
||||||
|
|
||||||
// 如果报错,则退还配额
|
|
||||||
if errWithCode != nil {
|
|
||||||
quotaInfo.undo(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
quotaInfo.consume(c, usage)
|
|
||||||
}
|
|
@ -1,79 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
providersBase "one-api/providers/base"
|
|
||||||
"one-api/types"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func RelayImageEdits(c *gin.Context) {
|
|
||||||
|
|
||||||
var imageEditRequest types.ImageEditRequest
|
|
||||||
|
|
||||||
if err := common.UnmarshalBodyReusable(c, &imageEditRequest); err != nil {
|
|
||||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if imageEditRequest.Prompt == "" {
|
|
||||||
common.AbortWithMessage(c, http.StatusBadRequest, "field prompt is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if imageEditRequest.Model == "" {
|
|
||||||
imageEditRequest.Model = "dall-e-2"
|
|
||||||
}
|
|
||||||
|
|
||||||
if imageEditRequest.Size == "" {
|
|
||||||
imageEditRequest.Size = "1024x1024"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取供应商
|
|
||||||
provider, modelName, fail := getProvider(c, imageEditRequest.Model)
|
|
||||||
if fail {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
imageEditRequest.Model = modelName
|
|
||||||
|
|
||||||
imageEditsProvider, ok := provider.(providersBase.ImageEditsInterface)
|
|
||||||
if !ok {
|
|
||||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取Input Tokens
|
|
||||||
promptTokens, err := common.CountTokenImage(imageEditRequest)
|
|
||||||
if err != nil {
|
|
||||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
usage := &types.Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
}
|
|
||||||
provider.SetUsage(usage)
|
|
||||||
|
|
||||||
quotaInfo, errWithCode := generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response, errWithCode := imageEditsProvider.CreateImageEdits(&imageEditRequest)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
errWithCode = responseJsonClient(c, response)
|
|
||||||
|
|
||||||
// 如果报错,则退还配额
|
|
||||||
if errWithCode != nil {
|
|
||||||
quotaInfo.undo(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
quotaInfo.consume(c, usage)
|
|
||||||
}
|
|
@ -1,82 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
providersBase "one-api/providers/base"
|
|
||||||
"one-api/types"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func RelayImageGenerations(c *gin.Context) {
|
|
||||||
|
|
||||||
var imageRequest types.ImageRequest
|
|
||||||
|
|
||||||
if err := common.UnmarshalBodyReusable(c, &imageRequest); err != nil {
|
|
||||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if imageRequest.Model == "" {
|
|
||||||
imageRequest.Model = "dall-e-2"
|
|
||||||
}
|
|
||||||
|
|
||||||
if imageRequest.N == 0 {
|
|
||||||
imageRequest.N = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if imageRequest.Size == "" {
|
|
||||||
imageRequest.Size = "1024x1024"
|
|
||||||
}
|
|
||||||
|
|
||||||
if imageRequest.Quality == "" {
|
|
||||||
imageRequest.Quality = "standard"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取供应商
|
|
||||||
provider, modelName, fail := getProvider(c, imageRequest.Model)
|
|
||||||
if fail {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
imageRequest.Model = modelName
|
|
||||||
|
|
||||||
imageGenerationsProvider, ok := provider.(providersBase.ImageGenerationsInterface)
|
|
||||||
if !ok {
|
|
||||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取Input Tokens
|
|
||||||
promptTokens, err := common.CountTokenImage(imageRequest)
|
|
||||||
if err != nil {
|
|
||||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
usage := &types.Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
}
|
|
||||||
provider.SetUsage(usage)
|
|
||||||
|
|
||||||
quotaInfo, errWithCode := generateQuotaInfo(c, imageRequest.Model, promptTokens)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response, errWithCode := imageGenerationsProvider.CreateImageGenerations(&imageRequest)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
errWithCode = responseJsonClient(c, response)
|
|
||||||
|
|
||||||
// 如果报错,则退还配额
|
|
||||||
if errWithCode != nil {
|
|
||||||
quotaInfo.undo(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
quotaInfo.consume(c, usage)
|
|
||||||
}
|
|
@ -1,74 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
providersBase "one-api/providers/base"
|
|
||||||
"one-api/types"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func RelayImageVariations(c *gin.Context) {
|
|
||||||
|
|
||||||
var imageEditRequest types.ImageEditRequest
|
|
||||||
|
|
||||||
if err := common.UnmarshalBodyReusable(c, &imageEditRequest); err != nil {
|
|
||||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if imageEditRequest.Model == "" {
|
|
||||||
imageEditRequest.Model = "dall-e-2"
|
|
||||||
}
|
|
||||||
|
|
||||||
if imageEditRequest.Size == "" {
|
|
||||||
imageEditRequest.Size = "1024x1024"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取供应商
|
|
||||||
provider, modelName, fail := getProvider(c, imageEditRequest.Model)
|
|
||||||
if fail {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
imageEditRequest.Model = modelName
|
|
||||||
|
|
||||||
imageVariations, ok := provider.(providersBase.ImageVariationsInterface)
|
|
||||||
if !ok {
|
|
||||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取Input Tokens
|
|
||||||
promptTokens, err := common.CountTokenImage(imageEditRequest)
|
|
||||||
if err != nil {
|
|
||||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
usage := &types.Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
}
|
|
||||||
provider.SetUsage(usage)
|
|
||||||
|
|
||||||
quotaInfo, errWithCode := generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response, errWithCode := imageVariations.CreateImageVariations(&imageEditRequest)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
errWithCode = responseJsonClient(c, response)
|
|
||||||
|
|
||||||
// 如果报错,则退还配额
|
|
||||||
if errWithCode != nil {
|
|
||||||
quotaInfo.undo(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
quotaInfo.consume(c, usage)
|
|
||||||
}
|
|
@ -1,66 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
providersBase "one-api/providers/base"
|
|
||||||
"one-api/types"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func RelayModerations(c *gin.Context) {
|
|
||||||
|
|
||||||
var moderationRequest types.ModerationRequest
|
|
||||||
|
|
||||||
if err := common.UnmarshalBodyReusable(c, &moderationRequest); err != nil {
|
|
||||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if moderationRequest.Model == "" {
|
|
||||||
moderationRequest.Model = "text-moderation-stable"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取供应商
|
|
||||||
provider, modelName, fail := getProvider(c, moderationRequest.Model)
|
|
||||||
if fail {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
moderationRequest.Model = modelName
|
|
||||||
|
|
||||||
moderationProvider, ok := provider.(providersBase.ModerationInterface)
|
|
||||||
if !ok {
|
|
||||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取Input Tokens
|
|
||||||
promptTokens := common.CountTokenInput(moderationRequest.Input, moderationRequest.Model)
|
|
||||||
|
|
||||||
usage := &types.Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
}
|
|
||||||
provider.SetUsage(usage)
|
|
||||||
|
|
||||||
quotaInfo, errWithCode := generateQuotaInfo(c, moderationRequest.Model, promptTokens)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response, errWithCode := moderationProvider.CreateModeration(&moderationRequest)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
errWithCode = responseJsonClient(c, response)
|
|
||||||
|
|
||||||
// 如果报错,则退还配额
|
|
||||||
if errWithCode != nil {
|
|
||||||
quotaInfo.undo(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
quotaInfo.consume(c, usage)
|
|
||||||
}
|
|
@ -1,62 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
providersBase "one-api/providers/base"
|
|
||||||
"one-api/types"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func RelaySpeech(c *gin.Context) {
|
|
||||||
|
|
||||||
var speechRequest types.SpeechAudioRequest
|
|
||||||
|
|
||||||
if err := common.UnmarshalBodyReusable(c, &speechRequest); err != nil {
|
|
||||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取供应商
|
|
||||||
provider, modelName, fail := getProvider(c, speechRequest.Model)
|
|
||||||
if fail {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
speechRequest.Model = modelName
|
|
||||||
|
|
||||||
speechProvider, ok := provider.(providersBase.SpeechInterface)
|
|
||||||
if !ok {
|
|
||||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取Input Tokens
|
|
||||||
promptTokens := len(speechRequest.Input)
|
|
||||||
|
|
||||||
usage := &types.Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
}
|
|
||||||
provider.SetUsage(usage)
|
|
||||||
|
|
||||||
quotaInfo, errWithCode := generateQuotaInfo(c, speechRequest.Model, promptTokens)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response, errWithCode := speechProvider.CreateSpeech(&speechRequest)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
errWithCode = responseMultipart(c, response)
|
|
||||||
|
|
||||||
// 如果报错,则退还配额
|
|
||||||
if errWithCode != nil {
|
|
||||||
quotaInfo.undo(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
quotaInfo.consume(c, usage)
|
|
||||||
}
|
|
@ -1,62 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
providersBase "one-api/providers/base"
|
|
||||||
"one-api/types"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func RelayTranscriptions(c *gin.Context) {
|
|
||||||
|
|
||||||
var audioRequest types.AudioRequest
|
|
||||||
|
|
||||||
if err := common.UnmarshalBodyReusable(c, &audioRequest); err != nil {
|
|
||||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取供应商
|
|
||||||
provider, modelName, fail := getProvider(c, audioRequest.Model)
|
|
||||||
if fail {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
audioRequest.Model = modelName
|
|
||||||
|
|
||||||
transcriptionsProvider, ok := provider.(providersBase.TranscriptionsInterface)
|
|
||||||
if !ok {
|
|
||||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取Input Tokens
|
|
||||||
promptTokens := 0
|
|
||||||
|
|
||||||
usage := &types.Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
}
|
|
||||||
provider.SetUsage(usage)
|
|
||||||
|
|
||||||
quotaInfo, errWithCode := generateQuotaInfo(c, audioRequest.Model, promptTokens)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response, errWithCode := transcriptionsProvider.CreateTranscriptions(&audioRequest)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
errWithCode = responseCustom(c, response)
|
|
||||||
|
|
||||||
// 如果报错,则退还配额
|
|
||||||
if errWithCode != nil {
|
|
||||||
quotaInfo.undo(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
quotaInfo.consume(c, usage)
|
|
||||||
}
|
|
@ -1,62 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
providersBase "one-api/providers/base"
|
|
||||||
"one-api/types"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func RelayTranslations(c *gin.Context) {
|
|
||||||
|
|
||||||
var audioRequest types.AudioRequest
|
|
||||||
|
|
||||||
if err := common.UnmarshalBodyReusable(c, &audioRequest); err != nil {
|
|
||||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取供应商
|
|
||||||
provider, modelName, fail := getProvider(c, audioRequest.Model)
|
|
||||||
if fail {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
audioRequest.Model = modelName
|
|
||||||
|
|
||||||
translationProvider, ok := provider.(providersBase.TranslationInterface)
|
|
||||||
if !ok {
|
|
||||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取Input Tokens
|
|
||||||
promptTokens := 0
|
|
||||||
|
|
||||||
usage := &types.Usage{
|
|
||||||
PromptTokens: promptTokens,
|
|
||||||
}
|
|
||||||
provider.SetUsage(usage)
|
|
||||||
|
|
||||||
quotaInfo, errWithCode := generateQuotaInfo(c, audioRequest.Model, promptTokens)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response, errWithCode := translationProvider.CreateTranslation(&audioRequest)
|
|
||||||
if errWithCode != nil {
|
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
errWithCode = responseCustom(c, response)
|
|
||||||
|
|
||||||
// 如果报错,则退还配额
|
|
||||||
if errWithCode != nil {
|
|
||||||
quotaInfo.undo(c, errWithCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
quotaInfo.consume(c, usage)
|
|
||||||
}
|
|
@ -1,63 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
|
||||||
"one-api/types"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
func RelayNotImplemented(c *gin.Context) {
|
|
||||||
err := types.OpenAIError{
|
|
||||||
Message: "API not implemented",
|
|
||||||
Type: "one_api_error",
|
|
||||||
Param: "",
|
|
||||||
Code: "api_not_implemented",
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusNotImplemented, gin.H{
|
|
||||||
"error": err,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func RelayNotFound(c *gin.Context) {
|
|
||||||
err := types.OpenAIError{
|
|
||||||
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
|
|
||||||
Type: "invalid_request_error",
|
|
||||||
Param: "",
|
|
||||||
Code: "",
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{
|
|
||||||
"error": err,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func errorHelper(c *gin.Context, err *types.OpenAIErrorWithStatusCode) {
|
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
|
||||||
retryTimesStr := c.Query("retry")
|
|
||||||
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
|
||||||
if retryTimesStr == "" {
|
|
||||||
retryTimes = common.RetryTimes
|
|
||||||
}
|
|
||||||
if retryTimes > 0 {
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
|
|
||||||
} else {
|
|
||||||
if err.StatusCode == http.StatusTooManyRequests {
|
|
||||||
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
|
||||||
}
|
|
||||||
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
|
|
||||||
c.JSON(err.StatusCode, gin.H{
|
|
||||||
"error": err.OpenAIError,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
|
||||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
|
||||||
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
channelName := c.GetString("channel_name")
|
|
||||||
disableChannel(channelId, channelName, err.Message)
|
|
||||||
}
|
|
||||||
}
|
|
53
controller/relay/base.go
Normal file
53
controller/relay/base.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
providersBase "one-api/providers/base"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type relayBase struct {
|
||||||
|
c *gin.Context
|
||||||
|
provider providersBase.ProviderInterface
|
||||||
|
originalModel string
|
||||||
|
modelName string
|
||||||
|
}
|
||||||
|
|
||||||
|
type RelayBaseInterface interface {
|
||||||
|
send() (err *types.OpenAIErrorWithStatusCode, done bool)
|
||||||
|
getPromptTokens() (int, error)
|
||||||
|
setRequest() error
|
||||||
|
setProvider(modelName string) error
|
||||||
|
getProvider() providersBase.ProviderInterface
|
||||||
|
getOriginalModel() string
|
||||||
|
getModelName() string
|
||||||
|
getContext() *gin.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayBase) setProvider(modelName string) error {
|
||||||
|
provider, modelName, fail := getProvider(r.c, modelName)
|
||||||
|
if fail != nil {
|
||||||
|
return fail
|
||||||
|
}
|
||||||
|
r.provider = provider
|
||||||
|
r.modelName = modelName
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayBase) getContext() *gin.Context {
|
||||||
|
return r.c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayBase) getProvider() providersBase.ProviderInterface {
|
||||||
|
return r.provider
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayBase) getOriginalModel() string {
|
||||||
|
return r.originalModel
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayBase) getModelName() string {
|
||||||
|
return r.modelName
|
||||||
|
}
|
76
controller/relay/chat.go
Normal file
76
controller/relay/chat.go
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/common/requester"
|
||||||
|
providersBase "one-api/providers/base"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type relayChat struct {
|
||||||
|
relayBase
|
||||||
|
chatRequest types.ChatCompletionRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRelayChat(c *gin.Context) *relayChat {
|
||||||
|
relay := &relayChat{}
|
||||||
|
relay.c = c
|
||||||
|
return relay
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayChat) setRequest() error {
|
||||||
|
if err := common.UnmarshalBodyReusable(r.c, &r.chatRequest); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.chatRequest.MaxTokens < 0 || r.chatRequest.MaxTokens > math.MaxInt32/2 {
|
||||||
|
return errors.New("max_tokens is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.originalModel = r.chatRequest.Model
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayChat) getPromptTokens() (int, error) {
|
||||||
|
return common.CountTokenMessages(r.chatRequest.Messages, r.modelName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||||
|
chatProvider, ok := r.provider.(providersBase.ChatInterface)
|
||||||
|
if !ok {
|
||||||
|
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||||
|
done = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.chatRequest.Model = r.modelName
|
||||||
|
|
||||||
|
if r.chatRequest.Stream {
|
||||||
|
var response requester.StreamReaderInterface[string]
|
||||||
|
response, err = chatProvider.CreateChatCompletionStream(&r.chatRequest)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = responseStreamClient(r.c, response)
|
||||||
|
} else {
|
||||||
|
var response *types.ChatCompletionResponse
|
||||||
|
response, err = chatProvider.CreateChatCompletion(&r.chatRequest)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = responseJsonClient(r.c, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
done = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
76
controller/relay/completions.go
Normal file
76
controller/relay/completions.go
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/common/requester"
|
||||||
|
providersBase "one-api/providers/base"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type relayCompletions struct {
|
||||||
|
relayBase
|
||||||
|
request types.CompletionRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRelayCompletions(c *gin.Context) *relayCompletions {
|
||||||
|
relay := &relayCompletions{}
|
||||||
|
relay.c = c
|
||||||
|
return relay
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayCompletions) setRequest() error {
|
||||||
|
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.request.MaxTokens < 0 || r.request.MaxTokens > math.MaxInt32/2 {
|
||||||
|
return errors.New("max_tokens is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.originalModel = r.request.Model
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayCompletions) getPromptTokens() (int, error) {
|
||||||
|
return common.CountTokenInput(r.request.Prompt, r.modelName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||||
|
provider, ok := r.provider.(providersBase.CompletionInterface)
|
||||||
|
if !ok {
|
||||||
|
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||||
|
done = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.request.Model = r.modelName
|
||||||
|
|
||||||
|
if r.request.Stream {
|
||||||
|
var response requester.StreamReaderInterface[string]
|
||||||
|
response, err = provider.CreateCompletionStream(&r.request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = responseStreamClient(r.c, response)
|
||||||
|
} else {
|
||||||
|
var response *types.CompletionResponse
|
||||||
|
response, err = provider.CreateCompletion(&r.request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = responseJsonClient(r.c, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
done = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
63
controller/relay/embeddings.go
Normal file
63
controller/relay/embeddings.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
providersBase "one-api/providers/base"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type relayEmbeddings struct {
|
||||||
|
relayBase
|
||||||
|
request types.EmbeddingRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRelayEmbeddings(c *gin.Context) *relayEmbeddings {
|
||||||
|
relay := &relayEmbeddings{}
|
||||||
|
relay.c = c
|
||||||
|
return relay
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayEmbeddings) setRequest() error {
|
||||||
|
if strings.HasSuffix(r.c.Request.URL.Path, "embeddings") {
|
||||||
|
r.request.Model = r.c.Param("model")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.originalModel = r.request.Model
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayEmbeddings) getPromptTokens() (int, error) {
|
||||||
|
return common.CountTokenInput(r.request.Input, r.modelName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayEmbeddings) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||||
|
provider, ok := r.provider.(providersBase.EmbeddingsInterface)
|
||||||
|
if !ok {
|
||||||
|
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||||
|
done = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.request.Model = r.modelName
|
||||||
|
|
||||||
|
response, err := provider.CreateEmbeddings(&r.request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = responseJsonClient(r.c, response)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
done = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
71
controller/relay/image-edits.go
Normal file
71
controller/relay/image-edits.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
providersBase "one-api/providers/base"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type relayImageEdits struct {
|
||||||
|
relayBase
|
||||||
|
request types.ImageEditRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRelayImageEdits(c *gin.Context) *relayImageEdits {
|
||||||
|
relay := &relayImageEdits{}
|
||||||
|
relay.c = c
|
||||||
|
return relay
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayImageEdits) setRequest() error {
|
||||||
|
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.request.Prompt == "" {
|
||||||
|
return errors.New("field prompt is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.request.Model == "" {
|
||||||
|
r.request.Model = "dall-e-2"
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.request.Size == "" {
|
||||||
|
r.request.Size = "1024x1024"
|
||||||
|
}
|
||||||
|
|
||||||
|
r.originalModel = r.request.Model
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayImageEdits) getPromptTokens() (int, error) {
|
||||||
|
return common.CountTokenImage(r.request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayImageEdits) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||||
|
provider, ok := r.provider.(providersBase.ImageEditsInterface)
|
||||||
|
if !ok {
|
||||||
|
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||||
|
done = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.request.Model = r.modelName
|
||||||
|
|
||||||
|
response, err := provider.CreateImageEdits(&r.request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = responseJsonClient(r.c, response)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
done = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
74
controller/relay/image-generations.go
Normal file
74
controller/relay/image-generations.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
providersBase "one-api/providers/base"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type relayImageGenerations struct {
|
||||||
|
relayBase
|
||||||
|
request types.ImageRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRelayImageGenerations(c *gin.Context) *relayImageGenerations {
|
||||||
|
relay := &relayImageGenerations{}
|
||||||
|
relay.c = c
|
||||||
|
return relay
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayImageGenerations) setRequest() error {
|
||||||
|
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.request.Model == "" {
|
||||||
|
r.request.Model = "dall-e-2"
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.request.N == 0 {
|
||||||
|
r.request.N = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.request.Size == "" {
|
||||||
|
r.request.Size = "1024x1024"
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.request.Quality == "" {
|
||||||
|
r.request.Quality = "standard"
|
||||||
|
}
|
||||||
|
|
||||||
|
r.originalModel = r.request.Model
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayImageGenerations) getPromptTokens() (int, error) {
|
||||||
|
return common.CountTokenImage(r.request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayImageGenerations) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||||
|
provider, ok := r.provider.(providersBase.ImageGenerationsInterface)
|
||||||
|
if !ok {
|
||||||
|
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||||
|
done = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.request.Model = r.modelName
|
||||||
|
|
||||||
|
response, err := provider.CreateImageGenerations(&r.request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = responseJsonClient(r.c, response)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
done = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
66
controller/relay/image-variationsy.go
Normal file
66
controller/relay/image-variationsy.go
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
providersBase "one-api/providers/base"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type relayImageVariations struct {
|
||||||
|
relayBase
|
||||||
|
request types.ImageEditRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRelayImageVariations(c *gin.Context) *relayImageVariations {
|
||||||
|
relay := &relayImageVariations{}
|
||||||
|
relay.c = c
|
||||||
|
return relay
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayImageVariations) setRequest() error {
|
||||||
|
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.request.Model == "" {
|
||||||
|
r.request.Model = "dall-e-2"
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.request.Size == "" {
|
||||||
|
r.request.Size = "1024x1024"
|
||||||
|
}
|
||||||
|
|
||||||
|
r.originalModel = r.request.Model
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayImageVariations) getPromptTokens() (int, error) {
|
||||||
|
return common.CountTokenImage(r.request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayImageVariations) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||||
|
provider, ok := r.provider.(providersBase.ImageVariationsInterface)
|
||||||
|
if !ok {
|
||||||
|
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||||
|
done = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.request.Model = r.modelName
|
||||||
|
|
||||||
|
response, err := provider.CreateImageVariations(&r.request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = responseJsonClient(r.c, response)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
done = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
106
controller/relay/main.go
Normal file
106
controller/relay/main.go
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Relay(c *gin.Context) {
|
||||||
|
relay := Path2Relay(c, c.Request.URL.Path)
|
||||||
|
if relay == nil {
|
||||||
|
common.AbortWithMessage(c, http.StatusNotFound, "Not Found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := relay.setRequest(); err != nil {
|
||||||
|
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := relay.setProvider(relay.getOriginalModel()); err != nil {
|
||||||
|
common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiErr, done := RelayHandler(relay)
|
||||||
|
if apiErr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
channel := relay.getProvider().GetChannel()
|
||||||
|
go processChannelRelayError(c.Request.Context(), channel.Id, channel.Name, apiErr)
|
||||||
|
|
||||||
|
retryTimes := common.RetryTimes
|
||||||
|
if done || !shouldRetry(c, apiErr.StatusCode) {
|
||||||
|
common.LogError(c.Request.Context(), fmt.Sprintf("relay error happen, status code is %d, won't retry in this case", apiErr.StatusCode))
|
||||||
|
retryTimes = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := retryTimes; i > 0; i-- {
|
||||||
|
// 冻结通道
|
||||||
|
model.ChannelGroup.Cooldowns(channel.Id)
|
||||||
|
if err := relay.setProvider(relay.getOriginalModel()); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
channel = relay.getProvider().GetChannel()
|
||||||
|
common.LogError(c.Request.Context(), fmt.Sprintf("using channel #%d(%s) to retry (remain times %d)", channel.Id, channel.Name, i))
|
||||||
|
apiErr, done = RelayHandler(relay)
|
||||||
|
if apiErr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go processChannelRelayError(c.Request.Context(), channel.Id, channel.Name, apiErr)
|
||||||
|
if done || !shouldRetry(c, apiErr.StatusCode) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if apiErr != nil {
|
||||||
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
|
if apiErr.StatusCode == http.StatusTooManyRequests {
|
||||||
|
apiErr.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||||
|
}
|
||||||
|
apiErr.OpenAIError.Message = common.MessageWithRequestId(apiErr.OpenAIError.Message, requestId)
|
||||||
|
c.JSON(apiErr.StatusCode, gin.H{
|
||||||
|
"error": apiErr.OpenAIError,
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||||
|
promptTokens, tonkeErr := relay.getPromptTokens()
|
||||||
|
if tonkeErr != nil {
|
||||||
|
err = common.ErrorWrapper(tonkeErr, "token_error", http.StatusBadRequest)
|
||||||
|
done = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &types.Usage{
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
relay.getProvider().SetUsage(usage)
|
||||||
|
|
||||||
|
var quotaInfo *QuotaInfo
|
||||||
|
quotaInfo, err = generateQuotaInfo(relay.getContext(), relay.getModelName(), promptTokens)
|
||||||
|
if err != nil {
|
||||||
|
done = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err, done = relay.send()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
quotaInfo.undo(relay.getContext())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
quotaInfo.consume(relay.getContext(), usage)
|
||||||
|
return
|
||||||
|
}
|
62
controller/relay/moderations.go
Normal file
62
controller/relay/moderations.go
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
providersBase "one-api/providers/base"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type relayModerations struct {
|
||||||
|
relayBase
|
||||||
|
request types.ModerationRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRelayModerations(c *gin.Context) *relayModerations {
|
||||||
|
relay := &relayModerations{}
|
||||||
|
relay.c = c
|
||||||
|
return relay
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayModerations) setRequest() error {
|
||||||
|
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.request.Model == "" {
|
||||||
|
r.request.Model = "text-moderation-stable"
|
||||||
|
}
|
||||||
|
|
||||||
|
r.originalModel = r.request.Model
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayModerations) getPromptTokens() (int, error) {
|
||||||
|
return common.CountTokenInput(r.request.Input, r.modelName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayModerations) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||||
|
provider, ok := r.provider.(providersBase.ModerationInterface)
|
||||||
|
if !ok {
|
||||||
|
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||||
|
done = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.request.Model = r.modelName
|
||||||
|
|
||||||
|
response, err := provider.CreateModeration(&r.request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = responseJsonClient(r.c, response)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
done = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package controller
|
package relay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -144,7 +144,7 @@ func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName stri
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *QuotaInfo) undo(c *gin.Context, errWithCode *types.OpenAIErrorWithStatusCode) {
|
func (q *QuotaInfo) undo(c *gin.Context) {
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
if q.HandelStatus {
|
if q.HandelStatus {
|
||||||
go func(ctx context.Context) {
|
go func(ctx context.Context) {
|
||||||
@ -155,7 +155,6 @@ func (q *QuotaInfo) undo(c *gin.Context, errWithCode *types.OpenAIErrorWithStatu
|
|||||||
}
|
}
|
||||||
}(c.Request.Context())
|
}(c.Request.Context())
|
||||||
}
|
}
|
||||||
errorHelper(c, errWithCode)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *QuotaInfo) consume(c *gin.Context, usage *types.Usage) {
|
func (q *QuotaInfo) consume(c *gin.Context, usage *types.Usage) {
|
58
controller/relay/speech.go
Normal file
58
controller/relay/speech.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
providersBase "one-api/providers/base"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type relaySpeech struct {
|
||||||
|
relayBase
|
||||||
|
request types.SpeechAudioRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRelaySpeech(c *gin.Context) *relaySpeech {
|
||||||
|
relay := &relaySpeech{}
|
||||||
|
relay.c = c
|
||||||
|
return relay
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relaySpeech) setRequest() error {
|
||||||
|
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.originalModel = r.request.Model
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relaySpeech) getPromptTokens() (int, error) {
|
||||||
|
return len(r.request.Input), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relaySpeech) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||||
|
provider, ok := r.provider.(providersBase.SpeechInterface)
|
||||||
|
if !ok {
|
||||||
|
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||||
|
done = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.request.Model = r.modelName
|
||||||
|
|
||||||
|
response, err := provider.CreateSpeech(&r.request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = responseMultipart(r.c, response)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
done = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
58
controller/relay/transcriptions.go
Normal file
58
controller/relay/transcriptions.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
providersBase "one-api/providers/base"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type relayTranscriptions struct {
|
||||||
|
relayBase
|
||||||
|
request types.AudioRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRelayTranscriptions(c *gin.Context) *relayTranscriptions {
|
||||||
|
relay := &relayTranscriptions{}
|
||||||
|
relay.c = c
|
||||||
|
return relay
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayTranscriptions) setRequest() error {
|
||||||
|
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.originalModel = r.request.Model
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayTranscriptions) getPromptTokens() (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayTranscriptions) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||||
|
provider, ok := r.provider.(providersBase.TranscriptionsInterface)
|
||||||
|
if !ok {
|
||||||
|
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||||
|
done = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.request.Model = r.modelName
|
||||||
|
|
||||||
|
response, err := provider.CreateTranscriptions(&r.request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = responseCustom(r.c, response)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
done = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
58
controller/relay/translations.go
Normal file
58
controller/relay/translations.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
providersBase "one-api/providers/base"
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type relayTranslations struct {
|
||||||
|
relayBase
|
||||||
|
request types.AudioRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRelayTranslations(c *gin.Context) *relayTranslations {
|
||||||
|
relay := &relayTranslations{}
|
||||||
|
relay.c = c
|
||||||
|
return relay
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayTranslations) setRequest() error {
|
||||||
|
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.originalModel = r.request.Model
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayTranslations) getPromptTokens() (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *relayTranslations) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||||
|
provider, ok := r.provider.(providersBase.TranslationInterface)
|
||||||
|
if !ok {
|
||||||
|
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||||
|
done = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.request.Model = r.modelName
|
||||||
|
|
||||||
|
response, err := provider.CreateTranslation(&r.request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = responseCustom(r.c, response)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
done = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
package controller
|
package relay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -8,127 +9,98 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
|
"one-api/controller"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/providers"
|
"one-api/providers"
|
||||||
providersBase "one-api/providers/base"
|
providersBase "one-api/providers/base"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"reflect"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-playground/validator/v10"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail bool) {
|
func Path2Relay(c *gin.Context, path string) RelayBaseInterface {
|
||||||
|
if strings.HasPrefix(path, "/v1/chat/completions") {
|
||||||
|
return NewRelayChat(c)
|
||||||
|
} else if strings.HasPrefix(path, "/v1/completions") {
|
||||||
|
return NewRelayCompletions(c)
|
||||||
|
} else if strings.HasPrefix(path, "/v1/embeddings") {
|
||||||
|
return NewRelayEmbeddings(c)
|
||||||
|
} else if strings.HasPrefix(path, "/v1/moderations") {
|
||||||
|
return NewRelayModerations(c)
|
||||||
|
} else if strings.HasPrefix(path, "/v1/images/generations") {
|
||||||
|
return NewRelayImageGenerations(c)
|
||||||
|
} else if strings.HasPrefix(path, "/v1/images/edits") {
|
||||||
|
return NewRelayImageEdits(c)
|
||||||
|
} else if strings.HasPrefix(path, "/v1/images/variations") {
|
||||||
|
return NewRelayImageVariations(c)
|
||||||
|
} else if strings.HasPrefix(path, "/v1/audio/speech") {
|
||||||
|
return NewRelaySpeech(c)
|
||||||
|
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
|
||||||
|
return NewRelayTranscriptions(c)
|
||||||
|
} else if strings.HasPrefix(path, "/v1/audio/translations") {
|
||||||
|
return NewRelayTranslations(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) {
|
||||||
channel, fail := fetchChannel(c, modeName)
|
channel, fail := fetchChannel(c, modeName)
|
||||||
if fail {
|
if fail != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.Set("channel_id", channel.Id)
|
||||||
|
|
||||||
provider = providers.GetProvider(channel, c)
|
provider = providers.GetProvider(channel, c)
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found")
|
fail = errors.New("channel not found")
|
||||||
fail = true
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
provider.SetOriginalModel(modeName)
|
||||||
|
|
||||||
newModelName, err := provider.ModelMappingHandler(modeName)
|
newModelName, fail = provider.ModelMappingHandler(modeName)
|
||||||
if err != nil {
|
if fail != nil {
|
||||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
|
||||||
fail = true
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetValidFieldName(err error, obj interface{}) string {
|
func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail error) {
|
||||||
getObj := reflect.TypeOf(obj)
|
channelId := c.GetInt("specific_channel_id")
|
||||||
if errs, ok := err.(validator.ValidationErrors); ok {
|
|
||||||
for _, e := range errs {
|
|
||||||
if f, exist := getObj.Elem().FieldByName(e.Field()); exist {
|
|
||||||
return f.Name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return err.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail bool) {
|
|
||||||
channelId := c.GetInt("channelId")
|
|
||||||
if channelId > 0 {
|
if channelId > 0 {
|
||||||
channel, fail = fetchChannelById(c, channelId)
|
return fetchChannelById(channelId)
|
||||||
if fail {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
channel, fail = fetchChannelByModel(c, modelName)
|
|
||||||
if fail {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Set("channel_id", channel.Id)
|
return fetchChannelByModel(c, modelName)
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchChannelById(c *gin.Context, channelId int) (*model.Channel, bool) {
|
func fetchChannelById(channelId int) (*model.Channel, error) {
|
||||||
channel, err := model.GetChannelById(channelId, true)
|
channel, err := model.GetChannelById(channelId, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.AbortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
return nil, errors.New("无效的渠道 Id")
|
||||||
return nil, true
|
|
||||||
}
|
}
|
||||||
if channel.Status != common.ChannelStatusEnabled {
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
common.AbortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
return nil, errors.New("该渠道已被禁用")
|
||||||
return nil, true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return channel, false
|
return channel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, bool) {
|
func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, error) {
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, modelName)
|
channel, err := model.ChannelGroup.Next(group, modelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, modelName)
|
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, modelName)
|
||||||
if channel != nil {
|
if channel != nil {
|
||||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||||
message = "数据库一致性已被破坏,请联系管理员"
|
message = "数据库一致性已被破坏,请联系管理员"
|
||||||
}
|
}
|
||||||
common.AbortWithMessage(c, http.StatusServiceUnavailable, message)
|
return nil, errors.New(message)
|
||||||
return nil, true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return channel, false
|
return channel, nil
|
||||||
}
|
|
||||||
|
|
||||||
func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
|
|
||||||
if !common.AutomaticDisableChannelEnabled {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if statusCode == http.StatusUnauthorized {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool {
|
|
||||||
if !common.AutomaticEnableChannelEnabled {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if openAIErr != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWithStatusCode {
|
func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWithStatusCode {
|
||||||
@ -201,3 +173,30 @@ func responseCustom(c *gin.Context, response *types.AudioResponseWrapper) *types
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldRetry(c *gin.Context, statusCode int) bool {
|
||||||
|
channelId := c.GetInt("specific_channel_id")
|
||||||
|
if channelId > 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if statusCode == http.StatusTooManyRequests {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if statusCode/100 == 5 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if statusCode == http.StatusBadRequest {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if statusCode/100 == 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *types.OpenAIErrorWithStatusCode) {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("relay error (channel #%d(%s)): %s", channelId, channelName, err.Message))
|
||||||
|
if controller.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) {
|
||||||
|
controller.DisableChannel(channelId, channelName, err.Message)
|
||||||
|
}
|
||||||
|
}
|
4
main.go
4
main.go
@ -59,11 +59,11 @@ func main() {
|
|||||||
if common.MemoryCacheEnabled {
|
if common.MemoryCacheEnabled {
|
||||||
common.SysLog("memory cache enabled")
|
common.SysLog("memory cache enabled")
|
||||||
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
|
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
|
||||||
model.InitChannelCache()
|
model.InitChannelGroup()
|
||||||
}
|
}
|
||||||
if common.MemoryCacheEnabled {
|
if common.MemoryCacheEnabled {
|
||||||
go model.SyncOptions(common.SyncFrequency)
|
go model.SyncOptions(common.SyncFrequency)
|
||||||
go model.SyncChannelCache(common.SyncFrequency)
|
go model.SyncChannelGroup(common.SyncFrequency)
|
||||||
}
|
}
|
||||||
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
|
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
|
||||||
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
|
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
|
||||||
|
@ -114,7 +114,7 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id")
|
abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set("channelId", channelId)
|
c.Set("specific_channel_id", channelId)
|
||||||
} else {
|
} else {
|
||||||
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||||
return
|
return
|
||||||
|
@ -11,6 +11,7 @@ type Ability struct {
|
|||||||
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
|
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
|
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
|
||||||
|
Weight *uint `json:"weight" gorm:"default:1"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||||
@ -67,6 +68,7 @@ func (channel *Channel) AddAbilities() error {
|
|||||||
ChannelId: channel.Id,
|
ChannelId: channel.Id,
|
||||||
Enabled: channel.Status == common.ChannelStatusEnabled,
|
Enabled: channel.Status == common.ChannelStatusEnabled,
|
||||||
Priority: channel.Priority,
|
Priority: channel.Priority,
|
||||||
|
Weight: channel.Weight,
|
||||||
}
|
}
|
||||||
abilities = append(abilities, ability)
|
abilities = append(abilities, ability)
|
||||||
}
|
}
|
||||||
@ -98,3 +100,49 @@ func (channel *Channel) UpdateAbilities() error {
|
|||||||
func UpdateAbilityStatus(channelId int, status bool) error {
|
func UpdateAbilityStatus(channelId int, status bool) error {
|
||||||
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
|
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetEnabledAbility() ([]*Ability, error) {
|
||||||
|
trueVal := "1"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
trueVal = "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
var abilities []*Ability
|
||||||
|
err := DB.Where("enabled = ?", trueVal).Order("priority desc, weight desc").Find(&abilities).Error
|
||||||
|
return abilities, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type AbilityChannelGroup struct {
|
||||||
|
Group string `json:"group"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Priority int `json:"priority"`
|
||||||
|
ChannelIds string `json:"channel_ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAbilityChannelGroup() ([]*AbilityChannelGroup, error) {
|
||||||
|
var abilities []*AbilityChannelGroup
|
||||||
|
|
||||||
|
var channelSql string
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
channelSql = `string_agg("channel_id"::text, ',')`
|
||||||
|
} else if common.UsingSQLite {
|
||||||
|
channelSql = `group_concat("channel_id", ',')`
|
||||||
|
} else {
|
||||||
|
channelSql = "GROUP_CONCAT(`channel_id` SEPARATOR ',')"
|
||||||
|
}
|
||||||
|
|
||||||
|
trueVal := "1"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
trueVal = "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
err := DB.Raw(`
|
||||||
|
SELECT `+quotePostgresField("group")+`, model, priority, `+channelSql+` as channel_ids
|
||||||
|
FROM abilities
|
||||||
|
WHERE enabled = ?
|
||||||
|
GROUP BY `+quotePostgresField("group")+`, model, priority
|
||||||
|
ORDER BY priority DESC
|
||||||
|
`, trueVal).Scan(&abilities).Error
|
||||||
|
|
||||||
|
return abilities, err
|
||||||
|
}
|
||||||
|
176
model/balancer.go
Normal file
176
model/balancer.go
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"math/rand"
|
||||||
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChannelChoice struct {
|
||||||
|
Channel *Channel
|
||||||
|
CooldownsTime int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChannelsChooser struct {
|
||||||
|
sync.RWMutex
|
||||||
|
Channels map[int]*ChannelChoice
|
||||||
|
Rule map[string]map[string][][]int // group -> model -> priority -> channelIds
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *ChannelsChooser) Cooldowns(channelId int) bool {
|
||||||
|
if common.RetryCooldownSeconds == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
cc.Lock()
|
||||||
|
defer cc.Unlock()
|
||||||
|
if _, ok := cc.Channels[channelId]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
cc.Channels[channelId].CooldownsTime = time.Now().Unix() + int64(common.RetryCooldownSeconds)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *ChannelsChooser) Balancer(channelIds []int) *Channel {
|
||||||
|
nowTime := time.Now().Unix()
|
||||||
|
totalWeight := 0
|
||||||
|
|
||||||
|
validChannels := make([]*ChannelChoice, 0, len(channelIds))
|
||||||
|
for _, channelId := range channelIds {
|
||||||
|
if choice, ok := cc.Channels[channelId]; ok && choice.CooldownsTime < nowTime {
|
||||||
|
weight := int(*choice.Channel.Weight)
|
||||||
|
totalWeight += weight
|
||||||
|
validChannels = append(validChannels, choice)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(validChannels) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(validChannels) == 1 {
|
||||||
|
return validChannels[0].Channel
|
||||||
|
}
|
||||||
|
|
||||||
|
choiceWeight := rand.Intn(totalWeight)
|
||||||
|
for _, choice := range validChannels {
|
||||||
|
weight := int(*choice.Channel.Weight)
|
||||||
|
choiceWeight -= weight
|
||||||
|
if choiceWeight < 0 {
|
||||||
|
return choice.Channel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *ChannelsChooser) Next(group, model string) (*Channel, error) {
|
||||||
|
if !common.MemoryCacheEnabled {
|
||||||
|
return GetRandomSatisfiedChannel(group, model)
|
||||||
|
}
|
||||||
|
cc.RLock()
|
||||||
|
defer cc.RUnlock()
|
||||||
|
if _, ok := cc.Rule[group]; !ok {
|
||||||
|
return nil, errors.New("group not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := cc.Rule[group][model]; !ok {
|
||||||
|
return nil, errors.New("model not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
channelsPriority := cc.Rule[group][model]
|
||||||
|
if len(channelsPriority) == 0 {
|
||||||
|
return nil, errors.New("channel not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, priority := range channelsPriority {
|
||||||
|
channel := cc.Balancer(priority)
|
||||||
|
if channel != nil {
|
||||||
|
return channel, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("channel not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *ChannelsChooser) GetGroupModels(group string) ([]string, error) {
|
||||||
|
if !common.MemoryCacheEnabled {
|
||||||
|
return GetGroupModels(group)
|
||||||
|
}
|
||||||
|
|
||||||
|
cc.RLock()
|
||||||
|
defer cc.RUnlock()
|
||||||
|
|
||||||
|
if _, ok := cc.Rule[group]; !ok {
|
||||||
|
return nil, errors.New("group not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
models := make([]string, 0, len(cc.Rule[group]))
|
||||||
|
for model := range cc.Rule[group] {
|
||||||
|
models = append(models, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelGroup = ChannelsChooser{}
|
||||||
|
|
||||||
|
func InitChannelGroup() {
|
||||||
|
var channels []*Channel
|
||||||
|
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
|
||||||
|
|
||||||
|
abilities, err := GetAbilityChannelGroup()
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog("get enabled abilities failed: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newGroup := make(map[string]map[string][][]int)
|
||||||
|
newChannels := make(map[int]*ChannelChoice)
|
||||||
|
|
||||||
|
for _, channel := range channels {
|
||||||
|
if *channel.Weight == 0 {
|
||||||
|
channel.Weight = &common.DefaultChannelWeight
|
||||||
|
}
|
||||||
|
newChannels[channel.Id] = &ChannelChoice{
|
||||||
|
Channel: channel,
|
||||||
|
CooldownsTime: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ability := range abilities {
|
||||||
|
if _, ok := newGroup[ability.Group]; !ok {
|
||||||
|
newGroup[ability.Group] = make(map[string][][]int)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := newGroup[ability.Group][ability.Model]; !ok {
|
||||||
|
newGroup[ability.Group][ability.Model] = make([][]int, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
var priorityIds []int
|
||||||
|
// 逗号分割 ability.ChannelId
|
||||||
|
channelIds := strings.Split(ability.ChannelIds, ",")
|
||||||
|
for _, channelId := range channelIds {
|
||||||
|
priorityIds = append(priorityIds, common.String2Int(channelId))
|
||||||
|
}
|
||||||
|
|
||||||
|
newGroup[ability.Group][ability.Model] = append(newGroup[ability.Group][ability.Model], priorityIds)
|
||||||
|
}
|
||||||
|
|
||||||
|
ChannelGroup.Lock()
|
||||||
|
ChannelGroup.Rule = newGroup
|
||||||
|
ChannelGroup.Channels = newChannels
|
||||||
|
ChannelGroup.Unlock()
|
||||||
|
common.SysLog("channels synced from database")
|
||||||
|
}
|
||||||
|
|
||||||
|
func SyncChannelGroup(frequency int) {
|
||||||
|
for {
|
||||||
|
time.Sleep(time.Duration(frequency) * time.Second)
|
||||||
|
common.SysLog("syncing channels from database")
|
||||||
|
InitChannelGroup()
|
||||||
|
}
|
||||||
|
}
|
106
model/cache.go
106
model/cache.go
@ -2,14 +2,9 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"sort"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -131,104 +126,3 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
|||||||
}
|
}
|
||||||
return userEnabled, err
|
return userEnabled, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var group2model2channels map[string]map[string][]*Channel
|
|
||||||
var channelSyncLock sync.RWMutex
|
|
||||||
|
|
||||||
func InitChannelCache() {
|
|
||||||
newChannelId2channel := make(map[int]*Channel)
|
|
||||||
var channels []*Channel
|
|
||||||
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
|
|
||||||
for _, channel := range channels {
|
|
||||||
newChannelId2channel[channel.Id] = channel
|
|
||||||
}
|
|
||||||
var abilities []*Ability
|
|
||||||
DB.Find(&abilities)
|
|
||||||
groups := make(map[string]bool)
|
|
||||||
for _, ability := range abilities {
|
|
||||||
groups[ability.Group] = true
|
|
||||||
}
|
|
||||||
newGroup2model2channels := make(map[string]map[string][]*Channel)
|
|
||||||
for group := range groups {
|
|
||||||
newGroup2model2channels[group] = make(map[string][]*Channel)
|
|
||||||
}
|
|
||||||
for _, channel := range channels {
|
|
||||||
groups := strings.Split(channel.Group, ",")
|
|
||||||
for _, group := range groups {
|
|
||||||
models := strings.Split(channel.Models, ",")
|
|
||||||
for _, model := range models {
|
|
||||||
if _, ok := newGroup2model2channels[group][model]; !ok {
|
|
||||||
newGroup2model2channels[group][model] = make([]*Channel, 0)
|
|
||||||
}
|
|
||||||
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// sort by priority
|
|
||||||
for group, model2channels := range newGroup2model2channels {
|
|
||||||
for model, channels := range model2channels {
|
|
||||||
sort.Slice(channels, func(i, j int) bool {
|
|
||||||
return channels[i].GetPriority() > channels[j].GetPriority()
|
|
||||||
})
|
|
||||||
newGroup2model2channels[group][model] = channels
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
channelSyncLock.Lock()
|
|
||||||
group2model2channels = newGroup2model2channels
|
|
||||||
channelSyncLock.Unlock()
|
|
||||||
common.SysLog("channels synced from database")
|
|
||||||
}
|
|
||||||
|
|
||||||
func SyncChannelCache(frequency int) {
|
|
||||||
for {
|
|
||||||
time.Sleep(time.Duration(frequency) * time.Second)
|
|
||||||
common.SysLog("syncing channels from database")
|
|
||||||
InitChannelCache()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
|
||||||
if !common.MemoryCacheEnabled {
|
|
||||||
return GetRandomSatisfiedChannel(group, model)
|
|
||||||
}
|
|
||||||
channelSyncLock.RLock()
|
|
||||||
defer channelSyncLock.RUnlock()
|
|
||||||
channels := group2model2channels[group][model]
|
|
||||||
if len(channels) == 0 {
|
|
||||||
return nil, errors.New("channel not found")
|
|
||||||
}
|
|
||||||
endIdx := len(channels)
|
|
||||||
// choose by priority
|
|
||||||
firstChannel := channels[0]
|
|
||||||
if firstChannel.GetPriority() > 0 {
|
|
||||||
for i := range channels {
|
|
||||||
if channels[i].GetPriority() != firstChannel.GetPriority() {
|
|
||||||
endIdx = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
idx := rand.Intn(endIdx)
|
|
||||||
return channels[idx], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CacheGetGroupModels(group string) ([]string, error) {
|
|
||||||
if !common.MemoryCacheEnabled {
|
|
||||||
return GetGroupModels(group)
|
|
||||||
}
|
|
||||||
channelSyncLock.RLock()
|
|
||||||
defer channelSyncLock.RUnlock()
|
|
||||||
|
|
||||||
groupModels := group2model2channels[group]
|
|
||||||
if groupModels == nil {
|
|
||||||
return nil, errors.New("group not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
models := make([]string, 0)
|
|
||||||
for model := range groupModels {
|
|
||||||
models = append(models, model)
|
|
||||||
}
|
|
||||||
return models, nil
|
|
||||||
}
|
|
||||||
|
@ -13,7 +13,7 @@ type Channel struct {
|
|||||||
Key string `json:"key" form:"key" gorm:"type:varchar(767);not null;index"`
|
Key string `json:"key" form:"key" gorm:"type:varchar(767);not null;index"`
|
||||||
Status int `json:"status" form:"status" gorm:"default:1"`
|
Status int `json:"status" form:"status" gorm:"default:1"`
|
||||||
Name string `json:"name" form:"name" gorm:"index"`
|
Name string `json:"name" form:"name" gorm:"index"`
|
||||||
Weight *uint `json:"weight" gorm:"default:0"`
|
Weight *uint `json:"weight" gorm:"default:1"`
|
||||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
TestTime int64 `json:"test_time" gorm:"bigint"`
|
TestTime int64 `json:"test_time" gorm:"bigint"`
|
||||||
ResponseTime int `json:"response_time"` // in milliseconds
|
ResponseTime int `json:"response_time"` // in milliseconds
|
||||||
@ -95,11 +95,8 @@ func GetAllChannels() ([]*Channel, error) {
|
|||||||
func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
||||||
channel := Channel{Id: id}
|
channel := Channel{Id: id}
|
||||||
var err error = nil
|
var err error = nil
|
||||||
if selectAll {
|
err = DB.First(&channel, "id = ?", id).Error
|
||||||
err = DB.First(&channel, "id = ?", id).Error
|
|
||||||
} else {
|
|
||||||
err = DB.Omit("key").First(&channel, "id = ?", id).Error
|
|
||||||
}
|
|
||||||
return &channel, err
|
return &channel, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,6 +77,8 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["ChatLink"] = common.ChatLink
|
common.OptionMap["ChatLink"] = common.ChatLink
|
||||||
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
|
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
|
||||||
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
|
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
|
||||||
|
common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds)
|
||||||
|
|
||||||
common.OptionMapRWMutex.Unlock()
|
common.OptionMapRWMutex.Unlock()
|
||||||
initModelRatio()
|
initModelRatio()
|
||||||
loadOptionsFromDatabase()
|
loadOptionsFromDatabase()
|
||||||
@ -146,6 +148,7 @@ var optionIntMap = map[string]*int{
|
|||||||
"QuotaRemindThreshold": &common.QuotaRemindThreshold,
|
"QuotaRemindThreshold": &common.QuotaRemindThreshold,
|
||||||
"PreConsumedQuota": &common.PreConsumedQuota,
|
"PreConsumedQuota": &common.PreConsumedQuota,
|
||||||
"RetryTimes": &common.RetryTimes,
|
"RetryTimes": &common.RetryTimes,
|
||||||
|
"RetryCooldownSeconds": &common.RetryCooldownSeconds,
|
||||||
}
|
}
|
||||||
|
|
||||||
var optionBoolMap = map[string]*bool{
|
var optionBoolMap = map[string]*bool{
|
||||||
|
@ -2,6 +2,7 @@ package router
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
|
"one-api/controller/relay"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@ -19,18 +20,18 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
relayV1Router := router.Group("/v1")
|
relayV1Router := router.Group("/v1")
|
||||||
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
|
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
|
||||||
{
|
{
|
||||||
relayV1Router.POST("/completions", controller.RelayCompletions)
|
relayV1Router.POST("/completions", relay.Relay)
|
||||||
relayV1Router.POST("/chat/completions", controller.RelayChat)
|
relayV1Router.POST("/chat/completions", relay.Relay)
|
||||||
// relayV1Router.POST("/edits", controller.Relay)
|
// relayV1Router.POST("/edits", controller.Relay)
|
||||||
relayV1Router.POST("/images/generations", controller.RelayImageGenerations)
|
relayV1Router.POST("/images/generations", relay.Relay)
|
||||||
relayV1Router.POST("/images/edits", controller.RelayImageEdits)
|
relayV1Router.POST("/images/edits", relay.Relay)
|
||||||
relayV1Router.POST("/images/variations", controller.RelayImageVariations)
|
relayV1Router.POST("/images/variations", relay.Relay)
|
||||||
relayV1Router.POST("/embeddings", controller.RelayEmbeddings)
|
relayV1Router.POST("/embeddings", relay.Relay)
|
||||||
// relayV1Router.POST("/engines/:model/embeddings", controller.RelayEmbeddings)
|
// relayV1Router.POST("/engines/:model/embeddings", controller.RelayEmbeddings)
|
||||||
relayV1Router.POST("/audio/transcriptions", controller.RelayTranscriptions)
|
relayV1Router.POST("/audio/transcriptions", relay.Relay)
|
||||||
relayV1Router.POST("/audio/translations", controller.RelayTranslations)
|
relayV1Router.POST("/audio/translations", relay.Relay)
|
||||||
relayV1Router.POST("/audio/speech", controller.RelaySpeech)
|
relayV1Router.POST("/audio/speech", relay.Relay)
|
||||||
relayV1Router.POST("/moderations", controller.RelayModerations)
|
relayV1Router.POST("/moderations", relay.Relay)
|
||||||
relayV1Router.GET("/files", controller.RelayNotImplemented)
|
relayV1Router.GET("/files", controller.RelayNotImplemented)
|
||||||
relayV1Router.POST("/files", controller.RelayNotImplemented)
|
relayV1Router.POST("/files", controller.RelayNotImplemented)
|
||||||
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
|
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
|
||||||
|
@ -21,7 +21,8 @@ import {
|
|||||||
Container,
|
Container,
|
||||||
Autocomplete,
|
Autocomplete,
|
||||||
FormHelperText,
|
FormHelperText,
|
||||||
Checkbox
|
Checkbox,
|
||||||
|
Switch
|
||||||
} from '@mui/material';
|
} from '@mui/material';
|
||||||
|
|
||||||
import { Formik } from 'formik';
|
import { Formik } from 'formik';
|
||||||
@ -73,6 +74,7 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => {
|
|||||||
const [inputLabel, setInputLabel] = useState(defaultConfig.inputLabel); //
|
const [inputLabel, setInputLabel] = useState(defaultConfig.inputLabel); //
|
||||||
const [inputPrompt, setInputPrompt] = useState(defaultConfig.prompt);
|
const [inputPrompt, setInputPrompt] = useState(defaultConfig.prompt);
|
||||||
const [modelOptions, setModelOptions] = useState([]);
|
const [modelOptions, setModelOptions] = useState([]);
|
||||||
|
const [batchAdd, setBatchAdd] = useState(false);
|
||||||
|
|
||||||
const initChannel = (typeValue) => {
|
const initChannel = (typeValue) => {
|
||||||
if (typeConfig[typeValue]?.inputLabel) {
|
if (typeConfig[typeValue]?.inputLabel) {
|
||||||
@ -246,6 +248,7 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => {
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
setBatchAdd(false);
|
||||||
if (channelId) {
|
if (channelId) {
|
||||||
loadChannel().then();
|
loadChannel().then();
|
||||||
} else {
|
} else {
|
||||||
@ -479,18 +482,36 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => {
|
|||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
</Container>
|
</Container>
|
||||||
<FormControl fullWidth error={Boolean(touched.key && errors.key)} sx={{ ...theme.typography.otherInput }}>
|
<FormControl fullWidth error={Boolean(touched.key && errors.key)} sx={{ ...theme.typography.otherInput }}>
|
||||||
<InputLabel htmlFor="channel-key-label">{inputLabel.key}</InputLabel>
|
{!batchAdd ? (
|
||||||
<OutlinedInput
|
<>
|
||||||
id="channel-key-label"
|
<InputLabel htmlFor="channel-key-label">{inputLabel.key}</InputLabel>
|
||||||
label={inputLabel.key}
|
<OutlinedInput
|
||||||
type="text"
|
id="channel-key-label"
|
||||||
value={values.key}
|
label={inputLabel.key}
|
||||||
name="key"
|
type="text"
|
||||||
onBlur={handleBlur}
|
value={values.key}
|
||||||
onChange={handleChange}
|
name="key"
|
||||||
inputProps={{}}
|
onBlur={handleBlur}
|
||||||
aria-describedby="helper-text-channel-key-label"
|
onChange={handleChange}
|
||||||
/>
|
inputProps={{}}
|
||||||
|
aria-describedby="helper-text-channel-key-label"
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<TextField
|
||||||
|
multiline
|
||||||
|
id="channel-key-label"
|
||||||
|
label={inputLabel.key}
|
||||||
|
value={values.key}
|
||||||
|
name="key"
|
||||||
|
onBlur={handleBlur}
|
||||||
|
onChange={handleChange}
|
||||||
|
aria-describedby="helper-text-channel-key-label"
|
||||||
|
minRows={5}
|
||||||
|
placeholder={inputPrompt.key + ',一行一个密钥'}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
{touched.key && errors.key ? (
|
{touched.key && errors.key ? (
|
||||||
<FormHelperText error id="helper-tex-channel-key-label">
|
<FormHelperText error id="helper-tex-channel-key-label">
|
||||||
{errors.key}
|
{errors.key}
|
||||||
@ -499,6 +520,17 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => {
|
|||||||
<FormHelperText id="helper-tex-channel-key-label"> {inputPrompt.key} </FormHelperText>
|
<FormHelperText id="helper-tex-channel-key-label"> {inputPrompt.key} </FormHelperText>
|
||||||
)}
|
)}
|
||||||
</FormControl>
|
</FormControl>
|
||||||
|
{channelId === 0 && (
|
||||||
|
<Container
|
||||||
|
sx={{
|
||||||
|
textAlign: 'right'
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Switch checked={batchAdd} onChange={(e) => setBatchAdd(e.target.checked)} />
|
||||||
|
批量添加
|
||||||
|
</Container>
|
||||||
|
)}
|
||||||
|
|
||||||
<FormControl fullWidth error={Boolean(touched.model_mapping && errors.model_mapping)} sx={{ ...theme.typography.otherInput }}>
|
<FormControl fullWidth error={Boolean(touched.model_mapping && errors.model_mapping)} sx={{ ...theme.typography.otherInput }}>
|
||||||
{/* <InputLabel htmlFor="channel-model_mapping-label">{inputLabel.model_mapping}</InputLabel> */}
|
{/* <InputLabel htmlFor="channel-model_mapping-label">{inputLabel.model_mapping}</InputLabel> */}
|
||||||
<TextField
|
<TextField
|
||||||
|
@ -34,7 +34,7 @@ import TableSwitch from 'ui-component/Switch';
|
|||||||
import ResponseTimeLabel from './ResponseTimeLabel';
|
import ResponseTimeLabel from './ResponseTimeLabel';
|
||||||
import GroupLabel from './GroupLabel';
|
import GroupLabel from './GroupLabel';
|
||||||
|
|
||||||
import { IconDotsVertical, IconEdit, IconTrash, IconPencil } from '@tabler/icons-react';
|
import { IconDotsVertical, IconEdit, IconTrash, IconPencil, IconCopy } from '@tabler/icons-react';
|
||||||
import KeyboardArrowDownIcon from '@mui/icons-material/KeyboardArrowDown';
|
import KeyboardArrowDownIcon from '@mui/icons-material/KeyboardArrowDown';
|
||||||
import KeyboardArrowUpIcon from '@mui/icons-material/KeyboardArrowUp';
|
import KeyboardArrowUpIcon from '@mui/icons-material/KeyboardArrowUp';
|
||||||
import { copy } from 'utils/common';
|
import { copy } from 'utils/common';
|
||||||
@ -44,6 +44,7 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal,
|
|||||||
const [openDelete, setOpenDelete] = useState(false);
|
const [openDelete, setOpenDelete] = useState(false);
|
||||||
const [statusSwitch, setStatusSwitch] = useState(item.status);
|
const [statusSwitch, setStatusSwitch] = useState(item.status);
|
||||||
const [priorityValve, setPriority] = useState(item.priority);
|
const [priorityValve, setPriority] = useState(item.priority);
|
||||||
|
const [weightValve, setWeight] = useState(item.weight);
|
||||||
const [responseTimeData, setResponseTimeData] = useState({ test_time: item.test_time, response_time: item.response_time });
|
const [responseTimeData, setResponseTimeData] = useState({ test_time: item.test_time, response_time: item.response_time });
|
||||||
const [itemBalance, setItemBalance] = useState(item.balance);
|
const [itemBalance, setItemBalance] = useState(item.balance);
|
||||||
|
|
||||||
@ -81,9 +82,28 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal,
|
|||||||
if (priorityValve === '' || priorityValve === item.priority) {
|
if (priorityValve === '' || priorityValve === item.priority) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (priorityValve < 0) {
|
||||||
|
showError('优先级不能小于 0');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
await manageChannel(item.id, 'priority', priorityValve);
|
await manageChannel(item.id, 'priority', priorityValve);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const handleWeight = async () => {
|
||||||
|
if (weightValve === '' || weightValve === item.weight) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (weightValve <= 0) {
|
||||||
|
showError('权重不能小于 0');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
await manageChannel(item.id, 'weight', weightValve);
|
||||||
|
};
|
||||||
|
|
||||||
const handleResponseTime = async () => {
|
const handleResponseTime = async () => {
|
||||||
const { success, time } = await manageChannel(item.id, 'test', '');
|
const { success, time } = await manageChannel(item.id, 'test', '');
|
||||||
if (success) {
|
if (success) {
|
||||||
@ -176,6 +196,25 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal,
|
|||||||
/>
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
|
<TableCell>
|
||||||
|
<FormControl sx={{ m: 1, width: '70px' }} variant="standard">
|
||||||
|
<InputLabel htmlFor={`priority-${item.id}`}>权重</InputLabel>
|
||||||
|
<Input
|
||||||
|
id={`weight-${item.id}`}
|
||||||
|
type="text"
|
||||||
|
value={weightValve}
|
||||||
|
onChange={(e) => setWeight(e.target.value)}
|
||||||
|
sx={{ textAlign: 'center' }}
|
||||||
|
endAdornment={
|
||||||
|
<InputAdornment position="end">
|
||||||
|
<IconButton onClick={handleWeight} sx={{ color: 'rgb(99, 115, 129)' }} size="small">
|
||||||
|
<IconPencil />
|
||||||
|
</IconButton>
|
||||||
|
</InputAdornment>
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
</TableCell>
|
||||||
|
|
||||||
<TableCell>
|
<TableCell>
|
||||||
<IconButton onClick={handleOpenMenu} sx={{ color: 'rgb(99, 115, 129)' }}>
|
<IconButton onClick={handleOpenMenu} sx={{ color: 'rgb(99, 115, 129)' }}>
|
||||||
@ -204,6 +243,16 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal,
|
|||||||
<IconEdit style={{ marginRight: '16px' }} />
|
<IconEdit style={{ marginRight: '16px' }} />
|
||||||
编辑
|
编辑
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
|
|
||||||
|
<MenuItem
|
||||||
|
onClick={() => {
|
||||||
|
handleCloseMenu();
|
||||||
|
manageChannel(item.id, 'copy');
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<IconCopy style={{ marginRight: '16px' }} /> 复制{' '}
|
||||||
|
</MenuItem>
|
||||||
|
|
||||||
<MenuItem onClick={handleDeleteOpen} sx={{ color: 'error.main' }}>
|
<MenuItem onClick={handleDeleteOpen} sx={{ color: 'error.main' }}>
|
||||||
<IconTrash style={{ marginRight: '16px' }} />
|
<IconTrash style={{ marginRight: '16px' }} />
|
||||||
删除
|
删除
|
||||||
|
@ -11,6 +11,7 @@ import LinearProgress from '@mui/material/LinearProgress';
|
|||||||
import ButtonGroup from '@mui/material/ButtonGroup';
|
import ButtonGroup from '@mui/material/ButtonGroup';
|
||||||
import Toolbar from '@mui/material/Toolbar';
|
import Toolbar from '@mui/material/Toolbar';
|
||||||
import useMediaQuery from '@mui/material/useMediaQuery';
|
import useMediaQuery from '@mui/material/useMediaQuery';
|
||||||
|
import Alert from '@mui/material/Alert';
|
||||||
|
|
||||||
import { Button, IconButton, Card, Box, Stack, Container, Typography, Divider } from '@mui/material';
|
import { Button, IconButton, Card, Box, Stack, Container, Typography, Divider } from '@mui/material';
|
||||||
import ChannelTableRow from './component/TableRow';
|
import ChannelTableRow from './component/TableRow';
|
||||||
@ -116,6 +117,19 @@ export default function ChannelPage() {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
switch (action) {
|
switch (action) {
|
||||||
|
case 'copy': {
|
||||||
|
let oldRes = await API.get(`/api/channel/${id}`);
|
||||||
|
const { success, message, data } = oldRes.data;
|
||||||
|
if (!success) {
|
||||||
|
showError(message);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// 删除 data.id
|
||||||
|
delete data.id;
|
||||||
|
data.name = data.name + '_copy';
|
||||||
|
res = await API.post(`/api/channel/`, { ...data });
|
||||||
|
break;
|
||||||
|
}
|
||||||
case 'delete':
|
case 'delete':
|
||||||
res = await API.delete(url + id);
|
res = await API.delete(url + id);
|
||||||
break;
|
break;
|
||||||
@ -134,6 +148,15 @@ export default function ChannelPage() {
|
|||||||
priority: parseInt(value)
|
priority: parseInt(value)
|
||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
|
case 'weight':
|
||||||
|
if (value === '') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
res = await API.put(url, {
|
||||||
|
...data,
|
||||||
|
weight: parseInt(value)
|
||||||
|
});
|
||||||
|
break;
|
||||||
case 'test':
|
case 'test':
|
||||||
res = await API.get(url + `test/${id}`);
|
res = await API.get(url + `test/${id}`);
|
||||||
break;
|
break;
|
||||||
@ -141,7 +164,7 @@ export default function ChannelPage() {
|
|||||||
const { success, message } = res.data;
|
const { success, message } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
showSuccess('操作成功完成!');
|
showSuccess('操作成功完成!');
|
||||||
if (action === 'delete') {
|
if (action === 'delete' || action === 'copy') {
|
||||||
await handleRefresh();
|
await handleRefresh();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -271,6 +294,20 @@ export default function ChannelPage() {
|
|||||||
</Button>
|
</Button>
|
||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
</Stack>
|
</Stack>
|
||||||
|
<Stack mb={5}>
|
||||||
|
<Alert severity="info">
|
||||||
|
优先级/权重解释:
|
||||||
|
<br />
|
||||||
|
1. 优先级越大,越优先使用;(只有该优先级下的节点都冻结或者禁用了,才会使用低优先级的节点)
|
||||||
|
<br />
|
||||||
|
2. 相同优先级下:如果“MEMORY_CACHE_ENABLED”启用,则根据权重进行负载均衡(加权随机);否则忽略权重直接随机
|
||||||
|
<br />
|
||||||
|
3. 如果在设置-通用设置中设置了“重试次数”和“重试间隔”,则会在失败后重试。
|
||||||
|
<br />
|
||||||
|
4.
|
||||||
|
重试逻辑:1)先在高优先级中的节点重试,如果高优先级中的节点都冻结了,才会在低优先级中的节点重试。2)如果设置了“重试间隔”,则某一渠道失败后,会冻结一段时间,所有人都不会再使用这个渠道,直到冻结时间结束。3)重试次数用完后,直接结束。
|
||||||
|
</Alert>
|
||||||
|
</Stack>
|
||||||
<Card>
|
<Card>
|
||||||
<Box component="form" noValidate>
|
<Box component="form" noValidate>
|
||||||
<TableToolBar filterName={toolBarValue} handleFilterName={handleToolBarValue} groupOptions={groupOptions} />
|
<TableToolBar filterName={toolBarValue} handleFilterName={handleToolBarValue} groupOptions={groupOptions} />
|
||||||
@ -349,6 +386,7 @@ export default function ChannelPage() {
|
|||||||
{ id: 'response_time', label: '响应时间', disableSort: false },
|
{ id: 'response_time', label: '响应时间', disableSort: false },
|
||||||
{ id: 'balance', label: '余额', disableSort: false },
|
{ id: 'balance', label: '余额', disableSort: false },
|
||||||
{ id: 'priority', label: '优先级', disableSort: false },
|
{ id: 'priority', label: '优先级', disableSort: false },
|
||||||
|
{ id: 'weight', label: '权重', disableSort: false },
|
||||||
{ id: 'action', label: '操作', disableSort: true }
|
{ id: 'action', label: '操作', disableSort: true }
|
||||||
]}
|
]}
|
||||||
/>
|
/>
|
||||||
|
@ -30,7 +30,8 @@ const OperationSetting = () => {
|
|||||||
DisplayInCurrencyEnabled: '',
|
DisplayInCurrencyEnabled: '',
|
||||||
DisplayTokenStatEnabled: '',
|
DisplayTokenStatEnabled: '',
|
||||||
ApproximateTokenEnabled: '',
|
ApproximateTokenEnabled: '',
|
||||||
RetryTimes: 0
|
RetryTimes: 0,
|
||||||
|
RetryCooldownSeconds: 0
|
||||||
});
|
});
|
||||||
const [originInputs, setOriginInputs] = useState({});
|
const [originInputs, setOriginInputs] = useState({});
|
||||||
const [newModelRatioView, setNewModelRatioView] = useState(false);
|
const [newModelRatioView, setNewModelRatioView] = useState(false);
|
||||||
@ -139,6 +140,11 @@ const OperationSetting = () => {
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case 'general':
|
case 'general':
|
||||||
|
if (inputs.QuotaPerUnit < 0 || inputs.RetryTimes < 0 || inputs.RetryCooldownSeconds < 0) {
|
||||||
|
showError('单位额度、重试次数、冷却时间不能为负数');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
|
if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
|
||||||
await updateOption('TopUpLink', inputs.TopUpLink);
|
await updateOption('TopUpLink', inputs.TopUpLink);
|
||||||
}
|
}
|
||||||
@ -151,6 +157,9 @@ const OperationSetting = () => {
|
|||||||
if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
|
if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
|
||||||
await updateOption('RetryTimes', inputs.RetryTimes);
|
await updateOption('RetryTimes', inputs.RetryTimes);
|
||||||
}
|
}
|
||||||
|
if (originInputs['RetryCooldownSeconds'] !== inputs.RetryCooldownSeconds) {
|
||||||
|
await updateOption('RetryCooldownSeconds', inputs.RetryCooldownSeconds);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,6 +233,18 @@ const OperationSetting = () => {
|
|||||||
disabled={loading}
|
disabled={loading}
|
||||||
/>
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
|
<FormControl fullWidth>
|
||||||
|
<InputLabel htmlFor="RetryCooldownSeconds">重试间隔(秒)</InputLabel>
|
||||||
|
<OutlinedInput
|
||||||
|
id="RetryCooldownSeconds"
|
||||||
|
name="RetryCooldownSeconds"
|
||||||
|
value={inputs.RetryCooldownSeconds}
|
||||||
|
onChange={handleInputChange}
|
||||||
|
label="重试间隔(秒)"
|
||||||
|
placeholder="重试间隔(秒)"
|
||||||
|
disabled={loading}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
</Stack>
|
</Stack>
|
||||||
<Stack
|
<Stack
|
||||||
direction={{ sm: 'column', md: 'row' }}
|
direction={{ sm: 'column', md: 'row' }}
|
||||||
|
@ -12,7 +12,8 @@ import {
|
|||||||
DialogTitle,
|
DialogTitle,
|
||||||
DialogActions,
|
DialogActions,
|
||||||
DialogContent,
|
DialogContent,
|
||||||
Divider
|
Divider,
|
||||||
|
Typography
|
||||||
} from '@mui/material';
|
} from '@mui/material';
|
||||||
import Grid from '@mui/material/Unstable_Grid2';
|
import Grid from '@mui/material/Unstable_Grid2';
|
||||||
import { showError, showSuccess } from 'utils/common'; //,
|
import { showError, showSuccess } from 'utils/common'; //,
|
||||||
@ -115,16 +116,37 @@ const OtherSetting = () => {
|
|||||||
|
|
||||||
const checkUpdate = async () => {
|
const checkUpdate = async () => {
|
||||||
try {
|
try {
|
||||||
const res = await API.get('https://api.github.com/repos/MartialBE/one-api/releases/latest');
|
if (!process.env.REACT_APP_VERSION) {
|
||||||
const { tag_name, body } = res.data;
|
showError('无法获取当前版本号');
|
||||||
if (tag_name === process.env.REACT_APP_VERSION) {
|
return;
|
||||||
showSuccess(`已是最新版本:${tag_name}`);
|
}
|
||||||
|
|
||||||
|
// 如果版本前缀是v开头的
|
||||||
|
if (process.env.REACT_APP_VERSION.startsWith('v')) {
|
||||||
|
const res = await API.get('https://api.github.com/repos/MartialBE/one-api/releases/latest');
|
||||||
|
const { tag_name, body } = res.data;
|
||||||
|
if (tag_name === process.env.REACT_APP_VERSION) {
|
||||||
|
showSuccess(`已是最新版本:${tag_name}`);
|
||||||
|
} else {
|
||||||
|
setUpdateData({
|
||||||
|
tag_name: tag_name,
|
||||||
|
content: marked.parse(body)
|
||||||
|
});
|
||||||
|
setShowUpdateModal(true);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
setUpdateData({
|
const res = await API.get('https://api.github.com/repos/MartialBE/one-api/commits/main');
|
||||||
tag_name: tag_name,
|
const { sha, commit } = res.data;
|
||||||
content: marked.parse(body)
|
const newVersion = 'dev-' + sha.substr(0, 7);
|
||||||
});
|
if (newVersion === process.env.REACT_APP_VERSION) {
|
||||||
setShowUpdateModal(true);
|
showSuccess(`已是最新版本:${newVersion}`);
|
||||||
|
} else {
|
||||||
|
setUpdateData({
|
||||||
|
tag_name: newVersion,
|
||||||
|
content: marked.parse(commit.message)
|
||||||
|
});
|
||||||
|
setShowUpdateModal(true);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
return;
|
return;
|
||||||
@ -137,6 +159,9 @@ const OtherSetting = () => {
|
|||||||
<SubCard title="通用设置">
|
<SubCard title="通用设置">
|
||||||
<Grid container spacing={{ xs: 3, sm: 2, md: 4 }}>
|
<Grid container spacing={{ xs: 3, sm: 2, md: 4 }}>
|
||||||
<Grid xs={12}>
|
<Grid xs={12}>
|
||||||
|
<Typography variant="h6" gutterBottom>
|
||||||
|
当前版本:{process.env.REACT_APP_VERSION}
|
||||||
|
</Typography>
|
||||||
<Button variant="contained" onClick={checkUpdate}>
|
<Button variant="contained" onClick={checkUpdate}>
|
||||||
检查更新
|
检查更新
|
||||||
</Button>
|
</Button>
|
||||||
|
Loading…
Reference in New Issue
Block a user