From 3d8a51e139c40d4ab37cb19ae3a8c7a96a1bbd94 Mon Sep 17 00:00:00 2001
From: MartialBE
Date: Wed, 29 May 2024 01:56:14 +0800
Subject: [PATCH] =?UTF-8?q?=F0=9F=94=96=20chore:=20migration=20constants?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
cli/flag.go | 18 +-
common/common.go | 9 +-
common/config/config.go | 21 +-
common/{ => config}/constants.go | 2 +-
common/notify/channel/email.go | 8 +-
common/redis.go | 3 +-
common/stmp/email.go | 13 +-
common/stmp/email_test.go | 3 +-
common/stmp/template.go | 10 +-
common/telegram/command_aff.go | 6 +-
common/telegram/command_apikey.go | 10 +-
common/telegram/common.go | 6 +-
common/token.go | 7 +-
controller/billing.go | 14 +-
controller/channel-billing.go | 8 +-
controller/channel-test.go | 10 +-
controller/common.go | 10 +-
controller/github.go | 16 +-
controller/lark.go | 18 +-
controller/misc.go | 69 +++---
controller/option.go | 16 +-
controller/token.go | 7 +-
controller/user.go | 41 ++--
controller/wechat.go | 18 +-
main.go | 12 +-
middleware/auth.go | 10 +-
middleware/rate-limit.go | 9 +-
middleware/turnstile-check.go | 6 +-
model/ability.go | 3 +-
model/balancer.go | 10 +-
model/channel.go | 14 +-
model/common.go | 7 +-
model/log.go | 10 +-
model/main.go | 15 +-
model/option.go | 213 ++++++++---------
model/price.go | 260 ++++++++++-----------
model/redemption.go | 5 +-
model/token.go | 23 +-
model/user.go | 37 +--
model/utils.go | 4 +-
providers/ali/ali_test.go | 4 +-
providers/ali/chat.go | 3 +-
providers/ali/embeddings.go | 3 +-
providers/azure/image_generations.go | 3 +-
providers/azureSpeech/speech.go | 3 +-
providers/baichuan/chat.go | 6 +-
providers/baidu/chat.go | 3 +-
providers/baidu/embeddings.go | 3 +-
providers/base/common.go | 21 +-
providers/bedrock/chat.go | 3 +-
providers/claude/chat.go | 3 +-
providers/cohere/chat.go | 3 +-
providers/coze/chat.go | 3 +-
providers/groq/chat.go | 6 +-
providers/hunyuan/chat.go | 3 +-
providers/midjourney/base.go | 4 +-
providers/minimax/chat.go | 3 +-
providers/minimax/embeddings.go | 3 +-
providers/mistral/chat.go | 3 +-
providers/mistral/embeddings.go | 3 +-
providers/moonshot/chat.go | 5 +-
providers/ollama/chat.go | 3 +-
providers/ollama/embeddings.go | 3 +-
providers/openai/base.go | 9 +-
providers/openai/chat.go | 5 +-
providers/openai/completion.go | 5 +-
providers/openai/embeddings.go | 4 +-
providers/openai/image_edits.go | 3 +-
providers/openai/image_generations.go | 3 +-
providers/openai/image_variations.go | 4 +-
providers/openai/moderation.go | 4 +-
providers/openai/speech.go | 4 +-
providers/openai/transcriptions.go | 3 +-
providers/openai/translations.go | 3 +-
providers/palm/chat.go | 3 +-
providers/providers.go | 56 ++---
providers/stabilityAI/image_generations.go | 3 +-
providers/tencent/chat.go | 3 +-
providers/xunfei/chat.go | 3 +-
providers/zhipu/chat.go | 3 +-
providers/zhipu/embeddings.go | 3 +-
providers/zhipu/image_generations.go | 3 +-
relay/common.go | 3 +-
relay/main.go | 3 +-
relay/midjourney/relay-mj.go | 3 +-
relay/relay.go | 3 +-
relay/relay_util/cache.go | 7 +-
relay/relay_util/pricing.go | 4 +-
relay/relay_util/quota.go | 3 +-
relay/relay_util/type.go | 52 +++--
router/main.go | 4 +-
91 files changed, 670 insertions(+), 614 deletions(-)
rename common/{ => config}/constants.go (99%)
diff --git a/cli/flag.go b/cli/flag.go
index 804f6cb4..2dcfd5a3 100644
--- a/cli/flag.go
+++ b/cli/flag.go
@@ -3,7 +3,8 @@ package cli
import (
"flag"
"fmt"
- "one-api/common"
+ "one-api/common/config"
+ "one-api/common/utils"
"os"
"github.com/spf13/viper"
@@ -18,11 +19,11 @@ var (
export = flag.Bool("export", false, "Exports prices to a JSON file.")
)
-func FlagConfig() {
+func InitCli() {
flag.Parse()
if *printVersion {
- fmt.Println(common.Version)
+ fmt.Println(config.Version)
os.Exit(0)
}
@@ -44,10 +45,19 @@ func FlagConfig() {
os.Exit(0)
}
+ if Config != nil && !utils.IsFileExist(*Config) {
+ panic("Config file not found")
+ }
+
+ viper.SetConfigFile(*Config)
+ if err := viper.ReadInConfig(); err != nil {
+ panic(err)
+ }
+
}
func help() {
- fmt.Println("One API " + common.Version + " - All in one API service for OpenAI API.")
+ fmt.Println("One API " + config.Version + " - All in one API service for OpenAI API.")
fmt.Println("Copyright (C) 2024 MartialBE. All rights reserved.")
fmt.Println("Original copyright holder: JustSong")
fmt.Println("GitHub: https://github.com/MartialBE/one-api")
diff --git a/common/common.go b/common/common.go
index e3d6accc..41de9367 100644
--- a/common/common.go
+++ b/common/common.go
@@ -1,10 +1,13 @@
package common
-import "fmt"
+import (
+ "fmt"
+ "one-api/common/config"
+)
func LogQuota(quota int) string {
- if DisplayInCurrencyEnabled {
- return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
+ if config.DisplayInCurrencyEnabled {
+ return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit)
} else {
return fmt.Sprintf("%d 点额度", quota)
}
diff --git a/common/config/config.go b/common/config/config.go
index 3998c293..da6110c9 100644
--- a/common/config/config.go
+++ b/common/config/config.go
@@ -5,37 +5,22 @@ import (
"strings"
"time"
- "one-api/cli"
- "one-api/common"
"one-api/common/utils"
"github.com/spf13/viper"
)
func InitConf() {
- cli.FlagConfig()
defaultConfig()
- setConfigFile()
setEnv()
if viper.GetBool("debug") {
logger.SysLog("running in debug mode")
}
- common.IsMasterNode = viper.GetString("node_type") != "slave"
- common.RequestInterval = time.Duration(viper.GetInt("polling_interval")) * time.Second
- common.SessionSecret = utils.GetOrDefault("session_secret", common.SessionSecret)
-}
-
-func setConfigFile() {
- if !utils.IsFileExist(*cli.Config) {
- return
- }
-
- viper.SetConfigFile(*cli.Config)
- if err := viper.ReadInConfig(); err != nil {
- panic(err)
- }
+ IsMasterNode = viper.GetString("node_type") != "slave"
+ RequestInterval = time.Duration(viper.GetInt("polling_interval")) * time.Second
+ SessionSecret = utils.GetOrDefault("session_secret", SessionSecret)
}
func setEnv() {
diff --git a/common/constants.go b/common/config/constants.go
similarity index 99%
rename from common/constants.go
rename to common/config/constants.go
index 1e0c5605..345c8b9e 100644
--- a/common/constants.go
+++ b/common/config/constants.go
@@ -1,4 +1,4 @@
-package common
+package config
import (
"sync"
diff --git a/common/notify/channel/email.go b/common/notify/channel/email.go
index 6fde98a5..05425b1a 100644
--- a/common/notify/channel/email.go
+++ b/common/notify/channel/email.go
@@ -3,7 +3,7 @@ package channel
import (
"context"
"errors"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/stmp"
"github.com/gomarkdown/markdown"
@@ -28,10 +28,10 @@ func (e *Email) Name() string {
func (e *Email) Send(ctx context.Context, title, message string) error {
to := e.To
if to == "" {
- to = common.RootUserEmail
+ to = config.RootUserEmail
}
- if common.SMTPServer == "" || common.SMTPAccount == "" || common.SMTPToken == "" || to == "" {
+ if config.SMTPServer == "" || config.SMTPAccount == "" || config.SMTPToken == "" || to == "" {
return errors.New("smtp config is not set, skip send email notifier")
}
@@ -44,7 +44,7 @@ func (e *Email) Send(ctx context.Context, title, message string) error {
body := markdown.Render(doc, renderer)
- emailClient := stmp.NewStmp(common.SMTPServer, common.SMTPPort, common.SMTPAccount, common.SMTPToken, common.SMTPFrom)
+ emailClient := stmp.NewStmp(config.SMTPServer, config.SMTPPort, config.SMTPAccount, config.SMTPToken, config.SMTPFrom)
return emailClient.Send(to, title, string(body))
}
diff --git a/common/redis.go b/common/redis.go
index de537138..26fa3c31 100644
--- a/common/redis.go
+++ b/common/redis.go
@@ -2,6 +2,7 @@ package common
import (
"context"
+ "one-api/common/config"
"one-api/common/logger"
"time"
@@ -41,7 +42,7 @@ func InitRedisClient() (err error) {
} else {
RedisEnabled = true
// for compatibility with old versions
- MemoryCacheEnabled = true
+ config.MemoryCacheEnabled = true
}
return err
diff --git a/common/stmp/email.go b/common/stmp/email.go
index 1f8b6a03..8d3292c9 100644
--- a/common/stmp/email.go
+++ b/common/stmp/email.go
@@ -3,6 +3,7 @@ package stmp
import (
"fmt"
"one-api/common"
+ "one-api/common/config"
"one-api/common/utils"
"strings"
@@ -38,7 +39,7 @@ func (s *StmpConfig) Send(to, subject, body string) error {
message.Subject(subject)
message.SetGenHeader("References", s.getReferences())
message.SetBodyString(mail.TypeTextHTML, body)
- message.SetUserAgent(fmt.Sprintf("One API %s // https://github.com/MartialBE/one-api", common.Version))
+ message.SetUserAgent(fmt.Sprintf("One API %s // https://github.com/MartialBE/one-api", config.Version))
client, err := mail.NewClient(
s.Host,
@@ -78,11 +79,11 @@ func (s *StmpConfig) Render(to, subject, content string) error {
}
func GetSystemStmp() (*StmpConfig, error) {
- if common.SMTPServer == "" || common.SMTPPort == 0 || common.SMTPAccount == "" || common.SMTPToken == "" {
+ if config.SMTPServer == "" || config.SMTPPort == 0 || config.SMTPAccount == "" || config.SMTPToken == "" {
return nil, fmt.Errorf("SMTP 信息未配置")
}
- return NewStmp(common.SMTPServer, common.SMTPPort, common.SMTPAccount, common.SMTPToken, common.SMTPFrom), nil
+ return NewStmp(config.SMTPServer, config.SMTPPort, config.SMTPAccount, config.SMTPToken, config.SMTPFrom), nil
}
func SendPasswordResetEmail(userName, email, link string) error {
@@ -106,7 +107,7 @@ func SendPasswordResetEmail(userName, email, link string) error {
重置链接 %d 分钟内有效,如果不是本人操作,请忽略。
`
- subject := fmt.Sprintf("%s密码重置", common.SystemName)
+ subject := fmt.Sprintf("%s密码重置", config.SystemName)
content := fmt.Sprintf(contentTemp, userName, link, link, common.VerificationValidMinutes)
return stmp.Render(email, subject, content)
@@ -132,7 +133,7 @@ func SendVerificationCodeEmail(email, code string) error {
验证码 %d 分钟内有效,如果不是本人操作,请忽略。
`
- subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName)
+ subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName)
content := fmt.Sprintf(contentTemp, code, common.VerificationValidMinutes)
return stmp.Render(email, subject, content)
@@ -162,7 +163,7 @@ func SendQuotaWarningCodeEmail(userName, email string, quota int, noMoreQuota bo
if noMoreQuota {
subject = "您的额度已用尽"
}
- topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
+ topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress)
content := fmt.Sprintf(contentTemp, userName, subject, quota, topUpLink, topUpLink)
diff --git a/common/stmp/email_test.go b/common/stmp/email_test.go
index 778c6ce9..081de2f2 100644
--- a/common/stmp/email_test.go
+++ b/common/stmp/email_test.go
@@ -2,6 +2,7 @@ package stmp_test
import (
"fmt"
+ "one-api/common/config"
"testing"
"one-api/common"
@@ -56,7 +57,7 @@ func TestSend(t *testing.T) {
验证码 %d 分钟内有效,如果不是本人操作,请忽略。
`
- subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName)
+ subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName)
content := fmt.Sprintf(contentTemp, code, common.VerificationValidMinutes)
err := stmpClient.Render(email, subject, content)
diff --git a/common/stmp/template.go b/common/stmp/template.go
index 8ac44147..1eb884d9 100644
--- a/common/stmp/template.go
+++ b/common/stmp/template.go
@@ -1,17 +1,17 @@
package stmp
import (
- "one-api/common"
+ "one-api/common/config"
)
func getLogo() string {
- if common.Logo == "" {
+ if config.Logo == "" {
return ""
}
return `
-
|
@@ -19,11 +19,11 @@ func getLogo() string {
}
func getSystemName() string {
- if common.SystemName == "" {
+ if config.SystemName == "" {
return "One API"
}
- return common.SystemName
+ return config.SystemName
}
func getDefaultTemplate(content string) string {
diff --git a/common/telegram/command_aff.go b/common/telegram/command_aff.go
index 68733169..a7d3188b 100644
--- a/common/telegram/command_aff.go
+++ b/common/telegram/command_aff.go
@@ -1,7 +1,7 @@
package telegram
import (
- "one-api/common"
+ "one-api/common/config"
"one-api/common/utils"
"strings"
@@ -24,8 +24,8 @@ func commandAffStart(b *gotgbot.Bot, ctx *ext.Context) error {
}
messae := "您可以通过分享您的邀请码来邀请朋友,每次成功邀请将获得奖励。\n\n您的邀请码是: " + user.AffCode
- if common.ServerAddress != "" {
- serverAddress := strings.TrimSuffix(common.ServerAddress, "/")
+ if config.ServerAddress != "" {
+ serverAddress := strings.TrimSuffix(config.ServerAddress, "/")
messae += "\n\n页面地址:" + serverAddress + "/register?aff=" + user.AffCode
}
diff --git a/common/telegram/command_apikey.go b/common/telegram/command_apikey.go
index cd6295db..c0062b3c 100644
--- a/common/telegram/command_apikey.go
+++ b/common/telegram/command_apikey.go
@@ -3,7 +3,7 @@ package telegram
import (
"fmt"
"net/url"
- "one-api/common"
+ "one-api/common/config"
"one-api/model"
"strings"
@@ -56,7 +56,7 @@ func getApikeyList(userId, page int) (message string, pageParams *paginationPara
}
chatUrlTmp := ""
- if common.ServerAddress != "" {
+ if config.ServerAddress != "" {
chatUrlTmp = getChatUrl()
}
@@ -75,11 +75,11 @@ func getApikeyList(userId, page int) (message string, pageParams *paginationPara
}
func getChatUrl() string {
- serverAddress := strings.TrimSuffix(common.ServerAddress, "/")
+ serverAddress := strings.TrimSuffix(config.ServerAddress, "/")
chatNextUrl := fmt.Sprintf(`{"key":"setToken","url":"%s"}`, serverAddress)
chatNextUrl = "https://chat.oneapi.pro/#/?settings=" + url.QueryEscape(chatNextUrl)
- if common.ChatLink != "" {
- chatLink := strings.TrimSuffix(common.ChatLink, "/")
+ if config.ChatLink != "" {
+ chatLink := strings.TrimSuffix(config.ChatLink, "/")
chatNextUrl = strings.ReplaceAll(chatNextUrl, `https://chat.oneapi.pro`, chatLink)
}
diff --git a/common/telegram/common.go b/common/telegram/common.go
index 4a3a3160..9b039b76 100644
--- a/common/telegram/common.go
+++ b/common/telegram/common.go
@@ -7,7 +7,7 @@ import (
"net"
"net/http"
"net/url"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"one-api/model"
"strings"
@@ -56,13 +56,13 @@ func InitTelegramBot() {
func StartTelegramBot() {
botWebhook := viper.GetString("tg.webhook_secret")
if botWebhook != "" {
- if common.ServerAddress == "" {
+ if config.ServerAddress == "" {
logger.SysLog("Telegram bot is not enabled: Server address is not set")
StopTelegramBot()
return
}
TGWebHookSecret = botWebhook
- serverAddress := strings.TrimSuffix(common.ServerAddress, "/")
+ serverAddress := strings.TrimSuffix(config.ServerAddress, "/")
urlPath := fmt.Sprintf("/api/telegram/%s", viper.GetString("tg.bot_api_key"))
webHookOpts := &ext.AddWebhookOpts{
diff --git a/common/token.go b/common/token.go
index df125193..8d5090c3 100644
--- a/common/token.go
+++ b/common/token.go
@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"math"
+ "one-api/common/config"
"one-api/common/logger"
"strings"
@@ -21,7 +22,7 @@ var gpt4oTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() {
if viper.GetBool("disable_token_encoders") {
- DISABLE_TOKEN_ENCODERS = true
+ config.DISABLE_TOKEN_ENCODERS = true
logger.SysLog("token encoders disabled")
return
}
@@ -46,7 +47,7 @@ func InitTokenEncoders() {
}
func getTokenEncoder(model string) *tiktoken.Tiktoken {
- if DISABLE_TOKEN_ENCODERS {
+ if config.DISABLE_TOKEN_ENCODERS {
return nil
}
@@ -75,7 +76,7 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
}
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
- if DISABLE_TOKEN_ENCODERS || ApproximateTokenEnabled {
+ if config.DISABLE_TOKEN_ENCODERS || config.ApproximateTokenEnabled {
return int(float64(len(text)) * 0.38)
}
return len(tokenEncoder.Encode(text, nil, nil))
diff --git a/controller/billing.go b/controller/billing.go
index 3de7c847..6e6192b8 100644
--- a/controller/billing.go
+++ b/controller/billing.go
@@ -1,7 +1,7 @@
package controller
import (
- "one-api/common"
+ "one-api/common/config"
"one-api/model"
"one-api/types"
@@ -14,7 +14,7 @@ func GetSubscription(c *gin.Context) {
var err error
var token *model.Token
var expiredTime int64
- if common.DisplayTokenStatEnabled {
+ if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId)
expiredTime = token.ExpiredTime
@@ -50,8 +50,8 @@ func GetSubscription(c *gin.Context) {
}
quota := remainQuota + usedQuota
amount := float64(quota)
- if common.DisplayInCurrencyEnabled {
- amount /= common.QuotaPerUnit
+ if config.DisplayInCurrencyEnabled {
+ amount /= config.QuotaPerUnit
}
if token != nil && token.UnlimitedQuota {
amount = 100000000
@@ -71,7 +71,7 @@ func GetUsage(c *gin.Context) {
var quota int
var err error
var token *model.Token
- if common.DisplayTokenStatEnabled {
+ if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId)
quota = token.UsedQuota
@@ -90,8 +90,8 @@ func GetUsage(c *gin.Context) {
return
}
amount := float64(quota)
- if common.DisplayInCurrencyEnabled {
- amount /= common.QuotaPerUnit
+ if config.DisplayInCurrencyEnabled {
+ amount /= config.QuotaPerUnit
}
usage := OpenAIUsageResponse{
Object: "list",
diff --git a/controller/channel-billing.go b/controller/channel-billing.go
index ffb53fbd..2f3254d1 100644
--- a/controller/channel-billing.go
+++ b/controller/channel-billing.go
@@ -4,7 +4,7 @@ import (
"errors"
"net/http"
"net/http/httptest"
- "one-api/common"
+ "one-api/common/config"
"one-api/model"
"one-api/providers"
providersBase "one-api/providers/base"
@@ -109,11 +109,11 @@ func updateAllChannelsBalance() error {
return err
}
for _, channel := range channels {
- if channel.Status != common.ChannelStatusEnabled {
+ if channel.Status != config.ChannelStatusEnabled {
continue
}
// TODO: support Azure
- if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
+ if channel.Type != config.ChannelTypeOpenAI && channel.Type != config.ChannelTypeCustom {
continue
}
balance, err := updateChannelBalance(channel)
@@ -125,7 +125,7 @@ func updateAllChannelsBalance() error {
DisableChannel(channel.Id, channel.Name, "余额不足", true)
}
}
- time.Sleep(common.RequestInterval)
+ time.Sleep(config.RequestInterval)
}
return nil
}
diff --git a/controller/channel-test.go b/controller/channel-test.go
index ee380893..677d48c7 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -6,7 +6,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"one-api/common/notify"
"one-api/common/utils"
@@ -145,16 +145,16 @@ func testAllChannels(isNotify bool) error {
if err != nil {
return err
}
- var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
+ var disableThreshold = int64(config.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
go func() {
var sendMessage string
for _, channel := range channels {
- time.Sleep(common.RequestInterval)
+ time.Sleep(config.RequestInterval)
- isChannelEnabled := channel.Status == common.ChannelStatusEnabled
+ isChannelEnabled := channel.Status == config.ChannelStatusEnabled
sendMessage += fmt.Sprintf("**通道 %s - #%d - %s** : \n\n", utils.EscapeMarkdownText(channel.Name), channel.Id, channel.StatusToStr())
tik := time.Now()
err, openaiErr := testChannel(channel, "")
@@ -173,7 +173,7 @@ func testAllChannels(isNotify bool) error {
// 如果已被禁用,但是请求成功,需要判断是否需要恢复
// 手动禁用的通道,不会自动恢复
if shouldEnableChannel(err, openaiErr) {
- if channel.Status == common.ChannelStatusAutoDisabled {
+ if channel.Status == config.ChannelStatusAutoDisabled {
EnableChannel(channel.Id, channel.Name, false)
sendMessage += "- 已被启用 \n\n"
} else {
diff --git a/controller/common.go b/controller/common.go
index bad43a3e..2a3e17f3 100644
--- a/controller/common.go
+++ b/controller/common.go
@@ -3,7 +3,7 @@ package controller
import (
"fmt"
"net/http"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/notify"
"one-api/model"
"one-api/types"
@@ -13,7 +13,7 @@ import (
)
func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool {
- if !common.AutomaticEnableChannelEnabled {
+ if !config.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
@@ -26,7 +26,7 @@ func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool {
}
func ShouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
- if !common.AutomaticDisableChannelEnabled {
+ if !config.AutomaticDisableChannelEnabled {
return false
}
@@ -74,7 +74,7 @@ func ShouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
// disable & notify
func DisableChannel(channelId int, channelName string, reason string, sendNotify bool) {
- model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
+ model.UpdateChannelStatusById(channelId, config.ChannelStatusAutoDisabled)
if !sendNotify {
return
}
@@ -86,7 +86,7 @@ func DisableChannel(channelId int, channelName string, reason string, sendNotify
// enable & notify
func EnableChannel(channelId int, channelName string, sendNotify bool) {
- model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
+ model.UpdateChannelStatusById(channelId, config.ChannelStatusEnabled)
if !sendNotify {
return
}
diff --git a/controller/github.go b/controller/github.go
index 2c452bc0..5c83a89e 100644
--- a/controller/github.go
+++ b/controller/github.go
@@ -6,7 +6,7 @@ import (
"errors"
"fmt"
"net/http"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"one-api/common/utils"
"one-api/model"
@@ -33,7 +33,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
- values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
+ values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
@@ -96,7 +96,7 @@ func GitHubOAuth(c *gin.Context) {
return
}
- if !common.GitHubOAuthEnabled {
+ if !config.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 GitHub 登录以及注册",
@@ -125,7 +125,7 @@ func GitHubOAuth(c *gin.Context) {
return
}
} else {
- if common.RegisterEnabled {
+ if config.RegisterEnabled {
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" {
user.DisplayName = githubUser.Name
@@ -133,8 +133,8 @@ func GitHubOAuth(c *gin.Context) {
user.DisplayName = "GitHub User"
}
user.Email = githubUser.Email
- user.Role = common.RoleCommonUser
- user.Status = common.UserStatusEnabled
+ user.Role = config.RoleCommonUser
+ user.Status = config.UserStatusEnabled
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -152,7 +152,7 @@ func GitHubOAuth(c *gin.Context) {
}
}
- if user.Status != common.UserStatusEnabled {
+ if user.Status != config.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
@@ -163,7 +163,7 @@ func GitHubOAuth(c *gin.Context) {
}
func GitHubBind(c *gin.Context) {
- if !common.GitHubOAuthEnabled {
+ if !config.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 GitHub 登录以及注册",
diff --git a/controller/lark.go b/controller/lark.go
index b8e8d06f..7fc88951 100644
--- a/controller/lark.go
+++ b/controller/lark.go
@@ -6,7 +6,7 @@ import (
"errors"
"fmt"
"net/http"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"one-api/model"
"strconv"
@@ -41,8 +41,8 @@ type LarkUser struct {
func getLarkAppAccessToken() (string, error) {
values := map[string]string{
- "app_id": common.LarkClientId,
- "app_secret": common.LarkClientSecret,
+ "app_id": config.LarkClientId,
+ "app_secret": config.LarkClientSecret,
}
jsonData, err := json.Marshal(values)
if err != nil {
@@ -148,7 +148,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) {
}
func LarkOAuth(c *gin.Context) {
- if !common.LarkAuthEnabled {
+ if !config.LarkAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过飞书登录以及注册",
"success": false,
@@ -191,15 +191,15 @@ func LarkOAuth(c *gin.Context) {
return
}
} else {
- if common.RegisterEnabled {
+ if config.RegisterEnabled {
user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1)
if larkUser.Data.Name != "" {
user.DisplayName = larkUser.Data.Name
} else {
user.DisplayName = "Lark User"
}
- user.Role = common.RoleCommonUser
- user.Status = common.UserStatusEnabled
+ user.Role = config.RoleCommonUser
+ user.Status = config.UserStatusEnabled
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -217,7 +217,7 @@ func LarkOAuth(c *gin.Context) {
}
}
- if user.Status != common.UserStatusEnabled {
+ if user.Status != config.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
@@ -228,7 +228,7 @@ func LarkOAuth(c *gin.Context) {
}
func LarkBind(c *gin.Context) {
- if !common.LarkAuthEnabled {
+ if !config.LarkAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过飞书登录以及注册",
"success": false,
diff --git a/controller/misc.go b/controller/misc.go
index e2c37e50..f026262e 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/stmp"
"one-api/common/telegram"
"one-api/model"
@@ -23,60 +24,60 @@ func GetStatus(c *gin.Context) {
"success": true,
"message": "",
"data": gin.H{
- "version": common.Version,
- "start_time": common.StartTime,
- "email_verification": common.EmailVerificationEnabled,
- "github_oauth": common.GitHubOAuthEnabled,
- "github_client_id": common.GitHubClientId,
- "lark_login": common.LarkAuthEnabled,
- "lark_client_id": common.LarkClientId,
- "system_name": common.SystemName,
- "logo": common.Logo,
- "footer_html": common.Footer,
- "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
- "wechat_login": common.WeChatAuthEnabled,
- "server_address": common.ServerAddress,
- "turnstile_check": common.TurnstileCheckEnabled,
- "turnstile_site_key": common.TurnstileSiteKey,
- "top_up_link": common.TopUpLink,
- "chat_link": common.ChatLink,
- "quota_per_unit": common.QuotaPerUnit,
- "display_in_currency": common.DisplayInCurrencyEnabled,
+ "version": config.Version,
+ "start_time": config.StartTime,
+ "email_verification": config.EmailVerificationEnabled,
+ "github_oauth": config.GitHubOAuthEnabled,
+ "github_client_id": config.GitHubClientId,
+ "lark_login": config.LarkAuthEnabled,
+ "lark_client_id": config.LarkClientId,
+ "system_name": config.SystemName,
+ "logo": config.Logo,
+ "footer_html": config.Footer,
+ "wechat_qrcode": config.WeChatAccountQRCodeImageURL,
+ "wechat_login": config.WeChatAuthEnabled,
+ "server_address": config.ServerAddress,
+ "turnstile_check": config.TurnstileCheckEnabled,
+ "turnstile_site_key": config.TurnstileSiteKey,
+ "top_up_link": config.TopUpLink,
+ "chat_link": config.ChatLink,
+ "quota_per_unit": config.QuotaPerUnit,
+ "display_in_currency": config.DisplayInCurrencyEnabled,
"telegram_bot": telegram_bot,
- "mj_notify_enabled": common.MjNotifyEnabled,
- "chat_cache_enabled": common.ChatCacheEnabled,
- "chat_links": common.ChatLinks,
+ "mj_notify_enabled": config.MjNotifyEnabled,
+ "chat_cache_enabled": config.ChatCacheEnabled,
+ "chat_links": config.ChatLinks,
},
})
}
func GetNotice(c *gin.Context) {
- common.OptionMapRWMutex.RLock()
- defer common.OptionMapRWMutex.RUnlock()
+ config.OptionMapRWMutex.RLock()
+ defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": common.OptionMap["Notice"],
+ "data": config.OptionMap["Notice"],
})
}
func GetAbout(c *gin.Context) {
- common.OptionMapRWMutex.RLock()
- defer common.OptionMapRWMutex.RUnlock()
+ config.OptionMapRWMutex.RLock()
+ defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": common.OptionMap["About"],
+ "data": config.OptionMap["About"],
})
}
func GetHomePageContent(c *gin.Context) {
- common.OptionMapRWMutex.RLock()
- defer common.OptionMapRWMutex.RUnlock()
+ config.OptionMapRWMutex.RLock()
+ defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": common.OptionMap["HomePageContent"],
+ "data": config.OptionMap["HomePageContent"],
})
}
@@ -89,9 +90,9 @@ func SendEmailVerification(c *gin.Context) {
})
return
}
- if common.EmailDomainRestrictionEnabled {
+ if config.EmailDomainRestrictionEnabled {
allowed := false
- for _, domain := range common.EmailDomainWhitelist {
+ for _, domain := range config.EmailDomainWhitelist {
if strings.HasSuffix(email, "@"+domain) {
allowed = true
break
@@ -157,7 +158,7 @@ func SendPasswordResetEmail(c *gin.Context) {
code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
- link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
+ link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code)
err := stmp.SendPasswordResetEmail(userName, email, link)
if err != nil {
diff --git a/controller/option.go b/controller/option.go
index 99f36607..438688f0 100644
--- a/controller/option.go
+++ b/controller/option.go
@@ -3,7 +3,7 @@ package controller
import (
"encoding/json"
"net/http"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/utils"
"one-api/model"
"strings"
@@ -13,8 +13,8 @@ import (
func GetOptions(c *gin.Context) {
var options []*model.Option
- common.OptionMapRWMutex.Lock()
- for k, v := range common.OptionMap {
+ config.OptionMapRWMutex.Lock()
+ for k, v := range config.OptionMap {
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
continue
}
@@ -23,7 +23,7 @@ func GetOptions(c *gin.Context) {
Value: utils.Interface2String(v),
})
}
- common.OptionMapRWMutex.Unlock()
+ config.OptionMapRWMutex.Unlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -44,7 +44,7 @@ func UpdateOption(c *gin.Context) {
}
switch option.Key {
case "GitHubOAuthEnabled":
- if option.Value == "true" && common.GitHubClientId == "" {
+ if option.Value == "true" && config.GitHubClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!",
@@ -52,7 +52,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "EmailDomainRestrictionEnabled":
- if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
+ if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
@@ -60,7 +60,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "WeChatAuthEnabled":
- if option.Value == "true" && common.WeChatServerAddress == "" {
+ if option.Value == "true" && config.WeChatServerAddress == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用微信登录,请先填入微信登录相关配置信息!",
@@ -68,7 +68,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "TurnstileCheckEnabled":
- if option.Value == "true" && common.TurnstileSiteKey == "" {
+ if option.Value == "true" && config.TurnstileSiteKey == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",
diff --git a/controller/token.go b/controller/token.go
index 9f43b111..43eed435 100644
--- a/controller/token.go
+++ b/controller/token.go
@@ -3,6 +3,7 @@ package controller
import (
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/utils"
"one-api/model"
"strconv"
@@ -199,15 +200,15 @@ func UpdateToken(c *gin.Context) {
})
return
}
- if token.Status == common.TokenStatusEnabled {
- if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= utils.GetTimestamp() && cleanToken.ExpiredTime != -1 {
+ if token.Status == config.TokenStatusEnabled {
+ if cleanToken.Status == config.TokenStatusExpired && cleanToken.ExpiredTime <= utils.GetTimestamp() && cleanToken.ExpiredTime != -1 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
})
return
}
- if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
+ if cleanToken.Status == config.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",
diff --git a/controller/user.go b/controller/user.go
index 4973314d..82b4a556 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/utils"
"one-api/model"
"strconv"
@@ -20,7 +21,7 @@ type LoginRequest struct {
}
func Login(c *gin.Context) {
- if !common.PasswordLoginEnabled {
+ if !config.PasswordLoginEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了密码登录",
"success": false,
@@ -107,14 +108,14 @@ func Logout(c *gin.Context) {
}
func Register(c *gin.Context) {
- if !common.RegisterEnabled {
+ if !config.RegisterEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了新用户注册",
"success": false,
})
return
}
- if !common.PasswordRegisterEnabled {
+ if !config.PasswordRegisterEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册",
"success": false,
@@ -137,7 +138,7 @@ func Register(c *gin.Context) {
})
return
}
- if common.EmailVerificationEnabled {
+ if config.EmailVerificationEnabled {
if user.Email == "" || user.VerificationCode == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -161,7 +162,7 @@ func Register(c *gin.Context) {
DisplayName: user.Username,
InviterId: inviterId,
}
- if common.EmailVerificationEnabled {
+ if config.EmailVerificationEnabled {
cleanUser.Email = user.Email
}
if err := cleanUser.Insert(inviterId); err != nil {
@@ -214,7 +215,7 @@ func GetUser(c *gin.Context) {
return
}
myRole := c.GetInt("role")
- if myRole <= user.Role && myRole != common.RoleRootUser {
+ if myRole <= user.Role && myRole != config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权获取同级或更高等级用户的信息",
@@ -360,14 +361,14 @@ func UpdateUser(c *gin.Context) {
return
}
myRole := c.GetInt("role")
- if myRole <= originUser.Role && myRole != common.RoleRootUser {
+ if myRole <= originUser.Role && myRole != config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权更新同权限等级或更高权限等级的用户信息",
})
return
}
- if myRole <= updatedUser.Role && myRole != common.RoleRootUser {
+ if myRole <= updatedUser.Role && myRole != config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
@@ -479,7 +480,7 @@ func DeleteSelf(c *gin.Context) {
id := c.GetInt("id")
user, _ := model.GetUserById(id, false)
- if user.Role == common.RoleRootUser {
+ if user.Role == config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不能删除超级管理员账户",
@@ -579,7 +580,7 @@ func ManageUser(c *gin.Context) {
return
}
myRole := c.GetInt("role")
- if myRole <= user.Role && myRole != common.RoleRootUser {
+ if myRole <= user.Role && myRole != config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权更新同权限等级或更高权限等级的用户信息",
@@ -588,8 +589,8 @@ func ManageUser(c *gin.Context) {
}
switch req.Action {
case "disable":
- user.Status = common.UserStatusDisabled
- if user.Role == common.RoleRootUser {
+ user.Status = config.UserStatusDisabled
+ if user.Role == config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法禁用超级管理员用户",
@@ -597,9 +598,9 @@ func ManageUser(c *gin.Context) {
return
}
case "enable":
- user.Status = common.UserStatusEnabled
+ user.Status = config.UserStatusEnabled
case "delete":
- if user.Role == common.RoleRootUser {
+ if user.Role == config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法删除超级管理员用户",
@@ -614,37 +615,37 @@ func ManageUser(c *gin.Context) {
return
}
case "promote":
- if myRole != common.RoleRootUser {
+ if myRole != config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "普通管理员用户无法提升其他用户为管理员",
})
return
}
- if user.Role >= common.RoleAdminUser {
+ if user.Role >= config.RoleAdminUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户已经是管理员",
})
return
}
- user.Role = common.RoleAdminUser
+ user.Role = config.RoleAdminUser
case "demote":
- if user.Role == common.RoleRootUser {
+ if user.Role == config.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法降级超级管理员用户",
})
return
}
- if user.Role == common.RoleCommonUser {
+ if user.Role == config.RoleCommonUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户已经是普通用户",
})
return
}
- user.Role = common.RoleCommonUser
+ user.Role = config.RoleCommonUser
}
if err := user.Update(false); err != nil {
diff --git a/controller/wechat.go b/controller/wechat.go
index fbd7d2bd..2ff9140b 100644
--- a/controller/wechat.go
+++ b/controller/wechat.go
@@ -5,7 +5,7 @@ import (
"errors"
"fmt"
"net/http"
- "one-api/common"
+ "one-api/common/config"
"one-api/model"
"strconv"
"time"
@@ -23,11 +23,11 @@ func getWeChatIdByCode(code string) (string, error) {
if code == "" {
return "", errors.New("无效的参数")
}
- req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil)
+ req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil)
if err != nil {
return "", err
}
- req.Header.Set("Authorization", common.WeChatServerToken)
+ req.Header.Set("Authorization", config.WeChatServerToken)
client := http.Client{
Timeout: 5 * time.Second,
}
@@ -51,7 +51,7 @@ func getWeChatIdByCode(code string) (string, error) {
}
func WeChatAuth(c *gin.Context) {
- if !common.WeChatAuthEnabled {
+ if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册",
"success": false,
@@ -80,11 +80,11 @@ func WeChatAuth(c *gin.Context) {
return
}
} else {
- if common.RegisterEnabled {
+ if config.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = "WeChat User"
- user.Role = common.RoleCommonUser
- user.Status = common.UserStatusEnabled
+ user.Role = config.RoleCommonUser
+ user.Status = config.UserStatusEnabled
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -102,7 +102,7 @@ func WeChatAuth(c *gin.Context) {
}
}
- if user.Status != common.UserStatusEnabled {
+ if user.Status != config.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
@@ -113,7 +113,7 @@ func WeChatAuth(c *gin.Context) {
}
func WeChatBind(c *gin.Context) {
- if !common.WeChatAuthEnabled {
+ if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册",
"success": false,
diff --git a/main.go b/main.go
index ef0c1984..cc68a8a4 100644
--- a/main.go
+++ b/main.go
@@ -3,6 +3,7 @@ package main
import (
"embed"
"fmt"
+ "one-api/cli"
"one-api/common"
"one-api/common/config"
"one-api/common/logger"
@@ -31,9 +32,10 @@ var buildFS embed.FS
var indexPage []byte
func main() {
+ cli.InitCli()
config.InitConf()
logger.SetupLogger()
- logger.SysLog("One API " + common.Version + " started")
+ logger.SysLog("One API " + config.Version + " started")
// Initialize SQL Database
model.SetupDB()
defer model.CloseDB()
@@ -60,10 +62,10 @@ func main() {
func initMemoryCache() {
if viper.GetBool("memory_cache_enabled") {
- common.MemoryCacheEnabled = true
+ config.MemoryCacheEnabled = true
}
- if !common.MemoryCacheEnabled {
+ if !config.MemoryCacheEnabled {
return
}
@@ -91,7 +93,7 @@ func initHttpServer() {
server.Use(middleware.RequestId())
middleware.SetUpLogger(server)
- store := cookie.NewStore([]byte(common.SessionSecret))
+ store := cookie.NewStore([]byte(config.SessionSecret))
server.Use(sessions.Sessions("session", store))
router.SetRouter(server, buildFS, indexPage)
@@ -105,7 +107,7 @@ func initHttpServer() {
func SyncChannelCache(frequency int) {
// 只有 从 服务器端获取数据的时候才会用到
- if common.IsMasterNode {
+ if config.IsMasterNode {
logger.SysLog("master node does't synchronize the channel")
return
}
diff --git a/middleware/auth.go b/middleware/auth.go
index 697acacf..e3834a4b 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -2,7 +2,7 @@ package middleware
import (
"net/http"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/utils"
"one-api/model"
"strings"
@@ -44,7 +44,7 @@ func authHelper(c *gin.Context, minRole int) {
return
}
}
- if status.(int) == common.UserStatusDisabled {
+ if status.(int) == config.UserStatusDisabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已被封禁",
@@ -68,19 +68,19 @@ func authHelper(c *gin.Context, minRole int) {
func UserAuth() func(c *gin.Context) {
return func(c *gin.Context) {
- authHelper(c, common.RoleCommonUser)
+ authHelper(c, config.RoleCommonUser)
}
}
func AdminAuth() func(c *gin.Context) {
return func(c *gin.Context) {
- authHelper(c, common.RoleAdminUser)
+ authHelper(c, config.RoleAdminUser)
}
}
func RootAuth() func(c *gin.Context) {
return func(c *gin.Context) {
- authHelper(c, common.RoleRootUser)
+ authHelper(c, config.RoleRootUser)
}
}
diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go
index dae6639c..27f2f190 100644
--- a/middleware/rate-limit.go
+++ b/middleware/rate-limit.go
@@ -4,6 +4,7 @@ import (
"context"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/utils"
"time"
@@ -45,7 +46,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
}
if listLength < int64(maxRequestNum) {
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
- rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+ rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
} else {
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
@@ -64,14 +65,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
// time.Since will return negative number!
// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
- rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+ rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
c.Status(http.StatusTooManyRequests)
c.Abort()
return
} else {
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
- rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+ rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
}
}
}
@@ -92,7 +93,7 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
}
} else {
// It's safe to call multi times.
- inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
+ inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration)
return func(c *gin.Context) {
memoryRateLimiter(c, maxRequestNum, duration, mark)
}
diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go
index 6f295864..629395e7 100644
--- a/middleware/turnstile-check.go
+++ b/middleware/turnstile-check.go
@@ -6,7 +6,7 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"net/url"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
)
@@ -16,7 +16,7 @@ type turnstileCheckResponse struct {
func TurnstileCheck() gin.HandlerFunc {
return func(c *gin.Context) {
- if common.TurnstileCheckEnabled {
+ if config.TurnstileCheckEnabled {
session := sessions.Default(c)
turnstileChecked := session.Get("turnstile")
if turnstileChecked != nil {
@@ -33,7 +33,7 @@ func TurnstileCheck() gin.HandlerFunc {
return
}
rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{
- "secret": {common.TurnstileSecretKey},
+ "secret": {config.TurnstileSecretKey},
"response": {response},
"remoteip": {c.ClientIP()},
})
diff --git a/model/ability.go b/model/ability.go
index 8aeb2811..2b973e80 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -2,6 +2,7 @@ package model
import (
"one-api/common"
+ "one-api/common/config"
"strings"
)
@@ -66,7 +67,7 @@ func (channel *Channel) AddAbilities() error {
Group: group,
Model: model,
ChannelId: channel.Id,
- Enabled: channel.Status == common.ChannelStatusEnabled,
+ Enabled: channel.Status == config.ChannelStatusEnabled,
Priority: channel.Priority,
Weight: channel.Weight,
}
diff --git a/model/balancer.go b/model/balancer.go
index 82cd0289..9873d90f 100644
--- a/model/balancer.go
+++ b/model/balancer.go
@@ -3,7 +3,7 @@ package model
import (
"errors"
"math/rand"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"one-api/common/utils"
"strings"
@@ -38,7 +38,7 @@ func FilterOnlyChat() ChannelsFilterFunc {
}
func (cc *ChannelsChooser) Cooldowns(channelId int) bool {
- if common.RetryCooldownSeconds == 0 {
+ if config.RetryCooldownSeconds == 0 {
return false
}
cc.Lock()
@@ -47,7 +47,7 @@ func (cc *ChannelsChooser) Cooldowns(channelId int) bool {
return false
}
- cc.Channels[channelId].CooldownsTime = time.Now().Unix() + int64(common.RetryCooldownSeconds)
+ cc.Channels[channelId].CooldownsTime = time.Now().Unix() + int64(config.RetryCooldownSeconds)
return true
}
@@ -159,7 +159,7 @@ var ChannelGroup = ChannelsChooser{}
func (cc *ChannelsChooser) Load() {
var channels []*Channel
- DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
+ DB.Where("status = ?", config.ChannelStatusEnabled).Find(&channels)
abilities, err := GetAbilityChannelGroup()
if err != nil {
@@ -173,7 +173,7 @@ func (cc *ChannelsChooser) Load() {
for _, channel := range channels {
if *channel.Weight == 0 {
- channel.Weight = &common.DefaultChannelWeight
+ channel.Weight = &config.DefaultChannelWeight
}
newChannels[channel.Id] = &ChannelChoice{
Channel: channel,
diff --git a/model/channel.go b/model/channel.go
index b2733ae7..ab38cefc 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -1,7 +1,7 @@
package model
import (
- "one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"one-api/common/utils"
"strings"
@@ -270,11 +270,11 @@ func (channel *Channel) Delete() error {
func (channel *Channel) StatusToStr() string {
switch channel.Status {
- case common.ChannelStatusEnabled:
+ case config.ChannelStatusEnabled:
return "启用"
- case common.ChannelStatusAutoDisabled:
+ case config.ChannelStatusAutoDisabled:
return "自动禁用"
- case common.ChannelStatusManuallyDisabled:
+ case config.ChannelStatusManuallyDisabled:
return "手动禁用"
}
@@ -282,7 +282,7 @@ func (channel *Channel) StatusToStr() string {
}
func UpdateChannelStatusById(id int, status int) {
- err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
+ err := UpdateAbilityStatus(id, status == config.ChannelStatusEnabled)
if err != nil {
logger.SysError("failed to update ability status: " + err.Error())
}
@@ -298,7 +298,7 @@ func UpdateChannelStatusById(id int, status int) {
}
func UpdateChannelUsedQuota(id int, quota int) {
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
return
}
@@ -318,7 +318,7 @@ func DeleteChannelByStatus(status int64) (int64, error) {
}
func DeleteDisabledChannel() (int64, error) {
- result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
+ result := DB.Where("status = ? or status = ?", config.ChannelStatusAutoDisabled, config.ChannelStatusManuallyDisabled).Delete(&Channel{})
// 同时删除Ability
DB.Where("enabled = ?", false).Delete(&Ability{})
return result.RowsAffected, result.Error
diff --git a/model/common.go b/model/common.go
index c85c19ab..e5a77d09 100644
--- a/model/common.go
+++ b/model/common.go
@@ -3,6 +3,7 @@ package model
import (
"fmt"
"one-api/common"
+ "one-api/common/config"
"strings"
"gorm.io/gorm"
@@ -43,11 +44,11 @@ func PaginateAndOrder[T modelable](db *gorm.DB, params *PaginationParams, result
params.Page = 1
}
if params.Size < 1 {
- params.Size = common.ItemsPerPage
+ params.Size = config.ItemsPerPage
}
- if params.Size > common.MaxRecentItems {
- return nil, fmt.Errorf("size 参数不能超过 %d", common.MaxRecentItems)
+ if params.Size > config.MaxRecentItems {
+ return nil, fmt.Errorf("size 参数不能超过 %d", config.MaxRecentItems)
}
offset := (params.Page - 1) * params.Size
diff --git a/model/log.go b/model/log.go
index 6c8b2008..9bac2cc7 100644
--- a/model/log.go
+++ b/model/log.go
@@ -3,7 +3,7 @@ package model
import (
"context"
"fmt"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"one-api/common/utils"
@@ -37,7 +37,7 @@ const (
)
func RecordLog(userId int, logType int, content string) {
- if logType == LogTypeConsume && !common.LogConsumeEnabled {
+ if logType == LogTypeConsume && !config.LogConsumeEnabled {
return
}
log := &Log{
@@ -55,7 +55,7 @@ func RecordLog(userId int, logType int, content string) {
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, requestTime int) {
logger.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
- if !common.LogConsumeEnabled {
+ if !config.LogConsumeEnabled {
return
}
log := &Log{
@@ -156,12 +156,12 @@ func GetUserLogsList(userId int, params *LogsListParams) (*DataResult[Log], erro
}
func SearchAllLogs(keyword string) (logs []*Log, err error) {
- err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
+ err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error
return logs, err
}
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
- err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
+ err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error
return logs, err
}
diff --git a/model/main.go b/model/main.go
index a81e1dc3..e9d5ac14 100644
--- a/model/main.go
+++ b/model/main.go
@@ -3,6 +3,7 @@ package model
import (
"fmt"
"one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"one-api/common/utils"
"strconv"
@@ -24,12 +25,12 @@ func SetupDB() {
logger.FatalLog("failed to initialize database: " + err.Error())
}
ChannelGroup.Load()
- common.RootUserEmail = GetRootUserEmail()
+ config.RootUserEmail = GetRootUserEmail()
if viper.GetBool("batch_update_enabled") {
- common.BatchUpdateEnabled = true
- common.BatchUpdateInterval = utils.GetOrDefault("batch_update_interval", 5)
- logger.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
+ config.BatchUpdateEnabled = true
+ config.BatchUpdateInterval = utils.GetOrDefault("batch_update_interval", 5)
+ logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s")
InitBatchUpdater()
}
}
@@ -46,8 +47,8 @@ func createRootAccountIfNeed() error {
rootUser := User{
Username: "root",
Password: hashedPassword,
- Role: common.RoleRootUser,
- Status: common.UserStatusEnabled,
+ Role: config.RoleRootUser,
+ Status: config.UserStatusEnabled,
DisplayName: "Root User",
AccessToken: utils.GetUUID(),
Quota: 100000000,
@@ -102,7 +103,7 @@ func InitDB() (err error) {
sqlDB.SetMaxOpenConns(utils.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(utils.GetOrDefault("SQL_MAX_LIFETIME", 60)))
- if !common.IsMasterNode {
+ if !config.IsMasterNode {
return nil
}
logger.SysLog("database migration started")
diff --git a/model/option.go b/model/option.go
index 4e75f3f3..0dc6db0d 100644
--- a/model/option.go
+++ b/model/option.go
@@ -2,6 +2,7 @@ package model
import (
"one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"strconv"
"strings"
@@ -26,63 +27,63 @@ func GetOption(key string) (option Option, err error) {
}
func InitOptionMap() {
- common.OptionMapRWMutex.Lock()
- common.OptionMap = make(map[string]string)
- common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled)
- common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
- common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
- common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
- common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
- common.OptionMap["LarkAuthEnabled"] = strconv.FormatBool(common.LarkAuthEnabled)
- common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
- common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
- common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
- common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
- common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled)
- common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
- common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
- common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
- common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
- common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
- common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
- common.OptionMap["SMTPServer"] = ""
- common.OptionMap["SMTPFrom"] = ""
- common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
- common.OptionMap["SMTPAccount"] = ""
- common.OptionMap["SMTPToken"] = ""
- common.OptionMap["Notice"] = ""
- common.OptionMap["About"] = ""
- common.OptionMap["HomePageContent"] = ""
- common.OptionMap["Footer"] = common.Footer
- common.OptionMap["SystemName"] = common.SystemName
- common.OptionMap["Logo"] = common.Logo
- common.OptionMap["ServerAddress"] = ""
- common.OptionMap["GitHubClientId"] = ""
- common.OptionMap["GitHubClientSecret"] = ""
- common.OptionMap["WeChatServerAddress"] = ""
- common.OptionMap["WeChatServerToken"] = ""
- common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
- common.OptionMap["TurnstileSiteKey"] = ""
- common.OptionMap["TurnstileSecretKey"] = ""
- common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
- common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
- common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
- common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
- common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
- common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
- common.OptionMap["TopUpLink"] = common.TopUpLink
- common.OptionMap["ChatLink"] = common.ChatLink
- common.OptionMap["ChatLinks"] = common.ChatLinks
- common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
- common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
- common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds)
+ config.OptionMapRWMutex.Lock()
+ config.OptionMap = make(map[string]string)
+ config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled)
+ config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
+ config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
+ config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
+ config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
+ config.OptionMap["LarkAuthEnabled"] = strconv.FormatBool(config.LarkAuthEnabled)
+ config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
+ config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
+ config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled)
+ config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled)
+ config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled)
+ config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled)
+ config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled)
+ config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled)
+ config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64)
+ config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled)
+ config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",")
+ config.OptionMap["SMTPServer"] = ""
+ config.OptionMap["SMTPFrom"] = ""
+ config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort)
+ config.OptionMap["SMTPAccount"] = ""
+ config.OptionMap["SMTPToken"] = ""
+ config.OptionMap["Notice"] = ""
+ config.OptionMap["About"] = ""
+ config.OptionMap["HomePageContent"] = ""
+ config.OptionMap["Footer"] = config.Footer
+ config.OptionMap["SystemName"] = config.SystemName
+ config.OptionMap["Logo"] = config.Logo
+ config.OptionMap["ServerAddress"] = ""
+ config.OptionMap["GitHubClientId"] = ""
+ config.OptionMap["GitHubClientSecret"] = ""
+ config.OptionMap["WeChatServerAddress"] = ""
+ config.OptionMap["WeChatServerToken"] = ""
+ config.OptionMap["WeChatAccountQRCodeImageURL"] = ""
+ config.OptionMap["TurnstileSiteKey"] = ""
+ config.OptionMap["TurnstileSecretKey"] = ""
+ config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser)
+ config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter)
+ config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee)
+ config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold)
+ config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota)
+ config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
+ config.OptionMap["TopUpLink"] = config.TopUpLink
+ config.OptionMap["ChatLink"] = config.ChatLink
+ config.OptionMap["ChatLinks"] = config.ChatLinks
+ config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
+ config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes)
+ config.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(config.RetryCooldownSeconds)
- common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(common.MjNotifyEnabled)
+ config.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(config.MjNotifyEnabled)
- common.OptionMap["ChatCacheEnabled"] = strconv.FormatBool(common.ChatCacheEnabled)
- common.OptionMap["ChatCacheExpireMinute"] = strconv.Itoa(common.ChatCacheExpireMinute)
+ config.OptionMap["ChatCacheEnabled"] = strconv.FormatBool(config.ChatCacheEnabled)
+ config.OptionMap["ChatCacheExpireMinute"] = strconv.Itoa(config.ChatCacheExpireMinute)
- common.OptionMapRWMutex.Unlock()
+ config.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
@@ -121,64 +122,64 @@ func UpdateOption(key string, value string) error {
}
var optionIntMap = map[string]*int{
- "SMTPPort": &common.SMTPPort,
- "QuotaForNewUser": &common.QuotaForNewUser,
- "QuotaForInviter": &common.QuotaForInviter,
- "QuotaForInvitee": &common.QuotaForInvitee,
- "QuotaRemindThreshold": &common.QuotaRemindThreshold,
- "PreConsumedQuota": &common.PreConsumedQuota,
- "RetryTimes": &common.RetryTimes,
- "RetryCooldownSeconds": &common.RetryCooldownSeconds,
- "ChatCacheExpireMinute": &common.ChatCacheExpireMinute,
+ "SMTPPort": &config.SMTPPort,
+ "QuotaForNewUser": &config.QuotaForNewUser,
+ "QuotaForInviter": &config.QuotaForInviter,
+ "QuotaForInvitee": &config.QuotaForInvitee,
+ "QuotaRemindThreshold": &config.QuotaRemindThreshold,
+ "PreConsumedQuota": &config.PreConsumedQuota,
+ "RetryTimes": &config.RetryTimes,
+ "RetryCooldownSeconds": &config.RetryCooldownSeconds,
+ "ChatCacheExpireMinute": &config.ChatCacheExpireMinute,
}
var optionBoolMap = map[string]*bool{
- "PasswordRegisterEnabled": &common.PasswordRegisterEnabled,
- "PasswordLoginEnabled": &common.PasswordLoginEnabled,
- "EmailVerificationEnabled": &common.EmailVerificationEnabled,
- "GitHubOAuthEnabled": &common.GitHubOAuthEnabled,
- "WeChatAuthEnabled": &common.WeChatAuthEnabled,
- "LarkAuthEnabled": &common.LarkAuthEnabled,
- "TurnstileCheckEnabled": &common.TurnstileCheckEnabled,
- "RegisterEnabled": &common.RegisterEnabled,
- "EmailDomainRestrictionEnabled": &common.EmailDomainRestrictionEnabled,
- "AutomaticDisableChannelEnabled": &common.AutomaticDisableChannelEnabled,
- "AutomaticEnableChannelEnabled": &common.AutomaticEnableChannelEnabled,
- "ApproximateTokenEnabled": &common.ApproximateTokenEnabled,
- "LogConsumeEnabled": &common.LogConsumeEnabled,
- "DisplayInCurrencyEnabled": &common.DisplayInCurrencyEnabled,
- "DisplayTokenStatEnabled": &common.DisplayTokenStatEnabled,
- "MjNotifyEnabled": &common.MjNotifyEnabled,
- "ChatCacheEnabled": &common.ChatCacheEnabled,
+ "PasswordRegisterEnabled": &config.PasswordRegisterEnabled,
+ "PasswordLoginEnabled": &config.PasswordLoginEnabled,
+ "EmailVerificationEnabled": &config.EmailVerificationEnabled,
+ "GitHubOAuthEnabled": &config.GitHubOAuthEnabled,
+ "WeChatAuthEnabled": &config.WeChatAuthEnabled,
+ "LarkAuthEnabled": &config.LarkAuthEnabled,
+ "TurnstileCheckEnabled": &config.TurnstileCheckEnabled,
+ "RegisterEnabled": &config.RegisterEnabled,
+ "EmailDomainRestrictionEnabled": &config.EmailDomainRestrictionEnabled,
+ "AutomaticDisableChannelEnabled": &config.AutomaticDisableChannelEnabled,
+ "AutomaticEnableChannelEnabled": &config.AutomaticEnableChannelEnabled,
+ "ApproximateTokenEnabled": &config.ApproximateTokenEnabled,
+ "LogConsumeEnabled": &config.LogConsumeEnabled,
+ "DisplayInCurrencyEnabled": &config.DisplayInCurrencyEnabled,
+ "DisplayTokenStatEnabled": &config.DisplayTokenStatEnabled,
+ "MjNotifyEnabled": &config.MjNotifyEnabled,
+ "ChatCacheEnabled": &config.ChatCacheEnabled,
}
var optionStringMap = map[string]*string{
- "SMTPServer": &common.SMTPServer,
- "SMTPAccount": &common.SMTPAccount,
- "SMTPFrom": &common.SMTPFrom,
- "SMTPToken": &common.SMTPToken,
- "ServerAddress": &common.ServerAddress,
- "GitHubClientId": &common.GitHubClientId,
- "GitHubClientSecret": &common.GitHubClientSecret,
- "Footer": &common.Footer,
- "SystemName": &common.SystemName,
- "Logo": &common.Logo,
- "WeChatServerAddress": &common.WeChatServerAddress,
- "WeChatServerToken": &common.WeChatServerToken,
- "WeChatAccountQRCodeImageURL": &common.WeChatAccountQRCodeImageURL,
- "TurnstileSiteKey": &common.TurnstileSiteKey,
- "TurnstileSecretKey": &common.TurnstileSecretKey,
- "TopUpLink": &common.TopUpLink,
- "ChatLink": &common.ChatLink,
- "ChatLinks": &common.ChatLinks,
- "LarkClientId": &common.LarkClientId,
- "LarkClientSecret": &common.LarkClientSecret,
+ "SMTPServer": &config.SMTPServer,
+ "SMTPAccount": &config.SMTPAccount,
+ "SMTPFrom": &config.SMTPFrom,
+ "SMTPToken": &config.SMTPToken,
+ "ServerAddress": &config.ServerAddress,
+ "GitHubClientId": &config.GitHubClientId,
+ "GitHubClientSecret": &config.GitHubClientSecret,
+ "Footer": &config.Footer,
+ "SystemName": &config.SystemName,
+ "Logo": &config.Logo,
+ "WeChatServerAddress": &config.WeChatServerAddress,
+ "WeChatServerToken": &config.WeChatServerToken,
+ "WeChatAccountQRCodeImageURL": &config.WeChatAccountQRCodeImageURL,
+ "TurnstileSiteKey": &config.TurnstileSiteKey,
+ "TurnstileSecretKey": &config.TurnstileSecretKey,
+ "TopUpLink": &config.TopUpLink,
+ "ChatLink": &config.ChatLink,
+ "ChatLinks": &config.ChatLinks,
+ "LarkClientId": &config.LarkClientId,
+ "LarkClientSecret": &config.LarkClientSecret,
}
func updateOptionMap(key string, value string) (err error) {
- common.OptionMapRWMutex.Lock()
- defer common.OptionMapRWMutex.Unlock()
- common.OptionMap[key] = value
+ config.OptionMapRWMutex.Lock()
+ defer config.OptionMapRWMutex.Unlock()
+ config.OptionMap[key] = value
if ptr, ok := optionIntMap[key]; ok {
*ptr, _ = strconv.Atoi(value)
return
@@ -196,13 +197,13 @@ func updateOptionMap(key string, value string) (err error) {
switch key {
case "EmailDomainWhitelist":
- common.EmailDomainWhitelist = strings.Split(value, ",")
+ config.EmailDomainWhitelist = strings.Split(value, ",")
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
case "ChannelDisableThreshold":
- common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
+ config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
case "QuotaPerUnit":
- common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
+ config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
}
return err
}
diff --git a/model/price.go b/model/price.go
index c1f42123..709a7c66 100644
--- a/model/price.go
+++ b/model/price.go
@@ -1,7 +1,7 @@
package model
import (
- "one-api/common"
+ "one-api/common/config"
"github.com/shopspring/decimal"
"gorm.io/gorm"
@@ -114,211 +114,211 @@ type ModelType struct {
func GetDefaultPrice() []*Price {
ModelTypes := map[string]ModelType{
// $0.03 / 1K tokens $0.06 / 1K tokens
- "gpt-4": {[]float64{15, 30}, common.ChannelTypeOpenAI},
- "gpt-4-0314": {[]float64{15, 30}, common.ChannelTypeOpenAI},
- "gpt-4-0613": {[]float64{15, 30}, common.ChannelTypeOpenAI},
+ "gpt-4": {[]float64{15, 30}, config.ChannelTypeOpenAI},
+ "gpt-4-0314": {[]float64{15, 30}, config.ChannelTypeOpenAI},
+ "gpt-4-0613": {[]float64{15, 30}, config.ChannelTypeOpenAI},
// $0.06 / 1K tokens $0.12 / 1K tokens
- "gpt-4-32k": {[]float64{30, 60}, common.ChannelTypeOpenAI},
- "gpt-4-32k-0314": {[]float64{30, 60}, common.ChannelTypeOpenAI},
- "gpt-4-32k-0613": {[]float64{30, 60}, common.ChannelTypeOpenAI},
+ "gpt-4-32k": {[]float64{30, 60}, config.ChannelTypeOpenAI},
+ "gpt-4-32k-0314": {[]float64{30, 60}, config.ChannelTypeOpenAI},
+ "gpt-4-32k-0613": {[]float64{30, 60}, config.ChannelTypeOpenAI},
// $0.01 / 1K tokens $0.03 / 1K tokens
- "gpt-4-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI},
- "gpt-4-turbo": {[]float64{5, 15}, common.ChannelTypeOpenAI},
- "gpt-4-turbo-2024-04-09": {[]float64{5, 15}, common.ChannelTypeOpenAI},
- "gpt-4-1106-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI},
- "gpt-4-0125-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI},
- "gpt-4-turbo-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI},
- "gpt-4-vision-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI},
+ "gpt-4-preview": {[]float64{5, 15}, config.ChannelTypeOpenAI},
+ "gpt-4-turbo": {[]float64{5, 15}, config.ChannelTypeOpenAI},
+ "gpt-4-turbo-2024-04-09": {[]float64{5, 15}, config.ChannelTypeOpenAI},
+ "gpt-4-1106-preview": {[]float64{5, 15}, config.ChannelTypeOpenAI},
+ "gpt-4-0125-preview": {[]float64{5, 15}, config.ChannelTypeOpenAI},
+ "gpt-4-turbo-preview": {[]float64{5, 15}, config.ChannelTypeOpenAI},
+ "gpt-4-vision-preview": {[]float64{5, 15}, config.ChannelTypeOpenAI},
// $0.005 / 1K tokens $0.015 / 1K tokens
- "gpt-4o": {[]float64{2.5, 7.5}, common.ChannelTypeOpenAI},
+ "gpt-4o": {[]float64{2.5, 7.5}, config.ChannelTypeOpenAI},
// $0.0005 / 1K tokens $0.0015 / 1K tokens
- "gpt-3.5-turbo": {[]float64{0.25, 0.75}, common.ChannelTypeOpenAI},
- "gpt-3.5-turbo-0125": {[]float64{0.25, 0.75}, common.ChannelTypeOpenAI},
+ "gpt-3.5-turbo": {[]float64{0.25, 0.75}, config.ChannelTypeOpenAI},
+ "gpt-3.5-turbo-0125": {[]float64{0.25, 0.75}, config.ChannelTypeOpenAI},
// $0.0015 / 1K tokens $0.002 / 1K tokens
- "gpt-3.5-turbo-0301": {[]float64{0.75, 1}, common.ChannelTypeOpenAI},
- "gpt-3.5-turbo-0613": {[]float64{0.75, 1}, common.ChannelTypeOpenAI},
- "gpt-3.5-turbo-instruct": {[]float64{0.75, 1}, common.ChannelTypeOpenAI},
+ "gpt-3.5-turbo-0301": {[]float64{0.75, 1}, config.ChannelTypeOpenAI},
+ "gpt-3.5-turbo-0613": {[]float64{0.75, 1}, config.ChannelTypeOpenAI},
+ "gpt-3.5-turbo-instruct": {[]float64{0.75, 1}, config.ChannelTypeOpenAI},
// $0.003 / 1K tokens $0.004 / 1K tokens
- "gpt-3.5-turbo-16k": {[]float64{1.5, 2}, common.ChannelTypeOpenAI},
- "gpt-3.5-turbo-16k-0613": {[]float64{1.5, 2}, common.ChannelTypeOpenAI},
+ "gpt-3.5-turbo-16k": {[]float64{1.5, 2}, config.ChannelTypeOpenAI},
+ "gpt-3.5-turbo-16k-0613": {[]float64{1.5, 2}, config.ChannelTypeOpenAI},
// $0.001 / 1K tokens $0.002 / 1K tokens
- "gpt-3.5-turbo-1106": {[]float64{0.5, 1}, common.ChannelTypeOpenAI},
+ "gpt-3.5-turbo-1106": {[]float64{0.5, 1}, config.ChannelTypeOpenAI},
// $0.0020 / 1K tokens
- "davinci-002": {[]float64{1, 1}, common.ChannelTypeOpenAI},
+ "davinci-002": {[]float64{1, 1}, config.ChannelTypeOpenAI},
// $0.0004 / 1K tokens
- "babbage-002": {[]float64{0.2, 0.2}, common.ChannelTypeOpenAI},
+ "babbage-002": {[]float64{0.2, 0.2}, config.ChannelTypeOpenAI},
// $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
- "whisper-1": {[]float64{15, 15}, common.ChannelTypeOpenAI},
+ "whisper-1": {[]float64{15, 15}, config.ChannelTypeOpenAI},
// $0.015 / 1K characters
- "tts-1": {[]float64{7.5, 7.5}, common.ChannelTypeOpenAI},
- "tts-1-1106": {[]float64{7.5, 7.5}, common.ChannelTypeOpenAI},
+ "tts-1": {[]float64{7.5, 7.5}, config.ChannelTypeOpenAI},
+ "tts-1-1106": {[]float64{7.5, 7.5}, config.ChannelTypeOpenAI},
// $0.030 / 1K characters
- "tts-1-hd": {[]float64{15, 15}, common.ChannelTypeOpenAI},
- "tts-1-hd-1106": {[]float64{15, 15}, common.ChannelTypeOpenAI},
- "text-embedding-ada-002": {[]float64{0.05, 0.05}, common.ChannelTypeOpenAI},
+ "tts-1-hd": {[]float64{15, 15}, config.ChannelTypeOpenAI},
+ "tts-1-hd-1106": {[]float64{15, 15}, config.ChannelTypeOpenAI},
+ "text-embedding-ada-002": {[]float64{0.05, 0.05}, config.ChannelTypeOpenAI},
// $0.00002 / 1K tokens
- "text-embedding-3-small": {[]float64{0.01, 0.01}, common.ChannelTypeOpenAI},
+ "text-embedding-3-small": {[]float64{0.01, 0.01}, config.ChannelTypeOpenAI},
// $0.00013 / 1K tokens
- "text-embedding-3-large": {[]float64{0.065, 0.065}, common.ChannelTypeOpenAI},
- "text-moderation-stable": {[]float64{0.1, 0.1}, common.ChannelTypeOpenAI},
- "text-moderation-latest": {[]float64{0.1, 0.1}, common.ChannelTypeOpenAI},
+ "text-embedding-3-large": {[]float64{0.065, 0.065}, config.ChannelTypeOpenAI},
+ "text-moderation-stable": {[]float64{0.1, 0.1}, config.ChannelTypeOpenAI},
+ "text-moderation-latest": {[]float64{0.1, 0.1}, config.ChannelTypeOpenAI},
// $0.016 - $0.020 / image
- "dall-e-2": {[]float64{8, 8}, common.ChannelTypeOpenAI},
+ "dall-e-2": {[]float64{8, 8}, config.ChannelTypeOpenAI},
// $0.040 - $0.120 / image
- "dall-e-3": {[]float64{20, 20}, common.ChannelTypeOpenAI},
+ "dall-e-3": {[]float64{20, 20}, config.ChannelTypeOpenAI},
// $0.80/million tokens $2.40/million tokens
- "claude-instant-1.2": {[]float64{0.4, 1.2}, common.ChannelTypeAnthropic},
+ "claude-instant-1.2": {[]float64{0.4, 1.2}, config.ChannelTypeAnthropic},
// $8.00/million tokens $24.00/million tokens
- "claude-2.0": {[]float64{4, 12}, common.ChannelTypeAnthropic},
- "claude-2.1": {[]float64{4, 12}, common.ChannelTypeAnthropic},
+ "claude-2.0": {[]float64{4, 12}, config.ChannelTypeAnthropic},
+ "claude-2.1": {[]float64{4, 12}, config.ChannelTypeAnthropic},
// $15 / M $75 / M
- "claude-3-opus-20240229": {[]float64{7.5, 22.5}, common.ChannelTypeAnthropic},
+ "claude-3-opus-20240229": {[]float64{7.5, 22.5}, config.ChannelTypeAnthropic},
// $3 / M $15 / M
- "claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, common.ChannelTypeAnthropic},
+ "claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, config.ChannelTypeAnthropic},
// $0.25 / M $1.25 / M 0.00025$ / 1k tokens 0.00125$ / 1k tokens
- "claude-3-haiku-20240307": {[]float64{0.125, 0.625}, common.ChannelTypeAnthropic},
+ "claude-3-haiku-20240307": {[]float64{0.125, 0.625}, config.ChannelTypeAnthropic},
// ¥0.004 / 1k tokens ¥0.008 / 1k tokens
- "ERNIE-Speed": {[]float64{0.2857, 0.5714}, common.ChannelTypeBaidu},
+ "ERNIE-Speed": {[]float64{0.2857, 0.5714}, config.ChannelTypeBaidu},
// ¥0.012 / 1k tokens ¥0.012 / 1k tokens
- "ERNIE-Bot": {[]float64{0.8572, 0.8572}, common.ChannelTypeBaidu},
- "ERNIE-3.5-8K": {[]float64{0.8572, 0.8572}, common.ChannelTypeBaidu},
+ "ERNIE-Bot": {[]float64{0.8572, 0.8572}, config.ChannelTypeBaidu},
+ "ERNIE-3.5-8K": {[]float64{0.8572, 0.8572}, config.ChannelTypeBaidu},
// 0.024元/千tokens 0.048元/千tokens
- "ERNIE-Bot-8k": {[]float64{1.7143, 3.4286}, common.ChannelTypeBaidu},
+ "ERNIE-Bot-8k": {[]float64{1.7143, 3.4286}, config.ChannelTypeBaidu},
// ¥0.008 / 1k tokens ¥0.008 / 1k tokens
- "ERNIE-Bot-turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeBaidu},
+ "ERNIE-Bot-turbo": {[]float64{0.5715, 0.5715}, config.ChannelTypeBaidu},
// ¥0.12 / 1k tokens ¥0.12 / 1k tokens
- "ERNIE-Bot-4": {[]float64{8.572, 8.572}, common.ChannelTypeBaidu},
- "ERNIE-4.0": {[]float64{8.572, 8.572}, common.ChannelTypeBaidu},
+ "ERNIE-Bot-4": {[]float64{8.572, 8.572}, config.ChannelTypeBaidu},
+ "ERNIE-4.0": {[]float64{8.572, 8.572}, config.ChannelTypeBaidu},
// ¥0.002 / 1k tokens
- "Embedding-V1": {[]float64{0.1429, 0.1429}, common.ChannelTypeBaidu},
+ "Embedding-V1": {[]float64{0.1429, 0.1429}, config.ChannelTypeBaidu},
// ¥0.004 / 1k tokens
- "BLOOMZ-7B": {[]float64{0.2857, 0.2857}, common.ChannelTypeBaidu},
+ "BLOOMZ-7B": {[]float64{0.2857, 0.2857}, config.ChannelTypeBaidu},
- "PaLM-2": {[]float64{1, 1}, common.ChannelTypePaLM},
+ "PaLM-2": {[]float64{1, 1}, config.ChannelTypePaLM},
// $0.50 / 1 million tokens $1.50 / 1 million tokens
// 0.0005$ / 1k tokens 0.0015$ / 1k tokens
- "gemini-pro": {[]float64{0.25, 0.75}, common.ChannelTypeGemini},
- "gemini-pro-vision": {[]float64{0.25, 0.75}, common.ChannelTypeGemini},
- "gemini-1.0-pro": {[]float64{0.25, 0.75}, common.ChannelTypeGemini},
+ "gemini-pro": {[]float64{0.25, 0.75}, config.ChannelTypeGemini},
+ "gemini-pro-vision": {[]float64{0.25, 0.75}, config.ChannelTypeGemini},
+ "gemini-1.0-pro": {[]float64{0.25, 0.75}, config.ChannelTypeGemini},
// $7 / 1 million tokens $21 / 1 million tokens
- "gemini-1.5-pro": {[]float64{1.75, 5.25}, common.ChannelTypeGemini},
- "gemini-1.5-pro-latest": {[]float64{1.75, 5.25}, common.ChannelTypeGemini},
- "gemini-1.5-flash": {[]float64{0.175, 0.265}, common.ChannelTypeGemini},
- "gemini-1.5-flash-latest": {[]float64{0.175, 0.265}, common.ChannelTypeGemini},
- "gemini-ultra": {[]float64{1, 1}, common.ChannelTypeGemini},
+ "gemini-1.5-pro": {[]float64{1.75, 5.25}, config.ChannelTypeGemini},
+ "gemini-1.5-pro-latest": {[]float64{1.75, 5.25}, config.ChannelTypeGemini},
+ "gemini-1.5-flash": {[]float64{0.175, 0.265}, config.ChannelTypeGemini},
+ "gemini-1.5-flash-latest": {[]float64{0.175, 0.265}, config.ChannelTypeGemini},
+ "gemini-ultra": {[]float64{1, 1}, config.ChannelTypeGemini},
// ¥0.005 / 1k tokens
- "glm-3-turbo": {[]float64{0.3572, 0.3572}, common.ChannelTypeZhipu},
+ "glm-3-turbo": {[]float64{0.3572, 0.3572}, config.ChannelTypeZhipu},
// ¥0.1 / 1k tokens
- "glm-4": {[]float64{7.143, 7.143}, common.ChannelTypeZhipu},
- "glm-4v": {[]float64{7.143, 7.143}, common.ChannelTypeZhipu},
+ "glm-4": {[]float64{7.143, 7.143}, config.ChannelTypeZhipu},
+ "glm-4v": {[]float64{7.143, 7.143}, config.ChannelTypeZhipu},
// ¥0.0005 / 1k tokens
- "embedding-2": {[]float64{0.0357, 0.0357}, common.ChannelTypeZhipu},
+ "embedding-2": {[]float64{0.0357, 0.0357}, config.ChannelTypeZhipu},
// ¥0.25 / 1张图片
- "cogview-3": {[]float64{17.8571, 17.8571}, common.ChannelTypeZhipu},
+ "cogview-3": {[]float64{17.8571, 17.8571}, config.ChannelTypeZhipu},
// ¥0.008 / 1k tokens
- "qwen-turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeAli},
+ "qwen-turbo": {[]float64{0.5715, 0.5715}, config.ChannelTypeAli},
// ¥0.02 / 1k tokens
- "qwen-plus": {[]float64{1.4286, 1.4286}, common.ChannelTypeAli},
- "qwen-vl-max": {[]float64{1.4286, 1.4286}, common.ChannelTypeAli},
+ "qwen-plus": {[]float64{1.4286, 1.4286}, config.ChannelTypeAli},
+ "qwen-vl-max": {[]float64{1.4286, 1.4286}, config.ChannelTypeAli},
// 0.12元/1,000tokens
- "qwen-max": {[]float64{8.5714, 8.5714}, common.ChannelTypeAli},
- "qwen-max-longcontext": {[]float64{8.5714, 8.5714}, common.ChannelTypeAli},
+ "qwen-max": {[]float64{8.5714, 8.5714}, config.ChannelTypeAli},
+ "qwen-max-longcontext": {[]float64{8.5714, 8.5714}, config.ChannelTypeAli},
// 0.008元/1,000tokens
- "qwen-vl-plus": {[]float64{0.5715, 0.5715}, common.ChannelTypeAli},
+ "qwen-vl-plus": {[]float64{0.5715, 0.5715}, config.ChannelTypeAli},
// ¥0.0007 / 1k tokens
- "text-embedding-v1": {[]float64{0.05, 0.05}, common.ChannelTypeAli},
+ "text-embedding-v1": {[]float64{0.05, 0.05}, config.ChannelTypeAli},
// ¥0.018 / 1k tokens
- "SparkDesk": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei},
- "SparkDesk-v1.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei},
- "SparkDesk-v2.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei},
- "SparkDesk-v3.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei},
- "SparkDesk-v3.5": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei},
+ "SparkDesk": {[]float64{1.2858, 1.2858}, config.ChannelTypeXunfei},
+ "SparkDesk-v1.1": {[]float64{1.2858, 1.2858}, config.ChannelTypeXunfei},
+ "SparkDesk-v2.1": {[]float64{1.2858, 1.2858}, config.ChannelTypeXunfei},
+ "SparkDesk-v3.1": {[]float64{1.2858, 1.2858}, config.ChannelTypeXunfei},
+ "SparkDesk-v3.5": {[]float64{1.2858, 1.2858}, config.ChannelTypeXunfei},
// ¥0.012 / 1k tokens
- "360GPT_S2_V9": {[]float64{0.8572, 0.8572}, common.ChannelType360},
+ "360GPT_S2_V9": {[]float64{0.8572, 0.8572}, config.ChannelType360},
// ¥0.001 / 1k tokens
- "embedding-bert-512-v1": {[]float64{0.0715, 0.0715}, common.ChannelType360},
- "embedding_s1_v1": {[]float64{0.0715, 0.0715}, common.ChannelType360},
- "semantic_similarity_s1_v1": {[]float64{0.0715, 0.0715}, common.ChannelType360},
+ "embedding-bert-512-v1": {[]float64{0.0715, 0.0715}, config.ChannelType360},
+ "embedding_s1_v1": {[]float64{0.0715, 0.0715}, config.ChannelType360},
+ "semantic_similarity_s1_v1": {[]float64{0.0715, 0.0715}, config.ChannelType360},
// ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
- "hunyuan": {[]float64{7.143, 7.143}, common.ChannelTypeTencent},
+ "hunyuan": {[]float64{7.143, 7.143}, config.ChannelTypeTencent},
// https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// ¥0.01 / 1k tokens
- "ChatStd": {[]float64{0.7143, 0.7143}, common.ChannelTypeTencent},
+ "ChatStd": {[]float64{0.7143, 0.7143}, config.ChannelTypeTencent},
//¥0.1 / 1k tokens
- "ChatPro": {[]float64{7.143, 7.143}, common.ChannelTypeTencent},
+ "ChatPro": {[]float64{7.143, 7.143}, config.ChannelTypeTencent},
- "Baichuan2-Turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeBaichuan}, // ¥0.008 / 1k tokens
- "Baichuan2-Turbo-192k": {[]float64{1.143, 1.143}, common.ChannelTypeBaichuan}, // ¥0.016 / 1k tokens
- "Baichuan2-53B": {[]float64{1.4286, 1.4286}, common.ChannelTypeBaichuan}, // ¥0.02 / 1k tokens
- "Baichuan-Text-Embedding": {[]float64{0.0357, 0.0357}, common.ChannelTypeBaichuan}, // ¥0.0005 / 1k tokens
+ "Baichuan2-Turbo": {[]float64{0.5715, 0.5715}, config.ChannelTypeBaichuan}, // ¥0.008 / 1k tokens
+ "Baichuan2-Turbo-192k": {[]float64{1.143, 1.143}, config.ChannelTypeBaichuan}, // ¥0.016 / 1k tokens
+ "Baichuan2-53B": {[]float64{1.4286, 1.4286}, config.ChannelTypeBaichuan}, // ¥0.02 / 1k tokens
+ "Baichuan-Text-Embedding": {[]float64{0.0357, 0.0357}, config.ChannelTypeBaichuan}, // ¥0.0005 / 1k tokens
- "abab5.5s-chat": {[]float64{0.3572, 0.3572}, common.ChannelTypeMiniMax}, // ¥0.005 / 1k tokens
- "abab5.5-chat": {[]float64{1.0714, 1.0714}, common.ChannelTypeMiniMax}, // ¥0.015 / 1k tokens
- "abab6-chat": {[]float64{14.2857, 14.2857}, common.ChannelTypeMiniMax}, // ¥0.2 / 1k tokens
- "embo-01": {[]float64{0.0357, 0.0357}, common.ChannelTypeMiniMax}, // ¥0.0005 / 1k tokens
+ "abab5.5s-chat": {[]float64{0.3572, 0.3572}, config.ChannelTypeMiniMax}, // ¥0.005 / 1k tokens
+ "abab5.5-chat": {[]float64{1.0714, 1.0714}, config.ChannelTypeMiniMax}, // ¥0.015 / 1k tokens
+ "abab6-chat": {[]float64{14.2857, 14.2857}, config.ChannelTypeMiniMax}, // ¥0.2 / 1k tokens
+ "embo-01": {[]float64{0.0357, 0.0357}, config.ChannelTypeMiniMax}, // ¥0.0005 / 1k tokens
- "deepseek-coder": {[]float64{0.75, 0.75}, common.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens
- "deepseek-chat": {[]float64{0.75, 0.75}, common.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens
+ "deepseek-coder": {[]float64{0.75, 0.75}, config.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens
+ "deepseek-chat": {[]float64{0.75, 0.75}, config.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens
- "moonshot-v1-8k": {[]float64{0.8572, 0.8572}, common.ChannelTypeMoonshot}, // ¥0.012 / 1K tokens
- "moonshot-v1-32k": {[]float64{1.7143, 1.7143}, common.ChannelTypeMoonshot}, // ¥0.024 / 1K tokens
- "moonshot-v1-128k": {[]float64{4.2857, 4.2857}, common.ChannelTypeMoonshot}, // ¥0.06 / 1K tokens
+ "moonshot-v1-8k": {[]float64{0.8572, 0.8572}, config.ChannelTypeMoonshot}, // ¥0.012 / 1K tokens
+ "moonshot-v1-32k": {[]float64{1.7143, 1.7143}, config.ChannelTypeMoonshot}, // ¥0.024 / 1K tokens
+ "moonshot-v1-128k": {[]float64{4.2857, 4.2857}, config.ChannelTypeMoonshot}, // ¥0.06 / 1K tokens
- "open-mistral-7b": {[]float64{0.125, 0.125}, common.ChannelTypeMistral}, // 0.25$ / 1M tokens 0.25$ / 1M tokens 0.00025$ / 1k tokens
- "open-mixtral-8x7b": {[]float64{0.35, 0.35}, common.ChannelTypeMistral}, // 0.7$ / 1M tokens 0.7$ / 1M tokens 0.0007$ / 1k tokens
- "mistral-small-latest": {[]float64{1, 3}, common.ChannelTypeMistral}, // 2$ / 1M tokens 6$ / 1M tokens 0.002$ / 1k tokens
- "mistral-medium-latest": {[]float64{1.35, 4.05}, common.ChannelTypeMistral}, // 2.7$ / 1M tokens 8.1$ / 1M tokens 0.0027$ / 1k tokens
- "mistral-large-latest": {[]float64{4, 12}, common.ChannelTypeMistral}, // 8$ / 1M tokens 24$ / 1M tokens 0.008$ / 1k tokens
- "mistral-embed": {[]float64{0.05, 0.05}, common.ChannelTypeMistral}, // 0.1$ / 1M tokens 0.1$ / 1M tokens 0.0001$ / 1k tokens
+ "open-mistral-7b": {[]float64{0.125, 0.125}, config.ChannelTypeMistral}, // 0.25$ / 1M tokens 0.25$ / 1M tokens 0.00025$ / 1k tokens
+ "open-mixtral-8x7b": {[]float64{0.35, 0.35}, config.ChannelTypeMistral}, // 0.7$ / 1M tokens 0.7$ / 1M tokens 0.0007$ / 1k tokens
+ "mistral-small-latest": {[]float64{1, 3}, config.ChannelTypeMistral}, // 2$ / 1M tokens 6$ / 1M tokens 0.002$ / 1k tokens
+ "mistral-medium-latest": {[]float64{1.35, 4.05}, config.ChannelTypeMistral}, // 2.7$ / 1M tokens 8.1$ / 1M tokens 0.0027$ / 1k tokens
+ "mistral-large-latest": {[]float64{4, 12}, config.ChannelTypeMistral}, // 8$ / 1M tokens 24$ / 1M tokens 0.008$ / 1k tokens
+ "mistral-embed": {[]float64{0.05, 0.05}, config.ChannelTypeMistral}, // 0.1$ / 1M tokens 0.1$ / 1M tokens 0.0001$ / 1k tokens
// $0.70/$0.80 /1M Tokens 0.0007$ / 1k tokens
- "llama2-70b-4096": {[]float64{0.35, 0.4}, common.ChannelTypeGroq},
+ "llama2-70b-4096": {[]float64{0.35, 0.4}, config.ChannelTypeGroq},
// $0.10/$0.10 /1M Tokens 0.0001$ / 1k tokens
- "llama2-7b-2048": {[]float64{0.05, 0.05}, common.ChannelTypeGroq},
- "gemma-7b-it": {[]float64{0.05, 0.05}, common.ChannelTypeGroq},
+ "llama2-7b-2048": {[]float64{0.05, 0.05}, config.ChannelTypeGroq},
+ "gemma-7b-it": {[]float64{0.05, 0.05}, config.ChannelTypeGroq},
// $0.27/$0.27 /1M Tokens 0.00027$ / 1k tokens
- "mixtral-8x7b-32768": {[]float64{0.135, 0.135}, common.ChannelTypeGroq},
+ "mixtral-8x7b-32768": {[]float64{0.135, 0.135}, config.ChannelTypeGroq},
// 2.5 元 / 1M tokens 0.0025 / 1k tokens
- "yi-34b-chat-0205": {[]float64{0.1786, 0.1786}, common.ChannelTypeLingyi},
+ "yi-34b-chat-0205": {[]float64{0.1786, 0.1786}, config.ChannelTypeLingyi},
// 12 元 / 1M tokens 0.012 / 1k tokens
- "yi-34b-chat-200k": {[]float64{0.8571, 0.8571}, common.ChannelTypeLingyi},
+ "yi-34b-chat-200k": {[]float64{0.8571, 0.8571}, config.ChannelTypeLingyi},
// 6 元 / 1M tokens 0.006 / 1k tokens
- "yi-vl-plus": {[]float64{0.4286, 0.4286}, common.ChannelTypeLingyi},
+ "yi-vl-plus": {[]float64{0.4286, 0.4286}, config.ChannelTypeLingyi},
- "@cf/stabilityai/stable-diffusion-xl-base-1.0": {[]float64{0, 0}, common.ChannelTypeCloudflareAI},
- "@cf/lykon/dreamshaper-8-lcm": {[]float64{0, 0}, common.ChannelTypeCloudflareAI},
- "@cf/bytedance/stable-diffusion-xl-lightning": {[]float64{0, 0}, common.ChannelTypeCloudflareAI},
- "@cf/qwen/qwen1.5-7b-chat-awq": {[]float64{0, 0}, common.ChannelTypeCloudflareAI},
- "@cf/qwen/qwen1.5-14b-chat-awq": {[]float64{0, 0}, common.ChannelTypeCloudflareAI},
- "@hf/thebloke/deepseek-coder-6.7b-base-awq": {[]float64{0, 0}, common.ChannelTypeCloudflareAI},
- "@hf/google/gemma-7b-it": {[]float64{0, 0}, common.ChannelTypeCloudflareAI},
- "@hf/thebloke/llama-2-13b-chat-awq": {[]float64{0, 0}, common.ChannelTypeCloudflareAI},
- "@cf/openai/whisper": {[]float64{0, 0}, common.ChannelTypeCloudflareAI},
+ "@cf/stabilityai/stable-diffusion-xl-base-1.0": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
+ "@cf/lykon/dreamshaper-8-lcm": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
+ "@cf/bytedance/stable-diffusion-xl-lightning": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
+ "@cf/qwen/qwen1.5-7b-chat-awq": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
+ "@cf/qwen/qwen1.5-14b-chat-awq": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
+ "@hf/thebloke/deepseek-coder-6.7b-base-awq": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
+ "@hf/google/gemma-7b-it": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
+ "@hf/thebloke/llama-2-13b-chat-awq": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
+ "@cf/openai/whisper": {[]float64{0, 0}, config.ChannelTypeCloudflareAI},
//$0.50 /1M TOKENS $1.50/1M TOKENS
- "command-r": {[]float64{0.25, 0.75}, common.ChannelTypeCohere},
+ "command-r": {[]float64{0.25, 0.75}, config.ChannelTypeCohere},
//$3 /1M TOKENS $15/1M TOKENS
- "command-r-plus": {[]float64{1.5, 7.5}, common.ChannelTypeCohere},
+ "command-r-plus": {[]float64{1.5, 7.5}, config.ChannelTypeCohere},
// 0.065
- "sd3": {[]float64{32.5, 32.5}, common.ChannelTypeStabilityAI},
+ "sd3": {[]float64{32.5, 32.5}, config.ChannelTypeStabilityAI},
// 0.04
- "sd3-turbo": {[]float64{20, 20}, common.ChannelTypeStabilityAI},
+ "sd3-turbo": {[]float64{20, 20}, config.ChannelTypeStabilityAI},
// 0.03
- "stable-image-core": {[]float64{15, 15}, common.ChannelTypeStabilityAI},
+ "stable-image-core": {[]float64{15, 15}, config.ChannelTypeStabilityAI},
// hunyuan
- "hunyuan-lite": {[]float64{0, 0}, common.ChannelTypeHunyuan},
- "hunyuan-standard": {[]float64{0.3214, 0.3571}, common.ChannelTypeHunyuan},
- "hunyuan-standard-256k": {[]float64{1.0714, 4.2857}, common.ChannelTypeHunyuan},
- "hunyuan-pro": {[]float64{2.1429, 7.1429}, common.ChannelTypeHunyuan},
+ "hunyuan-lite": {[]float64{0, 0}, config.ChannelTypeHunyuan},
+ "hunyuan-standard": {[]float64{0.3214, 0.3571}, config.ChannelTypeHunyuan},
+ "hunyuan-standard-256k": {[]float64{1.0714, 4.2857}, config.ChannelTypeHunyuan},
+ "hunyuan-pro": {[]float64{2.1429, 7.1429}, config.ChannelTypeHunyuan},
}
var prices []*Price
@@ -355,7 +355,7 @@ func GetDefaultPrice() []*Price {
prices = append(prices, &Price{
Model: model,
Type: TimesPriceType,
- ChannelType: common.ChannelTypeMidjourney,
+ ChannelType: config.ChannelTypeMidjourney,
Input: mjPrice,
Output: mjPrice,
})
diff --git a/model/redemption.go b/model/redemption.go
index 821e244b..da6f524a 100644
--- a/model/redemption.go
+++ b/model/redemption.go
@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"one-api/common"
+ "one-api/common/config"
"one-api/common/utils"
"gorm.io/gorm"
@@ -69,7 +70,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil {
return errors.New("无效的兑换码")
}
- if redemption.Status != common.RedemptionCodeStatusEnabled {
+ if redemption.Status != config.RedemptionCodeStatusEnabled {
return errors.New("该兑换码已被使用")
}
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
@@ -77,7 +78,7 @@ func Redeem(key string, userId int) (quota int, err error) {
return err
}
redemption.RedeemedTime = utils.GetTimestamp()
- redemption.Status = common.RedemptionCodeStatusUsed
+ redemption.Status = config.RedemptionCodeStatusUsed
err = tx.Save(redemption).Error
return err
})
diff --git a/model/token.go b/model/token.go
index 883b6d4c..a6281af0 100644
--- a/model/token.go
+++ b/model/token.go
@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"one-api/common/stmp"
"one-api/common/utils"
@@ -49,7 +50,7 @@ func GetUserTokensList(userId int, params *GenericParams) (*DataResult[Token], e
// 获取状态为可用的令牌
func GetUserEnabledTokens(userId int) (tokens []*Token, err error) {
- err = DB.Where("user_id = ? and status = ?", userId, common.TokenStatusEnabled).Find(&tokens).Error
+ err = DB.Where("user_id = ? and status = ?", userId, config.TokenStatusEnabled).Find(&tokens).Error
return tokens, err
}
@@ -65,17 +66,17 @@ func ValidateUserToken(key string) (token *Token, err error) {
}
return nil, errors.New("令牌验证失败")
}
- if token.Status == common.TokenStatusExhausted {
+ if token.Status == config.TokenStatusExhausted {
return nil, errors.New("该令牌额度已用尽")
- } else if token.Status == common.TokenStatusExpired {
+ } else if token.Status == config.TokenStatusExpired {
return nil, errors.New("该令牌已过期")
}
- if token.Status != common.TokenStatusEnabled {
+ if token.Status != config.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用")
}
if token.ExpiredTime != -1 && token.ExpiredTime < utils.GetTimestamp() {
if !common.RedisEnabled {
- token.Status = common.TokenStatusExpired
+ token.Status = config.TokenStatusExpired
err := token.SelectUpdate()
if err != nil {
logger.SysError("failed to update token status" + err.Error())
@@ -86,7 +87,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled {
// in this case, we can make sure the token is exhausted
- token.Status = common.TokenStatusExhausted
+ token.Status = config.TokenStatusExhausted
err := token.SelectUpdate()
if err != nil {
logger.SysError("failed to update token status" + err.Error())
@@ -128,7 +129,7 @@ func GetTokenByName(name string, user_id int) (*Token, error) {
}
func (token *Token) Insert() error {
- if token.ChatCache && !common.ChatCacheEnabled {
+ if token.ChatCache && !config.ChatCacheEnabled {
token.ChatCache = false
}
@@ -138,7 +139,7 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
- if token.ChatCache && !common.ChatCacheEnabled {
+ if token.ChatCache && !config.ChatCacheEnabled {
token.ChatCache = false
}
@@ -178,7 +179,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
return nil
}
@@ -200,7 +201,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
return nil
}
@@ -236,7 +237,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
if userQuota < quota {
return errors.New("用户额度不足")
}
- quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold
+ quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold
noMoreQuota := userQuota-quota <= 0
if quotaTooLow || noMoreQuota {
go sendQuotaWarningEmail(token.UserId, userQuota, noMoreQuota)
diff --git a/model/user.go b/model/user.go
index f7f13ec8..2393ac5e 100644
--- a/model/user.go
+++ b/model/user.go
@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"one-api/common/utils"
"strings"
@@ -116,7 +117,7 @@ func (user *User) Insert(inviterId int) error {
return err
}
}
- user.Quota = common.QuotaForNewUser
+ user.Quota = config.QuotaForNewUser
user.AccessToken = utils.GetUUID()
user.AffCode = utils.GetRandomString(4)
user.CreatedTime = utils.GetTimestamp()
@@ -124,17 +125,17 @@ func (user *User) Insert(inviterId int) error {
if result.Error != nil {
return result.Error
}
- if common.QuotaForNewUser > 0 {
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
+ if config.QuotaForNewUser > 0 {
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
}
if inviterId != 0 {
- if common.QuotaForInvitee > 0 {
- _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
+ if config.QuotaForInvitee > 0 {
+ _ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
}
- if common.QuotaForInviter > 0 {
- _ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
- RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
+ if config.QuotaForInviter > 0 {
+ _ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
+ RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
}
}
return nil
@@ -150,8 +151,8 @@ func (user *User) Update(updatePassword bool) error {
}
err = DB.Model(user).Updates(user).Error
- if err == nil && user.Role == common.RoleRootUser {
- common.RootUserEmail = user.Email
+ if err == nil && user.Role == config.RoleRootUser {
+ config.RootUserEmail = user.Email
}
return err
@@ -196,7 +197,7 @@ func (user *User) ValidateAndFill() (err error) {
}
}
okay := common.ValidatePasswordAndHash(password, user.Password)
- if !okay || user.Status != common.UserStatusEnabled {
+ if !okay || user.Status != config.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁")
}
return nil
@@ -310,7 +311,7 @@ func IsAdmin(userId int) bool {
logger.SysError("no such user " + err.Error())
return false
}
- return user.Role >= common.RoleAdminUser
+ return user.Role >= config.RoleAdminUser
}
func IsUserEnabled(userId int) (bool, error) {
@@ -322,7 +323,7 @@ func IsUserEnabled(userId int) (bool, error) {
if err != nil {
return false, err
}
- return user.Status == common.UserStatusEnabled, nil
+ return user.Status == config.UserStatusEnabled, nil
}
func ValidateAccessToken(token string) (user *User) {
@@ -366,7 +367,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil
}
@@ -382,7 +383,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
return nil
}
@@ -395,12 +396,12 @@ func decreaseUserQuota(id int, quota int) (err error) {
}
func GetRootUserEmail() (email string) {
- DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
+ DB.Model(&User{}).Where("role = ?", config.RoleRootUser).Select("email").Find(&email)
return email
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
- if common.BatchUpdateEnabled {
+ if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
return
diff --git a/model/utils.go b/model/utils.go
index e4797a78..e0826e0d 100644
--- a/model/utils.go
+++ b/model/utils.go
@@ -1,7 +1,7 @@
package model
import (
- "one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"sync"
"time"
@@ -29,7 +29,7 @@ func init() {
func InitBatchUpdater() {
go func() {
for {
- time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
+ time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second)
batchUpdate()
}
}()
diff --git a/providers/ali/ali_test.go b/providers/ali/ali_test.go
index c5cbc26f..a3f93ce5 100644
--- a/providers/ali/ali_test.go
+++ b/providers/ali/ali_test.go
@@ -2,7 +2,7 @@ package ali_test
import (
"net/http"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/test"
"one-api/model"
)
@@ -20,5 +20,5 @@ func setupAliTestServer() (baseUrl string, server *test.ServerTest, teardown fun
}
func getAliChannel(baseUrl string) model.Channel {
- return test.GetChannel(common.ChannelTypeAli, baseUrl, "", "", "")
+ return test.GetChannel(config.ChannelTypeAli, baseUrl, "", "", "")
}
diff --git a/providers/ali/chat.go b/providers/ali/chat.go
index 8efc839a..f30a1472 100644
--- a/providers/ali/chat.go
+++ b/providers/ali/chat.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/common/utils"
"one-api/types"
@@ -55,7 +56,7 @@ func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRe
}
func (p *AliProvider) getAliChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/ali/embeddings.go b/providers/ali/embeddings.go
index fdbe6ca6..c64961aa 100644
--- a/providers/ali/embeddings.go
+++ b/providers/ali/embeddings.go
@@ -3,11 +3,12 @@ package ali
import (
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/types"
)
func (p *AliProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/azure/image_generations.go b/providers/azure/image_generations.go
index 882d8268..f34db2f9 100644
--- a/providers/azure/image_generations.go
+++ b/providers/azure/image_generations.go
@@ -4,6 +4,7 @@ import (
"errors"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/providers/openai"
"one-api/types"
"time"
@@ -14,7 +15,7 @@ func (p *AzureProvider) CreateImageGenerations(request *types.ImageRequest) (*ty
return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest)
}
- req, errWithCode := p.GetRequestTextBody(common.RelayModeImagesGenerations, request.Model, request)
+ req, errWithCode := p.GetRequestTextBody(config.RelayModeImagesGenerations, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/azureSpeech/speech.go b/providers/azureSpeech/speech.go
index 9d03dbed..2210e5d6 100644
--- a/providers/azureSpeech/speech.go
+++ b/providers/azureSpeech/speech.go
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/types"
"strings"
)
@@ -82,7 +83,7 @@ func (p *AzureSpeechProvider) getRequestBody(request *types.SpeechAudioRequest)
}
func (p *AzureSpeechProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeAudioSpeech)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeAudioSpeech)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/baichuan/chat.go b/providers/baichuan/chat.go
index 36ed4f33..3cacb6db 100644
--- a/providers/baichuan/chat.go
+++ b/providers/baichuan/chat.go
@@ -2,7 +2,7 @@ package baichuan
import (
"net/http"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/providers/openai"
"one-api/types"
@@ -11,7 +11,7 @@ import (
func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody := p.getChatRequestBody(request)
- req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, requestBody)
+ req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, requestBody)
if errWithCode != nil {
return nil, errWithCode
}
@@ -51,7 +51,7 @@ func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatComplet
request.StreamOptions = nil
}
- req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
+ req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/baidu/chat.go b/providers/baidu/chat.go
index e93d1b1a..6b4551fa 100644
--- a/providers/baidu/chat.go
+++ b/providers/baidu/chat.go
@@ -5,6 +5,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/types"
"strings"
@@ -54,7 +55,7 @@ func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletion
}
func (p *BaiduProvider) getBaiduChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/baidu/embeddings.go b/providers/baidu/embeddings.go
index 5e13215e..a62fdf2f 100644
--- a/providers/baidu/embeddings.go
+++ b/providers/baidu/embeddings.go
@@ -3,11 +3,12 @@ package baidu
import (
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/types"
)
func (p *BaiduProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/base/common.go b/providers/base/common.go
index 37429880..9b952b64 100644
--- a/providers/base/common.go
+++ b/providers/base/common.go
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/model"
"one-api/types"
@@ -110,25 +111,25 @@ func (p *BaseProvider) ModelMappingHandler(modelName string) (string, error) {
func (p *BaseProvider) GetAPIUri(relayMode int) string {
switch relayMode {
- case common.RelayModeChatCompletions:
+ case config.RelayModeChatCompletions:
return p.Config.ChatCompletions
- case common.RelayModeCompletions:
+ case config.RelayModeCompletions:
return p.Config.Completions
- case common.RelayModeEmbeddings:
+ case config.RelayModeEmbeddings:
return p.Config.Embeddings
- case common.RelayModeAudioSpeech:
+ case config.RelayModeAudioSpeech:
return p.Config.AudioSpeech
- case common.RelayModeAudioTranscription:
+ case config.RelayModeAudioTranscription:
return p.Config.AudioTranscriptions
- case common.RelayModeAudioTranslation:
+ case config.RelayModeAudioTranslation:
return p.Config.AudioTranslations
- case common.RelayModeModerations:
+ case config.RelayModeModerations:
return p.Config.Moderation
- case common.RelayModeImagesGenerations:
+ case config.RelayModeImagesGenerations:
return p.Config.ImagesGenerations
- case common.RelayModeImagesEdits:
+ case config.RelayModeImagesEdits:
return p.Config.ImagesEdit
- case common.RelayModeImagesVariations:
+ case config.RelayModeImagesVariations:
return p.Config.ImagesVariations
default:
return ""
diff --git a/providers/bedrock/chat.go b/providers/bedrock/chat.go
index de24cdd2..654026e1 100644
--- a/providers/bedrock/chat.go
+++ b/providers/bedrock/chat.go
@@ -3,6 +3,7 @@ package bedrock
import (
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/providers/bedrock/category"
"one-api/types"
@@ -48,7 +49,7 @@ func (p *BedrockProvider) getChatRequest(request *types.ChatCompletionRequest) (
return nil, common.StringErrorWrapper("bedrock provider not found", "bedrock_err", http.StatusInternalServerError)
}
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/claude/chat.go b/providers/claude/chat.go
index 7c3d53e9..7024ab20 100644
--- a/providers/claude/chat.go
+++ b/providers/claude/chat.go
@@ -6,6 +6,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/image"
"one-api/common/requester"
"one-api/common/utils"
@@ -64,7 +65,7 @@ func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletio
}
func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/cohere/chat.go b/providers/cohere/chat.go
index ae416ab6..921a7ddc 100644
--- a/providers/cohere/chat.go
+++ b/providers/cohere/chat.go
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/common/utils"
"one-api/providers/base"
@@ -56,7 +57,7 @@ func (p *CohereProvider) CreateChatCompletionStream(request *types.ChatCompletio
}
func (p *CohereProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/coze/chat.go b/providers/coze/chat.go
index 3913e28f..8f58775c 100644
--- a/providers/coze/chat.go
+++ b/providers/coze/chat.go
@@ -6,6 +6,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/common/utils"
"one-api/types"
@@ -56,7 +57,7 @@ func (p *CozeProvider) CreateChatCompletionStream(request *types.ChatCompletionR
}
func (p *CozeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/groq/chat.go b/providers/groq/chat.go
index d41ccb30..ee3660ab 100644
--- a/providers/groq/chat.go
+++ b/providers/groq/chat.go
@@ -2,7 +2,7 @@ package groq
import (
"net/http"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/providers/openai"
"one-api/types"
@@ -11,7 +11,7 @@ import (
func (p *GroqProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
p.getChatRequestBody(request)
- req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
+ req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
@@ -51,7 +51,7 @@ func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionR
request.StreamOptions = nil
}
p.getChatRequestBody(request)
- req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
+ req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/hunyuan/chat.go b/providers/hunyuan/chat.go
index d5be9f06..f588d425 100644
--- a/providers/hunyuan/chat.go
+++ b/providers/hunyuan/chat.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/types"
"strings"
@@ -53,7 +54,7 @@ func (p *HunyuanProvider) CreateChatCompletionStream(request *types.ChatCompleti
}
func (p *HunyuanProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
- action, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
+ action, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/midjourney/base.go b/providers/midjourney/base.go
index 434457f9..21afe682 100644
--- a/providers/midjourney/base.go
+++ b/providers/midjourney/base.go
@@ -7,7 +7,7 @@ import (
"io"
"log"
"net/http"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"one-api/common/requester"
"one-api/model"
@@ -48,7 +48,7 @@ func (p *MidjourneyProvider) Send(timeout int, requestURL string) (*MidjourneyRe
return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
}
delete(mapResult, "accountFilter")
- if !common.MjNotifyEnabled {
+ if !config.MjNotifyEnabled {
delete(mapResult, "notifyHook")
}
}
diff --git a/providers/minimax/chat.go b/providers/minimax/chat.go
index 88d95d71..f2a612c0 100644
--- a/providers/minimax/chat.go
+++ b/providers/minimax/chat.go
@@ -5,6 +5,7 @@ import (
"errors"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/types"
"strings"
@@ -55,7 +56,7 @@ func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompleti
}
func (p *MiniMaxProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/minimax/embeddings.go b/providers/minimax/embeddings.go
index 0eb84d5e..f1a2f9e7 100644
--- a/providers/minimax/embeddings.go
+++ b/providers/minimax/embeddings.go
@@ -3,11 +3,12 @@ package minimax
import (
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/types"
)
func (p *MiniMaxProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/mistral/chat.go b/providers/mistral/chat.go
index d9eaa122..a02f30f5 100644
--- a/providers/mistral/chat.go
+++ b/providers/mistral/chat.go
@@ -5,6 +5,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/types"
"strings"
@@ -56,7 +57,7 @@ func (p *MistralProvider) CreateChatCompletionStream(request *types.ChatCompleti
}
func (p *MistralProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/mistral/embeddings.go b/providers/mistral/embeddings.go
index 8bcc9fbb..bb49735f 100644
--- a/providers/mistral/embeddings.go
+++ b/providers/mistral/embeddings.go
@@ -3,11 +3,12 @@ package mistral
import (
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/types"
)
func (p *MistralProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/moonshot/chat.go b/providers/moonshot/chat.go
index e45ba2e2..9c711439 100644
--- a/providers/moonshot/chat.go
+++ b/providers/moonshot/chat.go
@@ -3,6 +3,7 @@ package moonshot
import (
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/providers/openai"
"one-api/types"
@@ -10,7 +11,7 @@ import (
func (p *MoonshotProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
request.ClearEmptyMessages()
- req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
+ req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
@@ -51,7 +52,7 @@ func (p *MoonshotProvider) CreateChatCompletion(request *types.ChatCompletionReq
func (p *MoonshotProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
request.ClearEmptyMessages()
- req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
+ req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/ollama/chat.go b/providers/ollama/chat.go
index 5b01cf79..99ac9f2f 100644
--- a/providers/ollama/chat.go
+++ b/providers/ollama/chat.go
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/image"
"one-api/common/requester"
"one-api/common/utils"
@@ -56,7 +57,7 @@ func (p *OllamaProvider) CreateChatCompletionStream(request *types.ChatCompletio
}
func (p *OllamaProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/ollama/embeddings.go b/providers/ollama/embeddings.go
index 7231f9e4..2943f9e5 100644
--- a/providers/ollama/embeddings.go
+++ b/providers/ollama/embeddings.go
@@ -3,11 +3,12 @@ package ollama
import (
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/types"
)
func (p *OllamaProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/openai/base.go b/providers/openai/base.go
index 34d5744d..b586345e 100644
--- a/providers/openai/base.go
+++ b/providers/openai/base.go
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/model"
"one-api/types"
@@ -32,11 +33,11 @@ func (f OpenAIProviderFactory) Create(channel *model.Channel) base.ProviderInter
// 创建 OpenAIProvider
// https://platform.openai.com/docs/api-reference/introduction
func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvider {
- config := getOpenAIConfig(baseURL)
+ openaiConfig := getOpenAIConfig(baseURL)
OpenAIProvider := &OpenAIProvider{
BaseProvider: base.BaseProvider{
- Config: config,
+ Config: openaiConfig,
Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, RequestErrorHandle),
},
@@ -44,7 +45,7 @@ func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvide
BalanceAction: true,
}
- if channel.Type == common.ChannelTypeOpenAI {
+ if channel.Type == config.ChannelTypeOpenAI {
OpenAIProvider.SupportStreamOptions = true
}
@@ -109,7 +110,7 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string)
requestURL = fmt.Sprintf("/openai%s?api-version=%s", requestURL, apiVersion)
}
- } else if p.Channel.Type == common.ChannelTypeCustom && p.Channel.Other != "" {
+ } else if p.Channel.Type == config.ChannelTypeCustom && p.Channel.Other != "" {
requestURL = strings.Replace(requestURL, "v1", p.Channel.Other, 1)
}
diff --git a/providers/openai/chat.go b/providers/openai/chat.go
index e57b7d04..8857b6fa 100644
--- a/providers/openai/chat.go
+++ b/providers/openai/chat.go
@@ -5,6 +5,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/types"
"strings"
@@ -17,7 +18,7 @@ type OpenAIStreamHandler struct {
}
func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
- req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
+ req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
@@ -67,7 +68,7 @@ func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletio
// 避免误传导致报错
request.StreamOptions = nil
}
- req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
+ req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/openai/completion.go b/providers/openai/completion.go
index e52120e7..3a775605 100644
--- a/providers/openai/completion.go
+++ b/providers/openai/completion.go
@@ -5,13 +5,14 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/types"
"strings"
)
func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (openaiResponse *types.CompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
- req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request)
+ req, errWithCode := p.GetRequestTextBody(config.RelayModeCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
@@ -50,7 +51,7 @@ func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest
// 避免误传导致报错
request.StreamOptions = nil
}
- req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request)
+ req, errWithCode := p.GetRequestTextBody(config.RelayModeCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/openai/embeddings.go b/providers/openai/embeddings.go
index aa7b48a0..325780d3 100644
--- a/providers/openai/embeddings.go
+++ b/providers/openai/embeddings.go
@@ -2,12 +2,12 @@ package openai
import (
"net/http"
- "one-api/common"
+ "one-api/common/config"
"one-api/types"
)
func (p *OpenAIProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
- req, errWithCode := p.GetRequestTextBody(common.RelayModeEmbeddings, request.Model, request)
+ req, errWithCode := p.GetRequestTextBody(config.RelayModeEmbeddings, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/openai/image_edits.go b/providers/openai/image_edits.go
index 376459bc..5943d0cc 100644
--- a/providers/openai/image_edits.go
+++ b/providers/openai/image_edits.go
@@ -5,12 +5,13 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/types"
)
func (p *OpenAIProvider) CreateImageEdits(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
- req, errWithCode := p.getRequestImageBody(common.RelayModeEdits, request.Model, request)
+ req, errWithCode := p.getRequestImageBody(config.RelayModeEdits, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/openai/image_generations.go b/providers/openai/image_generations.go
index 53d4796e..c3e505b0 100644
--- a/providers/openai/image_generations.go
+++ b/providers/openai/image_generations.go
@@ -3,6 +3,7 @@ package openai
import (
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/types"
)
@@ -11,7 +12,7 @@ func (p *OpenAIProvider) CreateImageGenerations(request *types.ImageRequest) (*t
return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest)
}
- req, errWithCode := p.GetRequestTextBody(common.RelayModeImagesGenerations, request.Model, request)
+ req, errWithCode := p.GetRequestTextBody(config.RelayModeImagesGenerations, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/openai/image_variations.go b/providers/openai/image_variations.go
index 4160d28d..6ad5bd54 100644
--- a/providers/openai/image_variations.go
+++ b/providers/openai/image_variations.go
@@ -2,12 +2,12 @@ package openai
import (
"net/http"
- "one-api/common"
+ "one-api/common/config"
"one-api/types"
)
func (p *OpenAIProvider) CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
- req, errWithCode := p.getRequestImageBody(common.RelayModeImagesVariations, request.Model, request)
+ req, errWithCode := p.getRequestImageBody(config.RelayModeImagesVariations, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/openai/moderation.go b/providers/openai/moderation.go
index 0cbd1bad..972e88b9 100644
--- a/providers/openai/moderation.go
+++ b/providers/openai/moderation.go
@@ -2,13 +2,13 @@ package openai
import (
"net/http"
- "one-api/common"
+ "one-api/common/config"
"one-api/types"
)
func (p *OpenAIProvider) CreateModeration(request *types.ModerationRequest) (*types.ModerationResponse, *types.OpenAIErrorWithStatusCode) {
- req, errWithCode := p.GetRequestTextBody(common.RelayModeModerations, request.Model, request)
+ req, errWithCode := p.GetRequestTextBody(config.RelayModeModerations, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/openai/speech.go b/providers/openai/speech.go
index 5d3a6f65..077f0950 100644
--- a/providers/openai/speech.go
+++ b/providers/openai/speech.go
@@ -2,13 +2,13 @@ package openai
import (
"net/http"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/types"
)
func (p *OpenAIProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) {
- req, errWithCode := p.GetRequestTextBody(common.RelayModeAudioSpeech, request.Model, request)
+ req, errWithCode := p.GetRequestTextBody(config.RelayModeAudioSpeech, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/openai/transcriptions.go b/providers/openai/transcriptions.go
index 8be9aeb4..b24dd9fa 100644
--- a/providers/openai/transcriptions.go
+++ b/providers/openai/transcriptions.go
@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/types"
"regexp"
@@ -15,7 +16,7 @@ import (
)
func (p *OpenAIProvider) CreateTranscriptions(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) {
- req, errWithCode := p.getRequestAudioBody(common.RelayModeAudioTranscription, request.Model, request)
+ req, errWithCode := p.getRequestAudioBody(config.RelayModeAudioTranscription, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/openai/translations.go b/providers/openai/translations.go
index 1bc4b581..6f993e89 100644
--- a/providers/openai/translations.go
+++ b/providers/openai/translations.go
@@ -4,11 +4,12 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/types"
)
func (p *OpenAIProvider) CreateTranslation(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) {
- req, errWithCode := p.getRequestAudioBody(common.RelayModeAudioTranslation, request.Model, request)
+ req, errWithCode := p.getRequestAudioBody(config.RelayModeAudioTranslation, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/palm/chat.go b/providers/palm/chat.go
index 41577c96..754e7d91 100644
--- a/providers/palm/chat.go
+++ b/providers/palm/chat.go
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/common/utils"
"one-api/types"
@@ -55,7 +56,7 @@ func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionR
}
func (p *PalmProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/providers.go b/providers/providers.go
index fdbe6ad8..59127894 100644
--- a/providers/providers.go
+++ b/providers/providers.go
@@ -1,7 +1,7 @@
package providers
import (
- "one-api/common"
+ "one-api/common/config"
"one-api/model"
"one-api/providers/ali"
"one-api/providers/azure"
@@ -44,32 +44,32 @@ var providerFactories = make(map[int]ProviderFactory)
// 在程序启动时,添加所有的供应商工厂
func init() {
- providerFactories[common.ChannelTypeOpenAI] = openai.OpenAIProviderFactory{}
- providerFactories[common.ChannelTypeAzure] = azure.AzureProviderFactory{}
- providerFactories[common.ChannelTypeAli] = ali.AliProviderFactory{}
- providerFactories[common.ChannelTypeTencent] = tencent.TencentProviderFactory{}
- providerFactories[common.ChannelTypeBaidu] = baidu.BaiduProviderFactory{}
- providerFactories[common.ChannelTypeAnthropic] = claude.ClaudeProviderFactory{}
- providerFactories[common.ChannelTypePaLM] = palm.PalmProviderFactory{}
- providerFactories[common.ChannelTypeZhipu] = zhipu.ZhipuProviderFactory{}
- providerFactories[common.ChannelTypeXunfei] = xunfei.XunfeiProviderFactory{}
- providerFactories[common.ChannelTypeAzureSpeech] = azurespeech.AzureSpeechProviderFactory{}
- providerFactories[common.ChannelTypeGemini] = gemini.GeminiProviderFactory{}
- providerFactories[common.ChannelTypeBaichuan] = baichuan.BaichuanProviderFactory{}
- providerFactories[common.ChannelTypeMiniMax] = minimax.MiniMaxProviderFactory{}
- providerFactories[common.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{}
- providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{}
- providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{}
- providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{}
- providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
- providerFactories[common.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{}
- providerFactories[common.ChannelTypeCohere] = cohere.CohereProviderFactory{}
- providerFactories[common.ChannelTypeStabilityAI] = stabilityAI.StabilityAIProviderFactory{}
- providerFactories[common.ChannelTypeCoze] = coze.CozeProviderFactory{}
- providerFactories[common.ChannelTypeOllama] = ollama.OllamaProviderFactory{}
- providerFactories[common.ChannelTypeMoonshot] = moonshot.MoonshotProviderFactory{}
- providerFactories[common.ChannelTypeLingyi] = lingyi.LingyiProviderFactory{}
- providerFactories[common.ChannelTypeHunyuan] = hunyuan.HunyuanProviderFactory{}
+ providerFactories[config.ChannelTypeOpenAI] = openai.OpenAIProviderFactory{}
+ providerFactories[config.ChannelTypeAzure] = azure.AzureProviderFactory{}
+ providerFactories[config.ChannelTypeAli] = ali.AliProviderFactory{}
+ providerFactories[config.ChannelTypeTencent] = tencent.TencentProviderFactory{}
+ providerFactories[config.ChannelTypeBaidu] = baidu.BaiduProviderFactory{}
+ providerFactories[config.ChannelTypeAnthropic] = claude.ClaudeProviderFactory{}
+ providerFactories[config.ChannelTypePaLM] = palm.PalmProviderFactory{}
+ providerFactories[config.ChannelTypeZhipu] = zhipu.ZhipuProviderFactory{}
+ providerFactories[config.ChannelTypeXunfei] = xunfei.XunfeiProviderFactory{}
+ providerFactories[config.ChannelTypeAzureSpeech] = azurespeech.AzureSpeechProviderFactory{}
+ providerFactories[config.ChannelTypeGemini] = gemini.GeminiProviderFactory{}
+ providerFactories[config.ChannelTypeBaichuan] = baichuan.BaichuanProviderFactory{}
+ providerFactories[config.ChannelTypeMiniMax] = minimax.MiniMaxProviderFactory{}
+ providerFactories[config.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{}
+ providerFactories[config.ChannelTypeMistral] = mistral.MistralProviderFactory{}
+ providerFactories[config.ChannelTypeGroq] = groq.GroqProviderFactory{}
+ providerFactories[config.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{}
+ providerFactories[config.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
+ providerFactories[config.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{}
+ providerFactories[config.ChannelTypeCohere] = cohere.CohereProviderFactory{}
+ providerFactories[config.ChannelTypeStabilityAI] = stabilityAI.StabilityAIProviderFactory{}
+ providerFactories[config.ChannelTypeCoze] = coze.CozeProviderFactory{}
+ providerFactories[config.ChannelTypeOllama] = ollama.OllamaProviderFactory{}
+ providerFactories[config.ChannelTypeMoonshot] = moonshot.MoonshotProviderFactory{}
+ providerFactories[config.ChannelTypeLingyi] = lingyi.LingyiProviderFactory{}
+ providerFactories[config.ChannelTypeHunyuan] = hunyuan.HunyuanProviderFactory{}
}
@@ -79,7 +79,7 @@ func GetProvider(channel *model.Channel, c *gin.Context) base.ProviderInterface
var provider base.ProviderInterface
if !ok {
// 处理未找到的供应商工厂
- baseURL := common.ChannelBaseURLs[channel.Type]
+ baseURL := config.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
diff --git a/providers/stabilityAI/image_generations.go b/providers/stabilityAI/image_generations.go
index e79a46bd..37a2c7be 100644
--- a/providers/stabilityAI/image_generations.go
+++ b/providers/stabilityAI/image_generations.go
@@ -5,6 +5,7 @@ import (
"encoding/base64"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/storage"
"one-api/common/utils"
"one-api/types"
@@ -20,7 +21,7 @@ func convertModelName(modelName string) string {
}
func (p *StabilityAIProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeImagesGenerations)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeImagesGenerations)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/tencent/chat.go b/providers/tencent/chat.go
index 25625cf1..fb4ef287 100644
--- a/providers/tencent/chat.go
+++ b/providers/tencent/chat.go
@@ -5,6 +5,7 @@ import (
"errors"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/common/utils"
"one-api/types"
@@ -55,7 +56,7 @@ func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompleti
}
func (p *TencentProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/xunfei/chat.go b/providers/xunfei/chat.go
index ee8d289f..fc740076 100644
--- a/providers/xunfei/chat.go
+++ b/providers/xunfei/chat.go
@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/common/utils"
"one-api/types"
@@ -59,7 +60,7 @@ func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletio
}
func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (*websocket.Conn, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/zhipu/chat.go b/providers/zhipu/chat.go
index fc2ee6a5..536f68d9 100644
--- a/providers/zhipu/chat.go
+++ b/providers/zhipu/chat.go
@@ -5,6 +5,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/requester"
"one-api/types"
"strings"
@@ -54,7 +55,7 @@ func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletion
}
func (p *ZhipuProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/zhipu/embeddings.go b/providers/zhipu/embeddings.go
index e33e1da5..681f8fea 100644
--- a/providers/zhipu/embeddings.go
+++ b/providers/zhipu/embeddings.go
@@ -3,11 +3,12 @@ package zhipu
import (
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/types"
)
func (p *ZhipuProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/providers/zhipu/image_generations.go b/providers/zhipu/image_generations.go
index 5bed657b..52622cf6 100644
--- a/providers/zhipu/image_generations.go
+++ b/providers/zhipu/image_generations.go
@@ -3,12 +3,13 @@ package zhipu
import (
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/types"
"time"
)
func (p *ZhipuProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
- url, errWithCode := p.GetSupportedAPIUri(common.RelayModeImagesGenerations)
+ url, errWithCode := p.GetSupportedAPIUri(config.RelayModeImagesGenerations)
if errWithCode != nil {
return nil, errWithCode
}
diff --git a/relay/common.go b/relay/common.go
index a6618e2f..97a68b4c 100644
--- a/relay/common.go
+++ b/relay/common.go
@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"one-api/common/requester"
"one-api/common/utils"
@@ -93,7 +94,7 @@ func fetchChannelById(channelId int) (*model.Channel, error) {
if err != nil {
return nil, errors.New("无效的渠道 Id")
}
- if channel.Status != common.ChannelStatusEnabled {
+ if channel.Status != config.ChannelStatusEnabled {
return nil, errors.New("该渠道已被禁用")
}
diff --git a/relay/main.go b/relay/main.go
index a035a025..26fb1d24 100644
--- a/relay/main.go
+++ b/relay/main.go
@@ -4,6 +4,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"one-api/model"
"one-api/relay/relay_util"
@@ -50,7 +51,7 @@ func Relay(c *gin.Context) {
channel := relay.getProvider().GetChannel()
go processChannelRelayError(c.Request.Context(), channel.Id, channel.Name, apiErr)
- retryTimes := common.RetryTimes
+ retryTimes := config.RetryTimes
if done || !shouldRetry(c, apiErr.StatusCode) {
logger.LogError(c.Request.Context(), fmt.Sprintf("relay error happen, status code is %d, won't retry in this case", apiErr.StatusCode))
retryTimes = 0
diff --git a/relay/midjourney/relay-mj.go b/relay/midjourney/relay-mj.go
index 078e0ff4..0d5d3533 100644
--- a/relay/midjourney/relay-mj.go
+++ b/relay/midjourney/relay-mj.go
@@ -10,6 +10,7 @@ import (
"log"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/controller"
"one-api/model"
provider "one-api/providers/midjourney"
@@ -112,7 +113,7 @@ func coverMidjourneyTaskDto(originTask *model.Midjourney) (midjourneyTask provid
midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" {
- midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
+ midjourneyTask.ImageUrl = config.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
}
diff --git a/relay/relay.go b/relay/relay.go
index a0b493b8..c6e3889e 100644
--- a/relay/relay.go
+++ b/relay/relay.go
@@ -3,6 +3,7 @@ package relay
import (
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/model"
"one-api/providers/azure"
"one-api/providers/openai"
@@ -20,7 +21,7 @@ func RelayOnly(c *gin.Context) {
}
channel := provider.GetChannel()
- if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeAzure {
+ if channel.Type != config.ChannelTypeOpenAI && channel.Type != config.ChannelTypeAzure {
common.AbortWithMessage(c, http.StatusServiceUnavailable, "provider must be of type azureopenai or openai")
return
}
diff --git a/relay/relay_util/cache.go b/relay/relay_util/cache.go
index 8315e7a3..5130c753 100644
--- a/relay/relay_util/cache.go
+++ b/relay/relay_util/cache.go
@@ -5,6 +5,7 @@ import (
"encoding/hex"
"fmt"
"one-api/common"
+ "one-api/common/config"
"one-api/common/utils"
"one-api/model"
@@ -57,7 +58,7 @@ func NewChatCacheProps(c *gin.Context, allow bool) *ChatCacheProps {
return props
}
- if common.ChatCacheEnabled && c.GetBool("chat_cache") {
+ if config.ChatCacheEnabled && c.GetBool("chat_cache") {
props.Cache = true
}
@@ -113,7 +114,7 @@ func (p *ChatCacheProps) StoreCache(channelId, promptTokens, completionTokens in
p.CompletionTokens = completionTokens
p.ModelName = modelName
- return p.Driver.Set(p.getHash(), p, int64(common.ChatCacheExpireMinute))
+ return p.Driver.Set(p.getHash(), p, int64(config.ChatCacheExpireMinute))
}
func (p *ChatCacheProps) GetCache() *ChatCacheProps {
@@ -125,7 +126,7 @@ func (p *ChatCacheProps) GetCache() *ChatCacheProps {
}
func (p *ChatCacheProps) needCache() bool {
- return common.ChatCacheEnabled && p.Cache
+ return config.ChatCacheEnabled && p.Cache
}
func (p *ChatCacheProps) getHash() string {
diff --git a/relay/relay_util/pricing.go b/relay/relay_util/pricing.go
index b2cab192..72ee290e 100644
--- a/relay/relay_util/pricing.go
+++ b/relay/relay_util/pricing.go
@@ -3,7 +3,7 @@ package relay_util
import (
"encoding/json"
"errors"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"one-api/common/utils"
"one-api/model"
@@ -107,7 +107,7 @@ func (p *Pricing) GetPrice(modelName string) *model.Price {
return &model.Price{
Type: model.TokensPriceType,
- ChannelType: common.ChannelTypeUnknown,
+ ChannelType: config.ChannelTypeUnknown,
Input: model.DefaultPrice,
Output: model.DefaultPrice,
}
diff --git a/relay/relay_util/quota.go b/relay/relay_util/quota.go
index 48f6b094..e0e0add2 100644
--- a/relay/relay_util/quota.go
+++ b/relay/relay_util/quota.go
@@ -7,6 +7,7 @@ import (
"math"
"net/http"
"one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"one-api/model"
"one-api/types"
@@ -45,7 +46,7 @@ func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *type
if quota.price.Type == model.TimesPriceType {
quota.preConsumedQuota = int(1000 * quota.inputRatio)
} else {
- quota.preConsumedQuota = int(float64(quota.promptTokens+common.PreConsumedQuota) * quota.inputRatio)
+ quota.preConsumedQuota = int(float64(quota.promptTokens+config.PreConsumedQuota) * quota.inputRatio)
}
errWithCode := quota.preQuotaConsumption()
diff --git a/relay/relay_util/type.go b/relay/relay_util/type.go
index 9c2d72d3..3233b02c 100644
--- a/relay/relay_util/type.go
+++ b/relay/relay_util/type.go
@@ -1,35 +1,37 @@
package relay_util
-import "one-api/common"
+import (
+ "one-api/common/config"
+)
var UnknownOwnedBy = "未知"
var ModelOwnedBy map[int]string
func init() {
ModelOwnedBy = map[int]string{
- common.ChannelTypeOpenAI: "OpenAI",
- common.ChannelTypeAnthropic: "Anthropic",
- common.ChannelTypeBaidu: "Baidu",
- common.ChannelTypePaLM: "Google PaLM",
- common.ChannelTypeGemini: "Google Gemini",
- common.ChannelTypeZhipu: "Zhipu",
- common.ChannelTypeAli: "Ali",
- common.ChannelTypeXunfei: "Xunfei",
- common.ChannelType360: "360",
- common.ChannelTypeTencent: "Tencent",
- common.ChannelTypeBaichuan: "Baichuan",
- common.ChannelTypeMiniMax: "MiniMax",
- common.ChannelTypeDeepseek: "Deepseek",
- common.ChannelTypeMoonshot: "Moonshot",
- common.ChannelTypeMistral: "Mistral",
- common.ChannelTypeGroq: "Groq",
- common.ChannelTypeLingyi: "Lingyiwanwu",
- common.ChannelTypeMidjourney: "Midjourney",
- common.ChannelTypeCloudflareAI: "Cloudflare AI",
- common.ChannelTypeCohere: "Cohere",
- common.ChannelTypeStabilityAI: "Stability AI",
- common.ChannelTypeCoze: "Coze",
- common.ChannelTypeOllama: "Ollama",
- common.ChannelTypeHunyuan: "Hunyuan",
+ config.ChannelTypeOpenAI: "OpenAI",
+ config.ChannelTypeAnthropic: "Anthropic",
+ config.ChannelTypeBaidu: "Baidu",
+ config.ChannelTypePaLM: "Google PaLM",
+ config.ChannelTypeGemini: "Google Gemini",
+ config.ChannelTypeZhipu: "Zhipu",
+ config.ChannelTypeAli: "Ali",
+ config.ChannelTypeXunfei: "Xunfei",
+ config.ChannelType360: "360",
+ config.ChannelTypeTencent: "Tencent",
+ config.ChannelTypeBaichuan: "Baichuan",
+ config.ChannelTypeMiniMax: "MiniMax",
+ config.ChannelTypeDeepseek: "Deepseek",
+ config.ChannelTypeMoonshot: "Moonshot",
+ config.ChannelTypeMistral: "Mistral",
+ config.ChannelTypeGroq: "Groq",
+ config.ChannelTypeLingyi: "Lingyiwanwu",
+ config.ChannelTypeMidjourney: "Midjourney",
+ config.ChannelTypeCloudflareAI: "Cloudflare AI",
+ config.ChannelTypeCohere: "Cohere",
+ config.ChannelTypeStabilityAI: "Stability AI",
+ config.ChannelTypeCoze: "Coze",
+ config.ChannelTypeOllama: "Ollama",
+ config.ChannelTypeHunyuan: "Hunyuan",
}
}
diff --git a/router/main.go b/router/main.go
index 6702a7c5..b0ed0a64 100644
--- a/router/main.go
+++ b/router/main.go
@@ -4,7 +4,7 @@ import (
"embed"
"fmt"
"net/http"
- "one-api/common"
+ "one-api/common/config"
"one-api/common/logger"
"strings"
@@ -17,7 +17,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
SetDashboardRouter(router)
SetRelayRouter(router)
frontendBaseUrl := viper.GetString("frontend_base_url")
- if common.IsMasterNode && frontendBaseUrl != "" {
+ if config.IsMasterNode && frontendBaseUrl != "" {
frontendBaseUrl = ""
logger.SysLog("FRONTEND_BASE_URL is ignored on master node")
}