refactor: refactor config part

This commit is contained in:
JustSong 2024-01-21 23:18:32 +08:00
parent b373882814
commit 42569c83c0
53 changed files with 745 additions and 708 deletions

127
common/config/config.go Normal file
View File

@ -0,0 +1,127 @@
package config
import (
"one-api/common/helper"
"os"
"strconv"
"sync"
"time"
"github.com/google/uuid"
)
var SystemName = "One API"
var ServerAddress = "http://localhost:3000"
var Footer = ""
var Logo = ""
var TopUpLink = ""
var ChatLink = ""
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
var ItemsPerPage = 10
var MaxRecentItems = 100
var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{
"gmail.com",
"163.com",
"126.com",
"qq.com",
"outlook.com",
"hotmail.com",
"icloud.com",
"yahoo.com",
"foxmail.com",
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
var LogConsumeEnabled = true
var SMTPServer = ""
var SMTPPort = 587
var SMTPAccount = ""
var SMTPFrom = ""
var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
var TurnstileSiteKey = ""
var TurnstileSecretKey = ""
var QuotaForNewUser = 0
var QuotaForInviter = 0
var QuotaForInvitee = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var AutomaticEnableChannelEnabled = false
var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
var ApproximateTokenEnabled = false
var RetryTimes = 0
var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var Theme = helper.GetOrDefaultEnvString("THEME", "default")
var ValidThemes = map[string]bool{
"default": true,
"berry": true,
}
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
CriticalRateLimitNum = 20
CriticalRateLimitDuration int64 = 20 * 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute

View File

@ -1,114 +1,9 @@
package common package common
import ( import "time"
"os"
"strconv"
"sync"
"time"
"github.com/google/uuid"
)
var StartTime = time.Now().Unix() // unit: second var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "One API"
var ServerAddress = "http://localhost:3000"
var Footer = ""
var Logo = ""
var TopUpLink = ""
var ChatLink = ""
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
var ItemsPerPage = 10
var MaxRecentItems = 100
var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{
"gmail.com",
"163.com",
"126.com",
"qq.com",
"outlook.com",
"hotmail.com",
"icloud.com",
"yahoo.com",
"foxmail.com",
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
var LogConsumeEnabled = true
var SMTPServer = ""
var SMTPPort = 587
var SMTPAccount = ""
var SMTPFrom = ""
var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
var TurnstileSiteKey = ""
var TurnstileSecretKey = ""
var QuotaForNewUser = 0
var QuotaForInviter = 0
var QuotaForInvitee = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var AutomaticEnableChannelEnabled = false
var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
var ApproximateTokenEnabled = false
var RetryTimes = 0
var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var Theme = GetOrDefaultEnvString("THEME", "default")
var ValidThemes = map[string]bool{
"default": true,
"berry": true,
}
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
const ( const (
RoleGuestUser = 0 RoleGuestUser = 0
@ -117,34 +12,6 @@ const (
RoleRootUser = 100 RoleRootUser = 100
) )
var (
FileUploadPermission = RoleGuestUser
FileDownloadPermission = RoleGuestUser
ImageUploadPermission = RoleGuestUser
ImageDownloadPermission = RoleGuestUser
)
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
CriticalRateLimitNum = 20
CriticalRateLimitDuration int64 = 20 * 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute
const ( const (
UserStatusEnabled = 1 // don't use 0, 0 is the default value! UserStatusEnabled = 1 // don't use 0, 0 is the default value!
UserStatusDisabled = 2 // also don't use 0 UserStatusDisabled = 2 // also don't use 0

View File

@ -1,7 +1,9 @@
package common package common
import "one-api/common/helper"
var UsingSQLite = false var UsingSQLite = false
var UsingPostgreSQL = false var UsingPostgreSQL = false
var SQLitePath = "one-api.db" var SQLitePath = "one-api.db"
var SQLiteBusyTimeout = GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000)

View File

@ -6,18 +6,19 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net/smtp" "net/smtp"
"one-api/common/config"
"strings" "strings"
"time" "time"
) )
func SendEmail(subject string, receiver string, content string) error { func SendEmail(subject string, receiver string, content string) error {
if SMTPFrom == "" { // for compatibility if config.SMTPFrom == "" { // for compatibility
SMTPFrom = SMTPAccount config.SMTPFrom = config.SMTPAccount
} }
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
// Extract domain from SMTPFrom // Extract domain from SMTPFrom
parts := strings.Split(SMTPFrom, "@") parts := strings.Split(config.SMTPFrom, "@")
var domain string var domain string
if len(parts) > 1 { if len(parts) > 1 {
domain = parts[1] domain = parts[1]
@ -36,21 +37,21 @@ func SendEmail(subject string, receiver string, content string) error {
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 "Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
"Date: %s\r\n"+ "Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort)
to := strings.Split(receiver, ";") to := strings.Split(receiver, ";")
if SMTPPort == 465 { if config.SMTPPort == 465 {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: SMTPServer, ServerName: config.SMTPServer,
} }
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig) conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig)
if err != nil { if err != nil {
return err return err
} }
client, err := smtp.NewClient(conn, SMTPServer) client, err := smtp.NewClient(conn, config.SMTPServer)
if err != nil { if err != nil {
return err return err
} }
@ -58,7 +59,7 @@ func SendEmail(subject string, receiver string, content string) error {
if err = client.Auth(auth); err != nil { if err = client.Auth(auth); err != nil {
return err return err
} }
if err = client.Mail(SMTPFrom); err != nil { if err = client.Mail(config.SMTPFrom); err != nil {
return err return err
} }
receiverEmails := strings.Split(receiver, ";") receiverEmails := strings.Split(receiver, ";")
@ -80,7 +81,7 @@ func SendEmail(subject string, receiver string, content string) error {
return err return err
} }
} else { } else {
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) err = smtp.SendMail(addr, auth, config.SMTPAccount, to, mail)
} }
return err return err
} }

224
common/helper/helper.go Normal file
View File

@ -0,0 +1,224 @@
package helper
import (
"fmt"
"github.com/google/uuid"
"html/template"
"log"
"math/rand"
"net"
"one-api/common/logger"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"time"
)
func OpenBrowser(url string) {
var err error
switch runtime.GOOS {
case "linux":
err = exec.Command("xdg-open", url).Start()
case "windows":
err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
case "darwin":
err = exec.Command("open", url).Start()
}
if err != nil {
log.Println(err)
}
}
func GetIp() (ip string) {
ips, err := net.InterfaceAddrs()
if err != nil {
log.Println(err)
return ip
}
for _, a := range ips {
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
ip = ipNet.IP.String()
if strings.HasPrefix(ip, "10") {
return
}
if strings.HasPrefix(ip, "172") {
return
}
if strings.HasPrefix(ip, "192.168") {
return
}
ip = ""
}
}
}
return
}
var sizeKB = 1024
var sizeMB = sizeKB * 1024
var sizeGB = sizeMB * 1024
func Bytes2Size(num int64) string {
numStr := ""
unit := "B"
if num/int64(sizeGB) > 1 {
numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
unit = "GB"
} else if num/int64(sizeMB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
unit = "MB"
} else if num/int64(sizeKB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
unit = "KB"
} else {
numStr = fmt.Sprintf("%d", num)
}
return numStr + " " + unit
}
func Seconds2Time(num int) (time string) {
if num/31104000 > 0 {
time += strconv.Itoa(num/31104000) + " 年 "
num %= 31104000
}
if num/2592000 > 0 {
time += strconv.Itoa(num/2592000) + " 个月 "
num %= 2592000
}
if num/86400 > 0 {
time += strconv.Itoa(num/86400) + " 天 "
num %= 86400
}
if num/3600 > 0 {
time += strconv.Itoa(num/3600) + " 小时 "
num %= 3600
}
if num/60 > 0 {
time += strconv.Itoa(num/60) + " 分钟 "
num %= 60
}
time += strconv.Itoa(num) + " 秒"
return
}
func Interface2String(inter interface{}) string {
switch inter.(type) {
case string:
return inter.(string)
case int:
return fmt.Sprintf("%d", inter.(int))
case float64:
return fmt.Sprintf("%f", inter.(float64))
}
return "Not Implemented"
}
func UnescapeHTML(x string) interface{} {
return template.HTML(x)
}
func IntMax(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func Max(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func GetOrDefaultEnvInt(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetOrDefaultEnvString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func AssignOrDefault(value string, defaultValue string) string {
if len(value) != 0 {
return value
}
return defaultValue
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}

View File

@ -4,6 +4,7 @@ import (
"flag" "flag"
"fmt" "fmt"
"log" "log"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"os" "os"
"path/filepath" "path/filepath"
@ -40,7 +41,7 @@ func init() {
if os.Getenv("SESSION_SECRET") == "random_string" { if os.Getenv("SESSION_SECRET") == "random_string" {
logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.") logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
} else { } else {
SessionSecret = os.Getenv("SESSION_SECRET") config.SessionSecret = os.Getenv("SESSION_SECRET")
} }
} }
if os.Getenv("SQLITE_PATH") != "" { if os.Getenv("SQLITE_PATH") != "" {
@ -58,5 +59,6 @@ func init() {
log.Fatal(err) log.Fatal(err)
} }
} }
logger.LogDir = *LogDir
} }
} }

