refactor: update logging related logic
This commit is contained in:
parent
25c4c111ab
commit
42451d9d02
1
.gitignore
vendored
1
.gitignore
vendored
@ -5,3 +5,4 @@ upload
|
|||||||
*.db
|
*.db
|
||||||
build
|
build
|
||||||
*.db-journal
|
*.db-journal
|
||||||
|
logs
|
@ -97,6 +97,10 @@ var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQU
|
|||||||
var BatchUpdateEnabled = false
|
var BatchUpdateEnabled = false
|
||||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||||
|
|
||||||
|
const (
|
||||||
|
RequestIdKey = "X-Oneapi-Request-Id"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RoleGuestUser = 0
|
RoleGuestUser = 0
|
||||||
RoleCommonUser = 1
|
RoleCommonUser = 1
|
||||||
|
@ -12,7 +12,7 @@ var (
|
|||||||
Port = flag.Int("port", 3000, "the listening port")
|
Port = flag.Int("port", 3000, "the listening port")
|
||||||
PrintVersion = flag.Bool("version", false, "print version and exit")
|
PrintVersion = flag.Bool("version", false, "print version and exit")
|
||||||
PrintHelp = flag.Bool("help", false, "print help and exit")
|
PrintHelp = flag.Bool("help", false, "print help and exit")
|
||||||
LogDir = flag.String("log-dir", "", "specify the log directory")
|
LogDir = flag.String("log-dir", "./logs", "specify the log directory")
|
||||||
)
|
)
|
||||||
|
|
||||||
func printHelp() {
|
func printHelp() {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
@ -10,20 +11,21 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
loggerINFO = "INFO"
|
||||||
|
loggerWarn = "WARN"
|
||||||
|
loggerError = "ERR"
|
||||||
|
)
|
||||||
|
|
||||||
func SetupGinLog() {
|
func SetupGinLog() {
|
||||||
if *LogDir != "" {
|
if *LogDir != "" {
|
||||||
commonLogPath := filepath.Join(*LogDir, "common.log")
|
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
|
||||||
errorLogPath := filepath.Join(*LogDir, "error.log")
|
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||||
commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("failed to open log file")
|
log.Fatal("failed to open log file")
|
||||||
}
|
}
|
||||||
errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
|
||||||
if err != nil {
|
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
|
||||||
log.Fatal("failed to open log file")
|
|
||||||
}
|
|
||||||
gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd)
|
|
||||||
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,6 +39,28 @@ func SysError(s string) {
|
|||||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func LogInfo(ctx context.Context, msg string) {
|
||||||
|
logHelper(ctx, loggerINFO, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LogWarn(ctx context.Context, msg string) {
|
||||||
|
logHelper(ctx, loggerWarn, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LogError(ctx context.Context, msg string) {
|
||||||
|
logHelper(ctx, loggerError, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func logHelper(ctx context.Context, level string, msg string) {
|
||||||
|
writer := gin.DefaultErrorWriter
|
||||||
|
if level == loggerINFO {
|
||||||
|
writer = gin.DefaultWriter
|
||||||
|
}
|
||||||
|
id := ctx.Value(RequestIdKey)
|
||||||
|
now := time.Now()
|
||||||
|
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
||||||
|
}
|
||||||
|
|
||||||
func FatalLog(v ...any) {
|
func FatalLog(v ...any) {
|
||||||
t := time.Now()
|
t := time.Now()
|
||||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
||||||
|
@ -171,6 +171,11 @@ func GetTimestamp() int64 {
|
|||||||
return time.Now().Unix()
|
return time.Now().Unix()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetTimeString() string {
|
||||||
|
now := time.Now()
|
||||||
|
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
|
||||||
|
}
|
||||||
|
|
||||||
func Max(a int, b int) int {
|
func Max(a int, b int) int {
|
||||||
if a >= b {
|
if a >= b {
|
||||||
return a
|
return a
|
||||||
@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int {
|
|||||||
}
|
}
|
||||||
return num
|
return num
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MessageWithRequestId(message string, id string) string {
|
||||||
|
return fmt.Sprintf("%s (request id: %s)", message, id)
|
||||||
|
}
|
||||||
|
@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -91,7 +92,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
}
|
}
|
||||||
var audioResponse AudioResponse
|
var audioResponse AudioResponse
|
||||||
|
|
||||||
defer func() {
|
defer func(ctx context.Context) {
|
||||||
go func() {
|
go func() {
|
||||||
quota := countTokenText(audioResponse.Text, audioModel)
|
quota := countTokenText(audioResponse.Text, audioModel)
|
||||||
quotaDelta := quota - preConsumedQuota
|
quotaDelta := quota - preConsumedQuota
|
||||||
@ -106,13 +107,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
|
model.RecordConsumeLog(ctx, userId, 0, 0, audioModel, tokenName, quota, logContent)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}()
|
}(c.Request.Context())
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -124,7 +125,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
}
|
}
|
||||||
var textResponse ImageResponse
|
var textResponse ImageResponse
|
||||||
|
|
||||||
defer func() {
|
defer func(ctx context.Context) {
|
||||||
if consumeQuota {
|
if consumeQuota {
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
err := model.PostConsumeTokenQuota(tokenId, quota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -137,13 +138,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
|
model.RecordConsumeLog(ctx, userId, 0, 0, imageModel, tokenName, quota, logContent)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}(c.Request.Context())
|
||||||
|
|
||||||
if consumeQuota {
|
if consumeQuota {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -210,6 +211,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
// in this case, we do not pre-consume quota
|
// in this case, we do not pre-consume quota
|
||||||
// because the user has enough quota
|
// because the user has enough quota
|
||||||
preConsumedQuota = 0
|
preConsumedQuota = 0
|
||||||
|
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
||||||
}
|
}
|
||||||
if consumeQuota && preConsumedQuota > 0 {
|
if consumeQuota && preConsumedQuota > 0 {
|
||||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||||
@ -348,13 +350,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
if preConsumedQuota != 0 {
|
if preConsumedQuota != 0 {
|
||||||
go func() {
|
go func(ctx context.Context) {
|
||||||
// return pre-consumed quota
|
// return pre-consumed quota
|
||||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error return pre-consumed quota: " + err.Error())
|
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||||
}
|
}
|
||||||
}()
|
}(c.Request.Context())
|
||||||
}
|
}
|
||||||
return relayErrorHandler(resp)
|
return relayErrorHandler(resp)
|
||||||
}
|
}
|
||||||
@ -364,7 +366,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
|
|
||||||
defer func() {
|
defer func(ctx context.Context) {
|
||||||
// c.Writer.Flush()
|
// c.Writer.Flush()
|
||||||
go func() {
|
go func() {
|
||||||
if consumeQuota {
|
if consumeQuota {
|
||||||
@ -387,21 +389,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
quotaDelta := quota - preConsumedQuota
|
quotaDelta := quota - preConsumedQuota
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||||
}
|
}
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
err = model.CacheUpdateUserQuota(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error update user quota cache: " + err.Error())
|
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||||
}
|
}
|
||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
model.RecordConsumeLog(ctx, userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}()
|
}(c.Request.Context())
|
||||||
switch apiType {
|
switch apiType {
|
||||||
case APITypeOpenAI:
|
case APITypeOpenAI:
|
||||||
if isStream {
|
if isStream {
|
||||||
|
@ -196,6 +196,7 @@ func Relay(c *gin.Context) {
|
|||||||
err = relayTextHelper(c, relayMode)
|
err = relayTextHelper(c, relayMode)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
retryTimesStr := c.Query("retry")
|
retryTimesStr := c.Query("retry")
|
||||||
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
||||||
if retryTimesStr == "" {
|
if retryTimesStr == "" {
|
||||||
@ -207,12 +208,13 @@ func Relay(c *gin.Context) {
|
|||||||
if err.StatusCode == http.StatusTooManyRequests {
|
if err.StatusCode == http.StatusTooManyRequests {
|
||||||
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||||
}
|
}
|
||||||
|
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
|
||||||
c.JSON(err.StatusCode, gin.H{
|
c.JSON(err.StatusCode, gin.H{
|
||||||
"error": err.OpenAIError,
|
"error": err.OpenAIError,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
||||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||||
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
|
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
|
7
main.go
7
main.go
@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/controller"
|
"one-api/controller"
|
||||||
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/router"
|
"one-api/router"
|
||||||
"os"
|
"os"
|
||||||
@ -84,10 +85,12 @@ func main() {
|
|||||||
controller.InitTokenEncoders()
|
controller.InitTokenEncoders()
|
||||||
|
|
||||||
// Initialize HTTP server
|
// Initialize HTTP server
|
||||||
server := gin.Default()
|
server := gin.New()
|
||||||
|
server.Use(gin.Recovery())
|
||||||
// This will cause SSE not to work!!!
|
// This will cause SSE not to work!!!
|
||||||
//server.Use(gzip.Gzip(gzip.DefaultCompression))
|
//server.Use(gzip.Gzip(gzip.DefaultCompression))
|
||||||
|
server.Use(middleware.RequestId())
|
||||||
|
middleware.SetUpLogger(server)
|
||||||
// Initialize session store
|
// Initialize session store
|
||||||
store := cookie.NewStore([]byte(common.SessionSecret))
|
store := cookie.NewStore([]byte(common.SessionSecret))
|
||||||
server.Use(sessions.Sessions("session", store))
|
server.Use(sessions.Sessions("session", store))
|
||||||
|
@ -91,34 +91,16 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
key = parts[0]
|
key = parts[0]
|
||||||
token, err := model.ValidateUserToken(key)
|
token, err := model.ValidateUserToken(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{
|
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
||||||
"error": gin.H{
|
|
||||||
"message": err.Error(),
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userEnabled, err := model.IsUserEnabled(token.UserId)
|
userEnabled, err := model.IsUserEnabled(token.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||||
"error": gin.H{
|
|
||||||
"message": err.Error(),
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !userEnabled {
|
if !userEnabled {
|
||||||
c.JSON(http.StatusForbidden, gin.H{
|
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||||
"error": gin.H{
|
|
||||||
"message": "用户已被封禁",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set("id", token.UserId)
|
c.Set("id", token.UserId)
|
||||||
@ -134,13 +116,7 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
c.Set("channelId", parts[1])
|
c.Set("channelId", parts[1])
|
||||||
} else {
|
} else {
|
||||||
c.JSON(http.StatusForbidden, gin.H{
|
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||||
"error": gin.H{
|
|
||||||
"message": "普通用户不支持指定渠道",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -25,34 +25,16 @@ func Distribute() func(c *gin.Context) {
|
|||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
|
||||||
"error": gin.H{
|
|
||||||
"message": "无效的渠道 ID",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel, err = model.GetChannelById(id, true)
|
channel, err = model.GetChannelById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
|
||||||
"error": gin.H{
|
|
||||||
"message": "无效的渠道 ID",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if channel.Status != common.ChannelStatusEnabled {
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
c.JSON(http.StatusForbidden, gin.H{
|
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
||||||
"error": gin.H{
|
|
||||||
"message": "该渠道已被禁用",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -63,13 +45,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
|
||||||
"error": gin.H{
|
|
||||||
"message": "无效的请求",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||||
@ -99,13 +75,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||||
message = "数据库一致性已被破坏,请联系管理员"
|
message = "数据库一致性已被破坏,请联系管理员"
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
abortWithMessage(c, http.StatusServiceUnavailable, message)
|
||||||
"error": gin.H{
|
|
||||||
"message": message,
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
25
middleware/logger.go
Normal file
25
middleware/logger.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetUpLogger(server *gin.Engine) {
|
||||||
|
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||||
|
var requestID string
|
||||||
|
if param.Keys != nil {
|
||||||
|
requestID = param.Keys[common.RequestIdKey].(string)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
|
||||||
|
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
||||||
|
requestID,
|
||||||
|
param.StatusCode,
|
||||||
|
param.Latency,
|
||||||
|
param.ClientIP,
|
||||||
|
param.Method,
|
||||||
|
param.Path,
|
||||||
|
)
|
||||||
|
}))
|
||||||
|
}
|
18
middleware/request-id.go
Normal file
18
middleware/request-id.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RequestId() func(c *gin.Context) {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
id := common.GetTimeString() + common.GetRandomString(8)
|
||||||
|
c.Set(common.RequestIdKey, id)
|
||||||
|
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
c.Header(common.RequestIdKey, id)
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
17
middleware/utils.go
Normal file
17
middleware/utils.go
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||||
|
c.JSON(statusCode, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
||||||
|
"type": "one_api_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
common.LogError(c.Request.Context(), message)
|
||||||
|
}
|
@ -1,6 +1,8 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
)
|
)
|
||||||
@ -44,7 +46,8 @@ func RecordLog(userId int, logType int, content string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
|
func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
|
||||||
|
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||||
if !common.LogConsumeEnabled {
|
if !common.LogConsumeEnabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -62,7 +65,7 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
|
|||||||
}
|
}
|
||||||
err := DB.Create(log).Error
|
err := DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to record log: " + err.Error())
|
common.LogError(ctx, "failed to record log: "+err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user