Merge remote-tracking branch 'origin/main'

# Conflicts:
#	controller/log.go
#	controller/relay-audio.go
#	controller/relay-image.go
#	controller/relay-text.go
#	controller/relay.go
#	middleware/distributor.go
#	model/log.go
#	web/src/components/OperationSetting.js
This commit is contained in:
CaIon 2023-09-17 20:59:12 +08:00
commit 985e26fd1b
30 changed files with 561 additions and 319 deletions

3
.gitignore vendored
View File

@ -4,4 +4,5 @@ upload
*.exe *.exe
*.db *.db
build build
*.db-journal *.db-journal
logs

View File

@ -340,7 +340,7 @@ graph LR
### 命令行参数 ### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000` 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`
+ 例子:`--port 3000` + 例子:`--port 3000`
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,日志将不会被保存 2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下
+ 例子:`--log-dir ./logs` + 例子:`--log-dir ./logs`
3. `--version`: 打印系统版本号并退出。 3. `--version`: 打印系统版本号并退出。
4. `--help`: 查看命令的使用帮助和参数说明。 4. `--help`: 查看命令的使用帮助和参数说明。

View File

@ -101,6 +101,10 @@ var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQU
var BatchUpdateEnabled = false var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
const ( const (
RoleGuestUser = 0 RoleGuestUser = 0
RoleCommonUser = 1 RoleCommonUser = 1

View File

@ -12,7 +12,7 @@ var (
Port = flag.Int("port", 3000, "the listening port") Port = flag.Int("port", 3000, "the listening port")
PrintVersion = flag.Bool("version", false, "print version and exit") PrintVersion = flag.Bool("version", false, "print version and exit")
PrintHelp = flag.Bool("help", false, "print help and exit") PrintHelp = flag.Bool("help", false, "print help and exit")
LogDir = flag.String("log-dir", "", "specify the log directory") LogDir = flag.String("log-dir", "./logs", "specify the log directory")
) )
func printHelp() { func printHelp() {

View File

@ -1,29 +1,47 @@
package common package common
import ( import (
"context"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
"sync"
"time" "time"
) )
func SetupGinLog() { const (
loggerINFO = "INFO"
loggerWarn = "WARN"
loggerError = "ERR"
)
const maxLogCount = 1000000
var logCount int
var setupLogLock sync.Mutex
var setupLogWorking bool
func SetupLogger() {
if *LogDir != "" { if *LogDir != "" {
commonLogPath := filepath.Join(*LogDir, "common.log") ok := setupLogLock.TryLock()
errorLogPath := filepath.Join(*LogDir, "error.log") if !ok {
commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) log.Println("setup log is already working")
return
}
defer func() {
setupLogLock.Unlock()
setupLogWorking = false
}()
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
log.Fatal("failed to open log file") log.Fatal("failed to open log file")
} }
errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
if err != nil { gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
log.Fatal("failed to open log file")
}
gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd)
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd)
} }
} }
@ -37,6 +55,36 @@ func SysError(s string) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
} }
func LogInfo(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg)
}
func LogWarn(ctx context.Context, msg string) {
logHelper(ctx, loggerWarn, msg)
}
func LogError(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg)
}
func logHelper(ctx context.Context, level string, msg string) {
writer := gin.DefaultErrorWriter
if level == loggerINFO {
writer = gin.DefaultWriter
}
id := ctx.Value(RequestIdKey)
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
logCount++ // we don't need accurate count, so no lock here
if logCount > maxLogCount && !setupLogWorking {
logCount = 0
setupLogWorking = true
go func() {
SetupLogger()
}()
}
}
func FatalLog(v ...any) { func FatalLog(v ...any) {
t := time.Now() t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)

View File

@ -171,6 +171,11 @@ func GetTimestamp() int64 {
return time.Now().Unix() return time.Now().Unix()
} }
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func Max(a int, b int) int { func Max(a int, b int) int {
if a >= b { if a >= b {
return a return a
@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int {
} }
return num return num
} }
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}

View File

