Merge remote-tracking branch 'upstream/main'
This commit is contained in:
commit
d956ee5a6b
3
.gitignore
vendored
3
.gitignore
vendored
@ -4,4 +4,5 @@ upload
|
|||||||
*.exe
|
*.exe
|
||||||
*.db
|
*.db
|
||||||
build
|
build
|
||||||
*.db-journal
|
*.db-journal
|
||||||
|
logs
|
@ -290,6 +290,12 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
|
|||||||
|
|
||||||
注意,具体的 API Base 的格式取决于你所使用的客户端。
|
注意,具体的 API Base 的格式取决于你所使用的客户端。
|
||||||
|
|
||||||
|
例如对于 OpenAI 的官方库:
|
||||||
|
```bash
|
||||||
|
OPENAI_API_KEY="sk-xxxxxx"
|
||||||
|
OPENAI_API_BASE="https://<HOST>:<PORT>/v1"
|
||||||
|
```
|
||||||
|
|
||||||
```mermaid
|
```mermaid
|
||||||
graph LR
|
graph LR
|
||||||
A(用户)
|
A(用户)
|
||||||
@ -346,7 +352,7 @@ graph LR
|
|||||||
### 命令行参数
|
### 命令行参数
|
||||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
||||||
+ 例子:`--port 3000`
|
+ 例子:`--port 3000`
|
||||||
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,日志将不会被保存。
|
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下。
|
||||||
+ 例子:`--log-dir ./logs`
|
+ 例子:`--log-dir ./logs`
|
||||||
3. `--version`: 打印系统版本号并退出。
|
3. `--version`: 打印系统版本号并退出。
|
||||||
4. `--help`: 查看命令的使用帮助和参数说明。
|
4. `--help`: 查看命令的使用帮助和参数说明。
|
||||||
|
@ -106,6 +106,10 @@ var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQU
|
|||||||
var BatchUpdateEnabled = false
|
var BatchUpdateEnabled = false
|
||||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||||
|
|
||||||
|
const (
|
||||||
|
RequestIdKey = "X-Oneapi-Request-Id"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RoleGuestUser = 0
|
RoleGuestUser = 0
|
||||||
RoleCommonUser = 1
|
RoleCommonUser = 1
|
||||||
|
@ -12,7 +12,7 @@ var (
|
|||||||
Port = flag.Int("port", 3000, "the listening port")
|
Port = flag.Int("port", 3000, "the listening port")
|
||||||
PrintVersion = flag.Bool("version", false, "print version and exit")
|
PrintVersion = flag.Bool("version", false, "print version and exit")
|
||||||
PrintHelp = flag.Bool("help", false, "print help and exit")
|
PrintHelp = flag.Bool("help", false, "print help and exit")
|
||||||
LogDir = flag.String("log-dir", "", "specify the log directory")
|
LogDir = flag.String("log-dir", "./logs", "specify the log directory")
|
||||||
)
|
)
|
||||||
|
|
||||||
func printHelp() {
|
func printHelp() {
|
||||||
|
@ -1,29 +1,47 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetupGinLog() {
|
const (
|
||||||
|
loggerINFO = "INFO"
|
||||||
|
loggerWarn = "WARN"
|
||||||
|
loggerError = "ERR"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxLogCount = 1000000
|
||||||
|
|
||||||
|
var logCount int
|
||||||
|
var setupLogLock sync.Mutex
|
||||||
|
var setupLogWorking bool
|
||||||
|
|
||||||
|
func SetupLogger() {
|
||||||
if *LogDir != "" {
|
if *LogDir != "" {
|
||||||
commonLogPath := filepath.Join(*LogDir, "common.log")
|
ok := setupLogLock.TryLock()
|
||||||
errorLogPath := filepath.Join(*LogDir, "error.log")
|
if !ok {
|
||||||
commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
log.Println("setup log is already working")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
setupLogLock.Unlock()
|
||||||
|
setupLogWorking = false
|
||||||
|
}()
|
||||||
|
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
|
||||||
|
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("failed to open log file")
|
log.Fatal("failed to open log file")
|
||||||
}
|
}
|
||||||
errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
|
||||||
if err != nil {
|
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
|
||||||
log.Fatal("failed to open log file")
|
|
||||||
}
|
|
||||||
gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd)
|
|
||||||
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,6 +55,36 @@ func SysError(s string) {
|
|||||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func LogInfo(ctx context.Context, msg string) {
|
||||||
|
logHelper(ctx, loggerINFO, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LogWarn(ctx context.Context, msg string) {
|
||||||
|
logHelper(ctx, loggerWarn, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LogError(ctx context.Context, msg string) {
|
||||||
|
logHelper(ctx, loggerError, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func logHelper(ctx context.Context, level string, msg string) {
|
||||||
|
writer := gin.DefaultErrorWriter
|
||||||
|
if level == loggerINFO {
|
||||||
|
writer = gin.DefaultWriter
|
||||||
|
}
|
||||||
|
id := ctx.Value(RequestIdKey)
|
||||||
|
now := time.Now()
|
||||||
|
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
||||||
|
logCount++ // we don't need accurate count, so no lock here
|
||||||
|
if logCount > maxLogCount && !setupLogWorking {
|
||||||
|
logCount = 0
|
||||||
|
setupLogWorking = true
|
||||||
|
go func() {
|
||||||
|
SetupLogger()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func FatalLog(v ...any) {
|
func FatalLog(v ...any) {
|
||||||
t := time.Now()
|
t := time.Now()
|
||||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
||||||
|
@ -171,6 +171,11 @@ func GetTimestamp() int64 {
|
|||||||
return time.Now().Unix()
|
return time.Now().Unix()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetTimeString() string {
|
||||||
|
now := time.Now()
|
||||||
|
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
|
||||||
|
}
|
||||||
|
|
||||||
func Max(a int, b int) int {
|
func Max(a int, b int) int {
|
||||||
if a >= b {
|
if a >= b {
|
||||||
return a
|
return a
|
||||||
@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int {
|
|||||||
}
|
}
|
||||||
return num
|
return num
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MessageWithRequestId(message string, id string) string {
|
||||||
|
return fmt.Sprintf("%s (request id: %s)", message, id)
|
||||||
|
}
|
||||||
|
@ -29,7 +29,7 @@ func GetSubscription(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
openAIError := OpenAIError{
|
openAIError := OpenAIError{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
Type: "one_api_error",
|
Type: "upstream_error",
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"error": openAIError,
|
"error": openAIError,
|
||||||
|
@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -18,19 +19,21 @@ func GetAllLogs(c *gin.Context) {
|
|||||||
username := c.Query("username")
|
username := c.Query("username")
|
||||||
tokenName := c.Query("token_name")
|
tokenName := c.Query("token_name")
|
||||||
modelName := c.Query("model_name")
|
modelName := c.Query("model_name")
|
||||||
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
|
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||||
|
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserLogs(c *gin.Context) {
|
func GetUserLogs(c *gin.Context) {
|
||||||
@ -46,34 +49,36 @@ func GetUserLogs(c *gin.Context) {
|
|||||||
modelName := c.Query("model_name")
|
modelName := c.Query("model_name")
|
||||||
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
|
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchAllLogs(c *gin.Context) {
|
func SearchAllLogs(c *gin.Context) {
|
||||||
keyword := c.Query("keyword")
|
keyword := c.Query("keyword")
|
||||||
logs, err := model.SearchAllLogs(keyword)
|
logs, err := model.SearchAllLogs(keyword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchUserLogs(c *gin.Context) {
|
func SearchUserLogs(c *gin.Context) {
|
||||||
@ -81,17 +86,18 @@ func SearchUserLogs(c *gin.Context) {
|
|||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
logs, err := model.SearchUserLogs(userId, keyword)
|
logs, err := model.SearchUserLogs(userId, keyword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": logs,
|
"data": logs,
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetLogsStat(c *gin.Context) {
|
func GetLogsStat(c *gin.Context) {
|
||||||
@ -101,9 +107,10 @@ func GetLogsStat(c *gin.Context) {
|
|||||||
tokenName := c.Query("token_name")
|
tokenName := c.Query("token_name")
|
||||||
username := c.Query("username")
|
username := c.Query("username")
|
||||||
modelName := c.Query("model_name")
|
modelName := c.Query("model_name")
|
||||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||||
|
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
|
||||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
|
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": gin.H{
|
"data": gin.H{
|
||||||
@ -111,6 +118,7 @@ func GetLogsStat(c *gin.Context) {
|
|||||||
//"token": tokenNum,
|
//"token": tokenNum,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetLogsSelfStat(c *gin.Context) {
|
func GetLogsSelfStat(c *gin.Context) {
|
||||||
@ -120,9 +128,10 @@ func GetLogsSelfStat(c *gin.Context) {
|
|||||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||||
tokenName := c.Query("token_name")
|
tokenName := c.Query("token_name")
|
||||||
modelName := c.Query("model_name")
|
modelName := c.Query("model_name")
|
||||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||||
|
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
|
||||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
||||||
c.JSON(200, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": gin.H{
|
"data": gin.H{
|
||||||
@ -130,4 +139,30 @@ func GetLogsSelfStat(c *gin.Context) {
|
|||||||
//"token": tokenNum,
|
//"token": tokenNum,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteHistoryLogs(c *gin.Context) {
|
||||||
|
targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
|
||||||
|
if targetTimestamp == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "target timestamp is required",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
count, err := model.DeleteOldLog(targetTimestamp)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": count,
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -17,6 +18,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
|
|
||||||
@ -91,7 +93,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
}
|
}
|
||||||
var audioResponse AudioResponse
|
var audioResponse AudioResponse
|
||||||
|
|
||||||
defer func() {
|
defer func(ctx context.Context) {
|
||||||
go func() {
|
go func() {
|
||||||
quota := countTokenText(audioResponse.Text, audioModel)
|
quota := countTokenText(audioResponse.Text, audioModel)
|
||||||
quotaDelta := quota - preConsumedQuota
|
quotaDelta := quota - preConsumedQuota
|
||||||
@ -106,13 +108,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
|
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}()
|
}(c.Request.Context())
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -18,6 +19,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
consumeQuota := c.GetBool("consume_quota")
|
consumeQuota := c.GetBool("consume_quota")
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
@ -124,7 +126,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
}
|
}
|
||||||
var textResponse ImageResponse
|
var textResponse ImageResponse
|
||||||
|
|
||||||
defer func() {
|
defer func(ctx context.Context) {
|
||||||
if consumeQuota {
|
if consumeQuota {
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
err := model.PostConsumeTokenQuota(tokenId, quota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -137,13 +139,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
|
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}(c.Request.Context())
|
||||||
|
|
||||||
if consumeQuota {
|
if consumeQuota {
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -37,6 +38,7 @@ func init() {
|
|||||||
|
|
||||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
consumeQuota := c.GetBool("consume_quota")
|
consumeQuota := c.GetBool("consume_quota")
|
||||||
@ -210,6 +212,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
// in this case, we do not pre-consume quota
|
// in this case, we do not pre-consume quota
|
||||||
// because the user has enough quota
|
// because the user has enough quota
|
||||||
preConsumedQuota = 0
|
preConsumedQuota = 0
|
||||||
|
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
||||||
}
|
}
|
||||||
if consumeQuota && preConsumedQuota > 0 {
|
if consumeQuota && preConsumedQuota > 0 {
|
||||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||||
@ -348,13 +351,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
if preConsumedQuota != 0 {
|
if preConsumedQuota != 0 {
|
||||||
go func() {
|
go func(ctx context.Context) {
|
||||||
// return pre-consumed quota
|
// return pre-consumed quota
|
||||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error return pre-consumed quota: " + err.Error())
|
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
|
||||||
}
|
}
|
||||||
}()
|
}(c.Request.Context())
|
||||||
}
|
}
|
||||||
return relayErrorHandler(resp)
|
return relayErrorHandler(resp)
|
||||||
}
|
}
|
||||||
@ -381,9 +384,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
|
|
||||||
var textResponse TextResponse
|
var textResponse TextResponse
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
|
|
||||||
defer func() {
|
defer func(ctx context.Context) {
|
||||||
// c.Writer.Flush()
|
// c.Writer.Flush()
|
||||||
go func() {
|
go func() {
|
||||||
if consumeQuota {
|
if consumeQuota {
|
||||||
@ -406,21 +408,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
quotaDelta := quota - preConsumedQuota
|
quotaDelta := quota - preConsumedQuota
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||||
}
|
}
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
err = model.CacheUpdateUserQuota(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error update user quota cache: " + err.Error())
|
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||||
}
|
}
|
||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}()
|
}(c.Request.Context())
|
||||||
switch apiType {
|
switch apiType {
|
||||||
case APITypeOpenAI:
|
case APITypeOpenAI:
|
||||||
if isStream {
|
if isStream {
|
||||||
@ -558,24 +560,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
case APITypeXunfei:
|
case APITypeXunfei:
|
||||||
if isStream {
|
auth := c.Request.Header.Get("Authorization")
|
||||||
auth := c.Request.Header.Get("Authorization")
|
auth = strings.TrimPrefix(auth, "Bearer ")
|
||||||
auth = strings.TrimPrefix(auth, "Bearer ")
|
splits := strings.Split(auth, "|")
|
||||||
splits := strings.Split(auth, "|")
|
if len(splits) != 3 {
|
||||||
if len(splits) != 3 {
|
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||||
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if usage != nil {
|
|
||||||
textResponse.Usage = *usage
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
|
|
||||||
}
|
}
|
||||||
|
var err *OpenAIErrorWithStatusCode
|
||||||
|
var usage *Usage
|
||||||
|
if isStream {
|
||||||
|
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
||||||
|
} else {
|
||||||
|
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
case APITypeAIProxyLibrary:
|
case APITypeAIProxyLibrary:
|
||||||
if isStream {
|
if isStream {
|
||||||
err, usage := aiProxyLibraryStreamHandler(c, resp)
|
err, usage := aiProxyLibraryStreamHandler(c, resp)
|
||||||
|
@ -147,7 +147,7 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
|
|||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
OpenAIError: OpenAIError{
|
OpenAIError: OpenAIError{
|
||||||
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
|
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
|
||||||
Type: "one_api_error",
|
Type: "upstream_error",
|
||||||
Code: "bad_response_status_code",
|
Code: "bad_response_status_code",
|
||||||
Param: strconv.Itoa(resp.StatusCode),
|
Param: strconv.Itoa(resp.StatusCode),
|
||||||
},
|
},
|
||||||
|
@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
|
|||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: response.Payload.Choices.Text[0].Content,
|
Content: response.Payload.Choices.Text[0].Content,
|
||||||
},
|
},
|
||||||
|
FinishReason: stopFinishReason,
|
||||||
}
|
}
|
||||||
fullTextResponse := OpenAITextResponse{
|
fullTextResponse := OpenAITextResponse{
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
@ -177,33 +178,82 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||||
|
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
setEventStreamHeaders(c)
|
||||||
var usage Usage
|
var usage Usage
|
||||||
query := c.Request.URL.Query()
|
c.Stream(func(w io.Writer) bool {
|
||||||
apiVersion := query.Get("api-version")
|
select {
|
||||||
if apiVersion == "" {
|
case xunfeiResponse := <-dataChan:
|
||||||
apiVersion = c.GetString("api_version")
|
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||||
|
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||||
|
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||||
|
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||||
|
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
if apiVersion == "" {
|
var usage Usage
|
||||||
apiVersion = "v1.1"
|
var content string
|
||||||
common.SysLog("api_version not found, use default: " + apiVersion)
|
var xunfeiResponse XunfeiChatResponse
|
||||||
|
stop := false
|
||||||
|
for !stop {
|
||||||
|
select {
|
||||||
|
case xunfeiResponse = <-dataChan:
|
||||||
|
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
||||||
|
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||||
|
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||||
|
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
||||||
|
case stop = <-stopChan:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
domain := "general"
|
|
||||||
if apiVersion == "v2.1" {
|
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
||||||
domain = "generalv2"
|
|
||||||
|
response := responseXunfei2OpenAI(&xunfeiResponse)
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion)
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
|
||||||
d := websocket.Dialer{
|
d := websocket.Dialer{
|
||||||
HandshakeTimeout: 5 * time.Second,
|
HandshakeTimeout: 5 * time.Second,
|
||||||
}
|
}
|
||||||
conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
|
conn, resp, err := d.Dial(authUrl, nil)
|
||||||
if err != nil || resp.StatusCode != 101 {
|
if err != nil || resp.StatusCode != 101 {
|
||||||
return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
data := requestOpenAI2Xunfei(textRequest, appId, domain)
|
||||||
err = conn.WriteJSON(data)
|
err = conn.WriteJSON(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
dataChan := make(chan XunfeiChatResponse)
|
dataChan := make(chan XunfeiChatResponse)
|
||||||
stopChan := make(chan bool)
|
stopChan := make(chan bool)
|
||||||
go func() {
|
go func() {
|
||||||
@ -230,61 +280,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
|
|||||||
}
|
}
|
||||||
stopChan <- true
|
stopChan <- true
|
||||||
}()
|
}()
|
||||||
setEventStreamHeaders(c)
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
return dataChan, stopChan, nil
|
||||||
select {
|
|
||||||
case xunfeiResponse := <-dataChan:
|
|
||||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
|
||||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
|
||||||
usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
|
|
||||||
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
|
|
||||||
jsonResponse, err := json.Marshal(response)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling stream response: " + err.Error())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
||||||
return true
|
|
||||||
case <-stopChan:
|
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return nil, &usage
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
|
||||||
var xunfeiResponse XunfeiChatResponse
|
query := c.Request.URL.Query()
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
apiVersion := query.Get("api-version")
|
||||||
if err != nil {
|
if apiVersion == "" {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
apiVersion = c.GetString("api_version")
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
if apiVersion == "" {
|
||||||
if err != nil {
|
apiVersion = "v1.1"
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
common.SysLog("api_version not found, use default: " + apiVersion)
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &xunfeiResponse)
|
domain := "general"
|
||||||
if err != nil {
|
if apiVersion == "v2.1" {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
domain = "generalv2"
|
||||||
}
|
}
|
||||||
if xunfeiResponse.Header.Code != 0 {
|
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
||||||
return &OpenAIErrorWithStatusCode{
|
return domain, authUrl
|
||||||
OpenAIError: OpenAIError{
|
|
||||||
Message: xunfeiResponse.Header.Message,
|
|
||||||
Type: "xunfei_error",
|
|
||||||
Param: "",
|
|
||||||
Code: xunfeiResponse.Header.Code,
|
|
||||||
},
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
|
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
c.Writer.Header().Set("Content-Type", "application/json")
|
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
|
||||||
return nil, &fullTextResponse.Usage
|
|
||||||
}
|
}
|
||||||
|
@ -197,6 +197,7 @@ func Relay(c *gin.Context) {
|
|||||||
err = relayTextHelper(c, relayMode)
|
err = relayTextHelper(c, relayMode)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
retryTimesStr := c.Query("retry")
|
retryTimesStr := c.Query("retry")
|
||||||
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
||||||
if retryTimesStr == "" {
|
if retryTimesStr == "" {
|
||||||
@ -208,12 +209,13 @@ func Relay(c *gin.Context) {
|
|||||||
if err.StatusCode == http.StatusTooManyRequests {
|
if err.StatusCode == http.StatusTooManyRequests {
|
||||||
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||||
}
|
}
|
||||||
|
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
|
||||||
c.JSON(err.StatusCode, gin.H{
|
c.JSON(err.StatusCode, gin.H{
|
||||||
"error": err.OpenAIError,
|
"error": err.OpenAIError,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
||||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||||
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
|
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
|
9
main.go
9
main.go
@ -42,7 +42,7 @@ func main() {
|
|||||||
common.SysLog("Sentry initialized")
|
common.SysLog("Sentry initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
common.SetupGinLog()
|
common.SetupLogger()
|
||||||
common.SysLog("One API " + common.Version + " started")
|
common.SysLog("One API " + common.Version + " started")
|
||||||
if os.Getenv("GIN_MODE") != "debug" {
|
if os.Getenv("GIN_MODE") != "debug" {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
@ -108,11 +108,12 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize HTTP server
|
// Initialize HTTP server
|
||||||
server := gin.Default()
|
server := gin.New()
|
||||||
|
server.Use(gin.Recovery())
|
||||||
// This will cause SSE not to work!!!
|
// This will cause SSE not to work!!!
|
||||||
//server.Use(gzip.Gzip(gzip.DefaultCompression))
|
//server.Use(gzip.Gzip(gzip.DefaultCompression))
|
||||||
server.Use(middleware.CORS())
|
server.Use(middleware.RequestId())
|
||||||
|
middleware.SetUpLogger(server)
|
||||||
// Initialize session store
|
// Initialize session store
|
||||||
store := cookie.NewStore([]byte(common.SessionSecret))
|
store := cookie.NewStore([]byte(common.SessionSecret))
|
||||||
server.Use(sessions.Sessions("session", store))
|
server.Use(sessions.Sessions("session", store))
|
||||||
|
@ -91,34 +91,16 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
key = parts[0]
|
key = parts[0]
|
||||||
token, err := model.ValidateUserToken(key)
|
token, err := model.ValidateUserToken(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{
|
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
||||||
"error": gin.H{
|
|
||||||
"message": err.Error(),
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userEnabled, err := model.IsUserEnabled(token.UserId)
|
userEnabled, err := model.IsUserEnabled(token.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||||
"error": gin.H{
|
|
||||||
"message": err.Error(),
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !userEnabled {
|
if !userEnabled {
|
||||||
c.JSON(http.StatusForbidden, gin.H{
|
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||||
"error": gin.H{
|
|
||||||
"message": "用户已被封禁",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set("id", token.UserId)
|
c.Set("id", token.UserId)
|
||||||
@ -134,13 +116,7 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
c.Set("channelId", parts[1])
|
c.Set("channelId", parts[1])
|
||||||
} else {
|
} else {
|
||||||
c.JSON(http.StatusForbidden, gin.H{
|
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||||
"error": gin.H{
|
|
||||||
"message": "普通用户不支持指定渠道",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -26,34 +26,16 @@ func Distribute() func(c *gin.Context) {
|
|||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
|
||||||
"error": gin.H{
|
|
||||||
"message": "无效的渠道 ID",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel, err = model.GetChannelById(id, true)
|
channel, err = model.GetChannelById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
|
||||||
"error": gin.H{
|
|
||||||
"message": "无效的渠道 ID",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if channel.Status != common.ChannelStatusEnabled {
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
c.JSON(http.StatusForbidden, gin.H{
|
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
||||||
"error": gin.H{
|
|
||||||
"message": "该渠道已被禁用",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -64,13 +46,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
|
||||||
"error": gin.H{
|
|
||||||
"message": "无效的请求",
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||||
@ -100,13 +76,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||||
message = "数据库一致性已被破坏,请联系管理员"
|
message = "数据库一致性已被破坏,请联系管理员"
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
abortWithMessage(c, http.StatusServiceUnavailable, message)
|
||||||
"error": gin.H{
|
|
||||||
"message": message,
|
|
||||||
"type": "one_api_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
25
middleware/logger.go
Normal file
25
middleware/logger.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetUpLogger(server *gin.Engine) {
|
||||||
|
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||||
|
var requestID string
|
||||||
|
if param.Keys != nil {
|
||||||
|
requestID = param.Keys[common.RequestIdKey].(string)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
|
||||||
|
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
||||||
|
requestID,
|
||||||
|
param.StatusCode,
|
||||||
|
param.Latency,
|
||||||
|
param.ClientIP,
|
||||||
|
param.Method,
|
||||||
|
param.Path,
|
||||||
|
)
|
||||||
|
}))
|
||||||
|
}
|
18
middleware/request-id.go
Normal file
18
middleware/request-id.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RequestId() func(c *gin.Context) {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
id := common.GetTimeString() + common.GetRandomString(8)
|
||||||
|
c.Set(common.RequestIdKey, id)
|
||||||
|
ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
c.Header(common.RequestIdKey, id)
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
17
middleware/utils.go
Normal file
17
middleware/utils.go
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||||
|
c.JSON(statusCode, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
||||||
|
"type": "one_api_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
common.LogError(c.Request.Context(), message)
|
||||||
|
}
|
@ -13,6 +13,7 @@ type Ability struct {
|
|||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
AllowStreaming int `json:"allow_streaming" gorm:"default:1"`
|
AllowStreaming int `json:"allow_streaming" gorm:"default:1"`
|
||||||
AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:1"`
|
AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:1"`
|
||||||
|
Priority int64 `json:"priority" gorm:"bigint;default:0"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
|
func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
|
||||||
@ -33,9 +34,9 @@ func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channe
|
|||||||
}
|
}
|
||||||
|
|
||||||
if common.UsingSQLite || common.UsingPostgreSQL {
|
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||||
err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error
|
err = DB.Where(cmd, group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error
|
||||||
} else {
|
} else {
|
||||||
err = DB.Where(cmd, group, model).Order("RAND()").Limit(1).First(&ability).Error
|
err = DB.Where(cmd, group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -59,6 +60,7 @@ func (channel *Channel) AddAbilities() error {
|
|||||||
Enabled: channel.Status == common.ChannelStatusEnabled,
|
Enabled: channel.Status == common.ChannelStatusEnabled,
|
||||||
AllowStreaming: channel.AllowStreaming,
|
AllowStreaming: channel.AllowStreaming,
|
||||||
AllowNonStreaming: channel.AllowNonStreaming,
|
AllowNonStreaming: channel.AllowNonStreaming,
|
||||||
|
Priority: channel.Priority,
|
||||||
}
|
}
|
||||||
abilities = append(abilities, ability)
|
abilities = append(abilities, ability)
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -165,6 +166,17 @@ func InitChannelCache() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sort by priority
|
||||||
|
for group, model2channels := range newGroup2model2channels {
|
||||||
|
for model, channels := range model2channels {
|
||||||
|
sort.Slice(channels, func(i, j int) bool {
|
||||||
|
return channels[i].Priority > channels[j].Priority
|
||||||
|
})
|
||||||
|
newGroup2model2channels[group][model] = channels
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
channelSyncLock.Lock()
|
channelSyncLock.Lock()
|
||||||
group2model2channels = newGroup2model2channels
|
group2model2channels = newGroup2model2channels
|
||||||
channelSyncLock.Unlock()
|
channelSyncLock.Unlock()
|
||||||
@ -197,6 +209,12 @@ func CacheGetRandomSatisfiedChannel(group string, model string, stream bool) (*C
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// choose by priority
|
||||||
|
firstChannel := filteredChannels[0]
|
||||||
|
if firstChannel.Priority > 0 {
|
||||||
|
return firstChannel, nil
|
||||||
|
}
|
||||||
|
|
||||||
idx := rand.Intn(len(filteredChannels))
|
idx := rand.Intn(len(filteredChannels))
|
||||||
return filteredChannels[idx], nil
|
return filteredChannels[idx], nil
|
||||||
}
|
}
|
||||||
|
@ -26,6 +26,7 @@ type Channel struct {
|
|||||||
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||||
AllowStreaming int `json:"allow_streaming" gorm:"default:1"`
|
AllowStreaming int `json:"allow_streaming" gorm:"default:1"`
|
||||||
AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:1"`
|
AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:1"`
|
||||||
|
Priority int64 `json:"priority" gorm:"bigint;default:0"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
||||||
|
25
model/log.go
25
model/log.go
@ -1,6 +1,8 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
)
|
)
|
||||||
@ -17,6 +19,7 @@ type Log struct {
|
|||||||
Quota int `json:"quota" gorm:"default:0"`
|
Quota int `json:"quota" gorm:"default:0"`
|
||||||
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
||||||
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
||||||
|
Channel int `json:"channel" gorm:"default:0"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -44,7 +47,9 @@ func RecordLog(userId int, logType int, content string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
|
|
||||||
|
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
|
||||||
|
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||||
if !common.LogConsumeEnabled {
|
if !common.LogConsumeEnabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -59,14 +64,15 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
|
|||||||
TokenName: tokenName,
|
TokenName: tokenName,
|
||||||
ModelName: modelName,
|
ModelName: modelName,
|
||||||
Quota: quota,
|
Quota: quota,
|
||||||
|
Channel: channelId,
|
||||||
}
|
}
|
||||||
err := DB.Create(log).Error
|
err := DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to record log: " + err.Error())
|
common.LogError(ctx, "failed to record log: "+err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
|
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
|
||||||
var tx *gorm.DB
|
var tx *gorm.DB
|
||||||
if logType == LogTypeUnknown {
|
if logType == LogTypeUnknown {
|
||||||
tx = DB
|
tx = DB
|
||||||
@ -88,6 +94,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
|||||||
if endTimestamp != 0 {
|
if endTimestamp != 0 {
|
||||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||||
}
|
}
|
||||||
|
if channel != 0 {
|
||||||
|
tx = tx.Where("channel = ?", channel)
|
||||||
|
}
|
||||||
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
||||||
return logs, err
|
return logs, err
|
||||||
}
|
}
|
||||||
@ -125,7 +134,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
|
|||||||
return logs, err
|
return logs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) {
|
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
|
||||||
tx := DB.Table("logs").Select("sum(quota)")
|
tx := DB.Table("logs").Select("sum(quota)")
|
||||||
if username != "" {
|
if username != "" {
|
||||||
tx = tx.Where("username = ?", username)
|
tx = tx.Where("username = ?", username)
|
||||||
@ -142,6 +151,9 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
|||||||
if modelName != "" {
|
if modelName != "" {
|
||||||
tx = tx.Where("model_name = ?", modelName)
|
tx = tx.Where("model_name = ?", modelName)
|
||||||
}
|
}
|
||||||
|
if channel != 0 {
|
||||||
|
tx = tx.Where("channel = ?", channel)
|
||||||
|
}
|
||||||
tx.Where("type = ?", LogTypeConsume).Scan("a)
|
tx.Where("type = ?", LogTypeConsume).Scan("a)
|
||||||
return quota
|
return quota
|
||||||
}
|
}
|
||||||
@ -166,3 +178,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
|||||||
tx.Where("type = ?", LogTypeConsume).Scan(&token)
|
tx.Where("type = ?", LogTypeConsume).Scan(&token)
|
||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DeleteOldLog(targetTimestamp int64) (int64, error) {
|
||||||
|
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
|
||||||
|
return result.RowsAffected, result.Error
|
||||||
|
}
|
||||||
|
@ -100,6 +100,7 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
}
|
}
|
||||||
logRoute := apiRouter.Group("/log")
|
logRoute := apiRouter.Group("/log")
|
||||||
logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
|
logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
|
||||||
|
logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs)
|
||||||
logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
|
logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
|
||||||
logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
|
logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
|
||||||
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
|
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func SetRelayRouter(router *gin.Engine) {
|
func SetRelayRouter(router *gin.Engine) {
|
||||||
|
router.Use(middleware.CORS())
|
||||||
// https://platform.openai.com/docs/api-reference/introduction
|
// https://platform.openai.com/docs/api-reference/introduction
|
||||||
modelsRouter := router.Group("/v1/models")
|
modelsRouter := router.Group("/v1/models")
|
||||||
modelsRouter.Use(middleware.TokenAuth())
|
modelsRouter.Use(middleware.TokenAuth())
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
|
import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react';
|
||||||
import { Link } from 'react-router-dom';
|
import { Link } from 'react-router-dom';
|
||||||
import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
|
import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers';
|
||||||
|
|
||||||
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
|
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
|
||||||
import { renderGroup, renderNumber } from '../helpers/render';
|
import { renderGroup, renderNumber } from '../helpers/render';
|
||||||
@ -24,7 +24,7 @@ function renderType(type) {
|
|||||||
}
|
}
|
||||||
type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
|
type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
|
||||||
}
|
}
|
||||||
return <Label basic color={type2label[type].color}>{type2label[type].text}</Label>;
|
return <Label basic color={type2label[type]?.color}>{type2label[type]?.text}</Label>;
|
||||||
}
|
}
|
||||||
|
|
||||||
function renderBalance(type, balance) {
|
function renderBalance(type, balance) {
|
||||||
@ -96,7 +96,7 @@ const ChannelsTable = () => {
|
|||||||
});
|
});
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const manageChannel = async (id, action, idx) => {
|
const manageChannel = async (id, action, idx, priority) => {
|
||||||
let data = { id };
|
let data = { id };
|
||||||
let res;
|
let res;
|
||||||
switch (action) {
|
switch (action) {
|
||||||
@ -111,6 +111,13 @@ const ChannelsTable = () => {
|
|||||||
data.status = 2;
|
data.status = 2;
|
||||||
res = await API.put('/api/channel/', data);
|
res = await API.put('/api/channel/', data);
|
||||||
break;
|
break;
|
||||||
|
case 'priority':
|
||||||
|
if (priority === '') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data.priority = parseInt(priority);
|
||||||
|
res = await API.put('/api/channel/', data);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
const { success, message } = res.data;
|
const { success, message } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
@ -195,6 +202,7 @@ const ChannelsTable = () => {
|
|||||||
showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
|
showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
|
showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。")
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -334,6 +342,14 @@ const ChannelsTable = () => {
|
|||||||
>
|
>
|
||||||
余额
|
余额
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
|
<Table.HeaderCell
|
||||||
|
style={{ cursor: 'pointer' }}
|
||||||
|
onClick={() => {
|
||||||
|
sortChannel('priority');
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
优先级
|
||||||
|
</Table.HeaderCell>
|
||||||
<Table.HeaderCell>操作</Table.HeaderCell>
|
<Table.HeaderCell>操作</Table.HeaderCell>
|
||||||
</Table.Row>
|
</Table.Row>
|
||||||
</Table.Header>
|
</Table.Header>
|
||||||
@ -372,6 +388,22 @@ const ChannelsTable = () => {
|
|||||||
basic
|
basic
|
||||||
/>
|
/>
|
||||||
</Table.Cell>
|
</Table.Cell>
|
||||||
|
<Table.Cell>
|
||||||
|
<Popup
|
||||||
|
trigger={<Input type="number" defaultValue={channel.priority} onBlur={(event) => {
|
||||||
|
manageChannel(
|
||||||
|
channel.id,
|
||||||
|
'priority',
|
||||||
|
idx,
|
||||||
|
event.target.value,
|
||||||
|
);
|
||||||
|
}}>
|
||||||
|
<input style={{maxWidth:'60px'}} />
|
||||||
|
</Input>}
|
||||||
|
content='渠道选择优先级,越高越优先'
|
||||||
|
basic
|
||||||
|
/>
|
||||||
|
</Table.Cell>
|
||||||
<Table.Cell>
|
<Table.Cell>
|
||||||
<div>
|
<div>
|
||||||
<Button
|
<Button
|
||||||
@ -440,7 +472,7 @@ const ChannelsTable = () => {
|
|||||||
|
|
||||||
<Table.Footer>
|
<Table.Footer>
|
||||||
<Table.Row>
|
<Table.Row>
|
||||||
<Table.HeaderCell colSpan='8'>
|
<Table.HeaderCell colSpan='9'>
|
||||||
<Button size='small' as={Link} to='/channel/add' loading={loading}>
|
<Button size='small' as={Link} to='/channel/add' loading={loading}>
|
||||||
添加新的渠道
|
添加新的渠道
|
||||||
</Button>
|
</Button>
|
||||||
|
@ -4,7 +4,7 @@ import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
|||||||
import { UserContext } from '../context/User';
|
import { UserContext } from '../context/User';
|
||||||
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
|
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
|
||||||
import Turnstile from 'react-turnstile';
|
import Turnstile from 'react-turnstile';
|
||||||
import { getOAuthState, onGitHubOAuthClicked } from './utils';stream/main
|
import { getOAuthState, onGitHubOAuthClicked } from './utils';
|
||||||
|
|
||||||
const LoginForm = () => {
|
const LoginForm = () => {
|
||||||
const [inputs, setInputs] = useState({
|
const [inputs, setInputs] = useState({
|
||||||
@ -47,12 +47,6 @@ const LoginForm = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const onGitHubOAuthClicked = () => {
|
|
||||||
window.open(
|
|
||||||
`https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email`
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const onDiscordOAuthClicked = () => {
|
const onDiscordOAuthClicked = () => {
|
||||||
window.open(
|
window.open(
|
||||||
`https://discord.com/oauth2/authorize?response_type=code&client_id=${status.discord_client_id}&redirect_uri=${window.location.origin}/oauth/discord&scope=identify`,
|
`https://discord.com/oauth2/authorize?response_type=code&client_id=${status.discord_client_id}&redirect_uri=${window.location.origin}/oauth/discord&scope=identify`,
|
||||||
|
@ -56,9 +56,10 @@ const LogsTable = () => {
|
|||||||
token_name: '',
|
token_name: '',
|
||||||
model_name: '',
|
model_name: '',
|
||||||
start_timestamp: timestamp2string(0),
|
start_timestamp: timestamp2string(0),
|
||||||
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600)
|
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
||||||
|
channel: ''
|
||||||
});
|
});
|
||||||
const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs;
|
const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs;
|
||||||
|
|
||||||
const [stat, setStat] = useState({
|
const [stat, setStat] = useState({
|
||||||
quota: 0,
|
quota: 0,
|
||||||
@ -84,7 +85,7 @@ const LogsTable = () => {
|
|||||||
const getLogStat = async () => {
|
const getLogStat = async () => {
|
||||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||||
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
|
let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
setStat(data);
|
setStat(data);
|
||||||
@ -109,7 +110,7 @@ const LogsTable = () => {
|
|||||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||||
if (isAdminUser) {
|
if (isAdminUser) {
|
||||||
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
||||||
} else {
|
} else {
|
||||||
url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||||
}
|
}
|
||||||
@ -205,16 +206,9 @@ const LogsTable = () => {
|
|||||||
</Header>
|
</Header>
|
||||||
<Form>
|
<Form>
|
||||||
<Form.Group>
|
<Form.Group>
|
||||||
{
|
<Form.Input fluid label={'令牌名称'} width={3} value={token_name}
|
||||||
isAdminUser && (
|
|
||||||
<Form.Input fluid label={'用户名称'} width={2} value={username}
|
|
||||||
placeholder={'可选值'} name='username'
|
|
||||||
onChange={handleInputChange} />
|
|
||||||
)
|
|
||||||
}
|
|
||||||
<Form.Input fluid label={'令牌名称'} width={isAdminUser ? 2 : 3} value={token_name}
|
|
||||||
placeholder={'可选值'} name='token_name' onChange={handleInputChange} />
|
placeholder={'可选值'} name='token_name' onChange={handleInputChange} />
|
||||||
<Form.Input fluid label='模型名称' width={isAdminUser ? 2 : 3} value={model_name} placeholder='可选值'
|
<Form.Input fluid label='模型名称' width={3} value={model_name} placeholder='可选值'
|
||||||
name='model_name'
|
name='model_name'
|
||||||
onChange={handleInputChange} />
|
onChange={handleInputChange} />
|
||||||
<Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local'
|
<Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local'
|
||||||
@ -225,6 +219,19 @@ const LogsTable = () => {
|
|||||||
onChange={handleInputChange} />
|
onChange={handleInputChange} />
|
||||||
<Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>
|
<Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>
|
||||||
</Form.Group>
|
</Form.Group>
|
||||||
|
{
|
||||||
|
isAdminUser && <>
|
||||||
|
<Form.Group>
|
||||||
|
<Form.Input fluid label={'渠道 ID'} width={3} value={channel}
|
||||||
|
placeholder='可选值' name='channel'
|
||||||
|
onChange={handleInputChange} />
|
||||||
|
<Form.Input fluid label={'用户名称'} width={3} value={username}
|
||||||
|
placeholder={'可选值'} name='username'
|
||||||
|
onChange={handleInputChange} />
|
||||||
|
|
||||||
|
</Form.Group>
|
||||||
|
</>
|
||||||
|
}
|
||||||
</Form>
|
</Form>
|
||||||
<Table basic compact size='small'>
|
<Table basic compact size='small'>
|
||||||
<Table.Header>
|
<Table.Header>
|
||||||
@ -238,6 +245,17 @@ const LogsTable = () => {
|
|||||||
>
|
>
|
||||||
时间
|
时间
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
|
{
|
||||||
|
isAdminUser && <Table.HeaderCell
|
||||||
|
style={{ cursor: 'pointer' }}
|
||||||
|
onClick={() => {
|
||||||
|
sortLog('channel');
|
||||||
|
}}
|
||||||
|
width={1}
|
||||||
|
>
|
||||||
|
渠道
|
||||||
|
</Table.HeaderCell>
|
||||||
|
}
|
||||||
{
|
{
|
||||||
isAdminUser && <Table.HeaderCell
|
isAdminUser && <Table.HeaderCell
|
||||||
style={{ cursor: 'pointer' }}
|
style={{ cursor: 'pointer' }}
|
||||||
@ -299,16 +317,16 @@ const LogsTable = () => {
|
|||||||
onClick={() => {
|
onClick={() => {
|
||||||
sortLog('quota');
|
sortLog('quota');
|
||||||
}}
|
}}
|
||||||
width={2}
|
width={1}
|
||||||
>
|
>
|
||||||
消耗额度
|
额度
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
<Table.HeaderCell
|
<Table.HeaderCell
|
||||||
style={{ cursor: 'pointer' }}
|
style={{ cursor: 'pointer' }}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
sortLog('content');
|
sortLog('content');
|
||||||
}}
|
}}
|
||||||
width={isAdminUser ? 4 : 5}
|
width={isAdminUser ? 4 : 6}
|
||||||
>
|
>
|
||||||
详情
|
详情
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
@ -326,6 +344,11 @@ const LogsTable = () => {
|
|||||||
return (
|
return (
|
||||||
<Table.Row key={log.id}>
|
<Table.Row key={log.id}>
|
||||||
<Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
|
<Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
|
||||||
|
{
|
||||||
|
isAdminUser && (
|
||||||
|
<Table.Cell>{log.channel ? <Label basic>{log.channel}</Label> : ''}</Table.Cell>
|
||||||
|
)
|
||||||
|
}
|
||||||
{
|
{
|
||||||
isAdminUser && (
|
isAdminUser && (
|
||||||
<Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell>
|
<Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell>
|
||||||
@ -345,7 +368,7 @@ const LogsTable = () => {
|
|||||||
|
|
||||||
<Table.Footer>
|
<Table.Footer>
|
||||||
<Table.Row>
|
<Table.Row>
|
||||||
<Table.HeaderCell colSpan={'9'}>
|
<Table.HeaderCell colSpan={'10'}>
|
||||||
<Select
|
<Select
|
||||||
placeholder='选择明细分类'
|
placeholder='选择明细分类'
|
||||||
options={LOG_OPTIONS}
|
options={LOG_OPTIONS}
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import { Divider, Form, Grid, Header } from 'semantic-ui-react';
|
import { Divider, Form, Grid, Header } from 'semantic-ui-react';
|
||||||
import { API, showError, verifyJSON } from '../helpers';
|
import { API, showError, showSuccess, timestamp2string, verifyJSON } from '../helpers';
|
||||||
|
|
||||||
const OperationSetting = () => {
|
const OperationSetting = () => {
|
||||||
|
let now = new Date();
|
||||||
let [inputs, setInputs] = useState({
|
let [inputs, setInputs] = useState({
|
||||||
QuotaForNewUser: 0,
|
QuotaForNewUser: 0,
|
||||||
QuotaForInviter: 0,
|
QuotaForInviter: 0,
|
||||||
@ -20,10 +21,11 @@ const OperationSetting = () => {
|
|||||||
DisplayInCurrencyEnabled: '',
|
DisplayInCurrencyEnabled: '',
|
||||||
DisplayTokenStatEnabled: '',
|
DisplayTokenStatEnabled: '',
|
||||||
ApproximateTokenEnabled: '',
|
ApproximateTokenEnabled: '',
|
||||||
RetryTimes: 0,
|
RetryTimes: 0
|
||||||
});
|
});
|
||||||
const [originInputs, setOriginInputs] = useState({});
|
const [originInputs, setOriginInputs] = useState({});
|
||||||
let [loading, setLoading] = useState(false);
|
let [loading, setLoading] = useState(false);
|
||||||
|
let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
|
||||||
|
|
||||||
const getOptions = async () => {
|
const getOptions = async () => {
|
||||||
const res = await API.get('/api/option/');
|
const res = await API.get('/api/option/');
|
||||||
@ -130,6 +132,17 @@ const OperationSetting = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const deleteHistoryLogs = async () => {
|
||||||
|
console.log(inputs);
|
||||||
|
const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
|
||||||
|
const { success, message, data } = res.data;
|
||||||
|
if (success) {
|
||||||
|
showSuccess(`${data} 条日志已清理!`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
showError('日志清理失败:' + message);
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Grid columns={1}>
|
<Grid columns={1}>
|
||||||
<Grid.Column>
|
<Grid.Column>
|
||||||
@ -179,12 +192,6 @@ const OperationSetting = () => {
|
|||||||
/>
|
/>
|
||||||
</Form.Group>
|
</Form.Group>
|
||||||
<Form.Group inline>
|
<Form.Group inline>
|
||||||
<Form.Checkbox
|
|
||||||
checked={inputs.LogConsumeEnabled === 'true'}
|
|
||||||
label='启用额度消费日志记录'
|
|
||||||
name='LogConsumeEnabled'
|
|
||||||
onChange={handleInputChange}
|
|
||||||
/>
|
|
||||||
<Form.Checkbox
|
<Form.Checkbox
|
||||||
checked={inputs.DisplayInCurrencyEnabled === 'true'}
|
checked={inputs.DisplayInCurrencyEnabled === 'true'}
|
||||||
label='以货币形式显示额度'
|
label='以货币形式显示额度'
|
||||||
@ -208,6 +215,28 @@ const OperationSetting = () => {
|
|||||||
submitConfig('general').then();
|
submitConfig('general').then();
|
||||||
}}>保存通用设置</Form.Button>
|
}}>保存通用设置</Form.Button>
|
||||||
<Divider />
|
<Divider />
|
||||||
|
<Header as='h3'>
|
||||||
|
日志设置
|
||||||
|
</Header>
|
||||||
|
<Form.Group inline>
|
||||||
|
<Form.Checkbox
|
||||||
|
checked={inputs.LogConsumeEnabled === 'true'}
|
||||||
|
label='启用额度消费日志记录'
|
||||||
|
name='LogConsumeEnabled'
|
||||||
|
onChange={handleInputChange}
|
||||||
|
/>
|
||||||
|
</Form.Group>
|
||||||
|
<Form.Group widths={4}>
|
||||||
|
<Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
|
||||||
|
name='history_timestamp'
|
||||||
|
onChange={(e, { name, value }) => {
|
||||||
|
setHistoryTimestamp(value);
|
||||||
|
}} />
|
||||||
|
</Form.Group>
|
||||||
|
<Form.Button onClick={() => {
|
||||||
|
deleteHistoryLogs().then();
|
||||||
|
}}>清理历史日志</Form.Button>
|
||||||
|
<Divider />
|
||||||
<Header as='h3'>
|
<Header as='h3'>
|
||||||
监控设置
|
监控设置
|
||||||
</Header>
|
</Header>
|
||||||
|
@ -96,7 +96,7 @@ const TokensTable = () => {
|
|||||||
let nextUrl;
|
let nextUrl;
|
||||||
|
|
||||||
if (nextLink) {
|
if (nextLink) {
|
||||||
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`;
|
nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||||
} else {
|
} else {
|
||||||
nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user