refactor: update logging related logic

This commit is contained in:
JustSong 2023-09-17 15:39:46 +08:00
parent 25c4c111ab
commit 42451d9d02
16 changed files with 149 additions and 93 deletions

1
.gitignore vendored
View File

@ -5,3 +5,4 @@ upload
*.db *.db
build build
*.db-journal *.db-journal
logs

View File

@ -97,6 +97,10 @@ var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQU
var BatchUpdateEnabled = false var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
const ( const (
RoleGuestUser = 0 RoleGuestUser = 0
RoleCommonUser = 1 RoleCommonUser = 1

View File

@ -12,7 +12,7 @@ var (
Port = flag.Int("port", 3000, "the listening port") Port = flag.Int("port", 3000, "the listening port")
PrintVersion = flag.Bool("version", false, "print version and exit") PrintVersion = flag.Bool("version", false, "print version and exit")
PrintHelp = flag.Bool("help", false, "print help and exit") PrintHelp = flag.Bool("help", false, "print help and exit")
LogDir = flag.String("log-dir", "", "specify the log directory") LogDir = flag.String("log-dir", "./logs", "specify the log directory")
) )
func printHelp() { func printHelp() {

View File

@ -1,6 +1,7 @@
package common package common
import ( import (
"context"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
@ -10,20 +11,21 @@ import (
"time" "time"
) )
const (
loggerINFO = "INFO"
loggerWarn = "WARN"
loggerError = "ERR"
)
func SetupGinLog() { func SetupGinLog() {
if *LogDir != "" { if *LogDir != "" {
commonLogPath := filepath.Join(*LogDir, "common.log") logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
errorLogPath := filepath.Join(*LogDir, "error.log") fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
log.Fatal("failed to open log file") log.Fatal("failed to open log file")
} }
errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
if err != nil { gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
log.Fatal("failed to open log file")
}
gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd)
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd)
} }
} }
@ -37,6 +39,28 @@ func SysError(s string) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
} }
func LogInfo(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg)
}
func LogWarn(ctx context.Context, msg string) {
logHelper(ctx, loggerWarn, msg)
}
func LogError(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg)
}
func logHelper(ctx context.Context, level string, msg string) {
writer := gin.DefaultErrorWriter
if level == loggerINFO {
writer = gin.DefaultWriter
}
id := ctx.Value(RequestIdKey)
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
}
func FatalLog(v ...any) { func FatalLog(v ...any) {
t := time.Now() t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)

View File

@ -171,6 +171,11 @@ func GetTimestamp() int64 {
return time.Now().Unix() return time.Now().Unix()
} }
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func Max(a int, b int) int { func Max(a int, b int) int {
if a >= b { if a >= b {
return a return a
@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int {
} }
return num return num
} }
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}

View File

@ -2,6 +2,7 @@ package controller
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -91,7 +92,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
} }
var audioResponse AudioResponse var audioResponse AudioResponse
defer func() { defer func(ctx context.Context) {
go func() { go func() {
quota := countTokenText(audioResponse.Text, audioModel) quota := countTokenText(audioResponse.Text, audioModel)
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
@ -106,13 +107,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 { if quota != 0 {
tokenName := c.GetString("token_name") tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent) model.RecordConsumeLog(ctx, userId, 0, 0, audioModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota) model.UpdateChannelUsedQuota(channelId, quota)
} }
}() }()
}() }(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)

View File