@ -29,7 +29,7 @@ func GetSubscription(c *gin.Context) {
if err != nil { if err != nil {
openAIError := OpenAIError{ openAIError := OpenAIError{
Message: err.Error(), Message: err.Error(),
Type: "one_api_error", Type: "upstream_error",
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"error": openAIError, "error": openAIError,

View File

@ -18,19 +18,21 @@ func GetAllLogs(c *gin.Context) {
username := c.Query("username") username := c.Query("username")
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) channel, _ := strconv.Atoi(c.Query("channel"))
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
if err != nil { if err != nil {
c.JSON(200, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
}) })
return return
} }
c.JSON(200, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": logs, "data": logs,
}) })
return
} }
func GetUserLogs(c *gin.Context) { func GetUserLogs(c *gin.Context) {
@ -46,34 +48,36 @@ func GetUserLogs(c *gin.Context) {
modelName := c.Query("model_name") modelName := c.Query("model_name")
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
if err != nil { if err != nil {
c.JSON(200, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
}) })
return return
} }
c.JSON(200, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": logs, "data": logs,
}) })
return
} }
func SearchAllLogs(c *gin.Context) { func SearchAllLogs(c *gin.Context) {
keyword := c.Query("keyword") keyword := c.Query("keyword")
logs, err := model.SearchAllLogs(keyword) logs, err := model.SearchAllLogs(keyword)
if err != nil { if err != nil {
c.JSON(200, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
}) })
return return
} }
c.JSON(200, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": logs, "data": logs,
}) })
return
} }
func SearchUserLogs(c *gin.Context) { func SearchUserLogs(c *gin.Context) {
@ -81,17 +85,18 @@ func SearchUserLogs(c *gin.Context) {
userId := c.GetInt("id") userId := c.GetInt("id")
logs, err := model.SearchUserLogs(userId, keyword) logs, err := model.SearchUserLogs(userId, keyword)
if err != nil { if err != nil {
c.JSON(200, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
}) })
return return
} }
c.JSON(200, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": logs, "data": logs,
}) })
return
} }
func GetLogByKey(c *gin.Context) { func GetLogByKey(c *gin.Context) {
@ -118,9 +123,9 @@ func GetLogsStat(c *gin.Context) {
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
username := c.Query("username") username := c.Query("username")
modelName := c.Query("model_name") modelName := c.Query("model_name")
stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
c.JSON(200, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": gin.H{ "data": gin.H{
@ -129,6 +134,7 @@ func GetLogsStat(c *gin.Context) {
"tpm": stat.Tpm, "tpm": stat.Tpm,
}, },
}) })
return
} }
func GetLogsSelfStat(c *gin.Context) { func GetLogsSelfStat(c *gin.Context) {
@ -138,7 +144,8 @@ func GetLogsSelfStat(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) channel, _ := strconv.Atoi(c.Query("channel"))
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"success": true, "success": true,
@ -150,4 +157,30 @@ func GetLogsSelfStat(c *gin.Context) {
//"token": tokenNum, //"token": tokenNum,
}, },
}) })
return
}
func DeleteHistoryLogs(c *gin.Context) {
targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
if targetTimestamp == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "target timestamp is required",
})
return
}
count, err := model.DeleteOldLog(targetTimestamp)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": count,
})
return
} }

View File

@ -2,6 +2,7 @@ package controller
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -17,6 +18,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id") userId := c.GetInt("id")
group := c.GetString("group") group := c.GetString("group")
@ -91,7 +93,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
} }
var audioResponse AudioResponse var audioResponse AudioResponse
defer func() { defer func(ctx context.Context) {
go func() { go func() {
quota := countTokenText(audioResponse.Text, audioModel) quota := countTokenText(audioResponse.Text, audioModel)
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
@ -106,13 +108,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 { if quota != 0 {
tokenName := c.GetString("token_name") tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent, tokenId) model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent, tokenId)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota) model.UpdateChannelUsedQuota(channelId, quota)
} }
}() }()
}() }(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)

View File

