diff --git a/.gitignore b/.gitignore index 0b2856cc..1b2cf071 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ upload *.exe *.db build -*.db-journal \ No newline at end of file +*.db-journal +logs \ No newline at end of file diff --git a/README.md b/README.md index 1788dc69..9a3fd949 100644 --- a/README.md +++ b/README.md @@ -290,6 +290,12 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope 注意,具体的 API Base 的格式取决于你所使用的客户端。 +例如对于 OpenAI 的官方库: +```bash +OPENAI_API_KEY="sk-xxxxxx" +OPENAI_API_BASE="https://:/v1" +``` + ```mermaid graph LR A(用户) @@ -346,7 +352,7 @@ graph LR ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 + 例子:`--port 3000` -2. `--log-dir `: 指定日志文件夹,如果没有设置,日志将不会被保存。 +2. `--log-dir `: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。 + 例子:`--log-dir ./logs` 3. `--version`: 打印系统版本号并退出。 4. `--help`: 查看命令的使用帮助和参数说明。 diff --git a/common/constants.go b/common/constants.go index f812dcdb..2fdff70f 100644 --- a/common/constants.go +++ b/common/constants.go @@ -106,6 +106,10 @@ var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQU var BatchUpdateEnabled = false var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) +const ( + RequestIdKey = "X-Oneapi-Request-Id" +) + const ( RoleGuestUser = 0 RoleCommonUser = 1 diff --git a/common/init.go b/common/init.go index 0f22c69b..1e9c85ce 100644 --- a/common/init.go +++ b/common/init.go @@ -12,7 +12,7 @@ var ( Port = flag.Int("port", 3000, "the listening port") PrintVersion = flag.Bool("version", false, "print version and exit") PrintHelp = flag.Bool("help", false, "print help and exit") - LogDir = flag.String("log-dir", "", "specify the log directory") + LogDir = flag.String("log-dir", "./logs", "specify the log directory") ) func printHelp() { diff --git a/common/logger.go b/common/logger.go index 3658dbdb..61627217 100644 --- a/common/logger.go +++ b/common/logger.go @@ -1,29 +1,47 @@ package common import ( + "context" "fmt" "github.com/gin-gonic/gin" "io" "log" "os" "path/filepath" + "sync" "time" ) -func SetupGinLog() { +const ( + loggerINFO = "INFO" + loggerWarn = "WARN" + loggerError = "ERR" +) + +const maxLogCount = 1000000 + +var logCount int +var setupLogLock sync.Mutex +var setupLogWorking bool + +func SetupLogger() { if *LogDir != "" { - commonLogPath := filepath.Join(*LogDir, "common.log") - errorLogPath := filepath.Join(*LogDir, "error.log") - commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + ok := setupLogLock.TryLock() + if !ok { + log.Println("setup log is already working") + return + } + defer func() { + setupLogLock.Unlock() + setupLogWorking = false + }() + logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal("failed to open log file") } - errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - log.Fatal("failed to open log file") - } - gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd) - gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd) + gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) + gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) } } @@ -37,6 +55,36 @@ func SysError(s string) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) } +func LogInfo(ctx context.Context, msg string) { + logHelper(ctx, loggerINFO, msg) +} + +func LogWarn(ctx context.Context, msg string) { + logHelper(ctx, loggerWarn, msg) +} + +func LogError(ctx context.Context, msg string) { + logHelper(ctx, loggerError, msg) +} + +func logHelper(ctx context.Context, level string, msg string) { + writer := gin.DefaultErrorWriter + if level == loggerINFO { + writer = gin.DefaultWriter + } + id := ctx.Value(RequestIdKey) + now := time.Now() + _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) + logCount++ // we don't need accurate count, so no lock here + if logCount > maxLogCount && !setupLogWorking { + logCount = 0 + setupLogWorking = true + go func() { + SetupLogger() + }() + } +} + func FatalLog(v ...any) { t := time.Now() _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) diff --git a/common/utils.go b/common/utils.go index bb9b7e0c..ab901b77 100644 --- a/common/utils.go +++ b/common/utils.go @@ -171,6 +171,11 @@ func GetTimestamp() int64 { return time.Now().Unix() } +func GetTimeString() string { + now := time.Now() + return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) +} + func Max(a int, b int) int { if a >= b { return a @@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int { } return num } + +func MessageWithRequestId(message string, id string) string { + return fmt.Sprintf("%s (request id: %s)", message, id) +} diff --git a/controller/billing.go b/controller/billing.go index 79eae1e2..42e86aea 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -29,7 +29,7 @@ func GetSubscription(c *gin.Context) { if err != nil { openAIError := OpenAIError{ Message: err.Error(), - Type: "one_api_error", + Type: "upstream_error", } c.JSON(200, gin.H{ "error": openAIError, diff --git a/controller/log.go b/controller/log.go index ba043349..b65867fe 100644 --- a/controller/log.go +++ b/controller/log.go @@ -2,6 +2,7 @@ package controller import ( "github.com/gin-gonic/gin" + "net/http" "one-api/common" "one-api/model" "strconv" @@ -18,19 +19,21 @@ func GetAllLogs(c *gin.Context) { username := c.Query("username") tokenName := c.Query("token_name") modelName := c.Query("model_name") - logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) + channel, _ := strconv.Atoi(c.Query("channel")) + logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel) if err != nil { - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": logs, }) + return } func GetUserLogs(c *gin.Context) { @@ -46,34 +49,36 @@ func GetUserLogs(c *gin.Context) { modelName := c.Query("model_name") logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) if err != nil { - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": logs, }) + return } func SearchAllLogs(c *gin.Context) { keyword := c.Query("keyword") logs, err := model.SearchAllLogs(keyword) if err != nil { - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": logs, }) + return } func SearchUserLogs(c *gin.Context) { @@ -81,17 +86,18 @@ func SearchUserLogs(c *gin.Context) { userId := c.GetInt("id") logs, err := model.SearchUserLogs(userId, keyword) if err != nil { - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": logs, }) + return } func GetLogsStat(c *gin.Context) { @@ -101,9 +107,10 @@ func GetLogsStat(c *gin.Context) { tokenName := c.Query("token_name") username := c.Query("username") modelName := c.Query("model_name") - quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) + channel, _ := strconv.Atoi(c.Query("channel")) + quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": gin.H{ @@ -111,6 +118,7 @@ func GetLogsStat(c *gin.Context) { //"token": tokenNum, }, }) + return } func GetLogsSelfStat(c *gin.Context) { @@ -120,9 +128,10 @@ func GetLogsSelfStat(c *gin.Context) { endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) tokenName := c.Query("token_name") modelName := c.Query("model_name") - quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) + channel, _ := strconv.Atoi(c.Query("channel")) + quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) - c.JSON(200, gin.H{ + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": gin.H{ @@ -130,4 +139,30 @@ func GetLogsSelfStat(c *gin.Context) { //"token": tokenNum, }, }) + return +} + +func DeleteHistoryLogs(c *gin.Context) { + targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64) + if targetTimestamp == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "target timestamp is required", + }) + return + } + count, err := model.DeleteOldLog(targetTimestamp) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": count, + }) + return } diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 277ab404..e6f54f01 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -17,6 +18,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") + channelId := c.GetInt("channel_id") userId := c.GetInt("id") group := c.GetString("group") @@ -91,7 +93,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } var audioResponse AudioResponse - defer func() { + defer func(ctx context.Context) { go func() { quota := countTokenText(audioResponse.Text, audioModel) quotaDelta := quota - preConsumedQuota @@ -106,13 +108,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) } }() - }() + }(c.Request.Context()) responseBody, err := io.ReadAll(resp.Body) diff --git a/controller/relay-image.go b/controller/relay-image.go index de623288..fb30895c 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -18,6 +19,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") + channelId := c.GetInt("channel_id") userId := c.GetInt("id") consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") @@ -124,7 +126,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } var textResponse ImageResponse - defer func() { + defer func(ctx context.Context) { if consumeQuota { err := model.PostConsumeTokenQuota(tokenId, quota) if err != nil { @@ -137,13 +139,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) } } - }() + }(c.Request.Context()) if consumeQuota { responseBody, err := io.ReadAll(resp.Body) diff --git a/controller/relay-text.go b/controller/relay-text.go index f483e010..60fb2e2d 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -37,6 +38,7 @@ func init() { func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { channelType := c.GetInt("channel") + channelId := c.GetInt("channel_id") tokenId := c.GetInt("token_id") userId := c.GetInt("id") consumeQuota := c.GetBool("consume_quota") @@ -210,6 +212,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 + common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) } if consumeQuota && preConsumedQuota > 0 { err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) @@ -348,13 +351,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if resp.StatusCode != http.StatusOK { if preConsumedQuota != 0 { - go func() { + go func(ctx context.Context) { // return pre-consumed quota err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) if err != nil { - common.SysError("error return pre-consumed quota: " + err.Error()) + common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) } - }() + }(c.Request.Context()) } return relayErrorHandler(resp) } @@ -381,9 +384,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { var textResponse TextResponse tokenName := c.GetString("token_name") - channelId := c.GetInt("channel_id") - defer func() { + defer func(ctx context.Context) { // c.Writer.Flush() go func() { if consumeQuota { @@ -406,21 +408,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + common.LogError(ctx, "error consuming token remain quota: "+err.Error()) } err = model.CacheUpdateUserQuota(userId) if err != nil { - common.SysError("error update user quota cache: " + err.Error()) + common.LogError(ctx, "error update user quota cache: "+err.Error()) } if quota != 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) + model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateChannelUsedQuota(channelId, quota) } } }() - }() + }(c.Request.Context()) switch apiType { case APITypeOpenAI: if isStream { @@ -558,24 +560,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return nil } case APITypeXunfei: - if isStream { - auth := c.Request.Header.Get("Authorization") - auth = strings.TrimPrefix(auth, "Bearer ") - splits := strings.Split(auth, "|") - if len(splits) != 3 { - return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) - } - err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } else { - return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) + auth := c.Request.Header.Get("Authorization") + auth = strings.TrimPrefix(auth, "Bearer ") + splits := strings.Split(auth, "|") + if len(splits) != 3 { + return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) } + var err *OpenAIErrorWithStatusCode + var usage *Usage + if isStream { + err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) + } else { + err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2]) + } + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil case APITypeAIProxyLibrary: if isStream { err, usage := aiProxyLibraryStreamHandler(c, resp) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 3773dbbb..133e64d1 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -147,7 +147,7 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr StatusCode: resp.StatusCode, OpenAIError: OpenAIError{ Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), - Type: "one_api_error", + Type: "upstream_error", Code: "bad_response_status_code", Param: strconv.Itoa(resp.StatusCode), }, diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 3b6fe5a0..ff6bf065 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { Role: "assistant", Content: response.Payload.Choices.Text[0].Content, }, + FinishReason: stopFinishReason, } fullTextResponse := OpenAITextResponse{ Object: "chat.completion", @@ -177,33 +178,82 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { } func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { + domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) + dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) + if err != nil { + return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil + } + setEventStreamHeaders(c) var usage Usage - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString("api_version") + c.Stream(func(w io.Writer) bool { + select { + case xunfeiResponse := <-dataChan: + usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens + usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens + usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens + response := streamResponseXunfei2OpenAI(&xunfeiResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + return nil, &usage +} + +func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { + domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret) + dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) + if err != nil { + return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil } - if apiVersion == "" { - apiVersion = "v1.1" - common.SysLog("api_version not found, use default: " + apiVersion) + var usage Usage + var content string + var xunfeiResponse XunfeiChatResponse + stop := false + for !stop { + select { + case xunfeiResponse = <-dataChan: + content += xunfeiResponse.Payload.Choices.Text[0].Content + usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens + usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens + usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens + case stop = <-stopChan: + } } - domain := "general" - if apiVersion == "v2.1" { - domain = "generalv2" + + xunfeiResponse.Payload.Choices.Text[0].Content = content + + response := responseXunfei2OpenAI(&xunfeiResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } - hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion) + c.Writer.Header().Set("Content-Type", "application/json") + _, _ = c.Writer.Write(jsonResponse) + return nil, &usage +} + +func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) { d := websocket.Dialer{ HandshakeTimeout: 5 * time.Second, } - conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil) + conn, resp, err := d.Dial(authUrl, nil) if err != nil || resp.StatusCode != 101 { - return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil + return nil, nil, err } data := requestOpenAI2Xunfei(textRequest, appId, domain) err = conn.WriteJSON(data) if err != nil { - return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil + return nil, nil, err } + dataChan := make(chan XunfeiChatResponse) stopChan := make(chan bool) go func() { @@ -230,61 +280,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId } stopChan <- true }() - setEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case xunfeiResponse := <-dataChan: - usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens - usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens - usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens - response := streamResponseXunfei2OpenAI(&xunfeiResponse) - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - return false - } - }) - return nil, &usage + + return dataChan, stopChan, nil } -func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { - var xunfeiResponse XunfeiChatResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil +func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) { + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + if apiVersion == "" { + apiVersion = "v1.1" + common.SysLog("api_version not found, use default: " + apiVersion) } - err = json.Unmarshal(responseBody, &xunfeiResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + domain := "general" + if apiVersion == "v2.1" { + domain = "generalv2" } - if xunfeiResponse.Header.Code != 0 { - return &OpenAIErrorWithStatusCode{ - OpenAIError: OpenAIError{ - Message: xunfeiResponse.Header.Message, - Type: "xunfei_error", - Param: "", - Code: xunfeiResponse.Header.Code, - }, - StatusCode: resp.StatusCode, - }, nil - } - fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse) - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil - } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - return nil, &fullTextResponse.Usage + authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) + return domain, authUrl } diff --git a/controller/relay.go b/controller/relay.go index e6e45214..7e60afc0 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -197,6 +197,7 @@ func Relay(c *gin.Context) { err = relayTextHelper(c, relayMode) } if err != nil { + requestId := c.GetString(common.RequestIdKey) retryTimesStr := c.Query("retry") retryTimes, _ := strconv.Atoi(retryTimesStr) if retryTimesStr == "" { @@ -208,12 +209,13 @@ func Relay(c *gin.Context) { if err.StatusCode == http.StatusTooManyRequests { err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" } + err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) c.JSON(err.StatusCode, gin.H{ "error": err.OpenAIError, }) } channelId := c.GetInt("channel_id") - common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) + common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) // https://platform.openai.com/docs/guides/error-codes/api-errors if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { channelId := c.GetInt("channel_id") diff --git a/main.go b/main.go index ec98f40c..8af262a4 100644 --- a/main.go +++ b/main.go @@ -42,7 +42,7 @@ func main() { common.SysLog("Sentry initialized") } - common.SetupGinLog() + common.SetupLogger() common.SysLog("One API " + common.Version + " started") if os.Getenv("GIN_MODE") != "debug" { gin.SetMode(gin.ReleaseMode) @@ -108,11 +108,12 @@ func main() { } // Initialize HTTP server - server := gin.Default() + server := gin.New() + server.Use(gin.Recovery()) // This will cause SSE not to work!!! //server.Use(gzip.Gzip(gzip.DefaultCompression)) - server.Use(middleware.CORS()) - + server.Use(middleware.RequestId()) + middleware.SetUpLogger(server) // Initialize session store store := cookie.NewStore([]byte(common.SessionSecret)) server.Use(sessions.Sessions("session", store)) diff --git a/middleware/auth.go b/middleware/auth.go index 95516d6e..dfbc7dbd 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -91,34 +91,16 @@ func TokenAuth() func(c *gin.Context) { key = parts[0] token, err := model.ValidateUserToken(key) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusUnauthorized, err.Error()) return } userEnabled, err := model.IsUserEnabled(token.UserId) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusInternalServerError, err.Error()) return } if !userEnabled { - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "message": "用户已被封禁", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusForbidden, "用户已被封禁") return } c.Set("id", token.UserId) @@ -134,13 +116,7 @@ func TokenAuth() func(c *gin.Context) { if model.IsAdmin(token.UserId) { c.Set("channelId", parts[1]) } else { - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "message": "普通用户不支持指定渠道", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") return } } diff --git a/middleware/distributor.go b/middleware/distributor.go index f41b71f6..cbf8eeda 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -26,34 +26,16 @@ func Distribute() func(c *gin.Context) { if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "无效的渠道 ID", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") return } channel, err = model.GetChannelById(id, true) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "无效的渠道 ID", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") return } if channel.Status != common.ChannelStatusEnabled { - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "message": "该渠道已被禁用", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") return } } else { @@ -64,13 +46,7 @@ func Distribute() func(c *gin.Context) { err = common.UnmarshalBodyReusable(c, &modelRequest) } if err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "message": "无效的请求", - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusBadRequest, "无效的请求") return } if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { @@ -100,13 +76,7 @@ func Distribute() func(c *gin.Context) { common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" } - c.JSON(http.StatusServiceUnavailable, gin.H{ - "error": gin.H{ - "message": message, - "type": "one_api_error", - }, - }) - c.Abort() + abortWithMessage(c, http.StatusServiceUnavailable, message) return } } diff --git a/middleware/logger.go b/middleware/logger.go new file mode 100644 index 00000000..02f2e0a9 --- /dev/null +++ b/middleware/logger.go @@ -0,0 +1,25 @@ +package middleware + +import ( + "fmt" + "github.com/gin-gonic/gin" + "one-api/common" +) + +func SetUpLogger(server *gin.Engine) { + server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { + var requestID string + if param.Keys != nil { + requestID = param.Keys[common.RequestIdKey].(string) + } + return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", + param.TimeStamp.Format("2006/01/02 - 15:04:05"), + requestID, + param.StatusCode, + param.Latency, + param.ClientIP, + param.Method, + param.Path, + ) + })) +} diff --git a/middleware/request-id.go b/middleware/request-id.go new file mode 100644 index 00000000..e623be7a --- /dev/null +++ b/middleware/request-id.go @@ -0,0 +1,18 @@ +package middleware + +import ( + "context" + "github.com/gin-gonic/gin" + "one-api/common" +) + +func RequestId() func(c *gin.Context) { + return func(c *gin.Context) { + id := common.GetTimeString() + common.GetRandomString(8) + c.Set(common.RequestIdKey, id) + ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) + c.Request = c.Request.WithContext(ctx) + c.Header(common.RequestIdKey, id) + c.Next() + } +} diff --git a/middleware/utils.go b/middleware/utils.go new file mode 100644 index 00000000..536125cc --- /dev/null +++ b/middleware/utils.go @@ -0,0 +1,17 @@ +package middleware + +import ( + "github.com/gin-gonic/gin" + "one-api/common" +) + +func abortWithMessage(c *gin.Context, statusCode int, message string) { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), + "type": "one_api_error", + }, + }) + c.Abort() + common.LogError(c.Request.Context(), message) +} diff --git a/model/ability.go b/model/ability.go index 372427ff..831f9793 100644 --- a/model/ability.go +++ b/model/ability.go @@ -13,6 +13,7 @@ type Ability struct { Enabled bool `json:"enabled"` AllowStreaming int `json:"allow_streaming" gorm:"default:1"` AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:1"` + Priority int64 `json:"priority" gorm:"bigint;default:0"` } func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) { @@ -33,9 +34,9 @@ func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channe } if common.UsingSQLite || common.UsingPostgreSQL { - err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error + err = DB.Where(cmd, group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error } else { - err = DB.Where(cmd, group, model).Order("RAND()").Limit(1).First(&ability).Error + err = DB.Where(cmd, group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error } if err != nil { return nil, err @@ -59,6 +60,7 @@ func (channel *Channel) AddAbilities() error { Enabled: channel.Status == common.ChannelStatusEnabled, AllowStreaming: channel.AllowStreaming, AllowNonStreaming: channel.AllowNonStreaming, + Priority: channel.Priority, } abilities = append(abilities, ability) } diff --git a/model/cache.go b/model/cache.go index 24f045ef..4d561a0a 100644 --- a/model/cache.go +++ b/model/cache.go @@ -6,6 +6,7 @@ import ( "fmt" "math/rand" "one-api/common" + "sort" "strconv" "strings" "sync" @@ -165,6 +166,17 @@ func InitChannelCache() { } } } + + // sort by priority + for group, model2channels := range newGroup2model2channels { + for model, channels := range model2channels { + sort.Slice(channels, func(i, j int) bool { + return channels[i].Priority > channels[j].Priority + }) + newGroup2model2channels[group][model] = channels + } + } + channelSyncLock.Lock() group2model2channels = newGroup2model2channels channelSyncLock.Unlock() @@ -197,6 +209,12 @@ func CacheGetRandomSatisfiedChannel(group string, model string, stream bool) (*C } } + // choose by priority + firstChannel := filteredChannels[0] + if firstChannel.Priority > 0 { + return firstChannel, nil + } + idx := rand.Intn(len(filteredChannels)) return filteredChannels[idx], nil } diff --git a/model/channel.go b/model/channel.go index de8d35e9..9589918b 100644 --- a/model/channel.go +++ b/model/channel.go @@ -26,6 +26,7 @@ type Channel struct { ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` AllowStreaming int `json:"allow_streaming" gorm:"default:1"` AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:1"` + Priority int64 `json:"priority" gorm:"bigint;default:0"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { diff --git a/model/log.go b/model/log.go index b0d6409a..1c0a2dc6 100644 --- a/model/log.go +++ b/model/log.go @@ -1,6 +1,8 @@ package model import ( + "context" + "fmt" "gorm.io/gorm" "one-api/common" ) @@ -17,6 +19,7 @@ type Log struct { Quota int `json:"quota" gorm:"default:0"` PromptTokens int `json:"prompt_tokens" gorm:"default:0"` CompletionTokens int `json:"completion_tokens" gorm:"default:0"` + Channel int `json:"channel" gorm:"default:0"` } const ( @@ -44,7 +47,9 @@ func RecordLog(userId int, logType int, content string) { } } -func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { + +func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { + common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !common.LogConsumeEnabled { return } @@ -59,14 +64,15 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN TokenName: tokenName, ModelName: modelName, Quota: quota, + Channel: channelId, } err := DB.Create(log).Error if err != nil { - common.SysError("failed to record log: " + err.Error()) + common.LogError(ctx, "failed to record log: "+err.Error()) } } -func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) { +func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { var tx *gorm.DB if logType == LogTypeUnknown { tx = DB @@ -88,6 +94,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName if endTimestamp != 0 { tx = tx.Where("created_at <= ?", endTimestamp) } + if channel != 0 { + tx = tx.Where("channel = ?", channel) + } err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error return logs, err } @@ -125,7 +134,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { return logs, err } -func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) { +func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { tx := DB.Table("logs").Select("sum(quota)") if username != "" { tx = tx.Where("username = ?", username) @@ -142,6 +151,9 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa if modelName != "" { tx = tx.Where("model_name = ?", modelName) } + if channel != 0 { + tx = tx.Where("channel = ?", channel) + } tx.Where("type = ?", LogTypeConsume).Scan("a) return quota } @@ -166,3 +178,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa tx.Where("type = ?", LogTypeConsume).Scan(&token) return token } + +func DeleteOldLog(targetTimestamp int64) (int64, error) { + result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) + return result.RowsAffected, result.Error +} diff --git a/router/api-router.go b/router/api-router.go index 7e64a3bf..a2bcc8b6 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -100,6 +100,7 @@ func SetApiRouter(router *gin.Engine) { } logRoute := apiRouter.Group("/log") logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs) + logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs) logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat) logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat) logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) diff --git a/router/relay-router.go b/router/relay-router.go index a76e42cf..e84f02db 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -8,6 +8,7 @@ import ( ) func SetRelayRouter(router *gin.Engine) { + router.Use(middleware.CORS()) // https://platform.openai.com/docs/api-reference/introduction modelsRouter := router.Group("/v1/models") modelsRouter.Use(middleware.TokenAuth()) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 5eb39783..7c8457d0 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -1,7 +1,7 @@ import React, { useEffect, useState } from 'react'; -import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; +import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react'; import { Link } from 'react-router-dom'; -import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; +import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers'; import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; import { renderGroup, renderNumber } from '../helpers/render'; @@ -24,7 +24,7 @@ function renderType(type) { } type2label[0] = { value: 0, text: '未知类型', color: 'grey' }; } - return ; + return ; } function renderBalance(type, balance) { @@ -96,7 +96,7 @@ const ChannelsTable = () => { }); }, []); - const manageChannel = async (id, action, idx) => { + const manageChannel = async (id, action, idx, priority) => { let data = { id }; let res; switch (action) { @@ -111,6 +111,13 @@ const ChannelsTable = () => { data.status = 2; res = await API.put('/api/channel/', data); break; + case 'priority': + if (priority === '') { + return; + } + data.priority = parseInt(priority); + res = await API.put('/api/channel/', data); + break; } const { success, message } = res.data; if (success) { @@ -195,6 +202,7 @@ const ChannelsTable = () => { showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); } else { showError(message); + showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。") } }; @@ -334,6 +342,14 @@ const ChannelsTable = () => { > 余额 + { + sortChannel('priority'); + }} + > + 优先级 + 操作 @@ -372,6 +388,22 @@ const ChannelsTable = () => { basic /> + + { + manageChannel( + channel.id, + 'priority', + idx, + event.target.value, + ); + }}> + + } + content='渠道选择优先级,越高越优先' + basic + /> +
diff --git a/web/src/components/LoginForm.js b/web/src/components/LoginForm.js index 50dd5fec..cf0bbe2c 100644 --- a/web/src/components/LoginForm.js +++ b/web/src/components/LoginForm.js @@ -4,7 +4,7 @@ import { Link, useNavigate, useSearchParams } from 'react-router-dom'; import { UserContext } from '../context/User'; import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; import Turnstile from 'react-turnstile'; -import { getOAuthState, onGitHubOAuthClicked } from './utils';stream/main +import { getOAuthState, onGitHubOAuthClicked } from './utils'; const LoginForm = () => { const [inputs, setInputs] = useState({ @@ -47,12 +47,6 @@ const LoginForm = () => { ); }; - const onGitHubOAuthClicked = () => { - window.open( - `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` - ); - }; - const onDiscordOAuthClicked = () => { window.open( `https://discord.com/oauth2/authorize?response_type=code&client_id=${status.discord_client_id}&redirect_uri=${window.location.origin}/oauth/discord&scope=identify`, diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index c981e261..e266d79a 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -56,9 +56,10 @@ const LogsTable = () => { token_name: '', model_name: '', start_timestamp: timestamp2string(0), - end_timestamp: timestamp2string(now.getTime() / 1000 + 3600) + end_timestamp: timestamp2string(now.getTime() / 1000 + 3600), + channel: '' }); - const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs; + const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs; const [stat, setStat] = useState({ quota: 0, @@ -84,7 +85,7 @@ const LogsTable = () => { const getLogStat = async () => { let localStartTimestamp = Date.parse(start_timestamp) / 1000; let localEndTimestamp = Date.parse(end_timestamp) / 1000; - let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`); + let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`); const { success, message, data } = res.data; if (success) { setStat(data); @@ -109,7 +110,7 @@ const LogsTable = () => { let localStartTimestamp = Date.parse(start_timestamp) / 1000; let localEndTimestamp = Date.parse(end_timestamp) / 1000; if (isAdminUser) { - url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`; } else { url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; } @@ -205,16 +206,9 @@ const LogsTable = () => {
- { - isAdminUser && ( - - ) - } - - { onChange={handleInputChange} /> 查询 + { + isAdminUser && <> + + + + + + + }
@@ -238,6 +245,17 @@ const LogsTable = () => { > 时间 + { + isAdminUser && { + sortLog('channel'); + }} + width={1} + > + 渠道 + + } { isAdminUser && { onClick={() => { sortLog('quota'); }} - width={2} + width={1} > - 消耗额度 + 额度 { sortLog('content'); }} - width={isAdminUser ? 4 : 5} + width={isAdminUser ? 4 : 6} > 详情 @@ -326,6 +344,11 @@ const LogsTable = () => { return ( {renderTimestamp(log.created_at)} + { + isAdminUser && ( + {log.channel ? : ''} + ) + } { isAdminUser && ( {log.username ? : ''} @@ -345,7 +368,7 @@ const LogsTable = () => { - +