From ce12558ad6eb03702312eb49ac1755041f648be7 Mon Sep 17 00:00:00 2001 From: MartialBE Date: Wed, 29 May 2024 01:04:23 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=96=20chore:=20migration=20logger=20pa?= =?UTF-8?q?ckage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cli/export.go | 12 ++++++------ common/common.go | 11 +++++++++++ common/config/config.go | 3 ++- common/gin.go | 3 ++- common/go-channel.go | 3 ++- common/group-ratio.go | 9 ++++++--- common/{ => logger}/logger.go | 10 +--------- common/notify/notifier.go | 14 +++++++------- common/notify/send.go | 6 +++--- common/redis.go | 13 +++++++------ common/requester/ws_client.go | 4 ++-- common/storage/upload.go | 6 +++--- common/telegram/common.go | 27 ++++++++++++++------------- common/token.go | 17 +++++++++-------- controller/channel-test.go | 7 ++++--- controller/github.go | 5 +++-- controller/lark.go | 7 ++++--- controller/midjourney.go | 35 ++++++++++++++++++----------------- cron/main.go | 8 ++++---- main.go | 15 ++++++++------- middleware/logger.go | 4 ++-- middleware/recover.go | 6 +++--- middleware/request-id.go | 8 ++++---- middleware/turnstile-check.go | 5 +++-- middleware/utils.go | 6 +++--- model/balancer.go | 5 +++-- model/cache.go | 9 +++++---- model/channel.go | 11 ++++++----- model/log.go | 7 ++++--- model/main.go | 19 ++++++++++--------- model/migrate.go | 4 ++-- model/option.go | 5 +++-- model/token.go | 13 +++++++------ model/user.go | 9 +++++---- model/utils.go | 9 +++++---- providers/midjourney/base.go | 3 ++- providers/xunfei/base.go | 6 +++--- providers/zhipu/base.go | 4 ++-- relay/common.go | 7 ++++--- relay/main.go | 5 +++-- relay/midjourney/relay.go | 4 ++-- relay/relay_util/pricing.go | 9 +++++---- relay/relay_util/quota.go | 5 +++-- router/main.go | 3 ++- 44 files changed, 207 insertions(+), 174 deletions(-) create mode 100644 common/common.go rename common/{ => logger}/logger.go (92%) diff --git a/cli/export.go b/cli/export.go index 3cb2b23c..a14dd954 100644 --- a/cli/export.go +++ b/cli/export.go @@ -2,7 +2,7 @@ package cli import ( "encoding/json" - "one-api/common" + "one-api/common/logger" "one-api/relay/relay_util" "os" "sort" @@ -12,7 +12,7 @@ func ExportPrices() { prices := relay_util.GetPricesList("default") if len(prices) == 0 { - common.SysError("No prices found") + logger.SysError("No prices found") return } @@ -27,22 +27,22 @@ func ExportPrices() { // 导出到当前目录下的 prices.json 文件 file, err := os.Create("prices.json") if err != nil { - common.SysError("Failed to create file: " + err.Error()) + logger.SysError("Failed to create file: " + err.Error()) return } defer file.Close() jsonData, err := json.MarshalIndent(prices, "", " ") if err != nil { - common.SysError("Failed to encode prices: " + err.Error()) + logger.SysError("Failed to encode prices: " + err.Error()) return } _, err = file.Write(jsonData) if err != nil { - common.SysError("Failed to write to file: " + err.Error()) + logger.SysError("Failed to write to file: " + err.Error()) return } - common.SysLog("Prices exported to prices.json") + logger.SysLog("Prices exported to prices.json") } diff --git a/common/common.go b/common/common.go new file mode 100644 index 00000000..e3d6accc --- /dev/null +++ b/common/common.go @@ -0,0 +1,11 @@ +package common + +import "fmt" + +func LogQuota(quota int) string { + if DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit) + } else { + return fmt.Sprintf("%d 点额度", quota) + } +} diff --git a/common/config/config.go b/common/config/config.go index cbe04c08..3998c293 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -1,6 +1,7 @@ package config import ( + "one-api/common/logger" "strings" "time" @@ -18,7 +19,7 @@ func InitConf() { setEnv() if viper.GetBool("debug") { - common.SysLog("running in debug mode") + logger.SysLog("running in debug mode") } common.IsMasterNode = viper.GetString("node_type") != "slave" diff --git a/common/gin.go b/common/gin.go index 56f3a341..a5ba5781 100644 --- a/common/gin.go +++ b/common/gin.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "io" + "one-api/common/logger" "one-api/types" "github.com/gin-gonic/gin" @@ -65,7 +66,7 @@ func AbortWithMessage(c *gin.Context, statusCode int, message string) { }, }) c.Abort() - LogError(c.Request.Context(), message) + logger.LogError(c.Request.Context(), message) } func APIRespondWithError(c *gin.Context, status int, err error) { diff --git a/common/go-channel.go b/common/go-channel.go index 4f00dff2..fa0e8ef6 100644 --- a/common/go-channel.go +++ b/common/go-channel.go @@ -2,6 +2,7 @@ package common import ( "fmt" + "one-api/common/logger" "runtime/debug" ) @@ -9,7 +10,7 @@ func SafeGoroutine(f func()) { go func() { defer func() { if r := recover(); r != nil { - SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack()))) + logger.SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack()))) } }() f() diff --git a/common/group-ratio.go b/common/group-ratio.go index 1ec73c78..86e9e1f6 100644 --- a/common/group-ratio.go +++ b/common/group-ratio.go @@ -1,6 +1,9 @@ package common -import "encoding/json" +import ( + "encoding/json" + "one-api/common/logger" +) var GroupRatio = map[string]float64{ "default": 1, @@ -11,7 +14,7 @@ var GroupRatio = map[string]float64{ func GroupRatio2JSONString() string { jsonBytes, err := json.Marshal(GroupRatio) if err != nil { - SysError("error marshalling model ratio: " + err.Error()) + logger.SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -24,7 +27,7 @@ func UpdateGroupRatioByJSONString(jsonStr string) error { func GetGroupRatio(name string) float64 { ratio, ok := GroupRatio[name] if !ok { - SysError("group ratio not found: " + name) + logger.SysError("group ratio not found: " + name) return 1 } return ratio diff --git a/common/logger.go b/common/logger/logger.go similarity index 92% rename from common/logger.go rename to common/logger/logger.go index ba2c452d..7f272877 100644 --- a/common/logger.go +++ b/common/logger/logger.go @@ -1,4 +1,4 @@ -package common +package logger import ( "context" @@ -129,11 +129,3 @@ func FatalLog(v ...any) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) os.Exit(1) } - -func LogQuota(quota int) string { - if DisplayInCurrencyEnabled { - return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit) - } else { - return fmt.Sprintf("%d 点额度", quota) - } -} diff --git a/common/notify/notifier.go b/common/notify/notifier.go index 944d0c62..e1b24d9c 100644 --- a/common/notify/notifier.go +++ b/common/notify/notifier.go @@ -2,7 +2,7 @@ package notify import ( "context" - "one-api/common" + "one-api/common/logger" "one-api/common/notify/channel" "github.com/spf13/viper" @@ -23,13 +23,13 @@ func InitNotifier() { func InitEmailNotifier() { if viper.GetBool("notify.email.disable") { - common.SysLog("email notifier disabled") + logger.SysLog("email notifier disabled") return } smtp_to := viper.GetString("notify.email.smtp_to") emailNotifier := channel.NewEmail(smtp_to) AddNotifiers(emailNotifier) - common.SysLog("email notifier enable") + logger.SysLog("email notifier enable") } func InitDingTalkNotifier() { @@ -49,7 +49,7 @@ func InitDingTalkNotifier() { } AddNotifiers(dingTalkNotifier) - common.SysLog("dingtalk notifier enable") + logger.SysLog("dingtalk notifier enable") } func InitLarkNotifier() { @@ -69,7 +69,7 @@ func InitLarkNotifier() { } AddNotifiers(larkNotifier) - common.SysLog("lark notifier enable") + logger.SysLog("lark notifier enable") } func InitPushdeerNotifier() { @@ -81,7 +81,7 @@ func InitPushdeerNotifier() { pushdeerNotifier := channel.NewPushdeer(pushkey, viper.GetString("notify.pushdeer.url")) AddNotifiers(pushdeerNotifier) - common.SysLog("pushdeer notifier enable") + logger.SysLog("pushdeer notifier enable") } func InitTelegramNotifier() { @@ -95,5 +95,5 @@ func InitTelegramNotifier() { telegramNotifier := channel.NewTelegram(bot_token, chat_id, httpProxy) AddNotifiers(telegramNotifier) - common.SysLog("telegram notifier enable") + logger.SysLog("telegram notifier enable") } diff --git a/common/notify/send.go b/common/notify/send.go index 6246af6d..b3b80396 100644 --- a/common/notify/send.go +++ b/common/notify/send.go @@ -3,7 +3,7 @@ package notify import ( "context" "fmt" - "one-api/common" + "one-api/common/logger" ) func (n *Notify) Send(ctx context.Context, title, message string) { @@ -17,14 +17,14 @@ func (n *Notify) Send(ctx context.Context, title, message string) { } err := channel.Send(ctx, title, message) if err != nil { - common.LogError(ctx, fmt.Sprintf("%s err: %s", channelName, err.Error())) + logger.LogError(ctx, fmt.Sprintf("%s err: %s", channelName, err.Error())) } } } func Send(title, message string) { //lint:ignore SA1029 reason: 需要使用该类型作为错误处理 - ctx := context.WithValue(context.Background(), common.RequestIdKey, "NotifyTask") + ctx := context.WithValue(context.Background(), logger.RequestIdKey, "NotifyTask") notifyChannels.Send(ctx, title, message) } diff --git a/common/redis.go b/common/redis.go index 9b29f546..de537138 100644 --- a/common/redis.go +++ b/common/redis.go @@ -2,6 +2,7 @@ package common import ( "context" + "one-api/common/logger" "time" "github.com/go-redis/redis/v8" @@ -16,17 +17,17 @@ func InitRedisClient() (err error) { redisConn := viper.GetString("redis_conn_string") if redisConn == "" { - SysLog("REDIS_CONN_STRING not set, Redis is not enabled") + logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled") return nil } if viper.GetInt("sync_frequency") == 0 { - SysLog("SYNC_FREQUENCY not set, Redis is disabled") + logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") return nil } - SysLog("Redis is enabled") + logger.SysLog("Redis is enabled") opt, err := redis.ParseURL(redisConn) if err != nil { - FatalLog("failed to parse Redis connection string: " + err.Error()) + logger.FatalLog("failed to parse Redis connection string: " + err.Error()) return } RDB = redis.NewClient(opt) @@ -36,7 +37,7 @@ func InitRedisClient() (err error) { _, err = RDB.Ping(ctx).Result() if err != nil { - FatalLog("Redis ping test failed: " + err.Error()) + logger.FatalLog("Redis ping test failed: " + err.Error()) } else { RedisEnabled = true // for compatibility with old versions @@ -49,7 +50,7 @@ func InitRedisClient() (err error) { func ParseRedisOption() *redis.Options { opt, err := redis.ParseURL(viper.GetString("redis_conn_string")) if err != nil { - FatalLog("failed to parse Redis connection string: " + err.Error()) + logger.FatalLog("failed to parse Redis connection string: " + err.Error()) } return opt } diff --git a/common/requester/ws_client.go b/common/requester/ws_client.go index 7b6cf937..bd33d2c1 100644 --- a/common/requester/ws_client.go +++ b/common/requester/ws_client.go @@ -5,7 +5,7 @@ import ( "net" "net/http" "net/url" - "one-api/common" + "one-api/common/logger" "one-api/common/utils" "time" @@ -21,7 +21,7 @@ func GetWSClient(proxyAddr string) *websocket.Dialer { if proxyAddr != "" { err := setWSProxy(dialer, proxyAddr) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) return dialer } } diff --git a/common/storage/upload.go b/common/storage/upload.go index a82cb455..87d484ce 100644 --- a/common/storage/upload.go +++ b/common/storage/upload.go @@ -3,7 +3,7 @@ package storage import ( "context" "fmt" - "one-api/common" + "one-api/common/logger" ) func (s *Storage) Upload(ctx context.Context, data []byte, fileName string) string { @@ -17,7 +17,7 @@ func (s *Storage) Upload(ctx context.Context, data []byte, fileName string) stri } url, err := drive.Upload(data, fileName) if err != nil { - common.LogError(ctx, fmt.Sprintf("%s err: %s", driveName, err.Error())) + logger.LogError(ctx, fmt.Sprintf("%s err: %s", driveName, err.Error())) } else { return url } @@ -28,7 +28,7 @@ func (s *Storage) Upload(ctx context.Context, data []byte, fileName string) stri func Upload(data []byte, fileName string) string { //lint:ignore SA1029 reason: 需要使用该类型作为错误处理 - ctx := context.WithValue(context.Background(), common.RequestIdKey, "Upload") + ctx := context.WithValue(context.Background(), logger.RequestIdKey, "Upload") return storageDrives.Upload(ctx, data, fileName) } diff --git a/common/telegram/common.go b/common/telegram/common.go index cd56193f..4a3a3160 100644 --- a/common/telegram/common.go +++ b/common/telegram/common.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/common/logger" "one-api/model" "strings" "time" @@ -29,20 +30,20 @@ var TGEnabled = false func InitTelegramBot() { if TGEnabled { - common.SysLog("Telegram bot has been started") + logger.SysLog("Telegram bot has been started") return } botKey := viper.GetString("tg.bot_api_key") if botKey == "" { - common.SysLog("Telegram bot is not enabled") + logger.SysLog("Telegram bot is not enabled") return } var err error TGBot, err = gotgbot.NewBot(botKey, getBotOpts()) if err != nil { - common.SysLog("failed to create new telegram bot: " + err.Error()) + logger.SysLog("failed to create new telegram bot: " + err.Error()) return } @@ -56,7 +57,7 @@ func StartTelegramBot() { botWebhook := viper.GetString("tg.webhook_secret") if botWebhook != "" { if common.ServerAddress == "" { - common.SysLog("Telegram bot is not enabled: Server address is not set") + logger.SysLog("Telegram bot is not enabled: Server address is not set") StopTelegramBot() return } @@ -70,7 +71,7 @@ func StartTelegramBot() { err := TGupdater.AddWebhook(TGBot, urlPath, webHookOpts) if err != nil { - common.SysLog("Telegram bot failed to add webhook:" + err.Error()) + logger.SysLog("Telegram bot failed to add webhook:" + err.Error()) return } @@ -80,7 +81,7 @@ func StartTelegramBot() { SecretToken: TGWebHookSecret, }) if err != nil { - common.SysLog("Telegram bot failed to set webhook:" + err.Error()) + logger.SysLog("Telegram bot failed to set webhook:" + err.Error()) return } } else { @@ -96,13 +97,13 @@ func StartTelegramBot() { }) if err != nil { - common.SysLog("Telegram bot failed to start polling:" + err.Error()) + logger.SysLog("Telegram bot failed to start polling:" + err.Error()) } } // Idle, to keep updates coming in, and avoid bot stopping. go TGupdater.Idle() - common.SysLog(fmt.Sprintf("Telegram bot %s has been started...:", TGBot.User.Username)) + logger.SysLog(fmt.Sprintf("Telegram bot %s has been started...:", TGBot.User.Username)) TGEnabled = true } @@ -135,7 +136,7 @@ func setDispatcher() *ext.Dispatcher { dispatcher := ext.NewDispatcher(&ext.DispatcherOpts{ // If an error is returned by a handler, log it and continue going. Error: func(b *gotgbot.Bot, ctx *ext.Context, err error) ext.DispatcherAction { - common.SysLog("telegram an error occurred while handling update: " + err.Error()) + logger.SysLog("telegram an error occurred while handling update: " + err.Error()) return ext.DispatcherActionNoop }, MaxRoutines: ext.DefaultMaxRoutines, @@ -173,7 +174,7 @@ func getMenu() []gotgbot.BotCommand { customMenu, err := model.GetTelegramMenus() if err != nil { - common.SysLog("Failed to get custom menu, error: " + err.Error()) + logger.SysLog("Failed to get custom menu, error: " + err.Error()) } if len(customMenu) > 0 { @@ -234,7 +235,7 @@ func getHttpClient() (httpClient *http.Client) { proxyURL, err := url.Parse(proxyAddr) if err != nil { - common.SysLog("failed to parse TG proxy URL: " + err.Error()) + logger.SysLog("failed to parse TG proxy URL: " + err.Error()) return } switch proxyURL.Scheme { @@ -247,7 +248,7 @@ func getHttpClient() (httpClient *http.Client) { case "socks5": dialer, err := proxy.FromURL(proxyURL, proxy.Direct) if err != nil { - common.SysLog("failed to create TG SOCKS5 dialer: " + err.Error()) + logger.SysLog("failed to create TG SOCKS5 dialer: " + err.Error()) return } httpClient = &http.Client{ @@ -258,7 +259,7 @@ func getHttpClient() (httpClient *http.Client) { }, } default: - common.SysLog("unknown TG proxy type: " + proxyAddr) + logger.SysLog("unknown TG proxy type: " + proxyAddr) } return diff --git a/common/token.go b/common/token.go index 65ec7243..df125193 100644 --- a/common/token.go +++ b/common/token.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "math" + "one-api/common/logger" "strings" "one-api/common/image" @@ -21,27 +22,27 @@ var gpt4oTokenEncoder *tiktoken.Tiktoken func InitTokenEncoders() { if viper.GetBool("disable_token_encoders") { DISABLE_TOKEN_ENCODERS = true - SysLog("token encoders disabled") + logger.SysLog("token encoders disabled") return } - SysLog("initializing token encoders") + logger.SysLog("initializing token encoders") var err error gpt35TokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo") if err != nil { - FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) + logger.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) } gpt4TokenEncoder, err = tiktoken.EncodingForModel("gpt-4") if err != nil { - FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) + logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) } gpt4oTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o") if err != nil { - FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error())) + logger.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error())) } - SysLog("token encoders initialized") + logger.SysLog("token encoders initialized") } func getTokenEncoder(model string) *tiktoken.Tiktoken { @@ -64,7 +65,7 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken { var err error tokenEncoder, err = tiktoken.EncodingForModel(model) if err != nil { - SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) + logger.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) tokenEncoder = gpt35TokenEncoder } } @@ -119,7 +120,7 @@ func CountTokenMessages(messages []types.ChatCompletionMessage, model string) in imageTokens, err := countImageTokens(url, detail) if err != nil { //Due to the excessive length of the error information, only extract and record the most critical part. - SysError("error counting image tokens: " + err.Error()) + logger.SysError("error counting image tokens: " + err.Error()) } else { tokenNum += imageTokens } diff --git a/controller/channel-test.go b/controller/channel-test.go index b52340d0..ee380893 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "one-api/common" + "one-api/common/logger" "one-api/common/notify" "one-api/common/utils" "one-api/model" @@ -70,7 +71,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr // 转换为JSON字符串 jsonBytes, _ := json.Marshal(response) - common.SysLog(fmt.Sprintf("测试渠道 %s : %s 返回内容为:%s", channel.Name, request.Model, string(jsonBytes))) + logger.SysLog(fmt.Sprintf("测试渠道 %s : %s 返回内容为:%s", channel.Name, request.Model, string(jsonBytes))) return nil, nil } @@ -233,8 +234,8 @@ func AutomaticallyTestChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) - common.SysLog("testing all channels") + logger.SysLog("testing all channels") _ = testAllChannels(false) - common.SysLog("channel test finished") + logger.SysLog("channel test finished") } } diff --git a/controller/github.go b/controller/github.go index a6d923d0..2c452bc0 100644 --- a/controller/github.go +++ b/controller/github.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/logger" "one-api/common/utils" "one-api/model" "strconv" @@ -48,7 +49,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { } res, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res.Body.Close() @@ -64,7 +65,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) res2, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res2.Body.Close() diff --git a/controller/lark.go b/controller/lark.go index 328a5b83..b8e8d06f 100644 --- a/controller/lark.go +++ b/controller/lark.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/logger" "one-api/model" "strconv" "time" @@ -58,7 +59,7 @@ func getLarkAppAccessToken() (string, error) { } res, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return "", errors.New("无法连接至飞书服务器,请稍后重试!") } defer res.Body.Close() @@ -100,7 +101,7 @@ func getLarkUserAccessToken(code string) (string, error) { } res, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return "", errors.New("无法连接至飞书服务器,请稍后重试!") } defer res.Body.Close() @@ -135,7 +136,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) { } res2, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至飞书服务器,请稍后重试!") } var larkUser LarkUser diff --git a/controller/midjourney.go b/controller/midjourney.go index 8286807a..93e80715 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -11,6 +11,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" "one-api/common/requester" "one-api/model" provider "one-api/providers/midjourney" @@ -45,9 +46,9 @@ func ActivateUpdateMidjourneyTaskBulk() { } func UpdateMidjourneyTaskBulk() { - ctx := context.WithValue(context.Background(), common.RequestIdKey, "MidjourneyTask") + ctx := context.WithValue(context.Background(), logger.RequestIdKey, "MidjourneyTask") for { - common.LogInfo(ctx, "running") + logger.LogInfo(ctx, "running") tasks := model.GetAllUnFinishTasks() @@ -56,11 +57,11 @@ func UpdateMidjourneyTaskBulk() { for len(activeMidjourneyTask) > 0 { <-activeMidjourneyTask } - common.LogInfo(ctx, "no tasks, waiting...") + logger.LogInfo(ctx, "no tasks, waiting...") return } - common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) + logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) taskChannelM := make(map[int][]string) taskM := make(map[string]*model.Midjourney) nullTaskIds := make([]int, 0) @@ -79,9 +80,9 @@ func UpdateMidjourneyTaskBulk() { "progress": "100%", }) if err != nil { - common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err)) } else { - common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds)) + logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds)) } } if len(taskChannelM) == 0 { @@ -89,7 +90,7 @@ func UpdateMidjourneyTaskBulk() { } for channelId, taskIds := range taskChannelM { - common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) if len(taskIds) == 0 { continue } @@ -100,7 +101,7 @@ func UpdateMidjourneyTaskBulk() { "status": "FAILURE", "progress": "100%", }) - common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) + logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) continue } requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL) @@ -110,7 +111,7 @@ func UpdateMidjourneyTaskBulk() { }) req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body)) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err)) continue } // 设置超时时间 @@ -122,22 +123,22 @@ func UpdateMidjourneyTaskBulk() { req.Header.Set("mj-api-secret", midjourneyChannel.Key) resp, err := requester.HTTPClient.Do(req) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) continue } if resp.StatusCode != http.StatusOK { - common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) continue } responseBody, err := io.ReadAll(resp.Body) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) continue } var responseItems []provider.MidjourneyDto err = json.Unmarshal(responseBody, &responseItems) if err != nil { - common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) + logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) continue } resp.Body.Close() @@ -176,17 +177,17 @@ func UpdateMidjourneyTaskBulk() { } if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") { - common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) + logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) task.Progress = "100%" err = model.CacheUpdateUserQuota(task.UserId) if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) + logger.LogError(ctx, "error update user quota cache: "+err.Error()) } else { quota := task.Quota if quota != 0 { err = model.IncreaseUserQuota(task.UserId, quota) if err != nil { - common.LogError(ctx, "fail to increase user quota: "+err.Error()) + logger.LogError(ctx, "fail to increase user quota: "+err.Error()) } logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota)) model.RecordLog(task.UserId, model.LogTypeSystem, logContent) @@ -195,7 +196,7 @@ func UpdateMidjourneyTaskBulk() { } err = task.Update() if err != nil { - common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) + logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) } } } diff --git a/cron/main.go b/cron/main.go index 8e98da0f..198d7540 100644 --- a/cron/main.go +++ b/cron/main.go @@ -1,7 +1,7 @@ package cron import ( - "one-api/common" + "one-api/common/logger" "one-api/model" "time" @@ -11,7 +11,7 @@ import ( func InitCron() { scheduler, err := gocron.NewScheduler() if err != nil { - common.SysLog("Cron scheduler error: " + err.Error()) + logger.SysLog("Cron scheduler error: " + err.Error()) return } @@ -24,12 +24,12 @@ func InitCron() { )), gocron.NewTask(func() { model.RemoveChatCache(time.Now().Unix()) - common.SysLog("删除过期缓存数据") + logger.SysLog("删除过期缓存数据") }), ) if err != nil { - common.SysLog("Cron job error: " + err.Error()) + logger.SysLog("Cron job error: " + err.Error()) return } diff --git a/main.go b/main.go index 973fe2f1..ef0c1984 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "fmt" "one-api/common" "one-api/common/config" + "one-api/common/logger" "one-api/common/notify" "one-api/common/requester" "one-api/common/storage" @@ -31,8 +32,8 @@ var indexPage []byte func main() { config.InitConf() - common.SetupLogger() - common.SysLog("One API " + common.Version + " started") + logger.SetupLogger() + logger.SysLog("One API " + common.Version + " started") // Initialize SQL Database model.SetupDB() defer model.CloseDB() @@ -69,8 +70,8 @@ func initMemoryCache() { syncFrequency := viper.GetInt("sync_frequency") model.TokenCacheSeconds = syncFrequency - common.SysLog("memory cache enabled") - common.SysError(fmt.Sprintf("sync frequency: %d seconds", syncFrequency)) + logger.SysLog("memory cache enabled") + logger.SysError(fmt.Sprintf("sync frequency: %d seconds", syncFrequency)) go model.SyncOptions(syncFrequency) go SyncChannelCache(syncFrequency) } @@ -98,19 +99,19 @@ func initHttpServer() { err := server.Run(":" + port) if err != nil { - common.FatalLog("failed to start HTTP server: " + err.Error()) + logger.FatalLog("failed to start HTTP server: " + err.Error()) } } func SyncChannelCache(frequency int) { // 只有 从 服务器端获取数据的时候才会用到 if common.IsMasterNode { - common.SysLog("master node does't synchronize the channel") + logger.SysLog("master node does't synchronize the channel") return } for { time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing channels from database") + logger.SysLog("syncing channels from database") model.ChannelGroup.Load() relay_util.PricingInstance.Init() } diff --git a/middleware/logger.go b/middleware/logger.go index 02f2e0a9..ca372c52 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -3,14 +3,14 @@ package middleware import ( "fmt" "github.com/gin-gonic/gin" - "one-api/common" + "one-api/common/logger" ) func SetUpLogger(server *gin.Engine) { server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { var requestID string if param.Keys != nil { - requestID = param.Keys[common.RequestIdKey].(string) + requestID = param.Keys[logger.RequestIdKey].(string) } return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", param.TimeStamp.Format("2006/01/02 - 15:04:05"), diff --git a/middleware/recover.go b/middleware/recover.go index 6f1c1aed..1b85cc9a 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -3,7 +3,7 @@ package middleware import ( "fmt" "net/http" - "one-api/common" + "one-api/common/logger" "runtime/debug" "github.com/gin-gonic/gin" @@ -13,8 +13,8 @@ func RelayPanicRecover() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { - common.SysError(fmt.Sprintf("panic detected: %v", err)) - common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + logger.SysError(fmt.Sprintf("panic detected: %v", err)) + logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/MartialBE/one-api", err), diff --git a/middleware/request-id.go b/middleware/request-id.go index edca8c6f..e6a587d5 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -2,7 +2,7 @@ package middleware import ( "context" - "one-api/common" + "one-api/common/logger" "one-api/common/utils" "time" @@ -12,11 +12,11 @@ import ( func RequestId() func(c *gin.Context) { return func(c *gin.Context) { id := utils.GetTimeString() + utils.GetRandomString(8) - c.Set(common.RequestIdKey, id) - ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) + c.Set(logger.RequestIdKey, id) + ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) ctx = context.WithValue(ctx, "requestStartTime", time.Now()) c.Request = c.Request.WithContext(ctx) - c.Header(common.RequestIdKey, id) + c.Header(logger.RequestIdKey, id) c.Next() } } diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go index 26688810..6f295864 100644 --- a/middleware/turnstile-check.go +++ b/middleware/turnstile-check.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/common/logger" ) type turnstileCheckResponse struct { @@ -37,7 +38,7 @@ func TurnstileCheck() gin.HandlerFunc { "remoteip": {c.ClientIP()}, }) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), @@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc { var res turnstileCheckResponse err = json.NewDecoder(rawRes.Body).Decode(&res) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/middleware/utils.go b/middleware/utils.go index bfa58881..04c78aeb 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -1,7 +1,7 @@ package middleware import ( - "one-api/common" + "one-api/common/logger" "one-api/common/utils" "github.com/gin-gonic/gin" @@ -10,10 +10,10 @@ import ( func abortWithMessage(c *gin.Context, statusCode int, message string) { c.JSON(statusCode, gin.H{ "error": gin.H{ - "message": utils.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), + "message": utils.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)), "type": "one_api_error", }, }) c.Abort() - common.LogError(c.Request.Context(), message) + logger.LogError(c.Request.Context(), message) } diff --git a/model/balancer.go b/model/balancer.go index cdcf906f..82cd0289 100644 --- a/model/balancer.go +++ b/model/balancer.go @@ -4,6 +4,7 @@ import ( "errors" "math/rand" "one-api/common" + "one-api/common/logger" "one-api/common/utils" "strings" "sync" @@ -162,7 +163,7 @@ func (cc *ChannelsChooser) Load() { abilities, err := GetAbilityChannelGroup() if err != nil { - common.SysLog("get enabled abilities failed: " + err.Error()) + logger.SysLog("get enabled abilities failed: " + err.Error()) return } @@ -216,5 +217,5 @@ func (cc *ChannelsChooser) Load() { cc.Channels = newChannels cc.Match = newMatchList cc.Unlock() - common.SysLog("channels Load success") + logger.SysLog("channels Load success") } diff --git a/model/cache.go b/model/cache.go index 7c3a8e14..61832ce0 100644 --- a/model/cache.go +++ b/model/cache.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "one-api/common" + "one-api/common/logger" "strconv" "time" ) @@ -34,7 +35,7 @@ func CacheGetTokenByKey(key string) (*Token, error) { } err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) if err != nil { - common.SysError("Redis set token error: " + err.Error()) + logger.SysError("Redis set token error: " + err.Error()) } return &token, nil } @@ -54,7 +55,7 @@ func CacheGetUserGroup(id int) (group string, err error) { } err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(TokenCacheSeconds)*time.Second) if err != nil { - common.SysError("Redis set user group error: " + err.Error()) + logger.SysError("Redis set user group error: " + err.Error()) } } return group, err @@ -72,7 +73,7 @@ func CacheGetUserQuota(id int) (quota int, err error) { } err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(TokenCacheSeconds)*time.Second) if err != nil { - common.SysError("Redis set user quota error: " + err.Error()) + logger.SysError("Redis set user quota error: " + err.Error()) } return quota, err } @@ -119,7 +120,7 @@ func CacheIsUserEnabled(userId int) (bool, error) { } err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(TokenCacheSeconds)*time.Second) if err != nil { - common.SysError("Redis set user enabled error: " + err.Error()) + logger.SysError("Redis set user enabled error: " + err.Error()) } return userEnabled, err } diff --git a/model/channel.go b/model/channel.go index 7e6b2481..b2733ae7 100644 --- a/model/channel.go +++ b/model/channel.go @@ -2,6 +2,7 @@ package model import ( "one-api/common" + "one-api/common/logger" "one-api/common/utils" "strings" @@ -240,7 +241,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) { ResponseTime: int(responseTime), }).Error if err != nil { - common.SysError("failed to update response time: " + err.Error()) + logger.SysError("failed to update response time: " + err.Error()) } } @@ -250,7 +251,7 @@ func (channel *Channel) UpdateBalance(balance float64) { Balance: balance, }).Error if err != nil { - common.SysError("failed to update balance: " + err.Error()) + logger.SysError("failed to update balance: " + err.Error()) } } @@ -283,11 +284,11 @@ func (channel *Channel) StatusToStr() string { func UpdateChannelStatusById(id int, status int) { err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) if err != nil { - common.SysError("failed to update ability status: " + err.Error()) + logger.SysError("failed to update ability status: " + err.Error()) } err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error if err != nil { - common.SysError("failed to update channel status: " + err.Error()) + logger.SysError("failed to update channel status: " + err.Error()) } if err == nil { @@ -307,7 +308,7 @@ func UpdateChannelUsedQuota(id int, quota int) { func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { - common.SysError("failed to update channel used quota: " + err.Error()) + logger.SysError("failed to update channel used quota: " + err.Error()) } } diff --git a/model/log.go b/model/log.go index b507ed22..6c8b2008 100644 --- a/model/log.go +++ b/model/log.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "one-api/common" + "one-api/common/logger" "one-api/common/utils" "gorm.io/gorm" @@ -48,12 +49,12 @@ func RecordLog(userId int, logType int, content string) { } err := DB.Create(log).Error if err != nil { - common.SysError("failed to record log: " + err.Error()) + logger.SysError("failed to record log: " + err.Error()) } } func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, requestTime int) { - common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) + logger.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !common.LogConsumeEnabled { return } @@ -73,7 +74,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke } err := DB.Create(log).Error if err != nil { - common.LogError(ctx, "failed to record log: "+err.Error()) + logger.LogError(ctx, "failed to record log: "+err.Error()) } } diff --git a/model/main.go b/model/main.go index 637d6ba9..a81e1dc3 100644 --- a/model/main.go +++ b/model/main.go @@ -3,6 +3,7 @@ package model import ( "fmt" "one-api/common" + "one-api/common/logger" "one-api/common/utils" "strconv" "strings" @@ -20,7 +21,7 @@ var DB *gorm.DB func SetupDB() { err := InitDB() if err != nil { - common.FatalLog("failed to initialize database: " + err.Error()) + logger.FatalLog("failed to initialize database: " + err.Error()) } ChannelGroup.Load() common.RootUserEmail = GetRootUserEmail() @@ -28,7 +29,7 @@ func SetupDB() { if viper.GetBool("batch_update_enabled") { common.BatchUpdateEnabled = true common.BatchUpdateInterval = utils.GetOrDefault("batch_update_interval", 5) - common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") + logger.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") InitBatchUpdater() } } @@ -37,7 +38,7 @@ func createRootAccountIfNeed() error { var user User //if user.Status != common.UserStatusEnabled { if err := DB.First(&user).Error; err != nil { - common.SysLog("no user exists, create a root user for you: username is root, password is 123456") + logger.SysLog("no user exists, create a root user for you: username is root, password is 123456") hashedPassword, err := common.Password2Hash("123456") if err != nil { return err @@ -61,7 +62,7 @@ func chooseDB() (*gorm.DB, error) { dsn := viper.GetString("sql_dsn") if strings.HasPrefix(dsn, "postgres://") { // Use PostgreSQL - common.SysLog("using PostgreSQL as database") + logger.SysLog("using PostgreSQL as database") common.UsingPostgreSQL = true return gorm.Open(postgres.New(postgres.Config{ DSN: dsn, @@ -71,13 +72,13 @@ func chooseDB() (*gorm.DB, error) { }) } // Use MySQL - common.SysLog("using MySQL as database") + logger.SysLog("using MySQL as database") return gorm.Open(mysql.Open(dsn), &gorm.Config{ PrepareStmt: true, // precompile SQL }) } // Use SQLite - common.SysLog("SQL_DSN not set, using SQLite as database") + logger.SysLog("SQL_DSN not set, using SQLite as database") common.UsingSQLite = true config := fmt.Sprintf("?_busy_timeout=%d", utils.GetOrDefault("sqlite_busy_timeout", 3000)) return gorm.Open(sqlite.Open(viper.GetString("sqlite_path")+config), &gorm.Config{ @@ -104,7 +105,7 @@ func InitDB() (err error) { if !common.IsMasterNode { return nil } - common.SysLog("database migration started") + logger.SysLog("database migration started") migration(DB) @@ -152,11 +153,11 @@ func InitDB() (err error) { if err != nil { return err } - common.SysLog("database migrated") + logger.SysLog("database migrated") err = createRootAccountIfNeed() return err } else { - common.FatalLog(err) + logger.FatalLog(err) } return err } diff --git a/model/migrate.go b/model/migrate.go index 148e4880..0918269b 100644 --- a/model/migrate.go +++ b/model/migrate.go @@ -1,7 +1,7 @@ package model import ( - "one-api/common" + "one-api/common/logger" "github.com/go-gormigrate/gormigrate/v2" "gorm.io/gorm" @@ -22,7 +22,7 @@ func removeKeyIndexMigration() *gormigrate.Migration { err := tx.Migrator().DropIndex(&Channel{}, "idx_channels_key") if err != nil { - common.SysLog("remove idx_channels_key Failure: " + err.Error()) + logger.SysLog("remove idx_channels_key Failure: " + err.Error()) } return nil }, diff --git a/model/option.go b/model/option.go index 1af71811..4e75f3f3 100644 --- a/model/option.go +++ b/model/option.go @@ -2,6 +2,7 @@ package model import ( "one-api/common" + "one-api/common/logger" "strconv" "strings" "time" @@ -90,7 +91,7 @@ func loadOptionsFromDatabase() { for _, option := range options { err := updateOptionMap(option.Key, option.Value) if err != nil { - common.SysError("failed to update option map: " + err.Error()) + logger.SysError("failed to update option map: " + err.Error()) } } } @@ -98,7 +99,7 @@ func loadOptionsFromDatabase() { func SyncOptions(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing options from database") + logger.SysLog("syncing options from database") loadOptionsFromDatabase() } } diff --git a/model/token.go b/model/token.go index 2e89185d..883b6d4c 100644 --- a/model/token.go +++ b/model/token.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/common/logger" "one-api/common/stmp" "one-api/common/utils" @@ -58,7 +59,7 @@ func ValidateUserToken(key string) (token *Token, err error) { } token, err = CacheGetTokenByKey(key) if err != nil { - common.SysError("CacheGetTokenByKey failed: " + err.Error()) + logger.SysError("CacheGetTokenByKey failed: " + err.Error()) if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("无效的令牌") } @@ -77,7 +78,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExpired err := token.SelectUpdate() if err != nil { - common.SysError("failed to update token status" + err.Error()) + logger.SysError("failed to update token status" + err.Error()) } } return nil, errors.New("该令牌已过期") @@ -88,7 +89,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExhausted err := token.SelectUpdate() if err != nil { - common.SysError("failed to update token status" + err.Error()) + logger.SysError("failed to update token status" + err.Error()) } } return nil, errors.New("该令牌额度已用尽") @@ -254,12 +255,12 @@ func sendQuotaWarningEmail(userId int, userQuota int, noMoreQuota bool) { user := User{Id: userId} if err := user.FillUserById(); err != nil { - common.SysError("failed to fetch user email: " + err.Error()) + logger.SysError("failed to fetch user email: " + err.Error()) return } if user.Email == "" { - common.SysError("user email is empty") + logger.SysError("user email is empty") return } @@ -271,7 +272,7 @@ func sendQuotaWarningEmail(userId int, userQuota int, noMoreQuota bool) { err := stmp.SendQuotaWarningCodeEmail(userName, user.Email, userQuota, noMoreQuota) if err != nil { - common.SysError("failed to send email" + err.Error()) + logger.SysError("failed to send email" + err.Error()) } } diff --git a/model/user.go b/model/user.go index 311c0702..f7f13ec8 100644 --- a/model/user.go +++ b/model/user.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "one-api/common" + "one-api/common/logger" "one-api/common/utils" "strings" @@ -306,7 +307,7 @@ func IsAdmin(userId int) bool { var user User err := DB.Where("id = ?", userId).Select("role").Find(&user).Error if err != nil { - common.SysError("no such user " + err.Error()) + logger.SysError("no such user " + err.Error()) return false } return user.Role >= common.RoleAdminUser @@ -415,7 +416,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { }, ).Error if err != nil { - common.SysError("failed to update user used quota and request count: " + err.Error()) + logger.SysError("failed to update user used quota and request count: " + err.Error()) } } @@ -426,14 +427,14 @@ func updateUserUsedQuota(id int, quota int) { }, ).Error if err != nil { - common.SysError("failed to update user used quota: " + err.Error()) + logger.SysError("failed to update user used quota: " + err.Error()) } } func updateUserRequestCount(id int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error if err != nil { - common.SysError("failed to update user request count: " + err.Error()) + logger.SysError("failed to update user request count: " + err.Error()) } } diff --git a/model/utils.go b/model/utils.go index 1c28340b..e4797a78 100644 --- a/model/utils.go +++ b/model/utils.go @@ -2,6 +2,7 @@ package model import ( "one-api/common" + "one-api/common/logger" "sync" "time" ) @@ -45,7 +46,7 @@ func addNewRecord(type_ int, id int, value int) { } func batchUpdate() { - common.SysLog("batch update started") + logger.SysLog("batch update started") for i := 0; i < BatchUpdateTypeCount; i++ { batchUpdateLocks[i].Lock() store := batchUpdateStores[i] @@ -57,12 +58,12 @@ func batchUpdate() { case BatchUpdateTypeUserQuota: err := increaseUserQuota(key, value) if err != nil { - common.SysError("failed to batch update user quota: " + err.Error()) + logger.SysError("failed to batch update user quota: " + err.Error()) } case BatchUpdateTypeTokenQuota: err := increaseTokenQuota(key, value) if err != nil { - common.SysError("failed to batch update token quota: " + err.Error()) + logger.SysError("failed to batch update token quota: " + err.Error()) } case BatchUpdateTypeUsedQuota: updateUserUsedQuota(key, value) @@ -73,5 +74,5 @@ func batchUpdate() { } } } - common.SysLog("batch update finished") + logger.SysLog("batch update finished") } diff --git a/providers/midjourney/base.go b/providers/midjourney/base.go index 5b6a2569..434457f9 100644 --- a/providers/midjourney/base.go +++ b/providers/midjourney/base.go @@ -8,6 +8,7 @@ import ( "log" "net/http" "one-api/common" + "one-api/common/logger" "one-api/common/requester" "one-api/model" "one-api/providers/base" @@ -71,7 +72,7 @@ func (p *MidjourneyProvider) Send(timeout int, requestURL string) (*MidjourneyRe resp, errWith := p.Requester.SendRequestRaw(req) if errWith != nil { - common.SysError("do request failed: " + errWith.Error()) + logger.SysError("do request failed: " + errWith.Error()) return MidjourneyErrorWithStatusCodeWrapper(MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err } statusCode := resp.StatusCode diff --git a/providers/xunfei/base.go b/providers/xunfei/base.go index e6a7975d..1145cd46 100644 --- a/providers/xunfei/base.go +++ b/providers/xunfei/base.go @@ -6,7 +6,7 @@ import ( "encoding/base64" "fmt" "net/url" - "one-api/common" + "one-api/common/logger" "one-api/common/requester" "one-api/model" "one-api/providers/base" @@ -94,7 +94,7 @@ func (p *XunfeiProvider) getAPIVersion(modelName string) string { } apiVersion = "v1.1" - common.SysLog("api_version not found, use default: " + apiVersion) + logger.SysLog("api_version not found, use default: " + apiVersion) return apiVersion } @@ -130,7 +130,7 @@ func (p *XunfeiProvider) buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret st } ul, err := url.Parse(hostUrl) if err != nil { - common.SysError("url parse error: " + err.Error()) + logger.SysError("url parse error: " + err.Error()) return "" } date := time.Now().UTC().Format(time.RFC1123) diff --git a/providers/zhipu/base.go b/providers/zhipu/base.go index c5b696f9..f5e30262 100644 --- a/providers/zhipu/base.go +++ b/providers/zhipu/base.go @@ -4,7 +4,7 @@ import ( "encoding/json" "fmt" "net/http" - "one-api/common" + "one-api/common/logger" "one-api/common/requester" "one-api/model" "one-api/providers/base" @@ -95,7 +95,7 @@ func (p *ZhipuProvider) getZhipuToken() string { split := strings.Split(apikey, ".") if len(split) != 2 { - common.SysError("invalid zhipu key: " + apikey) + logger.SysError("invalid zhipu key: " + apikey) return "" } diff --git a/relay/common.go b/relay/common.go index 45907804..a6618e2f 100644 --- a/relay/common.go +++ b/relay/common.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/common/logger" "one-api/common/requester" "one-api/common/utils" "one-api/controller" @@ -115,7 +116,7 @@ func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, erro if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, modelName) if channel != nil { - common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) + logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" } return nil, errors.New(message) @@ -250,14 +251,14 @@ func shouldRetry(c *gin.Context, statusCode int) bool { } func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *types.OpenAIErrorWithStatusCode) { - common.LogError(ctx, fmt.Sprintf("relay error (channel #%d(%s)): %s", channelId, channelName, err.Message)) + logger.LogError(ctx, fmt.Sprintf("relay error (channel #%d(%s)): %s", channelId, channelName, err.Message)) if controller.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) { controller.DisableChannel(channelId, channelName, err.Message, true) } } func relayResponseWithErr(c *gin.Context, err *types.OpenAIErrorWithStatusCode) { - requestId := c.GetString(common.RequestIdKey) + requestId := c.GetString(logger.RequestIdKey) err.OpenAIError.Message = utils.MessageWithRequestId(err.OpenAIError.Message, requestId) c.JSON(err.StatusCode, gin.H{ "error": err.OpenAIError, diff --git a/relay/main.go b/relay/main.go index 6ed4d385..a035a025 100644 --- a/relay/main.go +++ b/relay/main.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/logger" "one-api/model" "one-api/relay/relay_util" "one-api/types" @@ -51,7 +52,7 @@ func Relay(c *gin.Context) { retryTimes := common.RetryTimes if done || !shouldRetry(c, apiErr.StatusCode) { - common.LogError(c.Request.Context(), fmt.Sprintf("relay error happen, status code is %d, won't retry in this case", apiErr.StatusCode)) + logger.LogError(c.Request.Context(), fmt.Sprintf("relay error happen, status code is %d, won't retry in this case", apiErr.StatusCode)) retryTimes = 0 } @@ -63,7 +64,7 @@ func Relay(c *gin.Context) { } channel = relay.getProvider().GetChannel() - common.LogError(c.Request.Context(), fmt.Sprintf("using channel #%d(%s) to retry (remain times %d)", channel.Id, channel.Name, i)) + logger.LogError(c.Request.Context(), fmt.Sprintf("using channel #%d(%s) to retry (remain times %d)", channel.Id, channel.Name, i)) apiErr, done = RelayHandler(relay) if apiErr == nil { return diff --git a/relay/midjourney/relay.go b/relay/midjourney/relay.go index e4ce2f15..c3386663 100644 --- a/relay/midjourney/relay.go +++ b/relay/midjourney/relay.go @@ -6,7 +6,7 @@ package midjourney import ( "fmt" "net/http" - "one-api/common" + "one-api/common/logger" provider "one-api/providers/midjourney" "strings" @@ -46,7 +46,7 @@ func RelayMidjourney(c *gin.Context) { "code": err.Code, }) channelId := c.GetInt("channel_id") - common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result))) + logger.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result))) } } diff --git a/relay/relay_util/pricing.go b/relay/relay_util/pricing.go index 67c46e54..b2cab192 100644 --- a/relay/relay_util/pricing.go +++ b/relay/relay_util/pricing.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "one-api/common" + "one-api/common/logger" "one-api/common/utils" "one-api/model" "sort" @@ -30,7 +31,7 @@ type BatchPrices struct { // NewPricing creates a new Pricing instance func NewPricing() { - common.SysLog("Initializing Pricing") + logger.SysLog("Initializing Pricing") PricingInstance = &Pricing{ Prices: make(map[string]*model.Price), @@ -40,16 +41,16 @@ func NewPricing() { err := PricingInstance.Init() if err != nil { - common.SysError("Failed to initialize Pricing:" + err.Error()) + logger.SysError("Failed to initialize Pricing:" + err.Error()) return } // 初始化时,需要检测是否有更新 if viper.GetBool("auto_price_updates") || len(PricingInstance.Prices) == 0 { - common.SysLog("Checking for pricing updates") + logger.SysLog("Checking for pricing updates") prices := model.GetDefaultPrice() PricingInstance.SyncPricing(prices, false) - common.SysLog("Pricing initialized") + logger.SysLog("Pricing initialized") } } diff --git a/relay/relay_util/quota.go b/relay/relay_util/quota.go index c859361e..48f6b094 100644 --- a/relay/relay_util/quota.go +++ b/relay/relay_util/quota.go @@ -7,6 +7,7 @@ import ( "math" "net/http" "one-api/common" + "one-api/common/logger" "one-api/model" "one-api/types" "time" @@ -154,7 +155,7 @@ func (q *Quota) Undo(c *gin.Context) { // return pre-consumed quota err := model.PostConsumeTokenQuota(tokenId, -q.preConsumedQuota) if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) + logger.LogError(ctx, "error return pre-consumed quota: "+err.Error()) } }(c.Request.Context()) } @@ -166,7 +167,7 @@ func (q *Quota) Consume(c *gin.Context, usage *types.Usage) { go func(ctx context.Context) { err := q.completedQuotaConsumption(usage, tokenName, ctx) if err != nil { - common.LogError(ctx, err.Error()) + logger.LogError(ctx, err.Error()) } }(c.Request.Context()) } diff --git a/router/main.go b/router/main.go index e26118d7..6702a7c5 100644 --- a/router/main.go +++ b/router/main.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/logger" "strings" "github.com/gin-gonic/gin" @@ -18,7 +19,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { frontendBaseUrl := viper.GetString("frontend_base_url") if common.IsMasterNode && frontendBaseUrl != "" { frontendBaseUrl = "" - common.SysLog("FRONTEND_BASE_URL is ignored on master node") + logger.SysLog("FRONTEND_BASE_URL is ignored on master node") } if frontendBaseUrl == "" { SetWebRouter(router, buildFS, indexPage)