From 959bcdef88935a5df8624ab68c7198bffe43c462 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Sep 2023 11:30:20 +0800 Subject: [PATCH 01/10] chore: update error code --- controller/billing.go | 2 +- controller/relay-utils.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/controller/billing.go b/controller/billing.go index 79eae1e2..42e86aea 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -29,7 +29,7 @@ func GetSubscription(c *gin.Context) { if err != nil { openAIError := OpenAIError{ Message: err.Error(), - Type: "one_api_error", + Type: "upstream_error", } c.JSON(200, gin.H{ "error": openAIError, diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 9010d275..3d5948fc 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -146,7 +146,7 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr StatusCode: resp.StatusCode, OpenAIError: OpenAIError{ Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), - Type: "one_api_error", + Type: "upstream_error", Code: "bad_response_status_code", Param: strconv.Itoa(resp.StatusCode), }, From 0d50ad4b2b4d51b5c19b38ef3213376325ab3c02 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Sep 2023 11:34:06 +0800 Subject: [PATCH 02/10] chore: update channel test prompt --- web/src/components/ChannelsTable.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 5eb39783..f712f11a 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -1,7 +1,7 @@ import React, { useEffect, useState } from 'react'; import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; import { Link } from 'react-router-dom'; -import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; +import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers'; import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; import { renderGroup, renderNumber } from '../helpers/render'; @@ -195,6 +195,7 @@ const ChannelsTable = () => { showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); } else { showError(message); + showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。") } }; From 25c4c111abc94996cd6c9607abb6d335ab042153 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Sep 2023 11:44:38 +0800 Subject: [PATCH 03/10] fix: only enable cors for relay routers to avoid csrf attack --- main.go | 2 -- router/relay-router.go | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/main.go b/main.go index 8c5f2f31..2aa52876 100644 --- a/main.go +++ b/main.go @@ -7,7 +7,6 @@ import ( "github.com/gin-gonic/gin" "one-api/common" "one-api/controller" - "one-api/middleware" "one-api/model" "one-api/router" "os" @@ -88,7 +87,6 @@ func main() { server := gin.Default() // This will cause SSE not to work!!! //server.Use(gzip.Gzip(gzip.DefaultCompression)) - server.Use(middleware.CORS()) // Initialize session store store := cookie.NewStore([]byte(common.SessionSecret)) diff --git a/router/relay-router.go b/router/relay-router.go index a76e42cf..e84f02db 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -8,6 +8,7 @@ import ( ) func SetRelayRouter(router *gin.Engine) { + router.Use(middleware.CORS()) // https://platform.openai.com/docs/api-reference/introduction modelsRouter := router.Group("/v1/models") modelsRouter.Use(middleware.TokenAuth()) From 42451d9d02e22cb546b746b68e68cec528603e5f Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Sep 2023 15:39:46 +0800 Subject: [PATCH 04/10] refactor: update logging related logic --- .gitignore | 3 ++- common/constants.go | 4 ++++ common/init.go | 2 +- common/logger.go | 42 ++++++++++++++++++++++++++++++--------- common/utils.go | 9 +++++++++ controller/relay-audio.go | 7 ++++--- controller/relay-image.go | 7 ++++--- controller/relay-text.go | 18 +++++++++-------- controller/relay.go | 4 +++- main.go | 7 +++++-- middleware/auth.go | 32 ++++------------------------- middleware/distributor.go | 40 +++++-------------------------------- middleware/logger.go | 25 +++++++++++++++++++++++ middleware/request-id.go | 18 +++++++++++++++++ middleware/utils.go | 17 ++++++++++++++++ model/log.go | 7 +++++-- 16 files changed, 149 insertions(+), 93 deletions(-) create mode 100644 middleware/logger.go create mode 100644 middleware/request-id.go create mode 100644 middleware/utils.go diff --git a/.gitignore b/.gitignore index 0b2856cc..1b2cf071 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ upload *.exe *.db build -*.db-journal \ No newline at end of file +*.db-journal +logs \ No newline at end of file diff --git a/common/constants.go b/common/constants.go index 69bd12a8..794a795f 100644 --- a/common/constants.go +++ b/common/constants.go @@ -97,6 +97,10 @@ var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQU var BatchUpdateEnabled = false var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) +const ( + RequestIdKey = "X-Oneapi-Request-Id" +) + const ( RoleGuestUser = 0 RoleCommonUser = 1 diff --git a/common/init.go b/common/init.go index 0f22c69b..1e9c85ce 100644 --- a/common/init.go +++ b/common/init.go @@ -12,7 +12,7 @@ var ( Port = flag.Int("port", 3000, "the listening port") PrintVersion = flag.Bool("version", false, "print version and exit") PrintHelp = flag.Bool("help", false, "print help and exit") - LogDir = flag.String("log-dir", "", "specify the log directory") + LogDir = flag.String("log-dir", "./logs", "specify the log directory") ) func printHelp() { diff --git a/common/logger.go b/common/logger.go index 3658dbdb..780a6237 100644 --- a/common/logger.go +++ b/common/logger.go @@ -1,6 +1,7 @@ package common import ( + "context" "fmt" "github.com/gin-gonic/gin" "io" @@ -10,20 +11,21 @@ import ( "time" ) +const ( + loggerINFO = "INFO" + loggerWarn = "WARN" + loggerError = "ERR" +) + func SetupGinLog() { if *LogDir != "" { - commonLogPath := filepath.Join(*LogDir, "common.log") - errorLogPath := filepath.Join(*LogDir, "error.log") - commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) + fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal("failed to open log file") } - errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - log.Fatal("failed to open log file") - } - gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd) - gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd) + gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) + gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) } } @@ -37,6 +39,28 @@ func SysError(s string) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) } +func LogInfo(ctx context.Context, msg string) { + logHelper(ctx, loggerINFO, msg) +} + +func LogWarn(ctx context.Context, msg string) { + logHelper(ctx, loggerWarn, msg) +} + +func LogError(ctx context.Context, msg string) { + logHelper(ctx, loggerError, msg) +} + +func logHelper(ctx context.Context, level string, msg string) { + writer := gin.DefaultErrorWriter + if level == loggerINFO { + writer = gin.DefaultWriter + } + id := ctx.Value(RequestIdKey) + now := time.Now() + _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) +} + func FatalLog(v ...any) { t := time.Now() _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) diff --git a/common/utils.go b/common/utils.go index bb9b7e0c..ab901b77 100644 --- a/common/utils.go +++ b/common/utils.go @@ -171,6 +171,11 @@ func GetTimestamp() int64 { return time.Now().Unix() } +func GetTimeString() string { + now := time.Now() + return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) +} + func Max(a int, b int) int { if a >= b { return a @@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int { } return num } + +func MessageWithRequestId(message string, id string) string { + return fmt.Sprintf("%s (request id: %s)", message, id) +} diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 277ab404..a7bc670b 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -91,7 +92,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } var audioResponse AudioResponse - defer func() { + defer func(ctx context.Context) { go func() { quota := countTokenText(audioResponse.Text, audioModel) quotaDelta := quota - preConsumedQuota @@ -106,13 +107,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent) + model.RecordConsumeLog(ctx, userId, 0, 0, audioModel, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) } }() - }() + }(c.Request.Context()) responseBody, err := io.ReadAll(resp.Body) diff --git a/controller/relay-image.go b/controller/relay-image.go index de623288..b1a22570 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -124,7 +125,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } var textResponse ImageResponse - defer func() { + defer func(ctx context.Context) { if consumeQuota { err := model.PostConsumeTokenQuota(tokenId, quota) if err != nil { @@ -137,13 +138,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent) + model.RecordConsumeLog(ctx, userId, 0, 0, imageModel, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) } } - }() + }(c.Request.Context()) if consumeQuota { responseBody, err := io.ReadAll(resp.Body) diff --git a/controller/relay-text.go b/controller/relay-text.go index 2cd5598a..6d481983 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -210,6 +211,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 + common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) } if consumeQuota && preConsumedQuota > 0 { err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) @@ -348,13 +350,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if resp.StatusCode != http.StatusOK { if preConsumedQuota != 0 { - go func() { + go func(ctx context.Context) { // return pre-consumed quota err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) if err != nil { - common.SysError("error return pre-consumed quota: " + err.Error()) + common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) } - }() + }(c.Request.Context()) } return relayErrorHandler(resp) } @@ -364,7 +366,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { tokenName := c.GetString("token_name") channelId := c.GetInt("channel_id") - defer func() { + defer func(ctx context.Context) { // c.Writer.Flush() go func() { if consumeQuota { @@ -387,21 +389,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + common.LogError(ctx, "error consuming token remain quota: "+err.Error()) } err = model.CacheUpdateUserQuota(userId) if err != nil { - common.SysError("error update user quota cache: " + err.Error()) + common.LogError(ctx, "error update user quota cache: "+err.Error()) } if quota != 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) + model.RecordConsumeLog(ctx, userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateChannelUsedQuota(channelId, quota) } } }() - }() + }(c.Request.Context()) switch apiType { case APITypeOpenAI: if isStream { diff --git a/controller/relay.go b/controller/relay.go index d20663f6..1926110e 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -196,6 +196,7 @@ func Relay(c *gin.Context) { err = relayTextHelper(c, relayMode) } if err != nil { + requestId := c.GetString(common.RequestIdKey) retryTimesStr := c.Query("retry") retryTimes, _ := strconv.Atoi(retryTimesStr) if retryTimesStr == "" { @@ -207,12 +208,13 @@ func Relay(c *gin.Context) { if err.StatusCode == http.StatusTooManyRequests { err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" } + err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) c.JSON(err.StatusCode, gin.H{ "error": err.OpenAIError, }) } channelId := c.GetInt("channel_id") - common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) + common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) // https://platform.openai.com/docs/guides/error-codes/api-errors if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { channelId := c.GetInt("channel_id") diff --git a/main.go b/main.go index 2aa52876..c7a8a2d6 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "github.com/gin-gonic/gin" "one-api/common" "one-api/controller" + "one-api/middleware" "one-api/model" "one-api/router" "os" @@ -84,10 +85,12 @@ func main() { controller.InitTokenEncoders() // Initialize HTTP server - server := gin.Default() + server := gin.New() + server.Use(gin.Recovery()) // This will cause SSE not to work!!! //server.Use(gzip.Gzip(gzip.DefaultCompression)) - + server.Use(middleware.RequestId()) + middleware.SetUpLogger(server) // Initialize session store store := cookie.NewStore([]byte(common.SessionSecret)) server.Use(sessions.Sessions("session", store)) diff --git a/middleware/auth.go b/middleware/auth.go index 95516d6e..dfbc7dbd 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -91,34 +91,16 @@ func TokenAuth() func(c *gin.Context) { key = parts[0] token, err := model.ValidateUserToken(key) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusUnauthorized, err.Error()) return } userEnabled, err := model.IsUserEnabled(token.UserId) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusInternalServerError, err.Error()) return } if !userEnabled { - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "message": "用户已被封禁", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusForbidden, "用户已被封禁") return } c.Set("id", token.UserId) @@ -134,13 +116,7 @@ func TokenAuth() func(c *gin.Context) { if model.IsAdmin(token.UserId) { c.Set("channelId", parts[1]) } else { - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "message": "普通用户不支持指定渠道", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") return } } diff --git a/middleware/distributor.go b/middleware/distributor.go index e8b76596..ab374a85 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -25,34 +25,16 @@ func Distribute() func(c *gin.Context) { if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "无效的渠道 ID", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") return } channel, err = model.GetChannelById(id, true) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "无效的渠道 ID", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") return } if channel.Status != common.ChannelStatusEnabled { - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "message": "该渠道已被禁用", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") return } } else { @@ -63,13 +45,7 @@ func Distribute() func(c *gin.Context) { err = common.UnmarshalBodyReusable(c, &modelRequest) } if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "无效的请求", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusBadRequest, "无效的请求") return } if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { @@ -99,13 +75,7 @@ func Distribute() func(c *gin.Context) { common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" } - c.JSON(http.StatusServiceUnavailable, gin.H{ - "error": gin.H{ - "message": message, - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusServiceUnavailable, message) return } } diff --git a/middleware/logger.go b/middleware/logger.go new file mode 100644 index 00000000..02f2e0a9 --- /dev/null +++ b/middleware/logger.go @@ -0,0 +1,25 @@ +package middleware + +import ( + "fmt" + "github.com/gin-gonic/gin" + "one-api/common" +) + +func SetUpLogger(server *gin.Engine) { + server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { + var requestID string + if param.Keys != nil { + requestID = param.Keys[common.RequestIdKey].(string) + } + return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", + param.TimeStamp.Format("2006/01/02 - 15:04:05"), + requestID, + param.StatusCode, + param.Latency, + param.ClientIP, + param.Method, + param.Path, + ) + })) +} diff --git a/middleware/request-id.go b/middleware/request-id.go new file mode 100644 index 00000000..e623be7a --- /dev/null +++ b/middleware/request-id.go @@ -0,0 +1,18 @@ +package middleware + +import ( + "context" + "github.com/gin-gonic/gin" + "one-api/common" +) + +func RequestId() func(c *gin.Context) { + return func(c *gin.Context) { + id := common.GetTimeString() + common.GetRandomString(8) + c.Set(common.RequestIdKey, id) + ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) + c.Request = c.Request.WithContext(ctx) + c.Header(common.RequestIdKey, id) + c.Next() + } +} diff --git a/middleware/utils.go b/middleware/utils.go new file mode 100644 index 00000000..536125cc --- /dev/null +++ b/middleware/utils.go @@ -0,0 +1,17 @@ +package middleware + +import ( + "github.com/gin-gonic/gin" + "one-api/common" +) + +func abortWithMessage(c *gin.Context, statusCode int, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), + "type": "one_api_error", + }, + }) + c.Abort() + common.LogError(c.Request.Context(), message) +} diff --git a/model/log.go b/model/log.go index b0d6409a..b6a72c26 100644 --- a/model/log.go +++ b/model/log.go @@ -1,6 +1,8 @@ package model import ( + "context" + "fmt" "gorm.io/gorm" "one-api/common" ) @@ -44,7 +46,8 @@ func RecordLog(userId int, logType int, content string) { } } -func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { +func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { + common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !common.LogConsumeEnabled { return } @@ -62,7 +65,7 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN } err := DB.Create(log).Error if err != nil { - common.SysError("failed to record log: " + err.Error()) + common.LogError(ctx, "failed to record log: "+err.Error()) } } From fe26a1448d05e4fdd4b4daea1129db19cb33fe29 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Sep 2023 15:41:01 +0800 Subject: [PATCH 05/10] docs: update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 138a15e0..02412ec8 100644 --- a/README.md +++ b/README.md @@ -325,7 +325,7 @@ graph LR ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 + 例子:`--port 3000` -2. `--log-dir `: 指定日志文件夹,如果没有设置,日志将不会被保存。 +2. `--log-dir `: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。 + 例子:`--log-dir ./logs` 3. `--version`: 打印系统版本号并退出。 4. `--help`: 查看命令的使用帮助和参数说明。 From 4335f005a6a4b6c2be1b2d6df68c2f1c5f0ebef6 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Sep 2023 16:35:30 +0800 Subject: [PATCH 06/10] feat: create new log file when too many logs recorded --- common/logger.go | 28 ++++++++++++++++++++++++++-- main.go | 2 +- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/common/logger.go b/common/logger.go index 780a6237..61627217 100644 --- a/common/logger.go +++ b/common/logger.go @@ -8,6 +8,7 @@ import ( "log" "os" "path/filepath" + "sync" "time" ) @@ -17,9 +18,24 @@ const ( loggerError = "ERR" ) -func SetupGinLog() { +const maxLogCount = 1000000 + +var logCount int +var setupLogLock sync.Mutex +var setupLogWorking bool + +func SetupLogger() { if *LogDir != "" { - logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) + ok := setupLogLock.TryLock() + if !ok { + log.Println("setup log is already working") + return + } + defer func() { + setupLogLock.Unlock() + setupLogWorking = false + }() + logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal("failed to open log file") @@ -59,6 +75,14 @@ func logHelper(ctx context.Context, level string, msg string) { id := ctx.Value(RequestIdKey) now := time.Now() _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) + logCount++ // we don't need accurate count, so no lock here + if logCount > maxLogCount && !setupLogWorking { + logCount = 0 + setupLogWorking = true + go func() { + SetupLogger() + }() + } } func FatalLog(v ...any) { diff --git a/main.go b/main.go index c7a8a2d6..e8ef4c20 100644 --- a/main.go +++ b/main.go @@ -21,7 +21,7 @@ var buildFS embed.FS var indexPage []byte func main() { - common.SetupGinLog() + common.SetupLogger() common.SysLog("One API " + common.Version + " started") if os.Getenv("GIN_MODE") != "debug" { gin.SetMode(gin.ReleaseMode) From 328aa6825596868920db2034701afb8fef52bb7c Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Sep 2023 17:09:56 +0800 Subject: [PATCH 07/10] feat: able to delete logs now (close #486) --- controller/log.go | 52 +++++++++++++++++++++----- model/log.go | 5 +++ router/api-router.go | 1 + web/src/components/OperationSetting.js | 45 ++++++++++++++++++---- 4 files changed, 85 insertions(+), 18 deletions(-) diff --git a/controller/log.go b/controller/log.go index ba043349..870ce396 100644 --- a/controller/log.go +++ b/controller/log.go @@ -2,6 +2,7 @@ package controller import ( "github.com/gin-gonic/gin" + "net/http" "one-api/common" "one-api/model" "strconv" @@ -20,17 +21,18 @@ func GetAllLogs(c *gin.Context) { modelName := c.Query("model_name") logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) if err != nil { - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": logs, }) + return } func GetUserLogs(c *gin.Context) { @@ -46,34 +48,36 @@ func GetUserLogs(c *gin.Context) { modelName := c.Query("model_name") logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) if err != nil { - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": logs, }) + return } func SearchAllLogs(c *gin.Context) { keyword := c.Query("keyword") logs, err := model.SearchAllLogs(keyword) if err != nil { - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": logs, }) + return } func SearchUserLogs(c *gin.Context) { @@ -81,17 +85,18 @@ func SearchUserLogs(c *gin.Context) { userId := c.GetInt("id") logs, err := model.SearchUserLogs(userId, keyword) if err != nil { - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": logs, }) + return } func GetLogsStat(c *gin.Context) { @@ -103,7 +108,7 @@ func GetLogsStat(c *gin.Context) { modelName := c.Query("model_name") quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": gin.H{ @@ -111,6 +116,7 @@ func GetLogsStat(c *gin.Context) { //"token": tokenNum, }, }) + return } func GetLogsSelfStat(c *gin.Context) { @@ -122,7 +128,7 @@ func GetLogsSelfStat(c *gin.Context) { modelName := c.Query("model_name") quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": gin.H{ @@ -130,4 +136,30 @@ func GetLogsSelfStat(c *gin.Context) { //"token": tokenNum, }, }) + return +} + +func DeleteHistoryLogs(c *gin.Context) { + targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64) + if targetTimestamp == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "target timestamp is required", + }) + return + } + count, err := model.DeleteOldLog(targetTimestamp) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": count, + }) + return } diff --git a/model/log.go b/model/log.go index b6a72c26..551cfda7 100644 --- a/model/log.go +++ b/model/log.go @@ -169,3 +169,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa tx.Where("type = ?", LogTypeConsume).Scan(&token) return token } + +func DeleteOldLog(targetTimestamp int64) (int64, error) { + result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) + return result.RowsAffected, result.Error +} diff --git a/router/api-router.go b/router/api-router.go index 7ad48871..d12bc54b 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -98,6 +98,7 @@ func SetApiRouter(router *gin.Engine) { } logRoute := apiRouter.Group("/log") logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs) + logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs) logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat) logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat) logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index 2adc7fa4..bf8b5ffd 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -1,8 +1,9 @@ import React, { useEffect, useState } from 'react'; import { Divider, Form, Grid, Header } from 'semantic-ui-react'; -import { API, showError, verifyJSON } from '../helpers'; +import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers'; const OperationSetting = () => { + let now = new Date(); let [inputs, setInputs] = useState({ QuotaForNewUser: 0, QuotaForInviter: 0, @@ -20,10 +21,11 @@ const OperationSetting = () => { DisplayInCurrencyEnabled: '', DisplayTokenStatEnabled: '', ApproximateTokenEnabled: '', - RetryTimes: 0, + RetryTimes: 0 }); const [originInputs, setOriginInputs] = useState({}); let [loading, setLoading] = useState(false); + let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago const getOptions = async () => { const res = await API.get('/api/option/'); @@ -130,6 +132,17 @@ const OperationSetting = () => { } }; + const deleteHistoryLogs = async () => { + console.log(inputs); + const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`); + const { success, message, data } = res.data; + if (success) { + showSuccess(`${data} 条日志已清理!`); + return; + } + showError('日志清理失败:' + message); + }; + return ( @@ -179,12 +192,6 @@ const OperationSetting = () => { /> - { submitConfig('general').then(); }}>保存通用设置 +
+ 日志设置 +
+ + + + + { + setHistoryTimestamp(value); + }} /> + + { + deleteHistoryLogs().then(); + }}>清理历史日志 +
监控设置
From 12ef9679a7dd3e2f10e57a056f8f7a46f9fced46 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Sep 2023 17:19:12 +0800 Subject: [PATCH 08/10] fix: fix url not passing when using custom chat_link --- web/src/components/TokensTable.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/components/TokensTable.js b/web/src/components/TokensTable.js index b45f07df..c7ec9b48 100644 --- a/web/src/components/TokensTable.js +++ b/web/src/components/TokensTable.js @@ -96,7 +96,7 @@ const TokensTable = () => { let nextUrl; if (nextLink) { - nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`; + nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; } else { nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; } From 24df3e5f62f99fefb7bc0f6ef09555622260e4fc Mon Sep 17 00:00:00 2001 From: igophper <34326532+igophper@users.noreply.github.com> Date: Sun, 17 Sep 2023 18:16:12 +0800 Subject: [PATCH 09/10] feat: support non-stream mode for xunfei (#498) * feat:xunfei suport none stream * fix:join content ignore seq --------- Co-authored-by: igophper --- controller/relay-text.go | 36 ++++----- controller/relay-xunfei.go | 145 ++++++++++++++++++++----------------- 2 files changed, 98 insertions(+), 83 deletions(-) diff --git a/controller/relay-text.go b/controller/relay-text.go index 6d481983..4481c652 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -541,24 +541,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return nil } case APITypeXunfei: - if isStream { - auth := c.Request.Header.Get("Authorization") - auth = strings.TrimPrefix(auth, "Bearer ") - splits := strings.Split(auth, "|") - if len(splits) != 3 { - return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) - } - err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } else { - return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) + auth := c.Request.Header.Get("Authorization") + auth = strings.TrimPrefix(auth, "Bearer ") + splits := strings.Split(auth, "|") + if len(splits) != 3 { + return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) } + var err *OpenAIErrorWithStatusCode + var usage *Usage + if isStream { + err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) + } else { + err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2]) + } + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil case APITypeAIProxyLibrary: if isStream { err, usage := aiProxyLibraryStreamHandler(c, resp) diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 3b6fe5a0..ff6bf065 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { Role: "assistant", Content: response.Payload.Choices.Text[0].Content, }, + FinishReason: stopFinishReason, } fullTextResponse := OpenAITextResponse{ Object: "chat.completion", @@ -177,33 +178,82 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { } func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { + domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) + dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) + if err != nil { + return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil + } + setEventStreamHeaders(c) var usage Usage - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString("api_version") + c.Stream(func(w io.Writer) bool { + select { + case xunfeiResponse := <-dataChan: + usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens + usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens + usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens + response := streamResponseXunfei2OpenAI(&xunfeiResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + return nil, &usage +} + +func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { + domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) + dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) + if err != nil { + return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } - if apiVersion == "" { - apiVersion = "v1.1" - common.SysLog("api_version not found, use default: " + apiVersion) + var usage Usage + var content string + var xunfeiResponse XunfeiChatResponse + stop := false + for !stop { + select { + case xunfeiResponse = <-dataChan: + content += xunfeiResponse.Payload.Choices.Text[0].Content + usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens + usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens + usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens + case stop = <-stopChan: + } } - domain := "general" - if apiVersion == "v2.1" { - domain = "generalv2" + + xunfeiResponse.Payload.Choices.Text[0].Content = content + + response := responseXunfei2OpenAI(&xunfeiResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } - hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion) + c.Writer.Header().Set("Content-Type", "application/json") + _, _ = c.Writer.Write(jsonResponse) + return nil, &usage +} + +func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } - conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil) + conn, resp, err := d.Dial(authUrl, nil) if err != nil || resp.StatusCode != 101 { - return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil + return nil, nil, err } data := requestOpenAI2Xunfei(textRequest, appId, domain) err = conn.WriteJSON(data) if err != nil { - return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil + return nil, nil, err } + dataChan := make(chan XunfeiChatResponse) stopChan := make(chan bool) go func() { @@ -230,61 +280,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId } stopChan <- true }() - setEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case xunfeiResponse := <-dataChan: - usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens - usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens - usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens - response := streamResponseXunfei2OpenAI(&xunfeiResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - return nil, &usage + + return dataChan, stopChan, nil } -func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var xunfeiResponse XunfeiChatResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil +func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) { + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + if apiVersion == "" { + apiVersion = "v1.1" + common.SysLog("api_version not found, use default: " + apiVersion) } - err = json.Unmarshal(responseBody, &xunfeiResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + domain := "general" + if apiVersion == "v2.1" { + domain = "generalv2" } - if xunfeiResponse.Header.Code != 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ - Message: xunfeiResponse.Header.Message, - Type: "xunfei_error", - Param: "", - Code: xunfeiResponse.Header.Code, - }, - StatusCode: resp.StatusCode, - }, nil - } - fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse) - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage + authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) + return domain, authUrl } From ecf8a6d87583446834c4f58622c7368d493687f6 Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Sun, 17 Sep 2023 19:18:16 +0800 Subject: [PATCH 10/10] feat: supprt channel priority now & record channel id in log (#484) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: 支持设置渠道优先级 & 日志中显示使用的渠道ID * fix: 设置渠道优先级未更新 ability * chore: update implementation --------- Co-authored-by: Xiangyuan Liu Co-authored-by: JustSong Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com> --- controller/log.go | 9 +++-- controller/relay-audio.go | 3 +- controller/relay-image.go | 3 +- controller/relay-text.go | 4 +- model/ability.go | 6 ++- model/cache.go | 17 +++++++++ model/channel.go | 1 + model/log.go | 17 +++++++-- web/src/components/ChannelsTable.js | 39 ++++++++++++++++++-- web/src/components/LogsTable.js | 57 ++++++++++++++++++++--------- 10 files changed, 122 insertions(+), 34 deletions(-) diff --git a/controller/log.go b/controller/log.go index 870ce396..b65867fe 100644 --- a/controller/log.go +++ b/controller/log.go @@ -19,7 +19,8 @@ func GetAllLogs(c *gin.Context) { username := c.Query("username") tokenName := c.Query("token_name") modelName := c.Query("model_name") - logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) + channel, _ := strconv.Atoi(c.Query("channel")) + logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -106,7 +107,8 @@ func GetLogsStat(c *gin.Context) { tokenName := c.Query("token_name") username := c.Query("username") modelName := c.Query("model_name") - quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) + channel, _ := strconv.Atoi(c.Query("channel")) + quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") c.JSON(http.StatusOK, gin.H{ "success": true, @@ -126,7 +128,8 @@ func GetLogsSelfStat(c *gin.Context) { endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) tokenName := c.Query("token_name") modelName := c.Query("model_name") - quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) + channel, _ := strconv.Atoi(c.Query("channel")) + quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) c.JSON(http.StatusOK, gin.H{ "success": true, diff --git a/controller/relay-audio.go b/controller/relay-audio.go index a7bc670b..e6f54f01 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -18,6 +18,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") + channelId := c.GetInt("channel_id") userId := c.GetInt("id") group := c.GetString("group") @@ -107,7 +108,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, 0, 0, audioModel, tokenName, quota, logContent) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) diff --git a/controller/relay-image.go b/controller/relay-image.go index b1a22570..fb30895c 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -19,6 +19,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") + channelId := c.GetInt("channel_id") userId := c.GetInt("id") consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") @@ -138,7 +139,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, 0, 0, imageModel, tokenName, quota, logContent) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) diff --git a/controller/relay-text.go b/controller/relay-text.go index 4481c652..5a5f355b 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -38,6 +38,7 @@ func init() { func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { channelType := c.GetInt("channel") + channelId := c.GetInt("channel_id") tokenId := c.GetInt("token_id") userId := c.GetInt("id") consumeQuota := c.GetBool("consume_quota") @@ -364,7 +365,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { var textResponse TextResponse tokenName := c.GetString("token_name") - channelId := c.GetInt("channel_id") defer func(ctx context.Context) { // c.Writer.Flush() @@ -397,7 +397,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } if quota != 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) + model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateChannelUsedQuota(channelId, quota) } diff --git a/model/ability.go b/model/ability.go index e87c3940..eb68fa0d 100644 --- a/model/ability.go +++ b/model/ability.go @@ -10,15 +10,16 @@ type Ability struct { Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` Enabled bool `json:"enabled"` + Priority int64 `json:"priority" gorm:"bigint;default:0"` } func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { ability := Ability{} var err error = nil if common.UsingSQLite { - err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error + err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error } else { - err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error + err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error } if err != nil { return nil, err @@ -40,6 +41,7 @@ func (channel *Channel) AddAbilities() error { Model: model, ChannelId: channel.Id, Enabled: channel.Status == common.ChannelStatusEnabled, + Priority: channel.Priority, } abilities = append(abilities, ability) } diff --git a/model/cache.go b/model/cache.go index c28952b5..1b547842 100644 --- a/model/cache.go +++ b/model/cache.go @@ -6,6 +6,7 @@ import ( "fmt" "math/rand" "one-api/common" + "sort" "strconv" "strings" "sync" @@ -159,6 +160,17 @@ func InitChannelCache() { } } } + + // sort by priority + for group, model2channels := range newGroup2model2channels { + for model, channels := range model2channels { + sort.Slice(channels, func(i, j int) bool { + return channels[i].Priority > channels[j].Priority + }) + newGroup2model2channels[group][model] = channels + } + } + channelSyncLock.Lock() group2model2channels = newGroup2model2channels channelSyncLock.Unlock() @@ -183,6 +195,11 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error if len(channels) == 0 { return nil, errors.New("channel not found") } + // choose by priority + firstChannel := channels[0] + if firstChannel.Priority > 0 { + return firstChannel, nil + } idx := rand.Intn(len(channels)) return channels[idx], nil } diff --git a/model/channel.go b/model/channel.go index 5c495bab..d146193b 100644 --- a/model/channel.go +++ b/model/channel.go @@ -23,6 +23,7 @@ type Channel struct { Group string `json:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` + Priority int64 `json:"priority" gorm:"bigint;default:0"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { diff --git a/model/log.go b/model/log.go index 551cfda7..1c0a2dc6 100644 --- a/model/log.go +++ b/model/log.go @@ -19,6 +19,7 @@ type Log struct { Quota int `json:"quota" gorm:"default:0"` PromptTokens int `json:"prompt_tokens" gorm:"default:0"` CompletionTokens int `json:"completion_tokens" gorm:"default:0"` + Channel int `json:"channel" gorm:"default:0"` } const ( @@ -46,8 +47,9 @@ func RecordLog(userId int, logType int, content string) { } } -func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { - common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, promptTokens, completionTokens, modelName, tokenName, quota, content)) + +func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { + common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !common.LogConsumeEnabled { return } @@ -62,6 +64,7 @@ func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, complet TokenName: tokenName, ModelName: modelName, Quota: quota, + Channel: channelId, } err := DB.Create(log).Error if err != nil { @@ -69,7 +72,7 @@ func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, complet } } -func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) { +func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { var tx *gorm.DB if logType == LogTypeUnknown { tx = DB @@ -91,6 +94,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName if endTimestamp != 0 { tx = tx.Where("created_at <= ?", endTimestamp) } + if channel != 0 { + tx = tx.Where("channel = ?", channel) + } err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error return logs, err } @@ -128,7 +134,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { return logs, err } -func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) { +func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { tx := DB.Table("logs").Select("sum(quota)") if username != "" { tx = tx.Where("username = ?", username) @@ -145,6 +151,9 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa if modelName != "" { tx = tx.Where("model_name = ?", modelName) } + if channel != 0 { + tx = tx.Where("channel = ?", channel) + } tx.Where("type = ?", LogTypeConsume).Scan("a) return quota } diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index f712f11a..7c8457d0 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -1,5 +1,5 @@ import React, { useEffect, useState } from 'react'; -import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; +import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react'; import { Link } from 'react-router-dom'; import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers'; @@ -24,7 +24,7 @@ function renderType(type) { } type2label[0] = { value: 0, text: '未知类型', color: 'grey' }; } - return ; + return ; } function renderBalance(type, balance) { @@ -96,7 +96,7 @@ const ChannelsTable = () => { }); }, []); - const manageChannel = async (id, action, idx) => { + const manageChannel = async (id, action, idx, priority) => { let data = { id }; let res; switch (action) { @@ -111,6 +111,13 @@ const ChannelsTable = () => { data.status = 2; res = await API.put('/api/channel/', data); break; + case 'priority': + if (priority === '') { + return; + } + data.priority = parseInt(priority); + res = await API.put('/api/channel/', data); + break; } const { success, message } = res.data; if (success) { @@ -335,6 +342,14 @@ const ChannelsTable = () => { > 余额 + { + sortChannel('priority'); + }} + > + 优先级 + 操作 @@ -373,6 +388,22 @@ const ChannelsTable = () => { basic /> + + { + manageChannel( + channel.id, + 'priority', + idx, + event.target.value, + ); + }}> + + } + content='渠道选择优先级,越高越优先' + basic + /> +
diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index c981e261..e266d79a 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -56,9 +56,10 @@ const LogsTable = () => { token_name: '', model_name: '', start_timestamp: timestamp2string(0), - end_timestamp: timestamp2string(now.getTime() / 1000 + 3600) + end_timestamp: timestamp2string(now.getTime() / 1000 + 3600), + channel: '' }); - const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs; + const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs; const [stat, setStat] = useState({ quota: 0, @@ -84,7 +85,7 @@ const LogsTable = () => { const getLogStat = async () => { let localStartTimestamp = Date.parse(start_timestamp) / 1000; let localEndTimestamp = Date.parse(end_timestamp) / 1000; - let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`); + let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`); const { success, message, data } = res.data; if (success) { setStat(data); @@ -109,7 +110,7 @@ const LogsTable = () => { let localStartTimestamp = Date.parse(start_timestamp) / 1000; let localEndTimestamp = Date.parse(end_timestamp) / 1000; if (isAdminUser) { - url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`; } else { url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; } @@ -205,16 +206,9 @@ const LogsTable = () => {
- { - isAdminUser && ( - - ) - } - - { onChange={handleInputChange} /> 查询 + { + isAdminUser && <> + + + + + + + }
@@ -238,6 +245,17 @@ const LogsTable = () => { > 时间 + { + isAdminUser && { + sortLog('channel'); + }} + width={1} + > + 渠道 + + } { isAdminUser && { onClick={() => { sortLog('quota'); }} - width={2} + width={1} > - 消耗额度 + 额度 { sortLog('content'); }} - width={isAdminUser ? 4 : 5} + width={isAdminUser ? 4 : 6} > 详情 @@ -326,6 +344,11 @@ const LogsTable = () => { return ( {renderTimestamp(log.created_at)} + { + isAdminUser && ( + {log.channel ? : ''} + ) + } { isAdminUser && ( {log.username ? : ''} @@ -345,7 +368,7 @@ const LogsTable = () => { - +