View File

@ -0,0 +1,7 @@
package logger
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
var LogDir string

View File

@ -6,7 +6,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"log" "log"
"one-api/common"
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
@ -26,7 +25,7 @@ var setupLogLock sync.Mutex
var setupLogWorking bool var setupLogWorking bool
func SetupLogger() { func SetupLogger() {
if *common.LogDir != "" { if LogDir != "" {
ok := setupLogLock.TryLock() ok := setupLogLock.TryLock()
if !ok { if !ok {
log.Println("setup log is already working") log.Println("setup log is already working")
@ -36,7 +35,7 @@ func SetupLogger() {
setupLogLock.Unlock() setupLogLock.Unlock()
setupLogWorking = false setupLogWorking = false
}() }()
logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
log.Fatal("failed to open log file") log.Fatal("failed to open log file")
@ -85,7 +84,7 @@ func logHelper(ctx context.Context, level string, msg string) {
if level == loggerINFO { if level == loggerINFO {
writer = gin.DefaultWriter writer = gin.DefaultWriter
} }
id := ctx.Value(common.RequestIdKey) id := ctx.Value(RequestIdKey)
now := time.Now() now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
logCount++ // we don't need accurate count, so no lock here logCount++ // we don't need accurate count, so no lock here
@ -103,11 +102,3 @@ func FatalLog(v ...any) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1) os.Exit(1)
} }
func LogQuota(quota int) string {
if common.DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", float64(quota)/common.QuotaPerUnit)
} else {
return fmt.Sprintf("%d 点额度", quota)
}
}

View File

@ -2,223 +2,13 @@ package common
import ( import (
"fmt" "fmt"
"github.com/google/uuid" "one-api/common/config"
"html/template"
"log"
"math/rand"
"net"
"one-api/common/logger"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"time"
) )
func OpenBrowser(url string) { func LogQuota(quota int) string {
var err error if config.DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", float64(quota)/config.QuotaPerUnit)
switch runtime.GOOS {
case "linux":
err = exec.Command("xdg-open", url).Start()
case "windows":
err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
case "darwin":
err = exec.Command("open", url).Start()
}
if err != nil {
log.Println(err)
}
}
func GetIp() (ip string) {
ips, err := net.InterfaceAddrs()
if err != nil {
log.Println(err)
return ip
}
for _, a := range ips {
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
ip = ipNet.IP.String()
if strings.HasPrefix(ip, "10") {
return
}
if strings.HasPrefix(ip, "172") {
return
}
if strings.HasPrefix(ip, "192.168") {
return
}
ip = ""
}
}
}
return
}
var sizeKB = 1024
var sizeMB = sizeKB * 1024
var sizeGB = sizeMB * 1024
func Bytes2Size(num int64) string {
numStr := ""
unit := "B"
if num/int64(sizeGB) > 1 {
numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
unit = "GB"
} else if num/int64(sizeMB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
unit = "MB"
} else if num/int64(sizeKB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
unit = "KB"
} else { } else {
numStr = fmt.Sprintf("%d", num) return fmt.Sprintf("%d 点额度", quota)
}
return numStr + " " + unit
}
func Seconds2Time(num int) (time string) {
if num/31104000 > 0 {
time += strconv.Itoa(num/31104000) + " 年 "
num %= 31104000
}
if num/2592000 > 0 {
time += strconv.Itoa(num/2592000) + " 个月 "
num %= 2592000
}
if num/86400 > 0 {
time += strconv.Itoa(num/86400) + " 天 "
num %= 86400
}
if num/3600 > 0 {
time += strconv.Itoa(num/3600) + " 小时 "
num %= 3600
}
if num/60 > 0 {
time += strconv.Itoa(num/60) + " 分钟 "
num %= 60
}
time += strconv.Itoa(num) + " 秒"
return
}
func Interface2String(inter interface{}) string {
switch inter.(type) {
case string:
return inter.(string)
case int:
return fmt.Sprintf("%d", inter.(int))
case float64:
return fmt.Sprintf("%f", inter.(float64))
}
return "Not Implemented"
}
func UnescapeHTML(x string) interface{} {
return template.HTML(x)
}
func IntMax(a int, b int) int {
if a >= b {
return a
} else {
return b
} }
} }
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func Max(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func GetOrDefaultEnvInt(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetOrDefaultEnvString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func AssignOrDefault(value string, defaultValue string) string {
if len(value) != 0 {
return value
}
return defaultValue
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}

View File

@ -2,7 +2,7 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "one-api/common/config"
"one-api/model" "one-api/model"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
) )
@ -13,7 +13,7 @@ func GetSubscription(c *gin.Context) {
var err error var err error
var token *model.Token var token *model.Token
var expiredTime int64 var expiredTime int64
if common.DisplayTokenStatEnabled { if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId) token, err = model.GetTokenById(tokenId)
expiredTime = token.ExpiredTime expiredTime = token.ExpiredTime
@ -39,8 +39,8 @@ func GetSubscription(c *gin.Context) {
} }
quota := remainQuota + usedQuota quota := remainQuota + usedQuota
amount := float64(quota) amount := float64(quota)
if common.DisplayInCurrencyEnabled { if config.DisplayInCurrencyEnabled {
amount /= common.QuotaPerUnit amount /= config.QuotaPerUnit
} }
if token != nil && token.UnlimitedQuota { if token != nil && token.UnlimitedQuota {
amount = 100000000 amount = 100000000
@ -61,7 +61,7 @@ func GetUsage(c *gin.Context) {
var quota int var quota int
var err error var err error
var token *model.Token var token *model.Token
if common.DisplayTokenStatEnabled { if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId) token, err = model.GetTokenById(tokenId)
quota = token.UsedQuota quota = token.UsedQuota
@ -80,8 +80,8 @@ func GetUsage(c *gin.Context) {
return return
} }
amount := float64(quota) amount := float64(quota)
if common.DisplayInCurrencyEnabled { if config.DisplayInCurrencyEnabled {
amount /= common.QuotaPerUnit amount /= config.QuotaPerUnit
} }
usage := OpenAIUsageResponse{ usage := OpenAIUsageResponse{
Object: "list", Object: "list",

View File

@ -7,6 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/model" "one-api/model"
"one-api/relay/util" "one-api/relay/util"
@ -315,7 +316,7 @@ func updateAllChannelsBalance() error {
disableChannel(channel.Id, channel.Name, "余额不足") disableChannel(channel.Id, channel.Name, "余额不足")
} }
} }
time.Sleep(common.RequestInterval) time.Sleep(config.RequestInterval)
} }
return nil return nil
} }

View File

@ -8,6 +8,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/model" "one-api/model"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
@ -151,10 +152,10 @@ var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false var testAllChannelsRunning bool = false
func notifyRootUser(subject string, content string) { func notifyRootUser(subject string, content string) {
if common.RootUserEmail == "" { if config.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail() config.RootUserEmail = model.GetRootUserEmail()
} }
err := common.SendEmail(subject, common.RootUserEmail, content) err := common.SendEmail(subject, config.RootUserEmail, content)
if err != nil { if err != nil {
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
} }
@ -177,8 +178,8 @@ func enableChannel(channelId int, channelName string) {
} }
func testAllChannels(notify bool) error { func testAllChannels(notify bool) error {
if common.RootUserEmail == "" { if config.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail() config.RootUserEmail = model.GetRootUserEmail()
} }
testAllChannelsLock.Lock() testAllChannelsLock.Lock()
if testAllChannelsRunning { if testAllChannelsRunning {
@ -192,7 +193,7 @@ func testAllChannels(notify bool) error {
return err return err
} }
testRequest := buildTestRequest() testRequest := buildTestRequest()
var disableThreshold = int64(common.ChannelDisableThreshold * 1000) var disableThreshold = int64(config.ChannelDisableThreshold * 1000)
if disableThreshold == 0 { if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value disableThreshold = 10000000 // a impossible value
} }
@ -214,13 +215,13 @@ func testAllChannels(notify bool) error {
enableChannel(channel.Id, channel.Name) enableChannel(channel.Id, channel.Name)
} }
channel.UpdateResponseTime(milliseconds) channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval) time.Sleep(config.RequestInterval)
} }
testAllChannelsLock.Lock() testAllChannelsLock.Lock()
testAllChannelsRunning = false testAllChannelsRunning = false
testAllChannelsLock.Unlock() testAllChannelsLock.Unlock()
if notify { if notify {
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") err := common.SendEmail("通道测试完成", config.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil { if err != nil {
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
} }

View File

