From 71171c63f5ac2be88eb0be9f56648d959dfcaef2 Mon Sep 17 00:00:00 2001 From: Buer <42402987+MartialBE@users.noreply.github.com> Date: Wed, 20 Mar 2024 14:12:47 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20support=20configuration=20f?= =?UTF-8?q?ile=20(#117)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ♻️ refactor: move file directory * ♻️ refactor: move file directory * ♻️ refactor: support multiple config methods * 🔥 del: remove unused code * 💩 refactor: Refactor channel management and synchronization * 💄 improve: add channel website * ✨ feat: allow recording 0 consumption --- .gitignore | 5 +- common/config/config.go | 54 +++++++++ common/config/flag.go | 49 ++++++++ common/constants.go | 47 +------- common/custom-event.go | 82 ------------- common/database.go | 3 - common/embed-file-system.go | 32 ----- common/init.go | 69 ----------- common/logger.go | 70 ++++++++--- common/redis.go | 25 ++-- common/requester/http_client.go | 9 +- common/{ => requester}/marshaller.go | 2 +- common/requester/request_builder.go | 5 +- common/telegram/common.go | 14 ++- common/utils.go | 25 ++-- config.example.yaml | 44 +++++++ controller/channel-billing.go | 4 + controller/channel-test.go | 4 + go.mod | 20 +++- go.sum | 36 ++++++ main.go | 110 +++++++----------- middleware/rate-limit.go | 32 ++++- middleware/telegram.go | 4 +- model/balancer.go | 32 ++--- model/cache.go | 13 +-- model/channel.go | 36 +++++- model/main.go | 29 ++++- model/option.go | 24 ++-- {controller/relay => relay}/base.go | 0 {controller/relay => relay}/chat.go | 0 controller/relay/utils.go => relay/common.go | 0 {controller/relay => relay}/completions.go | 0 {controller/relay => relay}/embeddings.go | 0 {controller/relay => relay}/image-edits.go | 0 .../relay => relay}/image-generations.go | 0 .../relay => relay}/image-variationsy.go | 0 {controller/relay => relay}/main.go | 9 +- {controller => relay}/model.go | 2 +- {controller/relay => relay}/moderations.go | 0 {controller/relay => relay}/speech.go | 0 {controller/relay => relay}/transcriptions.go | 0 {controller/relay => relay}/translations.go | 0 {controller/relay => relay/util}/quota.go | 61 +++++----- router/api-router.go | 5 +- router/main.go | 7 +- router/relay-router.go | 6 +- router/web-router.go | 10 +- web/src/constants/ChannelConstants.js | 63 ++++++---- web/src/views/Channel/component/TableRow.js | 14 ++- web/src/views/Log/component/TableRow.js | 6 +- 50 files changed, 581 insertions(+), 481 deletions(-) create mode 100644 common/config/config.go create mode 100644 common/config/flag.go delete mode 100644 common/custom-event.go delete mode 100644 common/embed-file-system.go delete mode 100644 common/init.go rename common/{ => requester}/marshaller.go (92%) create mode 100644 config.example.yaml rename {controller/relay => relay}/base.go (100%) rename {controller/relay => relay}/chat.go (100%) rename controller/relay/utils.go => relay/common.go (100%) rename {controller/relay => relay}/completions.go (100%) rename {controller/relay => relay}/embeddings.go (100%) rename {controller/relay => relay}/image-edits.go (100%) rename {controller/relay => relay}/image-generations.go (100%) rename {controller/relay => relay}/image-variationsy.go (100%) rename {controller/relay => relay}/main.go (92%) rename {controller => relay}/model.go (99%) rename {controller/relay => relay}/moderations.go (100%) rename {controller/relay => relay}/speech.go (100%) rename {controller/relay => relay}/transcriptions.go (100%) rename {controller/relay => relay}/translations.go (100%) rename {controller/relay => relay/util}/quota.go (70%) diff --git a/.gitignore b/.gitignore index 1921ddf3..c84fb556 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,7 @@ logs data tmp/ /test/ -.env \ No newline at end of file +.env +common/plugin/ +common/balancer/ +config.yaml \ No newline at end of file diff --git a/common/config/config.go b/common/config/config.go new file mode 100644 index 00000000..e24e27a6 --- /dev/null +++ b/common/config/config.go @@ -0,0 +1,54 @@ +package config + +import ( + "strings" + "time" + + "one-api/common" + + "github.com/spf13/viper" +) + +func InitConf() { + flagConfig() + defaultConfig() + setConfigFile() + setEnv() + + if viper.GetBool("debug") { + common.SysLog("running in debug mode") + } + + common.IsMasterNode = viper.GetString("NODE_TYPE") != "slave" + common.RequestInterval = time.Duration(viper.GetInt("POLLING_INTERVAL")) * time.Second + common.SessionSecret = common.GetOrDefault("SESSION_SECRET", common.SessionSecret) +} + +func setConfigFile() { + if !common.IsFileExist(*config) { + return + } + + viper.SetConfigFile(*config) + if err := viper.ReadInConfig(); err != nil { + panic(err) + } +} + +func setEnv() { + viper.AutomaticEnv() + viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) +} + +func defaultConfig() { + viper.SetDefault("port", "3000") + viper.SetDefault("gin_mode", "release") + viper.SetDefault("log_dir", "./logs") + viper.SetDefault("sqlite_path", "one-api.db") + viper.SetDefault("sqlite_busy_timeout", 3000) + viper.SetDefault("sync_frequency", 600) + viper.SetDefault("batch_update_interval", 5) + viper.SetDefault("global.api_rate_limit", 180) + viper.SetDefault("global.web_rate_limit", 100) + viper.SetDefault("connect_timeout", 5) +} diff --git a/common/config/flag.go b/common/config/flag.go new file mode 100644 index 00000000..85f56107 --- /dev/null +++ b/common/config/flag.go @@ -0,0 +1,49 @@ +package config + +import ( + "flag" + "fmt" + "one-api/common" + "os" + + "github.com/spf13/viper" +) + +var ( + port = flag.Int("port", 0, "the listening port") + printVersion = flag.Bool("version", false, "print version and exit") + printHelp = flag.Bool("help", false, "print help and exit") + logDir = flag.String("log-dir", "", "specify the log directory") + config = flag.String("config", "config.yaml", "specify the config.yaml path") +) + +func flagConfig() { + flag.Parse() + + if *printVersion { + fmt.Println(common.Version) + os.Exit(0) + } + + if *printHelp { + help() + os.Exit(0) + } + + if *port != 0 { + viper.Set("port", *port) + } + + if *logDir != "" { + viper.Set("log_dir", *logDir) + } + +} + +func help() { + fmt.Println("One API " + common.Version + " - All in one API service for OpenAI API.") + fmt.Println("Copyright (C) 2024 MartialBE. All rights reserved.") + fmt.Println("Original copyright holder: JustSong") + fmt.Println("GitHub: https://github.com/MartialBE/one-api") + fmt.Println("Usage: one-api [--port ] [--log-dir ] [--config ] [--version] [--help]") +} diff --git a/common/constants.go b/common/constants.go index f8ea101c..97ed4212 100644 --- a/common/constants.go +++ b/common/constants.go @@ -1,8 +1,6 @@ package common import ( - "os" - "strconv" "sync" "time" @@ -52,8 +50,7 @@ var EmailDomainWhitelist = []string{ "foxmail.com", } -var DebugEnabled = os.Getenv("DEBUG") == "true" -var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" +var MemoryCacheEnabled = false var LogConsumeEnabled = true @@ -88,22 +85,12 @@ var RetryCooldownSeconds = 5 var RootUserEmail = "" -var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" +var IsMasterNode = true -var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) -var RequestInterval = time.Duration(requestInterval) * time.Second - -var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second +var RequestInterval time.Duration var BatchUpdateEnabled = false -var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) - -var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 600) // unit is second -var ConnectTimeout = GetOrDefault("CONNECT_TIMEOUT", 5) // unit is second - -const ( - RequestIdKey = "X-Oneapi-Request-Id" -) +var BatchUpdateInterval = 5 const ( RoleGuestUser = 0 @@ -112,32 +99,6 @@ const ( RoleRootUser = 100 ) -var ( - FileUploadPermission = RoleGuestUser - FileDownloadPermission = RoleGuestUser - ImageUploadPermission = RoleGuestUser - ImageDownloadPermission = RoleGuestUser -) - -// All duration's unit is seconds -// Shouldn't larger then RateLimitKeyExpirationDuration -var ( - GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) - GlobalApiRateLimitDuration int64 = 3 * 60 - - GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 100) - GlobalWebRateLimitDuration int64 = 3 * 60 - - UploadRateLimitNum = 10 - UploadRateLimitDuration int64 = 60 - - DownloadRateLimitNum = 10 - DownloadRateLimitDuration int64 = 60 - - CriticalRateLimitNum = 20 - CriticalRateLimitDuration int64 = 20 * 60 -) - var RateLimitKeyExpirationDuration = 20 * time.Minute const ( diff --git a/common/custom-event.go b/common/custom-event.go deleted file mode 100644 index 69da4bc4..00000000 --- a/common/custom-event.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2014 Manu Martinez-Almeida. All rights reserved. -// Use of this source code is governed by a MIT style -// license that can be found in the LICENSE file. - -package common - -import ( - "fmt" - "io" - "net/http" - "strings" -) - -type stringWriter interface { - io.Writer - writeString(string) (int, error) -} - -type stringWrapper struct { - io.Writer -} - -func (w stringWrapper) writeString(str string) (int, error) { - return w.Writer.Write([]byte(str)) -} - -func checkWriter(writer io.Writer) stringWriter { - if w, ok := writer.(stringWriter); ok { - return w - } else { - return stringWrapper{writer} - } -} - -// Server-Sent Events -// W3C Working Draft 29 October 2009 -// http://www.w3.org/TR/2009/WD-eventsource-20091029/ - -var contentType = []string{"text/event-stream"} -var noCache = []string{"no-cache"} - -var fieldReplacer = strings.NewReplacer( - "\n", "\\n", - "\r", "\\r") - -var dataReplacer = strings.NewReplacer( - "\n", "\ndata:", - "\r", "\\r") - -type CustomEvent struct { - Event string - Id string - Retry uint - Data interface{} -} - -func encode(writer io.Writer, event CustomEvent) error { - w := checkWriter(writer) - return writeData(w, event.Data) -} - -func writeData(w stringWriter, data interface{}) error { - dataReplacer.WriteString(w, fmt.Sprint(data)) - if strings.HasPrefix(data.(string), "data") { - w.writeString("\n\n") - } - return nil -} - -func (r CustomEvent) Render(w http.ResponseWriter) error { - r.WriteContentType(w) - return encode(w, r) -} - -func (r CustomEvent) WriteContentType(w http.ResponseWriter) { - header := w.Header() - header["Content-Type"] = contentType - - if _, exist := header["Cache-Control"]; !exist { - header["Cache-Control"] = noCache - } -} diff --git a/common/database.go b/common/database.go index 76f2cd55..f3863bf6 100644 --- a/common/database.go +++ b/common/database.go @@ -2,6 +2,3 @@ package common var UsingSQLite = false var UsingPostgreSQL = false - -var SQLitePath = "one-api.db" -var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/common/embed-file-system.go b/common/embed-file-system.go deleted file mode 100644 index 3ea02cf8..00000000 --- a/common/embed-file-system.go +++ /dev/null @@ -1,32 +0,0 @@ -package common - -import ( - "embed" - "github.com/gin-contrib/static" - "io/fs" - "net/http" -) - -// Credit: https://github.com/gin-contrib/static/issues/19 - -type embedFileSystem struct { - http.FileSystem -} - -func (e embedFileSystem) Exists(prefix string, path string) bool { - _, err := e.Open(path) - if err != nil { - return false - } - return true -} - -func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem { - efs, err := fs.Sub(fsEmbed, targetPath) - if err != nil { - panic(err) - } - return embedFileSystem{ - FileSystem: http.FS(efs), - } -} diff --git a/common/init.go b/common/init.go deleted file mode 100644 index 5e852638..00000000 --- a/common/init.go +++ /dev/null @@ -1,69 +0,0 @@ -package common - -import ( - "flag" - "fmt" - "log" - "os" - "path/filepath" - - "github.com/joho/godotenv" -) - -var ( - Port = flag.Int("port", 3000, "the listening port") - PrintVersion = flag.Bool("version", false, "print version and exit") - PrintHelp = flag.Bool("help", false, "print help and exit") - LogDir = flag.String("log-dir", "./logs", "specify the log directory") -) - -func printHelp() { - fmt.Println("One API " + Version + " - All in one API service for OpenAI API.") - fmt.Println("Copyright (C) 2023 MartialBE. All rights reserved.") - fmt.Println("Original copyright holder: JustSong") - fmt.Println("GitHub: https://github.com/MartialBE/one-api") - fmt.Println("Usage: one-api [--port ] [--log-dir ] [--version] [--help]") -} - -func init() { - // 加载.env文件 - err := godotenv.Load() - if err != nil { - SysLog("failed to load .env file: " + err.Error()) - } - flag.Parse() - - if *PrintVersion { - fmt.Println(Version) - os.Exit(0) - } - - if *PrintHelp { - printHelp() - os.Exit(0) - } - - if os.Getenv("SESSION_SECRET") != "" { - if os.Getenv("SESSION_SECRET") == "random_string" { - SysError("SESSION_SECRET is set to an example value, please change it to a random string.") - } else { - SessionSecret = os.Getenv("SESSION_SECRET") - } - } - if os.Getenv("SQLITE_PATH") != "" { - SQLitePath = os.Getenv("SQLITE_PATH") - } - if *LogDir != "" { - var err error - *LogDir, err = filepath.Abs(*LogDir) - if err != nil { - log.Fatal(err) - } - if _, err := os.Stat(*LogDir); os.IsNotExist(err) { - err = os.Mkdir(*LogDir, 0777) - if err != nil { - log.Fatal(err) - } - } - } -} diff --git a/common/logger.go b/common/logger.go index 61627217..d2548679 100644 --- a/common/logger.go +++ b/common/logger.go @@ -3,13 +3,15 @@ package common import ( "context" "fmt" - "github.com/gin-gonic/gin" "io" "log" "os" "path/filepath" "sync" "time" + + "github.com/gin-gonic/gin" + "github.com/spf13/viper" ) const ( @@ -17,6 +19,9 @@ const ( loggerWarn = "WARN" loggerError = "ERR" ) +const ( + RequestIdKey = "X-Oneapi-Request-Id" +) const maxLogCount = 1000000 @@ -24,25 +29,54 @@ var logCount int var setupLogLock sync.Mutex var setupLogWorking bool +var defaultLogDir = "./logs" + func SetupLogger() { - if *LogDir != "" { - ok := setupLogLock.TryLock() - if !ok { - log.Println("setup log is already working") - return - } - defer func() { - setupLogLock.Unlock() - setupLogWorking = false - }() - logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) - fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - log.Fatal("failed to open log file") - } - gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) - gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) + logDir := getLogDir() + if logDir == "" { + return } + + ok := setupLogLock.TryLock() + if !ok { + log.Println("setup log is already working") + return + } + defer func() { + setupLogLock.Unlock() + setupLogWorking = false + }() + logPath := filepath.Join(logDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Fatal("failed to open log file") + } + gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) + gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) +} + +func getLogDir() string { + logDir := viper.GetString("log_dir") + if logDir == "" { + logDir = defaultLogDir + } + + var err error + logDir, err = filepath.Abs(viper.GetString("log_dir")) + if err != nil { + log.Fatal(err) + return "" + } + + if !IsFileExist(logDir) { + err = os.Mkdir(logDir, 0777) + if err != nil { + log.Fatal(err) + return "" + } + } + + return logDir } func SysLog(s string) { diff --git a/common/redis.go b/common/redis.go index 12c477b8..8657dff1 100644 --- a/common/redis.go +++ b/common/redis.go @@ -2,30 +2,32 @@ package common import ( "context" - "github.com/go-redis/redis/v8" - "os" "time" + + "github.com/go-redis/redis/v8" + "github.com/spf13/viper" ) var RDB *redis.Client -var RedisEnabled = true +var RedisEnabled = false // InitRedisClient This function is called after init() func InitRedisClient() (err error) { - if os.Getenv("REDIS_CONN_STRING") == "" { - RedisEnabled = false + redisConn := viper.GetString("REDIS_CONN_STRING") + + if redisConn == "" { SysLog("REDIS_CONN_STRING not set, Redis is not enabled") return nil } - if os.Getenv("SYNC_FREQUENCY") == "" { - RedisEnabled = false + if viper.GetInt("SYNC_FREQUENCY") == 0 { SysLog("SYNC_FREQUENCY not set, Redis is disabled") return nil } SysLog("Redis is enabled") - opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) + opt, err := redis.ParseURL(redisConn) if err != nil { FatalLog("failed to parse Redis connection string: " + err.Error()) + return } RDB = redis.NewClient(opt) @@ -35,12 +37,17 @@ func InitRedisClient() (err error) { _, err = RDB.Ping(ctx).Result() if err != nil { FatalLog("Redis ping test failed: " + err.Error()) + } else { + RedisEnabled = true + // for compatibility with old versions + MemoryCacheEnabled = true } + return err } func ParseRedisOption() *redis.Options { - opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) + opt, err := redis.ParseURL(viper.GetString("REDIS_CONN_STRING")) if err != nil { FatalLog("failed to parse Redis connection string: " + err.Error()) } diff --git a/common/requester/http_client.go b/common/requester/http_client.go index d049d3c2..a441a42f 100644 --- a/common/requester/http_client.go +++ b/common/requester/http_client.go @@ -39,7 +39,7 @@ func proxyFunc(req *http.Request) (*url.URL, error) { func socks5ProxyFunc(ctx context.Context, network, addr string) (net.Conn, error) { // 设置TCP超时 dialer := &net.Dialer{ - Timeout: time.Duration(common.ConnectTimeout) * time.Second, + Timeout: time.Duration(common.GetOrDefault("CONNECT_TIMEOUT", 5)) * time.Second, KeepAlive: 30 * time.Second, } @@ -64,7 +64,7 @@ func socks5ProxyFunc(ctx context.Context, network, addr string) (net.Conn, error var HTTPClient *http.Client -func init() { +func InitHttpClient() { trans := &http.Transport{ DialContext: socks5ProxyFunc, Proxy: proxyFunc, @@ -74,7 +74,8 @@ func init() { Transport: trans, } - if common.RelayTimeout != 0 { - HTTPClient.Timeout = time.Duration(common.RelayTimeout) * time.Second + relayTimeout := common.GetOrDefault("RELAY_TIMEOUT", 600) + if relayTimeout != 0 { + HTTPClient.Timeout = time.Duration(relayTimeout) * time.Second } } diff --git a/common/marshaller.go b/common/requester/marshaller.go similarity index 92% rename from common/marshaller.go rename to common/requester/marshaller.go index 0ef9d5da..4577af0c 100644 --- a/common/marshaller.go +++ b/common/requester/marshaller.go @@ -1,4 +1,4 @@ -package common +package requester import ( "encoding/json" diff --git a/common/requester/request_builder.go b/common/requester/request_builder.go index a7f4ee83..bcf7920a 100644 --- a/common/requester/request_builder.go +++ b/common/requester/request_builder.go @@ -5,7 +5,6 @@ import ( "context" "io" "net/http" - "one-api/common" ) type RequestBuilder interface { @@ -13,12 +12,12 @@ type RequestBuilder interface { } type HTTPRequestBuilder struct { - marshaller common.Marshaller + marshaller Marshaller } func NewRequestBuilder() *HTTPRequestBuilder { return &HTTPRequestBuilder{ - marshaller: &common.JSONMarshaller{}, + marshaller: &JSONMarshaller{}, } } diff --git a/common/telegram/common.go b/common/telegram/common.go index 4d92cee4..106138af 100644 --- a/common/telegram/common.go +++ b/common/telegram/common.go @@ -5,7 +5,6 @@ import ( "fmt" "one-api/common" "one-api/model" - "os" "strings" "time" @@ -14,6 +13,7 @@ import ( "github.com/PaulSonOfLars/gotgbot/v2/ext/handlers" "github.com/PaulSonOfLars/gotgbot/v2/ext/handlers/filters/callbackquery" "github.com/PaulSonOfLars/gotgbot/v2/ext/handlers/filters/message" + "github.com/spf13/viper" ) var TGupdater *ext.Updater @@ -28,13 +28,14 @@ func InitTelegramBot() { return } - if os.Getenv("TG_BOT_API_KEY") == "" { + botKey := viper.GetString("TG_BOT_API_KEY") + if botKey == "" { common.SysLog("Telegram bot is not enabled") return } var err error - TGBot, err = gotgbot.NewBot(os.Getenv("TG_BOT_API_KEY"), nil) + TGBot, err = gotgbot.NewBot(botKey, nil) if err != nil { common.SysLog("failed to create new telegram bot: " + err.Error()) return @@ -47,15 +48,16 @@ func InitTelegramBot() { } func StartTelegramBot() { - if os.Getenv("TG_WEBHOOK_SECRET") != "" { + botWebhook := viper.GetString("TG_WEBHOOK_SECRET") + if botWebhook != "" { if common.ServerAddress == "" { common.SysLog("Telegram bot is not enabled: Server address is not set") StopTelegramBot() return } - TGWebHookSecret = os.Getenv("TG_WEBHOOK_SECRET") + TGWebHookSecret = botWebhook serverAddress := strings.TrimSuffix(common.ServerAddress, "/") - urlPath := fmt.Sprintf("/api/telegram/%s", os.Getenv("TG_BOT_API_KEY")) + urlPath := fmt.Sprintf("/api/telegram/%s", viper.GetString("TG_BOT_API_KEY")) webHookOpts := &ext.AddWebhookOpts{ SecretToken: TGWebHookSecret, diff --git a/common/utils.go b/common/utils.go index 21bec8f5..f9e574b9 100644 --- a/common/utils.go +++ b/common/utils.go @@ -2,7 +2,6 @@ package common import ( "fmt" - "github.com/google/uuid" "html/template" "log" "math/rand" @@ -13,6 +12,9 @@ import ( "strconv" "strings" "time" + + "github.com/google/uuid" + "github.com/spf13/viper" ) func OpenBrowser(url string) { @@ -184,16 +186,14 @@ func Max(a int, b int) int { } } -func GetOrDefault(env string, defaultValue int) int { - if env == "" || os.Getenv(env) == "" { - return defaultValue +func GetOrDefault[T any](env string, defaultValue T) T { + if viper.IsSet(env) { + value := viper.Get(env) + if v, ok := value.(T); ok { + return v + } } - num, err := strconv.Atoi(os.Getenv(env)) - if err != nil { - SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) - return defaultValue - } - return num + return defaultValue } func MessageWithRequestId(message string, id string) string { @@ -207,3 +207,8 @@ func String2Int(str string) int { } return num } + +func IsFileExist(path string) bool { + _, err := os.Stat(path) + return err == nil || os.IsExist(err) +} diff --git a/config.example.yaml b/config.example.yaml new file mode 100644 index 00000000..cf750964 --- /dev/null +++ b/config.example.yaml @@ -0,0 +1,44 @@ +# 服务器设置 +port: 3000 # 服务端口 +gin_mode: "release" # gin 模式,可选值为 "release" 或 "debug",默认为 "release"。 +debug: false # 是否启用调试模式,启用后将输出更多日志信息。 +log_dir: "./logs" # 日志目录 +session_secret: "" # 会话密钥,未设置则使用随机值。 + +# 数据库设置 +sql_dsn: "" # 设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL +sqlite_path: "one-api.db" # sqlite 数据库文件路径 +sqlite_busy_timeout: 3000 # sqlite 数据库繁忙超时时间,单位为毫秒,默认为 3000。 +redis_conn_string: "" # 设置之后将使用指定 Redis 作为缓存,格式为 "redis://default:redispw@localhost:49153",未设置则不使用 Redis。 + +memory_cache_enabled: false # 是否启用内存缓存,启用后将缓存部分数据,减少数据库查询次数。 +sync_frequency: 600 # 在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 600 秒 +node_type: "master" # 节点类型,可选值为 "master" 或 "slave",默认为 "master"。 +frontend_base_url: "" # 设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 +polling_interval: 0 # 批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 +batch_update_interval: 5 # 批量更新聚合的时间间隔,单位为秒,默认为 5。 +batch_update_enabled: false # 启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 true 和 false,未设置则默认为 false + +# 全局设置 +global: + api_rate_limit: 180 # 全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 180。 + web_rate_limit: 100 # 全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 100。 + +# 频道更新设置 +channel: + update_frequency: 0 # 设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 + test_frequency: 0 # 设置之后将定期检查渠道,单位为分钟,未设置则不进行检查 + +# 连接设置 +relay_timeout: 0 # 中继请求超时时间,单位为秒,默认为 0。 +connect_timeout: 5 # 连接超时时间,单位为秒,默认为 5。 + +# 默认程序启动时会联网下载一些通用的词元的编码,如:gpt-3.5-turbo,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 +tiktoken_cache_dir: "" +# 目前该配置作用与 TIKTOKEN_CACHE_DIR 一致,但是优先级没有它高。 +data_gym_cache_dir: "" + +# Telegram设置 +tg: + bot_api_key: "" # 你的 Telegram bot 的 API 密钥 + webhook_secret: "" # 你的 webhook 密钥。你可以自定义这个密钥。如果设置了这个密钥,将使用webhook的方式接收消息,否则使用轮询(Polling)的方式。 \ No newline at end of file diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 5f838067..51d07323 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -147,6 +147,10 @@ func UpdateAllChannelsBalance(c *gin.Context) { } func AutomaticallyUpdateChannels(frequency int) { + if frequency <= 0 { + return + } + for { time.Sleep(time.Duration(frequency) * time.Minute) common.SysLog("updating all channels") diff --git a/controller/channel-test.go b/controller/channel-test.go index 1c6206d3..5a8e9bda 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -217,6 +217,10 @@ func TestAllChannels(c *gin.Context) { } func AutomaticallyTestChannels(frequency int) { + if frequency <= 0 { + return + } + for { time.Sleep(time.Duration(frequency) * time.Minute) common.SysLog("testing all channels") diff --git a/go.mod b/go.mod index 5354539e..290c82dc 100644 --- a/go.mod +++ b/go.mod @@ -29,8 +29,24 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/chenzhuoyu/iasm v0.9.1 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/sagikazarmark/locafero v0.4.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.6.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/spf13/viper v1.18.2 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/multierr v1.9.0 // indirect + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/sync v0.6.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect ) require ( @@ -38,7 +54,7 @@ require ( github.com/bytedance/sonic v1.11.3 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.0 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect @@ -64,7 +80,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.1.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect golang.org/x/arch v0.7.0 // indirect diff --git a/go.sum b/go.sum index 4d90b3ce..a426c315 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= @@ -41,6 +43,8 @@ github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cn github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= @@ -125,6 +129,8 @@ github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWm github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= @@ -169,6 +175,8 @@ github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNa github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lib/pq v1.10.3/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= @@ -182,6 +190,8 @@ github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJK github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/memcachier/mc v2.0.1+incompatible/go.mod h1:7bkvFE61leUBvXz+yxsOnGBQSZpBSPIMUQSmmSHvuXc= github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLgZiaenE= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -204,10 +214,26 @@ github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAc github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= +github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= +github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= +github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -224,6 +250,8 @@ github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= @@ -234,6 +262,10 @@ github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4d github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= @@ -246,6 +278,8 @@ golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= @@ -294,6 +328,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/main.go b/main.go index 5a13aa51..fe871c1f 100644 --- a/main.go +++ b/main.go @@ -4,17 +4,18 @@ import ( "embed" "fmt" "one-api/common" + "one-api/common/config" + "one-api/common/requester" "one-api/common/telegram" "one-api/controller" "one-api/middleware" "one-api/model" "one-api/router" - "os" - "strconv" "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" + "github.com/spf13/viper" ) //go:embed web/build @@ -24,87 +25,66 @@ var buildFS embed.FS var indexPage []byte func main() { + config.InitConf() common.SetupLogger() common.SysLog("One API " + common.Version + " started") - if os.Getenv("GIN_MODE") != "debug" { - gin.SetMode(gin.ReleaseMode) - } - if common.DebugEnabled { - common.SysLog("running in debug mode") - } // Initialize SQL Database - err := model.InitDB() - if err != nil { - common.FatalLog("failed to initialize database: " + err.Error()) - } - defer func() { - err := model.CloseDB() - if err != nil { - common.FatalLog("failed to close database: " + err.Error()) - } - }() - + model.SetupDB() + defer model.CloseDB() // Initialize Redis - err = common.InitRedisClient() - if err != nil { - common.FatalLog("failed to initialize Redis: " + err.Error()) - } - + common.InitRedisClient() // Initialize options model.InitOptionMap() - if common.RedisEnabled { - // for compatibility with old versions - common.MemoryCacheEnabled = true - } - if common.MemoryCacheEnabled { - common.SysLog("memory cache enabled") - common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) - model.InitChannelGroup() - } - if common.MemoryCacheEnabled { - go model.SyncOptions(common.SyncFrequency) - go model.SyncChannelGroup(common.SyncFrequency) - } - if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { - frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) - if err != nil { - common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) - } - go controller.AutomaticallyUpdateChannels(frequency) - } - if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { - frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) - if err != nil { - common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) - } - go controller.AutomaticallyTestChannels(frequency) - } - if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { - common.BatchUpdateEnabled = true - common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") - model.InitBatchUpdater() - } + initMemoryCache() + initSync() + common.InitTokenEncoders() + requester.InitHttpClient() // Initialize Telegram bot telegram.InitTelegramBot() - // Initialize HTTP server + initHttpServer() +} + +func initMemoryCache() { + if viper.GetBool("MEMORY_CACHE_ENABLED") { + common.MemoryCacheEnabled = true + } + + if !common.MemoryCacheEnabled { + return + } + + syncFrequency := viper.GetInt("SYNC_FREQUENCY") + model.TokenCacheSeconds = syncFrequency + + common.SysLog("memory cache enabled") + common.SysError(fmt.Sprintf("sync frequency: %d seconds", syncFrequency)) + go model.SyncOptions(syncFrequency) +} + +func initSync() { + go controller.AutomaticallyUpdateChannels(viper.GetInt("CHANNEL_UPDATE_FREQUENCY")) + go controller.AutomaticallyTestChannels(viper.GetInt("CHANNEL_TEST_FREQUENCY")) +} + +func initHttpServer() { + if viper.GetString("gin_mode") != "debug" { + gin.SetMode(gin.ReleaseMode) + } + server := gin.New() server.Use(gin.Recovery()) - // This will cause SSE not to work!!! - //server.Use(gzip.Gzip(gzip.DefaultCompression)) server.Use(middleware.RequestId()) middleware.SetUpLogger(server) - // Initialize session store + store := cookie.NewStore([]byte(common.SessionSecret)) server.Use(sessions.Sessions("session", store)) router.SetRouter(server, buildFS, indexPage) - var port = os.Getenv("PORT") - if port == "" { - port = strconv.Itoa(*common.Port) - } - err = server.Run(":" + port) + port := viper.GetString("PORT") + + err := server.Run(":" + port) if err != nil { common.FatalLog("failed to start HTTP server: " + err.Error()) } diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go index 8e5cff6c..65b358cb 100644 --- a/middleware/rate-limit.go +++ b/middleware/rate-limit.go @@ -3,16 +3,36 @@ package middleware import ( "context" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "time" + + "github.com/gin-gonic/gin" ) var timeFormat = "2006-01-02T15:04:05.000Z" var inMemoryRateLimiter common.InMemoryRateLimiter +// All duration's unit is seconds +// Shouldn't larger then RateLimitKeyExpirationDuration +var ( + GlobalApiRateLimitNum = 180 + GlobalApiRateLimitDuration int64 = 3 * 60 + + GlobalWebRateLimitNum = 100 + GlobalWebRateLimitDuration int64 = 3 * 60 + + UploadRateLimitNum = 10 + UploadRateLimitDuration int64 = 60 + + DownloadRateLimitNum = 10 + DownloadRateLimitDuration int64 = 60 + + CriticalRateLimitNum = 20 + CriticalRateLimitDuration int64 = 20 * 60 +) + func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { ctx := context.Background() rdb := common.RDB @@ -83,21 +103,21 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi } func GlobalWebRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW") + return rateLimitFactory(common.GetOrDefault("GLOBAL_WEB_RATE_LIMIT", GlobalWebRateLimitNum), GlobalWebRateLimitDuration, "GW") } func GlobalAPIRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA") + return rateLimitFactory(common.GetOrDefault("GLOBAL_API_RATE_LIMIT", GlobalApiRateLimitNum), GlobalApiRateLimitDuration, "GA") } func CriticalRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT") + return rateLimitFactory(CriticalRateLimitNum, CriticalRateLimitDuration, "CT") } func DownloadRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW") + return rateLimitFactory(DownloadRateLimitNum, DownloadRateLimitDuration, "DW") } func UploadRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP") + return rateLimitFactory(UploadRateLimitNum, UploadRateLimitDuration, "UP") } diff --git a/middleware/telegram.go b/middleware/telegram.go index d6d6559a..36a972b9 100644 --- a/middleware/telegram.go +++ b/middleware/telegram.go @@ -2,16 +2,16 @@ package middleware import ( "one-api/common/telegram" - "os" "github.com/gin-gonic/gin" + "github.com/spf13/viper" ) func Telegram() func(c *gin.Context) { return func(c *gin.Context) { token := c.Param("token") - if !telegram.TGEnabled || telegram.TGWebHookSecret == "" || token == "" || token != os.Getenv("TG_BOT_API_KEY") { + if !telegram.TGEnabled || telegram.TGWebHookSecret == "" || token == "" || token != viper.GetString("TG_BOT_API_KEY") { c.String(404, "Page not found") c.Abort() return diff --git a/model/balancer.go b/model/balancer.go index 38a90c0e..e4d03003 100644 --- a/model/balancer.go +++ b/model/balancer.go @@ -34,7 +34,7 @@ func (cc *ChannelsChooser) Cooldowns(channelId int) bool { return true } -func (cc *ChannelsChooser) Balancer(channelIds []int) *Channel { +func (cc *ChannelsChooser) balancer(channelIds []int) *Channel { nowTime := time.Now().Unix() totalWeight := 0 @@ -67,9 +67,9 @@ func (cc *ChannelsChooser) Balancer(channelIds []int) *Channel { return nil } -func (cc *ChannelsChooser) Next(group, model string) (*Channel, error) { +func (cc *ChannelsChooser) Next(group, modelName string) (*Channel, error) { if !common.MemoryCacheEnabled { - return GetRandomSatisfiedChannel(group, model) + return GetRandomSatisfiedChannel(group, modelName) } cc.RLock() defer cc.RUnlock() @@ -77,17 +77,17 @@ func (cc *ChannelsChooser) Next(group, model string) (*Channel, error) { return nil, errors.New("group not found") } - if _, ok := cc.Rule[group][model]; !ok { + if _, ok := cc.Rule[group][modelName]; !ok { return nil, errors.New("model not found") } - channelsPriority := cc.Rule[group][model] + channelsPriority := cc.Rule[group][modelName] if len(channelsPriority) == 0 { return nil, errors.New("channel not found") } for _, priority := range channelsPriority { - channel := cc.Balancer(priority) + channel := cc.balancer(priority) if channel != nil { return channel, nil } @@ -118,7 +118,7 @@ func (cc *ChannelsChooser) GetGroupModels(group string) ([]string, error) { var ChannelGroup = ChannelsChooser{} -func InitChannelGroup() { +func (cc *ChannelsChooser) Load() { var channels []*Channel DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) @@ -160,17 +160,9 @@ func InitChannelGroup() { newGroup[ability.Group][ability.Model] = append(newGroup[ability.Group][ability.Model], priorityIds) } - ChannelGroup.Lock() - ChannelGroup.Rule = newGroup - ChannelGroup.Channels = newChannels - ChannelGroup.Unlock() - common.SysLog("channels synced from database") -} - -func SyncChannelGroup(frequency int) { - for { - time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing channels from database") - InitChannelGroup() - } + cc.Lock() + cc.Rule = newGroup + cc.Channels = newChannels + cc.Unlock() + common.SysLog("channels Load success") } diff --git a/model/cache.go b/model/cache.go index 9edb8a64..7c3a8e14 100644 --- a/model/cache.go +++ b/model/cache.go @@ -9,10 +9,7 @@ import ( ) var ( - TokenCacheSeconds = common.SyncFrequency - UserId2GroupCacheSeconds = common.SyncFrequency - UserId2QuotaCacheSeconds = common.SyncFrequency - UserId2StatusCacheSeconds = common.SyncFrequency + TokenCacheSeconds = 0 ) func CacheGetTokenByKey(key string) (*Token, error) { @@ -55,7 +52,7 @@ func CacheGetUserGroup(id int) (group string, err error) { if err != nil { return "", err } - err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) + err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(TokenCacheSeconds)*time.Second) if err != nil { common.SysError("Redis set user group error: " + err.Error()) } @@ -73,7 +70,7 @@ func CacheGetUserQuota(id int) (quota int, err error) { if err != nil { return 0, err } - err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) + err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(TokenCacheSeconds)*time.Second) if err != nil { common.SysError("Redis set user quota error: " + err.Error()) } @@ -91,7 +88,7 @@ func CacheUpdateUserQuota(id int) error { if err != nil { return err } - err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) + err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(TokenCacheSeconds)*time.Second) return err } @@ -120,7 +117,7 @@ func CacheIsUserEnabled(userId int) (bool, error) { if userEnabled { enabled = "1" } - err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) + err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(TokenCacheSeconds)*time.Second) if err != nil { common.SysError("Redis set user enabled error: " + err.Error()) } diff --git a/model/channel.go b/model/channel.go index fa033d77..7dc58115 100644 --- a/model/channel.go +++ b/model/channel.go @@ -117,6 +117,8 @@ func BatchInsertChannels(channels []Channel) error { return err } } + + go ChannelGroup.Load() return nil } @@ -130,6 +132,10 @@ func BatchUpdateChannelsAzureApi(params *BatchChannelsParams) (int64, error) { if db.Error != nil { return 0, db.Error } + + if db.RowsAffected > 0 { + go ChannelGroup.Load() + } return db.RowsAffected, nil } @@ -152,10 +158,14 @@ func BatchDelModelChannels(params *BatchChannelsParams) (int64, error) { } channel.Models = strings.Join(modelsSlice, ",") - channel.Update(false) + channel.UpdateRaw(false) count++ } + if count > 0 { + go ChannelGroup.Load() + } + return count, nil } @@ -187,10 +197,26 @@ func (channel *Channel) Insert() error { return err } err = channel.AddAbilities() + + if err == nil { + go ChannelGroup.Load() + } + return err } func (channel *Channel) Update(overwrite bool) error { + + err := channel.UpdateRaw(overwrite) + + if err == nil { + go ChannelGroup.Load() + } + + return err +} + +func (channel *Channel) UpdateRaw(overwrite bool) error { var err error if overwrite { @@ -233,6 +259,9 @@ func (channel *Channel) Delete() error { return err } err = channel.DeleteAbilities() + if err == nil { + go ChannelGroup.Load() + } return err } @@ -245,6 +274,11 @@ func UpdateChannelStatusById(id int, status int) { if err != nil { common.SysError("failed to update channel status: " + err.Error()) } + + if err == nil { + + go ChannelGroup.Load() + } } func UpdateChannelUsedQuota(id int, quota int) { diff --git a/model/main.go b/model/main.go index b4a90338..90cd1b96 100644 --- a/model/main.go +++ b/model/main.go @@ -3,10 +3,11 @@ package model import ( "fmt" "one-api/common" - "os" + "strconv" "strings" "time" + "github.com/spf13/viper" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" @@ -15,6 +16,21 @@ import ( var DB *gorm.DB +func SetupDB() { + err := InitDB() + if err != nil { + common.FatalLog("failed to initialize database: " + err.Error()) + } + ChannelGroup.Load() + + if viper.GetBool("BATCH_UPDATE_ENABLED") { + common.BatchUpdateEnabled = true + common.BatchUpdateInterval = common.GetOrDefault("BATCH_UPDATE_INTERVAL", 5) + common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") + InitBatchUpdater() + } +} + func createRootAccountIfNeed() error { var user User //if user.Status != common.UserStatusEnabled { @@ -39,8 +55,8 @@ func createRootAccountIfNeed() error { } func chooseDB() (*gorm.DB, error) { - if os.Getenv("SQL_DSN") != "" { - dsn := os.Getenv("SQL_DSN") + if viper.IsSet("SQL_DSN") { + dsn := viper.GetString("SQL_DSN") if strings.HasPrefix(dsn, "postgres://") { // Use PostgreSQL common.SysLog("using PostgreSQL as database") @@ -61,8 +77,8 @@ func chooseDB() (*gorm.DB, error) { // Use SQLite common.SysLog("SQL_DSN not set, using SQLite as database") common.UsingSQLite = true - config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) - return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ + config := fmt.Sprintf("?_busy_timeout=%d", common.GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000)) + return gorm.Open(sqlite.Open(viper.GetString("sqlite_path")+config), &gorm.Config{ PrepareStmt: true, // precompile SQL }) } @@ -70,7 +86,7 @@ func chooseDB() (*gorm.DB, error) { func InitDB() (err error) { db, err := chooseDB() if err == nil { - if common.DebugEnabled { + if viper.GetBool("debug") { db = db.Debug() } DB = db @@ -78,6 +94,7 @@ func InitDB() (err error) { if err != nil { return err } + sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60))) diff --git a/model/option.go b/model/option.go index 24e9a20b..45979df7 100644 --- a/model/option.go +++ b/model/option.go @@ -27,10 +27,6 @@ func GetOption(key string) (option Option, err error) { func InitOptionMap() { common.OptionMapRWMutex.Lock() common.OptionMap = make(map[string]string) - common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) - common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) - common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) - common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) @@ -137,18 +133,14 @@ func UpdateOption(key string, value string) error { } var optionIntMap = map[string]*int{ - "FileUploadPermission": &common.FileUploadPermission, - "FileDownloadPermission": &common.FileDownloadPermission, - "ImageUploadPermission": &common.ImageUploadPermission, - "ImageDownloadPermission": &common.ImageDownloadPermission, - "SMTPPort": &common.SMTPPort, - "QuotaForNewUser": &common.QuotaForNewUser, - "QuotaForInviter": &common.QuotaForInviter, - "QuotaForInvitee": &common.QuotaForInvitee, - "QuotaRemindThreshold": &common.QuotaRemindThreshold, - "PreConsumedQuota": &common.PreConsumedQuota, - "RetryTimes": &common.RetryTimes, - "RetryCooldownSeconds": &common.RetryCooldownSeconds, + "SMTPPort": &common.SMTPPort, + "QuotaForNewUser": &common.QuotaForNewUser, + "QuotaForInviter": &common.QuotaForInviter, + "QuotaForInvitee": &common.QuotaForInvitee, + "QuotaRemindThreshold": &common.QuotaRemindThreshold, + "PreConsumedQuota": &common.PreConsumedQuota, + "RetryTimes": &common.RetryTimes, + "RetryCooldownSeconds": &common.RetryCooldownSeconds, } var optionBoolMap = map[string]*bool{ diff --git a/controller/relay/base.go b/relay/base.go similarity index 100% rename from controller/relay/base.go rename to relay/base.go diff --git a/controller/relay/chat.go b/relay/chat.go similarity index 100% rename from controller/relay/chat.go rename to relay/chat.go diff --git a/controller/relay/utils.go b/relay/common.go similarity index 100% rename from controller/relay/utils.go rename to relay/common.go diff --git a/controller/relay/completions.go b/relay/completions.go similarity index 100% rename from controller/relay/completions.go rename to relay/completions.go diff --git a/controller/relay/embeddings.go b/relay/embeddings.go similarity index 100% rename from controller/relay/embeddings.go rename to relay/embeddings.go diff --git a/controller/relay/image-edits.go b/relay/image-edits.go similarity index 100% rename from controller/relay/image-edits.go rename to relay/image-edits.go diff --git a/controller/relay/image-generations.go b/relay/image-generations.go similarity index 100% rename from controller/relay/image-generations.go rename to relay/image-generations.go diff --git a/controller/relay/image-variationsy.go b/relay/image-variationsy.go similarity index 100% rename from controller/relay/image-variationsy.go rename to relay/image-variationsy.go diff --git a/controller/relay/main.go b/relay/main.go similarity index 92% rename from controller/relay/main.go rename to relay/main.go index 6329dcb5..5da7a340 100644 --- a/controller/relay/main.go +++ b/relay/main.go @@ -5,6 +5,7 @@ import ( "net/http" "one-api/common" "one-api/model" + "one-api/relay/util" "one-api/types" "github.com/gin-gonic/gin" @@ -87,8 +88,8 @@ func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCod relay.getProvider().SetUsage(usage) - var quotaInfo *QuotaInfo - quotaInfo, err = generateQuotaInfo(relay.getContext(), relay.getModelName(), promptTokens) + var quota *util.Quota + quota, err = util.NewQuota(relay.getContext(), relay.getModelName(), promptTokens) if err != nil { done = true return @@ -97,10 +98,10 @@ func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCod err, done = relay.send() if err != nil { - quotaInfo.undo(relay.getContext()) + quota.Undo(relay.getContext()) return } - quotaInfo.consume(relay.getContext(), usage) + quota.Consume(relay.getContext(), usage) return } diff --git a/controller/model.go b/relay/model.go similarity index 99% rename from controller/model.go rename to relay/model.go index c13c31bb..74e6a636 100644 --- a/controller/model.go +++ b/relay/model.go @@ -1,4 +1,4 @@ -package controller +package relay import ( "fmt" diff --git a/controller/relay/moderations.go b/relay/moderations.go similarity index 100% rename from controller/relay/moderations.go rename to relay/moderations.go diff --git a/controller/relay/speech.go b/relay/speech.go similarity index 100% rename from controller/relay/speech.go rename to relay/speech.go diff --git a/controller/relay/transcriptions.go b/relay/transcriptions.go similarity index 100% rename from controller/relay/transcriptions.go rename to relay/transcriptions.go diff --git a/controller/relay/translations.go b/relay/translations.go similarity index 100% rename from controller/relay/translations.go rename to relay/translations.go diff --git a/controller/relay/quota.go b/relay/util/quota.go similarity index 70% rename from controller/relay/quota.go rename to relay/util/quota.go index 3343b6c5..0a7c2be0 100644 --- a/controller/relay/quota.go +++ b/relay/util/quota.go @@ -1,4 +1,4 @@ -package relay +package util import ( "context" @@ -14,7 +14,7 @@ import ( "github.com/gin-gonic/gin" ) -type QuotaInfo struct { +type Quota struct { modelName string promptTokens int preConsumedTokens int @@ -28,8 +28,8 @@ type QuotaInfo struct { HandelStatus bool } -func generateQuotaInfo(c *gin.Context, modelName string, promptTokens int) (*QuotaInfo, *types.OpenAIErrorWithStatusCode) { - quotaInfo := &QuotaInfo{ +func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *types.OpenAIErrorWithStatusCode) { + quota := &Quota{ modelName: modelName, promptTokens: promptTokens, userId: c.GetInt("id"), @@ -37,17 +37,17 @@ func generateQuotaInfo(c *gin.Context, modelName string, promptTokens int) (*Quo tokenId: c.GetInt("token_id"), HandelStatus: false, } - quotaInfo.initQuotaInfo(c.GetString("group")) + quota.init(c.GetString("group")) - errWithCode := quotaInfo.preQuotaConsumption() + errWithCode := quota.preQuotaConsumption() if errWithCode != nil { return nil, errWithCode } - return quotaInfo, nil + return quota, nil } -func (q *QuotaInfo) initQuotaInfo(groupName string) { +func (q *Quota) init(groupName string) { modelRatio := common.GetModelRatio(q.modelName) groupRatio := common.GetGroupRatio(groupName) preConsumedTokens := common.PreConsumedQuota @@ -62,7 +62,7 @@ func (q *QuotaInfo) initQuotaInfo(groupName string) { } -func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode { +func (q *Quota) preQuotaConsumption() *types.OpenAIErrorWithStatusCode { userQuota, err := model.CacheGetUserQuota(q.userId) if err != nil { return common.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) @@ -95,7 +95,7 @@ func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode { return nil } -func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error { +func (q *Quota) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error { quota := 0 completionRatio := q.modelRatio[1] * q.groupRatio promptTokens := usage.PromptTokens @@ -119,32 +119,31 @@ func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName stri if err != nil { return errors.New("error consuming token remain quota: " + err.Error()) } - if quota != 0 { - requestTime := 0 - requestStartTimeValue := ctx.Value("requestStartTime") - if requestStartTimeValue != nil { - requestStartTime, ok := requestStartTimeValue.(time.Time) - if ok { - requestTime = int(time.Since(requestStartTime).Milliseconds()) - } - } - var modelRatioStr string - if q.modelRatio[0] == q.modelRatio[1] { - modelRatioStr = fmt.Sprintf("%.2f", q.modelRatio[0]) - } else { - modelRatioStr = fmt.Sprintf("%.2f (输入)/%.2f (输出)", q.modelRatio[0], q.modelRatio[1]) - } - logContent := fmt.Sprintf("模型倍率 %s,分组倍率 %.2f", modelRatioStr, q.groupRatio) - model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime) - model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota) - model.UpdateChannelUsedQuota(q.channelId, quota) + requestTime := 0 + requestStartTimeValue := ctx.Value("requestStartTime") + if requestStartTimeValue != nil { + requestStartTime, ok := requestStartTimeValue.(time.Time) + if ok { + requestTime = int(time.Since(requestStartTime).Milliseconds()) + } } + var modelRatioStr string + if q.modelRatio[0] == q.modelRatio[1] { + modelRatioStr = fmt.Sprintf("%.2f", q.modelRatio[0]) + } else { + modelRatioStr = fmt.Sprintf("%.2f (输入)/%.2f (输出)", q.modelRatio[0], q.modelRatio[1]) + } + + logContent := fmt.Sprintf("模型倍率 %s,分组倍率 %.2f", modelRatioStr, q.groupRatio) + model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime) + model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota) + model.UpdateChannelUsedQuota(q.channelId, quota) return nil } -func (q *QuotaInfo) undo(c *gin.Context) { +func (q *Quota) Undo(c *gin.Context) { tokenId := c.GetInt("token_id") if q.HandelStatus { go func(ctx context.Context) { @@ -157,7 +156,7 @@ func (q *QuotaInfo) undo(c *gin.Context) { } } -func (q *QuotaInfo) consume(c *gin.Context, usage *types.Usage) { +func (q *Quota) Consume(c *gin.Context, usage *types.Usage) { tokenName := c.GetString("token_name") // 如果没有报错,则消费配额 go func(ctx context.Context) { diff --git a/router/api-router.go b/router/api-router.go index a5ea2e5b..57e81816 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -3,6 +3,7 @@ package router import ( "one-api/controller" "one-api/middleware" + "one-api/relay" "github.com/gin-contrib/gzip" "github.com/gin-gonic/gin" @@ -43,7 +44,7 @@ func SetApiRouter(router *gin.Engine) { selfRoute.GET("/token", controller.GenerateAccessToken) selfRoute.GET("/aff", controller.GetAffCode) selfRoute.POST("/topup", controller.TopUp) - selfRoute.GET("/models", controller.ListModels) + selfRoute.GET("/models", relay.ListModels) } adminRoute := userRoute.Group("/") @@ -73,7 +74,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute.Use(middleware.AdminAuth()) { channelRoute.GET("/", controller.GetChannelsList) - channelRoute.GET("/models", controller.ListModelsForAdmin) + channelRoute.GET("/models", relay.ListModelsForAdmin) channelRoute.GET("/:id", controller.GetChannel) channelRoute.GET("/test", controller.TestAllChannels) channelRoute.GET("/test/:id", controller.TestChannel) diff --git a/router/main.go b/router/main.go index b8ac4055..559747fe 100644 --- a/router/main.go +++ b/router/main.go @@ -3,18 +3,19 @@ package router import ( "embed" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" - "os" "strings" + + "github.com/gin-gonic/gin" + "github.com/spf13/viper" ) func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { SetApiRouter(router) SetDashboardRouter(router) SetRelayRouter(router) - frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") + frontendBaseUrl := viper.GetString("FRONTEND_BASE_URL") if common.IsMasterNode && frontendBaseUrl != "" { frontendBaseUrl = "" common.SysLog("FRONTEND_BASE_URL is ignored on master node") diff --git a/router/relay-router.go b/router/relay-router.go index f6f548e7..824cc1b3 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -2,8 +2,8 @@ package router import ( "one-api/controller" - "one-api/controller/relay" "one-api/middleware" + "one-api/relay" "github.com/gin-gonic/gin" ) @@ -14,8 +14,8 @@ func SetRelayRouter(router *gin.Engine) { modelsRouter := router.Group("/v1/models") modelsRouter.Use(middleware.TokenAuth(), middleware.Distribute()) { - modelsRouter.GET("", controller.ListModels) - modelsRouter.GET("/:model", controller.RetrieveModel) + modelsRouter.GET("", relay.ListModels) + modelsRouter.GET("/:model", relay.RetrieveModel) } relayV1Router := router.Group("/v1") relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute()) diff --git a/router/web-router.go b/router/web-router.go index 8f9c18a2..b98c74af 100644 --- a/router/web-router.go +++ b/router/web-router.go @@ -2,21 +2,21 @@ package router import ( "embed" - "github.com/gin-contrib/gzip" - "github.com/gin-contrib/static" - "github.com/gin-gonic/gin" "net/http" - "one-api/common" "one-api/controller" "one-api/middleware" "strings" + + "github.com/gin-contrib/gzip" + "github.com/gin-contrib/static" + "github.com/gin-gonic/gin" ) func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { router.Use(gzip.Gzip(gzip.DefaultCompression)) router.Use(middleware.GlobalWebRateLimit()) router.Use(middleware.Cache()) - router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/build"))) + router.Use(static.Serve("/", static.EmbedFolder(buildFS, "web/build"))) router.NoRoute(func(c *gin.Context) { if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") { controller.RelayNotFound(c) diff --git a/web/src/constants/ChannelConstants.js b/web/src/constants/ChannelConstants.js index 65fa4d18..49f82ae1 100644 --- a/web/src/constants/ChannelConstants.js +++ b/web/src/constants/ChannelConstants.js @@ -3,126 +3,147 @@ export const CHANNEL_OPTIONS = { key: 1, text: 'OpenAI', value: 1, - color: 'primary' + color: 'primary', + url: 'https://platform.openai.com/usage' }, 14: { key: 14, text: 'Anthropic Claude', value: 14, - color: 'info' + color: 'info', + url: 'https://console.anthropic.com/' }, 3: { key: 3, text: 'Azure OpenAI', value: 3, - color: 'orange' + color: 'orange', + url: 'https://oai.azure.com/' }, 11: { key: 11, text: 'Google PaLM2', value: 11, - color: 'orange' + color: 'orange', + url: 'https://aistudio.google.com/' }, 25: { key: 25, text: 'Google Gemini', value: 25, - color: 'orange' + color: 'orange', + url: 'https://aistudio.google.com/' }, 15: { key: 15, text: '百度文心千帆', value: 15, - color: 'default' + color: 'default', + url: 'https://console.bce.baidu.com/qianfan/overview' }, 17: { key: 17, text: '阿里通义千问', value: 17, - color: 'default' + color: 'default', + url: 'https://dashscope.console.aliyun.com/overview' }, 18: { key: 18, text: '讯飞星火认知', value: 18, - color: 'default' + color: 'default', + url: 'https://console.xfyun.cn/' }, 16: { key: 16, text: '智谱 ChatGLM', value: 16, - color: 'default' + color: 'default', + url: 'https://open.bigmodel.cn/overview' }, 19: { key: 19, text: '360 智脑', value: 19, - color: 'default' + color: 'default', + url: 'https://ai.360.com/open' }, 23: { key: 23, text: '腾讯混元', value: 23, - color: 'default' + color: 'default', + url: 'https://cloud.tencent.com/product/hunyuan' }, 26: { key: 26, text: '百川', value: 26, - color: 'orange' + color: 'orange', + url: 'https://platform.baichuan-ai.com/console/apikey' }, 27: { key: 27, text: 'MiniMax', value: 27, - color: 'orange' + color: 'orange', + url: 'https://www.minimaxi.com/user-center/basic-information' }, 28: { key: 28, text: 'Deepseek', value: 28, - color: 'default' + color: 'default', + url: 'https://platform.deepseek.com/usage' }, 29: { key: 29, text: 'Moonshot', value: 29, - color: 'default' + color: 'default', + url: 'https://platform.moonshot.cn/console/info' }, 30: { key: 30, text: 'Mistral', value: 30, - color: 'orange' + color: 'orange', + url: 'https://console.mistral.ai/' }, 31: { key: 31, text: 'Groq', value: 31, - color: 'primary' + color: 'primary', + url: 'https://console.groq.com/keys' }, 32: { key: 32, text: 'Amazon Bedrock', value: 32, - color: 'orange' + color: 'orange', + url: 'https://console.aws.amazon.com/bedrock/home' }, 33: { key: 33, text: '零一万物', value: 33, - color: 'primary' + color: 'primary', + url: 'https://platform.lingyiwanwu.com/details' }, 24: { key: 24, text: 'Azure Speech', value: 24, - color: 'orange' + color: 'orange', + url: 'https://portal.azure.com/' }, 8: { key: 8, text: '自定义渠道', value: 8, - color: 'primary' + color: 'primary', + url: '' } }; diff --git a/web/src/views/Channel/component/TableRow.js b/web/src/views/Channel/component/TableRow.js index c7ab9e6b..3fe204b6 100644 --- a/web/src/views/Channel/component/TableRow.js +++ b/web/src/views/Channel/component/TableRow.js @@ -34,7 +34,7 @@ import TableSwitch from 'ui-component/Switch'; import ResponseTimeLabel from './ResponseTimeLabel'; import GroupLabel from './GroupLabel'; -import { IconDotsVertical, IconEdit, IconTrash, IconPencil, IconCopy } from '@tabler/icons-react'; +import { IconDotsVertical, IconEdit, IconTrash, IconPencil, IconCopy, IconWorldWww } from '@tabler/icons-react'; import KeyboardArrowDownIcon from '@mui/icons-material/KeyboardArrowDown'; import KeyboardArrowUpIcon from '@mui/icons-material/KeyboardArrowUp'; import { copy } from 'utils/common'; @@ -252,6 +252,18 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal, > 复制{' '} + {CHANNEL_OPTIONS[item.type]?.url && ( + { + handleCloseMenu(); + // 新页面打开 + window.open(CHANNEL_OPTIONS[item.type].url); + }} + > + + 官网 + + )} diff --git a/web/src/views/Log/component/TableRow.js b/web/src/views/Log/component/TableRow.js index b701b24c..202fad93 100644 --- a/web/src/views/Log/component/TableRow.js +++ b/web/src/views/Log/component/TableRow.js @@ -76,9 +76,9 @@ export default function LogTableRow({ item, userIsAdmin }) { {' '} - {item.prompt_tokens || ''} - {item.completion_tokens || ''} - {item.quota ? renderQuota(item.quota, 6) : ''} + {item.prompt_tokens || '0'} + {item.completion_tokens || '0'} + {item.quota ? renderQuota(item.quota, 6) : '0'} {item.content}