diff --git a/common/config/config.go b/common/config/config.go new file mode 100644 index 00000000..34d871fe --- /dev/null +++ b/common/config/config.go @@ -0,0 +1,127 @@ +package config + +import ( + "one-api/common/helper" + "os" + "strconv" + "sync" + "time" + + "github.com/google/uuid" +) + +var SystemName = "One API" +var ServerAddress = "http://localhost:3000" +var Footer = "" +var Logo = "" +var TopUpLink = "" +var ChatLink = "" +var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens +var DisplayInCurrencyEnabled = true +var DisplayTokenStatEnabled = true + +// Any options with "Secret", "Token" in its key won't be return by GetOptions + +var SessionSecret = uuid.New().String() + +var OptionMap map[string]string +var OptionMapRWMutex sync.RWMutex + +var ItemsPerPage = 10 +var MaxRecentItems = 100 + +var PasswordLoginEnabled = true +var PasswordRegisterEnabled = true +var EmailVerificationEnabled = false +var GitHubOAuthEnabled = false +var WeChatAuthEnabled = false +var TurnstileCheckEnabled = false +var RegisterEnabled = true + +var EmailDomainRestrictionEnabled = false +var EmailDomainWhitelist = []string{ + "gmail.com", + "163.com", + "126.com", + "qq.com", + "outlook.com", + "hotmail.com", + "icloud.com", + "yahoo.com", + "foxmail.com", +} + +var DebugEnabled = os.Getenv("DEBUG") == "true" +var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" + +var LogConsumeEnabled = true + +var SMTPServer = "" +var SMTPPort = 587 +var SMTPAccount = "" +var SMTPFrom = "" +var SMTPToken = "" + +var GitHubClientId = "" +var GitHubClientSecret = "" + +var WeChatServerAddress = "" +var WeChatServerToken = "" +var WeChatAccountQRCodeImageURL = "" + +var TurnstileSiteKey = "" +var TurnstileSecretKey = "" + +var QuotaForNewUser = 0 +var QuotaForInviter = 0 +var QuotaForInvitee = 0 +var ChannelDisableThreshold = 5.0 +var AutomaticDisableChannelEnabled = false +var AutomaticEnableChannelEnabled = false +var QuotaRemindThreshold = 1000 +var PreConsumedQuota = 500 +var ApproximateTokenEnabled = false +var RetryTimes = 0 + +var RootUserEmail = "" + +var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" + +var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) +var RequestInterval = time.Duration(requestInterval) * time.Second + +var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second + +var BatchUpdateEnabled = false +var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5) + +var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second + +var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") + +var Theme = helper.GetOrDefaultEnvString("THEME", "default") +var ValidThemes = map[string]bool{ + "default": true, + "berry": true, +} + +// All duration's unit is seconds +// Shouldn't larger then RateLimitKeyExpirationDuration +var ( + GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180) + GlobalApiRateLimitDuration int64 = 3 * 60 + + GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60) + GlobalWebRateLimitDuration int64 = 3 * 60 + + UploadRateLimitNum = 10 + UploadRateLimitDuration int64 = 60 + + DownloadRateLimitNum = 10 + DownloadRateLimitDuration int64 = 60 + + CriticalRateLimitNum = 20 + CriticalRateLimitDuration int64 = 20 * 60 +) + +var RateLimitKeyExpirationDuration = 20 * time.Minute diff --git a/common/constants.go b/common/constants.go index a83de9c1..325454d4 100644 --- a/common/constants.go +++ b/common/constants.go @@ -1,114 +1,9 @@ package common -import ( - "os" - "strconv" - "sync" - "time" - - "github.com/google/uuid" -) +import "time" var StartTime = time.Now().Unix() // unit: second var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change -var SystemName = "One API" -var ServerAddress = "http://localhost:3000" -var Footer = "" -var Logo = "" -var TopUpLink = "" -var ChatLink = "" -var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens -var DisplayInCurrencyEnabled = true -var DisplayTokenStatEnabled = true - -// Any options with "Secret", "Token" in its key won't be return by GetOptions - -var SessionSecret = uuid.New().String() - -var OptionMap map[string]string -var OptionMapRWMutex sync.RWMutex - -var ItemsPerPage = 10 -var MaxRecentItems = 100 - -var PasswordLoginEnabled = true -var PasswordRegisterEnabled = true -var EmailVerificationEnabled = false -var GitHubOAuthEnabled = false -var WeChatAuthEnabled = false -var TurnstileCheckEnabled = false -var RegisterEnabled = true - -var EmailDomainRestrictionEnabled = false -var EmailDomainWhitelist = []string{ - "gmail.com", - "163.com", - "126.com", - "qq.com", - "outlook.com", - "hotmail.com", - "icloud.com", - "yahoo.com", - "foxmail.com", -} - -var DebugEnabled = os.Getenv("DEBUG") == "true" -var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" - -var LogConsumeEnabled = true - -var SMTPServer = "" -var SMTPPort = 587 -var SMTPAccount = "" -var SMTPFrom = "" -var SMTPToken = "" - -var GitHubClientId = "" -var GitHubClientSecret = "" - -var WeChatServerAddress = "" -var WeChatServerToken = "" -var WeChatAccountQRCodeImageURL = "" - -var TurnstileSiteKey = "" -var TurnstileSecretKey = "" - -var QuotaForNewUser = 0 -var QuotaForInviter = 0 -var QuotaForInvitee = 0 -var ChannelDisableThreshold = 5.0 -var AutomaticDisableChannelEnabled = false -var AutomaticEnableChannelEnabled = false -var QuotaRemindThreshold = 1000 -var PreConsumedQuota = 500 -var ApproximateTokenEnabled = false -var RetryTimes = 0 - -var RootUserEmail = "" - -var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" - -var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) -var RequestInterval = time.Duration(requestInterval) * time.Second - -var SyncFrequency = GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second - -var BatchUpdateEnabled = false -var BatchUpdateInterval = GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5) - -var RelayTimeout = GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second - -var GeminiSafetySetting = GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") - -var Theme = GetOrDefaultEnvString("THEME", "default") -var ValidThemes = map[string]bool{ - "default": true, - "berry": true, -} - -const ( - RequestIdKey = "X-Oneapi-Request-Id" -) const ( RoleGuestUser = 0 @@ -117,34 +12,6 @@ const ( RoleRootUser = 100 ) -var ( - FileUploadPermission = RoleGuestUser - FileDownloadPermission = RoleGuestUser - ImageUploadPermission = RoleGuestUser - ImageDownloadPermission = RoleGuestUser -) - -// All duration's unit is seconds -// Shouldn't larger then RateLimitKeyExpirationDuration -var ( - GlobalApiRateLimitNum = GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180) - GlobalApiRateLimitDuration int64 = 3 * 60 - - GlobalWebRateLimitNum = GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60) - GlobalWebRateLimitDuration int64 = 3 * 60 - - UploadRateLimitNum = 10 - UploadRateLimitDuration int64 = 60 - - DownloadRateLimitNum = 10 - DownloadRateLimitDuration int64 = 60 - - CriticalRateLimitNum = 20 - CriticalRateLimitDuration int64 = 20 * 60 -) - -var RateLimitKeyExpirationDuration = 20 * time.Minute - const ( UserStatusEnabled = 1 // don't use 0, 0 is the default value! UserStatusDisabled = 2 // also don't use 0 diff --git a/common/database.go b/common/database.go index 8f659b57..4c2c2717 100644 --- a/common/database.go +++ b/common/database.go @@ -1,7 +1,9 @@ package common +import "one-api/common/helper" + var UsingSQLite = false var UsingPostgreSQL = false var SQLitePath = "one-api.db" -var SQLiteBusyTimeout = GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) +var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/common/email.go b/common/email.go index b915f0f9..28719742 100644 --- a/common/email.go +++ b/common/email.go @@ -6,18 +6,19 @@ import ( "encoding/base64" "fmt" "net/smtp" + "one-api/common/config" "strings" "time" ) func SendEmail(subject string, receiver string, content string) error { - if SMTPFrom == "" { // for compatibility - SMTPFrom = SMTPAccount + if config.SMTPFrom == "" { // for compatibility + config.SMTPFrom = config.SMTPAccount } encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) // Extract domain from SMTPFrom - parts := strings.Split(SMTPFrom, "@") + parts := strings.Split(config.SMTPFrom, "@") var domain string if len(parts) > 1 { domain = parts[1] @@ -36,21 +37,21 @@ func SendEmail(subject string, receiver string, content string) error { "Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 "Date: %s\r\n"+ "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", - receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) - auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) - addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) + receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) + auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) + addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) to := strings.Split(receiver, ";") - if SMTPPort == 465 { + if config.SMTPPort == 465 { tlsConfig := &tls.Config{ InsecureSkipVerify: true, - ServerName: SMTPServer, + ServerName: config.SMTPServer, } - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig) + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) if err != nil { return err } - client, err := smtp.NewClient(conn, SMTPServer) + client, err := smtp.NewClient(conn, config.SMTPServer) if err != nil { return err } @@ -58,7 +59,7 @@ func SendEmail(subject string, receiver string, content string) error { if err = client.Auth(auth); err != nil { return err } - if err = client.Mail(SMTPFrom); err != nil { + if err = client.Mail(config.SMTPFrom); err != nil { return err } receiverEmails := strings.Split(receiver, ";") @@ -80,7 +81,7 @@ func SendEmail(subject string, receiver string, content string) error { return err } } else { - err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) + err = smtp.SendMail(addr, auth, config.SMTPAccount, to, mail) } return err } diff --git a/common/helper/helper.go b/common/helper/helper.go new file mode 100644 index 00000000..09f1df29 --- /dev/null +++ b/common/helper/helper.go @@ -0,0 +1,224 @@ +package helper + +import ( + "fmt" + "github.com/google/uuid" + "html/template" + "log" + "math/rand" + "net" + "one-api/common/logger" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "time" +) + +func OpenBrowser(url string) { + var err error + + switch runtime.GOOS { + case "linux": + err = exec.Command("xdg-open", url).Start() + case "windows": + err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + case "darwin": + err = exec.Command("open", url).Start() + } + if err != nil { + log.Println(err) + } +} + +func GetIp() (ip string) { + ips, err := net.InterfaceAddrs() + if err != nil { + log.Println(err) + return ip + } + + for _, a := range ips { + if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { + if ipNet.IP.To4() != nil { + ip = ipNet.IP.String() + if strings.HasPrefix(ip, "10") { + return + } + if strings.HasPrefix(ip, "172") { + return + } + if strings.HasPrefix(ip, "192.168") { + return + } + ip = "" + } + } + } + return +} + +var sizeKB = 1024 +var sizeMB = sizeKB * 1024 +var sizeGB = sizeMB * 1024 + +func Bytes2Size(num int64) string { + numStr := "" + unit := "B" + if num/int64(sizeGB) > 1 { + numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB)) + unit = "GB" + } else if num/int64(sizeMB) > 1 { + numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB))) + unit = "MB" + } else if num/int64(sizeKB) > 1 { + numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB))) + unit = "KB" + } else { + numStr = fmt.Sprintf("%d", num) + } + return numStr + " " + unit +} + +func Seconds2Time(num int) (time string) { + if num/31104000 > 0 { + time += strconv.Itoa(num/31104000) + " 年 " + num %= 31104000 + } + if num/2592000 > 0 { + time += strconv.Itoa(num/2592000) + " 个月 " + num %= 2592000 + } + if num/86400 > 0 { + time += strconv.Itoa(num/86400) + " 天 " + num %= 86400 + } + if num/3600 > 0 { + time += strconv.Itoa(num/3600) + " 小时 " + num %= 3600 + } + if num/60 > 0 { + time += strconv.Itoa(num/60) + " 分钟 " + num %= 60 + } + time += strconv.Itoa(num) + " 秒" + return +} + +func Interface2String(inter interface{}) string { + switch inter.(type) { + case string: + return inter.(string) + case int: + return fmt.Sprintf("%d", inter.(int)) + case float64: + return fmt.Sprintf("%f", inter.(float64)) + } + return "Not Implemented" +} + +func UnescapeHTML(x string) interface{} { + return template.HTML(x) +} + +func IntMax(a int, b int) int { + if a >= b { + return a + } else { + return b + } +} + +func GetUUID() string { + code := uuid.New().String() + code = strings.Replace(code, "-", "", -1) + return code +} + +const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func GenerateKey() string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, 48) + for i := 0; i < 16; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + uuid_ := GetUUID() + for i := 0; i < 32; i++ { + c := uuid_[i] + if i%2 == 0 && c >= 'a' && c <= 'z' { + c = c - 'a' + 'A' + } + key[i+16] = c + } + return string(key) +} + +func GetRandomString(length int) string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, length) + for i := 0; i < length; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + return string(key) +} + +func GetTimestamp() int64 { + return time.Now().Unix() +} + +func GetTimeString() string { + now := time.Now() + return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) +} + +func Max(a int, b int) int { + if a >= b { + return a + } else { + return b + } +} + +func GetOrDefaultEnvInt(env string, defaultValue int) int { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.Atoi(os.Getenv(env)) + if err != nil { + logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) + return defaultValue + } + return num +} + +func GetOrDefaultEnvString(env string, defaultValue string) string { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) +} + +func AssignOrDefault(value string, defaultValue string) string { + if len(value) != 0 { + return value + } + return defaultValue +} + +func MessageWithRequestId(message string, id string) string { + return fmt.Sprintf("%s (request id: %s)", message, id) +} + +func String2Int(str string) int { + num, err := strconv.Atoi(str) + if err != nil { + return 0 + } + return num +} diff --git a/common/init.go b/common/init.go index 9735c5b4..abc71108 100644 --- a/common/init.go +++ b/common/init.go @@ -4,6 +4,7 @@ import ( "flag" "fmt" "log" + "one-api/common/config" "one-api/common/logger" "os" "path/filepath" @@ -40,7 +41,7 @@ func init() { if os.Getenv("SESSION_SECRET") == "random_string" { logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.") } else { - SessionSecret = os.Getenv("SESSION_SECRET") + config.SessionSecret = os.Getenv("SESSION_SECRET") } } if os.Getenv("SQLITE_PATH") != "" { @@ -58,5 +59,6 @@ func init() { log.Fatal(err) } } + logger.LogDir = *LogDir } } diff --git a/common/logger/constants.go b/common/logger/constants.go new file mode 100644 index 00000000..78d32062 --- /dev/null +++ b/common/logger/constants.go @@ -0,0 +1,7 @@ +package logger + +const ( + RequestIdKey = "X-Oneapi-Request-Id" +) + +var LogDir string diff --git a/common/logger/logger.go b/common/logger/logger.go index 4386bc6c..b89dbdb7 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -6,7 +6,6 @@ import ( "github.com/gin-gonic/gin" "io" "log" - "one-api/common" "os" "path/filepath" "sync" @@ -26,7 +25,7 @@ var setupLogLock sync.Mutex var setupLogWorking bool func SetupLogger() { - if *common.LogDir != "" { + if LogDir != "" { ok := setupLogLock.TryLock() if !ok { log.Println("setup log is already working") @@ -36,7 +35,7 @@ func SetupLogger() { setupLogLock.Unlock() setupLogWorking = false }() - logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal("failed to open log file") @@ -85,7 +84,7 @@ func logHelper(ctx context.Context, level string, msg string) { if level == loggerINFO { writer = gin.DefaultWriter } - id := ctx.Value(common.RequestIdKey) + id := ctx.Value(RequestIdKey) now := time.Now() _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) logCount++ // we don't need accurate count, so no lock here @@ -103,11 +102,3 @@ func FatalLog(v ...any) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) os.Exit(1) } - -func LogQuota(quota int) string { - if common.DisplayInCurrencyEnabled { - return fmt.Sprintf("$%.6f 额度", float64(quota)/common.QuotaPerUnit) - } else { - return fmt.Sprintf("%d 点额度", quota) - } -} diff --git a/common/utils.go b/common/utils.go index 4e3312f9..41de9367 100644 --- a/common/utils.go +++ b/common/utils.go @@ -2,223 +2,13 @@ package common import ( "fmt" - "github.com/google/uuid" - "html/template" - "log" - "math/rand" - "net" - "one-api/common/logger" - "os" - "os/exec" - "runtime" - "strconv" - "strings" - "time" + "one-api/common/config" ) -func OpenBrowser(url string) { - var err error - - switch runtime.GOOS { - case "linux": - err = exec.Command("xdg-open", url).Start() - case "windows": - err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() - case "darwin": - err = exec.Command("open", url).Start() - } - if err != nil { - log.Println(err) - } -} - -func GetIp() (ip string) { - ips, err := net.InterfaceAddrs() - if err != nil { - log.Println(err) - return ip - } - - for _, a := range ips { - if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { - if ipNet.IP.To4() != nil { - ip = ipNet.IP.String() - if strings.HasPrefix(ip, "10") { - return - } - if strings.HasPrefix(ip, "172") { - return - } - if strings.HasPrefix(ip, "192.168") { - return - } - ip = "" - } - } - } - return -} - -var sizeKB = 1024 -var sizeMB = sizeKB * 1024 -var sizeGB = sizeMB * 1024 - -func Bytes2Size(num int64) string { - numStr := "" - unit := "B" - if num/int64(sizeGB) > 1 { - numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB)) - unit = "GB" - } else if num/int64(sizeMB) > 1 { - numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB))) - unit = "MB" - } else if num/int64(sizeKB) > 1 { - numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB))) - unit = "KB" +func LogQuota(quota int) string { + if config.DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit) } else { - numStr = fmt.Sprintf("%d", num) - } - return numStr + " " + unit -} - -func Seconds2Time(num int) (time string) { - if num/31104000 > 0 { - time += strconv.Itoa(num/31104000) + " 年 " - num %= 31104000 - } - if num/2592000 > 0 { - time += strconv.Itoa(num/2592000) + " 个月 " - num %= 2592000 - } - if num/86400 > 0 { - time += strconv.Itoa(num/86400) + " 天 " - num %= 86400 - } - if num/3600 > 0 { - time += strconv.Itoa(num/3600) + " 小时 " - num %= 3600 - } - if num/60 > 0 { - time += strconv.Itoa(num/60) + " 分钟 " - num %= 60 - } - time += strconv.Itoa(num) + " 秒" - return -} - -func Interface2String(inter interface{}) string { - switch inter.(type) { - case string: - return inter.(string) - case int: - return fmt.Sprintf("%d", inter.(int)) - case float64: - return fmt.Sprintf("%f", inter.(float64)) - } - return "Not Implemented" -} - -func UnescapeHTML(x string) interface{} { - return template.HTML(x) -} - -func IntMax(a int, b int) int { - if a >= b { - return a - } else { - return b + return fmt.Sprintf("%d 点额度", quota) } } - -func GetUUID() string { - code := uuid.New().String() - code = strings.Replace(code, "-", "", -1) - return code -} - -const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func GenerateKey() string { - rand.Seed(time.Now().UnixNano()) - key := make([]byte, 48) - for i := 0; i < 16; i++ { - key[i] = keyChars[rand.Intn(len(keyChars))] - } - uuid_ := GetUUID() - for i := 0; i < 32; i++ { - c := uuid_[i] - if i%2 == 0 && c >= 'a' && c <= 'z' { - c = c - 'a' + 'A' - } - key[i+16] = c - } - return string(key) -} - -func GetRandomString(length int) string { - rand.Seed(time.Now().UnixNano()) - key := make([]byte, length) - for i := 0; i < length; i++ { - key[i] = keyChars[rand.Intn(len(keyChars))] - } - return string(key) -} - -func GetTimestamp() int64 { - return time.Now().Unix() -} - -func GetTimeString() string { - now := time.Now() - return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) -} - -func Max(a int, b int) int { - if a >= b { - return a - } else { - return b - } -} - -func GetOrDefaultEnvInt(env string, defaultValue int) int { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - num, err := strconv.Atoi(os.Getenv(env)) - if err != nil { - logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) - return defaultValue - } - return num -} - -func GetOrDefaultEnvString(env string, defaultValue string) string { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - return os.Getenv(env) -} - -func AssignOrDefault(value string, defaultValue string) string { - if len(value) != 0 { - return value - } - return defaultValue -} - -func MessageWithRequestId(message string, id string) string { - return fmt.Sprintf("%s (request id: %s)", message, id) -} - -func String2Int(str string) int { - num, err := strconv.Atoi(str) - if err != nil { - return 0 - } - return num -} diff --git a/controller/billing.go b/controller/billing.go index e27fd614..1003bf10 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -2,7 +2,7 @@ package controller import ( "github.com/gin-gonic/gin" - "one-api/common" + "one-api/common/config" "one-api/model" "one-api/relay/channel/openai" ) @@ -13,7 +13,7 @@ func GetSubscription(c *gin.Context) { var err error var token *model.Token var expiredTime int64 - if common.DisplayTokenStatEnabled { + if config.DisplayTokenStatEnabled { tokenId := c.GetInt("token_id") token, err = model.GetTokenById(tokenId) expiredTime = token.ExpiredTime @@ -39,8 +39,8 @@ func GetSubscription(c *gin.Context) { } quota := remainQuota + usedQuota amount := float64(quota) - if common.DisplayInCurrencyEnabled { - amount /= common.QuotaPerUnit + if config.DisplayInCurrencyEnabled { + amount /= config.QuotaPerUnit } if token != nil && token.UnlimitedQuota { amount = 100000000 @@ -61,7 +61,7 @@ func GetUsage(c *gin.Context) { var quota int var err error var token *model.Token - if common.DisplayTokenStatEnabled { + if config.DisplayTokenStatEnabled { tokenId := c.GetInt("token_id") token, err = model.GetTokenById(tokenId) quota = token.UsedQuota @@ -80,8 +80,8 @@ func GetUsage(c *gin.Context) { return } amount := float64(quota) - if common.DisplayInCurrencyEnabled { - amount /= common.QuotaPerUnit + if config.DisplayInCurrencyEnabled { + amount /= config.QuotaPerUnit } usage := OpenAIUsageResponse{ Object: "list", diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 61a899a4..49c76760 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/model" "one-api/relay/util" @@ -315,7 +316,7 @@ func updateAllChannelsBalance() error { disableChannel(channel.Id, channel.Name, "余额不足") } } - time.Sleep(common.RequestInterval) + time.Sleep(config.RequestInterval) } return nil } diff --git a/controller/channel-test.go b/controller/channel-test.go index 73ff6bb2..88d6e3f2 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/model" "one-api/relay/channel/openai" @@ -151,10 +152,10 @@ var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false func notifyRootUser(subject string, content string) { - if common.RootUserEmail == "" { - common.RootUserEmail = model.GetRootUserEmail() + if config.RootUserEmail == "" { + config.RootUserEmail = model.GetRootUserEmail() } - err := common.SendEmail(subject, common.RootUserEmail, content) + err := common.SendEmail(subject, config.RootUserEmail, content) if err != nil { logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } @@ -177,8 +178,8 @@ func enableChannel(channelId int, channelName string) { } func testAllChannels(notify bool) error { - if common.RootUserEmail == "" { - common.RootUserEmail = model.GetRootUserEmail() + if config.RootUserEmail == "" { + config.RootUserEmail = model.GetRootUserEmail() } testAllChannelsLock.Lock() if testAllChannelsRunning { @@ -192,7 +193,7 @@ func testAllChannels(notify bool) error { return err } testRequest := buildTestRequest() - var disableThreshold = int64(common.ChannelDisableThreshold * 1000) + var disableThreshold = int64(config.ChannelDisableThreshold * 1000) if disableThreshold == 0 { disableThreshold = 10000000 // a impossible value } @@ -214,13 +215,13 @@ func testAllChannels(notify bool) error { enableChannel(channel.Id, channel.Name) } channel.UpdateResponseTime(milliseconds) - time.Sleep(common.RequestInterval) + time.Sleep(config.RequestInterval) } testAllChannelsLock.Lock() testAllChannelsRunning = false testAllChannelsLock.Unlock() if notify { - err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") + err := common.SendEmail("通道测试完成", config.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") if err != nil { logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } diff --git a/controller/channel.go b/controller/channel.go index 904abc23..6d368066 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -3,7 +3,8 @@ package controller import ( "github.com/gin-gonic/gin" "net/http" - "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/model" "strconv" "strings" @@ -14,7 +15,7 @@ func GetAllChannels(c *gin.Context) { if p < 0 { p = 0 } - channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false) + channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, false) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -83,7 +84,7 @@ func AddChannel(c *gin.Context) { }) return } - channel.CreatedTime = common.GetTimestamp() + channel.CreatedTime = helper.GetTimestamp() keys := strings.Split(channel.Key, "\n") channels := make([]model.Channel, 0, len(keys)) for _, key := range keys { diff --git a/controller/github.go b/controller/github.go index 68692b9d..f6049ddb 100644 --- a/controller/github.go +++ b/controller/github.go @@ -9,6 +9,8 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/common/logger" "one-api/model" "strconv" @@ -31,7 +33,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { if code == "" { return nil, errors.New("无效的参数") } - values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code} + values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code} jsonData, err := json.Marshal(values) if err != nil { return nil, err @@ -94,7 +96,7 @@ func GitHubOAuth(c *gin.Context) { return } - if !common.GitHubOAuthEnabled { + if !config.GitHubOAuthEnabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 GitHub 登录以及注册", @@ -123,7 +125,7 @@ func GitHubOAuth(c *gin.Context) { return } } else { - if common.RegisterEnabled { + if config.RegisterEnabled { user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) if githubUser.Name != "" { user.DisplayName = githubUser.Name @@ -161,7 +163,7 @@ func GitHubOAuth(c *gin.Context) { } func GitHubBind(c *gin.Context) { - if !common.GitHubOAuthEnabled { + if !config.GitHubOAuthEnabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 GitHub 登录以及注册", @@ -217,7 +219,7 @@ func GitHubBind(c *gin.Context) { func GenerateOAuthCode(c *gin.Context) { session := sessions.Default(c) - state := common.GetRandomString(12) + state := helper.GetRandomString(12) session.Set("oauth_state", state) err := session.Save() if err != nil { diff --git a/controller/log.go b/controller/log.go index b65867fe..418c6859 100644 --- a/controller/log.go +++ b/controller/log.go @@ -3,7 +3,7 @@ package controller import ( "github.com/gin-gonic/gin" "net/http" - "one-api/common" + "one-api/common/config" "one-api/model" "strconv" ) @@ -20,7 +20,7 @@ func GetAllLogs(c *gin.Context) { tokenName := c.Query("token_name") modelName := c.Query("model_name") channel, _ := strconv.Atoi(c.Query("channel")) - logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel) + logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*config.ItemsPerPage, config.ItemsPerPage, channel) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -47,7 +47,7 @@ func GetUserLogs(c *gin.Context) { endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) tokenName := c.Query("token_name") modelName := c.Query("model_name") - logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) + logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*config.ItemsPerPage, config.ItemsPerPage) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/misc.go b/controller/misc.go index 2bcbb41f..df7c0728 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/config" "one-api/model" "strings" @@ -18,55 +19,55 @@ func GetStatus(c *gin.Context) { "data": gin.H{ "version": common.Version, "start_time": common.StartTime, - "email_verification": common.EmailVerificationEnabled, - "github_oauth": common.GitHubOAuthEnabled, - "github_client_id": common.GitHubClientId, - "system_name": common.SystemName, - "logo": common.Logo, - "footer_html": common.Footer, - "wechat_qrcode": common.WeChatAccountQRCodeImageURL, - "wechat_login": common.WeChatAuthEnabled, - "server_address": common.ServerAddress, - "turnstile_check": common.TurnstileCheckEnabled, - "turnstile_site_key": common.TurnstileSiteKey, - "top_up_link": common.TopUpLink, - "chat_link": common.ChatLink, - "quota_per_unit": common.QuotaPerUnit, - "display_in_currency": common.DisplayInCurrencyEnabled, + "email_verification": config.EmailVerificationEnabled, + "github_oauth": config.GitHubOAuthEnabled, + "github_client_id": config.GitHubClientId, + "system_name": config.SystemName, + "logo": config.Logo, + "footer_html": config.Footer, + "wechat_qrcode": config.WeChatAccountQRCodeImageURL, + "wechat_login": config.WeChatAuthEnabled, + "server_address": config.ServerAddress, + "turnstile_check": config.TurnstileCheckEnabled, + "turnstile_site_key": config.TurnstileSiteKey, + "top_up_link": config.TopUpLink, + "chat_link": config.ChatLink, + "quota_per_unit": config.QuotaPerUnit, + "display_in_currency": config.DisplayInCurrencyEnabled, }, }) return } func GetNotice(c *gin.Context) { - common.OptionMapRWMutex.RLock() - defer common.OptionMapRWMutex.RUnlock() + config.OptionMapRWMutex.RLock() + defer config.OptionMapRWMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": common.OptionMap["Notice"], + "data": config.OptionMap["Notice"], }) return } func GetAbout(c *gin.Context) { - common.OptionMapRWMutex.RLock() - defer common.OptionMapRWMutex.RUnlock() + config.OptionMapRWMutex.RLock() + defer config.OptionMapRWMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": common.OptionMap["About"], + "data": config.OptionMap["About"], }) return } func GetHomePageContent(c *gin.Context) { - common.OptionMapRWMutex.RLock() - defer common.OptionMapRWMutex.RUnlock() + config.OptionMapRWMutex.RLock() + defer config.OptionMapRWMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": common.OptionMap["HomePageContent"], + "data": config.OptionMap["HomePageContent"], }) return } @@ -80,9 +81,9 @@ func SendEmailVerification(c *gin.Context) { }) return } - if common.EmailDomainRestrictionEnabled { + if config.EmailDomainRestrictionEnabled { allowed := false - for _, domain := range common.EmailDomainWhitelist { + for _, domain := range config.EmailDomainWhitelist { if strings.HasSuffix(email, "@"+domain) { allowed = true break @@ -105,10 +106,10 @@ func SendEmailVerification(c *gin.Context) { } code := common.GenerateVerificationCode(6) common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose) - subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) + subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName) content := fmt.Sprintf("