@ -3,7 +3,8 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/common/helper"
"one-api/model" "one-api/model"
"strconv" "strconv"
"strings" "strings"
@ -14,7 +15,7 @@ func GetAllChannels(c *gin.Context) {
if p < 0 { if p < 0 {
p = 0 p = 0
} }
channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false) channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -83,7 +84,7 @@ func AddChannel(c *gin.Context) {
}) })
return return
} }
channel.CreatedTime = common.GetTimestamp() channel.CreatedTime = helper.GetTimestamp()
keys := strings.Split(channel.Key, "\n") keys := strings.Split(channel.Key, "\n")
channels := make([]model.Channel, 0, len(keys)) channels := make([]model.Channel, 0, len(keys))
for _, key := range keys { for _, key := range keys {

View File

@ -9,6 +9,8 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/helper"
"one-api/common/logger" "one-api/common/logger"
"one-api/model" "one-api/model"
"strconv" "strconv"
@ -31,7 +33,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
if code == "" { if code == "" {
return nil, errors.New("无效的参数") return nil, errors.New("无效的参数")
} }
values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code} values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code}
jsonData, err := json.Marshal(values) jsonData, err := json.Marshal(values)
if err != nil { if err != nil {
return nil, err return nil, err
@ -94,7 +96,7 @@ func GitHubOAuth(c *gin.Context) {
return return
} }
if !common.GitHubOAuthEnabled { if !config.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "管理员未开启通过 GitHub 登录以及注册", "message": "管理员未开启通过 GitHub 登录以及注册",
@ -123,7 +125,7 @@ func GitHubOAuth(c *gin.Context) {
return return
} }
} else { } else {
if common.RegisterEnabled { if config.RegisterEnabled {
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" { if githubUser.Name != "" {
user.DisplayName = githubUser.Name user.DisplayName = githubUser.Name
@ -161,7 +163,7 @@ func GitHubOAuth(c *gin.Context) {
} }
func GitHubBind(c *gin.Context) { func GitHubBind(c *gin.Context) {
if !common.GitHubOAuthEnabled { if !config.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "管理员未开启通过 GitHub 登录以及注册", "message": "管理员未开启通过 GitHub 登录以及注册",
@ -217,7 +219,7 @@ func GitHubBind(c *gin.Context) {
func GenerateOAuthCode(c *gin.Context) { func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
state := common.GetRandomString(12) state := helper.GetRandomString(12)
session.Set("oauth_state", state) session.Set("oauth_state", state)
err := session.Save() err := session.Save()
if err != nil { if err != nil {

View File

@ -3,7 +3,7 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/model" "one-api/model"
"strconv" "strconv"
) )
@ -20,7 +20,7 @@ func GetAllLogs(c *gin.Context) {
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel")) channel, _ := strconv.Atoi(c.Query("channel"))
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel) logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*config.ItemsPerPage, config.ItemsPerPage, channel)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -47,7 +47,7 @@ func GetUserLogs(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/model" "one-api/model"
"strings" "strings"
@ -18,55 +19,55 @@ func GetStatus(c *gin.Context) {
"data": gin.H{ "data": gin.H{
"version": common.Version, "version": common.Version,
"start_time": common.StartTime, "start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled, "email_verification": config.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled, "github_oauth": config.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId, "github_client_id": config.GitHubClientId,
"system_name": common.SystemName, "system_name": config.SystemName,
"logo": common.Logo, "logo": config.Logo,
"footer_html": common.Footer, "footer_html": config.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL, "wechat_qrcode": config.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled, "wechat_login": config.WeChatAuthEnabled,
"server_address": common.ServerAddress, "server_address": config.ServerAddress,
"turnstile_check": common.TurnstileCheckEnabled, "turnstile_check": config.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey, "turnstile_site_key": config.TurnstileSiteKey,
"top_up_link": common.TopUpLink, "top_up_link": config.TopUpLink,
"chat_link": common.ChatLink, "chat_link": config.ChatLink,
"quota_per_unit": common.QuotaPerUnit, "quota_per_unit": config.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled, "display_in_currency": config.DisplayInCurrencyEnabled,
}, },
}) })
return return
} }
func GetNotice(c *gin.Context) { func GetNotice(c *gin.Context) {
common.OptionMapRWMutex.RLock() config.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock() defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": common.OptionMap["Notice"], "data": config.OptionMap["Notice"],
}) })
return return
} }
func GetAbout(c *gin.Context) { func GetAbout(c *gin.Context) {
common.OptionMapRWMutex.RLock() config.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock() defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": common.OptionMap["About"], "data": config.OptionMap["About"],
}) })
return return
} }
func GetHomePageContent(c *gin.Context) { func GetHomePageContent(c *gin.Context) {
common.OptionMapRWMutex.RLock() config.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock() defer config.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": common.OptionMap["HomePageContent"], "data": config.OptionMap["HomePageContent"],
}) })
return return
} }
@ -80,9 +81,9 @@ func SendEmailVerification(c *gin.Context) {
}) })
return return
} }
if common.EmailDomainRestrictionEnabled { if config.EmailDomainRestrictionEnabled {
allowed := false allowed := false
for _, domain := range common.EmailDomainWhitelist { for _, domain := range config.EmailDomainWhitelist {
if strings.HasSuffix(email, "@"+domain) { if strings.HasSuffix(email, "@"+domain) {
allowed = true allowed = true
break break
@ -105,10 +106,10 @@ func SendEmailVerification(c *gin.Context) {
} }
code := common.GenerateVerificationCode(6) code := common.GenerateVerificationCode(6)
common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose) common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose)
subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+ content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+
"<p>您的验证码为: <strong>%s</strong></p>"+ "<p>您的验证码为: <strong>%s</strong></p>"+
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes) "<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, code, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content) err := common.SendEmail(subject, email, content)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@ -142,12 +143,12 @@ func SendPasswordResetEmail(c *gin.Context) {
} }
code := common.GenerateVerificationCode(0) code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", common.SystemName) subject := fmt.Sprintf("%s密码重置", config.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ "<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+ "<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes) "<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, link, link, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content) err := common.SendEmail(subject, email, content)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@ -3,7 +3,8 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/common/helper"
"one-api/model" "one-api/model"
"strings" "strings"
@ -12,17 +13,17 @@ import (
func GetOptions(c *gin.Context) { func GetOptions(c *gin.Context) {
var options []*model.Option var options []*model.Option
common.OptionMapRWMutex.Lock() config.OptionMapRWMutex.Lock()
for k, v := range common.OptionMap { for k, v := range config.OptionMap {
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
continue continue
} }
options = append(options, &model.Option{ options = append(options, &model.Option{
Key: k, Key: k,
Value: common.Interface2String(v), Value: helper.Interface2String(v),
}) })
} }
common.OptionMapRWMutex.Unlock() config.OptionMapRWMutex.Unlock()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
@ -43,7 +44,7 @@ func UpdateOption(c *gin.Context) {
} }
switch option.Key { switch option.Key {
case "Theme": case "Theme":
if !common.ValidThemes[option.Value] { if !config.ValidThemes[option.Value] {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无效的主题", "message": "无效的主题",
@ -51,7 +52,7 @@ func UpdateOption(c *gin.Context) {
return return
} }
case "GitHubOAuthEnabled": case "GitHubOAuthEnabled":
if option.Value == "true" && common.GitHubClientId == "" { if option.Value == "true" && config.GitHubClientId == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法启用 GitHub OAuth请先填入 GitHub Client Id 以及 GitHub Client Secret", "message": "无法启用 GitHub OAuth请先填入 GitHub Client Id 以及 GitHub Client Secret",
@ -59,7 +60,7 @@ func UpdateOption(c *gin.Context) {
return return
} }
case "EmailDomainRestrictionEnabled": case "EmailDomainRestrictionEnabled":
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", "message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
@ -67,7 +68,7 @@ func UpdateOption(c *gin.Context) {
return return
} }
case "WeChatAuthEnabled": case "WeChatAuthEnabled":
if option.Value == "true" && common.WeChatServerAddress == "" { if option.Value == "true" && config.WeChatServerAddress == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法启用微信登录,请先填入微信登录相关配置信息!", "message": "无法启用微信登录,请先填入微信登录相关配置信息!",
@ -75,7 +76,7 @@ func UpdateOption(c *gin.Context) {
return return
} }
case "TurnstileCheckEnabled": case "TurnstileCheckEnabled":
if option.Value == "true" && common.TurnstileSiteKey == "" { if option.Value == "true" && config.TurnstileSiteKey == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", "message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",

View File

@ -3,7 +3,8 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/common/helper"
"one-api/model" "one-api/model"
"strconv" "strconv"
) )
@ -13,7 +14,7 @@ func GetAllRedemptions(c *gin.Context) {
if p < 0 { if p < 0 {
p = 0 p = 0
} }
redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage) redemptions, err := model.GetAllRedemptions(p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -105,12 +106,12 @@ func AddRedemption(c *gin.Context) {
} }
var keys []string var keys []string
for i := 0; i < redemption.Count; i++ { for i := 0; i < redemption.Count; i++ {
key := common.GetUUID() key := helper.GetUUID()
cleanRedemption := model.Redemption{ cleanRedemption := model.Redemption{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: redemption.Name, Name: redemption.Name,
Key: key, Key: key,
CreatedTime: common.GetTimestamp(), CreatedTime: helper.GetTimestamp(),
Quota: redemption.Quota, Quota: redemption.Quota,
} }
err = cleanRedemption.Insert() err = cleanRedemption.Insert()

View File

@ -4,7 +4,8 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/common/helper"
"one-api/common/logger" "one-api/common/logger"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
"one-api/relay/constant" "one-api/relay/constant"
@ -31,11 +32,11 @@ func Relay(c *gin.Context) {
err = controller.RelayTextHelper(c, relayMode) err = controller.RelayTextHelper(c, relayMode)
} }
if err != nil { if err != nil {
requestId := c.GetString(common.RequestIdKey) requestId := c.GetString(logger.RequestIdKey)
retryTimesStr := c.Query("retry") retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr) retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" { if retryTimesStr == "" {
retryTimes = common.RetryTimes retryTimes = config.RetryTimes
} }
if retryTimes > 0 { if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
@ -43,7 +44,7 @@ func Relay(c *gin.Context) {
if err.StatusCode == http.StatusTooManyRequests { if err.StatusCode == http.StatusTooManyRequests {
err.Error.Message = "当前分组上游负载已饱和,请稍后再试" err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
} }
err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId) err.Error.Message = helper.MessageWithRequestId(err.Error.Message, requestId)
c.JSON(err.StatusCode, gin.H{ c.JSON(err.StatusCode, gin.H{
"error": err.Error, "error": err.Error,
}) })

View File

@ -4,6 +4,8 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/helper"
"one-api/model" "one-api/model"
"strconv" "strconv"
) )
@ -14,7 +16,7 @@ func GetAllTokens(c *gin.Context) {
if p < 0 { if p < 0 {
p = 0 p = 0
} }
tokens, err := model.GetAllUserTokens(userId, p*common.ItemsPerPage, common.ItemsPerPage) tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -119,9 +121,9 @@ func AddToken(c *gin.Context) {
cleanToken := model.Token{ cleanToken := model.Token{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: token.Name, Name: token.Name,
Key: common.GenerateKey(), Key: helper.GenerateKey(),
CreatedTime: common.GetTimestamp(), CreatedTime: helper.GetTimestamp(),
AccessedTime: common.GetTimestamp(), AccessedTime: helper.GetTimestamp(),
ExpiredTime: token.ExpiredTime, ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota, RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota, UnlimitedQuota: token.UnlimitedQuota,
@ -187,7 +189,7 @@ func UpdateToken(c *gin.Context) {
return return
} }
if token.Status == common.TokenStatusEnabled { if token.Status == common.TokenStatusEnabled {
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 { if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",

View File

@ -5,7 +5,8 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/logger" "one-api/common/config"
"one-api/common/helper"
"one-api/model" "one-api/model"
"strconv" "strconv"
"time" "time"
@ -20,7 +21,7 @@ type LoginRequest struct {
} }
func Login(c *gin.Context) { func Login(c *gin.Context) {
if !common.PasswordLoginEnabled { if !config.PasswordLoginEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了密码登录", "message": "管理员关闭了密码登录",
"success": false, "success": false,
@ -107,14 +108,14 @@ func Logout(c *gin.Context) {
} }
func Register(c *gin.Context) { func Register(c *gin.Context) {
if !common.RegisterEnabled { if !config.RegisterEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了新用户注册", "message": "管理员关闭了新用户注册",
"success": false, "success": false,
}) })
return return
} }
if !common.PasswordRegisterEnabled { if !config.PasswordRegisterEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册", "message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册",
"success": false, "success": false,
@ -137,7 +138,7 @@ func Register(c *gin.Context) {
}) })
return return
} }
if common.EmailVerificationEnabled { if config.EmailVerificationEnabled {
if user.Email == "" || user.VerificationCode == "" { if user.Email == "" || user.VerificationCode == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -161,7 +162,7 @@ func Register(c *gin.Context) {
DisplayName: user.Username, DisplayName: user.Username,
InviterId: inviterId, InviterId: inviterId,
} }
if common.EmailVerificationEnabled { if config.EmailVerificationEnabled {
cleanUser.Email = user.Email cleanUser.Email = user.Email
} }
if err := cleanUser.Insert(inviterId); err != nil { if err := cleanUser.Insert(inviterId); err != nil {
@ -183,7 +184,7 @@ func GetAllUsers(c *gin.Context) {
if p < 0 { if p < 0 {
p = 0 p = 0
} }
users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage) users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -283,7 +284,7 @@ func GenerateAccessToken(c *gin.Context) {
}) })
return return
} }
user.AccessToken = common.GetUUID() user.AccessToken = helper.GetUUID()
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@ -320,7 +321,7 @@ func GetAffCode(c *gin.Context) {
return return
} }
if user.AffCode == "" { if user.AffCode == "" {
user.AffCode = common.GetRandomString(4) user.AffCode = helper.GetRandomString(4)
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -410,7 +411,7 @@ func UpdateUser(c *gin.Context) {
return return
} }
if originUser.Quota != updatedUser.Quota { if originUser.Quota != updatedUser.Quota {
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota))) model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
@ -727,7 +728,7 @@ func EmailBind(c *gin.Context) {
return return
} }
if user.Role == common.RoleRootUser { if user.Role == common.RoleRootUser {
common.RootUserEmail = email config.RootUserEmail = email
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,

View File

@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/model" "one-api/model"
"strconv" "strconv"
"time" "time"
@ -22,11 +23,11 @@ func getWeChatIdByCode(code string) (string, error) {
if code == "" { if code == "" {
return "", errors.New("无效的参数") return "", errors.New("无效的参数")
} }
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil) req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil)
if err != nil { if err != nil {
return "", err return "", err
} }
req.Header.Set("Authorization", common.WeChatServerToken) req.Header.Set("Authorization", config.WeChatServerToken)
client := http.Client{ client := http.Client{
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
} }
@ -50,7 +51,7 @@ func getWeChatIdByCode(code string) (string, error) {
} }
func WeChatAuth(c *gin.Context) { func WeChatAuth(c *gin.Context) {
if !common.WeChatAuthEnabled { if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册", "message": "管理员未开启通过微信登录以及注册",
"success": false, "success": false,
@ -79,7 +80,7 @@ func WeChatAuth(c *gin.Context) {
return return
} }
} else { } else {
if common.RegisterEnabled { if config.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = "WeChat User" user.DisplayName = "WeChat User"
user.Role = common.RoleCommonUser user.Role = common.RoleCommonUser
@ -112,7 +113,7 @@ func WeChatAuth(c *gin.Context) {
} }
func WeChatBind(c *gin.Context) { func WeChatBind(c *gin.Context) {
if !common.WeChatAuthEnabled { if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册", "message": "管理员未开启通过微信登录以及注册",
"success": false, "success": false,

23
main.go
View File

@ -7,6 +7,7 @@ import (
"github.com/gin-contrib/sessions/cookie" "github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/controller" "one-api/controller"
"one-api/middleware" "one-api/middleware"
@ -26,7 +27,7 @@ func main() {
if os.Getenv("GIN_MODE") != "debug" { if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
if common.DebugEnabled { if config.DebugEnabled {
logger.SysLog("running in debug mode") logger.SysLog("running in debug mode")
} }
// Initialize SQL Database // Initialize SQL Database
@ -49,19 +50,19 @@ func main() {
// Initialize options // Initialize options
model.InitOptionMap() model.InitOptionMap()
logger.SysLog(fmt.Sprintf("using theme %s", common.Theme)) logger.SysLog(fmt.Sprintf("using theme %s", config.Theme))
if common.RedisEnabled { if common.RedisEnabled {
// for compatibility with old versions // for compatibility with old versions
common.MemoryCacheEnabled = true config.MemoryCacheEnabled = true
} }
if common.MemoryCacheEnabled { if config.MemoryCacheEnabled {
logger.SysLog("memory cache enabled") logger.SysLog("memory cache enabled")
logger.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) logger.SysError(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency))
model.InitChannelCache() model.InitChannelCache()
} }
if common.MemoryCacheEnabled { if config.MemoryCacheEnabled {
go model.SyncOptions(common.SyncFrequency) go model.SyncOptions(config.SyncFrequency)
go model.SyncChannelCache(common.SyncFrequency) go model.SyncChannelCache(config.SyncFrequency)
} }
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
@ -78,8 +79,8 @@ func main() {
go controller.AutomaticallyTestChannels(frequency) go controller.AutomaticallyTestChannels(frequency)
} }
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true config.BatchUpdateEnabled = true
logger.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s")
model.InitBatchUpdater() model.InitBatchUpdater()
} }
openai.InitTokenEncoders() openai.InitTokenEncoders()
@ -92,7 +93,7 @@ func main() {
server.Use(middleware.RequestId()) server.Use(middleware.RequestId())
middleware.SetUpLogger(server) middleware.SetUpLogger(server)
// Initialize session store // Initialize session store
store := cookie.NewStore([]byte(common.SessionSecret)) store := cookie.NewStore([]byte(config.SessionSecret))
server.Use(sessions.Sessions("session", store)) server.Use(sessions.Sessions("session", store))
router.SetRouter(server, buildFS) router.SetRouter(server, buildFS)

View File

@ -3,14 +3,14 @@ package middleware
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "one-api/common/logger"
) )
func SetUpLogger(server *gin.Engine) { func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string var requestID string
if param.Keys != nil { if param.Keys != nil {
requestID = param.Keys[common.RequestIdKey].(string) requestID = param.Keys[logger.RequestIdKey].(string)
} }
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"), param.TimeStamp.Format("2006/01/02 - 15:04:05"),

View File

@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"time" "time"
) )
@ -26,7 +27,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
} }
if listLength < int64(maxRequestNum) { if listLength < int64(maxRequestNum) {
rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
} else { } else {
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr) oldTime, err := time.Parse(timeFormat, oldTimeStr)
@ -47,14 +48,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
// time.Since will return negative number! // time.Since will return negative number!
// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
if int64(nowTime.Sub(oldTime).Seconds()) < duration { if int64(nowTime.Sub(oldTime).Seconds()) < duration {
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
c.Status(http.StatusTooManyRequests) c.Status(http.StatusTooManyRequests)
c.Abort() c.Abort()
return return
} else { } else {
rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
} }
} }
} }
@ -75,7 +76,7 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
} }
} else { } else {
// It's safe to call multi times. // It's safe to call multi times.
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration)
return func(c *gin.Context) { return func(c *gin.Context) {
memoryRateLimiter(c, maxRequestNum, duration, mark) memoryRateLimiter(c, maxRequestNum, duration, mark)
} }
@ -83,21 +84,21 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
} }
func GlobalWebRateLimit() func(c *gin.Context) { func GlobalWebRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW") return rateLimitFactory(config.GlobalWebRateLimitNum, config.GlobalWebRateLimitDuration, "GW")
} }
func GlobalAPIRateLimit() func(c *gin.Context) { func GlobalAPIRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA") return rateLimitFactory(config.GlobalApiRateLimitNum, config.GlobalApiRateLimitDuration, "GA")
} }
func CriticalRateLimit() func(c *gin.Context) { func CriticalRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT") return rateLimitFactory(config.CriticalRateLimitNum, config.CriticalRateLimitDuration, "CT")
} }
func DownloadRateLimit() func(c *gin.Context) { func DownloadRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW") return rateLimitFactory(config.DownloadRateLimitNum, config.DownloadRateLimitDuration, "DW")
} }
func UploadRateLimit() func(c *gin.Context) { func UploadRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP") return rateLimitFactory(config.UploadRateLimitNum, config.UploadRateLimitDuration, "UP")
} }

