diff --git a/cli/flag.go b/cli/flag.go index 804f6cb4..2dcfd5a3 100644 --- a/cli/flag.go +++ b/cli/flag.go @@ -3,7 +3,8 @@ package cli import ( "flag" "fmt" - "one-api/common" + "one-api/common/config" + "one-api/common/utils" "os" "github.com/spf13/viper" @@ -18,11 +19,11 @@ var ( export = flag.Bool("export", false, "Exports prices to a JSON file.") ) -func FlagConfig() { +func InitCli() { flag.Parse() if *printVersion { - fmt.Println(common.Version) + fmt.Println(config.Version) os.Exit(0) } @@ -44,10 +45,19 @@ func FlagConfig() { os.Exit(0) } + if Config != nil && !utils.IsFileExist(*Config) { + panic("Config file not found") + } + + viper.SetConfigFile(*Config) + if err := viper.ReadInConfig(); err != nil { + panic(err) + } + } func help() { - fmt.Println("One API " + common.Version + " - All in one API service for OpenAI API.") + fmt.Println("One API " + config.Version + " - All in one API service for OpenAI API.") fmt.Println("Copyright (C) 2024 MartialBE. All rights reserved.") fmt.Println("Original copyright holder: JustSong") fmt.Println("GitHub: https://github.com/MartialBE/one-api") diff --git a/common/common.go b/common/common.go index e3d6accc..41de9367 100644 --- a/common/common.go +++ b/common/common.go @@ -1,10 +1,13 @@ package common -import "fmt" +import ( + "fmt" + "one-api/common/config" +) func LogQuota(quota int) string { - if DisplayInCurrencyEnabled { - return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit) + if config.DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit) } else { return fmt.Sprintf("%d 点额度", quota) } diff --git a/common/config/config.go b/common/config/config.go index 3998c293..da6110c9 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -5,37 +5,22 @@ import ( "strings" "time" - "one-api/cli" - "one-api/common" "one-api/common/utils" "github.com/spf13/viper" ) func InitConf() { - cli.FlagConfig() defaultConfig() - setConfigFile() setEnv() if viper.GetBool("debug") { logger.SysLog("running in debug mode") } - common.IsMasterNode = viper.GetString("node_type") != "slave" - common.RequestInterval = time.Duration(viper.GetInt("polling_interval")) * time.Second - common.SessionSecret = utils.GetOrDefault("session_secret", common.SessionSecret) -} - -func setConfigFile() { - if !utils.IsFileExist(*cli.Config) { - return - } - - viper.SetConfigFile(*cli.Config) - if err := viper.ReadInConfig(); err != nil { - panic(err) - } + IsMasterNode = viper.GetString("node_type") != "slave" + RequestInterval = time.Duration(viper.GetInt("polling_interval")) * time.Second + SessionSecret = utils.GetOrDefault("session_secret", SessionSecret) } func setEnv() { diff --git a/common/constants.go b/common/config/constants.go similarity index 99% rename from common/constants.go rename to common/config/constants.go index 1e0c5605..345c8b9e 100644 --- a/common/constants.go +++ b/common/config/constants.go @@ -1,4 +1,4 @@ -package common +package config import ( "sync" diff --git a/common/notify/channel/email.go b/common/notify/channel/email.go index 6fde98a5..05425b1a 100644 --- a/common/notify/channel/email.go +++ b/common/notify/channel/email.go @@ -3,7 +3,7 @@ package channel import ( "context" "errors" - "one-api/common" + "one-api/common/config" "one-api/common/stmp" "github.com/gomarkdown/markdown" @@ -28,10 +28,10 @@ func (e *Email) Name() string { func (e *Email) Send(ctx context.Context, title, message string) error { to := e.To if to == "" { - to = common.RootUserEmail + to = config.RootUserEmail } - if common.SMTPServer == "" || common.SMTPAccount == "" || common.SMTPToken == "" || to == "" { + if config.SMTPServer == "" || config.SMTPAccount == "" || config.SMTPToken == "" || to == "" { return errors.New("smtp config is not set, skip send email notifier") } @@ -44,7 +44,7 @@ func (e *Email) Send(ctx context.Context, title, message string) error { body := markdown.Render(doc, renderer) - emailClient := stmp.NewStmp(common.SMTPServer, common.SMTPPort, common.SMTPAccount, common.SMTPToken, common.SMTPFrom) + emailClient := stmp.NewStmp(config.SMTPServer, config.SMTPPort, config.SMTPAccount, config.SMTPToken, config.SMTPFrom) return emailClient.Send(to, title, string(body)) } diff --git a/common/redis.go b/common/redis.go index de537138..26fa3c31 100644 --- a/common/redis.go +++ b/common/redis.go @@ -2,6 +2,7 @@ package common import ( "context" + "one-api/common/config" "one-api/common/logger" "time" @@ -41,7 +42,7 @@ func InitRedisClient() (err error) { } else { RedisEnabled = true // for compatibility with old versions - MemoryCacheEnabled = true + config.MemoryCacheEnabled = true } return err diff --git a/common/stmp/email.go b/common/stmp/email.go index 1f8b6a03..8d3292c9 100644 --- a/common/stmp/email.go +++ b/common/stmp/email.go @@ -3,6 +3,7 @@ package stmp import ( "fmt" "one-api/common" + "one-api/common/config" "one-api/common/utils" "strings" @@ -38,7 +39,7 @@ func (s *StmpConfig) Send(to, subject, body string) error { message.Subject(subject) message.SetGenHeader("References", s.getReferences()) message.SetBodyString(mail.TypeTextHTML, body) - message.SetUserAgent(fmt.Sprintf("One API %s // https://github.com/MartialBE/one-api", common.Version)) + message.SetUserAgent(fmt.Sprintf("One API %s // https://github.com/MartialBE/one-api", config.Version)) client, err := mail.NewClient( s.Host, @@ -78,11 +79,11 @@ func (s *StmpConfig) Render(to, subject, content string) error { } func GetSystemStmp() (*StmpConfig, error) { - if common.SMTPServer == "" || common.SMTPPort == 0 || common.SMTPAccount == "" || common.SMTPToken == "" { + if config.SMTPServer == "" || config.SMTPPort == 0 || config.SMTPAccount == "" || config.SMTPToken == "" { return nil, fmt.Errorf("SMTP 信息未配置") } - return NewStmp(common.SMTPServer, common.SMTPPort, common.SMTPAccount, common.SMTPToken, common.SMTPFrom), nil + return NewStmp(config.SMTPServer, config.SMTPPort, config.SMTPAccount, config.SMTPToken, config.SMTPFrom), nil } func SendPasswordResetEmail(userName, email, link string) error { @@ -106,7 +107,7 @@ func SendPasswordResetEmail(userName, email, link string) error {

重置链接 %d 分钟内有效,如果不是本人操作,请忽略。

` - subject := fmt.Sprintf("%s密码重置", common.SystemName) + subject := fmt.Sprintf("%s密码重置", config.SystemName) content := fmt.Sprintf(contentTemp, userName, link, link, common.VerificationValidMinutes) return stmp.Render(email, subject, content) @@ -132,7 +133,7 @@ func SendVerificationCodeEmail(email, code string) error { 验证码 %d 分钟内有效,如果不是本人操作,请忽略。

` - subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) + subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName) content := fmt.Sprintf(contentTemp, code, common.VerificationValidMinutes) return stmp.Render(email, subject, content) @@ -162,7 +163,7 @@ func SendQuotaWarningCodeEmail(userName, email string, quota int, noMoreQuota bo if noMoreQuota { subject = "您的额度已用尽" } - topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) + topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress) content := fmt.Sprintf(contentTemp, userName, subject, quota, topUpLink, topUpLink) diff --git a/common/stmp/email_test.go b/common/stmp/email_test.go index 778c6ce9..081de2f2 100644 --- a/common/stmp/email_test.go +++ b/common/stmp/email_test.go @@ -2,6 +2,7 @@ package stmp_test import ( "fmt" + "one-api/common/config" "testing" "one-api/common" @@ -56,7 +57,7 @@ func TestSend(t *testing.T) { 验证码 %d 分钟内有效,如果不是本人操作,请忽略。

` - subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) + subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName) content := fmt.Sprintf(contentTemp, code, common.VerificationValidMinutes) err := stmpClient.Render(email, subject, content) diff --git a/common/stmp/template.go b/common/stmp/template.go index 8ac44147..1eb884d9 100644 --- a/common/stmp/template.go +++ b/common/stmp/template.go @@ -1,17 +1,17 @@ package stmp import ( - "one-api/common" + "one-api/common/config" ) func getLogo() string { - if common.Logo == "" { + if config.Logo == "" { return "" } return ` @@ -19,11 +19,11 @@ func getLogo() string { } func getSystemName() string { - if common.SystemName == "" { + if config.SystemName == "" { return "One API" } - return common.SystemName + return config.SystemName } func getDefaultTemplate(content string) string { diff --git a/common/telegram/command_aff.go b/common/telegram/command_aff.go index 68733169..a7d3188b 100644 --- a/common/telegram/command_aff.go +++ b/common/telegram/command_aff.go @@ -1,7 +1,7 @@ package telegram import ( - "one-api/common" + "one-api/common/config" "one-api/common/utils" "strings" @@ -24,8 +24,8 @@ func commandAffStart(b *gotgbot.Bot, ctx *ext.Context) error { } messae := "您可以通过分享您的邀请码来邀请朋友,每次成功邀请将获得奖励。\n\n您的邀请码是: " + user.AffCode - if common.ServerAddress != "" { - serverAddress := strings.TrimSuffix(common.ServerAddress, "/") + if config.ServerAddress != "" { + serverAddress := strings.TrimSuffix(config.ServerAddress, "/") messae += "\n\n页面地址:" + serverAddress + "/register?aff=" + user.AffCode } diff --git a/common/telegram/command_apikey.go b/common/telegram/command_apikey.go index cd6295db..c0062b3c 100644 --- a/common/telegram/command_apikey.go +++ b/common/telegram/command_apikey.go @@ -3,7 +3,7 @@ package telegram import ( "fmt" "net/url" - "one-api/common" + "one-api/common/config" "one-api/model" "strings" @@ -56,7 +56,7 @@ func getApikeyList(userId, page int) (message string, pageParams *paginationPara } chatUrlTmp := "" - if common.ServerAddress != "" { + if config.ServerAddress != "" { chatUrlTmp = getChatUrl() } @@ -75,11 +75,11 @@ func getApikeyList(userId, page int) (message string, pageParams *paginationPara } func getChatUrl() string { - serverAddress := strings.TrimSuffix(common.ServerAddress, "/") + serverAddress := strings.TrimSuffix(config.ServerAddress, "/") chatNextUrl := fmt.Sprintf(`{"key":"setToken","url":"%s"}`, serverAddress) chatNextUrl = "https://chat.oneapi.pro/#/?settings=" + url.QueryEscape(chatNextUrl) - if common.ChatLink != "" { - chatLink := strings.TrimSuffix(common.ChatLink, "/") + if config.ChatLink != "" { + chatLink := strings.TrimSuffix(config.ChatLink, "/") chatNextUrl = strings.ReplaceAll(chatNextUrl, `https://chat.oneapi.pro`, chatLink) } diff --git a/common/telegram/common.go b/common/telegram/common.go index 4a3a3160..9b039b76 100644 --- a/common/telegram/common.go +++ b/common/telegram/common.go @@ -7,7 +7,7 @@ import ( "net" "net/http" "net/url" - "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/model" "strings" @@ -56,13 +56,13 @@ func InitTelegramBot() { func StartTelegramBot() { botWebhook := viper.GetString("tg.webhook_secret") if botWebhook != "" { - if common.ServerAddress == "" { + if config.ServerAddress == "" { logger.SysLog("Telegram bot is not enabled: Server address is not set") StopTelegramBot() return } TGWebHookSecret = botWebhook - serverAddress := strings.TrimSuffix(common.ServerAddress, "/") + serverAddress := strings.TrimSuffix(config.ServerAddress, "/") urlPath := fmt.Sprintf("/api/telegram/%s", viper.GetString("tg.bot_api_key")) webHookOpts := &ext.AddWebhookOpts{ diff --git a/common/token.go b/common/token.go index df125193..8d5090c3 100644 --- a/common/token.go +++ b/common/token.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "math" + "one-api/common/config" "one-api/common/logger" "strings" @@ -21,7 +22,7 @@ var gpt4oTokenEncoder *tiktoken.Tiktoken func InitTokenEncoders() { if viper.GetBool("disable_token_encoders") { - DISABLE_TOKEN_ENCODERS = true + config.DISABLE_TOKEN_ENCODERS = true logger.SysLog("token encoders disabled") return } @@ -46,7 +47,7 @@ func InitTokenEncoders() { } func getTokenEncoder(model string) *tiktoken.Tiktoken { - if DISABLE_TOKEN_ENCODERS { + if config.DISABLE_TOKEN_ENCODERS { return nil } @@ -75,7 +76,7 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken { } func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { - if DISABLE_TOKEN_ENCODERS || ApproximateTokenEnabled { + if config.DISABLE_TOKEN_ENCODERS || config.ApproximateTokenEnabled { return int(float64(len(text)) * 0.38) } return len(tokenEncoder.Encode(text, nil, nil)) diff --git a/controller/billing.go b/controller/billing.go index 3de7c847..6e6192b8 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -1,7 +1,7 @@ package controller import ( - "one-api/common" + "one-api/common/config" "one-api/model" "one-api/types" @@ -14,7 +14,7 @@ func GetSubscription(c *gin.Context) { var err error var token *model.Token var expiredTime int64 - if common.DisplayTokenStatEnabled { + if config.DisplayTokenStatEnabled { tokenId := c.GetInt("token_id") token, err = model.GetTokenById(tokenId) expiredTime = token.ExpiredTime @@ -50,8 +50,8 @@ func GetSubscription(c *gin.Context) { } quota := remainQuota + usedQuota amount := float64(quota) - if common.DisplayInCurrencyEnabled { - amount /= common.QuotaPerUnit + if config.DisplayInCurrencyEnabled { + amount /= config.QuotaPerUnit } if token != nil && token.UnlimitedQuota { amount = 100000000 @@ -71,7 +71,7 @@ func GetUsage(c *gin.Context) { var quota int var err error var token *model.Token - if common.DisplayTokenStatEnabled { + if config.DisplayTokenStatEnabled { tokenId := c.GetInt("token_id") token, err = model.GetTokenById(tokenId) quota = token.UsedQuota @@ -90,8 +90,8 @@ func GetUsage(c *gin.Context) { return } amount := float64(quota) - if common.DisplayInCurrencyEnabled { - amount /= common.QuotaPerUnit + if config.DisplayInCurrencyEnabled { + amount /= config.QuotaPerUnit } usage := OpenAIUsageResponse{ Object: "list", diff --git a/controller/channel-billing.go b/controller/channel-billing.go index ffb53fbd..2f3254d1 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -4,7 +4,7 @@ import ( "errors" "net/http" "net/http/httptest" - "one-api/common" + "one-api/common/config" "one-api/model" "one-api/providers" providersBase "one-api/providers/base" @@ -109,11 +109,11 @@ func updateAllChannelsBalance() error { return err } for _, channel := range channels { - if channel.Status != common.ChannelStatusEnabled { + if channel.Status != config.ChannelStatusEnabled { continue } // TODO: support Azure - if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { + if channel.Type != config.ChannelTypeOpenAI && channel.Type != config.ChannelTypeCustom { continue } balance, err := updateChannelBalance(channel) @@ -125,7 +125,7 @@ func updateAllChannelsBalance() error { DisableChannel(channel.Id, channel.Name, "余额不足", true) } } - time.Sleep(common.RequestInterval) + time.Sleep(config.RequestInterval) } return nil } diff --git a/controller/channel-test.go b/controller/channel-test.go index ee380893..677d48c7 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -6,7 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" - "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/common/notify" "one-api/common/utils" @@ -145,16 +145,16 @@ func testAllChannels(isNotify bool) error { if err != nil { return err } - var disableThreshold = int64(common.ChannelDisableThreshold * 1000) + var disableThreshold = int64(config.ChannelDisableThreshold * 1000) if disableThreshold == 0 { disableThreshold = 10000000 // a impossible value } go func() { var sendMessage string for _, channel := range channels { - time.Sleep(common.RequestInterval) + time.Sleep(config.RequestInterval) - isChannelEnabled := channel.Status == common.ChannelStatusEnabled + isChannelEnabled := channel.Status == config.ChannelStatusEnabled sendMessage += fmt.Sprintf("**通道 %s - #%d - %s** : \n\n", utils.EscapeMarkdownText(channel.Name), channel.Id, channel.StatusToStr()) tik := time.Now() err, openaiErr := testChannel(channel, "") @@ -173,7 +173,7 @@ func testAllChannels(isNotify bool) error { // 如果已被禁用,但是请求成功,需要判断是否需要恢复 // 手动禁用的通道,不会自动恢复 if shouldEnableChannel(err, openaiErr) { - if channel.Status == common.ChannelStatusAutoDisabled { + if channel.Status == config.ChannelStatusAutoDisabled { EnableChannel(channel.Id, channel.Name, false) sendMessage += "- 已被启用 \n\n" } else { diff --git a/controller/common.go b/controller/common.go index bad43a3e..2a3e17f3 100644 --- a/controller/common.go +++ b/controller/common.go @@ -3,7 +3,7 @@ package controller import ( "fmt" "net/http" - "one-api/common" + "one-api/common/config" "one-api/common/notify" "one-api/model" "one-api/types" @@ -13,7 +13,7 @@ import ( ) func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool { - if !common.AutomaticEnableChannelEnabled { + if !config.AutomaticEnableChannelEnabled { return false } if err != nil { @@ -26,7 +26,7 @@ func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool { } func ShouldDisableChannel(err *types.OpenAIError, statusCode int) bool { - if !common.AutomaticDisableChannelEnabled { + if !config.AutomaticDisableChannelEnabled { return false } @@ -74,7 +74,7 @@ func ShouldDisableChannel(err *types.OpenAIError, statusCode int) bool { // disable & notify func DisableChannel(channelId int, channelName string, reason string, sendNotify bool) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + model.UpdateChannelStatusById(channelId, config.ChannelStatusAutoDisabled) if !sendNotify { return } @@ -86,7 +86,7 @@ func DisableChannel(channelId int, channelName string, reason string, sendNotify // enable & notify func EnableChannel(channelId int, channelName string, sendNotify bool) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) + model.UpdateChannelStatusById(channelId, config.ChannelStatusEnabled) if !sendNotify { return } diff --git a/controller/github.go b/controller/github.go index 2c452bc0..5c83a89e 100644 --- a/controller/github.go +++ b/controller/github.go @@ -6,7 +6,7 @@ import ( "errors" "fmt" "net/http" - "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/common/utils" "one-api/model" @@ -33,7 +33,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { if code == "" { return nil, errors.New("无效的参数") } - values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code} + values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code} jsonData, err := json.Marshal(values) if err != nil { return nil, err @@ -96,7 +96,7 @@ func GitHubOAuth(c *gin.Context) { return } - if !common.GitHubOAuthEnabled { + if !config.GitHubOAuthEnabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 GitHub 登录以及注册", @@ -125,7 +125,7 @@ func GitHubOAuth(c *gin.Context) { return } } else { - if common.RegisterEnabled { + if config.RegisterEnabled { user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) if githubUser.Name != "" { user.DisplayName = githubUser.Name @@ -133,8 +133,8 @@ func GitHubOAuth(c *gin.Context) { user.DisplayName = "GitHub User" } user.Email = githubUser.Email - user.Role = common.RoleCommonUser - user.Status = common.UserStatusEnabled + user.Role = config.RoleCommonUser + user.Status = config.UserStatusEnabled if err := user.Insert(0); err != nil { c.JSON(http.StatusOK, gin.H{ @@ -152,7 +152,7 @@ func GitHubOAuth(c *gin.Context) { } } - if user.Status != common.UserStatusEnabled { + if user.Status != config.UserStatusEnabled { c.JSON(http.StatusOK, gin.H{ "message": "用户已被封禁", "success": false, @@ -163,7 +163,7 @@ func GitHubOAuth(c *gin.Context) { } func GitHubBind(c *gin.Context) { - if !common.GitHubOAuthEnabled { + if !config.GitHubOAuthEnabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 GitHub 登录以及注册", diff --git a/controller/lark.go b/controller/lark.go index b8e8d06f..7fc88951 100644 --- a/controller/lark.go +++ b/controller/lark.go @@ -6,7 +6,7 @@ import ( "errors" "fmt" "net/http" - "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/model" "strconv" @@ -41,8 +41,8 @@ type LarkUser struct { func getLarkAppAccessToken() (string, error) { values := map[string]string{ - "app_id": common.LarkClientId, - "app_secret": common.LarkClientSecret, + "app_id": config.LarkClientId, + "app_secret": config.LarkClientSecret, } jsonData, err := json.Marshal(values) if err != nil { @@ -148,7 +148,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) { } func LarkOAuth(c *gin.Context) { - if !common.LarkAuthEnabled { + if !config.LarkAuthEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员未开启通过飞书登录以及注册", "success": false, @@ -191,15 +191,15 @@ func LarkOAuth(c *gin.Context) { return } } else { - if common.RegisterEnabled { + if config.RegisterEnabled { user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1) if larkUser.Data.Name != "" { user.DisplayName = larkUser.Data.Name } else { user.DisplayName = "Lark User" } - user.Role = common.RoleCommonUser - user.Status = common.UserStatusEnabled + user.Role = config.RoleCommonUser + user.Status = config.UserStatusEnabled if err := user.Insert(0); err != nil { c.JSON(http.StatusOK, gin.H{ @@ -217,7 +217,7 @@ func LarkOAuth(c *gin.Context) { } } - if user.Status != common.UserStatusEnabled { + if user.Status != config.UserStatusEnabled { c.JSON(http.StatusOK, gin.H{ "message": "用户已被封禁", "success": false, @@ -228,7 +228,7 @@ func LarkOAuth(c *gin.Context) { } func LarkBind(c *gin.Context) { - if !common.LarkAuthEnabled { + if !config.LarkAuthEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员未开启通过飞书登录以及注册", "success": false, diff --git a/controller/misc.go b/controller/misc.go index e2c37e50..f026262e 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/stmp" "one-api/common/telegram" "one-api/model" @@ -23,60 +24,60 @@ func GetStatus(c *gin.Context) { "success": true, "message": "", "data": gin.H{ - "version": common.Version, - "start_time": common.StartTime, - "email_verification": common.EmailVerificationEnabled, - "github_oauth": common.GitHubOAuthEnabled, - "github_client_id": common.GitHubClientId, - "lark_login": common.LarkAuthEnabled, - "lark_client_id": common.LarkClientId, - "system_name": common.SystemName, - "logo": common.Logo, - "footer_html": common.Footer, - "wechat_qrcode": common.WeChatAccountQRCodeImageURL, - "wechat_login": common.WeChatAuthEnabled, - "server_address": common.ServerAddress, - "turnstile_check": common.TurnstileCheckEnabled, - "turnstile_site_key": common.TurnstileSiteKey, - "top_up_link": common.TopUpLink, - "chat_link": common.ChatLink, - "quota_per_unit": common.QuotaPerUnit, - "display_in_currency": common.DisplayInCurrencyEnabled, + "version": config.Version, + "start_time": config.StartTime, + "email_verification": config.EmailVerificationEnabled, + "github_oauth": config.GitHubOAuthEnabled, + "github_client_id": config.GitHubClientId, + "lark_login": config.LarkAuthEnabled, + "lark_client_id": config.LarkClientId, + "system_name": config.SystemName, + "logo": config.Logo, + "footer_html": config.Footer, + "wechat_qrcode": config.WeChatAccountQRCodeImageURL, + "wechat_login": config.WeChatAuthEnabled, + "server_address": config.ServerAddress, + "turnstile_check": config.TurnstileCheckEnabled, + "turnstile_site_key": config.TurnstileSiteKey, + "top_up_link": config.TopUpLink, + "chat_link": config.ChatLink, + "quota_per_unit": config.QuotaPerUnit, + "display_in_currency": config.DisplayInCurrencyEnabled, "telegram_bot": telegram_bot, - "mj_notify_enabled": common.MjNotifyEnabled, - "chat_cache_enabled": common.ChatCacheEnabled, - "chat_links": common.ChatLinks, + "mj_notify_enabled": config.MjNotifyEnabled, + "chat_cache_enabled": config.ChatCacheEnabled, + "chat_links": config.ChatLinks, }, }) } func GetNotice(c *gin.Context) { - common.OptionMapRWMutex.RLock() - defer common.OptionMapRWMutex.RUnlock() + config.OptionMapRWMutex.RLock() + defer config.OptionMapRWMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": common.OptionMap["Notice"], + "data": config.OptionMap["Notice"], }) } func GetAbout(c *gin.Context) { - common.OptionMapRWMutex.RLock() - defer common.OptionMapRWMutex.RUnlock() + config.OptionMapRWMutex.RLock() + defer config.OptionMapRWMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": common.OptionMap["About"], + "data": config.OptionMap["About"], }) } func GetHomePageContent(c *gin.Context) { - common.OptionMapRWMutex.RLock() - defer common.OptionMapRWMutex.RUnlock() + config.OptionMapRWMutex.RLock() + defer config.OptionMapRWMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": common.OptionMap["HomePageContent"], + "data": config.OptionMap["HomePageContent"], }) } @@ -89,9 +90,9 @@ func SendEmailVerification(c *gin.Context) { }) return } - if common.EmailDomainRestrictionEnabled { + if config.EmailDomainRestrictionEnabled { allowed := false - for _, domain := range common.EmailDomainWhitelist { + for _, domain := range config.EmailDomainWhitelist { if strings.HasSuffix(email, "@"+domain) { allowed = true break @@ -157,7 +158,7 @@ func SendPasswordResetEmail(c *gin.Context) { code := common.GenerateVerificationCode(0) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) - link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) + link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code) err := stmp.SendPasswordResetEmail(userName, email, link) if err != nil { diff --git a/controller/option.go b/controller/option.go index 99f36607..438688f0 100644 --- a/controller/option.go +++ b/controller/option.go @@ -3,7 +3,7 @@ package controller import ( "encoding/json" "net/http" - "one-api/common" + "one-api/common/config" "one-api/common/utils" "one-api/model" "strings" @@ -13,8 +13,8 @@ import ( func GetOptions(c *gin.Context) { var options []*model.Option - common.OptionMapRWMutex.Lock() - for k, v := range common.OptionMap { + config.OptionMapRWMutex.Lock() + for k, v := range config.OptionMap { if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { continue } @@ -23,7 +23,7 @@ func GetOptions(c *gin.Context) { Value: utils.Interface2String(v), }) } - common.OptionMapRWMutex.Unlock() + config.OptionMapRWMutex.Unlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", @@ -44,7 +44,7 @@ func UpdateOption(c *gin.Context) { } switch option.Key { case "GitHubOAuthEnabled": - if option.Value == "true" && common.GitHubClientId == "" { + if option.Value == "true" && config.GitHubClientId == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", @@ -52,7 +52,7 @@ func UpdateOption(c *gin.Context) { return } case "EmailDomainRestrictionEnabled": - if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { + if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", @@ -60,7 +60,7 @@ func UpdateOption(c *gin.Context) { return } case "WeChatAuthEnabled": - if option.Value == "true" && common.WeChatServerAddress == "" { + if option.Value == "true" && config.WeChatServerAddress == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用微信登录,请先填入微信登录相关配置信息!", @@ -68,7 +68,7 @@ func UpdateOption(c *gin.Context) { return } case "TurnstileCheckEnabled": - if option.Value == "true" && common.TurnstileSiteKey == "" { + if option.Value == "true" && config.TurnstileSiteKey == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", diff --git a/controller/token.go b/controller/token.go index 9f43b111..43eed435 100644 --- a/controller/token.go +++ b/controller/token.go @@ -3,6 +3,7 @@ package controller import ( "net/http" "one-api/common" + "one-api/common/config" "one-api/common/utils" "one-api/model" "strconv" @@ -199,15 +200,15 @@ func UpdateToken(c *gin.Context) { }) return } - if token.Status == common.TokenStatusEnabled { - if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= utils.GetTimestamp() && cleanToken.ExpiredTime != -1 { + if token.Status == config.TokenStatusEnabled { + if cleanToken.Status == config.TokenStatusExpired && cleanToken.ExpiredTime <= utils.GetTimestamp() && cleanToken.ExpiredTime != -1 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", }) return } - if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { + if cleanToken.Status == config.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", diff --git a/controller/user.go b/controller/user.go index 4973314d..82b4a556 100644 --- a/controller/user.go +++ b/controller/user.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/utils" "one-api/model" "strconv" @@ -20,7 +21,7 @@ type LoginRequest struct { } func Login(c *gin.Context) { - if !common.PasswordLoginEnabled { + if !config.PasswordLoginEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员关闭了密码登录", "success": false, @@ -107,14 +108,14 @@ func Logout(c *gin.Context) { } func Register(c *gin.Context) { - if !common.RegisterEnabled { + if !config.RegisterEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员关闭了新用户注册", "success": false, }) return } - if !common.PasswordRegisterEnabled { + if !config.PasswordRegisterEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册", "success": false, @@ -137,7 +138,7 @@ func Register(c *gin.Context) { }) return } - if common.EmailVerificationEnabled { + if config.EmailVerificationEnabled { if user.Email == "" || user.VerificationCode == "" { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -161,7 +162,7 @@ func Register(c *gin.Context) { DisplayName: user.Username, InviterId: inviterId, } - if common.EmailVerificationEnabled { + if config.EmailVerificationEnabled { cleanUser.Email = user.Email } if err := cleanUser.Insert(inviterId); err != nil { @@ -214,7 +215,7 @@ func GetUser(c *gin.Context) { return } myRole := c.GetInt("role") - if myRole <= user.Role && myRole != common.RoleRootUser { + if myRole <= user.Role && myRole != config.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权获取同级或更高等级用户的信息", @@ -360,14 +361,14 @@ func UpdateUser(c *gin.Context) { return } myRole := c.GetInt("role") - if myRole <= originUser.Role && myRole != common.RoleRootUser { + if myRole <= originUser.Role && myRole != config.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权更新同权限等级或更高权限等级的用户信息", }) return } - if myRole <= updatedUser.Role && myRole != common.RoleRootUser { + if myRole <= updatedUser.Role && myRole != config.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权将其他用户权限等级提升到大于等于自己的权限等级", @@ -479,7 +480,7 @@ func DeleteSelf(c *gin.Context) { id := c.GetInt("id") user, _ := model.GetUserById(id, false) - if user.Role == common.RoleRootUser { + if user.Role == config.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "不能删除超级管理员账户", @@ -579,7 +580,7 @@ func ManageUser(c *gin.Context) { return } myRole := c.GetInt("role") - if myRole <= user.Role && myRole != common.RoleRootUser { + if myRole <= user.Role && myRole != config.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权更新同权限等级或更高权限等级的用户信息", @@ -588,8 +589,8 @@ func ManageUser(c *gin.Context) { } switch req.Action { case "disable": - user.Status = common.UserStatusDisabled - if user.Role == common.RoleRootUser { + user.Status = config.UserStatusDisabled + if user.Role == config.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法禁用超级管理员用户", @@ -597,9 +598,9 @@ func ManageUser(c *gin.Context) { return } case "enable": - user.Status = common.UserStatusEnabled + user.Status = config.UserStatusEnabled case "delete": - if user.Role == common.RoleRootUser { + if user.Role == config.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法删除超级管理员用户", @@ -614,37 +615,37 @@ func ManageUser(c *gin.Context) { return } case "promote": - if myRole != common.RoleRootUser { + if myRole != config.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "普通管理员用户无法提升其他用户为管理员", }) return } - if user.Role >= common.RoleAdminUser { + if user.Role >= config.RoleAdminUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该用户已经是管理员", }) return } - user.Role = common.RoleAdminUser + user.Role = config.RoleAdminUser case "demote": - if user.Role == common.RoleRootUser { + if user.Role == config.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法降级超级管理员用户", }) return } - if user.Role == common.RoleCommonUser { + if user.Role == config.RoleCommonUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该用户已经是普通用户", }) return } - user.Role = common.RoleCommonUser + user.Role = config.RoleCommonUser } if err := user.Update(false); err != nil { diff --git a/controller/wechat.go b/controller/wechat.go index fbd7d2bd..2ff9140b 100644 --- a/controller/wechat.go +++ b/controller/wechat.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" "net/http" - "one-api/common" + "one-api/common/config" "one-api/model" "strconv" "time" @@ -23,11 +23,11 @@ func getWeChatIdByCode(code string) (string, error) { if code == "" { return "", errors.New("无效的参数") } - req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil) + req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil) if err != nil { return "", err } - req.Header.Set("Authorization", common.WeChatServerToken) + req.Header.Set("Authorization", config.WeChatServerToken) client := http.Client{ Timeout: 5 * time.Second, } @@ -51,7 +51,7 @@ func getWeChatIdByCode(code string) (string, error) { } func WeChatAuth(c *gin.Context) { - if !common.WeChatAuthEnabled { + if !config.WeChatAuthEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员未开启通过微信登录以及注册", "success": false, @@ -80,11 +80,11 @@ func WeChatAuth(c *gin.Context) { return } } else { - if common.RegisterEnabled { + if config.RegisterEnabled { user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.DisplayName = "WeChat User" - user.Role = common.RoleCommonUser - user.Status = common.UserStatusEnabled + user.Role = config.RoleCommonUser + user.Status = config.UserStatusEnabled if err := user.Insert(0); err != nil { c.JSON(http.StatusOK, gin.H{ @@ -102,7 +102,7 @@ func WeChatAuth(c *gin.Context) { } } - if user.Status != common.UserStatusEnabled { + if user.Status != config.UserStatusEnabled { c.JSON(http.StatusOK, gin.H{ "message": "用户已被封禁", "success": false, @@ -113,7 +113,7 @@ func WeChatAuth(c *gin.Context) { } func WeChatBind(c *gin.Context) { - if !common.WeChatAuthEnabled { + if !config.WeChatAuthEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员未开启通过微信登录以及注册", "success": false, diff --git a/main.go b/main.go index ef0c1984..cc68a8a4 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "embed" "fmt" + "one-api/cli" "one-api/common" "one-api/common/config" "one-api/common/logger" @@ -31,9 +32,10 @@ var buildFS embed.FS var indexPage []byte func main() { + cli.InitCli() config.InitConf() logger.SetupLogger() - logger.SysLog("One API " + common.Version + " started") + logger.SysLog("One API " + config.Version + " started") // Initialize SQL Database model.SetupDB() defer model.CloseDB() @@ -60,10 +62,10 @@ func main() { func initMemoryCache() { if viper.GetBool("memory_cache_enabled") { - common.MemoryCacheEnabled = true + config.MemoryCacheEnabled = true } - if !common.MemoryCacheEnabled { + if !config.MemoryCacheEnabled { return } @@ -91,7 +93,7 @@ func initHttpServer() { server.Use(middleware.RequestId()) middleware.SetUpLogger(server) - store := cookie.NewStore([]byte(common.SessionSecret)) + store := cookie.NewStore([]byte(config.SessionSecret)) server.Use(sessions.Sessions("session", store)) router.SetRouter(server, buildFS, indexPage) @@ -105,7 +107,7 @@ func initHttpServer() { func SyncChannelCache(frequency int) { // 只有 从 服务器端获取数据的时候才会用到 - if common.IsMasterNode { + if config.IsMasterNode { logger.SysLog("master node does't synchronize the channel") return } diff --git a/middleware/auth.go b/middleware/auth.go index 697acacf..e3834a4b 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -2,7 +2,7 @@ package middleware import ( "net/http" - "one-api/common" + "one-api/common/config" "one-api/common/utils" "one-api/model" "strings" @@ -44,7 +44,7 @@ func authHelper(c *gin.Context, minRole int) { return } } - if status.(int) == common.UserStatusDisabled { + if status.(int) == config.UserStatusDisabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "用户已被封禁", @@ -68,19 +68,19 @@ func authHelper(c *gin.Context, minRole int) { func UserAuth() func(c *gin.Context) { return func(c *gin.Context) { - authHelper(c, common.RoleCommonUser) + authHelper(c, config.RoleCommonUser) } } func AdminAuth() func(c *gin.Context) { return func(c *gin.Context) { - authHelper(c, common.RoleAdminUser) + authHelper(c, config.RoleAdminUser) } } func RootAuth() func(c *gin.Context) { return func(c *gin.Context) { - authHelper(c, common.RoleRootUser) + authHelper(c, config.RoleRootUser) } } diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go index dae6639c..27f2f190 100644 --- a/middleware/rate-limit.go +++ b/middleware/rate-limit.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/utils" "time" @@ -45,7 +46,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st } if listLength < int64(maxRequestNum) { rdb.LPush(ctx, key, time.Now().Format(timeFormat)) - rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) } else { oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTime, err := time.Parse(timeFormat, oldTimeStr) @@ -64,14 +65,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st // time.Since will return negative number! // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows if int64(nowTime.Sub(oldTime).Seconds()) < duration { - rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) c.Status(http.StatusTooManyRequests) c.Abort() return } else { rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) - rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) } } } @@ -92,7 +93,7 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi } } else { // It's safe to call multi times. - inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) + inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration) return func(c *gin.Context) { memoryRateLimiter(c, maxRequestNum, duration, mark) } diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go index 6f295864..629395e7 100644 --- a/middleware/turnstile-check.go +++ b/middleware/turnstile-check.go @@ -6,7 +6,7 @@ import ( "github.com/gin-gonic/gin" "net/http" "net/url" - "one-api/common" + "one-api/common/config" "one-api/common/logger" ) @@ -16,7 +16,7 @@ type turnstileCheckResponse struct { func TurnstileCheck() gin.HandlerFunc { return func(c *gin.Context) { - if common.TurnstileCheckEnabled { + if config.TurnstileCheckEnabled { session := sessions.Default(c) turnstileChecked := session.Get("turnstile") if turnstileChecked != nil { @@ -33,7 +33,7 @@ func TurnstileCheck() gin.HandlerFunc { return } rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{ - "secret": {common.TurnstileSecretKey}, + "secret": {config.TurnstileSecretKey}, "response": {response}, "remoteip": {c.ClientIP()}, }) diff --git a/model/ability.go b/model/ability.go index 8aeb2811..2b973e80 100644 --- a/model/ability.go +++ b/model/ability.go @@ -2,6 +2,7 @@ package model import ( "one-api/common" + "one-api/common/config" "strings" ) @@ -66,7 +67,7 @@ func (channel *Channel) AddAbilities() error { Group: group, Model: model, ChannelId: channel.Id, - Enabled: channel.Status == common.ChannelStatusEnabled, + Enabled: channel.Status == config.ChannelStatusEnabled, Priority: channel.Priority, Weight: channel.Weight, } diff --git a/model/balancer.go b/model/balancer.go index 82cd0289..9873d90f 100644 --- a/model/balancer.go +++ b/model/balancer.go @@ -3,7 +3,7 @@ package model import ( "errors" "math/rand" - "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/common/utils" "strings" @@ -38,7 +38,7 @@ func FilterOnlyChat() ChannelsFilterFunc { } func (cc *ChannelsChooser) Cooldowns(channelId int) bool { - if common.RetryCooldownSeconds == 0 { + if config.RetryCooldownSeconds == 0 { return false } cc.Lock() @@ -47,7 +47,7 @@ func (cc *ChannelsChooser) Cooldowns(channelId int) bool { return false } - cc.Channels[channelId].CooldownsTime = time.Now().Unix() + int64(common.RetryCooldownSeconds) + cc.Channels[channelId].CooldownsTime = time.Now().Unix() + int64(config.RetryCooldownSeconds) return true } @@ -159,7 +159,7 @@ var ChannelGroup = ChannelsChooser{} func (cc *ChannelsChooser) Load() { var channels []*Channel - DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) + DB.Where("status = ?", config.ChannelStatusEnabled).Find(&channels) abilities, err := GetAbilityChannelGroup() if err != nil { @@ -173,7 +173,7 @@ func (cc *ChannelsChooser) Load() { for _, channel := range channels { if *channel.Weight == 0 { - channel.Weight = &common.DefaultChannelWeight + channel.Weight = &config.DefaultChannelWeight } newChannels[channel.Id] = &ChannelChoice{ Channel: channel, diff --git a/model/channel.go b/model/channel.go index b2733ae7..ab38cefc 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,7 +1,7 @@ package model import ( - "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/common/utils" "strings" @@ -270,11 +270,11 @@ func (channel *Channel) Delete() error { func (channel *Channel) StatusToStr() string { switch channel.Status { - case common.ChannelStatusEnabled: + case config.ChannelStatusEnabled: return "启用" - case common.ChannelStatusAutoDisabled: + case config.ChannelStatusAutoDisabled: return "自动禁用" - case common.ChannelStatusManuallyDisabled: + case config.ChannelStatusManuallyDisabled: return "手动禁用" } @@ -282,7 +282,7 @@ func (channel *Channel) StatusToStr() string { } func UpdateChannelStatusById(id int, status int) { - err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) + err := UpdateAbilityStatus(id, status == config.ChannelStatusEnabled) if err != nil { logger.SysError("failed to update ability status: " + err.Error()) } @@ -298,7 +298,7 @@ func UpdateChannelStatusById(id int, status int) { } func UpdateChannelUsedQuota(id int, quota int) { - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) return } @@ -318,7 +318,7 @@ func DeleteChannelByStatus(status int64) (int64, error) { } func DeleteDisabledChannel() (int64, error) { - result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) + result := DB.Where("status = ? or status = ?", config.ChannelStatusAutoDisabled, config.ChannelStatusManuallyDisabled).Delete(&Channel{}) // 同时删除Ability DB.Where("enabled = ?", false).Delete(&Ability{}) return result.RowsAffected, result.Error diff --git a/model/common.go b/model/common.go index c85c19ab..e5a77d09 100644 --- a/model/common.go +++ b/model/common.go @@ -3,6 +3,7 @@ package model import ( "fmt" "one-api/common" + "one-api/common/config" "strings" "gorm.io/gorm" @@ -43,11 +44,11 @@ func PaginateAndOrder[T modelable](db *gorm.DB, params *PaginationParams, result params.Page = 1 } if params.Size < 1 { - params.Size = common.ItemsPerPage + params.Size = config.ItemsPerPage } - if params.Size > common.MaxRecentItems { - return nil, fmt.Errorf("size 参数不能超过 %d", common.MaxRecentItems) + if params.Size > config.MaxRecentItems { + return nil, fmt.Errorf("size 参数不能超过 %d", config.MaxRecentItems) } offset := (params.Page - 1) * params.Size diff --git a/model/log.go b/model/log.go index 6c8b2008..9bac2cc7 100644 --- a/model/log.go +++ b/model/log.go @@ -3,7 +3,7 @@ package model import ( "context" "fmt" - "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/common/utils" @@ -37,7 +37,7 @@ const ( ) func RecordLog(userId int, logType int, content string) { - if logType == LogTypeConsume && !common.LogConsumeEnabled { + if logType == LogTypeConsume && !config.LogConsumeEnabled { return } log := &Log{ @@ -55,7 +55,7 @@ func RecordLog(userId int, logType int, content string) { func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, requestTime int) { logger.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) - if !common.LogConsumeEnabled { + if !config.LogConsumeEnabled { return } log := &Log{ @@ -156,12 +156,12 @@ func GetUserLogsList(userId int, params *LogsListParams) (*DataResult[Log], erro } func SearchAllLogs(keyword string) (logs []*Log, err error) { - err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error + err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error return logs, err } func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { - err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error + err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error return logs, err } diff --git a/model/main.go b/model/main.go index a81e1dc3..e9d5ac14 100644 --- a/model/main.go +++ b/model/main.go @@ -3,6 +3,7 @@ package model import ( "fmt" "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/common/utils" "strconv" @@ -24,12 +25,12 @@ func SetupDB() { logger.FatalLog("failed to initialize database: " + err.Error()) } ChannelGroup.Load() - common.RootUserEmail = GetRootUserEmail() + config.RootUserEmail = GetRootUserEmail() if viper.GetBool("batch_update_enabled") { - common.BatchUpdateEnabled = true - common.BatchUpdateInterval = utils.GetOrDefault("batch_update_interval", 5) - logger.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") + config.BatchUpdateEnabled = true + config.BatchUpdateInterval = utils.GetOrDefault("batch_update_interval", 5) + logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") InitBatchUpdater() } } @@ -46,8 +47,8 @@ func createRootAccountIfNeed() error { rootUser := User{ Username: "root", Password: hashedPassword, - Role: common.RoleRootUser, - Status: common.UserStatusEnabled, + Role: config.RoleRootUser, + Status: config.UserStatusEnabled, DisplayName: "Root User", AccessToken: utils.GetUUID(), Quota: 100000000, @@ -102,7 +103,7 @@ func InitDB() (err error) { sqlDB.SetMaxOpenConns(utils.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) sqlDB.SetConnMaxLifetime(time.Second * time.Duration(utils.GetOrDefault("SQL_MAX_LIFETIME", 60))) - if !common.IsMasterNode { + if !config.IsMasterNode { return nil } logger.SysLog("database migration started") diff --git a/model/option.go b/model/option.go index 4e75f3f3..0dc6db0d 100644 --- a/model/option.go +++ b/model/option.go @@ -2,6 +2,7 @@ package model import ( "one-api/common" + "one-api/common/config" "one-api/common/logger" "strconv" "strings" @@ -26,63 +27,63 @@ func GetOption(key string) (option Option, err error) { } func InitOptionMap() { - common.OptionMapRWMutex.Lock() - common.OptionMap = make(map[string]string) - common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) - common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) - common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) - common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) - common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) - common.OptionMap["LarkAuthEnabled"] = strconv.FormatBool(common.LarkAuthEnabled) - common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) - common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) - common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) - common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) - common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) - common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) - common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) - common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) - common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) - common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) - common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") - common.OptionMap["SMTPServer"] = "" - common.OptionMap["SMTPFrom"] = "" - common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) - common.OptionMap["SMTPAccount"] = "" - common.OptionMap["SMTPToken"] = "" - common.OptionMap["Notice"] = "" - common.OptionMap["About"] = "" - common.OptionMap["HomePageContent"] = "" - common.OptionMap["Footer"] = common.Footer - common.OptionMap["SystemName"] = common.SystemName - common.OptionMap["Logo"] = common.Logo - common.OptionMap["ServerAddress"] = "" - common.OptionMap["GitHubClientId"] = "" - common.OptionMap["GitHubClientSecret"] = "" - common.OptionMap["WeChatServerAddress"] = "" - common.OptionMap["WeChatServerToken"] = "" - common.OptionMap["WeChatAccountQRCodeImageURL"] = "" - common.OptionMap["TurnstileSiteKey"] = "" - common.OptionMap["TurnstileSecretKey"] = "" - common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) - common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) - common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) - common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) - common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) - common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() - common.OptionMap["TopUpLink"] = common.TopUpLink - common.OptionMap["ChatLink"] = common.ChatLink - common.OptionMap["ChatLinks"] = common.ChatLinks - common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) - common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) - common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds) + config.OptionMapRWMutex.Lock() + config.OptionMap = make(map[string]string) + config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled) + config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) + config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) + config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) + config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) + config.OptionMap["LarkAuthEnabled"] = strconv.FormatBool(config.LarkAuthEnabled) + config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) + config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) + config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled) + config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled) + config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled) + config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled) + config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled) + config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled) + config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64) + config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled) + config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",") + config.OptionMap["SMTPServer"] = "" + config.OptionMap["SMTPFrom"] = "" + config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort) + config.OptionMap["SMTPAccount"] = "" + config.OptionMap["SMTPToken"] = "" + config.OptionMap["Notice"] = "" + config.OptionMap["About"] = "" + config.OptionMap["HomePageContent"] = "" + config.OptionMap["Footer"] = config.Footer + config.OptionMap["SystemName"] = config.SystemName + config.OptionMap["Logo"] = config.Logo + config.OptionMap["ServerAddress"] = "" + config.OptionMap["GitHubClientId"] = "" + config.OptionMap["GitHubClientSecret"] = "" + config.OptionMap["WeChatServerAddress"] = "" + config.OptionMap["WeChatServerToken"] = "" + config.OptionMap["WeChatAccountQRCodeImageURL"] = "" + config.OptionMap["TurnstileSiteKey"] = "" + config.OptionMap["TurnstileSecretKey"] = "" + config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser) + config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter) + config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee) + config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold) + config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota) + config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() + config.OptionMap["TopUpLink"] = config.TopUpLink + config.OptionMap["ChatLink"] = config.ChatLink + config.OptionMap["ChatLinks"] = config.ChatLinks + config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) + config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes) + config.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(config.RetryCooldownSeconds) - common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(common.MjNotifyEnabled) + config.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(config.MjNotifyEnabled) - common.OptionMap["ChatCacheEnabled"] = strconv.FormatBool(common.ChatCacheEnabled) - common.OptionMap["ChatCacheExpireMinute"] = strconv.Itoa(common.ChatCacheExpireMinute) + config.OptionMap["ChatCacheEnabled"] = strconv.FormatBool(config.ChatCacheEnabled) + config.OptionMap["ChatCacheExpireMinute"] = strconv.Itoa(config.ChatCacheExpireMinute) - common.OptionMapRWMutex.Unlock() + config.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() } @@ -121,64 +122,64 @@ func UpdateOption(key string, value string) error { } var optionIntMap = map[string]*int{ - "SMTPPort": &common.SMTPPort, - "QuotaForNewUser": &common.QuotaForNewUser, - "QuotaForInviter": &common.QuotaForInviter, - "QuotaForInvitee": &common.QuotaForInvitee, - "QuotaRemindThreshold": &common.QuotaRemindThreshold, - "PreConsumedQuota": &common.PreConsumedQuota, - "RetryTimes": &common.RetryTimes, - "RetryCooldownSeconds": &common.RetryCooldownSeconds, - "ChatCacheExpireMinute": &common.ChatCacheExpireMinute, + "SMTPPort": &config.SMTPPort, + "QuotaForNewUser": &config.QuotaForNewUser, + "QuotaForInviter": &config.QuotaForInviter, + "QuotaForInvitee": &config.QuotaForInvitee, + "QuotaRemindThreshold": &config.QuotaRemindThreshold, + "PreConsumedQuota": &config.PreConsumedQuota, + "RetryTimes": &config.RetryTimes, + "RetryCooldownSeconds": &config.RetryCooldownSeconds, + "ChatCacheExpireMinute": &config.ChatCacheExpireMinute, } var optionBoolMap = map[string]*bool{ - "PasswordRegisterEnabled": &common.PasswordRegisterEnabled, - "PasswordLoginEnabled": &common.PasswordLoginEnabled, - "EmailVerificationEnabled": &common.EmailVerificationEnabled, - "GitHubOAuthEnabled": &common.GitHubOAuthEnabled, - "WeChatAuthEnabled": &common.WeChatAuthEnabled, - "LarkAuthEnabled": &common.LarkAuthEnabled, - "TurnstileCheckEnabled": &common.TurnstileCheckEnabled, - "RegisterEnabled": &common.RegisterEnabled, - "EmailDomainRestrictionEnabled": &common.EmailDomainRestrictionEnabled, - "AutomaticDisableChannelEnabled": &common.AutomaticDisableChannelEnabled, - "AutomaticEnableChannelEnabled": &common.AutomaticEnableChannelEnabled, - "ApproximateTokenEnabled": &common.ApproximateTokenEnabled, - "LogConsumeEnabled": &common.LogConsumeEnabled, - "DisplayInCurrencyEnabled": &common.DisplayInCurrencyEnabled, - "DisplayTokenStatEnabled": &common.DisplayTokenStatEnabled, - "MjNotifyEnabled": &common.MjNotifyEnabled, - "ChatCacheEnabled": &common.ChatCacheEnabled, + "PasswordRegisterEnabled": &config.PasswordRegisterEnabled, + "PasswordLoginEnabled": &config.PasswordLoginEnabled, + "EmailVerificationEnabled": &config.EmailVerificationEnabled, + "GitHubOAuthEnabled": &config.GitHubOAuthEnabled, + "WeChatAuthEnabled": &config.WeChatAuthEnabled, + "LarkAuthEnabled": &config.LarkAuthEnabled, + "TurnstileCheckEnabled": &config.TurnstileCheckEnabled, + "RegisterEnabled": &config.RegisterEnabled, + "EmailDomainRestrictionEnabled": &config.EmailDomainRestrictionEnabled, + "AutomaticDisableChannelEnabled": &config.AutomaticDisableChannelEnabled, + "AutomaticEnableChannelEnabled": &config.AutomaticEnableChannelEnabled, + "ApproximateTokenEnabled": &config.ApproximateTokenEnabled, + "LogConsumeEnabled": &config.LogConsumeEnabled, + "DisplayInCurrencyEnabled": &config.DisplayInCurrencyEnabled, + "DisplayTokenStatEnabled": &config.DisplayTokenStatEnabled, + "MjNotifyEnabled": &config.MjNotifyEnabled, + "ChatCacheEnabled": &config.ChatCacheEnabled, } var optionStringMap = map[string]*string{ - "SMTPServer": &common.SMTPServer, - "SMTPAccount": &common.SMTPAccount, - "SMTPFrom": &common.SMTPFrom, - "SMTPToken": &common.SMTPToken, - "ServerAddress": &common.ServerAddress, - "GitHubClientId": &common.GitHubClientId, - "GitHubClientSecret": &common.GitHubClientSecret, - "Footer": &common.Footer, - "SystemName": &common.SystemName, - "Logo": &common.Logo, - "WeChatServerAddress": &common.WeChatServerAddress, - "WeChatServerToken": &common.WeChatServerToken, - "WeChatAccountQRCodeImageURL": &common.WeChatAccountQRCodeImageURL, - "TurnstileSiteKey": &common.TurnstileSiteKey, - "TurnstileSecretKey": &common.TurnstileSecretKey, - "TopUpLink": &common.TopUpLink, - "ChatLink": &common.ChatLink, - "ChatLinks": &common.ChatLinks, - "LarkClientId": &common.LarkClientId, - "LarkClientSecret": &common.LarkClientSecret, + "SMTPServer": &config.SMTPServer, + "SMTPAccount": &config.SMTPAccount, + "SMTPFrom": &config.SMTPFrom, + "SMTPToken": &config.SMTPToken, + "ServerAddress": &config.ServerAddress, + "GitHubClientId": &config.GitHubClientId, + "GitHubClientSecret": &config.GitHubClientSecret, + "Footer": &config.Footer, + "SystemName": &config.SystemName, + "Logo": &config.Logo, + "WeChatServerAddress": &config.WeChatServerAddress, + "WeChatServerToken": &config.WeChatServerToken, + "WeChatAccountQRCodeImageURL": &config.WeChatAccountQRCodeImageURL, + "TurnstileSiteKey": &config.TurnstileSiteKey, + "TurnstileSecretKey": &config.TurnstileSecretKey, + "TopUpLink": &config.TopUpLink, + "ChatLink": &config.ChatLink, + "ChatLinks": &config.ChatLinks, + "LarkClientId": &config.LarkClientId, + "LarkClientSecret": &config.LarkClientSecret, } func updateOptionMap(key string, value string) (err error) { - common.OptionMapRWMutex.Lock() - defer common.OptionMapRWMutex.Unlock() - common.OptionMap[key] = value + config.OptionMapRWMutex.Lock() + defer config.OptionMapRWMutex.Unlock() + config.OptionMap[key] = value if ptr, ok := optionIntMap[key]; ok { *ptr, _ = strconv.Atoi(value) return @@ -196,13 +197,13 @@ func updateOptionMap(key string, value string) (err error) { switch key { case "EmailDomainWhitelist": - common.EmailDomainWhitelist = strings.Split(value, ",") + config.EmailDomainWhitelist = strings.Split(value, ",") case "GroupRatio": err = common.UpdateGroupRatioByJSONString(value) case "ChannelDisableThreshold": - common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) + config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) case "QuotaPerUnit": - common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) + config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) } return err } diff --git a/model/price.go b/model/price.go index c1f42123..709a7c66 100644 --- a/model/price.go +++ b/model/price.go @@ -1,7 +1,7 @@ package model import ( - "one-api/common" + "one-api/common/config" "github.com/shopspring/decimal" "gorm.io/gorm" @@ -114,211 +114,211 @@ type ModelType struct { func GetDefaultPrice() []*Price { ModelTypes := map[string]ModelType{ // $0.03 / 1K tokens $0.06 / 1K tokens - "gpt-4": {[]float64{15, 30}, common.ChannelTypeOpenAI}, - "gpt-4-0314": {[]float64{15, 30}, common.ChannelTypeOpenAI}, - "gpt-4-0613": {[]float64{15, 30}, common.ChannelTypeOpenAI}, + "gpt-4": {[]float64{15, 30}, config.ChannelTypeOpenAI}, + "gpt-4-0314": {[]float64{15, 30}, config.ChannelTypeOpenAI}, + "gpt-4-0613": {[]float64{15, 30}, config.ChannelTypeOpenAI}, // $0.06 / 1K tokens $0.12 / 1K tokens - "gpt-4-32k": {[]float64{30, 60}, common.ChannelTypeOpenAI}, - "gpt-4-32k-0314": {[]float64{30, 60}, common.ChannelTypeOpenAI}, - "gpt-4-32k-0613": {[]float64{30, 60}, common.ChannelTypeOpenAI}, + "gpt-4-32k": {[]float64{30, 60}, config.ChannelTypeOpenAI}, + "gpt-4-32k-0314": {[]float64{30, 60}, config.ChannelTypeOpenAI}, + "gpt-4-32k-0613": {[]float64{30, 60}, config.ChannelTypeOpenAI}, // $0.01 / 1K tokens $0.03 / 1K tokens - "gpt-4-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI}, - "gpt-4-turbo": {[]float64{5, 15}, common.ChannelTypeOpenAI}, - "gpt-4-turbo-2024-04-09": {[]float64{5, 15}, common.ChannelTypeOpenAI}, - "gpt-4-1106-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI}, - "gpt-4-0125-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI}, - "gpt-4-turbo-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI}, - "gpt-4-vision-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI}, + "gpt-4-preview": {[]float64{5, 15}, config.ChannelTypeOpenAI}, + "gpt-4-turbo": {[]float64{5, 15}, config.ChannelTypeOpenAI}, + "gpt-4-turbo-2024-04-09": {[]float64{5, 15}, config.ChannelTypeOpenAI}, + "gpt-4-1106-preview": {[]float64{5, 15}, config.ChannelTypeOpenAI}, + "gpt-4-0125-preview": {[]float64{5, 15}, config.ChannelTypeOpenAI}, + "gpt-4-turbo-preview": {[]float64{5, 15}, config.ChannelTypeOpenAI}, + "gpt-4-vision-preview": {[]float64{5, 15}, config.ChannelTypeOpenAI}, // $0.005 / 1K tokens $0.015 / 1K tokens - "gpt-4o": {[]float64{2.5, 7.5}, common.ChannelTypeOpenAI}, + "gpt-4o": {[]float64{2.5, 7.5}, config.ChannelTypeOpenAI}, // $0.0005 / 1K tokens $0.0015 / 1K tokens - "gpt-3.5-turbo": {[]float64{0.25, 0.75}, common.ChannelTypeOpenAI}, - "gpt-3.5-turbo-0125": {[]float64{0.25, 0.75}, common.ChannelTypeOpenAI}, + "gpt-3.5-turbo": {[]float64{0.25, 0.75}, config.ChannelTypeOpenAI}, + "gpt-3.5-turbo-0125": {[]float64{0.25, 0.75}, config.ChannelTypeOpenAI}, // $0.0015 / 1K tokens $0.002 / 1K tokens - "gpt-3.5-turbo-0301": {[]float64{0.75, 1}, common.ChannelTypeOpenAI}, - "gpt-3.5-turbo-0613": {[]float64{0.75, 1}, common.ChannelTypeOpenAI}, - "gpt-3.5-turbo-instruct": {[]float64{0.75, 1}, common.ChannelTypeOpenAI}, + "gpt-3.5-turbo-0301": {[]float64{0.75, 1}, config.ChannelTypeOpenAI}, + "gpt-3.5-turbo-0613": {[]float64{0.75, 1}, config.ChannelTypeOpenAI}, + "gpt-3.5-turbo-instruct": {[]float64{0.75, 1}, config.ChannelTypeOpenAI}, // $0.003 / 1K tokens $0.004 / 1K tokens - "gpt-3.5-turbo-16k": {[]float64{1.5, 2}, common.ChannelTypeOpenAI}, - "gpt-3.5-turbo-16k-0613": {[]float64{1.5, 2}, common.ChannelTypeOpenAI}, + "gpt-3.5-turbo-16k": {[]float64{1.5, 2}, config.ChannelTypeOpenAI}, + "gpt-3.5-turbo-16k-0613": {[]float64{1.5, 2}, config.ChannelTypeOpenAI}, // $0.001 / 1K tokens $0.002 / 1K tokens - "gpt-3.5-turbo-1106": {[]float64{0.5, 1}, common.ChannelTypeOpenAI}, + "gpt-3.5-turbo-1106": {[]float64{0.5, 1}, config.ChannelTypeOpenAI}, // $0.0020 / 1K tokens - "davinci-002": {[]float64{1, 1}, common.ChannelTypeOpenAI}, + "davinci-002": {[]float64{1, 1}, config.ChannelTypeOpenAI}, // $0.0004 / 1K tokens - "babbage-002": {[]float64{0.2, 0.2}, common.ChannelTypeOpenAI}, + "babbage-002": {[]float64{0.2, 0.2}, config.ChannelTypeOpenAI}, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens - "whisper-1": {[]float64{15, 15}, common.ChannelTypeOpenAI}, + "whisper-1": {[]float64{15, 15}, config.ChannelTypeOpenAI}, // $0.015 / 1K characters - "tts-1": {[]float64{7.5, 7.5}, common.ChannelTypeOpenAI}, - "tts-1-1106": {[]float64{7.5, 7.5}, common.ChannelTypeOpenAI}, + "tts-1": {[]float64{7.5, 7.5}, config.ChannelTypeOpenAI}, + "tts-1-1106": {[]float64{7.5, 7.5}, config.ChannelTypeOpenAI}, // $0.030 / 1K characters - "tts-1-hd": {[]float64{15, 15}, common.ChannelTypeOpenAI}, - "tts-1-hd-1106": {[]float64{15, 15}, common.ChannelTypeOpenAI}, - "text-embedding-ada-002": {[]float64{0.05, 0.05}, common.ChannelTypeOpenAI}, + "tts-1-hd": {[]float64{15, 15}, config.ChannelTypeOpenAI}, + "tts-1-hd-1106": {[]float64{15, 15}, config.ChannelTypeOpenAI}, + "text-embedding-ada-002": {[]float64{0.05, 0.05}, config.ChannelTypeOpenAI}, // $0.00002 / 1K tokens - "text-embedding-3-small": {[]float64{0.01, 0.01}, common.ChannelTypeOpenAI}, + "text-embedding-3-small": {[]float64{0.01, 0.01}, config.ChannelTypeOpenAI}, // $0.00013 / 1K tokens - "text-embedding-3-large": {[]float64{0.065, 0.065}, common.ChannelTypeOpenAI}, - "text-moderation-stable": {[]float64{0.1, 0.1}, common.ChannelTypeOpenAI}, - "text-moderation-latest": {[]float64{0.1, 0.1}, common.ChannelTypeOpenAI}, + "text-embedding-3-large": {[]float64{0.065, 0.065}, config.ChannelTypeOpenAI}, + "text-moderation-stable": {[]float64{0.1, 0.1}, config.ChannelTypeOpenAI}, + "text-moderation-latest": {[]float64{0.1, 0.1}, config.ChannelTypeOpenAI}, // $0.016 - $0.020 / image - "dall-e-2": {[]float64{8, 8}, common.ChannelTypeOpenAI}, + "dall-e-2": {[]float64{8, 8}, config.ChannelTypeOpenAI}, // $0.040 - $0.120 / image - "dall-e-3": {[]float64{20, 20}, common.ChannelTypeOpenAI}, + "dall-e-3": {[]float64{20, 20}, config.ChannelTypeOpenAI}, // $0.80/million tokens $2.40/million tokens - "claude-instant-1.2": {[]float64{0.4, 1.2}, common.ChannelTypeAnthropic}, + "claude-instant-1.2": {[]float64{0.4, 1.2}, config.ChannelTypeAnthropic}, // $8.00/million tokens $24.00/million tokens - "claude-2.0": {[]float64{4, 12}, common.ChannelTypeAnthropic}, - "claude-2.1": {[]float64{4, 12}, common.ChannelTypeAnthropic}, + "claude-2.0": {[]float64{4, 12}, config.ChannelTypeAnthropic}, + "claude-2.1": {[]float64{4, 12}, config.ChannelTypeAnthropic}, // $15 / M $75 / M - "claude-3-opus-20240229": {[]float64{7.5, 22.5}, common.ChannelTypeAnthropic}, + "claude-3-opus-20240229": {[]float64{7.5, 22.5}, config.ChannelTypeAnthropic}, // $3 / M $15 / M - "claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, common.ChannelTypeAnthropic}, + "claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, config.ChannelTypeAnthropic}, // $0.25 / M $1.25 / M 0.00025$ / 1k tokens 0.00125$ / 1k tokens - "claude-3-haiku-20240307": {[]float64{0.125, 0.625}, common.ChannelTypeAnthropic}, + "claude-3-haiku-20240307": {[]float64{0.125, 0.625}, config.ChannelTypeAnthropic}, // ¥0.004 / 1k tokens ¥0.008 / 1k tokens - "ERNIE-Speed": {[]float64{0.2857, 0.5714}, common.ChannelTypeBaidu}, + "ERNIE-Speed": {[]float64{0.2857, 0.5714}, config.ChannelTypeBaidu}, // ¥0.012 / 1k tokens ¥0.012 / 1k tokens - "ERNIE-Bot": {[]float64{0.8572, 0.8572}, common.ChannelTypeBaidu}, - "ERNIE-3.5-8K": {[]float64{0.8572, 0.8572}, common.ChannelTypeBaidu}, + "ERNIE-Bot": {[]float64{0.8572, 0.8572}, config.ChannelTypeBaidu}, + "ERNIE-3.5-8K": {[]float64{0.8572, 0.8572}, config.ChannelTypeBaidu}, // 0.024元/千tokens 0.048元/千tokens - "ERNIE-Bot-8k": {[]float64{1.7143, 3.4286}, common.ChannelTypeBaidu}, + "ERNIE-Bot-8k": {[]float64{1.7143, 3.4286}, config.ChannelTypeBaidu}, // ¥0.008 / 1k tokens ¥0.008 / 1k tokens - "ERNIE-Bot-turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeBaidu}, + "ERNIE-Bot-turbo": {[]float64{0.5715, 0.5715}, config.ChannelTypeBaidu}, // ¥0.12 / 1k tokens ¥0.12 / 1k tokens - "ERNIE-Bot-4": {[]float64{8.572, 8.572}, common.ChannelTypeBaidu}, - "ERNIE-4.0": {[]float64{8.572, 8.572}, common.ChannelTypeBaidu}, + "ERNIE-Bot-4": {[]float64{8.572, 8.572}, config.ChannelTypeBaidu}, + "ERNIE-4.0": {[]float64{8.572, 8.572}, config.ChannelTypeBaidu}, // ¥0.002 / 1k tokens - "Embedding-V1": {[]float64{0.1429, 0.1429}, common.ChannelTypeBaidu}, + "Embedding-V1": {[]float64{0.1429, 0.1429}, config.ChannelTypeBaidu}, // ¥0.004 / 1k tokens - "BLOOMZ-7B": {[]float64{0.2857, 0.2857}, common.ChannelTypeBaidu}, + "BLOOMZ-7B": {[]float64{0.2857, 0.2857}, config.ChannelTypeBaidu}, - "PaLM-2": {[]float64{1, 1}, common.ChannelTypePaLM}, + "PaLM-2": {[]float64{1, 1}, config.ChannelTypePaLM}, // $0.50 / 1 million tokens $1.50 / 1 million tokens // 0.0005$ / 1k tokens 0.0015$ / 1k tokens - "gemini-pro": {[]float64{0.25, 0.75}, common.ChannelTypeGemini}, - "gemini-pro-vision": {[]float64{0.25, 0.75}, common.ChannelTypeGemini}, - "gemini-1.0-pro": {[]float64{0.25, 0.75}, common.ChannelTypeGemini}, + "gemini-pro": {[]float64{0.25, 0.75}, config.ChannelTypeGemini}, + "gemini-pro-vision": {[]float64{0.25, 0.75}, config.ChannelTypeGemini}, + "gemini-1.0-pro": {[]float64{0.25, 0.75}, config.ChannelTypeGemini}, // $7 / 1 million tokens $21 / 1 million tokens - "gemini-1.5-pro": {[]float64{1.75, 5.25}, common.ChannelTypeGemini}, - "gemini-1.5-pro-latest": {[]float64{1.75, 5.25}, common.ChannelTypeGemini}, - "gemini-1.5-flash": {[]float64{0.175, 0.265}, common.ChannelTypeGemini}, - "gemini-1.5-flash-latest": {[]float64{0.175, 0.265}, common.ChannelTypeGemini}, - "gemini-ultra": {[]float64{1, 1}, common.ChannelTypeGemini}, + "gemini-1.5-pro": {[]float64{1.75, 5.25}, config.ChannelTypeGemini}, + "gemini-1.5-pro-latest": {[]float64{1.75, 5.25}, config.ChannelTypeGemini}, + "gemini-1.5-flash": {[]float64{0.175, 0.265}, config.ChannelTypeGemini}, + "gemini-1.5-flash-latest": {[]float64{0.175, 0.265}, config.ChannelTypeGemini}, + "gemini-ultra": {[]float64{1, 1}, config.ChannelTypeGemini}, // ¥0.005 / 1k tokens - "glm-3-turbo": {[]float64{0.3572, 0.3572}, common.ChannelTypeZhipu}, + "glm-3-turbo": {[]float64{0.3572, 0.3572}, config.ChannelTypeZhipu}, // ¥0.1 / 1k tokens - "glm-4": {[]float64{7.143, 7.143}, common.ChannelTypeZhipu}, - "glm-4v": {[]float64{7.143, 7.143}, common.ChannelTypeZhipu}, + "glm-4": {[]float64{7.143, 7.143}, config.ChannelTypeZhipu}, + "glm-4v": {[]float64{7.143, 7.143}, config.ChannelTypeZhipu}, // ¥0.0005 / 1k tokens - "embedding-2": {[]float64{0.0357, 0.0357}, common.ChannelTypeZhipu}, + "embedding-2": {[]float64{0.0357, 0.0357}, config.ChannelTypeZhipu}, // ¥0.25 / 1张图片 - "cogview-3": {[]float64{17.8571, 17.8571}, common.ChannelTypeZhipu}, + "cogview-3": {[]float64{17.8571, 17.8571}, config.ChannelTypeZhipu}, // ¥0.008 / 1k tokens - "qwen-turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeAli}, + "qwen-turbo": {[]float64{0.5715, 0.5715}, config.ChannelTypeAli}, // ¥0.02 / 1k tokens - "qwen-plus": {[]float64{1.4286, 1.4286}, common.ChannelTypeAli}, - "qwen-vl-max": {[]float64{1.4286, 1.4286}, common.ChannelTypeAli}, + "qwen-plus": {[]float64{1.4286, 1.4286}, config.ChannelTypeAli}, + "qwen-vl-max": {[]float64{1.4286, 1.4286}, config.ChannelTypeAli}, // 0.12元/1,000tokens - "qwen-max": {[]float64{8.5714, 8.5714}, common.ChannelTypeAli}, - "qwen-max-longcontext": {[]float64{8.5714, 8.5714}, common.ChannelTypeAli}, + "qwen-max": {[]float64{8.5714, 8.5714}, config.ChannelTypeAli}, + "qwen-max-longcontext": {[]float64{8.5714, 8.5714}, config.ChannelTypeAli}, // 0.008元/1,000tokens - "qwen-vl-plus": {[]float64{0.5715, 0.5715}, common.ChannelTypeAli}, + "qwen-vl-plus": {[]float64{0.5715, 0.5715}, config.ChannelTypeAli}, // ¥0.0007 / 1k tokens - "text-embedding-v1": {[]float64{0.05, 0.05}, common.ChannelTypeAli}, + "text-embedding-v1": {[]float64{0.05, 0.05}, config.ChannelTypeAli}, // ¥0.018 / 1k tokens - "SparkDesk": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei}, - "SparkDesk-v1.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei}, - "SparkDesk-v2.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei}, - "SparkDesk-v3.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei}, - "SparkDesk-v3.5": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei}, + "SparkDesk": {[]float64{1.2858, 1.2858}, config.ChannelTypeXunfei}, + "SparkDesk-v1.1": {[]float64{1.2858, 1.2858}, config.ChannelTypeXunfei}, + "SparkDesk-v2.1": {[]float64{1.2858, 1.2858}, config.ChannelTypeXunfei}, + "SparkDesk-v3.1": {[]float64{1.2858, 1.2858}, config.ChannelTypeXunfei}, + "SparkDesk-v3.5": {[]float64{1.2858, 1.2858}, config.ChannelTypeXunfei}, // ¥0.012 / 1k tokens - "360GPT_S2_V9": {[]float64{0.8572, 0.8572}, common.ChannelType360}, + "360GPT_S2_V9": {[]float64{0.8572, 0.8572}, config.ChannelType360}, // ¥0.001 / 1k tokens - "embedding-bert-512-v1": {[]float64{0.0715, 0.0715}, common.ChannelType360}, - "embedding_s1_v1": {[]float64{0.0715, 0.0715}, common.ChannelType360}, - "semantic_similarity_s1_v1": {[]float64{0.0715, 0.0715}, common.ChannelType360}, + "embedding-bert-512-v1": {[]float64{0.0715, 0.0715}, config.ChannelType360}, + "embedding_s1_v1": {[]float64{0.0715, 0.0715}, config.ChannelType360}, + "semantic_similarity_s1_v1": {[]float64{0.0715, 0.0715}, config.ChannelType360}, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 - "hunyuan": {[]float64{7.143, 7.143}, common.ChannelTypeTencent}, + "hunyuan": {[]float64{7.143, 7.143}, config.ChannelTypeTencent}, // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 // ¥0.01 / 1k tokens - "ChatStd": {[]float64{0.7143, 0.7143}, common.ChannelTypeTencent}, + "ChatStd": {[]float64{0.7143, 0.7143}, config.ChannelTypeTencent}, //¥0.1 / 1k tokens - "ChatPro": {[]float64{7.143, 7.143}, common.ChannelTypeTencent}, + "ChatPro": {[]float64{7.143, 7.143}, config.ChannelTypeTencent}, - "Baichuan2-Turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeBaichuan}, // ¥0.008 / 1k tokens - "Baichuan2-Turbo-192k": {[]float64{1.143, 1.143}, common.ChannelTypeBaichuan}, // ¥0.016 / 1k tokens - "Baichuan2-53B": {[]float64{1.4286, 1.4286}, common.ChannelTypeBaichuan}, // ¥0.02 / 1k tokens - "Baichuan-Text-Embedding": {[]float64{0.0357, 0.0357}, common.ChannelTypeBaichuan}, // ¥0.0005 / 1k tokens + "Baichuan2-Turbo": {[]float64{0.5715, 0.5715}, config.ChannelTypeBaichuan}, // ¥0.008 / 1k tokens + "Baichuan2-Turbo-192k": {[]float64{1.143, 1.143}, config.ChannelTypeBaichuan}, // ¥0.016 / 1k tokens + "Baichuan2-53B": {[]float64{1.4286, 1.4286}, config.ChannelTypeBaichuan}, // ¥0.02 / 1k tokens + "Baichuan-Text-Embedding": {[]float64{0.0357, 0.0357}, config.ChannelTypeBaichuan}, // ¥0.0005 / 1k tokens - "abab5.5s-chat": {[]float64{0.3572, 0.3572}, common.ChannelTypeMiniMax}, // ¥0.005 / 1k tokens - "abab5.5-chat": {[]float64{1.0714, 1.0714}, common.ChannelTypeMiniMax}, // ¥0.015 / 1k tokens - "abab6-chat": {[]float64{14.2857, 14.2857}, common.ChannelTypeMiniMax}, // ¥0.2 / 1k tokens - "embo-01": {[]float64{0.0357, 0.0357}, common.ChannelTypeMiniMax}, // ¥0.0005 / 1k tokens + "abab5.5s-chat": {[]float64{0.3572, 0.3572}, config.ChannelTypeMiniMax}, // ¥0.005 / 1k tokens + "abab5.5-chat": {[]float64{1.0714, 1.0714}, config.ChannelTypeMiniMax}, // ¥0.015 / 1k tokens + "abab6-chat": {[]float64{14.2857, 14.2857}, config.ChannelTypeMiniMax}, // ¥0.2 / 1k tokens + "embo-01": {[]float64{0.0357, 0.0357}, config.ChannelTypeMiniMax}, // ¥0.0005 / 1k tokens - "deepseek-coder": {[]float64{0.75, 0.75}, common.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens - "deepseek-chat": {[]float64{0.75, 0.75}, common.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens + "deepseek-coder": {[]float64{0.75, 0.75}, config.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens + "deepseek-chat": {[]float64{0.75, 0.75}, config.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens - "moonshot-v1-8k": {[]float64{0.8572, 0.8572}, common.ChannelTypeMoonshot}, // ¥0.012 / 1K tokens - "moonshot-v1-32k": {[]float64{1.7143, 1.7143}, common.ChannelTypeMoonshot}, // ¥0.024 / 1K tokens - "moonshot-v1-128k": {[]float64{4.2857, 4.2857}, common.ChannelTypeMoonshot}, // ¥0.06 / 1K tokens + "moonshot-v1-8k": {[]float64{0.8572, 0.8572}, config.ChannelTypeMoonshot}, // ¥0.012 / 1K tokens + "moonshot-v1-32k": {[]float64{1.7143, 1.7143}, config.ChannelTypeMoonshot}, // ¥0.024 / 1K tokens + "moonshot-v1-128k": {[]float64{4.2857, 4.2857}, config.ChannelTypeMoonshot}, // ¥0.06 / 1K tokens - "open-mistral-7b": {[]float64{0.125, 0.125}, common.ChannelTypeMistral}, // 0.25$ / 1M tokens 0.25$ / 1M tokens 0.00025$ / 1k tokens - "open-mixtral-8x7b": {[]float64{0.35, 0.35}, common.ChannelTypeMistral}, // 0.7$ / 1M tokens 0.7$ / 1M tokens 0.0007$ / 1k tokens - "mistral-small-latest": {[]float64{1, 3}, common.ChannelTypeMistral}, // 2$ / 1M tokens 6$ / 1M tokens 0.002$ / 1k tokens - "mistral-medium-latest": {[]float64{1.35, 4.05}, common.ChannelTypeMistral}, // 2.7$ / 1M tokens 8.1$ / 1M tokens 0.0027$ / 1k tokens - "mistral-large-latest": {[]float64{4, 12}, common.ChannelTypeMistral}, // 8$ / 1M tokens 24$ / 1M tokens 0.008$ / 1k tokens - "mistral-embed": {[]float64{0.05, 0.05}, common.ChannelTypeMistral}, // 0.1$ / 1M tokens 0.1$ / 1M tokens 0.0001$ / 1k tokens + "open-mistral-7b": {[]float64{0.125, 0.125}, config.ChannelTypeMistral}, // 0.25$ / 1M tokens 0.25$ / 1M tokens 0.00025$ / 1k tokens + "open-mixtral-8x7b": {[]float64{0.35, 0.35}, config.ChannelTypeMistral}, // 0.7$ / 1M tokens 0.7$ / 1M tokens 0.0007$ / 1k tokens + "mistral-small-latest": {[]float64{1, 3}, config.ChannelTypeMistral}, // 2$ / 1M tokens 6$ / 1M tokens 0.002$ / 1k tokens + "mistral-medium-latest": {[]float64{1.35, 4.05}, config.ChannelTypeMistral}, // 2.7$ / 1M tokens 8.1$ / 1M tokens 0.0027$ / 1k tokens + "mistral-large-latest": {[]float64{4, 12}, config.ChannelTypeMistral}, // 8$ / 1M tokens 24$ / 1M tokens 0.008$ / 1k tokens + "mistral-embed": {[]float64{0.05, 0.05}, config.ChannelTypeMistral}, // 0.1$ / 1M tokens 0.1$ / 1M tokens 0.0001$ / 1k tokens // $0.70/$0.80 /1M Tokens 0.0007$ / 1k tokens - "llama2-70b-4096": {[]float64{0.35, 0.4}, common.ChannelTypeGroq}, + "llama2-70b-4096": {[]float64{0.35, 0.4}, config.ChannelTypeGroq}, // $0.10/$0.10 /1M Tokens 0.0001$ / 1k tokens - "llama2-7b-2048": {[]float64{0.05, 0.05}, common.ChannelTypeGroq}, - "gemma-7b-it": {[]float64{0.05, 0.05}, common.ChannelTypeGroq}, + "llama2-7b-2048": {[]float64{0.05, 0.05}, config.ChannelTypeGroq}, + "gemma-7b-it": {[]float64{0.05, 0.05}, config.ChannelTypeGroq}, // $0.27/$0.27 /1M Tokens 0.00027$ / 1k tokens - "mixtral-8x7b-32768": {[]float64{0.135, 0.135}, common.ChannelTypeGroq}, + "mixtral-8x7b-32768": {[]float64{0.135, 0.135}, config.ChannelTypeGroq}, // 2.5 元 / 1M tokens 0.0025 / 1k tokens - "yi-34b-chat-0205": {[]float64{0.1786, 0.1786}, common.ChannelTypeLingyi}, + "yi-34b-chat-0205": {[]float64{0.1786, 0.1786}, config.ChannelTypeLingyi}, // 12 元 / 1M tokens 0.012 / 1k tokens - "yi-34b-chat-200k": {[]float64{0.8571, 0.8571}, common.ChannelTypeLingyi}, + "yi-34b-chat-200k": {[]float64{0.8571, 0.8571}, config.ChannelTypeLingyi}, // 6 元 / 1M tokens 0.006 / 1k tokens - "yi-vl-plus": {[]float64{0.4286, 0.4286}, common.ChannelTypeLingyi}, + "yi-vl-plus": {[]float64{0.4286, 0.4286}, config.ChannelTypeLingyi}, - "@cf/stabilityai/stable-diffusion-xl-base-1.0": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, - "@cf/lykon/dreamshaper-8-lcm": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, - "@cf/bytedance/stable-diffusion-xl-lightning": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, - "@cf/qwen/qwen1.5-7b-chat-awq": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, - "@cf/qwen/qwen1.5-14b-chat-awq": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, - "@hf/thebloke/deepseek-coder-6.7b-base-awq": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, - "@hf/google/gemma-7b-it": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, - "@hf/thebloke/llama-2-13b-chat-awq": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, - "@cf/openai/whisper": {[]float64{0, 0}, common.ChannelTypeCloudflareAI}, + "@cf/stabilityai/stable-diffusion-xl-base-1.0": {[]float64{0, 0}, config.ChannelTypeCloudflareAI}, + "@cf/lykon/dreamshaper-8-lcm": {[]float64{0, 0}, config.ChannelTypeCloudflareAI}, + "@cf/bytedance/stable-diffusion-xl-lightning": {[]float64{0, 0}, config.ChannelTypeCloudflareAI}, + "@cf/qwen/qwen1.5-7b-chat-awq": {[]float64{0, 0}, config.ChannelTypeCloudflareAI}, + "@cf/qwen/qwen1.5-14b-chat-awq": {[]float64{0, 0}, config.ChannelTypeCloudflareAI}, + "@hf/thebloke/deepseek-coder-6.7b-base-awq": {[]float64{0, 0}, config.ChannelTypeCloudflareAI}, + "@hf/google/gemma-7b-it": {[]float64{0, 0}, config.ChannelTypeCloudflareAI}, + "@hf/thebloke/llama-2-13b-chat-awq": {[]float64{0, 0}, config.ChannelTypeCloudflareAI}, + "@cf/openai/whisper": {[]float64{0, 0}, config.ChannelTypeCloudflareAI}, //$0.50 /1M TOKENS $1.50/1M TOKENS - "command-r": {[]float64{0.25, 0.75}, common.ChannelTypeCohere}, + "command-r": {[]float64{0.25, 0.75}, config.ChannelTypeCohere}, //$3 /1M TOKENS $15/1M TOKENS - "command-r-plus": {[]float64{1.5, 7.5}, common.ChannelTypeCohere}, + "command-r-plus": {[]float64{1.5, 7.5}, config.ChannelTypeCohere}, // 0.065 - "sd3": {[]float64{32.5, 32.5}, common.ChannelTypeStabilityAI}, + "sd3": {[]float64{32.5, 32.5}, config.ChannelTypeStabilityAI}, // 0.04 - "sd3-turbo": {[]float64{20, 20}, common.ChannelTypeStabilityAI}, + "sd3-turbo": {[]float64{20, 20}, config.ChannelTypeStabilityAI}, // 0.03 - "stable-image-core": {[]float64{15, 15}, common.ChannelTypeStabilityAI}, + "stable-image-core": {[]float64{15, 15}, config.ChannelTypeStabilityAI}, // hunyuan - "hunyuan-lite": {[]float64{0, 0}, common.ChannelTypeHunyuan}, - "hunyuan-standard": {[]float64{0.3214, 0.3571}, common.ChannelTypeHunyuan}, - "hunyuan-standard-256k": {[]float64{1.0714, 4.2857}, common.ChannelTypeHunyuan}, - "hunyuan-pro": {[]float64{2.1429, 7.1429}, common.ChannelTypeHunyuan}, + "hunyuan-lite": {[]float64{0, 0}, config.ChannelTypeHunyuan}, + "hunyuan-standard": {[]float64{0.3214, 0.3571}, config.ChannelTypeHunyuan}, + "hunyuan-standard-256k": {[]float64{1.0714, 4.2857}, config.ChannelTypeHunyuan}, + "hunyuan-pro": {[]float64{2.1429, 7.1429}, config.ChannelTypeHunyuan}, } var prices []*Price @@ -355,7 +355,7 @@ func GetDefaultPrice() []*Price { prices = append(prices, &Price{ Model: model, Type: TimesPriceType, - ChannelType: common.ChannelTypeMidjourney, + ChannelType: config.ChannelTypeMidjourney, Input: mjPrice, Output: mjPrice, }) diff --git a/model/redemption.go b/model/redemption.go index 821e244b..da6f524a 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/common/config" "one-api/common/utils" "gorm.io/gorm" @@ -69,7 +70,7 @@ func Redeem(key string, userId int) (quota int, err error) { if err != nil { return errors.New("无效的兑换码") } - if redemption.Status != common.RedemptionCodeStatusEnabled { + if redemption.Status != config.RedemptionCodeStatusEnabled { return errors.New("该兑换码已被使用") } err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error @@ -77,7 +78,7 @@ func Redeem(key string, userId int) (quota int, err error) { return err } redemption.RedeemedTime = utils.GetTimestamp() - redemption.Status = common.RedemptionCodeStatusUsed + redemption.Status = config.RedemptionCodeStatusUsed err = tx.Save(redemption).Error return err }) diff --git a/model/token.go b/model/token.go index 883b6d4c..a6281af0 100644 --- a/model/token.go +++ b/model/token.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/common/stmp" "one-api/common/utils" @@ -49,7 +50,7 @@ func GetUserTokensList(userId int, params *GenericParams) (*DataResult[Token], e // 获取状态为可用的令牌 func GetUserEnabledTokens(userId int) (tokens []*Token, err error) { - err = DB.Where("user_id = ? and status = ?", userId, common.TokenStatusEnabled).Find(&tokens).Error + err = DB.Where("user_id = ? and status = ?", userId, config.TokenStatusEnabled).Find(&tokens).Error return tokens, err } @@ -65,17 +66,17 @@ func ValidateUserToken(key string) (token *Token, err error) { } return nil, errors.New("令牌验证失败") } - if token.Status == common.TokenStatusExhausted { + if token.Status == config.TokenStatusExhausted { return nil, errors.New("该令牌额度已用尽") - } else if token.Status == common.TokenStatusExpired { + } else if token.Status == config.TokenStatusExpired { return nil, errors.New("该令牌已过期") } - if token.Status != common.TokenStatusEnabled { + if token.Status != config.TokenStatusEnabled { return nil, errors.New("该令牌状态不可用") } if token.ExpiredTime != -1 && token.ExpiredTime < utils.GetTimestamp() { if !common.RedisEnabled { - token.Status = common.TokenStatusExpired + token.Status = config.TokenStatusExpired err := token.SelectUpdate() if err != nil { logger.SysError("failed to update token status" + err.Error()) @@ -86,7 +87,7 @@ func ValidateUserToken(key string) (token *Token, err error) { if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !common.RedisEnabled { // in this case, we can make sure the token is exhausted - token.Status = common.TokenStatusExhausted + token.Status = config.TokenStatusExhausted err := token.SelectUpdate() if err != nil { logger.SysError("failed to update token status" + err.Error()) @@ -128,7 +129,7 @@ func GetTokenByName(name string, user_id int) (*Token, error) { } func (token *Token) Insert() error { - if token.ChatCache && !common.ChatCacheEnabled { + if token.ChatCache && !config.ChatCacheEnabled { token.ChatCache = false } @@ -138,7 +139,7 @@ func (token *Token) Insert() error { // Update Make sure your token's fields is completed, because this will update non-zero values func (token *Token) Update() error { - if token.ChatCache && !common.ChatCacheEnabled { + if token.ChatCache && !config.ChatCacheEnabled { token.ChatCache = false } @@ -178,7 +179,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeTokenQuota, id, quota) return nil } @@ -200,7 +201,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) return nil } @@ -236,7 +237,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { if userQuota < quota { return errors.New("用户额度不足") } - quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold + quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold noMoreQuota := userQuota-quota <= 0 if quotaTooLow || noMoreQuota { go sendQuotaWarningEmail(token.UserId, userQuota, noMoreQuota) diff --git a/model/user.go b/model/user.go index f7f13ec8..2393ac5e 100644 --- a/model/user.go +++ b/model/user.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/common/utils" "strings" @@ -116,7 +117,7 @@ func (user *User) Insert(inviterId int) error { return err } } - user.Quota = common.QuotaForNewUser + user.Quota = config.QuotaForNewUser user.AccessToken = utils.GetUUID() user.AffCode = utils.GetRandomString(4) user.CreatedTime = utils.GetTimestamp() @@ -124,17 +125,17 @@ func (user *User) Insert(inviterId int) error { if result.Error != nil { return result.Error } - if common.QuotaForNewUser > 0 { - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser))) + if config.QuotaForNewUser > 0 { + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) } if inviterId != 0 { - if common.QuotaForInvitee > 0 { - _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee) - RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) + if config.QuotaForInvitee > 0 { + _ = IncreaseUserQuota(user.Id, config.QuotaForInvitee) + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) } - if common.QuotaForInviter > 0 { - _ = IncreaseUserQuota(inviterId, common.QuotaForInviter) - RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter))) + if config.QuotaForInviter > 0 { + _ = IncreaseUserQuota(inviterId, config.QuotaForInviter) + RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) } } return nil @@ -150,8 +151,8 @@ func (user *User) Update(updatePassword bool) error { } err = DB.Model(user).Updates(user).Error - if err == nil && user.Role == common.RoleRootUser { - common.RootUserEmail = user.Email + if err == nil && user.Role == config.RoleRootUser { + config.RootUserEmail = user.Email } return err @@ -196,7 +197,7 @@ func (user *User) ValidateAndFill() (err error) { } } okay := common.ValidatePasswordAndHash(password, user.Password) - if !okay || user.Status != common.UserStatusEnabled { + if !okay || user.Status != config.UserStatusEnabled { return errors.New("用户名或密码错误,或用户已被封禁") } return nil @@ -310,7 +311,7 @@ func IsAdmin(userId int) bool { logger.SysError("no such user " + err.Error()) return false } - return user.Role >= common.RoleAdminUser + return user.Role >= config.RoleAdminUser } func IsUserEnabled(userId int) (bool, error) { @@ -322,7 +323,7 @@ func IsUserEnabled(userId int) (bool, error) { if err != nil { return false, err } - return user.Status == common.UserStatusEnabled, nil + return user.Status == config.UserStatusEnabled, nil } func ValidateAccessToken(token string) (user *User) { @@ -366,7 +367,7 @@ func IncreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUserQuota, id, quota) return nil } @@ -382,7 +383,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUserQuota, id, -quota) return nil } @@ -395,12 +396,12 @@ func decreaseUserQuota(id int, quota int) (err error) { } func GetRootUserEmail() (email string) { - DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) + DB.Model(&User{}).Where("role = ?", config.RoleRootUser).Select("email").Find(&email) return email } func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUsedQuota, id, quota) addNewRecord(BatchUpdateTypeRequestCount, id, 1) return diff --git a/model/utils.go b/model/utils.go index e4797a78..e0826e0d 100644 --- a/model/utils.go +++ b/model/utils.go @@ -1,7 +1,7 @@ package model import ( - "one-api/common" + "one-api/common/config" "one-api/common/logger" "sync" "time" @@ -29,7 +29,7 @@ func init() { func InitBatchUpdater() { go func() { for { - time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) + time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second) batchUpdate() } }() diff --git a/providers/ali/ali_test.go b/providers/ali/ali_test.go index c5cbc26f..a3f93ce5 100644 --- a/providers/ali/ali_test.go +++ b/providers/ali/ali_test.go @@ -2,7 +2,7 @@ package ali_test import ( "net/http" - "one-api/common" + "one-api/common/config" "one-api/common/test" "one-api/model" ) @@ -20,5 +20,5 @@ func setupAliTestServer() (baseUrl string, server *test.ServerTest, teardown fun } func getAliChannel(baseUrl string) model.Channel { - return test.GetChannel(common.ChannelTypeAli, baseUrl, "", "", "") + return test.GetChannel(config.ChannelTypeAli, baseUrl, "", "", "") } diff --git a/providers/ali/chat.go b/providers/ali/chat.go index 8efc839a..f30a1472 100644 --- a/providers/ali/chat.go +++ b/providers/ali/chat.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/common/utils" "one-api/types" @@ -55,7 +56,7 @@ func (p *AliProvider) CreateChatCompletionStream(request *types.ChatCompletionRe } func (p *AliProvider) getAliChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/ali/embeddings.go b/providers/ali/embeddings.go index fdbe6ca6..c64961aa 100644 --- a/providers/ali/embeddings.go +++ b/providers/ali/embeddings.go @@ -3,11 +3,12 @@ package ali import ( "net/http" "one-api/common" + "one-api/common/config" "one-api/types" ) func (p *AliProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/azure/image_generations.go b/providers/azure/image_generations.go index 882d8268..f34db2f9 100644 --- a/providers/azure/image_generations.go +++ b/providers/azure/image_generations.go @@ -4,6 +4,7 @@ import ( "errors" "net/http" "one-api/common" + "one-api/common/config" "one-api/providers/openai" "one-api/types" "time" @@ -14,7 +15,7 @@ func (p *AzureProvider) CreateImageGenerations(request *types.ImageRequest) (*ty return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest) } - req, errWithCode := p.GetRequestTextBody(common.RelayModeImagesGenerations, request.Model, request) + req, errWithCode := p.GetRequestTextBody(config.RelayModeImagesGenerations, request.Model, request) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/azureSpeech/speech.go b/providers/azureSpeech/speech.go index 9d03dbed..2210e5d6 100644 --- a/providers/azureSpeech/speech.go +++ b/providers/azureSpeech/speech.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/config" "one-api/types" "strings" ) @@ -82,7 +83,7 @@ func (p *AzureSpeechProvider) getRequestBody(request *types.SpeechAudioRequest) } func (p *AzureSpeechProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeAudioSpeech) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeAudioSpeech) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/baichuan/chat.go b/providers/baichuan/chat.go index 36ed4f33..3cacb6db 100644 --- a/providers/baichuan/chat.go +++ b/providers/baichuan/chat.go @@ -2,7 +2,7 @@ package baichuan import ( "net/http" - "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/providers/openai" "one-api/types" @@ -11,7 +11,7 @@ import ( func (p *BaichuanProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { requestBody := p.getChatRequestBody(request) - req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, requestBody) + req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, requestBody) if errWithCode != nil { return nil, errWithCode } @@ -51,7 +51,7 @@ func (p *BaichuanProvider) CreateChatCompletionStream(request *types.ChatComplet request.StreamOptions = nil } - req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) + req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/baidu/chat.go b/providers/baidu/chat.go index e93d1b1a..6b4551fa 100644 --- a/providers/baidu/chat.go +++ b/providers/baidu/chat.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/types" "strings" @@ -54,7 +55,7 @@ func (p *BaiduProvider) CreateChatCompletionStream(request *types.ChatCompletion } func (p *BaiduProvider) getBaiduChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/baidu/embeddings.go b/providers/baidu/embeddings.go index 5e13215e..a62fdf2f 100644 --- a/providers/baidu/embeddings.go +++ b/providers/baidu/embeddings.go @@ -3,11 +3,12 @@ package baidu import ( "net/http" "one-api/common" + "one-api/common/config" "one-api/types" ) func (p *BaiduProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/base/common.go b/providers/base/common.go index 37429880..9b952b64 100644 --- a/providers/base/common.go +++ b/providers/base/common.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/model" "one-api/types" @@ -110,25 +111,25 @@ func (p *BaseProvider) ModelMappingHandler(modelName string) (string, error) { func (p *BaseProvider) GetAPIUri(relayMode int) string { switch relayMode { - case common.RelayModeChatCompletions: + case config.RelayModeChatCompletions: return p.Config.ChatCompletions - case common.RelayModeCompletions: + case config.RelayModeCompletions: return p.Config.Completions - case common.RelayModeEmbeddings: + case config.RelayModeEmbeddings: return p.Config.Embeddings - case common.RelayModeAudioSpeech: + case config.RelayModeAudioSpeech: return p.Config.AudioSpeech - case common.RelayModeAudioTranscription: + case config.RelayModeAudioTranscription: return p.Config.AudioTranscriptions - case common.RelayModeAudioTranslation: + case config.RelayModeAudioTranslation: return p.Config.AudioTranslations - case common.RelayModeModerations: + case config.RelayModeModerations: return p.Config.Moderation - case common.RelayModeImagesGenerations: + case config.RelayModeImagesGenerations: return p.Config.ImagesGenerations - case common.RelayModeImagesEdits: + case config.RelayModeImagesEdits: return p.Config.ImagesEdit - case common.RelayModeImagesVariations: + case config.RelayModeImagesVariations: return p.Config.ImagesVariations default: return "" diff --git a/providers/bedrock/chat.go b/providers/bedrock/chat.go index de24cdd2..654026e1 100644 --- a/providers/bedrock/chat.go +++ b/providers/bedrock/chat.go @@ -3,6 +3,7 @@ package bedrock import ( "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/providers/bedrock/category" "one-api/types" @@ -48,7 +49,7 @@ func (p *BedrockProvider) getChatRequest(request *types.ChatCompletionRequest) ( return nil, common.StringErrorWrapper("bedrock provider not found", "bedrock_err", http.StatusInternalServerError) } - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/claude/chat.go b/providers/claude/chat.go index 7c3d53e9..7024ab20 100644 --- a/providers/claude/chat.go +++ b/providers/claude/chat.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/image" "one-api/common/requester" "one-api/common/utils" @@ -64,7 +65,7 @@ func (p *ClaudeProvider) CreateChatCompletionStream(request *types.ChatCompletio } func (p *ClaudeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/cohere/chat.go b/providers/cohere/chat.go index ae416ab6..921a7ddc 100644 --- a/providers/cohere/chat.go +++ b/providers/cohere/chat.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/common/utils" "one-api/providers/base" @@ -56,7 +57,7 @@ func (p *CohereProvider) CreateChatCompletionStream(request *types.ChatCompletio } func (p *CohereProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/coze/chat.go b/providers/coze/chat.go index 3913e28f..8f58775c 100644 --- a/providers/coze/chat.go +++ b/providers/coze/chat.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/common/utils" "one-api/types" @@ -56,7 +57,7 @@ func (p *CozeProvider) CreateChatCompletionStream(request *types.ChatCompletionR } func (p *CozeProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/groq/chat.go b/providers/groq/chat.go index d41ccb30..ee3660ab 100644 --- a/providers/groq/chat.go +++ b/providers/groq/chat.go @@ -2,7 +2,7 @@ package groq import ( "net/http" - "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/providers/openai" "one-api/types" @@ -11,7 +11,7 @@ import ( func (p *GroqProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { p.getChatRequestBody(request) - req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) + req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request) if errWithCode != nil { return nil, errWithCode } @@ -51,7 +51,7 @@ func (p *GroqProvider) CreateChatCompletionStream(request *types.ChatCompletionR request.StreamOptions = nil } p.getChatRequestBody(request) - req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) + req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/hunyuan/chat.go b/providers/hunyuan/chat.go index d5be9f06..f588d425 100644 --- a/providers/hunyuan/chat.go +++ b/providers/hunyuan/chat.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/types" "strings" @@ -53,7 +54,7 @@ func (p *HunyuanProvider) CreateChatCompletionStream(request *types.ChatCompleti } func (p *HunyuanProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { - action, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + action, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/midjourney/base.go b/providers/midjourney/base.go index 434457f9..21afe682 100644 --- a/providers/midjourney/base.go +++ b/providers/midjourney/base.go @@ -7,7 +7,7 @@ import ( "io" "log" "net/http" - "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/common/requester" "one-api/model" @@ -48,7 +48,7 @@ func (p *MidjourneyProvider) Send(timeout int, requestURL string) (*MidjourneyRe return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err } delete(mapResult, "accountFilter") - if !common.MjNotifyEnabled { + if !config.MjNotifyEnabled { delete(mapResult, "notifyHook") } } diff --git a/providers/minimax/chat.go b/providers/minimax/chat.go index 88d95d71..f2a612c0 100644 --- a/providers/minimax/chat.go +++ b/providers/minimax/chat.go @@ -5,6 +5,7 @@ import ( "errors" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/types" "strings" @@ -55,7 +56,7 @@ func (p *MiniMaxProvider) CreateChatCompletionStream(request *types.ChatCompleti } func (p *MiniMaxProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/minimax/embeddings.go b/providers/minimax/embeddings.go index 0eb84d5e..f1a2f9e7 100644 --- a/providers/minimax/embeddings.go +++ b/providers/minimax/embeddings.go @@ -3,11 +3,12 @@ package minimax import ( "net/http" "one-api/common" + "one-api/common/config" "one-api/types" ) func (p *MiniMaxProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/mistral/chat.go b/providers/mistral/chat.go index d9eaa122..a02f30f5 100644 --- a/providers/mistral/chat.go +++ b/providers/mistral/chat.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/types" "strings" @@ -56,7 +57,7 @@ func (p *MistralProvider) CreateChatCompletionStream(request *types.ChatCompleti } func (p *MistralProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/mistral/embeddings.go b/providers/mistral/embeddings.go index 8bcc9fbb..bb49735f 100644 --- a/providers/mistral/embeddings.go +++ b/providers/mistral/embeddings.go @@ -3,11 +3,12 @@ package mistral import ( "net/http" "one-api/common" + "one-api/common/config" "one-api/types" ) func (p *MistralProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/moonshot/chat.go b/providers/moonshot/chat.go index e45ba2e2..9c711439 100644 --- a/providers/moonshot/chat.go +++ b/providers/moonshot/chat.go @@ -3,6 +3,7 @@ package moonshot import ( "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/providers/openai" "one-api/types" @@ -10,7 +11,7 @@ import ( func (p *MoonshotProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { request.ClearEmptyMessages() - req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) + req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request) if errWithCode != nil { return nil, errWithCode } @@ -51,7 +52,7 @@ func (p *MoonshotProvider) CreateChatCompletion(request *types.ChatCompletionReq func (p *MoonshotProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) { request.ClearEmptyMessages() - req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) + req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/ollama/chat.go b/providers/ollama/chat.go index 5b01cf79..99ac9f2f 100644 --- a/providers/ollama/chat.go +++ b/providers/ollama/chat.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/image" "one-api/common/requester" "one-api/common/utils" @@ -56,7 +57,7 @@ func (p *OllamaProvider) CreateChatCompletionStream(request *types.ChatCompletio } func (p *OllamaProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/ollama/embeddings.go b/providers/ollama/embeddings.go index 7231f9e4..2943f9e5 100644 --- a/providers/ollama/embeddings.go +++ b/providers/ollama/embeddings.go @@ -3,11 +3,12 @@ package ollama import ( "net/http" "one-api/common" + "one-api/common/config" "one-api/types" ) func (p *OllamaProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/openai/base.go b/providers/openai/base.go index 34d5744d..b586345e 100644 --- a/providers/openai/base.go +++ b/providers/openai/base.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/model" "one-api/types" @@ -32,11 +33,11 @@ func (f OpenAIProviderFactory) Create(channel *model.Channel) base.ProviderInter // 创建 OpenAIProvider // https://platform.openai.com/docs/api-reference/introduction func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvider { - config := getOpenAIConfig(baseURL) + openaiConfig := getOpenAIConfig(baseURL) OpenAIProvider := &OpenAIProvider{ BaseProvider: base.BaseProvider{ - Config: config, + Config: openaiConfig, Channel: channel, Requester: requester.NewHTTPRequester(*channel.Proxy, RequestErrorHandle), }, @@ -44,7 +45,7 @@ func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvide BalanceAction: true, } - if channel.Type == common.ChannelTypeOpenAI { + if channel.Type == config.ChannelTypeOpenAI { OpenAIProvider.SupportStreamOptions = true } @@ -109,7 +110,7 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) requestURL = fmt.Sprintf("/openai%s?api-version=%s", requestURL, apiVersion) } - } else if p.Channel.Type == common.ChannelTypeCustom && p.Channel.Other != "" { + } else if p.Channel.Type == config.ChannelTypeCustom && p.Channel.Other != "" { requestURL = strings.Replace(requestURL, "v1", p.Channel.Other, 1) } diff --git a/providers/openai/chat.go b/providers/openai/chat.go index e57b7d04..8857b6fa 100644 --- a/providers/openai/chat.go +++ b/providers/openai/chat.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/types" "strings" @@ -17,7 +18,7 @@ type OpenAIStreamHandler struct { } func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { - req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) + req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request) if errWithCode != nil { return nil, errWithCode } @@ -67,7 +68,7 @@ func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletio // 避免误传导致报错 request.StreamOptions = nil } - req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request) + req, errWithCode := p.GetRequestTextBody(config.RelayModeChatCompletions, request.Model, request) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/openai/completion.go b/providers/openai/completion.go index e52120e7..3a775605 100644 --- a/providers/openai/completion.go +++ b/providers/openai/completion.go @@ -5,13 +5,14 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/types" "strings" ) func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (openaiResponse *types.CompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) { - req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request) + req, errWithCode := p.GetRequestTextBody(config.RelayModeCompletions, request.Model, request) if errWithCode != nil { return nil, errWithCode } @@ -50,7 +51,7 @@ func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest // 避免误传导致报错 request.StreamOptions = nil } - req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request) + req, errWithCode := p.GetRequestTextBody(config.RelayModeCompletions, request.Model, request) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/openai/embeddings.go b/providers/openai/embeddings.go index aa7b48a0..325780d3 100644 --- a/providers/openai/embeddings.go +++ b/providers/openai/embeddings.go @@ -2,12 +2,12 @@ package openai import ( "net/http" - "one-api/common" + "one-api/common/config" "one-api/types" ) func (p *OpenAIProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { - req, errWithCode := p.GetRequestTextBody(common.RelayModeEmbeddings, request.Model, request) + req, errWithCode := p.GetRequestTextBody(config.RelayModeEmbeddings, request.Model, request) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/openai/image_edits.go b/providers/openai/image_edits.go index 376459bc..5943d0cc 100644 --- a/providers/openai/image_edits.go +++ b/providers/openai/image_edits.go @@ -5,12 +5,13 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/types" ) func (p *OpenAIProvider) CreateImageEdits(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) { - req, errWithCode := p.getRequestImageBody(common.RelayModeEdits, request.Model, request) + req, errWithCode := p.getRequestImageBody(config.RelayModeEdits, request.Model, request) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/openai/image_generations.go b/providers/openai/image_generations.go index 53d4796e..c3e505b0 100644 --- a/providers/openai/image_generations.go +++ b/providers/openai/image_generations.go @@ -3,6 +3,7 @@ package openai import ( "net/http" "one-api/common" + "one-api/common/config" "one-api/types" ) @@ -11,7 +12,7 @@ func (p *OpenAIProvider) CreateImageGenerations(request *types.ImageRequest) (*t return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest) } - req, errWithCode := p.GetRequestTextBody(common.RelayModeImagesGenerations, request.Model, request) + req, errWithCode := p.GetRequestTextBody(config.RelayModeImagesGenerations, request.Model, request) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/openai/image_variations.go b/providers/openai/image_variations.go index 4160d28d..6ad5bd54 100644 --- a/providers/openai/image_variations.go +++ b/providers/openai/image_variations.go @@ -2,12 +2,12 @@ package openai import ( "net/http" - "one-api/common" + "one-api/common/config" "one-api/types" ) func (p *OpenAIProvider) CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) { - req, errWithCode := p.getRequestImageBody(common.RelayModeImagesVariations, request.Model, request) + req, errWithCode := p.getRequestImageBody(config.RelayModeImagesVariations, request.Model, request) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/openai/moderation.go b/providers/openai/moderation.go index 0cbd1bad..972e88b9 100644 --- a/providers/openai/moderation.go +++ b/providers/openai/moderation.go @@ -2,13 +2,13 @@ package openai import ( "net/http" - "one-api/common" + "one-api/common/config" "one-api/types" ) func (p *OpenAIProvider) CreateModeration(request *types.ModerationRequest) (*types.ModerationResponse, *types.OpenAIErrorWithStatusCode) { - req, errWithCode := p.GetRequestTextBody(common.RelayModeModerations, request.Model, request) + req, errWithCode := p.GetRequestTextBody(config.RelayModeModerations, request.Model, request) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/openai/speech.go b/providers/openai/speech.go index 5d3a6f65..077f0950 100644 --- a/providers/openai/speech.go +++ b/providers/openai/speech.go @@ -2,13 +2,13 @@ package openai import ( "net/http" - "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/types" ) func (p *OpenAIProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) { - req, errWithCode := p.GetRequestTextBody(common.RelayModeAudioSpeech, request.Model, request) + req, errWithCode := p.GetRequestTextBody(config.RelayModeAudioSpeech, request.Model, request) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/openai/transcriptions.go b/providers/openai/transcriptions.go index 8be9aeb4..b24dd9fa 100644 --- a/providers/openai/transcriptions.go +++ b/providers/openai/transcriptions.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/types" "regexp" @@ -15,7 +16,7 @@ import ( ) func (p *OpenAIProvider) CreateTranscriptions(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) { - req, errWithCode := p.getRequestAudioBody(common.RelayModeAudioTranscription, request.Model, request) + req, errWithCode := p.getRequestAudioBody(config.RelayModeAudioTranscription, request.Model, request) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/openai/translations.go b/providers/openai/translations.go index 1bc4b581..6f993e89 100644 --- a/providers/openai/translations.go +++ b/providers/openai/translations.go @@ -4,11 +4,12 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" "one-api/types" ) func (p *OpenAIProvider) CreateTranslation(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) { - req, errWithCode := p.getRequestAudioBody(common.RelayModeAudioTranslation, request.Model, request) + req, errWithCode := p.getRequestAudioBody(config.RelayModeAudioTranslation, request.Model, request) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/palm/chat.go b/providers/palm/chat.go index 41577c96..754e7d91 100644 --- a/providers/palm/chat.go +++ b/providers/palm/chat.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/common/utils" "one-api/types" @@ -55,7 +56,7 @@ func (p *PalmProvider) CreateChatCompletionStream(request *types.ChatCompletionR } func (p *PalmProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/providers.go b/providers/providers.go index fdbe6ad8..59127894 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -1,7 +1,7 @@ package providers import ( - "one-api/common" + "one-api/common/config" "one-api/model" "one-api/providers/ali" "one-api/providers/azure" @@ -44,32 +44,32 @@ var providerFactories = make(map[int]ProviderFactory) // 在程序启动时,添加所有的供应商工厂 func init() { - providerFactories[common.ChannelTypeOpenAI] = openai.OpenAIProviderFactory{} - providerFactories[common.ChannelTypeAzure] = azure.AzureProviderFactory{} - providerFactories[common.ChannelTypeAli] = ali.AliProviderFactory{} - providerFactories[common.ChannelTypeTencent] = tencent.TencentProviderFactory{} - providerFactories[common.ChannelTypeBaidu] = baidu.BaiduProviderFactory{} - providerFactories[common.ChannelTypeAnthropic] = claude.ClaudeProviderFactory{} - providerFactories[common.ChannelTypePaLM] = palm.PalmProviderFactory{} - providerFactories[common.ChannelTypeZhipu] = zhipu.ZhipuProviderFactory{} - providerFactories[common.ChannelTypeXunfei] = xunfei.XunfeiProviderFactory{} - providerFactories[common.ChannelTypeAzureSpeech] = azurespeech.AzureSpeechProviderFactory{} - providerFactories[common.ChannelTypeGemini] = gemini.GeminiProviderFactory{} - providerFactories[common.ChannelTypeBaichuan] = baichuan.BaichuanProviderFactory{} - providerFactories[common.ChannelTypeMiniMax] = minimax.MiniMaxProviderFactory{} - providerFactories[common.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{} - providerFactories[common.ChannelTypeMistral] = mistral.MistralProviderFactory{} - providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{} - providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{} - providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{} - providerFactories[common.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{} - providerFactories[common.ChannelTypeCohere] = cohere.CohereProviderFactory{} - providerFactories[common.ChannelTypeStabilityAI] = stabilityAI.StabilityAIProviderFactory{} - providerFactories[common.ChannelTypeCoze] = coze.CozeProviderFactory{} - providerFactories[common.ChannelTypeOllama] = ollama.OllamaProviderFactory{} - providerFactories[common.ChannelTypeMoonshot] = moonshot.MoonshotProviderFactory{} - providerFactories[common.ChannelTypeLingyi] = lingyi.LingyiProviderFactory{} - providerFactories[common.ChannelTypeHunyuan] = hunyuan.HunyuanProviderFactory{} + providerFactories[config.ChannelTypeOpenAI] = openai.OpenAIProviderFactory{} + providerFactories[config.ChannelTypeAzure] = azure.AzureProviderFactory{} + providerFactories[config.ChannelTypeAli] = ali.AliProviderFactory{} + providerFactories[config.ChannelTypeTencent] = tencent.TencentProviderFactory{} + providerFactories[config.ChannelTypeBaidu] = baidu.BaiduProviderFactory{} + providerFactories[config.ChannelTypeAnthropic] = claude.ClaudeProviderFactory{} + providerFactories[config.ChannelTypePaLM] = palm.PalmProviderFactory{} + providerFactories[config.ChannelTypeZhipu] = zhipu.ZhipuProviderFactory{} + providerFactories[config.ChannelTypeXunfei] = xunfei.XunfeiProviderFactory{} + providerFactories[config.ChannelTypeAzureSpeech] = azurespeech.AzureSpeechProviderFactory{} + providerFactories[config.ChannelTypeGemini] = gemini.GeminiProviderFactory{} + providerFactories[config.ChannelTypeBaichuan] = baichuan.BaichuanProviderFactory{} + providerFactories[config.ChannelTypeMiniMax] = minimax.MiniMaxProviderFactory{} + providerFactories[config.ChannelTypeDeepseek] = deepseek.DeepseekProviderFactory{} + providerFactories[config.ChannelTypeMistral] = mistral.MistralProviderFactory{} + providerFactories[config.ChannelTypeGroq] = groq.GroqProviderFactory{} + providerFactories[config.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{} + providerFactories[config.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{} + providerFactories[config.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{} + providerFactories[config.ChannelTypeCohere] = cohere.CohereProviderFactory{} + providerFactories[config.ChannelTypeStabilityAI] = stabilityAI.StabilityAIProviderFactory{} + providerFactories[config.ChannelTypeCoze] = coze.CozeProviderFactory{} + providerFactories[config.ChannelTypeOllama] = ollama.OllamaProviderFactory{} + providerFactories[config.ChannelTypeMoonshot] = moonshot.MoonshotProviderFactory{} + providerFactories[config.ChannelTypeLingyi] = lingyi.LingyiProviderFactory{} + providerFactories[config.ChannelTypeHunyuan] = hunyuan.HunyuanProviderFactory{} } @@ -79,7 +79,7 @@ func GetProvider(channel *model.Channel, c *gin.Context) base.ProviderInterface var provider base.ProviderInterface if !ok { // 处理未找到的供应商工厂 - baseURL := common.ChannelBaseURLs[channel.Type] + baseURL := config.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } diff --git a/providers/stabilityAI/image_generations.go b/providers/stabilityAI/image_generations.go index e79a46bd..37a2c7be 100644 --- a/providers/stabilityAI/image_generations.go +++ b/providers/stabilityAI/image_generations.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/storage" "one-api/common/utils" "one-api/types" @@ -20,7 +21,7 @@ func convertModelName(modelName string) string { } func (p *StabilityAIProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeImagesGenerations) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeImagesGenerations) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/tencent/chat.go b/providers/tencent/chat.go index 25625cf1..fb4ef287 100644 --- a/providers/tencent/chat.go +++ b/providers/tencent/chat.go @@ -5,6 +5,7 @@ import ( "errors" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/common/utils" "one-api/types" @@ -55,7 +56,7 @@ func (p *TencentProvider) CreateChatCompletionStream(request *types.ChatCompleti } func (p *TencentProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/xunfei/chat.go b/providers/xunfei/chat.go index ee8d289f..fc740076 100644 --- a/providers/xunfei/chat.go +++ b/providers/xunfei/chat.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/common/utils" "one-api/types" @@ -59,7 +60,7 @@ func (p *XunfeiProvider) CreateChatCompletionStream(request *types.ChatCompletio } func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (*websocket.Conn, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/zhipu/chat.go b/providers/zhipu/chat.go index fc2ee6a5..536f68d9 100644 --- a/providers/zhipu/chat.go +++ b/providers/zhipu/chat.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/requester" "one-api/types" "strings" @@ -54,7 +55,7 @@ func (p *ZhipuProvider) CreateChatCompletionStream(request *types.ChatCompletion } func (p *ZhipuProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeChatCompletions) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeChatCompletions) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/zhipu/embeddings.go b/providers/zhipu/embeddings.go index e33e1da5..681f8fea 100644 --- a/providers/zhipu/embeddings.go +++ b/providers/zhipu/embeddings.go @@ -3,11 +3,12 @@ package zhipu import ( "net/http" "one-api/common" + "one-api/common/config" "one-api/types" ) func (p *ZhipuProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeEmbeddings) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeEmbeddings) if errWithCode != nil { return nil, errWithCode } diff --git a/providers/zhipu/image_generations.go b/providers/zhipu/image_generations.go index 5bed657b..52622cf6 100644 --- a/providers/zhipu/image_generations.go +++ b/providers/zhipu/image_generations.go @@ -3,12 +3,13 @@ package zhipu import ( "net/http" "one-api/common" + "one-api/common/config" "one-api/types" "time" ) func (p *ZhipuProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) { - url, errWithCode := p.GetSupportedAPIUri(common.RelayModeImagesGenerations) + url, errWithCode := p.GetSupportedAPIUri(config.RelayModeImagesGenerations) if errWithCode != nil { return nil, errWithCode } diff --git a/relay/common.go b/relay/common.go index a6618e2f..97a68b4c 100644 --- a/relay/common.go +++ b/relay/common.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/common/requester" "one-api/common/utils" @@ -93,7 +94,7 @@ func fetchChannelById(channelId int) (*model.Channel, error) { if err != nil { return nil, errors.New("无效的渠道 Id") } - if channel.Status != common.ChannelStatusEnabled { + if channel.Status != config.ChannelStatusEnabled { return nil, errors.New("该渠道已被禁用") } diff --git a/relay/main.go b/relay/main.go index a035a025..26fb1d24 100644 --- a/relay/main.go +++ b/relay/main.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/model" "one-api/relay/relay_util" @@ -50,7 +51,7 @@ func Relay(c *gin.Context) { channel := relay.getProvider().GetChannel() go processChannelRelayError(c.Request.Context(), channel.Id, channel.Name, apiErr) - retryTimes := common.RetryTimes + retryTimes := config.RetryTimes if done || !shouldRetry(c, apiErr.StatusCode) { logger.LogError(c.Request.Context(), fmt.Sprintf("relay error happen, status code is %d, won't retry in this case", apiErr.StatusCode)) retryTimes = 0 diff --git a/relay/midjourney/relay-mj.go b/relay/midjourney/relay-mj.go index 078e0ff4..0d5d3533 100644 --- a/relay/midjourney/relay-mj.go +++ b/relay/midjourney/relay-mj.go @@ -10,6 +10,7 @@ import ( "log" "net/http" "one-api/common" + "one-api/common/config" "one-api/controller" "one-api/model" provider "one-api/providers/midjourney" @@ -112,7 +113,7 @@ func coverMidjourneyTaskDto(originTask *model.Midjourney) (midjourneyTask provid midjourneyTask.FinishTime = originTask.FinishTime midjourneyTask.ImageUrl = "" if originTask.ImageUrl != "" { - midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId + midjourneyTask.ImageUrl = config.ServerAddress + "/mj/image/" + originTask.MjId if originTask.Status != "SUCCESS" { midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10) } diff --git a/relay/relay.go b/relay/relay.go index a0b493b8..c6e3889e 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -3,6 +3,7 @@ package relay import ( "net/http" "one-api/common" + "one-api/common/config" "one-api/model" "one-api/providers/azure" "one-api/providers/openai" @@ -20,7 +21,7 @@ func RelayOnly(c *gin.Context) { } channel := provider.GetChannel() - if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeAzure { + if channel.Type != config.ChannelTypeOpenAI && channel.Type != config.ChannelTypeAzure { common.AbortWithMessage(c, http.StatusServiceUnavailable, "provider must be of type azureopenai or openai") return } diff --git a/relay/relay_util/cache.go b/relay/relay_util/cache.go index 8315e7a3..5130c753 100644 --- a/relay/relay_util/cache.go +++ b/relay/relay_util/cache.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "fmt" "one-api/common" + "one-api/common/config" "one-api/common/utils" "one-api/model" @@ -57,7 +58,7 @@ func NewChatCacheProps(c *gin.Context, allow bool) *ChatCacheProps { return props } - if common.ChatCacheEnabled && c.GetBool("chat_cache") { + if config.ChatCacheEnabled && c.GetBool("chat_cache") { props.Cache = true } @@ -113,7 +114,7 @@ func (p *ChatCacheProps) StoreCache(channelId, promptTokens, completionTokens in p.CompletionTokens = completionTokens p.ModelName = modelName - return p.Driver.Set(p.getHash(), p, int64(common.ChatCacheExpireMinute)) + return p.Driver.Set(p.getHash(), p, int64(config.ChatCacheExpireMinute)) } func (p *ChatCacheProps) GetCache() *ChatCacheProps { @@ -125,7 +126,7 @@ func (p *ChatCacheProps) GetCache() *ChatCacheProps { } func (p *ChatCacheProps) needCache() bool { - return common.ChatCacheEnabled && p.Cache + return config.ChatCacheEnabled && p.Cache } func (p *ChatCacheProps) getHash() string { diff --git a/relay/relay_util/pricing.go b/relay/relay_util/pricing.go index b2cab192..72ee290e 100644 --- a/relay/relay_util/pricing.go +++ b/relay/relay_util/pricing.go @@ -3,7 +3,7 @@ package relay_util import ( "encoding/json" "errors" - "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/common/utils" "one-api/model" @@ -107,7 +107,7 @@ func (p *Pricing) GetPrice(modelName string) *model.Price { return &model.Price{ Type: model.TokensPriceType, - ChannelType: common.ChannelTypeUnknown, + ChannelType: config.ChannelTypeUnknown, Input: model.DefaultPrice, Output: model.DefaultPrice, } diff --git a/relay/relay_util/quota.go b/relay/relay_util/quota.go index 48f6b094..e0e0add2 100644 --- a/relay/relay_util/quota.go +++ b/relay/relay_util/quota.go @@ -7,6 +7,7 @@ import ( "math" "net/http" "one-api/common" + "one-api/common/config" "one-api/common/logger" "one-api/model" "one-api/types" @@ -45,7 +46,7 @@ func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *type if quota.price.Type == model.TimesPriceType { quota.preConsumedQuota = int(1000 * quota.inputRatio) } else { - quota.preConsumedQuota = int(float64(quota.promptTokens+common.PreConsumedQuota) * quota.inputRatio) + quota.preConsumedQuota = int(float64(quota.promptTokens+config.PreConsumedQuota) * quota.inputRatio) } errWithCode := quota.preQuotaConsumption() diff --git a/relay/relay_util/type.go b/relay/relay_util/type.go index 9c2d72d3..3233b02c 100644 --- a/relay/relay_util/type.go +++ b/relay/relay_util/type.go @@ -1,35 +1,37 @@ package relay_util -import "one-api/common" +import ( + "one-api/common/config" +) var UnknownOwnedBy = "未知" var ModelOwnedBy map[int]string func init() { ModelOwnedBy = map[int]string{ - common.ChannelTypeOpenAI: "OpenAI", - common.ChannelTypeAnthropic: "Anthropic", - common.ChannelTypeBaidu: "Baidu", - common.ChannelTypePaLM: "Google PaLM", - common.ChannelTypeGemini: "Google Gemini", - common.ChannelTypeZhipu: "Zhipu", - common.ChannelTypeAli: "Ali", - common.ChannelTypeXunfei: "Xunfei", - common.ChannelType360: "360", - common.ChannelTypeTencent: "Tencent", - common.ChannelTypeBaichuan: "Baichuan", - common.ChannelTypeMiniMax: "MiniMax", - common.ChannelTypeDeepseek: "Deepseek", - common.ChannelTypeMoonshot: "Moonshot", - common.ChannelTypeMistral: "Mistral", - common.ChannelTypeGroq: "Groq", - common.ChannelTypeLingyi: "Lingyiwanwu", - common.ChannelTypeMidjourney: "Midjourney", - common.ChannelTypeCloudflareAI: "Cloudflare AI", - common.ChannelTypeCohere: "Cohere", - common.ChannelTypeStabilityAI: "Stability AI", - common.ChannelTypeCoze: "Coze", - common.ChannelTypeOllama: "Ollama", - common.ChannelTypeHunyuan: "Hunyuan", + config.ChannelTypeOpenAI: "OpenAI", + config.ChannelTypeAnthropic: "Anthropic", + config.ChannelTypeBaidu: "Baidu", + config.ChannelTypePaLM: "Google PaLM", + config.ChannelTypeGemini: "Google Gemini", + config.ChannelTypeZhipu: "Zhipu", + config.ChannelTypeAli: "Ali", + config.ChannelTypeXunfei: "Xunfei", + config.ChannelType360: "360", + config.ChannelTypeTencent: "Tencent", + config.ChannelTypeBaichuan: "Baichuan", + config.ChannelTypeMiniMax: "MiniMax", + config.ChannelTypeDeepseek: "Deepseek", + config.ChannelTypeMoonshot: "Moonshot", + config.ChannelTypeMistral: "Mistral", + config.ChannelTypeGroq: "Groq", + config.ChannelTypeLingyi: "Lingyiwanwu", + config.ChannelTypeMidjourney: "Midjourney", + config.ChannelTypeCloudflareAI: "Cloudflare AI", + config.ChannelTypeCohere: "Cohere", + config.ChannelTypeStabilityAI: "Stability AI", + config.ChannelTypeCoze: "Coze", + config.ChannelTypeOllama: "Ollama", + config.ChannelTypeHunyuan: "Hunyuan", } } diff --git a/router/main.go b/router/main.go index 6702a7c5..b0ed0a64 100644 --- a/router/main.go +++ b/router/main.go @@ -4,7 +4,7 @@ import ( "embed" "fmt" "net/http" - "one-api/common" + "one-api/common/config" "one-api/common/logger" "strings" @@ -17,7 +17,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { SetDashboardRouter(router) SetRelayRouter(router) frontendBaseUrl := viper.GetString("frontend_base_url") - if common.IsMasterNode && frontendBaseUrl != "" { + if config.IsMasterNode && frontendBaseUrl != "" { frontendBaseUrl = "" logger.SysLog("FRONTEND_BASE_URL is ignored on master node") }