@ -2,6 +2,7 @@ package controller
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -18,6 +19,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
userId := c.GetInt("id") userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota") consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group") group := c.GetString("group")
@ -124,7 +126,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
} }
var textResponse ImageResponse var textResponse ImageResponse
defer func() { defer func(ctx context.Context) {
if consumeQuota { if consumeQuota {
err := model.PostConsumeTokenQuota(tokenId, quota) err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil { if err != nil {
@ -137,13 +139,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 { if quota != 0 {
tokenName := c.GetString("token_name") tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent, tokenId) model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota) model.UpdateChannelUsedQuota(channelId, quota)
} }
} }
}() }(c.Request.Context())
if consumeQuota { if consumeQuota {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)

View File

@ -2,6 +2,7 @@ package controller
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -37,6 +38,7 @@ func init() {
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
userId := c.GetInt("id") userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota") consumeQuota := c.GetBool("consume_quota")
@ -108,7 +110,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
case common.ChannelTypeAIProxyLibrary: case common.ChannelTypeAIProxyLibrary:
apiType = APITypeAIProxyLibrary apiType = APITypeAIProxyLibrary
} }
baseURL := common.ChannelBaseURLs[channelType] baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String() requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" { if c.GetString("base_url") != "" {
@ -211,6 +212,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
// in this case, we do not pre-consume quota // in this case, we do not pre-consume quota
// because the user has enough quota // because the user has enough quota
preConsumedQuota = 0 preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
} }
if consumeQuota && preConsumedQuota > 0 { if consumeQuota && preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
@ -339,9 +341,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
} }
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
//req.Header.Set("Connection", c.Request.Header.Get("Connection")) //req.Header.Set("Connection", c.Request.Header.Get("Connection"))
req.Close = true
resp, err = httpClient.Do(req) resp, err = httpClient.Do(req)
if err != nil { if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
@ -358,13 +358,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
if preConsumedQuota != 0 { if preConsumedQuota != 0 {
go func() { go func(ctx context.Context) {
// return pre-consumed quota // return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil { if err != nil {
common.SysError("error return pre-consumed quota: " + err.Error()) common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
} }
}() }(c.Request.Context())
} }
return relayErrorHandler(resp) return relayErrorHandler(resp)
} }
@ -374,7 +374,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
tokenName := c.GetString("token_name") tokenName := c.GetString("token_name")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
defer func() { defer func(ctx context.Context) {
// c.Writer.Flush() // c.Writer.Flush()
go func() { go func() {
if consumeQuota { if consumeQuota {
@ -397,21 +397,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta) err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.LogError(ctx, "error consuming token remain quota: "+err.Error())
} }
err = model.CacheUpdateUserQuota(userId) err = model.CacheUpdateUserQuota(userId)
if err != nil { if err != nil {
common.SysError("error update user quota cache: " + err.Error()) common.LogError(ctx, "error update user quota cache: "+err.Error())
} }
if quota != 0 { if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId) model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota) model.UpdateChannelUsedQuota(channelId, quota)
} }
} }
}() }()
}() }(c.Request.Context())
switch apiType { switch apiType {
case APITypeOpenAI: case APITypeOpenAI:
if isStream { if isStream {
@ -549,24 +549,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return nil return nil
} }
case APITypeXunfei: case APITypeXunfei:
if isStream { auth := c.Request.Header.Get("Authorization")
auth := c.Request.Header.Get("Authorization") auth = strings.TrimPrefix(auth, "Bearer ")
auth = strings.TrimPrefix(auth, "Bearer ") splits := strings.Split(auth, "|")
splits := strings.Split(auth, "|") if len(splits) != 3 {
if len(splits) != 3 { return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
}
err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
} }
var err *OpenAIErrorWithStatusCode
var usage *Usage
if isStream {
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
} else {
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
}
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
case APITypeAIProxyLibrary: case APITypeAIProxyLibrary:
if isStream { if isStream {
err, usage := aiProxyLibraryStreamHandler(c, resp) err, usage := aiProxyLibraryStreamHandler(c, resp)

View File

@ -150,7 +150,7 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
OpenAIError: OpenAIError{ OpenAIError: OpenAIError{
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
Type: "one_api_error", Type: "upstream_error",
Code: "bad_response_status_code", Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode), Param: strconv.Itoa(resp.StatusCode),
}, },

View File

@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
Role: "assistant", Role: "assistant",
Content: response.Payload.Choices.Text[0].Content, Content: response.Payload.Choices.Text[0].Content,
}, },
FinishReason: stopFinishReason,
} }
fullTextResponse := OpenAITextResponse{ fullTextResponse := OpenAITextResponse{
Object: "chat.completion", Object: "chat.completion",
@ -177,33 +178,82 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
} }
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
}
setEventStreamHeaders(c)
var usage Usage var usage Usage
query := c.Request.URL.Query() c.Stream(func(w io.Writer) bool {
apiVersion := query.Get("api-version") select {
if apiVersion == "" { case xunfeiResponse := <-dataChan:
apiVersion = c.GetString("api_version") usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, &usage
}
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
} }
if apiVersion == "" { var usage Usage
apiVersion = "v1.1" var content string
common.SysLog("api_version not found, use default: " + apiVersion) var xunfeiResponse XunfeiChatResponse
stop := false
for !stop {
select {
case xunfeiResponse = <-dataChan:
content += xunfeiResponse.Payload.Choices.Text[0].Content
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
case stop = <-stopChan:
}
} }
domain := "general"
if apiVersion == "v2.1" { xunfeiResponse.Payload.Choices.Text[0].Content = content
domain = "generalv2"
response := responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion) c.Writer.Header().Set("Content-Type", "application/json")
_, _ = c.Writer.Write(jsonResponse)
return nil, &usage
}
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
d := websocket.Dialer{ d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second, HandshakeTimeout: 5 * time.Second,
} }
conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil) conn, resp, err := d.Dial(authUrl, nil)
if err != nil || resp.StatusCode != 101 { if err != nil || resp.StatusCode != 101 {
return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil return nil, nil, err
} }
data := requestOpenAI2Xunfei(textRequest, appId, domain) data := requestOpenAI2Xunfei(textRequest, appId, domain)
err = conn.WriteJSON(data) err = conn.WriteJSON(data)
if err != nil { if err != nil {
return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil return nil, nil, err
} }
dataChan := make(chan XunfeiChatResponse) dataChan := make(chan XunfeiChatResponse)
stopChan := make(chan bool) stopChan := make(chan bool)
go func() { go func() {
@ -230,61 +280,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
} }
stopChan <- true stopChan <- true
}() }()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool { return dataChan, stopChan, nil
select {
case xunfeiResponse := <-dataChan:
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
return nil, &usage
} }
func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
var xunfeiResponse XunfeiChatResponse query := c.Request.URL.Query()
responseBody, err := io.ReadAll(resp.Body) apiVersion := query.Get("api-version")
if err != nil { if apiVersion == "" {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil apiVersion = c.GetString("api_version")
} }
err = resp.Body.Close() if apiVersion == "" {
if err != nil { apiVersion = "v1.1"
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil common.SysLog("api_version not found, use default: " + apiVersion)
} }
err = json.Unmarshal(responseBody, &xunfeiResponse) domain := "general"
if err != nil { if apiVersion == "v2.1" {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil domain = "generalv2"
} }
if xunfeiResponse.Header.Code != 0 { authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
return &OpenAIErrorWithStatusCode{ return domain, authUrl
OpenAIError: OpenAIError{
Message: xunfeiResponse.Header.Message,
Type: "xunfei_error",
Param: "",
Code: xunfeiResponse.Header.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
} }

View File

@ -205,15 +205,20 @@ func Relay(c *gin.Context) {
relayMode = RelayModeImagesGenerations relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits relayMode = RelayModeEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
relayMode = RelayModeAudio
} }
var err *OpenAIErrorWithStatusCode var err *OpenAIErrorWithStatusCode
switch relayMode { switch relayMode {
case RelayModeImagesGenerations: case RelayModeImagesGenerations:
err = relayImageHelper(c, relayMode) err = relayImageHelper(c, relayMode)
case RelayModeAudio:
err = relayAudioHelper(c, relayMode)
default: default:
err = relayTextHelper(c, relayMode) err = relayTextHelper(c, relayMode)
} }
if err != nil { if err != nil {
requestId := c.GetString(common.RequestIdKey)
retryTimesStr := c.Query("retry") retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr) retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" { if retryTimesStr == "" {
@ -225,12 +230,13 @@ func Relay(c *gin.Context) {
if err.StatusCode == http.StatusTooManyRequests { if err.StatusCode == http.StatusTooManyRequests {
//err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" //err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
} }
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
c.JSON(err.StatusCode, gin.H{ c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError, "error": err.OpenAIError,
}) })
} }
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %v ", channelId, err)) common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors // https://platform.openai.com/docs/guides/error-codes/api-errors
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")