View File

@ -3,16 +3,17 @@ package middleware
import ( import (
"context" "context"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "one-api/common/helper"
"one-api/common/logger"
) )
func RequestId() func(c *gin.Context) { func RequestId() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
id := common.GetTimeString() + common.GetRandomString(8) id := helper.GetTimeString() + helper.GetRandomString(8)
c.Set(common.RequestIdKey, id) c.Set(logger.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx) c.Request = c.Request.WithContext(ctx)
c.Header(common.RequestIdKey, id) c.Header(logger.RequestIdKey, id)
c.Next() c.Next()
} }
} }

View File

@ -6,7 +6,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"net/url" "net/url"
"one-api/common" "one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
) )
@ -16,7 +16,7 @@ type turnstileCheckResponse struct {
func TurnstileCheck() gin.HandlerFunc { func TurnstileCheck() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if common.TurnstileCheckEnabled { if config.TurnstileCheckEnabled {
session := sessions.Default(c) session := sessions.Default(c)
turnstileChecked := session.Get("turnstile") turnstileChecked := session.Get("turnstile")
if turnstileChecked != nil { if turnstileChecked != nil {
@ -33,7 +33,7 @@ func TurnstileCheck() gin.HandlerFunc {
return return
} }
rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{ rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{
"secret": {common.TurnstileSecretKey}, "secret": {config.TurnstileSecretKey},
"response": {response}, "response": {response},
"remoteip": {c.ClientIP()}, "remoteip": {c.ClientIP()},
}) })

View File

@ -2,14 +2,14 @@ package middleware
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "one-api/common/helper"
"one-api/common/logger" "one-api/common/logger"
) )
func abortWithMessage(c *gin.Context, statusCode int, message string) { func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{ c.JSON(statusCode, gin.H{
"error": gin.H{ "error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), "message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)),
"type": "one_api_error", "type": "one_api_error",
}, },
}) })

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"sort" "sort"
"strconv" "strconv"
@ -15,10 +16,10 @@ import (
) )
var ( var (
TokenCacheSeconds = common.SyncFrequency TokenCacheSeconds = config.SyncFrequency
UserId2GroupCacheSeconds = common.SyncFrequency UserId2GroupCacheSeconds = config.SyncFrequency
UserId2QuotaCacheSeconds = common.SyncFrequency UserId2QuotaCacheSeconds = config.SyncFrequency
UserId2StatusCacheSeconds = common.SyncFrequency UserId2StatusCacheSeconds = config.SyncFrequency
) )
func CacheGetTokenByKey(key string) (*Token, error) { func CacheGetTokenByKey(key string) (*Token, error) {
@ -191,7 +192,7 @@ func SyncChannelCache(frequency int) {
} }
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
if !common.MemoryCacheEnabled { if !config.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model) return GetRandomSatisfiedChannel(group, model)
} }
channelSyncLock.RLock() channelSyncLock.RLock()