您好,你正在进行%s邮箱验证。

"+ "

您的验证码为: %s

"+ - "

验证码 %d 分钟内有效,如果不是本人操作,请忽略。

", common.SystemName, code, common.VerificationValidMinutes) + "

验证码 %d 分钟内有效,如果不是本人操作,请忽略。

", config.SystemName, code, common.VerificationValidMinutes) err := common.SendEmail(subject, email, content) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -142,12 +143,12 @@ func SendPasswordResetEmail(c *gin.Context) { } code := common.GenerateVerificationCode(0) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) - link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) - subject := fmt.Sprintf("%s密码重置", common.SystemName) + link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code) + subject := fmt.Sprintf("%s密码重置", config.SystemName) content := fmt.Sprintf("

您好,你正在进行%s密码重置。

"+ "

点击 此处 进行密码重置。

"+ "

如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:
%s

"+ - "

重置链接 %d 分钟内有效,如果不是本人操作,请忽略。

", common.SystemName, link, link, common.VerificationValidMinutes) + "

重置链接 %d 分钟内有效,如果不是本人操作,请忽略。

", config.SystemName, link, link, common.VerificationValidMinutes) err := common.SendEmail(subject, email, content) if err != nil { c.JSON(http.StatusOK, gin.H{ diff --git a/controller/option.go b/controller/option.go index 3b1cbad2..593ac2ae 100644 --- a/controller/option.go +++ b/controller/option.go @@ -3,7 +3,8 @@ package controller import ( "encoding/json" "net/http" - "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/model" "strings" @@ -12,17 +13,17 @@ import ( func GetOptions(c *gin.Context) { var options []*model.Option - common.OptionMapRWMutex.Lock() - for k, v := range common.OptionMap { + config.OptionMapRWMutex.Lock() + for k, v := range config.OptionMap { if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { continue } options = append(options, &model.Option{ Key: k, - Value: common.Interface2String(v), + Value: helper.Interface2String(v), }) } - common.OptionMapRWMutex.Unlock() + config.OptionMapRWMutex.Unlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", @@ -43,7 +44,7 @@ func UpdateOption(c *gin.Context) { } switch option.Key { case "Theme": - if !common.ValidThemes[option.Value] { + if !config.ValidThemes[option.Value] { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无效的主题", @@ -51,7 +52,7 @@ func UpdateOption(c *gin.Context) { return } case "GitHubOAuthEnabled": - if option.Value == "true" && common.GitHubClientId == "" { + if option.Value == "true" && config.GitHubClientId == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", @@ -59,7 +60,7 @@ func UpdateOption(c *gin.Context) { return } case "EmailDomainRestrictionEnabled": - if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { + if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", @@ -67,7 +68,7 @@ func UpdateOption(c *gin.Context) { return } case "WeChatAuthEnabled": - if option.Value == "true" && common.WeChatServerAddress == "" { + if option.Value == "true" && config.WeChatServerAddress == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用微信登录,请先填入微信登录相关配置信息!", @@ -75,7 +76,7 @@ func UpdateOption(c *gin.Context) { return } case "TurnstileCheckEnabled": - if option.Value == "true" && common.TurnstileSiteKey == "" { + if option.Value == "true" && config.TurnstileSiteKey == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", diff --git a/controller/redemption.go b/controller/redemption.go index 0f656be0..9eeeb943 100644 --- a/controller/redemption.go +++ b/controller/redemption.go @@ -3,7 +3,8 @@ package controller import ( "github.com/gin-gonic/gin" "net/http" - "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/model" "strconv" ) @@ -13,7 +14,7 @@ func GetAllRedemptions(c *gin.Context) { if p < 0 { p = 0 } - redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage) + redemptions, err := model.GetAllRedemptions(p*config.ItemsPerPage, config.ItemsPerPage) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -105,12 +106,12 @@ func AddRedemption(c *gin.Context) { } var keys []string for i := 0; i < redemption.Count; i++ { - key := common.GetUUID() + key := helper.GetUUID() cleanRedemption := model.Redemption{ UserId: c.GetInt("id"), Name: redemption.Name, Key: key, - CreatedTime: common.GetTimestamp(), + CreatedTime: helper.GetTimestamp(), Quota: redemption.Quota, } err = cleanRedemption.Insert() diff --git a/controller/relay.go b/controller/relay.go index e390ae75..46fedc7e 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -4,7 +4,8 @@ import ( "fmt" "github.com/gin-gonic/gin" "net/http" - "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/common/logger" "one-api/relay/channel/openai" "one-api/relay/constant" @@ -31,11 +32,11 @@ func Relay(c *gin.Context) { err = controller.RelayTextHelper(c, relayMode) } if err != nil { - requestId := c.GetString(common.RequestIdKey) + requestId := c.GetString(logger.RequestIdKey) retryTimesStr := c.Query("retry") retryTimes, _ := strconv.Atoi(retryTimesStr) if retryTimesStr == "" { - retryTimes = common.RetryTimes + retryTimes = config.RetryTimes } if retryTimes > 0 { c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) @@ -43,7 +44,7 @@ func Relay(c *gin.Context) { if err.StatusCode == http.StatusTooManyRequests { err.Error.Message = "当前分组上游负载已饱和,请稍后再试" } - err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId) + err.Error.Message = helper.MessageWithRequestId(err.Error.Message, requestId) c.JSON(err.StatusCode, gin.H{ "error": err.Error, }) diff --git a/controller/token.go b/controller/token.go index 8642122c..d6554abe 100644 --- a/controller/token.go +++ b/controller/token.go @@ -4,6 +4,8 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/model" "strconv" ) @@ -14,7 +16,7 @@ func GetAllTokens(c *gin.Context) { if p < 0 { p = 0 } - tokens, err := model.GetAllUserTokens(userId, p*common.ItemsPerPage, common.ItemsPerPage) + tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -119,9 +121,9 @@ func AddToken(c *gin.Context) { cleanToken := model.Token{ UserId: c.GetInt("id"), Name: token.Name, - Key: common.GenerateKey(), - CreatedTime: common.GetTimestamp(), - AccessedTime: common.GetTimestamp(), + Key: helper.GenerateKey(), + CreatedTime: helper.GetTimestamp(), + AccessedTime: helper.GetTimestamp(), ExpiredTime: token.ExpiredTime, RemainQuota: token.RemainQuota, UnlimitedQuota: token.UnlimitedQuota, @@ -187,7 +189,7 @@ func UpdateToken(c *gin.Context) { return } if token.Status == common.TokenStatusEnabled { - if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 { + if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", diff --git a/controller/user.go b/controller/user.go index d39bba3b..d13acddd 100644 --- a/controller/user.go +++ b/controller/user.go @@ -5,7 +5,8 @@ import ( "fmt" "net/http" "one-api/common" - "one-api/common/logger" + "one-api/common/config" + "one-api/common/helper" "one-api/model" "strconv" "time" @@ -20,7 +21,7 @@ type LoginRequest struct { } func Login(c *gin.Context) { - if !common.PasswordLoginEnabled { + if !config.PasswordLoginEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员关闭了密码登录", "success": false, @@ -107,14 +108,14 @@ func Logout(c *gin.Context) { } func Register(c *gin.Context) { - if !common.RegisterEnabled { + if !config.RegisterEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员关闭了新用户注册", "success": false, }) return } - if !common.PasswordRegisterEnabled { + if !config.PasswordRegisterEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册", "success": false, @@ -137,7 +138,7 @@ func Register(c *gin.Context) { }) return } - if common.EmailVerificationEnabled { + if config.EmailVerificationEnabled { if user.Email == "" || user.VerificationCode == "" { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -161,7 +162,7 @@ func Register(c *gin.Context) { DisplayName: user.Username, InviterId: inviterId, } - if common.EmailVerificationEnabled { + if config.EmailVerificationEnabled { cleanUser.Email = user.Email } if err := cleanUser.Insert(inviterId); err != nil { @@ -183,7 +184,7 @@ func GetAllUsers(c *gin.Context) { if p < 0 { p = 0 } - users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage) + users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -283,7 +284,7 @@ func GenerateAccessToken(c *gin.Context) { }) return } - user.AccessToken = common.GetUUID() + user.AccessToken = helper.GetUUID() if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { c.JSON(http.StatusOK, gin.H{ @@ -320,7 +321,7 @@ func GetAffCode(c *gin.Context) { return } if user.AffCode == "" { - user.AffCode = common.GetRandomString(4) + user.AffCode = helper.GetRandomString(4) if err := user.Update(false); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -410,7 +411,7 @@ func UpdateUser(c *gin.Context) { return } if originUser.Quota != updatedUser.Quota { - model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota))) + model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) } c.JSON(http.StatusOK, gin.H{ "success": true, @@ -727,7 +728,7 @@ func EmailBind(c *gin.Context) { return } if user.Role == common.RoleRootUser { - common.RootUserEmail = email + config.RootUserEmail = email } c.JSON(http.StatusOK, gin.H{ "success": true, diff --git a/controller/wechat.go b/controller/wechat.go index ff4c9fb6..fc231a24 100644 --- a/controller/wechat.go +++ b/controller/wechat.go @@ -7,6 +7,7 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "one-api/common/config" "one-api/model" "strconv" "time" @@ -22,11 +23,11 @@ func getWeChatIdByCode(code string) (string, error) { if code == "" { return "", errors.New("无效的参数") } - req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil) + req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil) if err != nil { return "", err } - req.Header.Set("Authorization", common.WeChatServerToken) + req.Header.Set("Authorization", config.WeChatServerToken) client := http.Client{ Timeout: 5 * time.Second, } @@ -50,7 +51,7 @@ func getWeChatIdByCode(code string) (string, error) { } func WeChatAuth(c *gin.Context) { - if !common.WeChatAuthEnabled { + if !config.WeChatAuthEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员未开启通过微信登录以及注册", "success": false, @@ -79,7 +80,7 @@ func WeChatAuth(c *gin.Context) { return } } else { - if common.RegisterEnabled { + if config.RegisterEnabled { user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.DisplayName = "WeChat User" user.Role = common.RoleCommonUser @@ -112,7 +113,7 @@ func WeChatAuth(c *gin.Context) { } func WeChatBind(c *gin.Context) { - if !common.WeChatAuthEnabled { + if !config.WeChatAuthEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员未开启通过微信登录以及注册", "success": false, diff --git a/main.go b/main.go index 9e4a88f2..b79c7bf7 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/controller" "one-api/middleware" @@ -26,7 +27,7 @@ func main() { if os.Getenv("GIN_MODE") != "debug" { gin.SetMode(gin.ReleaseMode) } - if common.DebugEnabled { + if config.DebugEnabled { logger.SysLog("running in debug mode") } // Initialize SQL Database @@ -49,19 +50,19 @@ func main() { // Initialize options model.InitOptionMap() - logger.SysLog(fmt.Sprintf("using theme %s", common.Theme)) + logger.SysLog(fmt.Sprintf("using theme %s", config.Theme)) if common.RedisEnabled { // for compatibility with old versions - common.MemoryCacheEnabled = true + config.MemoryCacheEnabled = true } - if common.MemoryCacheEnabled { + if config.MemoryCacheEnabled { logger.SysLog("memory cache enabled") - logger.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) + logger.SysError(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency)) model.InitChannelCache() } - if common.MemoryCacheEnabled { - go model.SyncOptions(common.SyncFrequency) - go model.SyncChannelCache(common.SyncFrequency) + if config.MemoryCacheEnabled { + go model.SyncOptions(config.SyncFrequency) + go model.SyncChannelCache(config.SyncFrequency) } if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) @@ -78,8 +79,8 @@ func main() { go controller.AutomaticallyTestChannels(frequency) } if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { - common.BatchUpdateEnabled = true - logger.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") + config.BatchUpdateEnabled = true + logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") model.InitBatchUpdater() } openai.InitTokenEncoders() @@ -92,7 +93,7 @@ func main() { server.Use(middleware.RequestId()) middleware.SetUpLogger(server) // Initialize session store - store := cookie.NewStore([]byte(common.SessionSecret)) + store := cookie.NewStore([]byte(config.SessionSecret)) server.Use(sessions.Sessions("session", store)) router.SetRouter(server, buildFS) diff --git a/middleware/logger.go b/middleware/logger.go index 02f2e0a9..ca372c52 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -3,14 +3,14 @@ package middleware import ( "fmt" "github.com/gin-gonic/gin" - "one-api/common" + "one-api/common/logger" ) func SetUpLogger(server *gin.Engine) { server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { var requestID string if param.Keys != nil { - requestID = param.Keys[common.RequestIdKey].(string) + requestID = param.Keys[logger.RequestIdKey].(string) } return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", param.TimeStamp.Format("2006/01/02 - 15:04:05"), diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go index 8e5cff6c..e89ef8d6 100644 --- a/middleware/rate-limit.go +++ b/middleware/rate-limit.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "one-api/common/config" "time" ) @@ -26,7 +27,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st } if listLength < int64(maxRequestNum) { rdb.LPush(ctx, key, time.Now().Format(timeFormat)) - rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) } else { oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTime, err := time.Parse(timeFormat, oldTimeStr) @@ -47,14 +48,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st // time.Since will return negative number! // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows if int64(nowTime.Sub(oldTime).Seconds()) < duration { - rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) c.Status(http.StatusTooManyRequests) c.Abort() return } else { rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) - rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) } } } @@ -75,7 +76,7 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi } } else { // It's safe to call multi times. - inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) + inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration) return func(c *gin.Context) { memoryRateLimiter(c, maxRequestNum, duration, mark) } @@ -83,21 +84,21 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi } func GlobalWebRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW") + return rateLimitFactory(config.GlobalWebRateLimitNum, config.GlobalWebRateLimitDuration, "GW") } func GlobalAPIRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA") + return rateLimitFactory(config.GlobalApiRateLimitNum, config.GlobalApiRateLimitDuration, "GA") } func CriticalRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT") + return rateLimitFactory(config.CriticalRateLimitNum, config.CriticalRateLimitDuration, "CT") } func DownloadRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW") + return rateLimitFactory(config.DownloadRateLimitNum, config.DownloadRateLimitDuration, "DW") } func UploadRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP") + return rateLimitFactory(config.UploadRateLimitNum, config.UploadRateLimitDuration, "UP") } diff --git a/middleware/request-id.go b/middleware/request-id.go index e623be7a..811a7ad8 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -3,16 +3,17 @@ package middleware import ( "context" "github.com/gin-gonic/gin" - "one-api/common" + "one-api/common/helper" + "one-api/common/logger" ) func RequestId() func(c *gin.Context) { return func(c *gin.Context) { - id := common.GetTimeString() + common.GetRandomString(8) - c.Set(common.RequestIdKey, id) - ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) + id := helper.GetTimeString() + helper.GetRandomString(8) + c.Set(logger.RequestIdKey, id) + ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) c.Request = c.Request.WithContext(ctx) - c.Header(common.RequestIdKey, id) + c.Header(logger.RequestIdKey, id) c.Next() } } diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go index 6f295864..629395e7 100644 --- a/middleware/turnstile-check.go +++ b/middleware/turnstile-check.go @@ -6,7 +6,7 @@ import ( "github.com/gin-gonic/gin" "net/http" "net/url" - "one-api/common" + "one-api/common/config" "one-api/common/logger" ) @@ -16,7 +16,7 @@ type turnstileCheckResponse struct { func TurnstileCheck() gin.HandlerFunc { return func(c *gin.Context) { - if common.TurnstileCheckEnabled { + if config.TurnstileCheckEnabled { session := sessions.Default(c) turnstileChecked := session.Get("turnstile") if turnstileChecked != nil { @@ -33,7 +33,7 @@ func TurnstileCheck() gin.HandlerFunc { return } rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{ - "secret": {common.TurnstileSecretKey}, + "secret": {config.TurnstileSecretKey}, "response": {response}, "remoteip": {c.ClientIP()}, }) diff --git a/middleware/utils.go b/middleware/utils.go index 31620bf2..d866d75b 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -2,14 +2,14 @@ package middleware import ( "github.com/gin-gonic/gin" - "one-api/common" + "one-api/common/helper" "one-api/common/logger" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { c.JSON(statusCode, gin.H{ "error": gin.H{ - "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), + "message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)), "type": "one_api_error", }, }) diff --git a/model/cache.go b/model/cache.go index eaed5bba..a81bdddd 100644 --- a/model/cache.go +++ b/model/cache.go @@ -6,6 +6,7 @@ import ( "fmt" "math/rand" "one-api/common" + "one-api/common/config" "one-api/common/logger" "sort" "strconv" @@ -15,10 +16,10 @@ import ( ) var ( - TokenCacheSeconds = common.SyncFrequency - UserId2GroupCacheSeconds = common.SyncFrequency - UserId2QuotaCacheSeconds = common.SyncFrequency - UserId2StatusCacheSeconds = common.SyncFrequency + TokenCacheSeconds = config.SyncFrequency + UserId2GroupCacheSeconds = config.SyncFrequency + UserId2QuotaCacheSeconds = config.SyncFrequency + UserId2StatusCacheSeconds = config.SyncFrequency ) func CacheGetTokenByKey(key string) (*Token, error) { @@ -191,7 +192,7 @@ func SyncChannelCache(frequency int) { } func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { - if !common.MemoryCacheEnabled { + if !config.MemoryCacheEnabled { return GetRandomSatisfiedChannel(group, model) } channelSyncLock.RLock() diff --git a/model/channel.go b/model/channel.go index d89d1666..966b7c39 100644 --- a/model/channel.go +++ b/model/channel.go @@ -5,6 +5,8 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/common/logger" ) @@ -45,7 +47,7 @@ func SearchChannels(keyword string) (channels []*Channel, err error) { if common.UsingPostgreSQL { keyCol = `"key"` } - err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error + err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", helper.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error return channels, err } @@ -125,7 +127,7 @@ func (channel *Channel) Update() error { func (channel *Channel) UpdateResponseTime(responseTime int64) { err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ - TestTime: common.GetTimestamp(), + TestTime: helper.GetTimestamp(), ResponseTime: int(responseTime), }).Error if err != nil { @@ -135,7 +137,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) { func (channel *Channel) UpdateBalance(balance float64) { err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ - BalanceUpdatedTime: common.GetTimestamp(), + BalanceUpdatedTime: helper.GetTimestamp(), Balance: balance, }).Error if err != nil { @@ -165,7 +167,7 @@ func UpdateChannelStatusById(id int, status int) { } func UpdateChannelUsedQuota(id int, quota int) { - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) return } diff --git a/model/log.go b/model/log.go index 728c4b17..78b6d9b3 100644 --- a/model/log.go +++ b/model/log.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/common/logger" "gorm.io/gorm" @@ -33,13 +35,13 @@ const ( ) func RecordLog(userId int, logType int, content string) { - if logType == LogTypeConsume && !common.LogConsumeEnabled { + if logType == LogTypeConsume && !config.LogConsumeEnabled { return } log := &Log{ UserId: userId, Username: GetUsernameById(userId), - CreatedAt: common.GetTimestamp(), + CreatedAt: helper.GetTimestamp(), Type: logType, Content: content, } @@ -51,13 +53,13 @@ func RecordLog(userId int, logType int, content string) { func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) - if !common.LogConsumeEnabled { + if !config.LogConsumeEnabled { return } log := &Log{ UserId: userId, Username: GetUsernameById(userId), - CreatedAt: common.GetTimestamp(), + CreatedAt: helper.GetTimestamp(), Type: LogTypeConsume, Content: content, PromptTokens: promptTokens, @@ -126,12 +128,12 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int } func SearchAllLogs(keyword string) (logs []*Log, err error) { - err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error + err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error return logs, err } func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { - err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error + err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error return logs, err } diff --git a/model/main.go b/model/main.go index 0b9c4f2b..2ed6f0e3 100644 --- a/model/main.go +++ b/model/main.go @@ -7,6 +7,8 @@ import ( "gorm.io/driver/sqlite" "gorm.io/gorm" "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/common/logger" "os" "strings" @@ -30,7 +32,7 @@ func createRootAccountIfNeed() error { Role: common.RoleRootUser, Status: common.UserStatusEnabled, DisplayName: "Root User", - AccessToken: common.GetUUID(), + AccessToken: helper.GetUUID(), Quota: 100000000, } DB.Create(&rootUser) @@ -70,7 +72,7 @@ func chooseDB() (*gorm.DB, error) { func InitDB() (err error) { db, err := chooseDB() if err == nil { - if common.DebugEnabled { + if config.DebugEnabled { db = db.Debug() } DB = db @@ -78,11 +80,11 @@ func InitDB() (err error) { if err != nil { return err } - sqlDB.SetMaxIdleConns(common.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) - sqlDB.SetMaxOpenConns(common.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) - sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) + sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) - if !common.IsMasterNode { + if !config.IsMasterNode { return nil } logger.SysLog("database migration started") diff --git a/model/option.go b/model/option.go index 80abff20..e211264c 100644 --- a/model/option.go +++ b/model/option.go @@ -2,6 +2,7 @@ package model import ( "one-api/common" + "one-api/common/config" "one-api/common/logger" "strconv" "strings" @@ -21,60 +22,56 @@ func AllOption() ([]*Option, error) { } func InitOptionMap() { - common.OptionMapRWMutex.Lock() - common.OptionMap = make(map[string]string) - common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) - common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) - common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) - common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) - common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) - common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) - common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) - common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) - common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) - common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) - common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) - common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) - common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) - common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) - common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) - common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) - common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) - common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) - common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) - common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") - common.OptionMap["SMTPServer"] = "" - common.OptionMap["SMTPFrom"] = "" - common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) - common.OptionMap["SMTPAccount"] = "" - common.OptionMap["SMTPToken"] = "" - common.OptionMap["Notice"] = "" - common.OptionMap["About"] = "" - common.OptionMap["HomePageContent"] = "" - common.OptionMap["Footer"] = common.Footer - common.OptionMap["SystemName"] = common.SystemName - common.OptionMap["Logo"] = common.Logo - common.OptionMap["ServerAddress"] = "" - common.OptionMap["GitHubClientId"] = "" - common.OptionMap["GitHubClientSecret"] = "" - common.OptionMap["WeChatServerAddress"] = "" - common.OptionMap["WeChatServerToken"] = "" - common.OptionMap["WeChatAccountQRCodeImageURL"] = "" - common.OptionMap["TurnstileSiteKey"] = "" - common.OptionMap["TurnstileSecretKey"] = "" - common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) - common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) - common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) - common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) - common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) - common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() - common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() - common.OptionMap["TopUpLink"] = common.TopUpLink - common.OptionMap["ChatLink"] = common.ChatLink - common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) - common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) - common.OptionMap["Theme"] = common.Theme - common.OptionMapRWMutex.Unlock() + config.OptionMapRWMutex.Lock() + config.OptionMap = make(map[string]string) + config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled) + config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) + config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) + config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) + config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) + config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) + config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) + config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled) + config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled) + config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled) + config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled) + config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled) + config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled) + config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64) + config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled) + config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",") + config.OptionMap["SMTPServer"] = "" + config.OptionMap["SMTPFrom"] = "" + config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort) + config.OptionMap["SMTPAccount"] = "" + config.OptionMap["SMTPToken"] = "" + config.OptionMap["Notice"] = "" + config.OptionMap["About"] = "" + config.OptionMap["HomePageContent"] = "" + config.OptionMap["Footer"] = config.Footer + config.OptionMap["SystemName"] = config.SystemName + config.OptionMap["Logo"] = config.Logo + config.OptionMap["ServerAddress"] = "" + config.OptionMap["GitHubClientId"] = "" + config.OptionMap["GitHubClientSecret"] = "" + config.OptionMap["WeChatServerAddress"] = "" + config.OptionMap["WeChatServerToken"] = "" + config.OptionMap["WeChatAccountQRCodeImageURL"] = "" + config.OptionMap["TurnstileSiteKey"] = "" + config.OptionMap["TurnstileSecretKey"] = "" + config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser) + config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter) + config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee) + config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold) + config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota) + config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() + config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() + config.OptionMap["TopUpLink"] = config.TopUpLink + config.OptionMap["ChatLink"] = config.ChatLink + config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) + config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes) + config.OptionMap["Theme"] = config.Theme + config.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() } @@ -113,117 +110,104 @@ func UpdateOption(key string, value string) error { } func updateOptionMap(key string, value string) (err error) { - common.OptionMapRWMutex.Lock() - defer common.OptionMapRWMutex.Unlock() - common.OptionMap[key] = value - if strings.HasSuffix(key, "Permission") { - intValue, _ := strconv.Atoi(value) - switch key { - case "FileUploadPermission": - common.FileUploadPermission = intValue - case "FileDownloadPermission": - common.FileDownloadPermission = intValue - case "ImageUploadPermission": - common.ImageUploadPermission = intValue - case "ImageDownloadPermission": - common.ImageDownloadPermission = intValue - } - } + config.OptionMapRWMutex.Lock() + defer config.OptionMapRWMutex.Unlock() + config.OptionMap[key] = value if strings.HasSuffix(key, "Enabled") { boolValue := value == "true" switch key { case "PasswordRegisterEnabled": - common.PasswordRegisterEnabled = boolValue + config.PasswordRegisterEnabled = boolValue case "PasswordLoginEnabled": - common.PasswordLoginEnabled = boolValue + config.PasswordLoginEnabled = boolValue case "EmailVerificationEnabled": - common.EmailVerificationEnabled = boolValue + config.EmailVerificationEnabled = boolValue case "GitHubOAuthEnabled": - common.GitHubOAuthEnabled = boolValue + config.GitHubOAuthEnabled = boolValue case "WeChatAuthEnabled": - common.WeChatAuthEnabled = boolValue + config.WeChatAuthEnabled = boolValue case "TurnstileCheckEnabled": - common.TurnstileCheckEnabled = boolValue + config.TurnstileCheckEnabled = boolValue case "RegisterEnabled": - common.RegisterEnabled = boolValue + config.RegisterEnabled = boolValue case "EmailDomainRestrictionEnabled": - common.EmailDomainRestrictionEnabled = boolValue + config.EmailDomainRestrictionEnabled = boolValue case "AutomaticDisableChannelEnabled": - common.AutomaticDisableChannelEnabled = boolValue + config.AutomaticDisableChannelEnabled = boolValue case "AutomaticEnableChannelEnabled": - common.AutomaticEnableChannelEnabled = boolValue + config.AutomaticEnableChannelEnabled = boolValue case "ApproximateTokenEnabled": - common.ApproximateTokenEnabled = boolValue + config.ApproximateTokenEnabled = boolValue case "LogConsumeEnabled": - common.LogConsumeEnabled = boolValue + config.LogConsumeEnabled = boolValue case "DisplayInCurrencyEnabled": - common.DisplayInCurrencyEnabled = boolValue + config.DisplayInCurrencyEnabled = boolValue case "DisplayTokenStatEnabled": - common.DisplayTokenStatEnabled = boolValue + config.DisplayTokenStatEnabled = boolValue } } switch key { case "EmailDomainWhitelist": - common.EmailDomainWhitelist = strings.Split(value, ",") + config.EmailDomainWhitelist = strings.Split(value, ",") case "SMTPServer": - common.SMTPServer = value + config.SMTPServer = value case "SMTPPort": intValue, _ := strconv.Atoi(value) - common.SMTPPort = intValue + config.SMTPPort = intValue case "SMTPAccount": - common.SMTPAccount = value + config.SMTPAccount = value case "SMTPFrom": - common.SMTPFrom = value + config.SMTPFrom = value case "SMTPToken": - common.SMTPToken = value + config.SMTPToken = value case "ServerAddress": - common.ServerAddress = value + config.ServerAddress = value case "GitHubClientId": - common.GitHubClientId = value + config.GitHubClientId = value case "GitHubClientSecret": - common.GitHubClientSecret = value + config.GitHubClientSecret = value case "Footer": - common.Footer = value + config.Footer = value case "SystemName": - common.SystemName = value + config.SystemName = value case "Logo": - common.Logo = value + config.Logo = value case "WeChatServerAddress": - common.WeChatServerAddress = value + config.WeChatServerAddress = value case "WeChatServerToken": - common.WeChatServerToken = value + config.WeChatServerToken = value case "WeChatAccountQRCodeImageURL": - common.WeChatAccountQRCodeImageURL = value + config.WeChatAccountQRCodeImageURL = value case "TurnstileSiteKey": - common.TurnstileSiteKey = value + config.TurnstileSiteKey = value case "TurnstileSecretKey": - common.TurnstileSecretKey = value + config.TurnstileSecretKey = value case "QuotaForNewUser": - common.QuotaForNewUser, _ = strconv.Atoi(value) + config.QuotaForNewUser, _ = strconv.Atoi(value) case "QuotaForInviter": - common.QuotaForInviter, _ = strconv.Atoi(value) + config.QuotaForInviter, _ = strconv.Atoi(value) case "QuotaForInvitee": - common.QuotaForInvitee, _ = strconv.Atoi(value) + config.QuotaForInvitee, _ = strconv.Atoi(value) case "QuotaRemindThreshold": - common.QuotaRemindThreshold, _ = strconv.Atoi(value) + config.QuotaRemindThreshold, _ = strconv.Atoi(value) case "PreConsumedQuota": - common.PreConsumedQuota, _ = strconv.Atoi(value) + config.PreConsumedQuota, _ = strconv.Atoi(value) case "RetryTimes": - common.RetryTimes, _ = strconv.Atoi(value) + config.RetryTimes, _ = strconv.Atoi(value) case "ModelRatio": err = common.UpdateModelRatioByJSONString(value) case "GroupRatio": err = common.UpdateGroupRatioByJSONString(value) case "TopUpLink": - common.TopUpLink = value + config.TopUpLink = value case "ChatLink": - common.ChatLink = value + config.ChatLink = value case "ChannelDisableThreshold": - common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) + config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) case "QuotaPerUnit": - common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) + config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) case "Theme": - common.Theme = value + config.Theme = value } return err } diff --git a/model/redemption.go b/model/redemption.go index ba1e1077..026794e0 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -5,7 +5,7 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" - "one-api/common/logger" + "one-api/common/helper" ) type Redemption struct { @@ -68,7 +68,7 @@ func Redeem(key string, userId int) (quota int, err error) { if err != nil { return err } - redemption.RedeemedTime = common.GetTimestamp() + redemption.RedeemedTime = helper.GetTimestamp() redemption.Status = common.RedemptionCodeStatusUsed err = tx.Save(redemption).Error return err @@ -76,7 +76,7 @@ func Redeem(key string, userId int) (quota int, err error) { if err != nil { return 0, errors.New("兑换失败," + err.Error()) } - RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", logger.LogQuota(redemption.Quota))) + RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) return redemption.Quota, nil } diff --git a/model/token.go b/model/token.go index 570de47d..2087225b 100644 --- a/model/token.go +++ b/model/token.go @@ -5,6 +5,8 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/common/logger" ) @@ -54,7 +56,7 @@ func ValidateUserToken(key string) (token *Token, err error) { if token.Status != common.TokenStatusEnabled { return nil, errors.New("该令牌状态不可用") } - if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { + if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() { if !common.RedisEnabled { token.Status = common.TokenStatusExpired err := token.SelectUpdate() @@ -139,7 +141,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeTokenQuota, id, quota) return nil } @@ -151,7 +153,7 @@ func increaseTokenQuota(id int, quota int) (err error) { map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota + ?", quota), "used_quota": gorm.Expr("used_quota - ?", quota), - "accessed_time": common.GetTimestamp(), + "accessed_time": helper.GetTimestamp(), }, ).Error return err @@ -161,7 +163,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) return nil } @@ -173,7 +175,7 @@ func decreaseTokenQuota(id int, quota int) (err error) { map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota - ?", quota), "used_quota": gorm.Expr("used_quota + ?", quota), - "accessed_time": common.GetTimestamp(), + "accessed_time": helper.GetTimestamp(), }, ).Error return err @@ -197,7 +199,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { if userQuota < quota { return errors.New("用户额度不足") } - quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold + quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold noMoreQuota := userQuota-quota <= 0 if quotaTooLow || noMoreQuota { go func() { @@ -210,7 +212,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { prompt = "您的额度已用尽" } if email != "" { - topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) + topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress) err = common.SendEmail(prompt, email, fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink)) if err != nil { diff --git a/model/user.go b/model/user.go index 17f94d9f..82e9707b 100644 --- a/model/user.go +++ b/model/user.go @@ -5,6 +5,8 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/common/logger" "strings" ) @@ -90,24 +92,24 @@ func (user *User) Insert(inviterId int) error { return err } } - user.Quota = common.QuotaForNewUser - user.AccessToken = common.GetUUID() - user.AffCode = common.GetRandomString(4) + user.Quota = config.QuotaForNewUser + user.AccessToken = helper.GetUUID() + user.AffCode = helper.GetRandomString(4) result := DB.Create(user) if result.Error != nil { return result.Error } - if common.QuotaForNewUser > 0 { - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser))) + if config.QuotaForNewUser > 0 { + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) } if inviterId != 0 { - if common.QuotaForInvitee > 0 { - _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee) - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee))) + if config.QuotaForInvitee > 0 { + _ = IncreaseUserQuota(user.Id, config.QuotaForInvitee) + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) } - if common.QuotaForInviter > 0 { - _ = IncreaseUserQuota(inviterId, common.QuotaForInviter) - RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter))) + if config.QuotaForInviter > 0 { + _ = IncreaseUserQuota(inviterId, config.QuotaForInviter) + RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) } } return nil @@ -292,7 +294,7 @@ func IncreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUserQuota, id, quota) return nil } @@ -308,7 +310,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUserQuota, id, -quota) return nil } @@ -326,7 +328,7 @@ func GetRootUserEmail() (email string) { } func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUsedQuota, id, quota) addNewRecord(BatchUpdateTypeRequestCount, id, 1) return diff --git a/model/utils.go b/model/utils.go index e4797a78..e0826e0d 100644 --- a/model/utils.go +++ b/model/utils.go @@ -1,7 +1,7 @@ package model import ( - "one-api/common" + "one-api/common/config" "one-api/common/logger" "sync" "time" @@ -29,7 +29,7 @@ func init() { func InitBatchUpdater() { go func() { for { - time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) + time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second) batchUpdate() } }() diff --git a/relay/channel/aiproxy/main.go b/relay/channel/aiproxy/main.go index 63fef55e..af9cd6f6 100644 --- a/relay/channel/aiproxy/main.go +++ b/relay/channel/aiproxy/main.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/helper" "one-api/common/logger" "one-api/relay/channel/openai" "one-api/relay/constant" @@ -51,9 +52,9 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon FinishReason: "stop", } fullTextResponse := openai.TextResponse{ - Id: common.GetUUID(), + Id: helper.GetUUID(), Object: "chat.completion", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Choices: []openai.TextResponseChoice{choice}, } return &fullTextResponse @@ -64,9 +65,9 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion choice.Delta.Content = aiProxyDocuments2Markdown(documents) choice.FinishReason = &constant.StopFinishReason return &openai.ChatCompletionsStreamResponse{ - Id: common.GetUUID(), + Id: helper.GetUUID(), Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Model: "", Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } @@ -76,9 +77,9 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = response.Content return &openai.ChatCompletionsStreamResponse{ - Id: common.GetUUID(), + Id: helper.GetUUID(), Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Model: response.Model, Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go index c5ada0d7..81dc5370 100644 --- a/relay/channel/ali/main.go +++ b/relay/channel/ali/main.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/helper" "one-api/common/logger" "one-api/relay/channel/openai" "strings" @@ -119,7 +120,7 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse { fullTextResponse := openai.TextResponse{ Id: response.RequestId, Object: "chat.completion", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Choices: []openai.TextResponseChoice{choice}, Usage: openai.Usage{ PromptTokens: response.Usage.InputTokens, @@ -140,7 +141,7 @@ func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletions response := openai.ChatCompletionsStreamResponse{ Id: aliResponse.RequestId, Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Model: "qwen", Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } diff --git a/relay/channel/anthropic/main.go b/relay/channel/anthropic/main.go index 006779b2..060fcde8 100644 --- a/relay/channel/anthropic/main.go +++ b/relay/channel/anthropic/main.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/helper" "one-api/common/logger" "one-api/relay/channel/openai" "strings" @@ -79,9 +80,9 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), } fullTextResponse := openai.TextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Object: "chat.completion", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Choices: []openai.TextResponseChoice{choice}, } return &fullTextResponse @@ -89,8 +90,8 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - createdTime := common.GetTimestamp() + responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) + createdTime := helper.GetTimestamp() scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { diff --git a/relay/channel/google/gemini.go b/relay/channel/google/gemini.go index 0f4e606c..3adc3fdd 100644 --- a/relay/channel/google/gemini.go +++ b/relay/channel/google/gemini.go @@ -7,6 +7,8 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" + "one-api/common/helper" "one-api/common/image" "one-api/common/logger" "one-api/relay/channel/openai" @@ -29,19 +31,19 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe SafetySettings: []GeminiChatSafetySettings{ { Category: "HARM_CATEGORY_HARASSMENT", - Threshold: common.GeminiSafetySetting, + Threshold: config.GeminiSafetySetting, }, { Category: "HARM_CATEGORY_HATE_SPEECH", - Threshold: common.GeminiSafetySetting, + Threshold: config.GeminiSafetySetting, }, { Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", - Threshold: common.GeminiSafetySetting, + Threshold: config.GeminiSafetySetting, }, { Category: "HARM_CATEGORY_DANGEROUS_CONTENT", - Threshold: common.GeminiSafetySetting, + Threshold: config.GeminiSafetySetting, }, }, GenerationConfig: GeminiChatGenerationConfig{ @@ -152,9 +154,9 @@ type GeminiChatPromptFeedback struct { func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse { fullTextResponse := openai.TextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Object: "chat.completion", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), } for i, candidate := range response.Candidates { @@ -230,9 +232,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = dummy.Content response := openai.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Model: "gemini-pro", Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } diff --git a/relay/channel/google/palm.go b/relay/channel/google/palm.go index c2518a07..3c86a432 100644 --- a/relay/channel/google/palm.go +++ b/relay/channel/google/palm.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/helper" "one-api/common/logger" "one-api/relay/channel/openai" "one-api/relay/constant" @@ -72,8 +73,8 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompl func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - createdTime := common.GetTimestamp() + responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) + createdTime := helper.GetTimestamp() dataChan := make(chan string) stopChan := make(chan bool) go func() { diff --git a/relay/channel/openai/token.go b/relay/channel/openai/token.go index b398c220..6803770e 100644 --- a/relay/channel/openai/token.go +++ b/relay/channel/openai/token.go @@ -6,6 +6,7 @@ import ( "github.com/pkoukk/tiktoken-go" "math" "one-api/common" + "one-api/common/config" "one-api/common/image" "one-api/common/logger" "strings" @@ -56,7 +57,7 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken { } func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { - if common.ApproximateTokenEnabled { + if config.ApproximateTokenEnabled { return int(float64(len(text)) * 0.38) } return len(tokenEncoder.Encode(text, nil, nil)) diff --git a/relay/channel/tencent/main.go b/relay/channel/tencent/main.go index 9203249a..64091d62 100644 --- a/relay/channel/tencent/main.go +++ b/relay/channel/tencent/main.go @@ -12,6 +12,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/helper" "one-api/common/logger" "one-api/relay/channel/openai" "one-api/relay/constant" @@ -47,9 +48,9 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { stream = 1 } return &ChatRequest{ - Timestamp: common.GetTimestamp(), - Expired: common.GetTimestamp() + 24*60*60, - QueryID: common.GetUUID(), + Timestamp: helper.GetTimestamp(), + Expired: helper.GetTimestamp() + 24*60*60, + QueryID: helper.GetUUID(), Temperature: request.Temperature, TopP: request.TopP, Stream: stream, @@ -60,7 +61,7 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest { func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { fullTextResponse := openai.TextResponse{ Object: "chat.completion", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Usage: response.Usage, } if len(response.Choices) > 0 { @@ -80,7 +81,7 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { response := openai.ChatCompletionsStreamResponse{ Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Model: "tencent-hunyuan", } if len(TencentResponse.Choices) > 0 { diff --git a/relay/channel/xunfei/main.go b/relay/channel/xunfei/main.go index 1c55cc09..906d2844 100644 --- a/relay/channel/xunfei/main.go +++ b/relay/channel/xunfei/main.go @@ -12,6 +12,7 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/common/helper" "one-api/common/logger" "one-api/relay/channel/openai" "one-api/relay/constant" @@ -69,7 +70,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse { } fullTextResponse := openai.TextResponse{ Object: "chat.completion", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Choices: []openai.TextResponseChoice{choice}, Usage: response.Payload.Usage.Text, } @@ -91,7 +92,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl } response := openai.ChatCompletionsStreamResponse{ Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Model: "SparkDesk", Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } diff --git a/relay/channel/zhipu/main.go b/relay/channel/zhipu/main.go index c818c80e..d831f57a 100644 --- a/relay/channel/zhipu/main.go +++ b/relay/channel/zhipu/main.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/helper" "one-api/common/logger" "one-api/relay/channel/openai" "one-api/relay/constant" @@ -102,7 +103,7 @@ func responseZhipu2OpenAI(response *Response) *openai.TextResponse { fullTextResponse := openai.TextResponse{ Id: response.Data.TaskId, Object: "chat.completion", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)), Usage: response.Data.Usage, } @@ -128,7 +129,7 @@ func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStr choice.Delta.Content = zhipuResponse response := openai.ChatCompletionsStreamResponse{ Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Model: "chatglm", Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } @@ -142,7 +143,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai. response := openai.ChatCompletionsStreamResponse{ Id: zhipuResponse.RequestId, Object: "chat.completion.chunk", - Created: common.GetTimestamp(), + Created: helper.GetTimestamp(), Model: "chatglm", Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, } diff --git a/relay/controller/audio.go b/relay/controller/audio.go index d8a896de..822d7e39 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -11,6 +11,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/model" "one-api/relay/channel/openai" @@ -54,7 +55,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) quota = preConsumedQuota default: - preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio) + preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio) } userQuota, err := model.CacheGetUserQuota(userId) if err != nil { diff --git a/relay/controller/text.go b/relay/controller/text.go index 968cc751..68354628 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -8,6 +8,7 @@ import ( "math" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/model" "one-api/relay/channel/openai" @@ -52,7 +53,7 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode case constant.RelayModeModerations: promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model) } - preConsumedTokens := common.PreConsumedQuota + preConsumedTokens := config.PreConsumedQuota if textRequest.MaxTokens != 0 { preConsumedTokens = promptTokens + textRequest.MaxTokens } diff --git a/relay/controller/util.go b/relay/controller/util.go index cdb10dbf..02f1b30f 100644 --- a/relay/controller/util.go +++ b/relay/controller/util.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/helper" "one-api/relay/channel/aiproxy" "one-api/relay/channel/ali" "one-api/relay/channel/anthropic" @@ -66,7 +67,7 @@ func GetRequestURL(requestURL string, apiType int, relayMode int, meta *util.Rel case constant.APITypePaLM: fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL) case constant.APITypeGemini: - version := common.AssignOrDefault(meta.APIVersion, "v1") + version := helper.AssignOrDefault(meta.APIVersion, "v1") action := "generateContent" if textRequest.Stream { action = "streamGenerateContent" diff --git a/relay/util/common.go b/relay/util/common.go index d7596188..be31857b 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/model" "one-api/relay/channel/openai" @@ -17,7 +18,7 @@ import ( ) func ShouldDisableChannel(err *openai.Error, statusCode int) bool { - if !common.AutomaticDisableChannelEnabled { + if !config.AutomaticDisableChannelEnabled { return false } if err == nil { @@ -33,7 +34,7 @@ func ShouldDisableChannel(err *openai.Error, statusCode int) bool { } func ShouldEnableChannel(err error, openAIErr *openai.Error) bool { - if !common.AutomaticEnableChannelEnabled { + if !config.AutomaticEnableChannelEnabled { return false } if err != nil { diff --git a/relay/util/init.go b/relay/util/init.go index d308d900..62d44d15 100644 --- a/relay/util/init.go +++ b/relay/util/init.go @@ -2,7 +2,7 @@ package util import ( "net/http" - "one-api/common" + "one-api/common/config" "time" ) @@ -10,11 +10,11 @@ var HTTPClient *http.Client var ImpatientHTTPClient *http.Client func init() { - if common.RelayTimeout == 0 { + if config.RelayTimeout == 0 { HTTPClient = &http.Client{} } else { HTTPClient = &http.Client{ - Timeout: time.Duration(common.RelayTimeout) * time.Second, + Timeout: time.Duration(config.RelayTimeout) * time.Second, } } diff --git a/router/main.go b/router/main.go index 733a1033..6504b312 100644 --- a/router/main.go +++ b/router/main.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "net/http" - "one-api/common" + "one-api/common/config" "one-api/common/logger" "os" "strings" @@ -16,7 +16,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS) { SetDashboardRouter(router) SetRelayRouter(router) frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") - if common.IsMasterNode && frontendBaseUrl != "" { + if config.IsMasterNode && frontendBaseUrl != "" { frontendBaseUrl = "" logger.SysLog("FRONTEND_BASE_URL is ignored on master node") } diff --git a/router/web-router.go b/router/web-router.go index 7328c7a3..95d8fcb9 100644 --- a/router/web-router.go +++ b/router/web-router.go @@ -8,17 +8,18 @@ import ( "github.com/gin-gonic/gin" "net/http" "one-api/common" + "one-api/common/config" "one-api/controller" "one-api/middleware" "strings" ) func SetWebRouter(router *gin.Engine, buildFS embed.FS) { - indexPageData, _ := buildFS.ReadFile(fmt.Sprintf("web/build/%s/index.html", common.Theme)) + indexPageData, _ := buildFS.ReadFile(fmt.Sprintf("web/build/%s/index.html", config.Theme)) router.Use(gzip.Gzip(gzip.DefaultCompression)) router.Use(middleware.GlobalWebRateLimit()) router.Use(middleware.Cache()) - router.Use(static.Serve("/", common.EmbedFolder(buildFS, fmt.Sprintf("web/build/%s", common.Theme)))) + router.Use(static.Serve("/", common.EmbedFolder(buildFS, fmt.Sprintf("web/build/%s", config.Theme)))) router.NoRoute(func(c *gin.Context) { if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") { controller.RelayNotFound(c)