View File

@ -21,7 +21,7 @@ var buildFS embed.FS
var indexPage []byte var indexPage []byte
func main() { func main() {
common.SetupGinLog() common.SetupLogger()
common.SysLog("One API " + common.Version + " started") common.SysLog("One API " + common.Version + " started")
if os.Getenv("GIN_MODE") != "debug" { if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
@ -86,11 +86,12 @@ func main() {
controller.InitTokenEncoders() controller.InitTokenEncoders()
// Initialize HTTP server // Initialize HTTP server
server := gin.Default() server := gin.New()
server.Use(gin.Recovery())
// This will cause SSE not to work!!! // This will cause SSE not to work!!!
//server.Use(gzip.Gzip(gzip.DefaultCompression)) //server.Use(gzip.Gzip(gzip.DefaultCompression))
server.Use(middleware.CORS()) server.Use(middleware.RequestId())
middleware.SetUpLogger(server)
// Initialize session store // Initialize session store
store := cookie.NewStore([]byte(common.SessionSecret)) store := cookie.NewStore([]byte(common.SessionSecret))
server.Use(sessions.Sessions("session", store)) server.Use(sessions.Sessions("session", store))

View File

@ -100,34 +100,16 @@ func TokenAuth() func(c *gin.Context) {
} }
token, err := model.ValidateUserToken(key) token, err := model.ValidateUserToken(key)
if err != nil { if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{ abortWithMessage(c, http.StatusUnauthorized, err.Error())
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
c.Abort()
return return
} }
userEnabled, err := model.IsUserEnabled(token.UserId) userEnabled, err := model.IsUserEnabled(token.UserId)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{ abortWithMessage(c, http.StatusInternalServerError, err.Error())
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
c.Abort()
return return
} }
if !userEnabled { if !userEnabled {
c.JSON(http.StatusForbidden, gin.H{ abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
"error": gin.H{
"message": "用户已被封禁",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
c.Set("id", token.UserId) c.Set("id", token.UserId)
@ -143,13 +125,7 @@ func TokenAuth() func(c *gin.Context) {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1]) c.Set("channelId", parts[1])
} else { } else {
c.JSON(http.StatusForbidden, gin.H{ abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
"error": gin.H{
"message": "普通用户不支持指定渠道",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
} }

View File

@ -2,7 +2,6 @@ package middleware
import ( import (
"fmt" "fmt"
"log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
@ -26,39 +25,22 @@ func Distribute() func(c *gin.Context) {
if ok { if ok {
id, err := strconv.Atoi(channelId.(string)) id, err := strconv.Atoi(channelId.(string))
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{ abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
"error": gin.H{
"message": "无效的渠道 ID",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
channel, err = model.GetChannelById(id, true) channel, err = model.GetChannelById(id, true)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{ abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
"error": gin.H{
"message": "无效的渠道 ID",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
if channel.Status != common.ChannelStatusEnabled { if channel.Status != common.ChannelStatusEnabled {
c.JSON(http.StatusForbidden, gin.H{ abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
"error": gin.H{
"message": "该渠道已被禁用",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
} else { } else {
// Select a channel for the user // Select a channel for the user
var modelRequest ModelRequest var modelRequest ModelRequest
var err error
if strings.HasPrefix(c.Request.URL.Path, "/mj") { if strings.HasPrefix(c.Request.URL.Path, "/mj") {
// Midjourney // Midjourney
if modelRequest.Model == "" { if modelRequest.Model == "" {
@ -67,17 +49,17 @@ func Distribute() func(c *gin.Context) {
} else { } else {
err := common.UnmarshalBodyReusable(c, &modelRequest) err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil { if err != nil {
log.Println(err) abortWithMessage(c, http.StatusBadRequest, "无效的请求")
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的请求",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
} }
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" { if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable" modelRequest.Model = "text-moderation-stable"
@ -98,7 +80,6 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "whisper-1" modelRequest.Model = "whisper-1"
} }
} }
var err error
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil { if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
@ -106,17 +87,10 @@ func Distribute() func(c *gin.Context) {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员" message = "数据库一致性已被破坏,请联系管理员"
} }
c.JSON(http.StatusServiceUnavailable, gin.H{ abortWithMessage(c, http.StatusServiceUnavailable, message)
"error": gin.H{
"message": message,
"type": "one_api_error",
},
})
c.Abort()
return return
} }
} }
//log.Printf("Using channel %v", channel)
c.Set("channel", channel.Type) c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id) c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name) c.Set("channel_name", channel.Name)

25
middleware/logger.go Normal file
View File

@ -0,0 +1,25 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
)
func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string
if param.Keys != nil {
requestID = param.Keys[common.RequestIdKey].(string)
}
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
requestID,
param.StatusCode,
param.Latency,
param.ClientIP,
param.Method,
param.Path,
)
}))
}

18
middleware/request-id.go Normal file
View File

@ -0,0 +1,18 @@
package middleware
import (
"context"
"github.com/gin-gonic/gin"
"one-api/common"
)
func RequestId() func(c *gin.Context) {
return func(c *gin.Context) {
id := common.GetTimeString() + common.GetRandomString(8)
c.Set(common.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx)
c.Header(common.RequestIdKey, id)
c.Next()
}
}

17
middleware/utils.go Normal file
View File

@ -0,0 +1,17 @@
package middleware
import (
"github.com/gin-gonic/gin"
"one-api/common"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
"type": "one_api_error",
},
})
c.Abort()
common.LogError(c.Request.Context(), message)
}

View File

@ -10,15 +10,16 @@ type Ability struct {
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
Priority int64 `json:"priority" gorm:"bigint;default:0"`
} }
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{} ability := Ability{}
var err error = nil var err error = nil
if common.UsingSQLite { if common.UsingSQLite {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error
} else { } else {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -40,6 +41,7 @@ func (channel *Channel) AddAbilities() error {
Model: model, Model: model,
ChannelId: channel.Id, ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled, Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
} }
abilities = append(abilities, ability) abilities = append(abilities, ability)
} }

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"one-api/common" "one-api/common"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -159,6 +160,17 @@ func InitChannelCache() {
} }
} }
} }
// sort by priority
for group, model2channels := range newGroup2model2channels {
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return channels[i].Priority > channels[j].Priority
})
newGroup2model2channels[group][model] = channels
}
}
channelSyncLock.Lock() channelSyncLock.Lock()
group2model2channels = newGroup2model2channels group2model2channels = newGroup2model2channels
channelSyncLock.Unlock() channelSyncLock.Unlock()
@ -183,6 +195,11 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
if len(channels) == 0 { if len(channels) == 0 {
return nil, errors.New("channel not found") return nil, errors.New("channel not found")
} }
// choose by priority
firstChannel := channels[0]
if firstChannel.Priority > 0 {
return firstChannel, nil
}
idx := rand.Intn(len(channels)) idx := rand.Intn(len(channels))
return channels[idx], nil return channels[idx], nil
} }