View File

@ -5,6 +5,8 @@ import (
"fmt" "fmt"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/helper"
"one-api/common/logger" "one-api/common/logger"
) )
@ -45,7 +47,7 @@ func SearchChannels(keyword string) (channels []*Channel, err error) {
if common.UsingPostgreSQL { if common.UsingPostgreSQL {
keyCol = `"key"` keyCol = `"key"`
} }
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", helper.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
return channels, err return channels, err
} }
@ -125,7 +127,7 @@ func (channel *Channel) Update() error {
func (channel *Channel) UpdateResponseTime(responseTime int64) { func (channel *Channel) UpdateResponseTime(responseTime int64) {
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
TestTime: common.GetTimestamp(), TestTime: helper.GetTimestamp(),
ResponseTime: int(responseTime), ResponseTime: int(responseTime),
}).Error }).Error
if err != nil { if err != nil {
@ -135,7 +137,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
func (channel *Channel) UpdateBalance(balance float64) { func (channel *Channel) UpdateBalance(balance float64) {
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
BalanceUpdatedTime: common.GetTimestamp(), BalanceUpdatedTime: helper.GetTimestamp(),
Balance: balance, Balance: balance,
}).Error }).Error
if err != nil { if err != nil {
@ -165,7 +167,7 @@ func UpdateChannelStatusById(id int, status int) {
} }
func UpdateChannelUsedQuota(id int, quota int) { func UpdateChannelUsedQuota(id int, quota int) {
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
return return
} }

View File

@ -4,6 +4,8 @@ import (
"context" "context"
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/helper"
"one-api/common/logger" "one-api/common/logger"
"gorm.io/gorm" "gorm.io/gorm"
@ -33,13 +35,13 @@ const (
) )
func RecordLog(userId int, logType int, content string) { func RecordLog(userId int, logType int, content string) {
if logType == LogTypeConsume && !common.LogConsumeEnabled { if logType == LogTypeConsume && !config.LogConsumeEnabled {
return return
} }
log := &Log{ log := &Log{
UserId: userId, UserId: userId,
Username: GetUsernameById(userId), Username: GetUsernameById(userId),
CreatedAt: common.GetTimestamp(), CreatedAt: helper.GetTimestamp(),
Type: logType, Type: logType,
Content: content, Content: content,
} }
@ -51,13 +53,13 @@ func RecordLog(userId int, logType int, content string) {
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled { if !config.LogConsumeEnabled {
return return
} }
log := &Log{ log := &Log{
UserId: userId, UserId: userId,
Username: GetUsernameById(userId), Username: GetUsernameById(userId),
CreatedAt: common.GetTimestamp(), CreatedAt: helper.GetTimestamp(),
Type: LogTypeConsume, Type: LogTypeConsume,
Content: content, Content: content,
PromptTokens: promptTokens, PromptTokens: promptTokens,
@ -126,12 +128,12 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
} }
func SearchAllLogs(keyword string) (logs []*Log, err error) { func SearchAllLogs(keyword string) (logs []*Log, err error) {
err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error
return logs, err return logs, err
} }
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error
return logs, err return logs, err
} }

View File

@ -7,6 +7,8 @@ import (
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/helper"
"one-api/common/logger" "one-api/common/logger"
"os" "os"
"strings" "strings"
@ -30,7 +32,7 @@ func createRootAccountIfNeed() error {
Role: common.RoleRootUser, Role: common.RoleRootUser,
Status: common.UserStatusEnabled, Status: common.UserStatusEnabled,
DisplayName: "Root User", DisplayName: "Root User",
AccessToken: common.GetUUID(), AccessToken: helper.GetUUID(),
Quota: 100000000, Quota: 100000000,
} }
DB.Create(&rootUser) DB.Create(&rootUser)
@ -70,7 +72,7 @@ func chooseDB() (*gorm.DB, error) {
func InitDB() (err error) { func InitDB() (err error) {
db, err := chooseDB() db, err := chooseDB()
if err == nil { if err == nil {
if common.DebugEnabled { if config.DebugEnabled {
db = db.Debug() db = db.Debug()
} }
DB = db DB = db
@ -78,11 +80,11 @@ func InitDB() (err error) {
if err != nil { if err != nil {
return err return err
} }
sqlDB.SetMaxIdleConns(common.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode { if !config.IsMasterNode {
return nil return nil
} }
logger.SysLog("database migration started") logger.SysLog("database migration started")

View File

@ -2,6 +2,7 @@ package model
import ( import (
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"strconv" "strconv"
"strings" "strings"
@ -21,60 +22,56 @@ func AllOption() ([]*Option, error) {
} }
func InitOptionMap() { func InitOptionMap() {
common.OptionMapRWMutex.Lock() config.OptionMapRWMutex.Lock()
common.OptionMap = make(map[string]string) config.OptionMap = make(map[string]string)
common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled)
common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled)
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled)
common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled)
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",")
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) config.OptionMap["SMTPServer"] = ""
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) config.OptionMap["SMTPFrom"] = ""
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort)
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") config.OptionMap["SMTPAccount"] = ""
common.OptionMap["SMTPServer"] = "" config.OptionMap["SMTPToken"] = ""
common.OptionMap["SMTPFrom"] = "" config.OptionMap["Notice"] = ""
common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) config.OptionMap["About"] = ""
common.OptionMap["SMTPAccount"] = "" config.OptionMap["HomePageContent"] = ""
common.OptionMap["SMTPToken"] = "" config.OptionMap["Footer"] = config.Footer
common.OptionMap["Notice"] = "" config.OptionMap["SystemName"] = config.SystemName
common.OptionMap["About"] = "" config.OptionMap["Logo"] = config.Logo
common.OptionMap["HomePageContent"] = "" config.OptionMap["ServerAddress"] = ""
common.OptionMap["Footer"] = common.Footer config.OptionMap["GitHubClientId"] = ""
common.OptionMap["SystemName"] = common.SystemName config.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["Logo"] = common.Logo config.OptionMap["WeChatServerAddress"] = ""
common.OptionMap["ServerAddress"] = "" config.OptionMap["WeChatServerToken"] = ""
common.OptionMap["GitHubClientId"] = "" config.OptionMap["WeChatAccountQRCodeImageURL"] = ""
common.OptionMap["GitHubClientSecret"] = "" config.OptionMap["TurnstileSiteKey"] = ""
common.OptionMap["WeChatServerAddress"] = "" config.OptionMap["TurnstileSecretKey"] = ""
common.OptionMap["WeChatServerToken"] = "" config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser)
common.OptionMap["WeChatAccountQRCodeImageURL"] = "" config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter)
common.OptionMap["TurnstileSiteKey"] = "" config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee)
common.OptionMap["TurnstileSecretKey"] = "" config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold)
common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota)
common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) config.OptionMap["TopUpLink"] = config.TopUpLink
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) config.OptionMap["ChatLink"] = config.ChatLink
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes)
common.OptionMap["TopUpLink"] = common.TopUpLink config.OptionMap["Theme"] = config.Theme
common.OptionMap["ChatLink"] = common.ChatLink config.OptionMapRWMutex.Unlock()
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMap["Theme"] = common.Theme
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase() loadOptionsFromDatabase()
} }
@ -113,117 +110,104 @@ func UpdateOption(key string, value string) error {
} }
func updateOptionMap(key string, value string) (err error) { func updateOptionMap(key string, value string) (err error) {
common.OptionMapRWMutex.Lock() config.OptionMapRWMutex.Lock()
defer common.OptionMapRWMutex.Unlock() defer config.OptionMapRWMutex.Unlock()
common.OptionMap[key] = value config.OptionMap[key] = value
if strings.HasSuffix(key, "Permission") {
intValue, _ := strconv.Atoi(value)
switch key {
case "FileUploadPermission":
common.FileUploadPermission = intValue
case "FileDownloadPermission":
common.FileDownloadPermission = intValue
case "ImageUploadPermission":
common.ImageUploadPermission = intValue
case "ImageDownloadPermission":
common.ImageDownloadPermission = intValue
}
}
if strings.HasSuffix(key, "Enabled") { if strings.HasSuffix(key, "Enabled") {
boolValue := value == "true" boolValue := value == "true"
switch key { switch key {
case "PasswordRegisterEnabled": case "PasswordRegisterEnabled":
common.PasswordRegisterEnabled = boolValue config.PasswordRegisterEnabled = boolValue
case "PasswordLoginEnabled": case "PasswordLoginEnabled":
common.PasswordLoginEnabled = boolValue config.PasswordLoginEnabled = boolValue
case "EmailVerificationEnabled": case "EmailVerificationEnabled":
common.EmailVerificationEnabled = boolValue config.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled": case "GitHubOAuthEnabled":
common.GitHubOAuthEnabled = boolValue config.GitHubOAuthEnabled = boolValue
case "WeChatAuthEnabled": case "WeChatAuthEnabled":
common.WeChatAuthEnabled = boolValue config.WeChatAuthEnabled = boolValue
case "TurnstileCheckEnabled": case "TurnstileCheckEnabled":
common.TurnstileCheckEnabled = boolValue config.TurnstileCheckEnabled = boolValue
case "RegisterEnabled": case "RegisterEnabled":
common.RegisterEnabled = boolValue config.RegisterEnabled = boolValue
case "EmailDomainRestrictionEnabled": case "EmailDomainRestrictionEnabled":
common.EmailDomainRestrictionEnabled = boolValue config.EmailDomainRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled": case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue config.AutomaticDisableChannelEnabled = boolValue
case "AutomaticEnableChannelEnabled": case "AutomaticEnableChannelEnabled":
common.AutomaticEnableChannelEnabled = boolValue config.AutomaticEnableChannelEnabled = boolValue
case "ApproximateTokenEnabled": case "ApproximateTokenEnabled":
common.ApproximateTokenEnabled = boolValue config.ApproximateTokenEnabled = boolValue
case "LogConsumeEnabled": case "LogConsumeEnabled":
common.LogConsumeEnabled = boolValue config.LogConsumeEnabled = boolValue
case "DisplayInCurrencyEnabled": case "DisplayInCurrencyEnabled":
common.DisplayInCurrencyEnabled = boolValue config.DisplayInCurrencyEnabled = boolValue
case "DisplayTokenStatEnabled": case "DisplayTokenStatEnabled":
common.DisplayTokenStatEnabled = boolValue config.DisplayTokenStatEnabled = boolValue
} }
} }
switch key { switch key {
case "EmailDomainWhitelist": case "EmailDomainWhitelist":
common.EmailDomainWhitelist = strings.Split(value, ",") config.EmailDomainWhitelist = strings.Split(value, ",")
case "SMTPServer": case "SMTPServer":
common.SMTPServer = value config.SMTPServer = value
case "SMTPPort": case "SMTPPort":
intValue, _ := strconv.Atoi(value) intValue, _ := strconv.Atoi(value)
common.SMTPPort = intValue config.SMTPPort = intValue
case "SMTPAccount": case "SMTPAccount":
common.SMTPAccount = value config.SMTPAccount = value
case "SMTPFrom": case "SMTPFrom":
common.SMTPFrom = value config.SMTPFrom = value
case "SMTPToken": case "SMTPToken":
common.SMTPToken = value config.SMTPToken = value
case "ServerAddress": case "ServerAddress":
common.ServerAddress = value config.ServerAddress = value
case "GitHubClientId": case "GitHubClientId":
common.GitHubClientId = value config.GitHubClientId = value
case "GitHubClientSecret": case "GitHubClientSecret":
common.GitHubClientSecret = value config.GitHubClientSecret = value
case "Footer": case "Footer":
common.Footer = value config.Footer = value
case "SystemName": case "SystemName":
common.SystemName = value config.SystemName = value
case "Logo": case "Logo":
common.Logo = value config.Logo = value
case "WeChatServerAddress": case "WeChatServerAddress":
common.WeChatServerAddress = value config.WeChatServerAddress = value
case "WeChatServerToken": case "WeChatServerToken":
common.WeChatServerToken = value config.WeChatServerToken = value
case "WeChatAccountQRCodeImageURL": case "WeChatAccountQRCodeImageURL":
common.WeChatAccountQRCodeImageURL = value config.WeChatAccountQRCodeImageURL = value
case "TurnstileSiteKey": case "TurnstileSiteKey":
common.TurnstileSiteKey = value config.TurnstileSiteKey = value
case "TurnstileSecretKey": case "TurnstileSecretKey":
common.TurnstileSecretKey = value config.TurnstileSecretKey = value
case "QuotaForNewUser": case "QuotaForNewUser":
common.QuotaForNewUser, _ = strconv.Atoi(value) config.QuotaForNewUser, _ = strconv.Atoi(value)
case "QuotaForInviter": case "QuotaForInviter":
common.QuotaForInviter, _ = strconv.Atoi(value) config.QuotaForInviter, _ = strconv.Atoi(value)
case "QuotaForInvitee": case "QuotaForInvitee":
common.QuotaForInvitee, _ = strconv.Atoi(value) config.QuotaForInvitee, _ = strconv.Atoi(value)
case "QuotaRemindThreshold": case "QuotaRemindThreshold":
common.QuotaRemindThreshold, _ = strconv.Atoi(value) config.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "PreConsumedQuota": case "PreConsumedQuota":
common.PreConsumedQuota, _ = strconv.Atoi(value) config.PreConsumedQuota, _ = strconv.Atoi(value)
case "RetryTimes": case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value) config.RetryTimes, _ = strconv.Atoi(value)
case "ModelRatio": case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value) err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio": case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value) err = common.UpdateGroupRatioByJSONString(value)
case "TopUpLink": case "TopUpLink":
common.TopUpLink = value config.TopUpLink = value
case "ChatLink": case "ChatLink":
common.ChatLink = value config.ChatLink = value
case "ChannelDisableThreshold": case "ChannelDisableThreshold":
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
case "QuotaPerUnit": case "QuotaPerUnit":
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "Theme": case "Theme":
common.Theme = value config.Theme = value
} }
return err return err
} }

