From 42451d9d02e22cb546b746b68e68cec528603e5f Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Sep 2023 15:39:46 +0800 Subject: [PATCH] refactor: update logging related logic --- .gitignore | 3 ++- common/constants.go | 4 ++++ common/init.go | 2 +- common/logger.go | 42 ++++++++++++++++++++++++++++++--------- common/utils.go | 9 +++++++++ controller/relay-audio.go | 7 ++++--- controller/relay-image.go | 7 ++++--- controller/relay-text.go | 18 +++++++++-------- controller/relay.go | 4 +++- main.go | 7 +++++-- middleware/auth.go | 32 ++++------------------------- middleware/distributor.go | 40 +++++-------------------------------- middleware/logger.go | 25 +++++++++++++++++++++++ middleware/request-id.go | 18 +++++++++++++++++ middleware/utils.go | 17 ++++++++++++++++ model/log.go | 7 +++++-- 16 files changed, 149 insertions(+), 93 deletions(-) create mode 100644 middleware/logger.go create mode 100644 middleware/request-id.go create mode 100644 middleware/utils.go diff --git a/.gitignore b/.gitignore index 0b2856cc..1b2cf071 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ upload *.exe *.db build -*.db-journal \ No newline at end of file +*.db-journal +logs \ No newline at end of file diff --git a/common/constants.go b/common/constants.go index 69bd12a8..794a795f 100644 --- a/common/constants.go +++ b/common/constants.go @@ -97,6 +97,10 @@ var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQU var BatchUpdateEnabled = false var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) +const ( + RequestIdKey = "X-Oneapi-Request-Id" +) + const ( RoleGuestUser = 0 RoleCommonUser = 1 diff --git a/common/init.go b/common/init.go index 0f22c69b..1e9c85ce 100644 --- a/common/init.go +++ b/common/init.go @@ -12,7 +12,7 @@ var ( Port = flag.Int("port", 3000, "the listening port") PrintVersion = flag.Bool("version", false, "print version and exit") PrintHelp = flag.Bool("help", false, "print help and exit") - LogDir = flag.String("log-dir", "", "specify the log directory") + LogDir = flag.String("log-dir", "./logs", "specify the log directory") ) func printHelp() { diff --git a/common/logger.go b/common/logger.go index 3658dbdb..780a6237 100644 --- a/common/logger.go +++ b/common/logger.go @@ -1,6 +1,7 @@ package common import ( + "context" "fmt" "github.com/gin-gonic/gin" "io" @@ -10,20 +11,21 @@ import ( "time" ) +const ( + loggerINFO = "INFO" + loggerWarn = "WARN" + loggerError = "ERR" +) + func SetupGinLog() { if *LogDir != "" { - commonLogPath := filepath.Join(*LogDir, "common.log") - errorLogPath := filepath.Join(*LogDir, "error.log") - commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) + fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal("failed to open log file") } - errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - log.Fatal("failed to open log file") - } - gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd) - gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd) + gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) + gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) } } @@ -37,6 +39,28 @@ func SysError(s string) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) } +func LogInfo(ctx context.Context, msg string) { + logHelper(ctx, loggerINFO, msg) +} + +func LogWarn(ctx context.Context, msg string) { + logHelper(ctx, loggerWarn, msg) +} + +func LogError(ctx context.Context, msg string) { + logHelper(ctx, loggerError, msg) +} + +func logHelper(ctx context.Context, level string, msg string) { + writer := gin.DefaultErrorWriter + if level == loggerINFO { + writer = gin.DefaultWriter + } + id := ctx.Value(RequestIdKey) + now := time.Now() + _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) +} + func FatalLog(v ...any) { t := time.Now() _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) diff --git a/common/utils.go b/common/utils.go index bb9b7e0c..ab901b77 100644 --- a/common/utils.go +++ b/common/utils.go @@ -171,6 +171,11 @@ func GetTimestamp() int64 { return time.Now().Unix() } +func GetTimeString() string { + now := time.Now() + return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) +} + func Max(a int, b int) int { if a >= b { return a @@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int { } return num } + +func MessageWithRequestId(message string, id string) string { + return fmt.Sprintf("%s (request id: %s)", message, id) +} diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 277ab404..a7bc670b 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -91,7 +92,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } var audioResponse AudioResponse - defer func() { + defer func(ctx context.Context) { go func() { quota := countTokenText(audioResponse.Text, audioModel) quotaDelta := quota - preConsumedQuota @@ -106,13 +107,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent) + model.RecordConsumeLog(ctx, userId, 0, 0, audioModel, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) } }() - }() + }(c.Request.Context()) responseBody, err := io.ReadAll(resp.Body) diff --git a/controller/relay-image.go b/controller/relay-image.go index de623288..b1a22570 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -124,7 +125,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } var textResponse ImageResponse - defer func() { + defer func(ctx context.Context) { if consumeQuota { err := model.PostConsumeTokenQuota(tokenId, quota) if err != nil { @@ -137,13 +138,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent) + model.RecordConsumeLog(ctx, userId, 0, 0, imageModel, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) } } - }() + }(c.Request.Context()) if consumeQuota { responseBody, err := io.ReadAll(resp.Body) diff --git a/controller/relay-text.go b/controller/relay-text.go index 2cd5598a..6d481983 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -210,6 +211,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 + common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) } if consumeQuota && preConsumedQuota > 0 { err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) @@ -348,13 +350,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if resp.StatusCode != http.StatusOK { if preConsumedQuota != 0 { - go func() { + go func(ctx context.Context) { // return pre-consumed quota err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) if err != nil { - common.SysError("error return pre-consumed quota: " + err.Error()) + common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) } - }() + }(c.Request.Context()) } return relayErrorHandler(resp) } @@ -364,7 +366,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { tokenName := c.GetString("token_name") channelId := c.GetInt("channel_id") - defer func() { + defer func(ctx context.Context) { // c.Writer.Flush() go func() { if consumeQuota { @@ -387,21 +389,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + common.LogError(ctx, "error consuming token remain quota: "+err.Error()) } err = model.CacheUpdateUserQuota(userId) if err != nil { - common.SysError("error update user quota cache: " + err.Error()) + common.LogError(ctx, "error update user quota cache: "+err.Error()) } if quota != 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) + model.RecordConsumeLog(ctx, userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateChannelUsedQuota(channelId, quota) } } }() - }() + }(c.Request.Context()) switch apiType { case APITypeOpenAI: if isStream { diff --git a/controller/relay.go b/controller/relay.go index d20663f6..1926110e 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -196,6 +196,7 @@ func Relay(c *gin.Context) { err = relayTextHelper(c, relayMode) } if err != nil { + requestId := c.GetString(common.RequestIdKey) retryTimesStr := c.Query("retry") retryTimes, _ := strconv.Atoi(retryTimesStr) if retryTimesStr == "" { @@ -207,12 +208,13 @@ func Relay(c *gin.Context) { if err.StatusCode == http.StatusTooManyRequests { err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" } + err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) c.JSON(err.StatusCode, gin.H{ "error": err.OpenAIError, }) } channelId := c.GetInt("channel_id") - common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) + common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) // https://platform.openai.com/docs/guides/error-codes/api-errors if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { channelId := c.GetInt("channel_id") diff --git a/main.go b/main.go index 2aa52876..c7a8a2d6 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "github.com/gin-gonic/gin" "one-api/common" "one-api/controller" + "one-api/middleware" "one-api/model" "one-api/router" "os" @@ -84,10 +85,12 @@ func main() { controller.InitTokenEncoders() // Initialize HTTP server - server := gin.Default() + server := gin.New() + server.Use(gin.Recovery()) // This will cause SSE not to work!!! //server.Use(gzip.Gzip(gzip.DefaultCompression)) - + server.Use(middleware.RequestId()) + middleware.SetUpLogger(server) // Initialize session store store := cookie.NewStore([]byte(common.SessionSecret)) server.Use(sessions.Sessions("session", store)) diff --git a/middleware/auth.go b/middleware/auth.go index 95516d6e..dfbc7dbd 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -91,34 +91,16 @@ func TokenAuth() func(c *gin.Context) { key = parts[0] token, err := model.ValidateUserToken(key) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusUnauthorized, err.Error()) return } userEnabled, err := model.IsUserEnabled(token.UserId) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusInternalServerError, err.Error()) return } if !userEnabled { - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "message": "用户已被封禁", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusForbidden, "用户已被封禁") return } c.Set("id", token.UserId) @@ -134,13 +116,7 @@ func TokenAuth() func(c *gin.Context) { if model.IsAdmin(token.UserId) { c.Set("channelId", parts[1]) } else { - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "message": "普通用户不支持指定渠道", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") return } } diff --git a/middleware/distributor.go b/middleware/distributor.go index e8b76596..ab374a85 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -25,34 +25,16 @@ func Distribute() func(c *gin.Context) { if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "无效的渠道 ID", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") return } channel, err = model.GetChannelById(id, true) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "无效的渠道 ID", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") return } if channel.Status != common.ChannelStatusEnabled { - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "message": "该渠道已被禁用", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") return } } else { @@ -63,13 +45,7 @@ func Distribute() func(c *gin.Context) { err = common.UnmarshalBodyReusable(c, &modelRequest) } if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "无效的请求", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusBadRequest, "无效的请求") return } if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { @@ -99,13 +75,7 @@ func Distribute() func(c *gin.Context) { common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" } - c.JSON(http.StatusServiceUnavailable, gin.H{ - "error": gin.H{ - "message": message, - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusServiceUnavailable, message) return } } diff --git a/middleware/logger.go b/middleware/logger.go new file mode 100644 index 00000000..02f2e0a9 --- /dev/null +++ b/middleware/logger.go @@ -0,0 +1,25 @@ +package middleware + +import ( + "fmt" + "github.com/gin-gonic/gin" + "one-api/common" +) + +func SetUpLogger(server *gin.Engine) { + server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { + var requestID string + if param.Keys != nil { + requestID = param.Keys[common.RequestIdKey].(string) + } + return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", + param.TimeStamp.Format("2006/01/02 - 15:04:05"), + requestID, + param.StatusCode, + param.Latency, + param.ClientIP, + param.Method, + param.Path, + ) + })) +} diff --git a/middleware/request-id.go b/middleware/request-id.go new file mode 100644 index 00000000..e623be7a --- /dev/null +++ b/middleware/request-id.go @@ -0,0 +1,18 @@ +package middleware + +import ( + "context" + "github.com/gin-gonic/gin" + "one-api/common" +) + +func RequestId() func(c *gin.Context) { + return func(c *gin.Context) { + id := common.GetTimeString() + common.GetRandomString(8) + c.Set(common.RequestIdKey, id) + ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) + c.Request = c.Request.WithContext(ctx) + c.Header(common.RequestIdKey, id) + c.Next() + } +} diff --git a/middleware/utils.go b/middleware/utils.go new file mode 100644 index 00000000..536125cc --- /dev/null +++ b/middleware/utils.go @@ -0,0 +1,17 @@ +package middleware + +import ( + "github.com/gin-gonic/gin" + "one-api/common" +) + +func abortWithMessage(c *gin.Context, statusCode int, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), + "type": "one_api_error", + }, + }) + c.Abort() + common.LogError(c.Request.Context(), message) +} diff --git a/model/log.go b/model/log.go index b0d6409a..b6a72c26 100644 --- a/model/log.go +++ b/model/log.go @@ -1,6 +1,8 @@ package model import ( + "context" + "fmt" "gorm.io/gorm" "one-api/common" ) @@ -44,7 +46,8 @@ func RecordLog(userId int, logType int, content string) { } } -func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { +func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { + common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !common.LogConsumeEnabled { return } @@ -62,7 +65,7 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN } err := DB.Create(log).Error if err != nil { - common.SysError("failed to record log: " + err.Error()) + common.LogError(ctx, "failed to record log: "+err.Error()) } }