View File

@ -24,6 +24,7 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"` Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority int64 `json:"priority" gorm:"bigint;default:0"`
} }
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {

View File

@ -1,6 +1,8 @@
package model package model
import ( import (
"context"
"fmt"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"strings" "strings"
@ -19,6 +21,7 @@ type Log struct {
PromptTokens int `json:"prompt_tokens" gorm:"default:0"` PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"` CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
TokenId int `json:"token_id" gorm:"default:0;index"` TokenId int `json:"token_id" gorm:"default:0;index"`
Channel int `json:"channel" gorm:"default:0"`
} }
const ( const (
@ -51,7 +54,8 @@ func RecordLog(userId int, logType int, content string) {
} }
} }
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int) { func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled { if !common.LogConsumeEnabled {
return return
} }
@ -66,15 +70,16 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
TokenName: tokenName, TokenName: tokenName,
ModelName: modelName, ModelName: modelName,
Quota: quota, Quota: quota,
Channel: channelId,
TokenId: tokenId, TokenId: tokenId,
} }
err := DB.Create(log).Error err := DB.Create(log).Error
if err != nil { if err != nil {
common.SysError("failed to record log: " + err.Error()) common.LogError(ctx, "failed to record log: "+err.Error())
} }
} }
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) { func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
var tx *gorm.DB var tx *gorm.DB
if logType == LogTypeUnknown { if logType == LogTypeUnknown {
tx = DB tx = DB
@ -96,6 +101,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if endTimestamp != 0 { if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp) tx = tx.Where("created_at <= ?", endTimestamp)
} }
if channel != 0 {
tx = tx.Where("channel = ?", channel)
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
return logs, err return logs, err
} }
@ -139,7 +147,7 @@ type Stat struct {
Tpm int `json:"tpm"` Tpm int `json:"tpm"`
} }
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (stat Stat) { func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
tx := DB.Table("logs").Select("sum(quota) quota, count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm") tx := DB.Table("logs").Select("sum(quota) quota, count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
@ -156,6 +164,10 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name = ?", modelName) tx = tx.Where("model_name = ?", modelName)
} }
if channel != 0 {
tx = tx.Where("channel = ?", channel)
}
tx.Where("type = ?", LogTypeConsume).Scan(&stat) tx.Where("type = ?", LogTypeConsume).Scan(&stat)
return stat return stat
} }
@ -180,3 +192,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
tx.Where("type = ?", LogTypeConsume).Scan(&token) tx.Where("type = ?", LogTypeConsume).Scan(&token)
return token return token
} }
func DeleteOldLog(targetTimestamp int64) (int64, error) {
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error
}