View File

@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"one-api/common/logger" "one-api/common/helper"
) )
type Redemption struct { type Redemption struct {
@ -68,7 +68,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil { if err != nil {
return err return err
} }
redemption.RedeemedTime = common.GetTimestamp() redemption.RedeemedTime = helper.GetTimestamp()
redemption.Status = common.RedemptionCodeStatusUsed redemption.Status = common.RedemptionCodeStatusUsed
err = tx.Save(redemption).Error err = tx.Save(redemption).Error
return err return err
@ -76,7 +76,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil { if err != nil {
return 0, errors.New("兑换失败," + err.Error()) return 0, errors.New("兑换失败," + err.Error())
} }
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", logger.LogQuota(redemption.Quota))) RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
return redemption.Quota, nil return redemption.Quota, nil
} }

View File

@ -5,6 +5,8 @@ import (
"fmt" "fmt"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/helper"
"one-api/common/logger" "one-api/common/logger"
) )
@ -54,7 +56,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
if token.Status != common.TokenStatusEnabled { if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用") return nil, errors.New("该令牌状态不可用")
} }
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() {
if !common.RedisEnabled { if !common.RedisEnabled {
token.Status = common.TokenStatusExpired token.Status = common.TokenStatusExpired
err := token.SelectUpdate() err := token.SelectUpdate()
@ -139,7 +141,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, quota) addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
return nil return nil
} }
@ -151,7 +153,7 @@ func increaseTokenQuota(id int, quota int) (err error) {
map[string]interface{}{ map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota + ?", quota), "remain_quota": gorm.Expr("remain_quota + ?", quota),
"used_quota": gorm.Expr("used_quota - ?", quota), "used_quota": gorm.Expr("used_quota - ?", quota),
"accessed_time": common.GetTimestamp(), "accessed_time": helper.GetTimestamp(),
}, },
).Error ).Error
return err return err
@ -161,7 +163,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
return nil return nil
} }
@ -173,7 +175,7 @@ func decreaseTokenQuota(id int, quota int) (err error) {
map[string]interface{}{ map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota - ?", quota), "remain_quota": gorm.Expr("remain_quota - ?", quota),
"used_quota": gorm.Expr("used_quota + ?", quota), "used_quota": gorm.Expr("used_quota + ?", quota),
"accessed_time": common.GetTimestamp(), "accessed_time": helper.GetTimestamp(),
}, },
).Error ).Error
return err return err
@ -197,7 +199,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
if userQuota < quota { if userQuota < quota {
return errors.New("用户额度不足") return errors.New("用户额度不足")
} }
quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold
noMoreQuota := userQuota-quota <= 0 noMoreQuota := userQuota-quota <= 0
if quotaTooLow || noMoreQuota { if quotaTooLow || noMoreQuota {
go func() { go func() {
@ -210,7 +212,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
prompt = "您的额度已用尽" prompt = "您的额度已用尽"
} }
if email != "" { if email != "" {
topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress)
err = common.SendEmail(prompt, email, err = common.SendEmail(prompt, email,
fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil { if err != nil {

View File

@ -5,6 +5,8 @@ import (
"fmt" "fmt"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/helper"
"one-api/common/logger" "one-api/common/logger"
"strings" "strings"
) )
@ -90,24 +92,24 @@ func (user *User) Insert(inviterId int) error {
return err return err
} }
} }
user.Quota = common.QuotaForNewUser user.Quota = config.QuotaForNewUser
user.AccessToken = common.GetUUID() user.AccessToken = helper.GetUUID()
user.AffCode = common.GetRandomString(4) user.AffCode = helper.GetRandomString(4)
result := DB.Create(user) result := DB.Create(user)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
if common.QuotaForNewUser > 0 { if config.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser))) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
} }
if inviterId != 0 { if inviterId != 0 {
if common.QuotaForInvitee > 0 { if config.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee) _ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee))) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
} }
if common.QuotaForInviter > 0 { if config.QuotaForInviter > 0 {
_ = IncreaseUserQuota(inviterId, common.QuotaForInviter) _ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter))) RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
} }
} }
return nil return nil
@ -292,7 +294,7 @@ func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota) addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil return nil
} }
@ -308,7 +310,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, -quota) addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
return nil return nil
} }
@ -326,7 +328,7 @@ func GetRootUserEmail() (email string) {
} }
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
if common.BatchUpdateEnabled { if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuota, id, quota) addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
addNewRecord(BatchUpdateTypeRequestCount, id, 1) addNewRecord(BatchUpdateTypeRequestCount, id, 1)
return return