@ -2,6 +2,7 @@ package controller
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -124,7 +125,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
} }
var textResponse ImageResponse var textResponse ImageResponse
defer func() { defer func(ctx context.Context) {
if consumeQuota { if consumeQuota {
err := model.PostConsumeTokenQuota(tokenId, quota) err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil { if err != nil {
@ -137,13 +138,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 { if quota != 0 {
tokenName := c.GetString("token_name") tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent) model.RecordConsumeLog(ctx, userId, 0, 0, imageModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota) model.UpdateChannelUsedQuota(channelId, quota)
} }
} }
}() }(c.Request.Context())
if consumeQuota { if consumeQuota {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)

View File

@ -2,6 +2,7 @@ package controller
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -210,6 +211,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
// in this case, we do not pre-consume quota // in this case, we do not pre-consume quota
// because the user has enough quota // because the user has enough quota
preConsumedQuota = 0 preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
} }
if consumeQuota && preConsumedQuota > 0 { if consumeQuota && preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
@ -348,13 +350,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
if preConsumedQuota != 0 { if preConsumedQuota != 0 {
go func() { go func(ctx context.Context) {
// return pre-consumed quota // return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil { if err != nil {
common.SysError("error return pre-consumed quota: " + err.Error()) common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
} }
}() }(c.Request.Context())
} }
return relayErrorHandler(resp) return relayErrorHandler(resp)
} }
@ -364,7 +366,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
tokenName := c.GetString("token_name") tokenName := c.GetString("token_name")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
defer func() { defer func(ctx context.Context) {
// c.Writer.Flush() // c.Writer.Flush()
go func() { go func() {
if consumeQuota { if consumeQuota {
@ -387,21 +389,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta) err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.LogError(ctx, "error consuming token remain quota: "+err.Error())
} }
err = model.CacheUpdateUserQuota(userId) err = model.CacheUpdateUserQuota(userId)
if err != nil { if err != nil {
common.SysError("error update user quota cache: " + err.Error()) common.LogError(ctx, "error update user quota cache: "+err.Error())
} }
if quota != 0 { if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) model.RecordConsumeLog(ctx, userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota) model.UpdateChannelUsedQuota(channelId, quota)
} }
} }
}() }()
}() }(c.Request.Context())
switch apiType { switch apiType {
case APITypeOpenAI: case APITypeOpenAI:
if isStream { if isStream {

View File

@ -196,6 +196,7 @@ func Relay(c *gin.Context) {
err = relayTextHelper(c, relayMode) err = relayTextHelper(c, relayMode)
} }
if err != nil { if err != nil {
requestId := c.GetString(common.RequestIdKey)
retryTimesStr := c.Query("retry") retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr) retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" { if retryTimesStr == "" {
@ -207,12 +208,13 @@ func Relay(c *gin.Context) {
if err.StatusCode == http.StatusTooManyRequests { if err.StatusCode == http.StatusTooManyRequests {
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
} }
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
c.JSON(err.StatusCode, gin.H{ c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError, "error": err.OpenAIError,
}) })
} }
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors // https://platform.openai.com/docs/guides/error-codes/api-errors
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")

View File

@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
"one-api/controller" "one-api/controller"
"one-api/middleware"
"one-api/model" "one-api/model"
"one-api/router" "one-api/router"
"os" "os"
@ -84,10 +85,12 @@ func main() {
controller.InitTokenEncoders() controller.InitTokenEncoders()
// Initialize HTTP server // Initialize HTTP server
server := gin.Default() server := gin.New()
server.Use(gin.Recovery())
// This will cause SSE not to work!!! // This will cause SSE not to work!!!
//server.Use(gzip.Gzip(gzip.DefaultCompression)) //server.Use(gzip.Gzip(gzip.DefaultCompression))
server.Use(middleware.RequestId())
middleware.SetUpLogger(server)
// Initialize session store // Initialize session store
store := cookie.NewStore([]byte(common.SessionSecret)) store := cookie.NewStore([]byte(common.SessionSecret))
server.Use(sessions.Sessions("session", store)) server.Use(sessions.Sessions("session", store))

View File

@ -91,34 +91,16 @@ func TokenAuth() func(c *gin.Context) {
key = parts[0] key = parts[0]
token, err := model.ValidateUserToken(key) token, err := model.ValidateUserToken(key)
if err != nil { if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{ abortWithMessage(c, http.StatusUnauthorized, err.Error())
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
c.Abort()
return return
} }
userEnabled, err := model.IsUserEnabled(token.UserId) userEnabled, err := model.IsUserEnabled(token.UserId)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{ abortWithMessage(c, http.StatusInternalServerError, err.Error())
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
c.Abort()
return return
} }
if !userEnabled { if !userEnabled {
c.JSON(http.StatusForbidden, gin.H{ abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
"error": gin.H{
"message": "用户已被封禁",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
c.Set("id", token.UserId) c.Set("id", token.UserId)
@ -134,13 +116,7 @@ func TokenAuth() func(c *gin.Context) {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1]) c.Set("channelId", parts[1])
} else { } else {
c.JSON(http.StatusForbidden, gin.H{ abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
"error": gin.H{
"message": "普通用户不支持指定渠道",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
} }

View File

@ -25,34 +25,16 @@ func Distribute() func(c *gin.Context) {
if ok { if ok {
id, err := strconv.Atoi(channelId.(string)) id, err := strconv.Atoi(channelId.(string))
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{ abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
"error": gin.H{
"message": "无效的渠道 ID",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
channel, err = model.GetChannelById(id, true) channel, err = model.GetChannelById(id, true)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{ abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
"error": gin.H{
"message": "无效的渠道 ID",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
if channel.Status != common.ChannelStatusEnabled { if channel.Status != common.ChannelStatusEnabled {
c.JSON(http.StatusForbidden, gin.H{ abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
"error": gin.H{
"message": "该渠道已被禁用",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
} else { } else {
@ -63,13 +45,7 @@ func Distribute() func(c *gin.Context) {
err = common.UnmarshalBodyReusable(c, &modelRequest) err = common.UnmarshalBodyReusable(c, &modelRequest)
} }
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{ abortWithMessage(c, http.StatusBadRequest, "无效的请求")
"error": gin.H{
"message": "无效的请求",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
@ -99,13 +75,7 @@ func Distribute() func(c *gin.Context) {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员" message = "数据库一致性已被破坏,请联系管理员"
} }
c.JSON(http.StatusServiceUnavailable, gin.H{ abortWithMessage(c, http.StatusServiceUnavailable, message)
"error": gin.H{
"message": message,
"type": "one_api_error",
},
})
c.Abort()
return return
} }
} }

25
middleware/logger.go Normal file
View File

@ -0,0 +1,25 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
)
func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string
if param.Keys != nil {
requestID = param.Keys[common.RequestIdKey].(string)
}
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
requestID,
param.StatusCode,
param.Latency,
param.ClientIP,
param.Method,
param.Path,
)
}))
}

18
middleware/request-id.go Normal file
View File

@ -0,0 +1,18 @@
package middleware
import (
"context"
"github.com/gin-gonic/gin"
"one-api/common"
)
func RequestId() func(c *gin.Context) {
return func(c *gin.Context) {
id := common.GetTimeString() + common.GetRandomString(8)
c.Set(common.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx)
c.Header(common.RequestIdKey, id)
c.Next()
}
}

17
middleware/utils.go Normal file
View File

@ -0,0 +1,17 @@
package middleware
import (
"github.com/gin-gonic/gin"
"one-api/common"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
"type": "one_api_error",
},
})
c.Abort()
common.LogError(c.Request.Context(), message)
}

View File

@ -1,6 +1,8 @@
package model package model
import ( import (
"context"
"fmt"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
) )
@ -44,7 +46,8 @@ func RecordLog(userId int, logType int, content string) {
} }
} }
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled { if !common.LogConsumeEnabled {
return return
} }
@ -62,7 +65,7 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
} }
err := DB.Create(log).Error err := DB.Create(log).Error
if err != nil { if err != nil {
common.SysError("failed to record log: " + err.Error()) common.LogError(ctx, "failed to record log: "+err.Error())
} }
} }