View File

@ -103,6 +103,7 @@ func SetApiRouter(router *gin.Engine) {
} }
logRoute := apiRouter.Group("/log") logRoute := apiRouter.Group("/log")
logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs) logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs)
logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat) logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat) logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)

View File

@ -8,6 +8,7 @@ import (
) )
func SetRelayRouter(router *gin.Engine) { func SetRelayRouter(router *gin.Engine) {
router.Use(middleware.CORS())
// https://platform.openai.com/docs/api-reference/introduction // https://platform.openai.com/docs/api-reference/introduction
modelsRouter := router.Group("/v1/models") modelsRouter := router.Group("/v1/models")
modelsRouter.Use(middleware.TokenAuth()) modelsRouter.Use(middleware.TokenAuth())

View File

@ -1,7 +1,7 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react';
import { Link } from 'react-router-dom'; import { Link } from 'react-router-dom';
import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
import {renderGroup, renderNumber, renderQuota} from '../helpers/render'; import {renderGroup, renderNumber, renderQuota} from '../helpers/render';
@ -24,7 +24,7 @@ function renderType(type) {
} }
type2label[0] = { value: 0, text: '未知类型', color: 'grey' }; type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
} }
return <Label basic color={type2label[type].color}>{type2label[type].text}</Label>; return <Label basic color={type2label[type]?.color}>{type2label[type]?.text}</Label>;
} }
function renderBalance(type, balance) { function renderBalance(type, balance) {
@ -96,7 +96,7 @@ const ChannelsTable = () => {
}); });
}, []); }, []);
const manageChannel = async (id, action, idx) => { const manageChannel = async (id, action, idx, priority) => {
let data = { id }; let data = { id };
let res; let res;
switch (action) { switch (action) {
@ -111,6 +111,13 @@ const ChannelsTable = () => {
data.status = 2; data.status = 2;
res = await API.put('/api/channel/', data); res = await API.put('/api/channel/', data);
break; break;
case 'priority':
if (priority === '') {
return;
}
data.priority = parseInt(priority);
res = await API.put('/api/channel/', data);
break;
} }
const { success, message } = res.data; const { success, message } = res.data;
if (success) { if (success) {
@ -195,6 +202,7 @@ const ChannelsTable = () => {
showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
} else { } else {
showError(message); showError(message);
showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。")
} }
}; };
@ -346,6 +354,14 @@ const ChannelsTable = () => {
> >
余额 余额
</Table.HeaderCell> </Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortChannel('priority');
}}
>
优先级
</Table.HeaderCell>
<Table.HeaderCell>操作</Table.HeaderCell> <Table.HeaderCell>操作</Table.HeaderCell>
</Table.Row> </Table.Row>
</Table.Header> </Table.Header>
@ -385,6 +401,22 @@ const ChannelsTable = () => {
basic basic
/> />
</Table.Cell> </Table.Cell>
<Table.Cell>
<Popup
trigger={<Input type="number" defaultValue={channel.priority} onBlur={(event) => {
manageChannel(
channel.id,
'priority',
idx,
event.target.value,
);
}}>
<input style={{maxWidth:'60px'}} />
</Input>}
content='渠道选择优先级,越高越优先'
basic
/>
</Table.Cell>
<Table.Cell> <Table.Cell>
<div> <div>
<Button <Button
@ -453,7 +485,7 @@ const ChannelsTable = () => {
<Table.Footer> <Table.Footer>
<Table.Row> <Table.Row>
<Table.HeaderCell colSpan='8'> <Table.HeaderCell colSpan='9'>
<Button size='small' as={Link} to='/channel/add' loading={loading}> <Button size='small' as={Link} to='/channel/add' loading={loading}>
添加新的渠道 添加新的渠道
</Button> </Button>

View File

@ -56,9 +56,10 @@ const LogsTable = () => {
token_name: '', token_name: '',
model_name: '', model_name: '',
start_timestamp: timestamp2string(0), start_timestamp: timestamp2string(0),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600) end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
channel: ''
}); });
const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs; const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs;
const [stat, setStat] = useState({ const [stat, setStat] = useState({
quota: 0, quota: 0,
@ -84,7 +85,7 @@ const LogsTable = () => {
const getLogStat = async () => { const getLogStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000; let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000; let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`); let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
setStat(data); setStat(data);
@ -109,7 +110,7 @@ const LogsTable = () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000; let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000; let localEndTimestamp = Date.parse(end_timestamp) / 1000;
if (isAdminUser) { if (isAdminUser) {
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
} else { } else {
url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
} }
@ -205,16 +206,9 @@ const LogsTable = () => {
</Header> </Header>
<Form> <Form>
<Form.Group> <Form.Group>
{ <Form.Input fluid label={'令牌名称'} width={3} value={token_name}
isAdminUser && (
<Form.Input fluid label={'用户名称'} width={2} value={username}
placeholder={'可选值'} name='username'
onChange={handleInputChange} />
)
}
<Form.Input fluid label={'令牌名称'} width={isAdminUser ? 2 : 3} value={token_name}
placeholder={'可选值'} name='token_name' onChange={handleInputChange} /> placeholder={'可选值'} name='token_name' onChange={handleInputChange} />
<Form.Input fluid label='模型名称' width={isAdminUser ? 2 : 3} value={model_name} placeholder='可选值' <Form.Input fluid label='模型名称' width={3} value={model_name} placeholder='可选值'
name='model_name' name='model_name'
onChange={handleInputChange} /> onChange={handleInputChange} />
<Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local' <Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local'
@ -225,6 +219,19 @@ const LogsTable = () => {
onChange={handleInputChange} /> onChange={handleInputChange} />
<Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button> <Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>
</Form.Group> </Form.Group>
{
isAdminUser && <>
<Form.Group>
<Form.Input fluid label={'渠道 ID'} width={3} value={channel}
placeholder='可选值' name='channel'
onChange={handleInputChange} />
<Form.Input fluid label={'用户名称'} width={3} value={username}
placeholder={'可选值'} name='username'
onChange={handleInputChange} />
</Form.Group>
</>
}
</Form> </Form>
<Table basic compact size='small'> <Table basic compact size='small'>
<Table.Header> <Table.Header>
@ -238,6 +245,17 @@ const LogsTable = () => {
> >
时间 时间
</Table.HeaderCell> </Table.HeaderCell>
{
isAdminUser && <Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
sortLog('channel');
}}
width={1}
>
渠道
</Table.HeaderCell>
}
{ {
isAdminUser && <Table.HeaderCell isAdminUser && <Table.HeaderCell
style={{ cursor: 'pointer' }} style={{ cursor: 'pointer' }}
@ -299,16 +317,16 @@ const LogsTable = () => {
onClick={() => { onClick={() => {
sortLog('quota'); sortLog('quota');
}} }}
width={2} width={1}
> >
消耗额度 额度
</Table.HeaderCell> </Table.HeaderCell>
<Table.HeaderCell <Table.HeaderCell
style={{ cursor: 'pointer' }} style={{ cursor: 'pointer' }}
onClick={() => { onClick={() => {
sortLog('content'); sortLog('content');
}} }}
width={isAdminUser ? 4 : 5} width={isAdminUser ? 4 : 6}
> >
详情 详情
</Table.HeaderCell> </Table.HeaderCell>
@ -326,6 +344,11 @@ const LogsTable = () => {
return ( return (
<Table.Row key={log.id}> <Table.Row key={log.id}>
<Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell> <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
{
isAdminUser && (
<Table.Cell>{log.channel ? <Label basic>{log.channel}</Label> : ''}</Table.Cell>
)
}
{ {
isAdminUser && ( isAdminUser && (
<Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell> <Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell>
@ -345,7 +368,7 @@ const LogsTable = () => {
<Table.Footer> <Table.Footer>
<Table.Row> <Table.Row>
<Table.HeaderCell colSpan={'9'}> <Table.HeaderCell colSpan={'10'}>
<Select <Select
placeholder='选择明细分类' placeholder='选择明细分类'
options={LOG_OPTIONS} options={LOG_OPTIONS}

View File

@ -1,9 +1,9 @@
import React, {useEffect, useState} from 'react'; import React, {useEffect, useState} from 'react';
import {Divider, Form, Grid, Header} from 'semantic-ui-react'; import {Divider, Form, Grid, Header} from 'semantic-ui-react';
import {API, showError, verifyJSON} from '../helpers'; import {API, showError, showSuccess, timestamp2string, verifyJSON} from '../helpers';
const OperationSetting = () => { const OperationSetting = () => {
let [inputs, setInputs] = useState({ let now = new Date();let [inputs, setInputs] = useState({
QuotaForNewUser: 0, QuotaForNewUser: 0,
QuotaForInviter: 0, QuotaForInviter: 0,
QuotaForInvitee: 0, QuotaForInvitee: 0,
@ -20,28 +20,28 @@ const OperationSetting = () => {
DisplayInCurrencyEnabled: '', DisplayInCurrencyEnabled: '',
DisplayTokenStatEnabled: '', DisplayTokenStatEnabled: '',
ApproximateTokenEnabled: '', ApproximateTokenEnabled: '',
RetryTimes: 0, RetryTimes: 0
}); });
const [originInputs, setOriginInputs] = useState({}); const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false); let [loading, setLoading] = useState(false);let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
const getOptions = async () => { const getOptions = async () => {
const res = await API.get('/api/option/'); const res = await API.get('/api/option/');
const {success, message, data} = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
let newInputs = {}; let newInputs = {};
data.forEach((item) => { data.forEach((item) => {
if (item.key === 'ModelRatio' || item.key === 'GroupRatio') { if (item.key === 'ModelRatio' || item.key === 'GroupRatio') {
item.value = JSON.stringify(JSON.parse(item.value), null, 2); item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
newInputs[item.key] = item.value;
});
setInputs(newInputs);
setOriginInputs(newInputs);
} else {
showError(message);
} }
}; newInputs[item.key] = item.value;
});
setInputs(newInputs);
setOriginInputs(newInputs);
} else {
showError(message);
}
};
useEffect(() => { useEffect(() => {
getOptions().then(); getOptions().then();
@ -73,72 +73,73 @@ const OperationSetting = () => {
} }
}; };
const submitConfig = async (group) => { const submitConfig = async (group) => {
switch (group) { switch (group) {
case 'monitor': case 'monitor':
if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) { if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) {
await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold); await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold);
}
if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
}
break;
case 'stable':
await updateOption('StablePrice', inputs.StablePrice);
await updateOption('NormalPrice', inputs.NormalPrice);
await updateOption('BasePrice', inputs.BasePrice);
localStorage.setItem('stable_price', inputs.StablePrice);
localStorage.setItem('normal_price', inputs.NormalPrice);
localStorage.setItem('base_price', inputs.BasePrice);
break;
case 'ratio':
if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
if (!verifyJSON(inputs.ModelRatio)) {
showError('模型倍率不是合法的 JSON 字符串');
return;
}
await updateOption('ModelRatio', inputs.ModelRatio);
}
if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
if (!verifyJSON(inputs.GroupRatio)) {
showError('分组倍率不是合法的 JSON 字符串');
return;
}
await updateOption('GroupRatio', inputs.GroupRatio);
}
break;
case 'quota':
if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
}
if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) {
await updateOption('QuotaForInvitee', inputs.QuotaForInvitee);
}
if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) {
await updateOption('QuotaForInviter', inputs.QuotaForInviter);
}
if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) {
await updateOption('PreConsumedQuota', inputs.PreConsumedQuota);
}
break;
case 'general':
if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
await updateOption('TopUpLink', inputs.TopUpLink);
}
if (originInputs['ChatLink'] !== inputs.ChatLink) {
await updateOption('ChatLink', inputs.ChatLink);
}
if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
}
if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
await updateOption('RetryTimes', inputs.RetryTimes);
}
break;
} }
}; if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
}
break;
case 'ratio':
if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
if (!verifyJSON(inputs.ModelRatio)) {
showError('模型倍率不是合法的 JSON 字符串');
return;
}
await updateOption('ModelRatio', inputs.ModelRatio);
}
if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
if (!verifyJSON(inputs.GroupRatio)) {
showError('分组倍率不是合法的 JSON 字符串');
return;
}
await updateOption('GroupRatio', inputs.GroupRatio);
}
break;
case 'quota':
if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
}
if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) {
await updateOption('QuotaForInvitee', inputs.QuotaForInvitee);
}
if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) {
await updateOption('QuotaForInviter', inputs.QuotaForInviter);
}
if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) {
await updateOption('PreConsumedQuota', inputs.PreConsumedQuota);
}
break;
case 'general':
if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
await updateOption('TopUpLink', inputs.TopUpLink);
}
if (originInputs['ChatLink'] !== inputs.ChatLink) {
await updateOption('ChatLink', inputs.ChatLink);
}
if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
}
if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
await updateOption('RetryTimes', inputs.RetryTimes);
}
break;
}
};
return ( const deleteHistoryLogs = async () => {
console.log(inputs);
const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
const { success, message, data } = res.data;
if (success) {
showSuccess(`${data} 条日志已清理!`);
return;
}
showError('日志清理失败:' + message);
};return (
<Grid columns={1}> <Grid columns={1}>
<Grid.Column> <Grid.Column>
<Form loading={loading}> <Form loading={loading}>
@ -187,12 +188,7 @@ const OperationSetting = () => {
/> />
</Form.Group> </Form.Group>
<Form.Group inline> <Form.Group inline>
<Form.Checkbox
checked={inputs.LogConsumeEnabled === 'true'}
label='启用额度消费日志记录'
name='LogConsumeEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox <Form.Checkbox
checked={inputs.DisplayInCurrencyEnabled === 'true'} checked={inputs.DisplayInCurrencyEnabled === 'true'}
label='以货币形式显示额度' label='以货币形式显示额度'
@ -214,7 +210,28 @@ const OperationSetting = () => {
</Form.Group> </Form.Group>
<Form.Button onClick={() => { <Form.Button onClick={() => {
submitConfig('general').then(); submitConfig('general').then();
}}>保存通用设置</Form.Button> }}>保存通用设置</Form.Button><Divider />
<Header as='h3'>
日志设置
</Header>
<Form.Group inline>
<Form.Checkbox
checked={inputs.LogConsumeEnabled === 'true'}
label='启用额度消费日志记录'
name='LogConsumeEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Group widths={4}>
<Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
name='history_timestamp'
onChange={(e, { name, value }) => {
setHistoryTimestamp(value);
}} />
</Form.Group>
<Form.Button onClick={() => {
deleteHistoryLogs().then();
}}>清理历史日志</Form.Button>
<Divider/> <Divider/>
<Header as='h3'> <Header as='h3'>
监控设置 监控设置

View File

@ -96,7 +96,7 @@ const TokensTable = () => {
let nextUrl; let nextUrl;
if (nextLink) { if (nextLink) {
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`; nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
} else { } else {
nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
} }