View File

@ -1,7 +1,7 @@
package model package model
import ( import (
"one-api/common" "one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"sync" "sync"
"time" "time"
@ -29,7 +29,7 @@ func init() {
func InitBatchUpdater() { func InitBatchUpdater() {
go func() { go func() {
for { for {
time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second)
batchUpdate() batchUpdate()
} }
}() }()

View File

@ -8,6 +8,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/helper"
"one-api/common/logger" "one-api/common/logger"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
"one-api/relay/constant" "one-api/relay/constant"
@ -51,9 +52,9 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon
FinishReason: "stop", FinishReason: "stop",
} }
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: common.GetUUID(), Id: helper.GetUUID(),
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice}, Choices: []openai.TextResponseChoice{choice},
} }
return &fullTextResponse return &fullTextResponse
@ -64,9 +65,9 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion
choice.Delta.Content = aiProxyDocuments2Markdown(documents) choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &constant.StopFinishReason choice.FinishReason = &constant.StopFinishReason
return &openai.ChatCompletionsStreamResponse{ return &openai.ChatCompletionsStreamResponse{
Id: common.GetUUID(), Id: helper.GetUUID(),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Model: "", Model: "",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
} }
@ -76,9 +77,9 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
var choice openai.ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content choice.Delta.Content = response.Content
return &openai.ChatCompletionsStreamResponse{ return &openai.ChatCompletionsStreamResponse{
Id: common.GetUUID(), Id: helper.GetUUID(),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Model: response.Model, Model: response.Model,
Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
} }

View File

@ -7,6 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/helper"
"one-api/common/logger" "one-api/common/logger"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
"strings" "strings"
@ -119,7 +120,7 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: response.RequestId, Id: response.RequestId,
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice}, Choices: []openai.TextResponseChoice{choice},
Usage: openai.Usage{ Usage: openai.Usage{
PromptTokens: response.Usage.InputTokens, PromptTokens: response.Usage.InputTokens,
@ -140,7 +141,7 @@ func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletions
response := openai.ChatCompletionsStreamResponse{ response := openai.ChatCompletionsStreamResponse{
Id: aliResponse.RequestId, Id: aliResponse.RequestId,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Model: "qwen", Model: "qwen",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
} }

View File

@ -8,6 +8,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/helper"
"one-api/common/logger" "one-api/common/logger"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
"strings" "strings"
@ -79,9 +80,9 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
} }
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice}, Choices: []openai.TextResponseChoice{choice},
} }
return &fullTextResponse return &fullTextResponse
@ -89,8 +90,8 @@ func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
responseText := "" responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
createdTime := common.GetTimestamp() createdTime := helper.GetTimestamp()
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {

View File

@ -7,6 +7,8 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/helper"
"one-api/common/image" "one-api/common/image"
"one-api/common/logger" "one-api/common/logger"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
@ -29,19 +31,19 @@ func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRe
SafetySettings: []GeminiChatSafetySettings{ SafetySettings: []GeminiChatSafetySettings{
{ {
Category: "HARM_CATEGORY_HARASSMENT", Category: "HARM_CATEGORY_HARASSMENT",
Threshold: common.GeminiSafetySetting, Threshold: config.GeminiSafetySetting,
}, },
{ {
Category: "HARM_CATEGORY_HATE_SPEECH", Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: common.GeminiSafetySetting, Threshold: config.GeminiSafetySetting,
}, },
{ {
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: common.GeminiSafetySetting, Threshold: config.GeminiSafetySetting,
}, },
{ {
Category: "HARM_CATEGORY_DANGEROUS_CONTENT", Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: common.GeminiSafetySetting, Threshold: config.GeminiSafetySetting,
}, },
}, },
GenerationConfig: GeminiChatGenerationConfig{ GenerationConfig: GeminiChatGenerationConfig{
@ -152,9 +154,9 @@ type GeminiChatPromptFeedback struct {
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse { func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
} }
for i, candidate := range response.Candidates { for i, candidate := range response.Candidates {
@ -230,9 +232,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatus
var choice openai.ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content choice.Delta.Content = dummy.Content
response := openai.ChatCompletionsStreamResponse{ response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Model: "gemini-pro", Model: "gemini-pro",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
} }

View File

@ -7,6 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/helper"
"one-api/common/logger" "one-api/common/logger"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
"one-api/relay/constant" "one-api/relay/constant"
@ -72,8 +73,8 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompl
func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) { func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
responseText := "" responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
createdTime := common.GetTimestamp() createdTime := helper.GetTimestamp()
dataChan := make(chan string) dataChan := make(chan string)
stopChan := make(chan bool) stopChan := make(chan bool)
go func() { go func() {

View File

@ -6,6 +6,7 @@ import (
"github.com/pkoukk/tiktoken-go" "github.com/pkoukk/tiktoken-go"
"math" "math"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/image" "one-api/common/image"
"one-api/common/logger" "one-api/common/logger"
"strings" "strings"
@ -56,7 +57,7 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken {
} }
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
if common.ApproximateTokenEnabled { if config.ApproximateTokenEnabled {
return int(float64(len(text)) * 0.38) return int(float64(len(text)) * 0.38)
} }
return len(tokenEncoder.Encode(text, nil, nil)) return len(tokenEncoder.Encode(text, nil, nil))

View File

@ -12,6 +12,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/helper"
"one-api/common/logger" "one-api/common/logger"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
"one-api/relay/constant" "one-api/relay/constant"
@ -47,9 +48,9 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
stream = 1 stream = 1
} }
return &ChatRequest{ return &ChatRequest{
Timestamp: common.GetTimestamp(), Timestamp: helper.GetTimestamp(),
Expired: common.GetTimestamp() + 24*60*60, Expired: helper.GetTimestamp() + 24*60*60,
QueryID: common.GetUUID(), QueryID: helper.GetUUID(),
Temperature: request.Temperature, Temperature: request.Temperature,
TopP: request.TopP, TopP: request.TopP,
Stream: stream, Stream: stream,
@ -60,7 +61,7 @@ func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Usage: response.Usage, Usage: response.Usage,
} }
if len(response.Choices) > 0 { if len(response.Choices) > 0 {
@ -80,7 +81,7 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
response := openai.ChatCompletionsStreamResponse{ response := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Model: "tencent-hunyuan", Model: "tencent-hunyuan",
} }
if len(TencentResponse.Choices) > 0 { if len(TencentResponse.Choices) > 0 {

View File

@ -12,6 +12,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"one-api/common" "one-api/common"
"one-api/common/helper"
"one-api/common/logger" "one-api/common/logger"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
"one-api/relay/constant" "one-api/relay/constant"
@ -69,7 +70,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
} }
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice}, Choices: []openai.TextResponseChoice{choice},
Usage: response.Payload.Usage.Text, Usage: response.Payload.Usage.Text,
} }
@ -91,7 +92,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl
} }
response := openai.ChatCompletionsStreamResponse{ response := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Model: "SparkDesk", Model: "SparkDesk",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
} }

View File

@ -8,6 +8,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/helper"
"one-api/common/logger" "one-api/common/logger"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
"one-api/relay/constant" "one-api/relay/constant"
@ -102,7 +103,7 @@ func responseZhipu2OpenAI(response *Response) *openai.TextResponse {
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: response.Data.TaskId, Id: response.Data.TaskId,
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)), Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)),
Usage: response.Data.Usage, Usage: response.Data.Usage,
} }
@ -128,7 +129,7 @@ func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStr
choice.Delta.Content = zhipuResponse choice.Delta.Content = zhipuResponse
response := openai.ChatCompletionsStreamResponse{ response := openai.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Model: "chatglm", Model: "chatglm",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
} }
@ -142,7 +143,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.
response := openai.ChatCompletionsStreamResponse{ response := openai.ChatCompletionsStreamResponse{
Id: zhipuResponse.RequestId, Id: zhipuResponse.RequestId,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: helper.GetTimestamp(),
Model: "chatglm", Model: "chatglm",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
} }

View File

@ -11,6 +11,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/model" "one-api/model"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
@ -54,7 +55,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumedQuota quota = preConsumedQuota
default: default:
preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio) preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio)
} }
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.CacheGetUserQuota(userId)
if err != nil { if err != nil {

View File

@ -8,6 +8,7 @@ import (
"math" "math"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/model" "one-api/model"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
@ -52,7 +53,7 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
case constant.RelayModeModerations: case constant.RelayModeModerations:
promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model) promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model)
} }
preConsumedTokens := common.PreConsumedQuota preConsumedTokens := config.PreConsumedQuota
if textRequest.MaxTokens != 0 { if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens preConsumedTokens = promptTokens + textRequest.MaxTokens
} }

View File

@ -9,6 +9,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/helper"
"one-api/relay/channel/aiproxy" "one-api/relay/channel/aiproxy"
"one-api/relay/channel/ali" "one-api/relay/channel/ali"
"one-api/relay/channel/anthropic" "one-api/relay/channel/anthropic"
@ -66,7 +67,7 @@ func GetRequestURL(requestURL string, apiType int, relayMode int, meta *util.Rel
case constant.APITypePaLM: case constant.APITypePaLM:
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL) fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL)
case constant.APITypeGemini: case constant.APITypeGemini:
version := common.AssignOrDefault(meta.APIVersion, "v1") version := helper.AssignOrDefault(meta.APIVersion, "v1")
action := "generateContent" action := "generateContent"
if textRequest.Stream { if textRequest.Stream {
action = "streamGenerateContent" action = "streamGenerateContent"

View File

@ -7,6 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/model" "one-api/model"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
@ -17,7 +18,7 @@ import (
) )
func ShouldDisableChannel(err *openai.Error, statusCode int) bool { func ShouldDisableChannel(err *openai.Error, statusCode int) bool {
if !common.AutomaticDisableChannelEnabled { if !config.AutomaticDisableChannelEnabled {
return false return false
} }
if err == nil { if err == nil {
@ -33,7 +34,7 @@ func ShouldDisableChannel(err *openai.Error, statusCode int) bool {
} }
func ShouldEnableChannel(err error, openAIErr *openai.Error) bool { func ShouldEnableChannel(err error, openAIErr *openai.Error) bool {
if !common.AutomaticEnableChannelEnabled { if !config.AutomaticEnableChannelEnabled {
return false return false
} }
if err != nil { if err != nil {

View File

@ -2,7 +2,7 @@ package util
import ( import (
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"time" "time"
) )
@ -10,11 +10,11 @@ var HTTPClient *http.Client
var ImpatientHTTPClient *http.Client var ImpatientHTTPClient *http.Client
func init() { func init() {
if common.RelayTimeout == 0 { if config.RelayTimeout == 0 {
HTTPClient = &http.Client{} HTTPClient = &http.Client{}
} else { } else {
HTTPClient = &http.Client{ HTTPClient = &http.Client{
Timeout: time.Duration(common.RelayTimeout) * time.Second, Timeout: time.Duration(config.RelayTimeout) * time.Second,
} }
} }

View File

@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"os" "os"
"strings" "strings"
@ -16,7 +16,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS) {
SetDashboardRouter(router) SetDashboardRouter(router)
SetRelayRouter(router) SetRelayRouter(router)
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
if common.IsMasterNode && frontendBaseUrl != "" { if config.IsMasterNode && frontendBaseUrl != "" {
frontendBaseUrl = "" frontendBaseUrl = ""
logger.SysLog("FRONTEND_BASE_URL is ignored on master node") logger.SysLog("FRONTEND_BASE_URL is ignored on master node")
} }

View File

@ -8,17 +8,18 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/controller" "one-api/controller"
"one-api/middleware" "one-api/middleware"
"strings" "strings"
) )
func SetWebRouter(router *gin.Engine, buildFS embed.FS) { func SetWebRouter(router *gin.Engine, buildFS embed.FS) {
indexPageData, _ := buildFS.ReadFile(fmt.Sprintf("web/build/%s/index.html", common.Theme)) indexPageData, _ := buildFS.ReadFile(fmt.Sprintf("web/build/%s/index.html", config.Theme))
router.Use(gzip.Gzip(gzip.DefaultCompression)) router.Use(gzip.Gzip(gzip.DefaultCompression))
router.Use(middleware.GlobalWebRateLimit()) router.Use(middleware.GlobalWebRateLimit())
router.Use(middleware.Cache()) router.Use(middleware.Cache())
router.Use(static.Serve("/", common.EmbedFolder(buildFS, fmt.Sprintf("web/build/%s", common.Theme)))) router.Use(static.Serve("/", common.EmbedFolder(buildFS, fmt.Sprintf("web/build/%s", config.Theme))))
router.NoRoute(func(c *gin.Context) { router.NoRoute(func(c *gin.Context) {
if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") { if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") {
controller.RelayNotFound(c) controller.RelayNotFound(c)